Spaces:
vztu
/
Runtime error

nanushio commited on
Commit
6ab99a7
·
1 Parent(s): f318285

- [MINOR] [SCRIPT] [CREATE] 1. create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+
5
+ import argparse
6
+ import pickle as pkl
7
+
8
+ import decord
9
+ from decord import VideoReader
10
+ import numpy as np
11
+ import yaml
12
+
13
+ from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition
14
+ from cover.models import COVER
15
+
16
+ mean, std = (
17
+ torch.FloatTensor([123.675, 116.28, 103.53]),
18
+ torch.FloatTensor([58.395, 57.12, 57.375]),
19
+ )
20
+
21
+ mean_clip, std_clip = (
22
+ torch.FloatTensor([122.77, 116.75, 104.09]),
23
+ torch.FloatTensor([68.50, 66.63, 70.32])
24
+ )
25
+
26
+ def fuse_results(results: list):
27
+ x = (results[0] + results[1] + results[2])
28
+ return {
29
+ "semantic" : results[0],
30
+ "technical": results[1],
31
+ "aesthetic": results[2],
32
+ "overall" : x,
33
+ }
34
+
35
+ def inference_one_video(input_video):
36
+ """
37
+ BASIC SETTINGS
38
+ """
39
+ torch.cuda.current_device()
40
+ torch.cuda.empty_cache()
41
+ torch.backends.cudnn.benchmark = True
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ with open("./cover.yml", "r") as f:
44
+ opt = yaml.safe_load(f)
45
+
46
+ dopt = opt["data"]["val-ytugc"]["args"]
47
+ temporal_samplers = {}
48
+ for stype, sopt in dopt["sample_types"].items():
49
+ temporal_samplers[stype] = UnifiedFrameSampler(
50
+ sopt["clip_len"] // sopt["t_frag"],
51
+ sopt["t_frag"],
52
+ sopt["frame_interval"],
53
+ sopt["num_clips"],
54
+ )
55
+
56
+ """
57
+ LOAD MODEL
58
+ """
59
+ evaluator = COVER(**opt["model"]["args"]).to(device)
60
+ state_dict = torch.load(opt["test_load_path"], map_location=device)
61
+
62
+ # set strict=False here to avoid error of missing
63
+ # weight of prompt_learner in clip-iqa+, cross-gate
64
+ evaluator.load_state_dict(state_dict['state_dict'], strict=False)
65
+
66
+ """
67
+ TESTING
68
+ """
69
+ views, _ = spatial_temporal_view_decomposition(
70
+ input_video, dopt["sample_types"], temporal_samplers
71
+ )
72
+
73
+ for k, v in views.items():
74
+ num_clips = dopt["sample_types"][k].get("num_clips", 1)
75
+ if k == 'technical' or k == 'aesthetic':
76
+ views[k] = (
77
+ ((v.permute(1, 2, 3, 0) - mean) / std)
78
+ .permute(3, 0, 1, 2)
79
+ .reshape(v.shape[0], num_clips, -1, *v.shape[2:])
80
+ .transpose(0, 1)
81
+ .to(device)
82
+ )
83
+ elif k == 'semantic':
84
+ views[k] = (
85
+ ((v.permute(1, 2, 3, 0) - mean_clip) / std_clip)
86
+ .permute(3, 0, 1, 2)
87
+ .reshape(v.shape[0], num_clips, -1, *v.shape[2:])
88
+ .transpose(0, 1)
89
+ .to(device)
90
+ )
91
+
92
+ results = [r.mean().item() for r in evaluator(views)]
93
+ pred_score = fuse_results(results)
94
+ return pred_score
95
+
96
+ # Define the input and output types for Gradio
97
+ video_input = gr.inputs.Video(type="numpy", label="Input Video")
98
+ output_label = gr.outputs.JSON(label="Scores")
99
+
100
+ # Create the Gradio interface
101
+ gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_label)
102
+
103
+ if __name__ == "__main__":
104
+ gradio_app.launch()