import torch from PIL import Image, ImageDraw import numpy as np import matplotlib.pyplot as plt def repeat_tensors(tensor, repeat_counts): repeated_tensors = [tensor[i:i+1].repeat(repeat, *[1] * (tensor.ndim - 1)) for i, repeat in enumerate(repeat_counts)] return torch.cat(repeated_tensors, dim=0) def split_tensors(tensor, split_counts): indices = torch.cumsum(torch.tensor([0] + split_counts), dim=0) return [tensor[indices[i]:indices[i+1]] for i in range(len(split_counts))] def visualize_heatmap(pil_image, heatmap, bbox=None): if isinstance(heatmap, torch.Tensor): heatmap = heatmap.detach().cpu().numpy() heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) heatmap = plt.cm.jet(np.array(heatmap) / 255.) heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) heatmap = Image.fromarray(heatmap).convert("RGBA") heatmap.putalpha(128) overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) if bbox is not None: width, height = pil_image.size xmin, ymin, xmax, ymax = bbox draw = ImageDraw.Draw(overlay_image) draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="green", width=3) return overlay_image def stack_and_pad(tensor_list): max_size = max([t.shape[0] for t in tensor_list]) padded_list = [] for t in tensor_list: if t.shape[0] == max_size: padded_list.append(t) else: padded_list.append(torch.cat([t, torch.zeros(max_size - t.shape[0], *t.shape[1:])], dim=0)) return torch.stack(padded_list)