xi0v Fabrice-TIERCELIN commited on
Commit
d3e5f59
·
verified ·
1 Parent(s): 7187257

Allow the user to force the model selection (and fix autorun) (#18)

Browse files

- Allow the user to force the model selection (and fix autorun) (83088413eb5c5d1b7f6cdf1076c41cf6895e8f45)


Co-authored-by: Fabrice TIERCELIN <[email protected]>

Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -34,7 +34,7 @@ def sample(
34
  noise_aug_strength: float = 0.1,
35
  decoding_t: int = 3,
36
  frame_format: str = "webp",
37
- version: str = "svd_xt",
38
  device: str = "cuda",
39
  output_folder: str = "outputs",
40
  ):
@@ -49,7 +49,7 @@ def sample(
49
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
50
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
51
 
52
- if 14 < fps_id:
53
  frames = fps25Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
54
  else:
55
  frames = fps14Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
@@ -105,6 +105,7 @@ with gr.Blocks() as demo:
105
  noise_aug_strength = gr.Slider(label="Noise strength", info="The noise to add", value=0.1, minimum=0, maximum=1, step=0.1)
106
  decoding_t = gr.Slider(label="Decoding", info="Number of frames decoded at a time; this eats more VRAM; reduce if necessary", value=3, minimum=1, maximum=5, step=1)
107
  frame_format = gr.Radio([["*.png", "png"], ["*.webp", "webp"], ["*.jpeg", "jpeg"], ["*.gif", "gif"], ["*.bmp", "bmp"]], label="Image format for result", info="File extention", value="webp", interactive=True)
 
108
  seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
109
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
110
 
@@ -115,18 +116,18 @@ with gr.Blocks() as demo:
115
  gallery = gr.Gallery(label="Generated frames")
116
 
117
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
118
- generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, frame_format], outputs=[video, gallery, seed], api_name="video")
119
 
120
  gr.Examples(
121
  examples=[
122
- ["Examples/Fire.webp", 25, 127, 0.1, 3, "png", 42, True],
123
- ["Examples/Town.jpeg", 25, 127, 0.1, 3, "png", 42, True],
124
- ["Examples/Water.png", 25, 127, 0.1, 3, "png", 42, True]
125
  ],
126
- inputs=[image, fps_id, motion_bucket_id, noise_aug_strength, decoding_t, frame_format, seed, randomize_seed],
127
  outputs=[video, gallery, seed],
128
  fn=sample,
129
- run_on_click=False,
130
  cache_examples=False,
131
  )
132
 
 
34
  noise_aug_strength: float = 0.1,
35
  decoding_t: int = 3,
36
  frame_format: str = "webp",
37
+ version: str = "auto",
38
  device: str = "cuda",
39
  output_folder: str = "outputs",
40
  ):
 
49
  base_count = len(glob(os.path.join(output_folder, "*.mp4")))
50
  video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
51
 
52
+ if version == "svdxt" or (14 < fps_id and version != "svd"):
53
  frames = fps25Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
54
  else:
55
  frames = fps14Pipe(image, decode_chunk_size=decoding_t, generator=generator, motion_bucket_id=motion_bucket_id, noise_aug_strength=noise_aug_strength, num_frames=25).frames[0]
 
105
  noise_aug_strength = gr.Slider(label="Noise strength", info="The noise to add", value=0.1, minimum=0, maximum=1, step=0.1)
106
  decoding_t = gr.Slider(label="Decoding", info="Number of frames decoded at a time; this eats more VRAM; reduce if necessary", value=3, minimum=1, maximum=5, step=1)
107
  frame_format = gr.Radio([["*.png", "png"], ["*.webp", "webp"], ["*.jpeg", "jpeg"], ["*.gif", "gif"], ["*.bmp", "bmp"]], label="Image format for result", info="File extention", value="webp", interactive=True)
108
+ version = gr.Radio([["Auto", "auto"], ["SVD (trained on 14 f/s)", "svd"], ["SVD-XT (trained on 25 f/s)", "svdxt"]], label="Model", info="Trained model", value="auto", interactive=True)
109
  seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
110
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
111
 
 
116
  gallery = gr.Gallery(label="Generated frames")
117
 
118
  image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
119
+ generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, frame_format, version], outputs=[video, gallery, seed], api_name="video")
120
 
121
  gr.Examples(
122
  examples=[
123
+ ["Examples/Fire.webp", 42, True, 127, 25, 0.1, 3, "png", "auto"],
124
+ ["Examples/Water.png", 42, True, 127, 25, 0.1, 3, "png", "auto"],
125
+ ["Examples/Town.jpeg", 42, True, 127, 25, 0.1, 3, "png", "auto"]
126
  ],
127
+ inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id, noise_aug_strength, decoding_t, frame_format, version],
128
  outputs=[video, gallery, seed],
129
  fn=sample,
130
+ run_on_click=True,
131
  cache_examples=False,
132
  )
133