|
import torch |
|
from PIL import Image, ImageDraw |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
def repeat_tensors(tensor, repeat_counts): |
|
repeated_tensors = [tensor[i:i+1].repeat(repeat, *[1] * (tensor.ndim - 1)) for i, repeat in enumerate(repeat_counts)] |
|
return torch.cat(repeated_tensors, dim=0) |
|
|
|
def split_tensors(tensor, split_counts): |
|
indices = torch.cumsum(torch.tensor([0] + split_counts), dim=0) |
|
return [tensor[indices[i]:indices[i+1]] for i in range(len(split_counts))] |
|
|
|
def visualize_heatmap(pil_image, heatmap, bbox=None): |
|
if isinstance(heatmap, torch.Tensor): |
|
heatmap = heatmap.detach().cpu().numpy() |
|
heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) |
|
heatmap = plt.cm.jet(np.array(heatmap) / 255.) |
|
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) |
|
heatmap = Image.fromarray(heatmap).convert("RGBA") |
|
heatmap.putalpha(128) |
|
overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) |
|
|
|
if bbox is not None: |
|
width, height = pil_image.size |
|
xmin, ymin, xmax, ymax = bbox |
|
draw = ImageDraw.Draw(overlay_image) |
|
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="green", width=3) |
|
return overlay_image |
|
|
|
def stack_and_pad(tensor_list): |
|
max_size = max([t.shape[0] for t in tensor_list]) |
|
padded_list = [] |
|
for t in tensor_list: |
|
if t.shape[0] == max_size: |
|
padded_list.append(t) |
|
else: |
|
padded_list.append(torch.cat([t, torch.zeros(max_size - t.shape[0], *t.shape[1:])], dim=0)) |
|
return torch.stack(padded_list) |
|
|