ford442 commited on
Commit
be81944
·
verified ·
1 Parent(s): 1504958

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -20
app.py CHANGED
@@ -21,15 +21,12 @@ from PIL import Image
21
  import tempfile
22
  import os
23
  import gc
24
- from openai import OpenAI
25
  import csv
26
  from datetime import datetime
27
 
28
-
29
  # Load Hugging Face token if needed
30
  hf_token = os.getenv("HF_TOKEN")
31
- openai_api_key = os.getenv("OPENAI_API_KEY")
32
- client = OpenAI(api_key=openai_api_key)
33
  system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
34
  system_prompt_i2v_path = "assets/system_prompt_i2v.txt"
35
  with open(system_prompt_t2v_path, "r") as f:
@@ -48,7 +45,7 @@ vae_dir = Path(model_path) / "vae"
48
  unet_dir = Path(model_path) / "unet"
49
  scheduler_dir = Path(model_path) / "scheduler"
50
 
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
 
53
  DATA_DIR = "/data"
54
  os.makedirs(DATA_DIR, exist_ok=True)
@@ -57,7 +54,6 @@ LOG_FILE_PATH = os.path.join("/data", "user_requests.csv")
57
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
58
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
59
 
60
-
61
  if not os.path.exists(LOG_FILE_PATH):
62
  with open(LOG_FILE_PATH, "w", newline="") as f:
63
  writer = csv.writer(f)
@@ -80,7 +76,6 @@ if not os.path.exists(LOG_FILE_PATH):
80
  ]
81
  )
82
 
83
-
84
  @lru_cache(maxsize=128)
85
  def log_request(
86
  request_type,
@@ -123,7 +118,6 @@ def log_request(
123
  except Exception as e:
124
  print(f"Error logging request: {e}")
125
 
126
-
127
  def compute_clip_embedding(text=None, image=None):
128
  """
129
  Compute CLIP embedding for a given text or image.
@@ -138,7 +132,6 @@ def compute_clip_embedding(text=None, image=None):
138
  embedding = outputs.detach().cpu().numpy().flatten().tolist()
139
  return embedding
140
 
141
-
142
  def load_vae(vae_dir):
143
  vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
144
  vae_config_path = vae_dir / "config.json"
@@ -149,7 +142,6 @@ def load_vae(vae_dir):
149
  vae.load_state_dict(vae_state_dict)
150
  return vae.to(device=device, dtype=torch.bfloat16)
151
 
152
-
153
  def load_unet(unet_dir):
154
  unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
155
  unet_config_path = unet_dir / "config.json"
@@ -159,13 +151,11 @@ def load_unet(unet_dir):
159
  transformer.load_state_dict(unet_state_dict, strict=True)
160
  return transformer.to(device=device, dtype=torch.bfloat16)
161
 
162
-
163
  def load_scheduler(scheduler_dir):
164
  scheduler_config_path = scheduler_dir / "scheduler_config.json"
165
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
166
  return RectifiedFlowScheduler.from_config(scheduler_config)
167
 
168
-
169
  # Helper function for image processing
170
  def center_crop_and_resize(frame, target_height, target_width):
171
  h, w, _ = frame.shape
@@ -182,7 +172,6 @@ def center_crop_and_resize(frame, target_height, target_width):
182
  frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
183
  return frame_resized
184
 
185
-
186
  def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
187
  image = Image.open(image_path).convert("RGB")
188
  image_np = np.array(image)
@@ -191,7 +180,6 @@ def load_image_to_tensor_with_resize(image_path, target_height=512, target_width
191
  frame_tensor = (frame_tensor / 127.5) - 1.0
192
  return frame_tensor.unsqueeze(0).unsqueeze(2)
193
 
194
-
195
  def enhance_prompt_if_enabled(prompt, enhance_toggle, type="t2v"):
196
  if not enhance_toggle:
197
  print("Enhance toggle is off, Prompt: ", prompt)
@@ -215,7 +203,6 @@ def enhance_prompt_if_enabled(prompt, enhance_toggle, type="t2v"):
215
  print(f"Error: {e}")
216
  return prompt
217
 
218
-
219
  # Preset options for resolution and frame configuration
220
  preset_options = [
221
  {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
@@ -247,7 +234,6 @@ preset_options = [
247
  {"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
248
  ]
249
 
250
-
251
  # Function to toggle visibility of sliders based on preset selection
252
  def preset_changed(preset):
253
  if preset != "Custom":
@@ -270,7 +256,6 @@ def preset_changed(preset):
270
  gr.update(visible=True),
271
  )
272
 
273
-
274
  # Load models
275
  vae = load_vae(vae_dir)
276
  unet = load_unet(unet_dir)
@@ -288,7 +273,6 @@ pipeline = XoraVideoPipeline(
288
  vae=vae,
289
  ).to(device)
290
 
291
-
292
  def generate_video_from_text(
293
  prompt="",
294
  enhance_prompt_toggle=False,
@@ -490,7 +474,6 @@ def generate_video_from_image(
490
 
491
  return output_path
492
 
493
-
494
  def create_advanced_options():
495
  with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
496
  seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
@@ -531,7 +514,6 @@ def create_advanced_options():
531
  num_frames_slider,
532
  ]
533
 
534
-
535
  # Define the Gradio interface with tabs
536
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
537
  with gr.Row(elem_id="title-row"):
 
21
  import tempfile
22
  import os
23
  import gc
 
24
  import csv
25
  from datetime import datetime
26
 
 
27
  # Load Hugging Face token if needed
28
  hf_token = os.getenv("HF_TOKEN")
29
+
 
30
  system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
31
  system_prompt_i2v_path = "assets/system_prompt_i2v.txt"
32
  with open(system_prompt_t2v_path, "r") as f:
 
45
  unet_dir = Path(model_path) / "unet"
46
  scheduler_dir = Path(model_path) / "scheduler"
47
 
48
+ device = torch.device("cuda")
49
 
50
  DATA_DIR = "/data"
51
  os.makedirs(DATA_DIR, exist_ok=True)
 
54
  clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
55
  clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
56
 
 
57
  if not os.path.exists(LOG_FILE_PATH):
58
  with open(LOG_FILE_PATH, "w", newline="") as f:
59
  writer = csv.writer(f)
 
76
  ]
77
  )
78
 
 
79
  @lru_cache(maxsize=128)
80
  def log_request(
81
  request_type,
 
118
  except Exception as e:
119
  print(f"Error logging request: {e}")
120
 
 
121
  def compute_clip_embedding(text=None, image=None):
122
  """
123
  Compute CLIP embedding for a given text or image.
 
132
  embedding = outputs.detach().cpu().numpy().flatten().tolist()
133
  return embedding
134
 
 
135
  def load_vae(vae_dir):
136
  vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
137
  vae_config_path = vae_dir / "config.json"
 
142
  vae.load_state_dict(vae_state_dict)
143
  return vae.to(device=device, dtype=torch.bfloat16)
144
 
 
145
  def load_unet(unet_dir):
146
  unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
147
  unet_config_path = unet_dir / "config.json"
 
151
  transformer.load_state_dict(unet_state_dict, strict=True)
152
  return transformer.to(device=device, dtype=torch.bfloat16)
153
 
 
154
  def load_scheduler(scheduler_dir):
155
  scheduler_config_path = scheduler_dir / "scheduler_config.json"
156
  scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
157
  return RectifiedFlowScheduler.from_config(scheduler_config)
158
 
 
159
  # Helper function for image processing
160
  def center_crop_and_resize(frame, target_height, target_width):
161
  h, w, _ = frame.shape
 
172
  frame_resized = cv2.resize(frame_cropped, (target_width, target_height))
173
  return frame_resized
174
 
 
175
  def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
176
  image = Image.open(image_path).convert("RGB")
177
  image_np = np.array(image)
 
180
  frame_tensor = (frame_tensor / 127.5) - 1.0
181
  return frame_tensor.unsqueeze(0).unsqueeze(2)
182
 
 
183
  def enhance_prompt_if_enabled(prompt, enhance_toggle, type="t2v"):
184
  if not enhance_toggle:
185
  print("Enhance toggle is off, Prompt: ", prompt)
 
203
  print(f"Error: {e}")
204
  return prompt
205
 
 
206
  # Preset options for resolution and frame configuration
207
  preset_options = [
208
  {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
 
234
  {"label": "512x320, 257 frames", "width": 512, "height": 320, "num_frames": 257},
235
  ]
236
 
 
237
  # Function to toggle visibility of sliders based on preset selection
238
  def preset_changed(preset):
239
  if preset != "Custom":
 
256
  gr.update(visible=True),
257
  )
258
 
 
259
  # Load models
260
  vae = load_vae(vae_dir)
261
  unet = load_unet(unet_dir)
 
273
  vae=vae,
274
  ).to(device)
275
 
 
276
  def generate_video_from_text(
277
  prompt="",
278
  enhance_prompt_toggle=False,
 
474
 
475
  return output_path
476
 
 
477
  def create_advanced_options():
478
  with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
479
  seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
 
514
  num_frames_slider,
515
  ]
516
 
 
517
  # Define the Gradio interface with tabs
518
  with gr.Blocks(theme=gr.themes.Soft()) as iface:
519
  with gr.Row(elem_id="title-row"):