fffiloni commited on
Commit
38f6355
·
verified ·
1 Parent(s): 23de020

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import os
4
+ import subprocess
5
+ import shutil
6
+ import tempfile
7
+ import uuid
8
+ import gradio as gr
9
+ from glob import glob
10
+ from huggingface_hub import snapshot_download
11
+
12
+ # Download models
13
+ os.makedirs("models", exist_ok=True)
14
+
15
+ snapshot_download(
16
+ repo_id = "fffiloni/SVFR",
17
+ local_dir = "./models"
18
+ )
19
+
20
+ # List of subdirectories to create inside "checkpoints"
21
+ subfolders = [
22
+ "stable-video-diffusion-img2vid-xt"
23
+ ]
24
+ # Create each subdirectory
25
+ for subfolder in subfolders:
26
+ os.makedirs(os.path.join("models", subfolder), exist_ok=True)
27
+
28
+ snapshot_download(
29
+ repo_id = "stabilityai/stable-video-diffusion-img2vid-xt",
30
+ local_dir = "./models/stable-video-diffusion-img2vid-xt"
31
+ )
32
+
33
+ def infer(lq_sequence, task_name):
34
+
35
+ unique_id = str(uuid.uuid4())
36
+ output_dir = f"results_{unique_id}"
37
+
38
+ if task_name == "BFR":
39
+ task_id = "0"
40
+ elif task_name == "colorization":
41
+ task_id = "1"
42
+ elif task_name == "BFR + colorization":
43
+ task_id = "0,1"
44
+
45
+ try:
46
+ # Run the inference command
47
+ subprocess.run(
48
+ [
49
+ "python", "infer.py",
50
+ "--config", "config/infer.yaml"
51
+ "--task_ids", f"{task_id}"
52
+ "--input_path", f"{lq_sequence}"
53
+ "--output_dir", f"{output_dir}",
54
+ ],
55
+ check=True
56
+ )
57
+
58
+ # Search for the mp4 file in a subfolder of output_dir
59
+ output_video = glob(os.path.join(output_dir,"*.mp4"))
60
+ print(output_video)
61
+
62
+ if output_video:
63
+ output_video_path = output_video[0] # Get the first match
64
+ else:
65
+ output_video_path = None
66
+
67
+ print(output_video_path)
68
+ return output_video_path
69
+
70
+ except subprocess.CalledProcessError as e:
71
+ raise gr.Error(f"Error during inference: {str(e)}")
72
+
73
+ with gr.Blocks() as demo:
74
+ with gr.Column():
75
+ with gr.Row():
76
+ with gr.Column():
77
+ input_seq = gr.Video(label="Video LQ")
78
+ task_name = gr.Radio(
79
+ label="Task",
80
+ choices=["BFR", "colorization", "BFR + colorization"],
81
+ value="BFR"
82
+ )
83
+ submit_btn = gr.Button("Submit")
84
+ with gr.Column():
85
+ output_res = gr.Video(label="Restored")
86
+ gr.Examples(
87
+ examples = [
88
+ ["./assert/lq/lq1.mp4", "BFR"],
89
+ ["./assert/lq/lq2mp4", "BFR + colorization"],
90
+ ["./assert/lq/lq3.mp4", "colorization"]
91
+ ],
92
+ inputs = [input_seq, task_name]
93
+ )
94
+
95
+ submit_btn.click(
96
+ fn = infer,
97
+ inputs = [input_seq, task_name],
98
+ outputs = [output_res]
99
+ )
100
+
101
+ demo.queue().launch(show_api=False, show_error=True)