File size: 2,871 Bytes
38f6355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import sys
import os
import subprocess
import shutil
import tempfile
import uuid
import gradio as gr
from glob import glob
from huggingface_hub import snapshot_download

# Download models
os.makedirs("models", exist_ok=True)

snapshot_download(
    repo_id = "fffiloni/SVFR",
    local_dir = "./models"  
)

# List of subdirectories to create inside "checkpoints"
subfolders = [
    "stable-video-diffusion-img2vid-xt"
]
# Create each subdirectory
for subfolder in subfolders:
    os.makedirs(os.path.join("models", subfolder), exist_ok=True)

snapshot_download(
    repo_id = "stabilityai/stable-video-diffusion-img2vid-xt",
    local_dir = "./models/stable-video-diffusion-img2vid-xt"  
)

def infer(lq_sequence, task_name):
    
    unique_id = str(uuid.uuid4())
    output_dir = f"results_{unique_id}"

    if task_name == "BFR":
        task_id = "0"
    elif task_name == "colorization":
        task_id = "1"
    elif task_name == "BFR + colorization":
        task_id = "0,1"
    
    try:
        # Run the inference command
        subprocess.run(
            [
                "python", "infer.py",
                "--config", "config/infer.yaml"
                "--task_ids", f"{task_id}"
                "--input_path", f"{lq_sequence}"
                "--output_dir", f"{output_dir}",
            ],
            check=True
        )

        # Search for the mp4 file in a subfolder of output_dir
        output_video = glob(os.path.join(output_dir,"*.mp4"))
        print(output_video)
        
        if output_video:
            output_video_path = output_video[0]  # Get the first match
        else:
            output_video_path = None
        
        print(output_video_path)
        return output_video_path
    
    except subprocess.CalledProcessError as e:
        raise gr.Error(f"Error during inference: {str(e)}")

with gr.Blocks() as demo:
    with gr.Column():
        with gr.Row():
            with gr.Column():
                input_seq = gr.Video(label="Video LQ")
                task_name = gr.Radio(
                    label="Task", 
                    choices=["BFR", "colorization", "BFR + colorization"], 
                    value="BFR"
                )
                submit_btn = gr.Button("Submit")
            with gr.Column():
                output_res = gr.Video(label="Restored")
                gr.Examples(
                    examples = [
                        ["./assert/lq/lq1.mp4", "BFR"],
                        ["./assert/lq/lq2mp4", "BFR + colorization"],
                        ["./assert/lq/lq3.mp4", "colorization"]
                    ],
                    inputs = [input_seq, task_name]
                )
    
    submit_btn.click(
        fn = infer,
        inputs = [input_seq, task_name],
        outputs = [output_res]
    )

demo.queue().launch(show_api=False, show_error=True)