Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from PIL import Image | |
import torch | |
import yaml | |
import math | |
import torchvision.transforms as T | |
from torchvision.io import read_video,write_video | |
import os | |
import random | |
import numpy as np | |
from torchvision.io import write_video | |
# from kornia.filters import joint_bilateral_blur | |
from kornia.geometry.transform import remap | |
from kornia.utils.grid import create_meshgrid | |
import cv2 | |
def save_video_frames(video_path, img_size=(512,512)): | |
video, _, = read_video(video_path, output_format="TCHW") | |
# rotate video -90 degree if video is .mov format. this is a weird bug in torchvision | |
if video_path.endswith('.mov'): | |
video = T.functional.rotate(video, -90) | |
video_name = Path(video_path).stem | |
os.makedirs(f'data/{video_name}', exist_ok=True) | |
for i in range(len(video)): | |
ind = str(i).zfill(5) | |
image = T.ToPILImage()(video[i]) | |
image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS) | |
image_resized.save(f'data/{video_name}/{ind}.png') | |
def video_to_frames(video_path, img_size=(512,512)): | |
video, _, video_info = read_video(video_path, output_format="TCHW") | |
# rotate video -90 degree if video is .mov format. this is a weird bug in torchvision | |
if video_path.endswith('.mov'): | |
video = T.functional.rotate(video, -90) | |
video_name = Path(video_path).stem | |
# os.makedirs(f'data/{video_name}', exist_ok=True) | |
frames = [] | |
for i in range(len(video)): | |
ind = str(i).zfill(5) | |
image = T.ToPILImage()(video[i]) | |
# get new height and width to maintain aspect ratio | |
height, width = image.size | |
new_height = math.floor(img_size[0] * height / width) | |
new_width = math.floor(img_size[1] * width / height) | |
# pad | |
image = Image.new(image.mode, (new_width, new_height), (0, 0, 0)) | |
image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS) | |
# image_resized.save(f'data/{video_name}/{ind}.png') | |
frames.append(image_resized) | |
return frames, video_info["video_fps"] | |
def add_dict_to_yaml_file(file_path, key, value): | |
data = {} | |
# If the file already exists, load its contents into the data dictionary | |
if os.path.exists(file_path): | |
with open(file_path, 'r') as file: | |
data = yaml.safe_load(file) | |
# Add or update the key-value pair | |
data[key] = value | |
# Save the data back to the YAML file | |
with open(file_path, 'w') as file: | |
yaml.dump(data, file) | |
def isinstance_str(x: object, cls_name: str): | |
""" | |
Checks whether x has any class *named* cls_name in its ancestry. | |
Doesn't require access to the class's implementation. | |
Useful for patching! | |
""" | |
for _cls in x.__class__.__mro__: | |
if _cls.__name__ == cls_name: | |
return True | |
return False | |
def batch_cosine_sim(x, y): | |
if type(x) is list: | |
x = torch.cat(x, dim=0) | |
if type(y) is list: | |
y = torch.cat(y, dim=0) | |
x = x / x.norm(dim=-1, keepdim=True) | |
y = y / y.norm(dim=-1, keepdim=True) | |
similarity = x @ y.T | |
return similarity | |
def load_imgs(data_path, n_frames, device='cuda', pil=False): | |
imgs = [] | |
pils = [] | |
for i in range(n_frames): | |
img_path = os.path.join(data_path, "%05d.jpg" % i) | |
if not os.path.exists(img_path): | |
img_path = os.path.join(data_path, "%05d.png" % i) | |
img_pil = Image.open(img_path) | |
pils.append(img_pil) | |
img = T.ToTensor()(img_pil).unsqueeze(0) | |
imgs.append(img) | |
if pil: | |
return torch.cat(imgs).to(device), pils | |
return torch.cat(imgs).to(device) | |
def save_video(raw_frames, save_path, fps=10): | |
video_codec = "libx264" | |
video_options = { | |
"crf": "18", # Constant Rate Factor (lower value = higher quality, 18 is a good balance) | |
"preset": "slow", # Encoding preset (e.g., ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow) | |
} | |
frames = (raw_frames * 255).to(torch.uint8).cpu().permute(0, 2, 3, 1) | |
write_video(save_path, frames, fps=fps, video_codec=video_codec, options=video_options) | |
def seed_everything(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |