Linoy Tsaban commited on
Commit
8832b9b
·
1 Parent(s): dd28623

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +121 -0
utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
+ import torch
4
+ import yaml
5
+ import math
6
+
7
+ import torchvision.transforms as T
8
+ from torchvision.io import read_video,write_video
9
+ import os
10
+ import random
11
+ import numpy as np
12
+ from torchvision.io import write_video
13
+ # from kornia.filters import joint_bilateral_blur
14
+ from kornia.geometry.transform import remap
15
+ from kornia.utils.grid import create_meshgrid
16
+ import cv2
17
+
18
+ def save_video_frames(video_path, img_size=(512,512)):
19
+ video, _, _ = read_video(video_path, output_format="TCHW")
20
+ # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
21
+ if video_path.endswith('.mov'):
22
+ video = T.functional.rotate(video, -90)
23
+ video_name = Path(video_path).stem
24
+ os.makedirs(f'data/{video_name}', exist_ok=True)
25
+ for i in range(len(video)):
26
+ ind = str(i).zfill(5)
27
+ image = T.ToPILImage()(video[i])
28
+ image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
29
+ image_resized.save(f'data/{video_name}/{ind}.png')
30
+
31
+ def video_to_frames(video_path, img_size=(512,512)):
32
+ video, _, _ = read_video(video_path, output_format="TCHW")
33
+ # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
34
+ if video_path.endswith('.mov'):
35
+ video = T.functional.rotate(video, -90)
36
+ video_name = Path(video_path).stem
37
+ # os.makedirs(f'data/{video_name}', exist_ok=True)
38
+ frames = []
39
+ for i in range(len(video)):
40
+ ind = str(i).zfill(5)
41
+ image = T.ToPILImage()(video[i])
42
+ image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
43
+ # image_resized.save(f'data/{video_name}/{ind}.png')
44
+ frames.append(image_resized)
45
+ return frames
46
+
47
+ def add_dict_to_yaml_file(file_path, key, value):
48
+ data = {}
49
+
50
+ # If the file already exists, load its contents into the data dictionary
51
+ if os.path.exists(file_path):
52
+ with open(file_path, 'r') as file:
53
+ data = yaml.safe_load(file)
54
+
55
+ # Add or update the key-value pair
56
+ data[key] = value
57
+
58
+ # Save the data back to the YAML file
59
+ with open(file_path, 'w') as file:
60
+ yaml.dump(data, file)
61
+
62
+ def isinstance_str(x: object, cls_name: str):
63
+ """
64
+ Checks whether x has any class *named* cls_name in its ancestry.
65
+ Doesn't require access to the class's implementation.
66
+
67
+ Useful for patching!
68
+ """
69
+
70
+ for _cls in x.__class__.__mro__:
71
+ if _cls.__name__ == cls_name:
72
+ return True
73
+
74
+ return False
75
+
76
+
77
+ def batch_cosine_sim(x, y):
78
+ if type(x) is list:
79
+ x = torch.cat(x, dim=0)
80
+ if type(y) is list:
81
+ y = torch.cat(y, dim=0)
82
+ x = x / x.norm(dim=-1, keepdim=True)
83
+ y = y / y.norm(dim=-1, keepdim=True)
84
+ similarity = x @ y.T
85
+ return similarity
86
+
87
+
88
+ def load_imgs(data_path, n_frames, device='cuda', pil=False):
89
+ imgs = []
90
+ pils = []
91
+ for i in range(n_frames):
92
+ img_path = os.path.join(data_path, "%05d.jpg" % i)
93
+ if not os.path.exists(img_path):
94
+ img_path = os.path.join(data_path, "%05d.png" % i)
95
+ img_pil = Image.open(img_path)
96
+ pils.append(img_pil)
97
+ img = T.ToTensor()(img_pil).unsqueeze(0)
98
+ imgs.append(img)
99
+ if pil:
100
+ return torch.cat(imgs).to(device), pils
101
+ return torch.cat(imgs).to(device)
102
+
103
+
104
+ def save_video(raw_frames, save_path, fps=10):
105
+ video_codec = "libx264"
106
+ video_options = {
107
+ "crf": "18", # Constant Rate Factor (lower value = higher quality, 18 is a good balance)
108
+ "preset": "slow", # Encoding preset (e.g., ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow)
109
+ }
110
+
111
+ frames = (raw_frames * 255).to(torch.uint8).cpu().permute(0, 2, 3, 1)
112
+ write_video(save_path, frames, fps=fps, video_codec=video_codec, options=video_options)
113
+
114
+
115
+ def seed_everything(seed):
116
+ torch.manual_seed(seed)
117
+ torch.cuda.manual_seed(seed)
118
+ random.seed(seed)
119
+ np.random.seed(seed)
120
+
121
+