Spaces:
Build error
Build error
import io | |
import math | |
import os | |
import PIL.Image | |
import numpy as np | |
import imageio.v3 as iio | |
import warnings | |
import torch | |
import torchvision.transforms.functional as TF | |
from scipy.ndimage import binary_dilation, binary_erosion | |
import cv2 | |
import re | |
import matplotlib.pyplot as plt | |
from matplotlib import animation | |
from IPython.display import HTML, Image, display | |
IMG_THUMBSIZE = None | |
def torch2np(x, vmin=-1, vmax=1): | |
if x.ndim != 4: | |
# raise Exception("Please only use (B,C,H,W) torch tensors!") | |
warnings.warn( | |
"Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!") | |
if x.ndim == 3: | |
x = x[None] | |
if x.ndim == 2: | |
x = x[None, None] | |
x = x.detach().cpu().float() | |
if x.dtype == torch.uint8: | |
return x.numpy().astype(np.uint8) | |
elif vmin is not None and vmax is not None: | |
x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin)) | |
x = x.permute(0, 2, 3, 1).to(torch.uint8) | |
return x.numpy() | |
else: | |
raise NotImplementedError() | |
class IImage: | |
''' | |
Generic media storage. Can store both images and videos. | |
Stores data as a numpy array by default. | |
Can be viewed in a jupyter notebook. | |
''' | |
def open(path): | |
iio_obj = iio.imopen(path, 'r') | |
data = iio_obj.read() | |
try: | |
# .properties() does not work for images but for gif files | |
if not iio_obj.properties().is_batch: | |
data = data[None] | |
except AttributeError as e: | |
# this one works for gif files | |
if not "duration" in iio_obj.metadata(): | |
data = data[None] | |
if data.ndim == 3: | |
data = data[..., None] | |
image = IImage(data) | |
image.link = os.path.abspath(path) | |
return image | |
def normalized(x, dims=[-1, -2]): | |
x = (x - x.amin(dims, True)) / \ | |
(x.amax(dims, True) - x.amin(dims, True)) | |
return IImage(x, 0) | |
def numpy(self): return self.data | |
def torch(self, vmin=-1, vmax=1): | |
if self.data.ndim == 3: | |
data = self.data.transpose(2, 0, 1) / 255. | |
else: | |
data = self.data.transpose(0, 3, 1, 2) / 255. | |
return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin) | |
def cuda(self): | |
self.device = 'cuda' | |
return self | |
def cpu(self): | |
self.device = 'cpu' | |
return self | |
def pil(self): | |
ans = [] | |
for x in self.data: | |
if x.shape[-1] == 1: | |
x = x[..., 0] | |
ans.append(PIL.Image.fromarray(x)) | |
if len(ans) == 1: | |
return ans[0] | |
return ans | |
def is_iimage(self): | |
return True | |
def shape(self): return self.data.shape | |
def size(self): return (self.data.shape[-2], self.data.shape[-3]) | |
def setFps(self, fps): | |
self.fps = fps | |
self.generate_display() | |
return self | |
def __init__(self, x, vmin=-1, vmax=1, fps=None): | |
if isinstance(x, PIL.Image.Image): | |
self.data = np.array(x) | |
if self.data.ndim == 2: | |
self.data = self.data[..., None] # (H,W,C) | |
self.data = self.data[None] # (B,H,W,C) | |
elif isinstance(x, IImage): | |
self.data = x.data.copy() # Simple Copy | |
elif isinstance(x, np.ndarray): | |
self.data = x.copy().astype(np.uint8) | |
if self.data.ndim == 2: | |
self.data = self.data[None, ..., None] | |
if self.data.ndim == 3: | |
warnings.warn( | |
"Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)") | |
self.data = self.data[None] | |
elif isinstance(x, torch.Tensor): | |
self.data = torch2np(x, vmin, vmax) | |
self.display_str = None | |
self.device = 'cpu' | |
self.fps = fps if fps is not None else ( | |
1 if len(self.data) < 10 else 30) | |
self.link = None | |
def generate_display(self): | |
if IMG_THUMBSIZE is not None: | |
if self.size[1] < self.size[0]: | |
thumb = self.resize( | |
(self.size[1]*IMG_THUMBSIZE//self.size[0], IMG_THUMBSIZE)) | |
else: | |
thumb = self.resize( | |
(IMG_THUMBSIZE, self.size[0]*IMG_THUMBSIZE//self.size[1])) | |
else: | |
thumb = self | |
if self.is_video(): | |
self.anim = Animation(thumb.data, fps=self.fps) | |
self.anim.render() | |
self.display_str = self.anim.anim_str | |
else: | |
b = io.BytesIO() | |
data = thumb.data[0] | |
if data.shape[-1] == 1: | |
data = data[..., 0] | |
PIL.Image.fromarray(data).save(b, "PNG") | |
self.display_str = b.getvalue() | |
return self.display_str | |
def resize(self, size, *args, **kwargs): | |
if size is None: | |
return self | |
use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False) | |
# Backward compatibility | |
resample = kwargs.pop('filter', PIL.Image.BICUBIC) | |
resample = kwargs.pop('resample', resample) | |
if isinstance(size, int): | |
if use_small_edge_when_int: | |
h, w = self.data.shape[1:3] | |
aspect_ratio = h / w | |
size = (max(size, int(size * aspect_ratio)), | |
max(size, int(size / aspect_ratio))) | |
else: | |
h, w = self.data.shape[1:3] | |
aspect_ratio = h / w | |
size = (min(size, int(size * aspect_ratio)), | |
min(size, int(size / aspect_ratio))) | |
if self.size == size[::-1]: | |
return self | |
return stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self]) | |
def pad(self, padding, *args, **kwargs): | |
return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0) | |
def padx(self, multiplier, *args, **kwargs): | |
size = np.array(self.size) | |
padding = np.concatenate( | |
[[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size]) | |
return self.pad(list(padding), *args, **kwargs) | |
def pad2wh(self, w=0, h=0, **kwargs): | |
cw, ch = self.size | |
return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs) | |
def pad2square(self, *args, **kwargs): | |
if self.size[0] > self.size[1]: | |
dx = self.size[0] - self.size[1] | |
return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs) | |
elif self.size[0] < self.size[1]: | |
dx = self.size[1] - self.size[0] | |
return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs) | |
return self | |
def crop2square(self, *args, **kwargs): | |
if self.size[0] > self.size[1]: | |
dx = self.size[0] - self.size[1] | |
return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs) | |
elif self.size[0] < self.size[1]: | |
dx = self.size[1] - self.size[0] | |
return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs) | |
return self | |
def alpha(self): | |
return IImage(self.data[..., -1, None], fps=self.fps) | |
def rgb(self): | |
return IImage(self.pil().convert('RGB'), fps=self.fps) | |
def png(self): | |
return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1)) | |
def grid(self, nrows=None, ncols=None): | |
if nrows is not None: | |
ncols = math.ceil(self.data.shape[0] / nrows) | |
elif ncols is not None: | |
nrows = math.ceil(self.data.shape[0] / ncols) | |
else: | |
warnings.warn( | |
"No dimensions specified, creating a grid with 5 columns (default)") | |
ncols = 5 | |
nrows = math.ceil(self.data.shape[0] / ncols) | |
pad = nrows * ncols - self.data.shape[0] | |
data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0))) | |
rows = [np.concatenate(x, 1, dtype=np.uint8) | |
for x in np.array_split(data, nrows)] | |
return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None]) | |
def hstack(self): | |
return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None]) | |
def vstack(self): | |
return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None]) | |
def vsplit(self, number_of_splits): | |
return IImage(np.concatenate(np.split(self.data, number_of_splits, 1))) | |
def hsplit(self, number_of_splits): | |
return IImage(np.concatenate(np.split(self.data, number_of_splits, 2))) | |
def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET): | |
data = np.stack([cv2.cvtColor(cv2.applyColorMap( | |
x, cmap), cv2.COLOR_BGR2RGB) for x in self.data]) | |
return IImage(data).resize(resize, use_small_edge_when_int=True) | |
def display(self): | |
try: | |
display(self) | |
except: | |
print("No display") | |
return self | |
def dilate(self, iterations=1, *args, **kwargs): | |
if iterations == 0: | |
return IImage(self.data) | |
return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) | |
def erode(self, iterations=1, *args, **kwargs): | |
return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8)) | |
def hull(self): | |
convex_hulls = [] | |
for frame in self.data: | |
contours, hierarchy = cv2.findContours( | |
frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
contours = [x.astype(np.int32) for x in contours] | |
mask_contours = [cv2.convexHull(np.concatenate(contours))] | |
canvas = np.zeros(self.data[0].shape, np.uint8) | |
convex_hull = cv2.drawContours( | |
canvas, mask_contours, -1, (255, 0, 0), -1) | |
convex_hulls.append(convex_hull) | |
return IImage(np.array(convex_hulls)) | |
def is_video(self): | |
return self.data.shape[0] > 1 | |
def __getitem__(self, idx): | |
return IImage(self.data[None, idx], fps=self.fps) | |
# if self.is_video(): return IImage(self.data[idx], fps = self.fps) | |
# return self | |
def _repr_png_(self): | |
if self.is_video(): | |
return None | |
if self.display_str is None: | |
self.generate_display() | |
return self.display_str | |
def _repr_html_(self): | |
if not self.is_video(): | |
return None | |
if self.display_str is None: | |
self.generate_display() | |
return self.display_str | |
def save(self, path): | |
_, ext = os.path.splitext(path) | |
if self.is_video(): | |
# if ext in ['.jpg', '.png']: | |
if self.display_str is None: | |
self.generate_display() | |
if ext == ".apng": | |
self.anim.anim_obj.save(path, writer="pillow") | |
else: | |
self.anim.anim_obj.save(path) | |
else: | |
data = self.data if self.data.ndim == 3 else self.data[0] | |
if data.shape[-1] == 1: | |
data = data[:, :, 0] | |
PIL.Image.fromarray(data).save(path) | |
return self | |
def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2): | |
if not isinstance(text, list): | |
text = [text for _ in self.data] | |
data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX, | |
font_scale, color, thickness) for x, t in zip(self.data, text)]) | |
return IImage(data) | |
def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0): | |
assert np.count_nonzero(padding) == 1 | |
axis_padding = np.nonzero(padding)[0][0] | |
scale_padding = padding[axis_padding] | |
y_0 = 0 | |
x_0 = 0 | |
if axis_padding == 0: | |
width = scale_padding | |
y_max = self.shape[1] | |
elif axis_padding == 1: | |
width = self.shape[2] | |
y_max = scale_padding | |
elif axis_padding == 2: | |
x_0 = self.shape[2] | |
width = scale_padding | |
y_max = self.shape[1] | |
elif axis_padding == 3: | |
width = self.shape[2] | |
y_0 = self.shape[1] | |
y_max = self.shape[1]+scale_padding | |
width -= center[0] | |
x_0 += center[0] | |
y_0 += center[1] | |
self = self.pad(padding, fill=fill) | |
def wrap_text(text, width, _font_scale): | |
allowed_seperator = ' |-|_|/|\n' | |
words = re.split(allowed_seperator, text) | |
# words = text.split() | |
lines = [] | |
current_line = words[0] | |
sep_list = [] | |
start_idx = 0 | |
for start_word in words[:-1]: | |
pos = text.find(start_word, start_idx) | |
pos += len(start_word) | |
sep_list.append(text[pos]) | |
start_idx = pos+1 | |
for word, separator in zip(words[1:], sep_list): | |
if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: | |
current_line += separator + word | |
else: | |
if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: | |
lines.append(current_line) | |
current_line = word | |
else: | |
return [] | |
if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width: | |
lines.append(current_line) | |
else: | |
return [] | |
return lines | |
def wrap_text_and_scale(text, width, _font_scale, y_0, y_max): | |
height = y_max+1 | |
while height > y_max: | |
text_lines = wrap_text(text, width, _font_scale) | |
if len(text) > 0 and len(text_lines) == 0: | |
height = y_max+1 | |
else: | |
line_height = cv2.getTextSize( | |
text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1] | |
height = line_height * len(text_lines) + y_0 | |
# scale font if out of frame | |
if height > y_max: | |
_font_scale = _font_scale * scale_factor | |
return text_lines, line_height, _font_scale | |
result = [] | |
if not isinstance(text, list): | |
text = [text for _ in self.data] | |
else: | |
assert len(text) == len(self.data) | |
for x, t in zip(self.data, text): | |
x = x.copy() | |
text_lines, line_height, _font_scale = wrap_text_and_scale( | |
t, width, font_scale, y_0, y_max) | |
y = line_height | |
for line in text_lines: | |
x = cv2.putText( | |
x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness) | |
y += line_height | |
result.append(x) | |
data = np.stack(result) | |
return IImage(data) | |
# ========== OPERATORS ============= | |
def __or__(self, other): | |
# TODO: fix for variable sizes | |
return IImage(np.concatenate([self.data, other.data], 2)) | |
def __truediv__(self, other): | |
# TODO: fix for variable sizes | |
return IImage(np.concatenate([self.data, other.data], 1)) | |
def __and__(self, other): | |
return IImage(np.concatenate([self.data, other.data], 0)) | |
def __add__(self, other): | |
return IImage(0.5 * self.data + 0.5 * other.data) | |
def __mul__(self, other): | |
if isinstance(other, IImage): | |
return IImage(self.data / 255. * other.data) | |
return IImage(self.data * other / 255.) | |
def __xor__(self, other): | |
return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0)) | |
def __invert__(self): | |
return IImage(255 - self.data) | |
__rmul__ = __mul__ | |
def bbox(self): | |
return [cv2.boundingRect(x) for x in self.data] | |
def fill_bbox(self, bbox_list, fill=255): | |
data = self.data.copy() | |
for bbox in bbox_list: | |
x, y, w, h = bbox | |
data[:, y:y+h, x:x+w, :] = fill | |
return IImage(data) | |
def crop(self, bbox): | |
assert len(bbox) in [2, 4] | |
if len(bbox) == 2: | |
x, y = 0, 0 | |
w, h = bbox | |
elif len(bbox) == 4: | |
x, y, w, h = bbox | |
return IImage(self.data[:, y:y+h, x:x+w, :]) | |
def stack(images, axis = 0): | |
return IImage(np.concatenate([x.data for x in images], axis)) | |
class Animation: | |
JS = 0 | |
HTML = 1 | |
ANIMATION_MODE = HTML | |
def __init__(self, frames, fps = 30): | |
"""_summary_ | |
Args: | |
frames (np.ndarray): _description_ | |
""" | |
self.frames = frames | |
self.fps = fps | |
self.anim_obj = None | |
self.anim_str = None | |
def render(self): | |
size = (self.frames.shape[2],self.frames.shape[1]) | |
self.fig = plt.figure(figsize = size, dpi = 1) | |
plt.axis('off') | |
img = plt.imshow(self.frames[0], cmap = 'gray') | |
self.fig.subplots_adjust(0,0,1,1) | |
self.anim_obj = animation.FuncAnimation( | |
self.fig, | |
lambda i: img.set_data(self.frames[i,:,:,:]), | |
frames=self.frames.shape[0], | |
interval = 1000 / self.fps | |
) | |
plt.close() | |
if Animation.ANIMATION_MODE == Animation.HTML: | |
self.anim_str = self.anim_obj.to_html5_video() | |
elif Animation.ANIMATION_MODE == Animation.JS: | |
self.anim_str = self.anim_obj.to_jshtml() | |
return self.anim_obj | |
def _repr_html_(self): | |
if self.anim_obj is None: self.render() | |
return self.anim_str |