yoavhacohen commited on
Commit
41b1cab
·
unverified ·
2 Parent(s): 00c2119 4bb89c5

Merge pull request #9 from LightricksResearch/image-to-video-for-zeev

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -1
  2. xora/examples/image_to_video.py +140 -44
requirements.txt CHANGED
@@ -3,4 +3,6 @@ diffusers==0.28.2
3
  transformers==4.44.2
4
  sentencepiece>=0.1.96
5
  accelerate
6
- einops
 
 
 
3
  transformers==4.44.2
4
  sentencepiece>=0.1.96
5
  accelerate
6
+ einops
7
+ matplotlib
8
+ opencv-python
xora/examples/image_to_video.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
  from xora.models.transformers.transformer3d import Transformer3DModel
@@ -9,6 +10,13 @@ from transformers import T5EncoderModel, T5Tokenizer
9
  import safetensors.torch
10
  import json
11
  import argparse
 
 
 
 
 
 
 
12
 
13
  def load_vae(vae_dir):
14
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
@@ -34,78 +42,166 @@ def load_scheduler(scheduler_dir):
34
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
35
  return RectifiedFlowScheduler.from_config(scheduler_config)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def main():
38
- # Parse command line arguments
39
- parser = argparse.ArgumentParser(description='Load models from separate directories')
40
- parser.add_argument('--separate_dir', type=str, required=True, help='Path to the directory containing unet, vae, and scheduler subdirectories')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  args = parser.parse_args()
42
 
43
  # Paths for the separate mode directories
44
- separate_dir = Path(args.separate_dir)
45
- unet_dir = separate_dir / 'unet'
46
- vae_dir = separate_dir / 'vae'
47
- scheduler_dir = separate_dir / 'scheduler'
48
 
49
  # Load models
50
  vae = load_vae(vae_dir)
51
  unet = load_unet(unet_dir)
52
  scheduler = load_scheduler(scheduler_dir)
53
-
54
- # Patchifier (remains the same)
55
  patchifier = SymmetricPatchifier(patch_size=1)
56
-
57
- # text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to("cuda")
58
- # tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
59
 
60
  # Use submodels for the pipeline
61
  submodel_dict = {
62
- "transformer": unet, # using unet for transformer
63
  "patchifier": patchifier,
64
- "text_encoder": None,
65
- "tokenizer": None,
66
  "scheduler": scheduler,
67
  "vae": vae,
68
  }
69
 
70
- model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
71
- pipeline = VideoPixArtAlphaPipeline(
72
- **submodel_dict
73
- ).to("cuda")
74
-
75
- num_inference_steps = 20
76
- num_images_per_prompt = 1
77
- guidance_scale = 3
78
- height = 512
79
- width = 768
80
- num_frames = 57
81
- frame_rate = 25
82
-
83
- # Sample input stays the same
84
- sample = torch.load("/opt/sample_media.pt")
85
- for key, item in sample.items():
86
- if item is not None:
87
- sample[key] = item.cuda()
88
 
89
- # media_items = torch.load("/opt/sample_media.pt")
 
 
 
 
 
90
 
91
- # Generate images (video frames)
92
  images = pipeline(
93
- num_inference_steps=num_inference_steps,
94
- num_images_per_prompt=num_images_per_prompt,
95
- guidance_scale=guidance_scale,
96
- generator=None,
97
  output_type="pt",
98
  callback_on_step_end=None,
99
- height=height,
100
- width=width,
101
- num_frames=num_frames,
102
- frame_rate=frame_rate,
103
  **sample,
104
  is_video=True,
105
  vae_per_channel_normalize=True,
 
106
  ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- print("Generated video frames.")
109
 
110
  if __name__ == "__main__":
111
  main()
 
1
+ import time
2
  import torch
3
  from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
4
  from xora.models.transformers.transformer3d import Transformer3DModel
 
10
  import safetensors.torch
11
  import json
12
  import argparse
13
+ from xora.utils.conditioning_method import ConditioningMethod
14
+ import os
15
+ import numpy as np
16
+ import cv2
17
+ from PIL import Image
18
+ from tqdm import tqdm
19
+ import random
20
 
21
  def load_vae(vae_dir):
22
  vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
 
42
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
43
  return RectifiedFlowScheduler.from_config(scheduler_config)
44
 
45
+ def center_crop_and_resize(frame, target_height, target_width):
46
+ h, w, _ = frame.shape
47
+ aspect_ratio_target = target_width / target_height
48
+ aspect_ratio_frame = w / h
49
+ if aspect_ratio_frame > aspect_ratio_target:
50
+ new_width = int(h * aspect_ratio_target)
51
+ x_start = (w - new_width) // 2
52
+ frame_cropped = frame[:, x_start:x_start + new_width]
53
+ else:
54
+ new_height = int(w / aspect_ratio_target)
55
+ y_start = (h - new_height) // 2
56
+ frame_cropped = frame[y_start:y_start + new_height, :]
57
+ frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
58
+ return frame_resized
59
+
60
+ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width=768):
61
+ cap = cv2.VideoCapture(video_path)
62
+ frames = []
63
+ while True:
64
+ ret, frame = cap.read()
65
+ if not ret:
66
+ break
67
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
68
+ frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
69
+ frames.append(frame_resized)
70
+ cap.release()
71
+ video_np = (np.array(frames) / 127.5) - 1.0
72
+ video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
73
+ return video_tensor
74
+
75
+ def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
76
+ image = Image.open(image_path).convert("RGB")
77
+ image_np = np.array(image)
78
+ frame_resized = center_crop_and_resize(image_np, target_height, target_width)
79
+ frame_tensor = torch.tensor(frame_resized).permute(2, 0, 1).float()
80
+ frame_tensor = (frame_tensor / 127.5) - 1.0
81
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
82
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
83
+
84
  def main():
85
+ parser = argparse.ArgumentParser(description='Load models from separate directories and run the pipeline.')
86
+
87
+ # Directories
88
+ parser.add_argument('--ckpt_dir', type=str, required=True,
89
+ help='Path to the directory containing unet, vae, and scheduler subdirectories')
90
+ parser.add_argument('--video_path', type=str,
91
+ help='Path to the input video file (first frame used)')
92
+ parser.add_argument('--image_path', type=str,
93
+ help='Path to the input image file')
94
+ parser.add_argument('--seed', type=int, default="171198")
95
+
96
+ # Pipeline parameters
97
+ parser.add_argument('--num_inference_steps', type=int, default=40, help='Number of inference steps')
98
+ parser.add_argument('--num_images_per_prompt', type=int, default=1, help='Number of images per prompt')
99
+ parser.add_argument('--guidance_scale', type=float, default=3, help='Guidance scale for the pipeline')
100
+ parser.add_argument('--height', type=int, default=512, help='Height of the output video frames')
101
+ parser.add_argument('--width', type=int, default=768, help='Width of the output video frames')
102
+ parser.add_argument('--num_frames', type=int, default=121, help='Number of frames to generate in the output video')
103
+ parser.add_argument('--frame_rate', type=int, default=25, help='Frame rate for the output video')
104
+
105
+ # Prompts
106
+ parser.add_argument('--prompt', type=str,
107
+ default='A man wearing a black leather jacket and blue jeans is riding a Harley Davidson motorcycle down a paved road. The man has short brown hair and is wearing a black helmet. The motorcycle is a dark red color with a large front fairing. The road is surrounded by green grass and trees. There is a gas station on the left side of the road with a red and white sign that says "Oil" and "Diner".',
108
+ help='Text prompt to guide generation')
109
+ parser.add_argument('--negative_prompt', type=str,
110
+ default='worst quality, inconsistent motion, blurry, jittery, distorted',
111
+ help='Negative prompt for undesired features')
112
+
113
  args = parser.parse_args()
114
 
115
  # Paths for the separate mode directories
116
+ ckpt_dir = Path(args.ckpt_dir)
117
+ unet_dir = ckpt_dir / 'unet'
118
+ vae_dir = ckpt_dir / 'vae'
119
+ scheduler_dir = ckpt_dir / 'scheduler'
120
 
121
  # Load models
122
  vae = load_vae(vae_dir)
123
  unet = load_unet(unet_dir)
124
  scheduler = load_scheduler(scheduler_dir)
 
 
125
  patchifier = SymmetricPatchifier(patch_size=1)
126
+ text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(
127
+ "cuda")
128
+ tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
129
 
130
  # Use submodels for the pipeline
131
  submodel_dict = {
132
+ "transformer": unet,
133
  "patchifier": patchifier,
134
+ "text_encoder": text_encoder,
135
+ "tokenizer": tokenizer,
136
  "scheduler": scheduler,
137
  "vae": vae,
138
  }
139
 
140
+ pipeline = VideoPixArtAlphaPipeline(**submodel_dict).to("cuda")
141
+
142
+ # Load media (video or image)
143
+ if args.video_path:
144
+ media_items = load_video_to_tensor_with_resize(args.video_path, args.height, args.width).unsqueeze(0)
145
+ elif args.image_path:
146
+ media_items = load_image_to_tensor_with_resize(args.image_path, args.height, args.width)
147
+ else:
148
+ raise ValueError("Either --video_path or --image_path must be provided.")
149
+
150
+ # Prepare input for the pipeline
151
+ sample = {
152
+ "prompt": args.prompt,
153
+ 'prompt_attention_mask': None,
154
+ 'negative_prompt': args.negative_prompt,
155
+ 'negative_prompt_attention_mask': None,
156
+ 'media_items': media_items,
157
+ }
158
 
159
+ start_time = time.time()
160
+ random.seed(args.seed)
161
+ np.random.seed(args.seed)
162
+ torch.manual_seed(args.seed)
163
+ torch.cuda.manual_seed(args.seed)
164
+ generator = torch.Generator(device="cuda").manual_seed(args.seed)
165
 
 
166
  images = pipeline(
167
+ num_inference_steps=args.num_inference_steps,
168
+ num_images_per_prompt=args.num_images_per_prompt,
169
+ guidance_scale=args.guidance_scale,
170
+ generator=generator,
171
  output_type="pt",
172
  callback_on_step_end=None,
173
+ height=args.height,
174
+ width=args.width,
175
+ num_frames=args.num_frames,
176
+ frame_rate=args.frame_rate,
177
  **sample,
178
  is_video=True,
179
  vae_per_channel_normalize=True,
180
+ conditioning_method=ConditioningMethod.FIRST_FRAME
181
  ).images
182
+ # Save output video
183
+ def get_unique_filename(base, ext, dir='.', index_range=1000):
184
+ for i in range(index_range):
185
+ filename = os.path.join(dir, f"{base}_{i}{ext}")
186
+ if not os.path.exists(filename):
187
+ return filename
188
+ raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
189
+
190
+
191
+ for i in range(images.shape[0]):
192
+ video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
193
+ video_np = (video_np * 255).astype(np.uint8)
194
+ fps = args.frame_rate
195
+ height, width = video_np.shape[1:3]
196
+ output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
197
+
198
+ out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
199
+
200
+ for frame in video_np[..., ::-1]:
201
+ out.write(frame)
202
+
203
+ out.release()
204
 
 
205
 
206
  if __name__ == "__main__":
207
  main()