Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def repeat_tensors(tensor, repeat_counts): | |
repeated_tensors = [tensor[i:i+1].repeat(repeat, *[1] * (tensor.ndim - 1)) for i, repeat in enumerate(repeat_counts)] | |
return torch.cat(repeated_tensors, dim=0) | |
def split_tensors(tensor, split_counts): | |
indices = torch.cumsum(torch.tensor([0] + split_counts), dim=0) | |
return [tensor[indices[i]:indices[i+1]] for i in range(len(split_counts))] | |
def visualize_heatmap(pil_image, heatmap, bbox=None): | |
if isinstance(heatmap, torch.Tensor): | |
heatmap = heatmap.detach().cpu().numpy() | |
heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) | |
heatmap = plt.cm.jet(np.array(heatmap) / 255.) | |
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) | |
heatmap = Image.fromarray(heatmap).convert("RGBA") | |
heatmap.putalpha(128) | |
overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) | |
if bbox is not None: | |
width, height = pil_image.size | |
xmin, ymin, xmax, ymax = bbox | |
draw = ImageDraw.Draw(overlay_image) | |
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="green", width=3) | |
return overlay_image | |
def stack_and_pad(tensor_list): | |
max_size = max([t.shape[0] for t in tensor_list]) | |
padded_list = [] | |
for t in tensor_list: | |
if t.shape[0] == max_size: | |
padded_list.append(t) | |
else: | |
padded_list.append(torch.cat([t, torch.zeros(max_size - t.shape[0], *t.shape[1:])], dim=0)) | |
return torch.stack(padded_list) | |