File size: 1,664 Bytes
9c9498f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt
def repeat_tensors(tensor, repeat_counts):
repeated_tensors = [tensor[i:i+1].repeat(repeat, *[1] * (tensor.ndim - 1)) for i, repeat in enumerate(repeat_counts)]
return torch.cat(repeated_tensors, dim=0)
def split_tensors(tensor, split_counts):
indices = torch.cumsum(torch.tensor([0] + split_counts), dim=0)
return [tensor[indices[i]:indices[i+1]] for i in range(len(split_counts))]
def visualize_heatmap(pil_image, heatmap, bbox=None):
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.detach().cpu().numpy()
heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
heatmap = plt.cm.jet(np.array(heatmap) / 255.)
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
heatmap = Image.fromarray(heatmap).convert("RGBA")
heatmap.putalpha(128)
overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)
if bbox is not None:
width, height = pil_image.size
xmin, ymin, xmax, ymax = bbox
draw = ImageDraw.Draw(overlay_image)
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="green", width=3)
return overlay_image
def stack_and_pad(tensor_list):
max_size = max([t.shape[0] for t in tensor_list])
padded_list = []
for t in tensor_list:
if t.shape[0] == max_size:
padded_list.append(t)
else:
padded_list.append(torch.cat([t, torch.zeros(max_size - t.shape[0], *t.shape[1:])], dim=0))
return torch.stack(padded_list)
|