Spaces:
Runtime error
Runtime error
Image to video script: make determinist by random seed.
Browse files- xora/examples/image_to_video.py +23 -10
xora/examples/image_to_video.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
3 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
@@ -14,6 +15,8 @@ import os
|
|
14 |
import numpy as np
|
15 |
import cv2
|
16 |
from PIL import Image
|
|
|
|
|
17 |
|
18 |
def load_vae(vae_dir):
|
19 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
@@ -65,9 +68,8 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
|
|
65 |
frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
|
66 |
frames.append(frame_resized)
|
67 |
cap.release()
|
68 |
-
video_np = np.array(frames)
|
69 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
70 |
-
video_tensor = (video_tensor / 127.5) - 1.0
|
71 |
return video_tensor
|
72 |
|
73 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
@@ -154,9 +156,13 @@ def main():
|
|
154 |
'media_items': media_items,
|
155 |
}
|
156 |
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
# Run the pipeline
|
160 |
images = pipeline(
|
161 |
num_inference_steps=args.num_inference_steps,
|
162 |
num_images_per_prompt=args.num_images_per_prompt,
|
@@ -173,20 +179,27 @@ def main():
|
|
173 |
vae_per_channel_normalize=True,
|
174 |
conditioning_method=ConditioningMethod.FIRST_FRAME
|
175 |
).images
|
176 |
-
|
177 |
# Save output video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
for i in range(images.shape[0]):
|
179 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
180 |
video_np = (video_np * 255).astype(np.uint8)
|
181 |
fps = args.frame_rate
|
182 |
height, width = video_np.shape[1:3]
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
(width, height))
|
188 |
for frame in video_np[..., ::-1]:
|
189 |
out.write(frame)
|
|
|
190 |
out.release()
|
191 |
|
192 |
|
|
|
1 |
+
import time
|
2 |
import torch
|
3 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
4 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
|
|
15 |
import numpy as np
|
16 |
import cv2
|
17 |
from PIL import Image
|
18 |
+
from tqdm import tqdm
|
19 |
+
import random
|
20 |
|
21 |
def load_vae(vae_dir):
|
22 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
|
|
68 |
frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
|
69 |
frames.append(frame_resized)
|
70 |
cap.release()
|
71 |
+
video_np = (np.array(frames) / 127.5) - 1.0
|
72 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
|
|
73 |
return video_tensor
|
74 |
|
75 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
|
|
156 |
'media_items': media_items,
|
157 |
}
|
158 |
|
159 |
+
start_time = time.time()
|
160 |
+
random.seed(args.seed)
|
161 |
+
np.random.seed(args.seed)
|
162 |
+
torch.manual_seed(args.seed)
|
163 |
+
torch.cuda.manual_seed(args.seed)
|
164 |
+
generator = torch.Generator(device="cuda").manual_seed(args.seed)
|
165 |
|
|
|
166 |
images = pipeline(
|
167 |
num_inference_steps=args.num_inference_steps,
|
168 |
num_images_per_prompt=args.num_images_per_prompt,
|
|
|
179 |
vae_per_channel_normalize=True,
|
180 |
conditioning_method=ConditioningMethod.FIRST_FRAME
|
181 |
).images
|
|
|
182 |
# Save output video
|
183 |
+
def get_unique_filename(base, ext, dir='.', index_range=1000):
|
184 |
+
for i in range(index_range):
|
185 |
+
filename = os.path.join(dir, f"{base}_{i}{ext}")
|
186 |
+
if not os.path.exists(filename):
|
187 |
+
return filename
|
188 |
+
raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
|
189 |
+
|
190 |
+
|
191 |
for i in range(images.shape[0]):
|
192 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
193 |
video_np = (video_np * 255).astype(np.uint8)
|
194 |
fps = args.frame_rate
|
195 |
height, width = video_np.shape[1:3]
|
196 |
+
output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
|
197 |
+
|
198 |
+
out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
199 |
+
|
|
|
200 |
for frame in video_np[..., ::-1]:
|
201 |
out.write(frame)
|
202 |
+
|
203 |
out.release()
|
204 |
|
205 |
|