osanseviero commited on
Commit
7f86517
·
1 Parent(s): e2a8fab

Create new file

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import bisect
3
+ import multiprocessing as mp
4
+ from collections import deque
5
+ import cv2
6
+ import torch
7
+
8
+ from detectron2.data import MetadataCatalog
9
+ from detectron2.engine.defaults import DefaultPredictor
10
+ from detectron2.utils.video_visualizer import VideoVisualizer
11
+ from detectron2.utils.visualizer import ColorMode, Visualizer
12
+ import argparse
13
+ import glob
14
+ import multiprocessing as mp
15
+ import numpy as np
16
+ import os
17
+ import tempfile
18
+ import time
19
+ import warnings
20
+ import cv2
21
+ import subprocess
22
+ import tqdm
23
+
24
+ from detectron2.config import get_cfg
25
+ from detectron2.data.detection_utils import read_image
26
+ from detectron2.utils.logger import setup_logger
27
+
28
+ import gradio as gr
29
+
30
+ TOTAL_FRAMES = 40
31
+
32
+ subprocess.run(["git", "clone", "https://github.com/wjf5203/VNext"])
33
+
34
+ def setup_cfg(cfg):
35
+ # load config from file and command-line arguments
36
+ cfg = get_cfg()
37
+ # To use demo for Panoptic-DeepLab, please uncomment the following two lines.
38
+ # from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config # noqa
39
+ # add_panoptic_deeplab_config(cfg)
40
+ cfg.merge_from_file("VNext/configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml")
41
+ # Set score_threshold for builtin models
42
+ cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5
43
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
44
+ cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
45
+ cfg.freeze()
46
+ return cfg
47
+
48
+ predictor = DefaultPredictor(setup_cfg({}))
49
+ metadata = MetadataCatalog.get("__unused")
50
+
51
+ def run_on_video(video, total_frames):
52
+ video_visualizer = VideoVisualizer(metadata, ColorMode.IMAGE)
53
+
54
+ def _frame_from_video(video):
55
+ while video.isOpened():
56
+ success, frame = video.read()
57
+ if success:
58
+ yield frame
59
+ else:
60
+ break
61
+
62
+ def process_predictions(frame, predictions):
63
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
64
+ if "panoptic_seg" in predictions:
65
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
66
+ vis_frame = video_visualizer.draw_panoptic_seg_predictions(
67
+ frame, panoptic_seg.to("cpu"), segments_info
68
+ )
69
+ elif "instances" in predictions:
70
+ predictions = predictions["instances"].to("cpu")
71
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
72
+ elif "sem_seg" in predictions:
73
+ vis_frame = video_visualizer.draw_sem_seg(
74
+ frame, predictions["sem_seg"].argmax(dim=0).to("cpu")
75
+ )
76
+
77
+ # Converts Matplotlib RGB format to OpenCV BGR format
78
+ vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
79
+ return vis_frame
80
+
81
+ frame_gen = _frame_from_video(video)
82
+ i = 0
83
+ for frame in frame_gen:
84
+ i += 1
85
+ if i == total_frames:
86
+ return
87
+ yield process_predictions(frame, predictor(frame))
88
+
89
+
90
+ def inference(video):
91
+ video = cv2.VideoCapture(video)
92
+
93
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
94
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
+
96
+ frames_per_second = video.get(cv2.CAP_PROP_FPS)
97
+ num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
98
+ print(num_frames)
99
+
100
+ if num_frames>TOTAL_FRAMES:
101
+ num_frames=TOTAL_FRAMES
102
+
103
+ codec, file_ext = (
104
+ ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
105
+ )
106
+ print(codec, file_ext)
107
+ output_fname = "result.mp4"
108
+ output_file = cv2.VideoWriter(
109
+ filename=output_fname,
110
+ fourcc=cv2.VideoWriter_fourcc(*codec),
111
+ fps=float(frames_per_second),
112
+ frameSize=(width, height),
113
+ isColor=True,
114
+ )
115
+ for vis_frame in tqdm.tqdm(run_on_video(video, num_frames), total=num_frames):
116
+ output_file.write(vis_frame)
117
+ video.release()
118
+ output_file.release()
119
+
120
+ out_file = tempfile.NamedTemporaryFile(suffix="out.mp4", delete=False)
121
+ subprocess.run(f"ffmpeg -y -loglevel quiet -stats -i {output_fname} -c:v libx264 {out_file.name}".split())
122
+ return out_file.name
123
+
124
+ video_interface = gr.Interface(
125
+ fn=inference,
126
+ inputs=[
127
+ gr.Video(type="file"),
128
+ ],
129
+ outputs=gr.Video(type="file", format="mp4"),
130
+ examples=[
131
+ ["inps.mp4"],
132
+ ],
133
+ allow_flagging=False,
134
+ allow_screenshot=False,
135
+ ).launch(debug=True)
136
+
137
+
138
+