vicellst-official / visualization.py
polejowska's picture
Update visualization.py
8abff99 verified
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
from constants import COLORS
from utils import fig2img
def visualize_prediction(
pil_img, output_dict, threshold=0.7, id2label=None, display_mask=False, mask=None
):
keep = output_dict["scores"] > threshold
boxes = output_dict["boxes"][keep].tolist()
scores = output_dict["scores"][keep].tolist()
labels = output_dict["labels"][keep].tolist()
if id2label is not None:
labels = [id2label[x] for x in labels]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(pil_img)
if display_mask and mask is not None:
mask_arr = np.asarray(mask)
new_mask = np.zeros_like(mask_arr)
new_mask[mask_arr > 0] = 255
new_mask = Image.fromarray(new_mask)
ax.imshow(new_mask, alpha=0.5, cmap='viridis')
colors = COLORS * 100
counter = 0
for score, (xmin, ymin, xmax, ymax), label, color in zip(
scores, boxes, labels, colors
):
counter += 1
ax.add_patch(
plt.Rectangle(
(xmin, ymin),
xmax - xmin,
ymax - ymin,
fill=False,
color=color,
linewidth=2,
)
)
ax.text(
xmin,
ymin,
f"[{counter}] {score:0.2f}",
fontsize=8,
bbox=dict(facecolor="yellow", alpha=0.5),
)
ax.axis("off")
return fig2img(fig)
def visualize_attention_map(pil_img, attention_map):
attention_map = attention_map[-1].detach().cpu()
n_heads = attention_map.shape[1]
avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
resized_attention_weight = F.interpolate(
avg_attention_weight.unsqueeze(0).unsqueeze(0),
size=pil_img.size[::-1],
mode="bicubic",
).squeeze().numpy()
fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(pil_img)
ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
ax.set_title(f"Head {i+1}")
ax.axis("off")
plt.tight_layout()
return fig2img(fig)