wwen1997 commited on
Commit
6cc7fb3
·
verified ·
1 Parent(s): 5cdf939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -148
app.py CHANGED
@@ -527,6 +527,154 @@ class Drag:
527
  return val_save_dir
528
 
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  if __name__ == "__main__":
531
 
532
  args = get_args()
@@ -564,153 +712,6 @@ if __name__ == "__main__":
564
  last_frame_path = gr.State()
565
  tracking_points = gr.State([])
566
 
567
- def reset_states(first_frame_path, last_frame_path, tracking_points):
568
- first_frame_path = gr.State()
569
- last_frame_path = gr.State()
570
- tracking_points = gr.State([])
571
-
572
- return first_frame_path, last_frame_path, tracking_points
573
-
574
-
575
- def preprocess_image(image):
576
-
577
- image_pil = image2pil(image.name)
578
-
579
- raw_w, raw_h = image_pil.size
580
- # resize_ratio = max(512 / raw_w, 320 / raw_h)
581
- # image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
582
- # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
583
- image_pil = image_pil.resize((512, 320), Image.BILINEAR)
584
-
585
- first_frame_path = os.path.join(args.output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
586
-
587
- image_pil.save(first_frame_path)
588
-
589
- return first_frame_path, first_frame_path, gr.State([])
590
-
591
-
592
- def preprocess_image_end(image_end):
593
-
594
- image_end_pil = image2pil(image_end.name)
595
-
596
- raw_w, raw_h = image_end_pil.size
597
- # resize_ratio = max(512 / raw_w, 320 / raw_h)
598
- # image_end_pil = image_end_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
599
- # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
600
- image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
601
-
602
- last_frame_path = os.path.join(args.output_dir, f"last_frame_{str(uuid.uuid4())[:4]}.png")
603
-
604
- image_end_pil.save(last_frame_path)
605
-
606
- return last_frame_path, last_frame_path, gr.State([])
607
-
608
-
609
- def add_drag(tracking_points):
610
- tracking_points.constructor_args['value'].append([])
611
- return tracking_points
612
-
613
-
614
- def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
615
- tracking_points.constructor_args['value'].pop()
616
- transparent_background = Image.open(first_frame_path).convert('RGBA')
617
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
618
- w, h = transparent_background.size
619
- transparent_layer = np.zeros((h, w, 4))
620
-
621
- for track in tracking_points.constructor_args['value']:
622
- if len(track) > 1:
623
- for i in range(len(track)-1):
624
- start_point = track[i]
625
- end_point = track[i+1]
626
- vx = end_point[0] - start_point[0]
627
- vy = end_point[1] - start_point[1]
628
- arrow_length = np.sqrt(vx**2 + vy**2)
629
- if i == len(track)-2:
630
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
631
- else:
632
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
633
- else:
634
- cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
635
-
636
- transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
637
- trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
638
- trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
639
-
640
- return tracking_points, trajectory_map, trajectory_map_end
641
-
642
-
643
- def delete_last_step(tracking_points, first_frame_path, last_frame_path):
644
- tracking_points.constructor_args['value'][-1].pop()
645
- transparent_background = Image.open(first_frame_path).convert('RGBA')
646
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
647
- w, h = transparent_background.size
648
- transparent_layer = np.zeros((h, w, 4))
649
-
650
- for track in tracking_points.constructor_args['value']:
651
- if len(track) > 1:
652
- for i in range(len(track)-1):
653
- start_point = track[i]
654
- end_point = track[i+1]
655
- vx = end_point[0] - start_point[0]
656
- vy = end_point[1] - start_point[1]
657
- arrow_length = np.sqrt(vx**2 + vy**2)
658
- if i == len(track)-2:
659
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
660
- else:
661
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
662
- else:
663
- cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
664
-
665
- transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
666
- trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
667
- trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
668
-
669
- return tracking_points, trajectory_map, trajectory_map_end
670
-
671
-
672
- def add_tracking_points(tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
673
- print(f"You selected {evt.value} at {evt.index} from {evt.target}")
674
- tracking_points.constructor_args['value'][-1].append(evt.index)
675
-
676
- transparent_background = Image.open(first_frame_path).convert('RGBA')
677
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
678
-
679
- w, h = transparent_background.size
680
- transparent_layer = 0
681
- for idx, track in enumerate(tracking_points.constructor_args['value']):
682
- # mask = cv2.imread(
683
- # os.path.join(args.output_dir, f"mask_{idx+1}.jpg")
684
- # )
685
- mask = np.zeros((320, 512, 3))
686
- color = color_list[idx+1]
687
- transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
688
-
689
- if len(track) > 1:
690
- for i in range(len(track)-1):
691
- start_point = track[i]
692
- end_point = track[i+1]
693
- vx = end_point[0] - start_point[0]
694
- vy = end_point[1] - start_point[1]
695
- arrow_length = np.sqrt(vx**2 + vy**2)
696
- if i == len(track)-2:
697
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
698
- else:
699
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
700
- else:
701
- cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
702
-
703
- transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
704
- alpha_coef = 0.99
705
- im2_data = transparent_layer.getdata()
706
- new_im2_data = [(r, g, b, int(a * alpha_coef)) for r, g, b, a in im2_data]
707
- transparent_layer.putdata(new_im2_data)
708
-
709
- trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
710
- trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
711
-
712
- return tracking_points, trajectory_map, trajectory_map_end
713
-
714
  with gr.Row():
715
  with gr.Column(scale=1):
716
  image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
@@ -798,4 +799,4 @@ if __name__ == "__main__":
798
 
799
  run_button.click(Framer.run, [first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id], output_video)
800
 
801
- demo.launch(max_threads=1)
 
527
  return val_save_dir
528
 
529
 
530
+ def reset_states(first_frame_path, last_frame_path, tracking_points):
531
+ first_frame_path = gr.State()
532
+ last_frame_path = gr.State()
533
+ tracking_points = gr.State([])
534
+
535
+ return first_frame_path, last_frame_path, tracking_points
536
+
537
+
538
+ def preprocess_image(image):
539
+
540
+ image_pil = image2pil(image.name)
541
+
542
+ raw_w, raw_h = image_pil.size
543
+ # resize_ratio = max(512 / raw_w, 320 / raw_h)
544
+ # image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
545
+ # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
546
+ image_pil = image_pil.resize((512, 320), Image.BILINEAR)
547
+
548
+ first_frame_path = os.path.join(args.output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
549
+
550
+ image_pil.save(first_frame_path)
551
+
552
+ return first_frame_path, first_frame_path, gr.State([])
553
+
554
+
555
+ def preprocess_image_end(image_end):
556
+
557
+ image_end_pil = image2pil(image_end.name)
558
+
559
+ raw_w, raw_h = image_end_pil.size
560
+ # resize_ratio = max(512 / raw_w, 320 / raw_h)
561
+ # image_end_pil = image_end_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
562
+ # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
563
+ image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
564
+
565
+ last_frame_path = os.path.join(args.output_dir, f"last_frame_{str(uuid.uuid4())[:4]}.png")
566
+
567
+ image_end_pil.save(last_frame_path)
568
+
569
+ return last_frame_path, last_frame_path, gr.State([])
570
+
571
+
572
+ def add_drag(tracking_points):
573
+ tracking_points.constructor_args['value'].append([])
574
+ return tracking_points
575
+
576
+
577
+ def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
578
+ tracking_points.constructor_args['value'].pop()
579
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
580
+ transparent_background_end = Image.open(last_frame_path).convert('RGBA')
581
+ w, h = transparent_background.size
582
+ transparent_layer = np.zeros((h, w, 4))
583
+
584
+ for track in tracking_points.constructor_args['value']:
585
+ if len(track) > 1:
586
+ for i in range(len(track)-1):
587
+ start_point = track[i]
588
+ end_point = track[i+1]
589
+ vx = end_point[0] - start_point[0]
590
+ vy = end_point[1] - start_point[1]
591
+ arrow_length = np.sqrt(vx**2 + vy**2)
592
+ if i == len(track)-2:
593
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
594
+ else:
595
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
596
+ else:
597
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
598
+
599
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
600
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
601
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
602
+
603
+ return tracking_points, trajectory_map, trajectory_map_end
604
+
605
+
606
+ def delete_last_step(tracking_points, first_frame_path, last_frame_path):
607
+ tracking_points.constructor_args['value'][-1].pop()
608
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
609
+ transparent_background_end = Image.open(last_frame_path).convert('RGBA')
610
+ w, h = transparent_background.size
611
+ transparent_layer = np.zeros((h, w, 4))
612
+
613
+ for track in tracking_points.constructor_args['value']:
614
+ if len(track) > 1:
615
+ for i in range(len(track)-1):
616
+ start_point = track[i]
617
+ end_point = track[i+1]
618
+ vx = end_point[0] - start_point[0]
619
+ vy = end_point[1] - start_point[1]
620
+ arrow_length = np.sqrt(vx**2 + vy**2)
621
+ if i == len(track)-2:
622
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
623
+ else:
624
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
625
+ else:
626
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
627
+
628
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
629
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
630
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
631
+
632
+ return tracking_points, trajectory_map, trajectory_map_end
633
+
634
+
635
+ def add_tracking_points(tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
636
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
637
+ tracking_points.constructor_args['value'][-1].append(evt.index)
638
+
639
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
640
+ transparent_background_end = Image.open(last_frame_path).convert('RGBA')
641
+
642
+ w, h = transparent_background.size
643
+ transparent_layer = 0
644
+ for idx, track in enumerate(tracking_points.constructor_args['value']):
645
+ # mask = cv2.imread(
646
+ # os.path.join(args.output_dir, f"mask_{idx+1}.jpg")
647
+ # )
648
+ mask = np.zeros((320, 512, 3))
649
+ color = color_list[idx+1]
650
+ transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
651
+
652
+ if len(track) > 1:
653
+ for i in range(len(track)-1):
654
+ start_point = track[i]
655
+ end_point = track[i+1]
656
+ vx = end_point[0] - start_point[0]
657
+ vy = end_point[1] - start_point[1]
658
+ arrow_length = np.sqrt(vx**2 + vy**2)
659
+ if i == len(track)-2:
660
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
661
+ else:
662
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
663
+ else:
664
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
665
+
666
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
667
+ alpha_coef = 0.99
668
+ im2_data = transparent_layer.getdata()
669
+ new_im2_data = [(r, g, b, int(a * alpha_coef)) for r, g, b, a in im2_data]
670
+ transparent_layer.putdata(new_im2_data)
671
+
672
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
673
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
674
+
675
+ return tracking_points, trajectory_map, trajectory_map_end
676
+
677
+
678
  if __name__ == "__main__":
679
 
680
  args = get_args()
 
712
  last_frame_path = gr.State()
713
  tracking_points = gr.State([])
714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  with gr.Row():
716
  with gr.Column(scale=1):
717
  image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
 
799
 
800
  run_button.click(Framer.run, [first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id], output_video)
801
 
802
+ demo.launch()