import torch import sys import os import subprocess import shutil import tempfile import uuid import gradio as gr from glob import glob from huggingface_hub import snapshot_download # Download models os.makedirs("models", exist_ok=True) snapshot_download( repo_id = "fffiloni/SVFR", local_dir = "./models" ) # List of subdirectories to create inside "checkpoints" subfolders = [ "stable-video-diffusion-img2vid-xt" ] # Create each subdirectory for subfolder in subfolders: os.makedirs(os.path.join("models", subfolder), exist_ok=True) snapshot_download( repo_id = "stabilityai/stable-video-diffusion-img2vid-xt", local_dir = "./models/stable-video-diffusion-img2vid-xt" ) def infer(lq_sequence, task_name): unique_id = str(uuid.uuid4()) output_dir = f"results_{unique_id}" if task_name == "BFR": task_id = "0" elif task_name == "colorization": task_id = "1" elif task_name == "BFR + colorization": task_id = "0,1" try: # Run the inference command subprocess.run( [ "python", "infer.py", "--config", "config/infer.yaml" "--task_ids", f"{task_id}" "--input_path", f"{lq_sequence}" "--output_dir", f"{output_dir}", ], check=True ) # Search for the mp4 file in a subfolder of output_dir output_video = glob(os.path.join(output_dir,"*.mp4")) print(output_video) if output_video: output_video_path = output_video[0] # Get the first match else: output_video_path = None print(output_video_path) return output_video_path except subprocess.CalledProcessError as e: raise gr.Error(f"Error during inference: {str(e)}") with gr.Blocks() as demo: with gr.Column(): with gr.Row(): with gr.Column(): input_seq = gr.Video(label="Video LQ") task_name = gr.Radio( label="Task", choices=["BFR", "colorization", "BFR + colorization"], value="BFR" ) submit_btn = gr.Button("Submit") with gr.Column(): output_res = gr.Video(label="Restored") gr.Examples( examples = [ ["./assert/lq/lq1.mp4", "BFR"], ["./assert/lq/lq2mp4", "BFR + colorization"], ["./assert/lq/lq3.mp4", "colorization"] ], inputs = [input_seq, task_name] ) submit_btn.click( fn = infer, inputs = [input_seq, task_name], outputs = [output_res] ) demo.queue().launch(show_api=False, show_error=True)