hpoghos commited on
Commit
d67a615
·
1 Parent(s): 15809db
Files changed (3) hide show
  1. .gitignore +4 -1
  2. app.py +83 -25
  3. t2v_enhanced/model_func.py +22 -6
.gitignore CHANGED
@@ -14,4 +14,7 @@ t2v_enhanced/logs
14
  t2v_enhanced/slurm_logs
15
  t2v_enhanced/lightning_logs
16
  t2v_enhanced/results
17
- t2v_enhanced/gradio_output
 
 
 
 
14
  t2v_enhanced/slurm_logs
15
  t2v_enhanced/lightning_logs
16
  t2v_enhanced/results
17
+ t2v_enhanced/gradio_output
18
+ gradio_output/
19
+ lightning_logs/
20
+ t2v_enhanced/
app.py CHANGED
@@ -5,6 +5,7 @@ import argparse
5
  import datetime
6
  from pathlib import Path
7
  import torch
 
8
  import gradio as gr
9
  import tempfile
10
  import yaml
@@ -40,7 +41,10 @@ cfg_v2v = {'downscale': 1, 'upscale_size': (1280, 720), 'model_id': 'damo/Video-
40
  # ----- Initialization -----
41
  # --------------------------
42
  ms_model = init_modelscope(device)
43
- # zs_model = init_zeroscope(device)
 
 
 
44
  stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
45
  msxl_model = init_v2v_model(cfg_v2v)
46
 
@@ -50,7 +54,8 @@ inference_generator = torch.Generator(device="cuda")
50
  # -------------------------
51
  # ----- Functionality -----
52
  # -------------------------
53
- def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance, where_to_log=result_fol):
 
54
  now = datetime.datetime.now()
55
  name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
56
 
@@ -59,18 +64,59 @@ def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, n_
59
  else:
60
  num_frames = int(num_frames.split(" ")[0])
61
 
62
- n_autoreg_gen = num_frames/8-8
63
 
64
  inference_generator.manual_seed(seed)
65
- short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
66
- stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, name, stream_cli, stream_model)
 
 
 
 
 
 
 
67
  video_path = opj(where_to_log, name+".mp4")
68
  return video_path
69
 
70
- def enhance(prompt, input_to_enhance):
 
 
71
  encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model)
72
  return encoded_video
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # --------------------------
76
  # ----- Gradio-Demo UI -----
@@ -117,30 +163,32 @@ with gr.Blocks() as demo:
117
  with gr.Row():
118
  with gr.Column():
119
  with gr.Row():
120
- num_frames = gr.Dropdown(["24", "32", "40", "48", "56", "80 - only on local", "240 - only on local", "600 - only on local", "1200 - only on local", "10000 - only on local"], label="Number of Video Frames: Default is 56", info="For >80 frames use local workstation!")
121
  with gr.Row():
122
  prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder="Ex: Dog running on the street.")
123
  with gr.Row():
124
- image_stage1 = gr.Image(label='Image Prompt (only required for I2V base models)', show_label=True, scale=1, show_download_button=True)
125
  with gr.Column():
126
  video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True)
127
  with gr.Row():
128
- run_button_stage1 = gr.Button("Long Video Preview Generation")
 
 
 
129
 
130
  with gr.Row():
131
  with gr.Column():
132
  with gr.Accordion('Advanced options', open=False):
133
  model_name_stage1 = gr.Dropdown(
134
- choices=["T2V: ModelScope", "T2V: ZeroScope", "I2V: AnimateDiff"],
135
- label="Base Model. Default is ModelScope",
136
- info="Currently supports only ModelScope. We will add more options later!",
137
  )
138
  model_name_stage2 = gr.Dropdown(
139
- choices=["ModelScope-XL", "Another", "Another"],
140
- label="Enhancement Model. Default is ModelScope-XL",
141
- info="Currently supports only ModelScope-XL. We will add more options later!",
142
  )
143
- n_prompt = gr.Textbox(label="Optional Negative Prompt", value='')
144
  seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,)
145
 
146
  t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,)
@@ -148,9 +196,25 @@ with gr.Blocks() as demo:
148
 
149
  with gr.Column():
150
  with gr.Row():
151
- video_stage2 = gr.Video(label='Enhanced Long Video', show_label=True, interactive=False, height=473, show_download_button=True)
152
- with gr.Row():
153
- run_button_stage2 = gr.Button("Long Video Enhancement")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  '''
155
  '''
156
  gr.HTML(
@@ -174,12 +238,6 @@ with gr.Blocks() as demo:
174
  </div>
175
  """)
176
 
177
- inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, n_prompt, seed, t, image_guidance]
178
- run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,)
179
-
180
- inputs_v2v = [prompt_stage1, video_stage1]
181
- run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,)
182
-
183
 
184
  if on_huggingspace:
185
  demo.queue(max_size=20)
 
5
  import datetime
6
  from pathlib import Path
7
  import torch
8
+ import spaces
9
  import gradio as gr
10
  import tempfile
11
  import yaml
 
41
  # ----- Initialization -----
42
  # --------------------------
43
  ms_model = init_modelscope(device)
44
+ # # zs_model = init_zeroscope(device)
45
+ ad_model = init_animatediff(device)
46
+ svd_model = init_svd(device)
47
+ sdxl_model = init_sdxl(device)
48
  stream_cli, stream_model = init_streamingt2v_model(ckpt_file_streaming_t2v, result_fol)
49
  msxl_model = init_v2v_model(cfg_v2v)
50
 
 
54
  # -------------------------
55
  # ----- Functionality -----
56
  # -------------------------
57
+ @spaces.GPU
58
+ def generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance, where_to_log=result_fol):
59
  now = datetime.datetime.now()
60
  name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
61
 
 
64
  else:
65
  num_frames = int(num_frames.split(" ")[0])
66
 
67
+ n_autoreg_gen = num_frames//8-8
68
 
69
  inference_generator.manual_seed(seed)
70
+
71
+ if model_name_stage1 == "ModelScopeT2V (text to video)":
72
+ short_video = ms_short_gen(prompt, ms_model, inference_generator, t, device)
73
+ elif model_name_stage1 == "AnimateDiff (text to video)":
74
+ short_video = ad_short_gen(prompt, ad_model, inference_generator, t, device)
75
+ elif model_name_stage1 == "SVD (image to video)":
76
+ short_video = svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t, device)
77
+
78
+ stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, name, stream_cli, stream_model)
79
  video_path = opj(where_to_log, name+".mp4")
80
  return video_path
81
 
82
+ def enhance(prompt, input_to_enhance, num_frames=None, image=None, model_name_stage1=None, model_name_stage2=None, seed=33, t=50, image_guidance=9.5, result_fol=result_fol):
83
+ if input_to_enhance is None:
84
+ input_to_enhance = generate(prompt, num_frames, image, model_name_stage1, model_name_stage2, seed, t, image_guidance)
85
  encoded_video = video2video(prompt, input_to_enhance, result_fol, cfg_v2v, msxl_model)
86
  return encoded_video
87
 
88
+ def change_visibility(value):
89
+ if value == "SVD (image to video)":
90
+ return gr.Image(label='Image Prompt (if not attached then SDXL will be used to generate the starting image)', show_label=True, scale=1, show_download_button=False, interactive=True, type='pil')
91
+ else:
92
+ return gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False, type='pil')
93
+
94
+
95
+ examples = [
96
+ ["Camera moving in a wide bright ice cave.",
97
+ None, "24 - frames", None, "ModelScopeT2V (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
98
+ ["Explore the coral gardens of the sea: witness the kaleidoscope of colors and shapes as coral reefs provide shelter for a myriad of marine life.",
99
+ None, "24 - frames", None, "ModelScopeT2V (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
100
+ ["Experience the dance of jellyfish: float through mesmerizing swarms of jellyfish, pulsating with otherworldly grace and beauty.",
101
+ None, "24 - frames", None, "ModelScopeT2V (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
102
+ ["Discover the secret language of bees: delve into the complex communication system that allows bees to coordinate their actions and navigate the world.",
103
+ None, "24 - frames", None, "AnimateDiff (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
104
+ ["A beagle reading a paper.",
105
+ None, "24 - frames", None, "AnimateDiff (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
106
+ ["Beautiful Paris Day and Night Hyperlapse.",
107
+ None, "24 - frames", None, "AnimateDiff (text to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
108
+ ["Fishes swimming in ocean camera moving, cinematic.",
109
+ None, "24 - frames", "__assets__/fish.jpg", "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
110
+ ["A squirrel on a table full of big nuts.",
111
+ None, "24 - frames", "__assets__/squirrel.jpg", "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
112
+ ["Ants, beetles and centipede nest.",
113
+ None, "24 - frames", None, "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
114
+ ]
115
+
116
+ # examples = [
117
+ # ["Fishes swimming in ocean camera moving, cinematic.",
118
+ # None, "24 - frames", "__assets__/fish.jpg", "SVD (image to video)", "MS-Vid2Vid-XL", 33, 50, 9.0],
119
+ # ]
120
 
121
  # --------------------------
122
  # ----- Gradio-Demo UI -----
 
163
  with gr.Row():
164
  with gr.Column():
165
  with gr.Row():
166
+ num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames")
167
  with gr.Row():
168
  prompt_stage1 = gr.Textbox(label='Textual Prompt', placeholder="Ex: Dog running on the street.")
169
  with gr.Row():
170
+ image_stage1 = gr.Image(label='Image Prompt (first select Image-to-Video model from advanced options to enable image upload)', show_label=True, scale=1, show_download_button=False, interactive=False, type='pil')
171
  with gr.Column():
172
  video_stage1 = gr.Video(label='Long Video Preview', show_label=True, interactive=False, scale=2, show_download_button=True)
173
  with gr.Row():
174
+ with gr.Row():
175
+ run_button_stage1 = gr.Button("long Video Generation (faster preview)")
176
+ with gr.Row():
177
+ run_button_stage2 = gr.Button("long Video Generation")
178
 
179
  with gr.Row():
180
  with gr.Column():
181
  with gr.Accordion('Advanced options', open=False):
182
  model_name_stage1 = gr.Dropdown(
183
+ choices=["ModelScopeT2V (text to video)", "AnimateDiff (text to video)", "SVD (image to video)"],
184
+ label="Base Model",
185
+ value="ModelScopeT2V (text to video)"
186
  )
187
  model_name_stage2 = gr.Dropdown(
188
+ choices=["MS-Vid2Vid-XL"],
189
+ label="Enhancement Model",
190
+ value="MS-Vid2Vid-XL"
191
  )
 
192
  seed = gr.Slider(label='Seed', minimum=0, maximum=65536, value=33,step=1,)
193
 
194
  t = gr.Slider(label="Timesteps", minimum=0, maximum=100, value=50, step=1,)
 
196
 
197
  with gr.Column():
198
  with gr.Row():
199
+ video_stage2 = gr.Video(label='Long Video', show_label=True, interactive=False, height=588, show_download_button=True)
200
+
201
+ model_name_stage1.change(fn=change_visibility, inputs=[model_name_stage1], outputs=image_stage1)
202
+
203
+ inputs_t2v = [prompt_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, seed, t, image_guidance]
204
+ run_button_stage1.click(fn=generate, inputs=inputs_t2v, outputs=video_stage1,)
205
+
206
+ inputs_v2v = [prompt_stage1, video_stage1, num_frames, image_stage1, model_name_stage1, model_name_stage2, seed, t, image_guidance]
207
+
208
+ # gr.Examples(examples=examples,
209
+ # inputs=inputs_v2v,
210
+ # outputs=video_stage2,
211
+ # fn=enhance,
212
+ # run_on_click=False,
213
+ # # cache_examples=on_huggingspace,
214
+ # cache_examples=False,
215
+ # )
216
+ run_button_stage2.click(fn=enhance, inputs=inputs_v2v, outputs=video_stage2,)
217
+
218
  '''
219
  '''
220
  gr.HTML(
 
238
  </div>
239
  """)
240
 
 
 
 
 
 
 
241
 
242
  if on_huggingspace:
243
  demo.queue(max_size=20)
t2v_enhanced/model_func.py CHANGED
@@ -51,15 +51,20 @@ def sdxl_image_gen(prompt, sdxl_model):
51
  return image
52
 
53
  def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"):
54
- if image is None or image == "":
55
  image = sdxl_image_gen(prompt, sdxl_model)
56
  image = image.resize((576, 576))
57
  image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
58
- else:
59
  image = load_image(image)
60
  image = resize_and_keep(image)
61
  image = center_crop(image)
62
  image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
 
 
 
 
 
63
 
64
  frames = svd_model(image, decode_chunk_size=8, generator=inference_generator).frames[0]
65
  frames = torch.stack([transform(frame) for frame in frames])
@@ -70,9 +75,10 @@ def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=2
70
  return frames
71
 
72
 
73
- def stream_long_gen(prompt, short_video, n_autoreg_gen, n_prompt, seed, t, image_guidance, result_file_stem, stream_cli, stream_model):
74
  trainer = stream_cli.trainer
75
  trainer.limit_predict_batches = 1
 
76
  trainer.predict_cfg = {
77
  "predict_dir": stream_cli.config["result_fol"].as_posix(),
78
  "result_file_stem": result_file_stem,
@@ -93,7 +99,8 @@ def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True):
93
  pad = cfg_v2v['pad']
94
 
95
  now = datetime.datetime.now()
96
- name = prompt[:100].replace(" ", "_") + "_" + str(now.time()).replace(":", "_").replace(".", "_")
 
97
  enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4")
98
 
99
  video_frames = imageio.mimread(video)
@@ -107,11 +114,20 @@ def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True):
107
  video = [pad_to_fit(frame, upscale_size) for frame in video]
108
  # video = [np.array(frame) for frame in video]
109
 
110
- imageio.mimsave(opj(where_to_log, 'temp.mp4'), video, fps=8)
111
 
112
  p_input = {
113
- 'video_path': opj(where_to_log, 'temp.mp4'),
114
  'text': prompt
115
  }
116
  output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO]
 
 
 
 
 
 
 
 
 
117
  return enhanced_video_mp4
 
51
  return image
52
 
53
  def svd_short_gen(image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"):
54
+ if image is None:
55
  image = sdxl_image_gen(prompt, sdxl_model)
56
  image = image.resize((576, 576))
57
  image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
58
+ elif type(image) is str:
59
  image = load_image(image)
60
  image = resize_and_keep(image)
61
  image = center_crop(image)
62
  image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
63
+ else:
64
+ image = Image.fromarray(np.uint8(image))
65
+ image = resize_and_keep(image)
66
+ image = center_crop(image)
67
+ image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
68
 
69
  frames = svd_model(image, decode_chunk_size=8, generator=inference_generator).frames[0]
70
  frames = torch.stack([transform(frame) for frame in frames])
 
75
  return frames
76
 
77
 
78
+ def stream_long_gen(prompt, short_video, n_autoreg_gen, seed, t, image_guidance, result_file_stem, stream_cli, stream_model):
79
  trainer = stream_cli.trainer
80
  trainer.limit_predict_batches = 1
81
+
82
  trainer.predict_cfg = {
83
  "predict_dir": stream_cli.config["result_fol"].as_posix(),
84
  "result_file_stem": result_file_stem,
 
99
  pad = cfg_v2v['pad']
100
 
101
  now = datetime.datetime.now()
102
+ now = str(now.time()).replace(":", "_").replace(".", "_")
103
+ name = prompt[:100].replace(" ", "_") + "_" + now
104
  enhanced_video_mp4 = opj(where_to_log, name+"_enhanced.mp4")
105
 
106
  video_frames = imageio.mimread(video)
 
114
  video = [pad_to_fit(frame, upscale_size) for frame in video]
115
  # video = [np.array(frame) for frame in video]
116
 
117
+ imageio.mimsave(opj(where_to_log, 'temp_'+now+'.mp4'), video, fps=8)
118
 
119
  p_input = {
120
+ 'video_path': opj(where_to_log, 'temp_'+now+'.mp4'),
121
  'text': prompt
122
  }
123
  output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO]
124
+
125
+ # Remove padding
126
+ video_frames = imageio.mimread(enhanced_video_mp4)
127
+ video_frames_square = []
128
+ for frame in video_frames:
129
+ frame = frame[:, 280:-280, :]
130
+ video_frames_square.append(frame)
131
+ imageio.mimsave(enhanced_video_mp4, video_frames_square)
132
+
133
  return enhanced_video_mp4