fffiloni AJohn123 commited on
Commit
25d1b89
·
verified ·
1 Parent(s): bc166ca
Files changed (5) hide show
  1. ORIGINAL_README.md +6 -0
  2. README.md +3 -1
  3. app.py +385 -40
  4. infer.py +29 -5
  5. src/dataset/dataset.py +68 -0
ORIGINAL_README.md CHANGED
@@ -108,6 +108,8 @@ https://github.com/user-attachments/assets/efdac23c-0ba5-4dad-ab8c-48904af5dd89
108
 
109
  ## 🚀 Getting Started
110
 
 
 
111
  ## Setup
112
 
113
  Use the following command to install a conda environment for SVFR from scratch:
@@ -200,6 +202,10 @@ The code of SVFR is released under the MIT License. There is no limitation for b
200
 
201
  **The pretrained models we provided with this library are available for non-commercial research purposes only, including both auto-downloading models and manual-downloading models.**
202
 
 
 
 
 
203
 
204
  ## BibTex
205
  ```
 
108
 
109
  ## 🚀 Getting Started
110
 
111
+ > Note: It is recommended to use a GPU with 16GB or more VRAM.
112
+
113
  ## Setup
114
 
115
  Use the following command to install a conda environment for SVFR from scratch:
 
202
 
203
  **The pretrained models we provided with this library are available for non-commercial research purposes only, including both auto-downloading models and manual-downloading models.**
204
 
205
+ ## Acknowledgments
206
+
207
+ This work is built on the architecture of [Sonic](https://github.com/jixiaozhong/Sonic).
208
+
209
 
210
  ## BibTex
211
  ```
README.md CHANGED
@@ -10,4 +10,6 @@ pinned: false
10
  short_description: Unified Framework for Generalized Video Face Restoration
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
10
  short_description: Unified Framework for Generalized Video Face Restoration
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
app.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import sys
3
  import os
@@ -8,6 +20,45 @@ import uuid
8
  import gradio as gr
9
  from glob import glob
10
  from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Download models
13
  os.makedirs("models", exist_ok=True)
@@ -30,44 +81,326 @@ snapshot_download(
30
  local_dir = "./models/stable-video-diffusion-img2vid-xt"
31
  )
32
 
33
- def infer(lq_sequence, task_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  unique_id = str(uuid.uuid4())
36
  output_dir = f"results_{unique_id}"
37
 
38
- if task_name == "BFR":
39
- task_id = "0"
40
- elif task_name == "colorization":
41
- task_id = "1"
42
- elif task_name == "BFR + colorization":
43
- task_id = "0,1"
 
 
44
 
45
  try:
46
- # Run the inference command
47
- subprocess.run(
48
- [
49
- "python", "infer.py",
50
- "--config", "config/infer.yaml",
51
- "--task_ids", f"{task_id}",
52
- "--input_path", f"{lq_sequence}",
53
- "--output_dir", f"{output_dir}",
54
- ],
55
- check=True
56
- )
57
 
58
  # Search for the mp4 file in a subfolder of output_dir
59
- output_video = glob(os.path.join(output_dir,"*.mp4"))
60
- print(output_video)
 
61
 
62
  if output_video:
63
  output_video_path = output_video[0] # Get the first match
 
64
  else:
65
  output_video_path = None
 
66
 
67
- print(output_video_path)
68
- return output_video_path
 
69
 
70
  except subprocess.CalledProcessError as e:
 
71
  raise gr.Error(f"Error during inference: {str(e)}")
72
 
73
  css="""
@@ -91,38 +424,50 @@ with gr.Blocks(css=css) as demo:
91
  <a href="https://arxiv.org/pdf/2501.01235">
92
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
93
  </a>
94
- <a href="https://huggingface.co/spaces/fffiloni/SVFR-demo?duplicate=true">
95
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
96
- </a>
97
- <a href="https://huggingface.co/fffiloni">
98
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
99
- </a>
100
  </div>
101
  """)
102
  with gr.Row():
103
  with gr.Column():
104
  input_seq = gr.Video(label="Video LQ")
105
- task_name = gr.Radio(
106
  label="Task",
107
- choices=["BFR", "colorization", "BFR + colorization"],
108
- value="BFR"
109
  )
110
- submit_btn = gr.Button("Submit")
 
 
 
 
 
111
  with gr.Column():
 
112
  output_res = gr.Video(label="Restored")
113
  gr.Examples(
114
  examples = [
115
- ["./assert/lq/lq1.mp4", "BFR"],
116
- ["./assert/lq/lq2.mp4", "BFR + colorization"],
117
- ["./assert/lq/lq3.mp4", "colorization"]
118
  ],
119
- inputs = [input_seq, task_name]
120
  )
121
-
 
 
 
 
 
 
 
122
  submit_btn.click(
123
  fn = infer,
124
- inputs = [input_seq, task_name],
125
- outputs = [output_res]
 
 
 
 
 
126
  )
127
 
128
- demo.queue().launch(show_api=False, show_error=True)
 
1
+ """
2
+ This script is based on the original project by https://huggingface.co/fffiloni.
3
+ URL: https://huggingface.co/spaces/fffiloni/SVFR-demo/blob/main/app.py
4
+
5
+ Modifications made:
6
+ - Synced the infer code updates from GitHub repo.
7
+ - Added an inpainting option to enhance functionality.
8
+
9
+ Author of modifications: https://github.com/wangzhiyaoo
10
+ Date: 2025/01/15
11
+ """
12
+
13
  import torch
14
  import sys
15
  import os
 
20
  import gradio as gr
21
  from glob import glob
22
  from huggingface_hub import snapshot_download
23
+ import random
24
+
25
+ import argparse
26
+ import warnings
27
+ import os
28
+ import numpy as np
29
+ import torch
30
+ import torch.utils.checkpoint
31
+ from PIL import Image
32
+ import random
33
+
34
+ from omegaconf import OmegaConf
35
+ from diffusers import AutoencoderKLTemporalDecoder
36
+ from diffusers.schedulers import EulerDiscreteScheduler
37
+ from transformers import CLIPVisionModelWithProjection
38
+ import torchvision.transforms as transforms
39
+ import torch.nn.functional as F
40
+ from src.models.svfr_adapter.unet_3d_svd_condition_ip import UNet3DConditionSVDModel
41
+
42
+ # pipeline
43
+ from src.pipelines.pipeline import LQ2VideoLongSVDPipeline
44
+
45
+ from src.utils.util import (
46
+ save_videos_grid,
47
+ seed_everything,
48
+ )
49
+ from torchvision.utils import save_image
50
+
51
+ from src.models.id_proj import IDProjConvModel
52
+ from src.models import model_insightface_360k
53
+
54
+ from src.dataset.face_align.align import AlignImage
55
+
56
+ warnings.filterwarnings("ignore")
57
+
58
+ import decord
59
+ import cv2
60
+ from src.dataset.dataset import get_affine_transform, mean_face_lm5p_256, get_union_bbox, process_bbox, crop_resize_img
61
+
62
 
63
  # Download models
64
  os.makedirs("models", exist_ok=True)
 
81
  local_dir = "./models/stable-video-diffusion-img2vid-xt"
82
  )
83
 
84
+ BASE_DIR = '.'
85
+
86
+ config = OmegaConf.load("./config/infer.yaml")
87
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
88
+ f"{BASE_DIR}/{config.pretrained_model_name_or_path}",
89
+ subfolder="vae",
90
+ variant="fp16")
91
+
92
+ val_noise_scheduler = EulerDiscreteScheduler.from_pretrained(
93
+ f"{BASE_DIR}/{config.pretrained_model_name_or_path}",
94
+ subfolder="scheduler")
95
+
96
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
97
+ f"{BASE_DIR}/{config.pretrained_model_name_or_path}",
98
+ subfolder="image_encoder",
99
+ variant="fp16")
100
+ unet = UNet3DConditionSVDModel.from_pretrained(
101
+ f"{BASE_DIR}/{config.pretrained_model_name_or_path}",
102
+ subfolder="unet",
103
+ variant="fp16")
104
+
105
+ weight_dir = 'models/face_align'
106
+ det_path = os.path.join(BASE_DIR, weight_dir, 'yoloface_v5m.pt')
107
+ align_instance = AlignImage("cuda", det_path=det_path)
108
+
109
+ to_tensor = transforms.Compose([
110
+ transforms.ToTensor(),
111
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
112
+ ])
113
+
114
+ import torch.nn as nn
115
+ class InflatedConv3d(nn.Conv2d):
116
+ def forward(self, x):
117
+ x = super().forward(x)
118
+ return x
119
+ # Add ref channel
120
+ old_weights = unet.conv_in.weight
121
+ old_bias = unet.conv_in.bias
122
+ new_conv1 = InflatedConv3d(
123
+ 12,
124
+ old_weights.shape[0],
125
+ kernel_size=unet.conv_in.kernel_size,
126
+ stride=unet.conv_in.stride,
127
+ padding=unet.conv_in.padding,
128
+ bias=True if old_bias is not None else False,
129
+ )
130
+ param = torch.zeros((320, 4, 3, 3), requires_grad=True)
131
+ new_conv1.weight = torch.nn.Parameter(torch.cat((old_weights, param), dim=1))
132
+ if old_bias is not None:
133
+ new_conv1.bias = old_bias
134
+ unet.conv_in = new_conv1
135
+ unet.config["in_channels"] = 12
136
+ unet.config.in_channels = 12
137
+
138
+
139
+ id_linear = IDProjConvModel(in_channels=512, out_channels=1024).to(device='cuda')
140
+
141
+ # load pretrained weights
142
+ unet_checkpoint_path = os.path.join(BASE_DIR, config.unet_checkpoint_path)
143
+ unet.load_state_dict(
144
+ torch.load(unet_checkpoint_path, map_location="cpu"),
145
+ strict=True,
146
+ )
147
+
148
+ id_linear_checkpoint_path = os.path.join(BASE_DIR, config.id_linear_checkpoint_path)
149
+ id_linear.load_state_dict(
150
+ torch.load(id_linear_checkpoint_path, map_location="cpu"),
151
+ strict=True,
152
+ )
153
+
154
+ net_arcface = model_insightface_360k.getarcface(f'{BASE_DIR}/{config.net_arcface_checkpoint_path}').eval().to(device="cuda")
155
+
156
+ if config.weight_dtype == "fp16":
157
+ weight_dtype = torch.float16
158
+ elif config.weight_dtype == "fp32":
159
+ weight_dtype = torch.float32
160
+ elif config.weight_dtype == "bf16":
161
+ weight_dtype = torch.bfloat16
162
+ else:
163
+ raise ValueError(
164
+ f"Do not support weight dtype: {config.weight_dtype} during training"
165
+ )
166
+
167
+ image_encoder.to(weight_dtype)
168
+ vae.to(weight_dtype)
169
+ unet.to(weight_dtype)
170
+ id_linear.to(weight_dtype)
171
+ net_arcface.requires_grad_(False).to(weight_dtype)
172
+
173
+ pipe = LQ2VideoLongSVDPipeline(
174
+ unet=unet,
175
+ image_encoder=image_encoder,
176
+ vae=vae,
177
+ scheduler=val_noise_scheduler,
178
+ feature_extractor=None
179
+
180
+ )
181
+ pipe = pipe.to("cuda", dtype=unet.dtype)
182
+
183
+ def gen(args,pipe):
184
+ save_dir = f"{BASE_DIR}/{args.output_dir}"
185
+ os.makedirs(save_dir,exist_ok=True)
186
+
187
+ seed_input = args.seed
188
+ seed_everything(seed_input)
189
+
190
+ video_path = args.input_path
191
+ task_ids = args.task_ids
192
+
193
+ if 2 in task_ids and args.mask_path is not None:
194
+ mask_path = args.mask_path
195
+ mask = Image.open(mask_path).convert("L")
196
+ mask_array = np.array(mask)
197
+
198
+ white_positions = mask_array == 255
199
+
200
+ print('task_ids:',task_ids)
201
+ task_prompt = [0,0,0]
202
+ for i in range(3):
203
+ if i in task_ids:
204
+ task_prompt[i] = 1
205
+ print("task_prompt:",task_prompt)
206
+
207
+ video_name = video_path.split('/')[-1]
208
+ # print(video_name)
209
+
210
+ if os.path.exists(os.path.join(save_dir, "result_frames", video_name[:-4])):
211
+ print(os.path.join(save_dir, "result_frames", video_name[:-4]))
212
+ # continue
213
+
214
+ cap = decord.VideoReader(video_path, fault_tol=1)
215
+ total_frames = len(cap)
216
+ T = total_frames #
217
+ print("total_frames:",total_frames)
218
+ step=1
219
+ drive_idx_start = 0
220
+ drive_idx_list = list(range(drive_idx_start, drive_idx_start + T * step, step))
221
+ assert len(drive_idx_list) == T
222
+
223
+ # Crop faces from the video for further processing
224
+ bbox_list = []
225
+ frame_interval = 5
226
+ for frame_count, drive_idx in enumerate(drive_idx_list):
227
+ if frame_count % frame_interval != 0:
228
+ continue
229
+ frame = cap[drive_idx].asnumpy()
230
+ _, _, bboxes_list = align_instance(frame[:,:,[2,1,0]], maxface=True)
231
+ if bboxes_list==[]:
232
+ continue
233
+ x1, y1, ww, hh = bboxes_list[0]
234
+ x2, y2 = x1 + ww, y1 + hh
235
+ bbox = [x1, y1, x2, y2]
236
+ bbox_list.append(bbox)
237
+ bbox = get_union_bbox(bbox_list)
238
+ bbox_s = process_bbox(bbox, expand_radio=0.4, height=frame.shape[0], width=frame.shape[1])
239
+
240
+ imSameIDs = []
241
+ vid_gt = []
242
+ for i, drive_idx in enumerate(drive_idx_list):
243
+ frame = cap[drive_idx].asnumpy()
244
+ imSameID = Image.fromarray(frame)
245
+ imSameID = crop_resize_img(imSameID, bbox_s)
246
+ imSameID = imSameID.resize((512,512))
247
+ if 1 in task_ids:
248
+ imSameID = imSameID.convert("L") # Convert to grayscale
249
+ imSameID = imSameID.convert("RGB")
250
+ image_array = np.array(imSameID)
251
+ if 2 in task_ids and args.mask_path is not None:
252
+ image_array[white_positions] = [255, 255, 255] # mask for inpainting task
253
+ vid_gt.append(np.float32(image_array/255.))
254
+ imSameIDs.append(imSameID)
255
+
256
+ vid_lq = [(torch.from_numpy(frame).permute(2,0,1) - 0.5) / 0.5 for frame in vid_gt]
257
+
258
+ val_data = dict(
259
+ pixel_values_vid_lq = torch.stack(vid_lq,dim=0),
260
+ # pixel_values_ref_img=self.to_tensor(target_image),
261
+ # pixel_values_ref_concat_img=self.to_tensor(imSrc2),
262
+ task_ids=task_ids,
263
+ task_id_input=torch.tensor(task_prompt),
264
+ total_frames=total_frames,
265
+ )
266
+
267
+ window_overlap=0
268
+ inter_frame_list = get_overlap_slide_window_indices(val_data["total_frames"],config.data.n_sample_frames,window_overlap)
269
+
270
+ lq_frames = val_data["pixel_values_vid_lq"]
271
+ task_ids = val_data["task_ids"]
272
+ task_id_input = val_data["task_id_input"]
273
+ height, width = val_data["pixel_values_vid_lq"].shape[-2:]
274
+
275
+ print("Generating the first clip...")
276
+ output = pipe(
277
+ lq_frames[inter_frame_list[0]].to("cuda").to(weight_dtype), # lq
278
+ None, # ref concat
279
+ torch.zeros((1, len(inter_frame_list[0]), 49, 1024)).to("cuda").to(weight_dtype),# encoder_hidden_states
280
+ task_id_input.to("cuda").to(weight_dtype),
281
+ height=height,
282
+ width=width,
283
+ num_frames=len(inter_frame_list[0]),
284
+ decode_chunk_size=config.decode_chunk_size,
285
+ noise_aug_strength=config.noise_aug_strength,
286
+ min_guidance_scale=config.min_appearance_guidance_scale,
287
+ max_guidance_scale=config.max_appearance_guidance_scale,
288
+ overlap=config.overlap,
289
+ frames_per_batch=len(inter_frame_list[0]),
290
+ num_inference_steps=50,
291
+ i2i_noise_strength=config.i2i_noise_strength,
292
+ )
293
+ video = output.frames
294
+
295
+ ref_img_tensor = video[0][:,-1]
296
+ ref_img = (video[0][:,-1] *0.5+0.5).clamp(0,1) * 255.
297
+ ref_img = ref_img.permute(1,2,0).cpu().numpy().astype(np.uint8)
298
+
299
+ pts5 = align_instance(ref_img[:,:,[2,1,0]], maxface=True)[0][0]
300
+
301
+ warp_mat = get_affine_transform(pts5, mean_face_lm5p_256 * height/256)
302
+ ref_img = cv2.warpAffine(np.array(Image.fromarray(ref_img)), warp_mat, (height, width), flags=cv2.INTER_CUBIC)
303
+ ref_img = to_tensor(ref_img).to("cuda").to(weight_dtype)
304
+
305
+ save_image(ref_img*0.5 + 0.5,f"{save_dir}/ref_img_align.png")
306
+
307
+ ref_img = F.interpolate(ref_img.unsqueeze(0)[:, :, 0:224, 16:240], size=[112, 112], mode='bilinear')
308
+ _, id_feature_conv = net_arcface(ref_img)
309
+ id_embedding = id_linear(id_feature_conv)
310
+
311
+ print('Generating all video clips...')
312
+ video = pipe(
313
+ lq_frames.to("cuda").to(weight_dtype), # lq
314
+ ref_img_tensor.to("cuda").to(weight_dtype),
315
+ id_embedding.unsqueeze(1).repeat(1, len(lq_frames), 1, 1).to("cuda").to(weight_dtype), # encoder_hidden_states
316
+ task_id_input.to("cuda").to(weight_dtype),
317
+ height=height,
318
+ width=width,
319
+ num_frames=val_data["total_frames"],#frame_num,
320
+ decode_chunk_size=config.decode_chunk_size,
321
+ noise_aug_strength=config.noise_aug_strength,
322
+ min_guidance_scale=config.min_appearance_guidance_scale,
323
+ max_guidance_scale=config.max_appearance_guidance_scale,
324
+ overlap=config.overlap,
325
+ frames_per_batch=config.data.n_sample_frames,
326
+ num_inference_steps=config.num_inference_steps,
327
+ i2i_noise_strength=config.i2i_noise_strength,
328
+ ).frames
329
+
330
+
331
+ video = (video*0.5 + 0.5).clamp(0, 1)
332
+ video = torch.cat([video.to(device="cuda")], dim=0).cpu()
333
+ save_videos_grid(video, f"{save_dir}/{video_name[:-4]}_{seed_input}_gen.mp4", n_rows=1, fps=25)
334
+
335
+ lq_frames = lq_frames.permute(1,0,2,3).unsqueeze(0)
336
+ lq_frames = (lq_frames * 0.5 + 0.5).clamp(0, 1).to(device="cuda").cpu()
337
+ save_videos_grid(lq_frames, f"{save_dir}/{video_name[:-4]}_{seed_input}_ori.mp4", n_rows=1, fps=25)
338
+
339
+ if args.restore_frames:
340
+ video = video.squeeze(0)
341
+ os.makedirs(os.path.join(save_dir, "result_frames", f"{video_name[:-4]}_{seed_input}"),exist_ok=True)
342
+ print(os.path.join(save_dir, "result_frames", video_name[:-4]))
343
+ for i in range(video.shape[1]):
344
+ save_frames_path = os.path.join(f"{save_dir}/result_frames", f"{video_name[:-4]}_{seed_input}", f'{i:08d}.png')
345
+ save_image(video[:,i], save_frames_path)
346
+
347
+
348
+ def get_overlap_slide_window_indices(video_length, window_size, window_overlap):
349
+ inter_frame_list = []
350
+ for j in range(0, video_length, window_size-window_overlap):
351
+ inter_frame_list.append( [e % video_length for e in range(j, min(j + window_size, video_length))] )
352
+
353
+ return inter_frame_list
354
+
355
+
356
+
357
+ def random_seed():
358
+ return random.randint(0, 10000)
359
+
360
+ def infer(lq_sequence, task_name, mask, seed):
361
 
362
  unique_id = str(uuid.uuid4())
363
  output_dir = f"results_{unique_id}"
364
 
365
+ task_mapping = {
366
+ "BFR": 0,
367
+ "Colorization": 1,
368
+ "Inpainting": 2
369
+ }
370
+
371
+ task_ids = [task_mapping[task] for task in task_name if task in task_mapping]
372
+ # task_id = ",".join(task_ids)
373
 
374
  try:
375
+ parser = argparse.ArgumentParser()
376
+ args = parser.parse_args()
377
+ args.task_ids = task_ids
378
+ args.input_path = f"{lq_sequence}"
379
+ args.output_dir = f"{output_dir}"
380
+ args.mask_path = f"{mask}"
381
+ args.seed = int(seed)
382
+ args.restore_frames = False
383
+
384
+ gen(args,pipe)
 
385
 
386
  # Search for the mp4 file in a subfolder of output_dir
387
+ output_video = glob(os.path.join(output_dir,"*gen.mp4"))
388
+ face_region_video = glob(os.path.join(output_dir,"*ori.mp4"))
389
+ # print(face_region_video,output_video)
390
 
391
  if output_video:
392
  output_video_path = output_video[0] # Get the first match
393
+ face_region_video_path = face_region_video[0] # Get the first match
394
  else:
395
  output_video_path = None
396
+ face_region_video = None
397
 
398
+ print(output_video_path,face_region_video_path)
399
+ torch.cuda.empty_cache()
400
+ return face_region_video_path,output_video_path
401
 
402
  except subprocess.CalledProcessError as e:
403
+ torch.cuda.empty_cache()
404
  raise gr.Error(f"Error during inference: {str(e)}")
405
 
406
  css="""
 
424
  <a href="https://arxiv.org/pdf/2501.01235">
425
  <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
426
  </a>
 
 
 
 
 
 
427
  </div>
428
  """)
429
  with gr.Row():
430
  with gr.Column():
431
  input_seq = gr.Video(label="Video LQ")
432
+ task_name = gr.CheckboxGroup(
433
  label="Task",
434
+ choices=["BFR", "Colorization", "Inpainting"],
435
+ value=["BFR"] # default
436
  )
437
+ mask_input = gr.Image(type="filepath",label="Inpainting Mask")
438
+ with gr.Row():
439
+ seed_input = gr.Number(label="Seed", value=77, precision=0)
440
+ random_seed_btn = gr.Button("🎲",scale=1,elem_id="dice-btn")
441
+ submit_btn = gr.Button("Submit", variant="primary")
442
+ clear_btn = gr.Button("Clear")
443
  with gr.Column():
444
+ output_face = gr.Video(label="Face Region Input")
445
  output_res = gr.Video(label="Restored")
446
  gr.Examples(
447
  examples = [
448
+ ["./assert/lq/lq1.mp4", ["BFR"],None],
449
+ ["./assert/lq/lq2.mp4", ["BFR", "Colorization"],None],
450
+ ["./assert/lq/lq3.mp4", ["BFR", "Colorization", "Inpainting"],"./assert/mask/lq3.png"]
451
  ],
452
+ inputs = [input_seq, task_name, mask_input]
453
  )
454
+
455
+ random_seed_btn.click(
456
+ fn=random_seed,
457
+ inputs=[],
458
+ outputs=seed_input
459
+ )
460
+
461
+
462
  submit_btn.click(
463
  fn = infer,
464
+ inputs = [input_seq, task_name, mask_input,seed_input],
465
+ outputs = [output_face,output_res]
466
+ )
467
+ clear_btn.click(
468
+ fn=lambda: [None,["BFR"],None,77,None,None],
469
+ inputs=None,
470
+ outputs=[input_seq, task_name, mask_input, seed_input, output_face, output_res]
471
  )
472
 
473
+ demo.queue().launch(show_api=False, show_error=True, server_port=1203)
infer.py CHANGED
@@ -33,10 +33,11 @@ warnings.filterwarnings("ignore")
33
 
34
  import decord
35
  import cv2
36
- from src.dataset.dataset import get_affine_transform, mean_face_lm5p_256
37
 
38
  BASE_DIR = '.'
39
 
 
40
  def main(config,args):
41
  if 'CUDA_VISIBLE_DEVICES' in os.environ:
42
  cuda_visible_devices = os.environ['CUDA_VISIBLE_DEVICES']
@@ -179,13 +180,33 @@ def main(config,args):
179
  drive_idx_list = list(range(drive_idx_start, drive_idx_start + T * step, step))
180
  assert len(drive_idx_list) == T
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  imSameIDs = []
183
  vid_gt = []
184
  for i, drive_idx in enumerate(drive_idx_list):
185
  frame = cap[drive_idx].asnumpy()
186
  imSameID = Image.fromarray(frame)
187
-
188
  imSameID = imSameID.resize((512,512))
 
 
 
189
  image_array = np.array(imSameID)
190
  if 2 in task_ids and args.mask_path is not None:
191
  image_array[white_positions] = [255, 255, 255] # mask for inpainting task
@@ -241,7 +262,7 @@ def main(config,args):
241
  ref_img = cv2.warpAffine(np.array(Image.fromarray(ref_img)), warp_mat, (height, width), flags=cv2.INTER_CUBIC)
242
  ref_img = to_tensor(ref_img).to("cuda").to(weight_dtype)
243
 
244
- save_image(ref_img*0.5 + 0.5,f"{save_dir}/ref_img_align.png")
245
 
246
  ref_img = F.interpolate(ref_img.unsqueeze(0)[:, :, 0:224, 16:240], size=[112, 112], mode='bilinear')
247
  _, id_feature_conv = net_arcface(ref_img)
@@ -269,8 +290,11 @@ def main(config,args):
269
 
270
  video = (video*0.5 + 0.5).clamp(0, 1)
271
  video = torch.cat([video.to(device="cuda")], dim=0).cpu()
272
-
273
- save_videos_grid(video, f"{save_dir}/{video_name[:-4]}_{seed_input}.mp4", n_rows=1, fps=25)
 
 
 
274
 
275
  if args.restore_frames:
276
  video = video.squeeze(0)
 
33
 
34
  import decord
35
  import cv2
36
+ from src.dataset.dataset import get_affine_transform, mean_face_lm5p_256, get_union_bbox, process_bbox, crop_resize_img
37
 
38
  BASE_DIR = '.'
39
 
40
+
41
  def main(config,args):
42
  if 'CUDA_VISIBLE_DEVICES' in os.environ:
43
  cuda_visible_devices = os.environ['CUDA_VISIBLE_DEVICES']
 
180
  drive_idx_list = list(range(drive_idx_start, drive_idx_start + T * step, step))
181
  assert len(drive_idx_list) == T
182
 
183
+ # Crop faces from the video for further processing
184
+ bbox_list = []
185
+ frame_interval = 5
186
+ for frame_count, drive_idx in enumerate(drive_idx_list):
187
+ if frame_count % frame_interval != 0:
188
+ continue
189
+ frame = cap[drive_idx].asnumpy()
190
+ _, _, bboxes_list = align_instance(frame[:,:,[2,1,0]], maxface=True)
191
+ if bboxes_list==[]:
192
+ continue
193
+ x1, y1, ww, hh = bboxes_list[0]
194
+ x2, y2 = x1 + ww, y1 + hh
195
+ bbox = [x1, y1, x2, y2]
196
+ bbox_list.append(bbox)
197
+ bbox = get_union_bbox(bbox_list)
198
+ bbox_s = process_bbox(bbox, expand_radio=0.4, height=frame.shape[0], width=frame.shape[1])
199
+
200
  imSameIDs = []
201
  vid_gt = []
202
  for i, drive_idx in enumerate(drive_idx_list):
203
  frame = cap[drive_idx].asnumpy()
204
  imSameID = Image.fromarray(frame)
205
+ imSameID = crop_resize_img(imSameID, bbox_s)
206
  imSameID = imSameID.resize((512,512))
207
+ if 1 in task_ids:
208
+ imSameID = imSameID.convert("L") # Convert to grayscale
209
+ imSameID = imSameID.convert("RGB")
210
  image_array = np.array(imSameID)
211
  if 2 in task_ids and args.mask_path is not None:
212
  image_array[white_positions] = [255, 255, 255] # mask for inpainting task
 
262
  ref_img = cv2.warpAffine(np.array(Image.fromarray(ref_img)), warp_mat, (height, width), flags=cv2.INTER_CUBIC)
263
  ref_img = to_tensor(ref_img).to("cuda").to(weight_dtype)
264
 
265
+ # save_image(ref_img*0.5 + 0.5,f"{save_dir}/ref_img_align.png")
266
 
267
  ref_img = F.interpolate(ref_img.unsqueeze(0)[:, :, 0:224, 16:240], size=[112, 112], mode='bilinear')
268
  _, id_feature_conv = net_arcface(ref_img)
 
290
 
291
  video = (video*0.5 + 0.5).clamp(0, 1)
292
  video = torch.cat([video.to(device="cuda")], dim=0).cpu()
293
+ save_videos_grid(video, f"{save_dir}/{video_name[:-4]}_{seed_input}_gen.mp4", n_rows=1, fps=25)
294
+
295
+ lq_frames = lq_frames.permute(1,0,2,3).unsqueeze(0)
296
+ lq_frames = (lq_frames * 0.5 + 0.5).clamp(0, 1).to(device="cuda").cpu()
297
+ save_videos_grid(lq_frames, f"{save_dir}/{video_name[:-4]}_{seed_input}_ori.mp4", n_rows=1, fps=25)
298
 
299
  if args.restore_frames:
300
  video = video.squeeze(0)
src/dataset/dataset.py CHANGED
@@ -48,3 +48,71 @@ def get_affine_transform(target_face_lm5p, mean_lm5p):
48
  mat_warp[1][2] = mat23[3]
49
 
50
  return mat_warp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  mat_warp[1][2] = mat23[3]
49
 
50
  return mat_warp
51
+
52
+ def get_union_bbox(bboxes):
53
+ bboxes = np.array(bboxes)
54
+ min_x = np.min(bboxes[:, 0])
55
+ min_y = np.min(bboxes[:, 1])
56
+ max_x = np.max(bboxes[:, 2])
57
+ max_y = np.max(bboxes[:, 3])
58
+ return np.array([min_x, min_y, max_x, max_y])
59
+
60
+
61
+ def process_bbox(bbox, expand_radio, height, width):
62
+
63
+ def expand(bbox, ratio, height, width):
64
+
65
+ bbox_h = bbox[3] - bbox[1]
66
+ bbox_w = bbox[2] - bbox[0]
67
+
68
+ expand_x1 = max(bbox[0] - ratio * bbox_w, 0)
69
+ expand_y1 = max(bbox[1] - ratio * bbox_h, 0)
70
+ expand_x2 = min(bbox[2] + ratio * bbox_w, width)
71
+ expand_y2 = min(bbox[3] + ratio * bbox_h, height)
72
+
73
+ return [expand_x1,expand_y1,expand_x2,expand_y2]
74
+
75
+ def to_square(bbox_src, bbox_expend, height, width):
76
+
77
+ h = bbox_expend[3] - bbox_expend[1]
78
+ w = bbox_expend[2] - bbox_expend[0]
79
+ c_h = (bbox_expend[1] + bbox_expend[3]) / 2
80
+ c_w = (bbox_expend[0] + bbox_expend[2]) / 2
81
+
82
+ c = min(h, w) / 2
83
+
84
+ c_src_h = (bbox_src[1] + bbox_src[3]) / 2
85
+ c_src_w = (bbox_src[0] + bbox_src[2]) / 2
86
+
87
+ s_h, s_w = 0, 0
88
+ if w < h:
89
+ d = abs((h - w) / 2)
90
+ s_h = min(d, abs(c_src_h-c_h))
91
+ s_h = s_h if c_src_h > c_h else s_h * (-1)
92
+ else:
93
+ d = abs((h - w) / 2)
94
+ s_w = min(d, abs(c_src_w-c_w))
95
+ s_w = s_w if c_src_w > c_w else s_w * (-1)
96
+
97
+
98
+ c_h = (bbox_expend[1] + bbox_expend[3]) / 2 + s_h
99
+ c_w = (bbox_expend[0] + bbox_expend[2]) / 2 + s_w
100
+
101
+ square_x1 = c_w - c
102
+ square_y1 = c_h - c
103
+ square_x2 = c_w + c
104
+ square_y2 = c_h + c
105
+
106
+ return [round(square_x1), round(square_y1), round(square_x2), round(square_y2)]
107
+
108
+
109
+ bbox_expend = expand(bbox, expand_radio, height=height, width=width)
110
+ processed_bbox = to_square(bbox, bbox_expend, height=height, width=width)
111
+
112
+ return processed_bbox
113
+
114
+
115
+ def crop_resize_img(img, bbox):
116
+ x1, y1, x2, y2 = bbox
117
+ img = img.crop((x1, y1, x2, y2))
118
+ return img