hysts HF staff commited on
Commit
f740706
·
1 Parent(s): dc16691
Files changed (1) hide show
  1. app.py +163 -237
app.py CHANGED
@@ -10,9 +10,7 @@ import numpy as np
10
  import spaces
11
  import torch
12
  import torchvision
13
- from diffusers.utils.import_utils import is_xformers_available
14
  from huggingface_hub import snapshot_download
15
- from packaging import version
16
  from PIL import Image
17
  from scipy.interpolate import PchipInterpolator
18
 
@@ -39,55 +37,40 @@ snapshot_download(
39
  )
40
 
41
 
42
- def get_args():
43
- import argparse
 
44
 
45
- parser = argparse.ArgumentParser()
 
 
 
 
46
 
47
- parser.add_argument("--min_guidance_scale", type=float, default=1.0)
48
- parser.add_argument("--max_guidance_scale", type=float, default=3.0)
49
- parser.add_argument("--middle_max_guidance", type=int, default=0, choices=[0, 1])
50
- parser.add_argument("--with_control", type=int, default=1, choices=[0, 1])
51
 
52
- parser.add_argument("--controlnet_cond_scale", type=float, default=1.0)
53
-
54
- parser.add_argument(
55
- "--dataset",
56
- type=str,
57
- default="videoswap",
58
- )
59
-
60
- parser.add_argument(
61
- "--model",
62
- type=str,
63
- default="checkpoints/framer_512x320",
64
- help="Path to model.",
65
- )
66
-
67
- parser.add_argument("--output_dir", type=str, default="gradio_demo/outputs", help="Path to the output video.")
68
-
69
- parser.add_argument("--seed", type=int, default=42, help="random seed.")
70
-
71
- parser.add_argument("--noise_aug", type=float, default=0.02)
72
-
73
- parser.add_argument("--num_frames", type=int, default=14)
74
- parser.add_argument("--frame_interval", type=int, default=2)
75
-
76
- parser.add_argument("--width", type=int, default=512)
77
- parser.add_argument("--height", type=int, default=320)
78
-
79
- parser.add_argument(
80
- "--num_workers",
81
- type=int,
82
- default=0,
83
- help=(
84
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
85
- ),
86
- )
87
-
88
- args = parser.parse_args()
89
 
90
- return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  def interpolate_trajectory(points, n_points):
@@ -164,7 +147,7 @@ def get_vis_image(
164
  vis_img = new_img.copy()
165
  # ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
166
 
167
- if idxx >= args.num_frames:
168
  break
169
 
170
  # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
@@ -363,187 +346,6 @@ def validate_and_convert_image(image, target_size=(512, 512)):
363
  return image
364
 
365
 
366
- class Drag:
367
-
368
- @spaces.GPU
369
- def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
370
- self.device = device
371
- self.dtype = dtype
372
-
373
- unet = UNetSpatioTemporalConditionModel.from_pretrained(
374
- os.path.join(args.model, "unet"),
375
- torch_dtype=torch.float16,
376
- low_cpu_mem_usage=True,
377
- custom_resume=True,
378
- )
379
- unet = unet.to(device, dtype)
380
-
381
- controlnet = ControlNetSVDModel.from_pretrained(
382
- os.path.join(args.model, "controlnet"),
383
- )
384
- controlnet = controlnet.to(device, dtype)
385
-
386
- if is_xformers_available():
387
- import xformers
388
-
389
- xformers_version = version.parse(xformers.__version__)
390
- unet.enable_xformers_memory_efficient_attention()
391
- # controlnet.enable_xformers_memory_efficient_attention()
392
- else:
393
- raise ValueError("xformers is not available. Make sure it is installed correctly")
394
-
395
- pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
396
- "checkpoints/stable-video-diffusion-img2vid-xt",
397
- unet=unet,
398
- controlnet=controlnet,
399
- low_cpu_mem_usage=False,
400
- torch_dtype=torch.float16,
401
- variant="fp16",
402
- local_files_only=True,
403
- )
404
- pipe.to(device)
405
-
406
- self.pipeline = pipe
407
- # self.pipeline.enable_model_cpu_offload()
408
-
409
- self.height = height
410
- self.width = width
411
- self.args = args
412
- self.model_length = model_length
413
- self.use_sift = use_sift
414
-
415
- @spaces.GPU
416
- def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
417
- original_width, original_height = 512, 320 # TODO
418
-
419
- # load_image
420
- image = Image.open(first_frame_path).convert("RGB")
421
- width, height = image.size
422
- image = image.resize((self.width, self.height))
423
-
424
- image_end = Image.open(last_frame_path).convert("RGB")
425
- image_end = image_end.resize((self.width, self.height))
426
-
427
- input_all_points = tracking_points
428
-
429
- sift_track_update = False
430
- anchor_points_flag = None
431
-
432
- if (len(input_all_points) == 0) and self.use_sift:
433
- sift_track_update = True
434
- controlnet_cond_scale = 0.5
435
-
436
- from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
437
- from models_diffusers.sift_match import sift_match
438
-
439
- output_file_sift = os.path.join(args.output_dir, "sift.png")
440
-
441
- # (f, topk, 2), f=2 (before interpolation)
442
- pred_tracks = sift_match(
443
- image,
444
- image_end,
445
- thr=0.5,
446
- topk=5,
447
- method="random",
448
- output_path=output_file_sift,
449
- )
450
-
451
- if pred_tracks is not None:
452
- # interpolate the tracks, following draganything gradio demo
453
- pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=self.model_length)
454
-
455
- anchor_points_flag = torch.zeros((self.model_length, pred_tracks.shape[1])).to(pred_tracks.device)
456
- anchor_points_flag[0] = 1
457
- anchor_points_flag[-1] = 1
458
-
459
- pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
460
-
461
- else:
462
-
463
- resized_all_points = [
464
- tuple(
465
- [
466
- tuple([int(e1[0] * self.width / original_width), int(e1[1] * self.height / original_height)])
467
- for e1 in e
468
- ]
469
- )
470
- for e in input_all_points
471
- ]
472
-
473
- # a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
474
- # in image w & h scale
475
-
476
- for idx, splited_track in enumerate(resized_all_points):
477
- if len(splited_track) == 0:
478
- warnings.warn("running without point trajectory control")
479
- continue
480
-
481
- if len(splited_track) == 1: # stationary point
482
- displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
483
- splited_track = tuple([splited_track[0], displacement_point])
484
- # interpolate the track
485
- splited_track = interpolate_trajectory(splited_track, self.model_length)
486
- splited_track = splited_track[: self.model_length]
487
- resized_all_points[idx] = splited_track
488
-
489
- pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
490
-
491
- vis_images = get_vis_image(
492
- target_size=(self.args.height, self.args.width),
493
- points=pred_tracks,
494
- num_frames=self.model_length,
495
- )
496
-
497
- if len(pred_tracks.shape) != 3:
498
- print("pred_tracks.shape", pred_tracks.shape)
499
- with_control = False
500
- controlnet_cond_scale = 0.0
501
- else:
502
- with_control = True
503
- pred_tracks = pred_tracks.permute(1, 0, 2).to(self.device, self.dtype) # (num_frames, num_points, 2)
504
-
505
- point_embedding = None
506
- video_frames = self.pipeline(
507
- image,
508
- image_end,
509
- # trajectory control
510
- with_control=with_control,
511
- point_tracks=pred_tracks,
512
- point_embedding=point_embedding,
513
- with_id_feature=False,
514
- controlnet_cond_scale=controlnet_cond_scale,
515
- # others
516
- num_frames=14,
517
- width=width,
518
- height=height,
519
- # decode_chunk_size=8,
520
- # generator=generator,
521
- motion_bucket_id=motion_bucket_id,
522
- fps=7,
523
- num_inference_steps=30,
524
- # track
525
- sift_track_update=sift_track_update,
526
- anchor_points_flag=anchor_points_flag,
527
- ).frames[0]
528
-
529
- vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
530
- vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
531
- vis_images = [Image.fromarray(img) for img in vis_images]
532
-
533
- # video_frames = [img for sublist in video_frames for img in sublist]
534
- val_save_dir = os.path.join(args.output_dir, "vis_gif.gif")
535
- save_gifs_side_by_side(
536
- video_frames,
537
- vis_images[: self.model_length],
538
- val_save_dir,
539
- target_size=(self.width, self.height),
540
- duration=110,
541
- point_tracks=pred_tracks,
542
- )
543
-
544
- return val_save_dir
545
-
546
-
547
  def reset_states(first_frame_path, last_frame_path, tracking_points):
548
  first_frame_path = None
549
  last_frame_path = None
@@ -561,7 +363,7 @@ def preprocess_image(image):
561
  # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
562
  image_pil = image_pil.resize((512, 320), Image.BILINEAR)
563
 
564
- first_frame_path = os.path.join(args.output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
565
 
566
  image_pil.save(first_frame_path)
567
 
@@ -578,7 +380,7 @@ def preprocess_image_end(image_end):
578
  # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
579
  image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
580
 
581
- last_frame_path = os.path.join(args.output_dir, f"last_frame_{str(uuid.uuid4())[:4]}.png")
582
 
583
  image_end_pil.save(last_frame_path)
584
 
@@ -692,7 +494,7 @@ def add_tracking_points(
692
  transparent_layer = 0
693
  for idx, track in enumerate(tracking_points):
694
  # mask = cv2.imread(
695
- # os.path.join(args.output_dir, f"mask_{idx+1}.jpg")
696
  # )
697
  mask = np.zeros((320, 512, 3))
698
  color = color_list[idx + 1]
@@ -737,10 +539,136 @@ def add_tracking_points(
737
  return tracking_points, trajectory_map, trajectory_map_end
738
 
739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
  if __name__ == "__main__":
741
 
742
- args = get_args()
743
- ensure_dirname(args.output_dir)
744
 
745
  color_list = []
746
  for i in range(20):
@@ -771,8 +699,6 @@ if __name__ == "__main__":
771
  3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
772
  )
773
 
774
- # device, args, height, width, model_length
775
- Framer = Drag("cuda", args, 320, 512, 14)
776
  first_frame_path = gr.State()
777
  last_frame_path = gr.State()
778
  tracking_points = gr.State([])
@@ -898,7 +824,7 @@ if __name__ == "__main__":
898
  )
899
 
900
  run_button.click(
901
- fn=Framer.run,
902
  inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
903
  outputs=output_video,
904
  )
 
10
  import spaces
11
  import torch
12
  import torchvision
 
13
  from huggingface_hub import snapshot_download
 
14
  from PIL import Image
15
  from scipy.interpolate import PchipInterpolator
16
 
 
37
  )
38
 
39
 
40
+ model_id = "checkpoints/framer_512x320"
41
+ device = "cuda"
42
+ dtype = torch.float16
43
 
44
+ OUTPUT_DIR = "gradio_demo/outputs"
45
+ HEIGHT = 320
46
+ WIDTH = 512
47
+ MODEL_LENGTH = 14
48
+ USE_SIFT = False
49
 
 
 
 
 
50
 
51
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
52
+ os.path.join(model_id, "unet"),
53
+ torch_dtype=torch.float16,
54
+ low_cpu_mem_usage=True,
55
+ custom_resume=True,
56
+ )
57
+ unet = unet.to(device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ controlnet = ControlNetSVDModel.from_pretrained(
60
+ os.path.join(model_id, "controlnet"),
61
+ )
62
+ controlnet = controlnet.to(device, dtype)
63
+
64
+ pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
65
+ "checkpoints/stable-video-diffusion-img2vid-xt",
66
+ unet=unet,
67
+ controlnet=controlnet,
68
+ low_cpu_mem_usage=False,
69
+ torch_dtype=torch.float16,
70
+ variant="fp16",
71
+ local_files_only=True,
72
+ )
73
+ pipe.to(device)
74
 
75
 
76
  def interpolate_trajectory(points, n_points):
 
147
  vis_img = new_img.copy()
148
  # ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
149
 
150
+ if idxx >= num_frames:
151
  break
152
 
153
  # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
 
346
  return image
347
 
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def reset_states(first_frame_path, last_frame_path, tracking_points):
350
  first_frame_path = None
351
  last_frame_path = None
 
363
  # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
364
  image_pil = image_pil.resize((512, 320), Image.BILINEAR)
365
 
366
+ first_frame_path = os.path.join(OUTPUT_DIR, f"first_frame_{str(uuid.uuid4())[:4]}.png")
367
 
368
  image_pil.save(first_frame_path)
369
 
 
380
  # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
381
  image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
382
 
383
+ last_frame_path = os.path.join(OUTPUT_DIR, f"last_frame_{str(uuid.uuid4())[:4]}.png")
384
 
385
  image_end_pil.save(last_frame_path)
386
 
 
494
  transparent_layer = 0
495
  for idx, track in enumerate(tracking_points):
496
  # mask = cv2.imread(
497
+ # os.path.join(OUTPUT_DIR, f"mask_{idx+1}.jpg")
498
  # )
499
  mask = np.zeros((320, 512, 3))
500
  color = color_list[idx + 1]
 
539
  return tracking_points, trajectory_map, trajectory_map_end
540
 
541
 
542
+ @spaces.GPU
543
+ def run(first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
544
+ original_width, original_height = 512, 320 # TODO
545
+
546
+ # load_image
547
+ image = Image.open(first_frame_path).convert("RGB")
548
+ width, height = image.size
549
+ image = image.resize((WIDTH, HEIGHT))
550
+
551
+ image_end = Image.open(last_frame_path).convert("RGB")
552
+ image_end = image_end.resize((WIDTH, HEIGHT))
553
+
554
+ input_all_points = tracking_points
555
+
556
+ sift_track_update = False
557
+ anchor_points_flag = None
558
+
559
+ if (len(input_all_points) == 0) and USE_SIFT:
560
+ sift_track_update = True
561
+ controlnet_cond_scale = 0.5
562
+
563
+ from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
564
+ from models_diffusers.sift_match import sift_match
565
+
566
+ output_file_sift = os.path.join(OUTPUT_DIR, "sift.png")
567
+
568
+ # (f, topk, 2), f=2 (before interpolation)
569
+ pred_tracks = sift_match(
570
+ image,
571
+ image_end,
572
+ thr=0.5,
573
+ topk=5,
574
+ method="random",
575
+ output_path=output_file_sift,
576
+ )
577
+
578
+ if pred_tracks is not None:
579
+ # interpolate the tracks, following draganything gradio demo
580
+ pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=MODEL_LENGTH)
581
+
582
+ anchor_points_flag = torch.zeros((MODEL_LENGTH, pred_tracks.shape[1])).to(pred_tracks.device)
583
+ anchor_points_flag[0] = 1
584
+ anchor_points_flag[-1] = 1
585
+
586
+ pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
587
+
588
+ else:
589
+
590
+ resized_all_points = [
591
+ tuple([tuple([int(e1[0] * WIDTH / original_width), int(e1[1] * HEIGHT / original_height)]) for e1 in e])
592
+ for e in input_all_points
593
+ ]
594
+
595
+ # a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
596
+ # in image w & h scale
597
+
598
+ for idx, splited_track in enumerate(resized_all_points):
599
+ if len(splited_track) == 0:
600
+ warnings.warn("running without point trajectory control")
601
+ continue
602
+
603
+ if len(splited_track) == 1: # stationary point
604
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
605
+ splited_track = tuple([splited_track[0], displacement_point])
606
+ # interpolate the track
607
+ splited_track = interpolate_trajectory(splited_track, MODEL_LENGTH)
608
+ splited_track = splited_track[:MODEL_LENGTH]
609
+ resized_all_points[idx] = splited_track
610
+
611
+ pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
612
+
613
+ vis_images = get_vis_image(
614
+ target_size=(HEIGHT, WIDTH),
615
+ points=pred_tracks,
616
+ num_frames=MODEL_LENGTH,
617
+ )
618
+
619
+ if len(pred_tracks.shape) != 3:
620
+ print("pred_tracks.shape", pred_tracks.shape)
621
+ with_control = False
622
+ controlnet_cond_scale = 0.0
623
+ else:
624
+ with_control = True
625
+ pred_tracks = pred_tracks.permute(1, 0, 2).to(device, dtype) # (num_frames, num_points, 2)
626
+
627
+ point_embedding = None
628
+ video_frames = pipe(
629
+ image,
630
+ image_end,
631
+ # trajectory control
632
+ with_control=with_control,
633
+ point_tracks=pred_tracks,
634
+ point_embedding=point_embedding,
635
+ with_id_feature=False,
636
+ controlnet_cond_scale=controlnet_cond_scale,
637
+ # others
638
+ num_frames=14,
639
+ width=width,
640
+ height=height,
641
+ # decode_chunk_size=8,
642
+ # generator=generator,
643
+ motion_bucket_id=motion_bucket_id,
644
+ fps=7,
645
+ num_inference_steps=30,
646
+ # track
647
+ sift_track_update=sift_track_update,
648
+ anchor_points_flag=anchor_points_flag,
649
+ ).frames[0]
650
+
651
+ vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
652
+ vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
653
+ vis_images = [Image.fromarray(img) for img in vis_images]
654
+
655
+ # video_frames = [img for sublist in video_frames for img in sublist]
656
+ val_save_dir = os.path.join(OUTPUT_DIR, "vis_gif.gif")
657
+ save_gifs_side_by_side(
658
+ video_frames,
659
+ vis_images[:MODEL_LENGTH],
660
+ val_save_dir,
661
+ target_size=(WIDTH, HEIGHT),
662
+ duration=110,
663
+ point_tracks=pred_tracks,
664
+ )
665
+
666
+ return val_save_dir
667
+
668
+
669
  if __name__ == "__main__":
670
 
671
+ ensure_dirname(OUTPUT_DIR)
 
672
 
673
  color_list = []
674
  for i in range(20):
 
699
  3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
700
  )
701
 
 
 
702
  first_frame_path = gr.State()
703
  last_frame_path = gr.State()
704
  tracking_points = gr.State([])
 
824
  )
825
 
826
  run_button.click(
827
+ fn=run,
828
  inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
829
  outputs=output_video,
830
  )