""" This script is based on the original project by https://huggingface.co/fffiloni. URL: https://huggingface.co/spaces/fffiloni/SVFR-demo/blob/main/app.py Modifications made: - Synced the infer code updates from GitHub repo. - Added an inpainting option to enhance functionality. Author of modifications: https://github.com/wangzhiyaoo Date: 2025/01/15 """ import torch import sys import os import subprocess import shutil import tempfile import uuid import gradio as gr from glob import glob from huggingface_hub import snapshot_download import random import argparse import warnings import os import numpy as np import torch import torch.utils.checkpoint from PIL import Image import random from omegaconf import OmegaConf from diffusers import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from transformers import CLIPVisionModelWithProjection import torchvision.transforms as transforms import torch.nn.functional as F from src.models.svfr_adapter.unet_3d_svd_condition_ip import UNet3DConditionSVDModel # pipeline from src.pipelines.pipeline import LQ2VideoLongSVDPipeline from src.utils.util import ( save_videos_grid, seed_everything, ) from torchvision.utils import save_image from src.models.id_proj import IDProjConvModel from src.models import model_insightface_360k from src.dataset.face_align.align import AlignImage warnings.filterwarnings("ignore") import decord import cv2 from src.dataset.dataset import get_affine_transform, mean_face_lm5p_256, get_union_bbox, process_bbox, crop_resize_img # Download models os.makedirs("models", exist_ok=True) snapshot_download( repo_id = "fffiloni/SVFR", local_dir = "./models" ) # List of subdirectories to create inside "checkpoints" subfolders = [ "stable-video-diffusion-img2vid-xt" ] # Create each subdirectory for subfolder in subfolders: os.makedirs(os.path.join("models", subfolder), exist_ok=True) snapshot_download( repo_id = "stabilityai/stable-video-diffusion-img2vid-xt", local_dir = "./models/stable-video-diffusion-img2vid-xt" ) is_shared_ui = True if "fffiloni/SVFR-demo" in os.environ['SPACE_ID'] else False from moviepy.editor import VideoFileClip def process_video(input_path): # Create a temporary folder temp_folder = "temp_video_output" os.makedirs(temp_folder, exist_ok=True) # Get file extension file_extension = os.path.splitext(input_path)[1].lower() # Convert to MP4 if not already MP4 if file_extension != ".mp4": mp4_path = os.path.join(temp_folder, "converted.mp4") clip = VideoFileClip(input_path) clip.write_videofile(mp4_path, codec="libx264") clip.close() else: mp4_path = input_path # Already MP4 # Load video and trim to 4 seconds if necessary output_path = os.path.join(temp_folder, "trimmed.mp4") with VideoFileClip(mp4_path) as clip: if clip.duration > 4: trimmed_clip = clip.subclip(0, 4) # Trim to 4 seconds trimmed_clip.write_videofile(output_path, codec="libx264") trimmed_clip.close() else: # If no trimming is needed, copy the file to output_path if mp4_path != output_path: shutil.copy(mp4_path, output_path) # Return the output path return output_path, temp_folder BASE_DIR = '.' config = OmegaConf.load("./config/infer.yaml") vae = AutoencoderKLTemporalDecoder.from_pretrained( f"{BASE_DIR}/{config.pretrained_model_name_or_path}", subfolder="vae", variant="fp16") val_noise_scheduler = EulerDiscreteScheduler.from_pretrained( f"{BASE_DIR}/{config.pretrained_model_name_or_path}", subfolder="scheduler") image_encoder = CLIPVisionModelWithProjection.from_pretrained( f"{BASE_DIR}/{config.pretrained_model_name_or_path}", subfolder="image_encoder", variant="fp16") unet = UNet3DConditionSVDModel.from_pretrained( f"{BASE_DIR}/{config.pretrained_model_name_or_path}", subfolder="unet", variant="fp16") weight_dir = 'models/face_align' det_path = os.path.join(BASE_DIR, weight_dir, 'yoloface_v5m.pt') align_instance = AlignImage("cuda", det_path=det_path) to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) import torch.nn as nn class InflatedConv3d(nn.Conv2d): def forward(self, x): x = super().forward(x) return x # Add ref channel old_weights = unet.conv_in.weight old_bias = unet.conv_in.bias new_conv1 = InflatedConv3d( 12, old_weights.shape[0], kernel_size=unet.conv_in.kernel_size, stride=unet.conv_in.stride, padding=unet.conv_in.padding, bias=True if old_bias is not None else False, ) param = torch.zeros((320, 4, 3, 3), requires_grad=True) new_conv1.weight = torch.nn.Parameter(torch.cat((old_weights, param), dim=1)) if old_bias is not None: new_conv1.bias = old_bias unet.conv_in = new_conv1 unet.config["in_channels"] = 12 unet.config.in_channels = 12 id_linear = IDProjConvModel(in_channels=512, out_channels=1024).to(device='cuda') # load pretrained weights unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path) unet.load_state_dict( torch.load(unet_checkpoint_path, map_location="cpu"), strict=True, ) id_linear_checkpoint_path = os.path.join(BASE_DIR, config.id_linear_checkpoint_path) id_linear.load_state_dict( torch.load(id_linear_checkpoint_path, map_location="cpu"), strict=True, ) net_arcface = model_insightface_360k.getarcface(f'{BASE_DIR}/{config.net_arcface_checkpoint_path}').eval().to(device="cuda") if config.weight_dtype == "fp16": weight_dtype = torch.float16 elif config.weight_dtype == "fp32": weight_dtype = torch.float32 elif config.weight_dtype == "bf16": weight_dtype = torch.bfloat16 else: raise ValueError( f"Do not support weight dtype: {config.weight_dtype} during training" ) image_encoder.to(weight_dtype) vae.to(weight_dtype) unet.to(weight_dtype) id_linear.to(weight_dtype) net_arcface.requires_grad_(False).to(weight_dtype) pipe = LQ2VideoLongSVDPipeline( unet=unet, image_encoder=image_encoder, vae=vae, scheduler=val_noise_scheduler, feature_extractor=None ) pipe = pipe.to("cuda", dtype=unet.dtype) def gen(args,pipe): save_dir = f"{BASE_DIR}/{args.output_dir}" os.makedirs(save_dir,exist_ok=True) seed_input = args.seed seed_everything(seed_input) video_path = args.input_path task_ids = args.task_ids if 2 in task_ids and args.mask_path is not None: mask_path = args.mask_path mask = Image.open(mask_path).convert("L") mask_array = np.array(mask) white_positions = mask_array == 255 print('task_ids:',task_ids) task_prompt = [0,0,0] for i in range(3): if i in task_ids: task_prompt[i] = 1 print("task_prompt:",task_prompt) video_name = video_path.split('/')[-1] # print(video_name) if os.path.exists(os.path.join(save_dir, "result_frames", video_name[:-4])): print(os.path.join(save_dir, "result_frames", video_name[:-4])) # continue cap = decord.VideoReader(video_path, fault_tol=1) total_frames = len(cap) T = total_frames # print("total_frames:",total_frames) step=1 drive_idx_start = 0 drive_idx_list = list(range(drive_idx_start, drive_idx_start + T * step, step)) assert len(drive_idx_list) == T # Crop faces from the video for further processing bbox_list = [] frame_interval = 5 for frame_count, drive_idx in enumerate(drive_idx_list): if frame_count % frame_interval != 0: continue frame = cap[drive_idx].asnumpy() _, _, bboxes_list = align_instance(frame[:,:,[2,1,0]], maxface=True) if bboxes_list==[]: continue x1, y1, ww, hh = bboxes_list[0] x2, y2 = x1 + ww, y1 + hh bbox = [x1, y1, x2, y2] bbox_list.append(bbox) bbox = get_union_bbox(bbox_list) bbox_s = process_bbox(bbox, expand_radio=0.4, height=frame.shape[0], width=frame.shape[1]) imSameIDs = [] vid_gt = [] for i, drive_idx in enumerate(drive_idx_list): frame = cap[drive_idx].asnumpy() imSameID = Image.fromarray(frame) imSameID = crop_resize_img(imSameID, bbox_s) imSameID = imSameID.resize((512,512)) if 1 in task_ids: imSameID = imSameID.convert("L") # Convert to grayscale imSameID = imSameID.convert("RGB") image_array = np.array(imSameID) if 2 in task_ids and args.mask_path is not None: image_array[white_positions] = [255, 255, 255] # mask for inpainting task vid_gt.append(np.float32(image_array/255.)) imSameIDs.append(imSameID) vid_lq = [(torch.from_numpy(frame).permute(2,0,1) - 0.5) / 0.5 for frame in vid_gt] val_data = dict( pixel_values_vid_lq = torch.stack(vid_lq,dim=0), # pixel_values_ref_img=self.to_tensor(target_image), # pixel_values_ref_concat_img=self.to_tensor(imSrc2), task_ids=task_ids, task_id_input=torch.tensor(task_prompt), total_frames=total_frames, ) window_overlap=0 inter_frame_list = get_overlap_slide_window_indices(val_data["total_frames"],config.data.n_sample_frames,window_overlap) lq_frames = val_data["pixel_values_vid_lq"] task_ids = val_data["task_ids"] task_id_input = val_data["task_id_input"] height, width = val_data["pixel_values_vid_lq"].shape[-2:] print("Generating the first clip...") output = pipe( lq_frames[inter_frame_list[0]].to("cuda").to(weight_dtype), # lq None, # ref concat torch.zeros((1, len(inter_frame_list[0]), 49, 1024)).to("cuda").to(weight_dtype),# encoder_hidden_states task_id_input.to("cuda").to(weight_dtype), height=height, width=width, num_frames=len(inter_frame_list[0]), decode_chunk_size=config.decode_chunk_size, noise_aug_strength=config.noise_aug_strength, min_guidance_scale=config.min_appearance_guidance_scale, max_guidance_scale=config.max_appearance_guidance_scale, overlap=config.overlap, frames_per_batch=len(inter_frame_list[0]), num_inference_steps=50, i2i_noise_strength=config.i2i_noise_strength, ) video = output.frames ref_img_tensor = video[0][:,-1] ref_img = (video[0][:,-1] *0.5+0.5).clamp(0,1) * 255. ref_img = ref_img.permute(1,2,0).cpu().numpy().astype(np.uint8) pts5 = align_instance(ref_img[:,:,[2,1,0]], maxface=True)[0][0] warp_mat = get_affine_transform(pts5, mean_face_lm5p_256 * height/256) ref_img = cv2.warpAffine(np.array(Image.fromarray(ref_img)), warp_mat, (height, width), flags=cv2.INTER_CUBIC) ref_img = to_tensor(ref_img).to("cuda").to(weight_dtype) save_image(ref_img*0.5 + 0.5,f"{save_dir}/ref_img_align.png") ref_img = F.interpolate(ref_img.unsqueeze(0)[:, :, 0:224, 16:240], size=[112, 112], mode='bilinear') _, id_feature_conv = net_arcface(ref_img) id_embedding = id_linear(id_feature_conv) print('Generating all video clips...') video = pipe( lq_frames.to("cuda").to(weight_dtype), # lq ref_img_tensor.to("cuda").to(weight_dtype), id_embedding.unsqueeze(1).repeat(1, len(lq_frames), 1, 1).to("cuda").to(weight_dtype), # encoder_hidden_states task_id_input.to("cuda").to(weight_dtype), height=height, width=width, num_frames=val_data["total_frames"],#frame_num, decode_chunk_size=config.decode_chunk_size, noise_aug_strength=config.noise_aug_strength, min_guidance_scale=config.min_appearance_guidance_scale, max_guidance_scale=config.max_appearance_guidance_scale, overlap=config.overlap, frames_per_batch=config.data.n_sample_frames, num_inference_steps=config.num_inference_steps, i2i_noise_strength=config.i2i_noise_strength, ).frames video = (video*0.5 + 0.5).clamp(0, 1) video = torch.cat([video.to(device="cuda")], dim=0).cpu() save_videos_grid(video, f"{save_dir}/{video_name[:-4]}_{seed_input}_gen.mp4", n_rows=1, fps=25) lq_frames = lq_frames.permute(1,0,2,3).unsqueeze(0) lq_frames = (lq_frames * 0.5 + 0.5).clamp(0, 1).to(device="cuda").cpu() save_videos_grid(lq_frames, f"{save_dir}/{video_name[:-4]}_{seed_input}_ori.mp4", n_rows=1, fps=25) if args.restore_frames: video = video.squeeze(0) os.makedirs(os.path.join(save_dir, "result_frames", f"{video_name[:-4]}_{seed_input}"),exist_ok=True) print(os.path.join(save_dir, "result_frames", video_name[:-4])) for i in range(video.shape[1]): save_frames_path = os.path.join(f"{save_dir}/result_frames", f"{video_name[:-4]}_{seed_input}", f'{i:08d}.png') save_image(video[:,i], save_frames_path) def get_overlap_slide_window_indices(video_length, window_size, window_overlap): inter_frame_list = [] for j in range(0, video_length, window_size-window_overlap): inter_frame_list.append( [e % video_length for e in range(j, min(j + window_size, video_length))] ) return inter_frame_list def random_seed(): return random.randint(0, 10000) def infer(lq_sequence, task_name, mask, seed, progress=gr.Progress(track_tqdm=True)): if is_shared_ui: lq_sequence, temp_input_folder = process_video(lq_sequence) unique_id = str(uuid.uuid4()) output_dir = f"results_{unique_id}" task_mapping = { "BFR": 0, "Colorization": 1, "Inpainting": 2 } task_ids = [task_mapping[task] for task in task_name if task in task_mapping] # task_id = ",".join(task_ids) try: parser = argparse.ArgumentParser() args = parser.parse_args() args.task_ids = task_ids args.input_path = f"{lq_sequence}" args.output_dir = f"{output_dir}" args.mask_path = f"{mask}" args.seed = int(seed) args.restore_frames = False gen(args,pipe) # Search for the mp4 file in a subfolder of output_dir output_video = glob(os.path.join(output_dir,"*gen.mp4")) face_region_video = glob(os.path.join(output_dir,"*ori.mp4")) # print(face_region_video,output_video) if output_video: output_video_path = output_video[0] # Get the first match face_region_video_path = face_region_video[0] # Get the first match else: output_video_path = None face_region_video = None print(output_video_path,face_region_video_path) if is_shared_ui: # Clean up temporary input folder shutil.rmtree(temp_input_folder) torch.cuda.empty_cache() return face_region_video_path,output_video_path except subprocess.CalledProcessError as e: if is_shared_ui: # Clean up temporary input folder shutil.rmtree(temp_input_folder) torch.cuda.empty_cache() raise gr.Error(f"Error during inference: {str(e)}") css=""" div#col-container{ margin: 0 auto; max-width: 982px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# SVFR: A Unified Framework for Generalized Video Face Restoration") gr.Markdown("SVFR is a unified framework for face video restoration that supports tasks such as BFR, Colorization, Inpainting, and their combinations within one cohesive system.") gr.HTML("""
Duplicate this Space
""") with gr.Row(): with gr.Column(): input_seq = gr.Video(label="Video LQ") task_name = gr.CheckboxGroup( label="Task", choices=["BFR", "Colorization", "Inpainting"], value=["BFR"] # default ) mask_input = gr.Image(type="filepath",label="Inpainting Mask") with gr.Row(): seed_input = gr.Number(label="Seed", value=77, precision=0) random_seed_btn = gr.Button("🎲",scale=1,elem_id="dice-btn") submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(): output_face = gr.Video(label="Face Region Input") output_res = gr.Video(label="Restored") gr.Examples( examples = [ ["./assert/lq/lq1.mp4", ["BFR"],None], ["./assert/lq/lq2.mp4", ["BFR", "Colorization"],None], ["./assert/lq/lq3.mp4", ["BFR", "Colorization", "Inpainting"],"./assert/mask/lq3.png"] ], inputs = [input_seq, task_name, mask_input] ) random_seed_btn.click( fn=random_seed, inputs=[], outputs=seed_input ) submit_btn.click( fn = infer, inputs = [input_seq, task_name, mask_input,seed_input], outputs = [output_face,output_res] ) clear_btn.click( fn=lambda: [None,["BFR"],None,77,None,None], inputs=None, outputs=[input_seq, task_name, mask_input, seed_input, output_face, output_res] ) demo.queue().launch(show_api=False, show_error=True)