fffiloni commited on
Commit
2c5156d
·
verified ·
1 Parent(s): 7840297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -8,6 +8,7 @@ from transformers import T5EncoderModel, T5Tokenizer
8
 
9
  from datetime import datetime
10
  import random
 
11
 
12
  from huggingface_hub import hf_hub_download
13
 
@@ -38,26 +39,27 @@ pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_id, tokenizer=tokeniz
38
  # Add this near the top after imports
39
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
40
 
41
- def find_and_move_object_to_cpu():
42
- for obj in gc.get_objects():
43
- try:
44
- if isinstance(obj, torch.nn.Module):
45
- if any(param.is_cuda for param in obj.parameters()):
46
- obj.to('cpu')
47
- if any(buf.is_cuda for buf in obj.buffers()):
48
- obj.to('cpu')
49
- except Exception as e:
50
- pass
51
-
52
- def clear_gpu():
53
- torch.cuda.empty_cache()
54
- torch.cuda.synchronize()
55
- gc.collect()
56
 
57
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
58
  # Move everything to CPU initially
59
  pipe.to("cpu")
60
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
61
 
62
  lora_path = "checkpoints/"
63
  weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors"
@@ -74,7 +76,6 @@ def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True))
74
  torch.cuda.empty_cache()
75
 
76
  prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
77
- image = load_image(image_path)
78
  seed = random.randint(0, 2**8 - 1)
79
 
80
  with torch.inference_mode():
@@ -94,11 +95,31 @@ def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True))
94
  torch.cuda.empty_cache()
95
  gc.collect()
96
 
97
- # Generate output video
98
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
99
- export_to_video(video.frames[0], f"output_{timestamp}.mp4", fps=8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- return f"output_{timestamp}.mp4"
102
 
103
  # Set up Gradio UI
104
  with gr.Blocks(analytics_enabled=False) as demo:
 
8
 
9
  from datetime import datetime
10
  import random
11
+ from moviepy.editor import VideoFileClip
12
 
13
  from huggingface_hub import hf_hub_download
14
 
 
39
  # Add this near the top after imports
40
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
41
 
42
+ def calculate_resize_dimensions(width, height, max_width=1024):
43
+ """Calculate new dimensions maintaining aspect ratio"""
44
+ if width <= max_width:
45
+ return width, height
46
+
47
+ aspect_ratio = height / width
48
+ new_width = max_width
49
+ new_height = int(max_width * aspect_ratio)
50
+ return new_width, new_height
 
 
 
 
 
 
51
 
52
  def infer(image_path, prompt, orbit_type, progress=gr.Progress(track_tqdm=True)):
53
  # Move everything to CPU initially
54
  pipe.to("cpu")
55
  torch.cuda.empty_cache()
56
+
57
+ # Load and get original image dimensions
58
+ image = load_image(image_path)
59
+ original_width, original_height = image.size
60
+
61
+ # Calculate target dimensions maintaining aspect ratio
62
+ target_width, target_height = calculate_resize_dimensions(original_width, original_height)
63
 
64
  lora_path = "checkpoints/"
65
  weight_name = "orbit_left_lora_weights.safetensors" if orbit_type == "Left" else "orbit_up_lora_weights.safetensors"
 
76
  torch.cuda.empty_cache()
77
 
78
  prompt = f"{prompt}. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
 
79
  seed = random.randint(0, 2**8 - 1)
80
 
81
  with torch.inference_mode():
 
95
  torch.cuda.empty_cache()
96
  gc.collect()
97
 
98
+ # Generate initial output video
99
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
100
+ temp_path = f"output_{timestamp}_temp.mp4"
101
+ output_path = f"output_{timestamp}.mp4"
102
+
103
+ # Export initial video
104
+ export_to_video(video.frames[0], temp_path, fps=8)
105
+
106
+ # Resize using moviepy with h264 codec
107
+ video_clip = VideoFileClip(temp_path)
108
+ resized_clip = video_clip.resize(width=target_width, height=target_height)
109
+ resized_clip.write_videofile(
110
+ output_path,
111
+ codec='libx264',
112
+ fps=8,
113
+ preset='medium',
114
+ ffmpeg_params=['-crf', '23']
115
+ )
116
+
117
+ # Cleanup
118
+ video_clip.close()
119
+ resized_clip.close()
120
+ os.remove(temp_path)
121
 
122
+ return output_path
123
 
124
  # Set up Gradio UI
125
  with gr.Blocks(analytics_enabled=False) as demo: