Peleck commited on
Commit
fa8453f
Β·
1 Parent(s): 7592833
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ *.pth
3
+ *.onnx
4
+ *.pyc
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Swap-mukham WIP
3
- emoji: πŸ”₯
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.40.1
8
  app_file: app.py
 
1
  ---
2
  title: Swap-mukham WIP
3
+ emoji: 😊
4
+ colorFrom: blue
5
+ colorTo: black
6
  sdk: gradio
7
  sdk_version: 3.40.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,1134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import shutil
5
+ import base64
6
+ import datetime
7
+ import argparse
8
+ import numpy as np
9
+ import gradio as gr
10
+ from tqdm import tqdm
11
+ import concurrent.futures
12
+
13
+ import threading
14
+ cv_reader_lock = threading.Lock()
15
+
16
+ ## ------------------------------ USER ARGS ------------------------------
17
+
18
+ parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
19
+ parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
20
+ parser.add_argument("--max_threads", type=int, help="Max num of threads to use", default=2)
21
+ parser.add_argument("--colab", action="store_true", help="Colab mode", default=False)
22
+ parser.add_argument("--cpu", action="store_true", help="Enable cpu mode", default=False)
23
+ parser.add_argument("--prefer_text_widget", action="store_true", help="Replaces target video widget with text widget", default=False)
24
+ user_args = parser.parse_args()
25
+
26
+ USE_CPU = 1
27
+
28
+ if not USE_CPU:
29
+ import torch
30
+
31
+ import default_paths as dp
32
+ import global_variables as gv
33
+
34
+ from swap_mukham import SwapMukham
35
+ from nsfw_checker import NSFWChecker
36
+
37
+ from face_parsing import mask_regions_to_list
38
+
39
+ from utils.device import get_device_and_provider, device_types_list
40
+ from utils.image import (
41
+ image_mask_overlay,
42
+ resize_image_by_resolution,
43
+ resolution_map,
44
+ fast_pil_encode,
45
+ fast_numpy_encode,
46
+ get_crf_for_resolution,
47
+ )
48
+ from utils.io import (
49
+ open_directory,
50
+ get_images_from_directory,
51
+ copy_files_to_directory,
52
+ create_directory,
53
+ get_single_video_frame,
54
+ ffmpeg_merge_frames,
55
+ ffmpeg_mux_audio,
56
+ add_datetime_to_filename,
57
+ )
58
+
59
+ gr.processing_utils.encode_pil_to_base64 = fast_pil_encode
60
+ gr.processing_utils.encode_array_to_base64 = fast_numpy_encode
61
+
62
+ gv.USE_COLAB = user_args.colab
63
+ gv.MAX_THREADS = user_args.max_threads
64
+ gv.DEFAULT_OUTPUT_PATH = user_args.out_dir
65
+
66
+ PREFER_TEXT_WIDGET = user_args.prefer_text_widget
67
+
68
+ WORKSPACE = None
69
+ OUTPUT_FILE = None
70
+
71
+ preferred_device = "cpu" if USE_CPU else "cuda"
72
+ DEVICE_LIST = device_types_list
73
+ DEVICE, PROVIDER, OPTIONS = get_device_and_provider(device=preferred_device)
74
+ SWAP_MUKHAM = SwapMukham(device=DEVICE)
75
+
76
+ IS_RUNNING = False
77
+ CURRENT_FRAME = None
78
+ COLLECTED_FACES = []
79
+ FOREGROUND_MASK_DICT = {}
80
+ NSFW_CACHE = {}
81
+
82
+
83
+ ## ------------------------------ MAIN PROCESS ------------------------------
84
+
85
+
86
+ def process(
87
+ test_mode,
88
+ target_type,
89
+ image_path,
90
+ video_path,
91
+ directory_path,
92
+ source_path,
93
+ use_foreground_mask,
94
+ img_fg_mask,
95
+ fg_mask_softness,
96
+ output_path,
97
+ output_name,
98
+ use_datetime_suffix,
99
+ sequence_output_format,
100
+ keep_output_sequence,
101
+ swap_condition,
102
+ age,
103
+ distance,
104
+ face_enhancer_name,
105
+ face_upscaler_opacity,
106
+ use_face_parsing,
107
+ parse_from_target,
108
+ mask_regions,
109
+ mask_blur_amount,
110
+ mask_erode_amount,
111
+ swap_iteration,
112
+ face_scale,
113
+ use_laplacian_blending,
114
+ crop_top,
115
+ crop_bott,
116
+ crop_left,
117
+ crop_right,
118
+ current_idx,
119
+ number_of_threads,
120
+ use_frame_selection,
121
+ frame_selection_ranges,
122
+ video_quality,
123
+ face_detection_condition,
124
+ face_detection_size,
125
+ face_detection_threshold,
126
+ averaging_method,
127
+ progress=gr.Progress(track_tqdm=True),
128
+ *specifics,
129
+ ):
130
+ global WORKSPACE
131
+ global OUTPUT_FILE
132
+ global PREVIEW
133
+ WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
134
+
135
+ global IS_RUNNING
136
+ IS_RUNNING = True
137
+
138
+ ## ------------------------------ GUI UPDATE FUNC ------------------------------
139
+ def ui_before():
140
+ return (
141
+ gr.update(visible=True, value=None),
142
+ gr.update(interactive=False),
143
+ gr.update(interactive=False),
144
+ gr.update(visible=False, value=None),
145
+ )
146
+
147
+ def ui_after():
148
+ return (
149
+ gr.update(visible=True, value=PREVIEW),
150
+ gr.update(interactive=True),
151
+ gr.update(interactive=True),
152
+ gr.update(visible=False, value=None),
153
+ )
154
+
155
+ def ui_after_vid():
156
+ return (
157
+ gr.update(visible=False),
158
+ gr.update(interactive=True),
159
+ gr.update(interactive=True),
160
+ gr.update(value=OUTPUT_FILE, visible=True),
161
+ )
162
+
163
+ if not test_mode:
164
+ yield ui_before() # resets ui preview
165
+ progress(0, desc="Processing")
166
+
167
+ start_time = time.time()
168
+ total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
169
+ get_finsh_text = (
170
+ lambda start_time: f"Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
171
+ )
172
+
173
+ ## ------------------------------ PREPARE INPUTS ------------------------------
174
+
175
+ if use_datetime_suffix:
176
+ output_name = add_datetime_to_filename(output_name)
177
+
178
+ mask_regions = mask_regions_to_list(mask_regions)
179
+
180
+ specifics = list(specifics)
181
+ half = len(specifics) // 2
182
+ if swap_condition == "specific face":
183
+ source_specifics = [
184
+ ([s.name for s in src] if src is not None else None, spc) for src, spc in zip(specifics[:half], specifics[half:])
185
+ ]
186
+ else:
187
+ source_paths = [i.name for i in source_path]
188
+ source_specifics = [(source_paths, None)]
189
+
190
+ if crop_top > crop_bott:
191
+ crop_top, crop_bott = crop_bott, crop_top
192
+ if crop_left > crop_right:
193
+ crop_left, crop_right = crop_right, crop_left
194
+ crop_mask = (crop_top, 511 - crop_bott, crop_left, 511 - crop_right)
195
+
196
+ input_args = {
197
+ "similarity": distance,
198
+ "age": age,
199
+ "face_scale": face_scale,
200
+ "num_of_pass": swap_iteration,
201
+ "face_upscaler_opacity": face_upscaler_opacity,
202
+ "mask_crop_values": crop_mask,
203
+ "mask_erode_amount": mask_erode_amount,
204
+ "mask_blur_amount": mask_blur_amount,
205
+ "use_laplacian_blending": use_laplacian_blending,
206
+ "swap_condition": swap_condition,
207
+ "face_parse_regions": mask_regions,
208
+ "use_face_parsing": use_face_parsing,
209
+ "face_detection_size": [int(face_detection_size), int(face_detection_size)],
210
+ "face_detection_threshold": face_detection_threshold,
211
+ "face_detection_condition": face_detection_condition,
212
+ "parse_from_target": parse_from_target,
213
+ "averaging_method": averaging_method,
214
+ }
215
+
216
+ SWAP_MUKHAM.set_values(input_args)
217
+ if (
218
+ SWAP_MUKHAM.face_upscaler is None
219
+ or SWAP_MUKHAM.face_upscaler_name != face_enhancer_name
220
+ ):
221
+ SWAP_MUKHAM.load_face_upscaler(face_enhancer_name, device=DEVICE)
222
+ if SWAP_MUKHAM.face_parser is None and use_face_parsing:
223
+ SWAP_MUKHAM.load_face_parser(device=DEVICE)
224
+ SWAP_MUKHAM.analyse_source_faces(source_specifics)
225
+
226
+ mask = None
227
+ if use_foreground_mask and img_fg_mask is not None:
228
+ mask = img_fg_mask.get("mask", None)
229
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGRA2RGB)
230
+ if fg_mask_softness > 0:
231
+ mask = cv2.blur(mask, (int(fg_mask_softness), int(fg_mask_softness)))
232
+ mask = mask.astype("float32") / 255.0
233
+
234
+ def nsfw_assertion(is_nsfw):
235
+ if is_nsfw:
236
+ message = "NSFW content detected !"
237
+ gr.Info(message)
238
+ assert not is_nsfw, message
239
+
240
+ ## ------------------------------ IMAGE ------------------------------
241
+
242
+ if target_type == "Image" and not test_mode:
243
+ target = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
244
+
245
+ is_nsfw = SWAP_MUKHAM.nsfw_detector.check_image(target)
246
+ nsfw_assertion(is_nsfw)
247
+
248
+ output = SWAP_MUKHAM.process_frame(
249
+ [target, mask]
250
+ )
251
+ output_file = os.path.join(output_path, output_name + ".png")
252
+ cv2.imwrite(output_file, output)
253
+
254
+ PREVIEW = output
255
+ OUTPUT_FILE = output_file
256
+ WORKSPACE = output_path
257
+
258
+ gr.Info(get_finsh_text(start_time))
259
+ yield ui_after()
260
+
261
+ ## ------------------------------ VIDEO ------------------------------
262
+
263
+ elif target_type == "Video" and not test_mode:
264
+ video_path = video_path.replace('"', '').strip()
265
+
266
+ if video_path in NSFW_CACHE.keys():
267
+ nsfw_assertion(NSFW_CACHE.get(video_path))
268
+ else:
269
+ is_nsfw = SWAP_MUKHAM.nsfw_detector.check_video(video_path)
270
+ NSFW_CACHE[video_path] = is_nsfw
271
+ nsfw_assertion(is_nsfw)
272
+
273
+ temp_path = os.path.join(output_path, output_name)
274
+ os.makedirs(temp_path, exist_ok=True)
275
+
276
+ cap = cv2.VideoCapture(video_path)
277
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
278
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
279
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
280
+ fps = cap.get(cv2.CAP_PROP_FPS)
281
+
282
+ is_in_range = lambda idx: any([int(rng[0]) <= idx <= int(rng[1]) for rng in frame_selection_ranges]) if use_frame_selection else True
283
+
284
+ print("[ Swapping process started ]")
285
+
286
+ def swap_video_func(frame_index):
287
+ if IS_RUNNING:
288
+ with cv_reader_lock:
289
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index))
290
+ valid_frame, frame = cap.read()
291
+
292
+ if valid_frame:
293
+ if is_in_range(frame_index):
294
+ mask = FOREGROUND_MASK_DICT.get(frame_index, None) if use_foreground_mask else None
295
+ output = SWAP_MUKHAM.process_frame([frame, mask])
296
+ else:
297
+ output = frame
298
+ frame_path = os.path.join(temp_path, f"frame_{frame_index}.{sequence_output_format}")
299
+ if sequence_output_format == "jpg":
300
+ cv2.imwrite(frame_path, output, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
301
+ else:
302
+ cv2.imwrite(frame_path, output)
303
+
304
+ with concurrent.futures.ThreadPoolExecutor(max_workers=number_of_threads) as executor:
305
+ futures = [executor.submit(swap_video_func, idx) for idx in range(total_frames)]
306
+
307
+ with tqdm(total=total_frames, desc="Processing") as pbar:
308
+ for future in concurrent.futures.as_completed(futures):
309
+ future.result()
310
+ pbar.update(1)
311
+
312
+ cap.release()
313
+
314
+ if IS_RUNNING:
315
+ print("[ Merging image sequence ]")
316
+ progress(0, desc="Merging image sequence")
317
+ WORKSPACE = output_path
318
+ out_without_audio = output_name + "_without_audio" + ".mp4"
319
+ destination = os.path.join(output_path, out_without_audio)
320
+ crf = get_crf_for_resolution(max(width,height), video_quality)
321
+ ret, destination = ffmpeg_merge_frames(
322
+ temp_path, f"frame_%d.{sequence_output_format}", destination, fps=fps, crf=crf, ffmpeg_path=dp.FFMPEG_PATH
323
+ )
324
+ OUTPUT_FILE = destination
325
+
326
+ if ret:
327
+ print("[ Merging audio ]")
328
+ progress(0, desc="Merging audio")
329
+ OUTPUT_FILE = destination
330
+ out_with_audio = out_without_audio.replace("_without_audio", "")
331
+ _ret, _destination = ffmpeg_mux_audio(
332
+ video_path, out_without_audio, out_with_audio, ffmpeg_path=dp.FFMPEG_PATH
333
+ )
334
+
335
+ if _ret:
336
+ OUTPUT_FILE = _destination
337
+ os.remove(out_without_audio)
338
+
339
+ if os.path.exists(temp_path) and not keep_output_sequence:
340
+ print("[ Removing temporary files ]")
341
+ progress(0, desc="Removing temporary files")
342
+ shutil.rmtree(temp_path)
343
+
344
+ finish_text = get_finsh_text(start_time)
345
+ print(f"[ {finish_text} ]")
346
+ gr.Info(finish_text)
347
+ yield ui_after_vid()
348
+
349
+ ## ------------------------------ DIRECTORY ------------------------------
350
+
351
+ elif target_type == "Directory" and not test_mode:
352
+ temp_path = os.path.join(output_path, output_name)
353
+ temp_path = create_directory(temp_path, remove_existing=True)
354
+
355
+ directory_path = directory_path.replace('"', '').strip()
356
+ image_paths = get_images_from_directory(directory_path)
357
+
358
+ is_nsfw = SWAP_MUKHAM.nsfw_detector.check_image_paths(image_paths)
359
+ nsfw_assertion(is_nsfw)
360
+
361
+ new_image_paths = copy_files_to_directory(image_paths, temp_path)
362
+
363
+ def swap_func(img_path):
364
+ if IS_RUNNING:
365
+ frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
366
+ output = SWAP_MUKHAM.process_frame([frame, None])
367
+ cv2.imwrite(img_path, output)
368
+
369
+ with concurrent.futures.ThreadPoolExecutor(max_workers=number_of_threads) as executor:
370
+ futures = [executor.submit(swap_func, img_path) for img_path in new_image_paths]
371
+
372
+ with tqdm(total=len(new_image_paths), desc="Processing") as pbar:
373
+ for future in concurrent.futures.as_completed(futures):
374
+ future.result()
375
+ pbar.update(1)
376
+
377
+ PREVIEW = cv2.imread(new_image_paths[-1])
378
+ WORKSPACE = temp_path
379
+ OUTPUT_FILE = new_image_paths[-1]
380
+
381
+ gr.Info(get_finsh_text(start_time))
382
+ yield ui_after()
383
+
384
+ ## ------------------------------ STREAM ------------------------------
385
+
386
+ elif target_type == "Stream" and not test_mode:
387
+ pass
388
+
389
+ ## ------------------------------ TEST ------------------------------
390
+
391
+ if test_mode and target_type == "Video":
392
+ mask = None
393
+ if use_face_parsing_mask:
394
+ mask = FOREGROUND_MASK_DICT.get(current_idx, None)
395
+ if CURRENT_FRAME is not None and isinstance(CURRENT_FRAME, np.ndarray):
396
+ PREVIEW = SWAP_MUKHAM.process_frame(
397
+ [CURRENT_FRAME[:, :, ::-1], mask]
398
+ )
399
+ gr.Info(get_finsh_text(start_time))
400
+ yield ui_after()
401
+
402
+
403
+ ## ------------------------------ GRADIO GUI ------------------------------
404
+
405
+ css = """
406
+
407
+ div.gradio-container{
408
+ max-width: unset !important;
409
+ }
410
+
411
+ footer{
412
+ display:none !important
413
+ }
414
+
415
+ #slider_row {
416
+ display: flex;
417
+ flex-wrap: wrap;
418
+ justify-content: space-between;
419
+ }
420
+
421
+ #refresh_slider {
422
+ flex: 0 1 20%;
423
+ display: flex;
424
+ align-items: center;
425
+ }
426
+
427
+ #frame_slider {
428
+ flex: 1 0 80%;
429
+ display: flex;
430
+ align-items: center;
431
+ }
432
+
433
+ """
434
+
435
+ WIDGET_PREVIEW_HEIGHT = 450
436
+
437
+ with gr.Blocks(css=css, theme=gr.themes.Default()) as interface:
438
+ gr.Markdown("# πŸ—Ώ Swap Mukham")
439
+ gr.Markdown("### Single image face swapper")
440
+ with gr.Row():
441
+ with gr.Row():
442
+ with gr.Column(scale=0.35):
443
+ with gr.Tabs():
444
+ with gr.TabItem("πŸ“„ Input"):
445
+ swap_condition = gr.Dropdown(
446
+ gv.FACE_DETECT_CONDITIONS,
447
+ info="Choose which face or faces in the target image to swap.",
448
+ multiselect=False,
449
+ show_label=False,
450
+ value=gv.FACE_DETECT_CONDITIONS[0],
451
+ interactive=True,
452
+ )
453
+ age = gr.Number(
454
+ value=25, label="Value", interactive=True, visible=False
455
+ )
456
+
457
+ ## ------------------------------ SOURCE IMAGE ------------------------------
458
+
459
+ source_image_input = gr.Files(
460
+ label="Source face", type="file", interactive=True,
461
+ )
462
+
463
+ ## ------------------------------ SOURCE SPECIFIC ------------------------------
464
+
465
+ with gr.Box(visible=False) as specific_face:
466
+ for i in range(gv.NUM_OF_SRC_SPECIFIC):
467
+ idx = i + 1
468
+ code = "\n"
469
+ code += f"with gr.Tab(label='{idx}'):"
470
+ code += "\n\twith gr.Row():"
471
+ code += f"\n\t\tsrc{idx} = gr.Files(interactive=True, type='file', label='Source Face {idx}')"
472
+ code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
473
+ exec(code)
474
+
475
+ ## ------------------------------ TARGET TYPE ------------------------------
476
+
477
+ with gr.Group():
478
+ target_type = gr.Radio(
479
+ ["Image", "Video", "Directory"],
480
+ label="Target Type",
481
+ value="Video",
482
+ )
483
+
484
+ ## ------------------------------ TARGET IMAGE ------------------------------
485
+
486
+ with gr.Box(visible=False) as input_image_group:
487
+ target_image_input = gr.Image(
488
+ label="Target Image",
489
+ interactive=True,
490
+ type="filepath",
491
+ height=200
492
+ )
493
+
494
+ ## ------------------------------ TARGET VIDEO ------------------------------
495
+
496
+ with gr.Box(visible=True) as input_video_group:
497
+ with gr.Column():
498
+ video_widget = gr.Text if PREFER_TEXT_WIDGET else gr.Video
499
+ video_input = video_widget(
500
+ label="Target Video", interactive=True,
501
+ )
502
+
503
+ ## ------------------------------ FRAME SELECTION ------------------------------
504
+
505
+ with gr.Accordion("Frame Selection", open=False):
506
+ use_frame_selection = gr.Checkbox(
507
+ label="Use frame selection", value=False, interactive=True,
508
+ )
509
+ frame_selection_ranges = gr.Numpy(
510
+ headers=["Start Frame", "End Frame"],
511
+ datatype=["number", "number"],
512
+ row_count=1,
513
+ col_count=(2, "fixed"),
514
+ interactive=True
515
+ )
516
+
517
+ ## ------------------------------ TARGET DIRECTORY ------------------------------
518
+
519
+ with gr.Box(visible=False) as input_directory_group:
520
+ directory_input = gr.Text(
521
+ label="Target Image Directory", interactive=True
522
+ )
523
+
524
+ ## ------------------------------ TAB MODEL ------------------------------
525
+
526
+ with gr.TabItem("🎚️ Model"):
527
+ with gr.Accordion("Detection", open=False):
528
+ face_detection_condition = gr.Dropdown(
529
+ gv.SINGLE_FACE_DETECT_CONDITIONS,
530
+ label="Condition",
531
+ value=gv.DETECT_CONDITION,
532
+ interactive=True,
533
+ info="This condition is only used when multiple faces are detected on source or specific image.",
534
+ )
535
+ face_detection_size = gr.Number(
536
+ label="Detection Size",
537
+ value=gv.DETECT_SIZE,
538
+ interactive=True,
539
+ )
540
+ face_detection_threshold = gr.Number(
541
+ label="Detection Threshold",
542
+ value=gv.DETECT_THRESHOLD,
543
+ interactive=True,
544
+ )
545
+ face_scale = gr.Slider(
546
+ label="Landmark Scale",
547
+ minimum=0,
548
+ maximum=2,
549
+ value=1,
550
+ interactive=True,
551
+ )
552
+ with gr.Accordion("Embedding/Recognition", open=True):
553
+ averaging_method = gr.Dropdown(
554
+ gv.AVERAGING_METHODS,
555
+ label="Averaging Method",
556
+ value=gv.AVERAGING_METHOD,
557
+ interactive=True,
558
+ )
559
+ distance_slider = gr.Slider(
560
+ minimum=0,
561
+ maximum=2,
562
+ value=0.65,
563
+ interactive=True,
564
+ label="Specific-Target Distance",
565
+ )
566
+ with gr.Accordion("Swapper", open=True):
567
+ with gr.Row():
568
+ swap_iteration = gr.Slider(
569
+ label="Swap Iteration",
570
+ minimum=1,
571
+ maximum=4,
572
+ value=1,
573
+ step=1,
574
+ interactive=True,
575
+ )
576
+
577
+ ## ------------------------------ TAB POST-PROCESS ------------------------------
578
+
579
+ with gr.TabItem("πŸͺ„ Post-Process"):
580
+ with gr.Row():
581
+ face_enhancer_name = gr.Dropdown(
582
+ gv.FACE_ENHANCER_LIST,
583
+ label="Face Enhancer",
584
+ value="NONE",
585
+ multiselect=False,
586
+ interactive=True,
587
+ )
588
+ face_upscaler_opacity = gr.Slider(
589
+ label="Opacity",
590
+ minimum=0,
591
+ maximum=1,
592
+ value=1,
593
+ step=0.001,
594
+ interactive=True,
595
+ )
596
+
597
+ with gr.Accordion("Face Mask", open=False):
598
+ with gr.Group():
599
+ with gr.Row():
600
+ use_face_parsing_mask = gr.Checkbox(
601
+ label="Enable Face Parsing",
602
+ value=False,
603
+ interactive=True,
604
+ )
605
+ parse_from_target = gr.Checkbox(
606
+ label="Parse from target",
607
+ value=False,
608
+ interactive=True,
609
+ )
610
+ mask_regions = gr.Dropdown(
611
+ gv.MASK_REGIONS,
612
+ value=gv.MASK_REGIONS_DEFAULT,
613
+ multiselect=True,
614
+ label="Include",
615
+ interactive=True,
616
+ )
617
+
618
+ with gr.Accordion("Crop Face Bounding-Box", open=False):
619
+ with gr.Group():
620
+ with gr.Row():
621
+ crop_top = gr.Slider(
622
+ label="Top",
623
+ minimum=0,
624
+ maximum=511,
625
+ value=0,
626
+ step=1,
627
+ interactive=True,
628
+ )
629
+ crop_bott = gr.Slider(
630
+ label="Bottom",
631
+ minimum=0,
632
+ maximum=511,
633
+ value=511,
634
+ step=1,
635
+ interactive=True,
636
+ )
637
+ with gr.Row():
638
+ crop_left = gr.Slider(
639
+ label="Left",
640
+ minimum=0,
641
+ maximum=511,
642
+ value=0,
643
+ step=1,
644
+ interactive=True,
645
+ )
646
+ crop_right = gr.Slider(
647
+ label="Right",
648
+ minimum=0,
649
+ maximum=511,
650
+ value=511,
651
+ step=1,
652
+ interactive=True,
653
+ )
654
+
655
+ with gr.Row():
656
+ mask_erode_amount = gr.Slider(
657
+ label="Mask Erode",
658
+ minimum=0,
659
+ maximum=1,
660
+ value=gv.MASK_ERODE_AMOUNT,
661
+ step=0.001,
662
+ interactive=True,
663
+ )
664
+
665
+ mask_blur_amount = gr.Slider(
666
+ label="Mask Blur",
667
+ minimum=0,
668
+ maximum=1,
669
+ value=gv.MASK_BLUR_AMOUNT,
670
+ step=0.001,
671
+ interactive=True,
672
+ )
673
+
674
+ use_laplacian_blending = gr.Checkbox(
675
+ label="Laplacian Blending",
676
+ value=True,
677
+ interactive=True,
678
+ )
679
+
680
+ ## ------------------------------ TAB OUTPUT ------------------------------
681
+
682
+ with gr.TabItem("πŸ“€ Output"):
683
+ output_directory = gr.Text(
684
+ label="Output Directory",
685
+ value=gv.DEFAULT_OUTPUT_PATH,
686
+ interactive=True,
687
+ )
688
+ with gr.Group():
689
+ output_name = gr.Text(
690
+ label="Output Name", value="Result", interactive=True
691
+ )
692
+ use_datetime_suffix = gr.Checkbox(
693
+ label="Suffix date-time", value=True, interactive=True
694
+ )
695
+ with gr.Accordion("Video settings", open=True):
696
+ with gr.Row():
697
+ sequence_output_format = gr.Dropdown(
698
+ ["jpg", "png"],
699
+ label="Sequence format",
700
+ value="jpg",
701
+ interactive=True,
702
+ )
703
+ video_quality = gr.Dropdown(
704
+ gv.VIDEO_QUALITY_LIST,
705
+ label="Quality",
706
+ value=gv.VIDEO_QUALITY,
707
+ interactive=True
708
+ )
709
+ keep_output_sequence = gr.Checkbox(
710
+ label="Keep output sequence", value=False, interactive=True
711
+ )
712
+
713
+ ## ------------------------------ TAB PERFORMANCE ------------------------------
714
+ with gr.TabItem("πŸ› οΈ Performance"):
715
+ preview_resolution = gr.Dropdown(
716
+ gv.RESOLUTIONS,
717
+ label="Preview Resolution",
718
+ value="Original",
719
+ interactive=True,
720
+ )
721
+ number_of_threads = gr.Number(
722
+ step=1,
723
+ interactive=True,
724
+ label="Max number of threads",
725
+ value=gv.MAX_THREADS,
726
+ minimum=1,
727
+ )
728
+ with gr.Box():
729
+ with gr.Column():
730
+ with gr.Row():
731
+ face_analyser_device = gr.Radio(
732
+ DEVICE_LIST,
733
+ label="Face detection & recognition",
734
+ value=DEVICE,
735
+ interactive=True,
736
+ )
737
+ face_analyser_device_submit = gr.Button("Apply")
738
+ with gr.Row():
739
+ face_swapper_device = gr.Radio(
740
+ DEVICE_LIST,
741
+ label="Face swapper",
742
+ value=DEVICE,
743
+ interactive=True,
744
+ )
745
+ face_swapper_device_submit = gr.Button("Apply")
746
+ with gr.Row():
747
+ face_parser_device = gr.Radio(
748
+ DEVICE_LIST,
749
+ label="Face parsing",
750
+ value=DEVICE,
751
+ interactive=True,
752
+ )
753
+ face_parser_device_submit = gr.Button("Apply")
754
+ with gr.Row():
755
+ face_upscaler_device = gr.Radio(
756
+ DEVICE_LIST,
757
+ label="Face upscaler",
758
+ value=DEVICE,
759
+ interactive=True,
760
+ )
761
+ face_upscaler_device_submit = gr.Button("Apply")
762
+
763
+ face_analyser_device_submit.click(
764
+ fn=lambda d: SWAP_MUKHAM.load_face_analyser(
765
+ device=d
766
+ ),
767
+ inputs=[face_analyser_device],
768
+ )
769
+ face_swapper_device_submit.click(
770
+ fn=lambda d: SWAP_MUKHAM.load_face_swapper(
771
+ device=d
772
+ ),
773
+ inputs=[face_swapper_device],
774
+ )
775
+ face_parser_device_submit.click(
776
+ fn=lambda d: SWAP_MUKHAM.load_face_parser(device=d),
777
+ inputs=[face_parser_device],
778
+ )
779
+ face_upscaler_device_submit.click(
780
+ fn=lambda n, d: SWAP_MUKHAM.load_face_upscaler(
781
+ n, device=d
782
+ ),
783
+ inputs=[face_enhancer_name, face_upscaler_device],
784
+ )
785
+
786
+ ## ------------------------------ SWAP, CANCEL, FRAME SLIDER ------------------------------
787
+
788
+ with gr.Column(scale=0.65):
789
+ with gr.Row():
790
+ swap_button = gr.Button("✨ Swap", variant="primary")
791
+ cancel_button = gr.Button("β›” Cancel")
792
+ collect_faces = gr.Button("πŸ‘¨ Collect Faces")
793
+ test_swap = gr.Button("πŸ§ͺ Test Swap")
794
+
795
+ with gr.Box() as frame_slider_box:
796
+ with gr.Row(elem_id="slider_row", equal_height=True):
797
+ set_slider_range_btn = gr.Button(
798
+ "Set Range", interactive=True, elem_id="refresh_slider"
799
+ )
800
+ frame_slider = gr.Slider(
801
+ label="Frame",
802
+ minimum=0,
803
+ maximum=1,
804
+ value=0,
805
+ step=1,
806
+ interactive=True,
807
+ elem_id="frame_slider",
808
+ )
809
+
810
+ ## ------------------------------ PREVIEW ------------------------------
811
+
812
+ with gr.Tabs():
813
+ with gr.TabItem("Preview"):
814
+
815
+ preview_image = gr.Image(
816
+ label="Preview", type="numpy", interactive=False, height=WIDGET_PREVIEW_HEIGHT,
817
+ )
818
+
819
+ preview_video = gr.Video(
820
+ label="Output", interactive=False, visible=False, height=WIDGET_PREVIEW_HEIGHT,
821
+ )
822
+ preview_enabled_text = gr.Markdown(
823
+ "Disable paint foreground to preview !", visible=False
824
+ )
825
+ with gr.Row():
826
+ output_directory_button = gr.Button(
827
+ "πŸ“‚", interactive=False, visible=not gv.USE_COLAB
828
+ )
829
+ output_video_button = gr.Button(
830
+ "🎬", interactive=False, visible=not gv.USE_COLAB
831
+ )
832
+
833
+ output_directory_button.click(
834
+ lambda: open_directory(path=WORKSPACE),
835
+ inputs=None,
836
+ outputs=None,
837
+ )
838
+ output_video_button.click(
839
+ lambda: open_directory(path=OUTPUT_FILE),
840
+ inputs=None,
841
+ outputs=None,
842
+ )
843
+
844
+ ## ------------------------------ FOREGROUND MASK ------------------------------
845
+
846
+ with gr.TabItem("Paint Foreground"):
847
+ with gr.Box() as fg_mask_group:
848
+ with gr.Row():
849
+ with gr.Row():
850
+ use_foreground_mask = gr.Checkbox(
851
+ label="Use foreground mask", value=False, interactive=True)
852
+ fg_mask_softness = gr.Slider(
853
+ label="Mask Softness",
854
+ minimum=0,
855
+ maximum=200,
856
+ value=1,
857
+ step=1,
858
+ interactive=True,
859
+ )
860
+ add_fg_mask_btn = gr.Button("Add", interactive=True)
861
+ del_fg_mask_btn = gr.Button("Del", interactive=True)
862
+ img_fg_mask = gr.Image(
863
+ label="Paint Mask",
864
+ tool="sketch",
865
+ interactive=True,
866
+ type="numpy",
867
+ height=WIDGET_PREVIEW_HEIGHT,
868
+ )
869
+
870
+ ## ------------------------------ COLLECT FACE ------------------------------
871
+
872
+ with gr.TabItem("Collected Faces"):
873
+ collected_faces = gr.Gallery(
874
+ label="Faces",
875
+ show_label=False,
876
+ elem_id="gallery",
877
+ columns=[6], rows=[6], object_fit="contain", height=WIDGET_PREVIEW_HEIGHT,
878
+ )
879
+
880
+ ## ------------------------------ FOOTER LINKS ------------------------------
881
+
882
+ with gr.Row(variant='panel'):
883
+ gr.HTML(
884
+ """
885
+ <div style="display: flex; flex-direction: row; justify-content: center;">
886
+ <h3 style="margin-right: 10px;"><a href="https://github.com/sponsors/harisreedhar" style="text-decoration: none;">🀝 Sponsor</a></h3>
887
+ <h3 style="margin-right: 10px;"><a href="https://github.com/harisreedhar/Swap-Mukham" style="text-decoration: none;">πŸ‘¨β€πŸ’» Source</a></h3>
888
+ <h3 style="margin-right: 10px;"><a href="https://github.com/harisreedhar/Swap-Mukham#disclaimer" style="text-decoration: none;">⚠️ Disclaimer</a></h3>
889
+ <h3 style="margin-right: 10px;"><a href="https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb" style="text-decoration: none;">🌐 Colab</a></h3>
890
+ <h3><a href="https://github.com/harisreedhar/Swap-Mukham#acknowledgements" style="text-decoration: none;">πŸ€— Acknowledgements</a></h3>
891
+ </div>
892
+ """
893
+ )
894
+
895
+ ## ------------------------------ GRADIO EVENTS ------------------------------
896
+
897
+ def on_target_type_change(value):
898
+ visibility = {
899
+ "Image": (True, False, False, False, True, False, False, False),
900
+ "Video": (False, True, False, True, True, True, True, True),
901
+ "Directory": (False, False, True, False, False, False, False, False),
902
+ "Stream": (False, False, True, False, False, False, False, False),
903
+ }
904
+ return list(gr.update(visible=i) for i in visibility[value])
905
+
906
+ target_type.change(
907
+ on_target_type_change,
908
+ inputs=[target_type],
909
+ outputs=[
910
+ input_image_group,
911
+ input_video_group,
912
+ input_directory_group,
913
+ frame_slider_box,
914
+ fg_mask_group,
915
+ add_fg_mask_btn,
916
+ del_fg_mask_btn,
917
+ test_swap,
918
+ ],
919
+ )
920
+
921
+ target_image_input.change(
922
+ lambda inp: gr.update(value=inp),
923
+ inputs=[target_image_input],
924
+ outputs=[img_fg_mask]
925
+ )
926
+
927
+ def on_swap_condition_change(value):
928
+ visibility = {
929
+ "age less than": (True, False, True),
930
+ "age greater than": (True, False, True),
931
+ "specific face": (False, True, False),
932
+ }
933
+ return tuple(
934
+ gr.update(visible=i) for i in visibility.get(value, (False, False, True))
935
+ )
936
+
937
+ swap_condition.change(
938
+ on_swap_condition_change,
939
+ inputs=[swap_condition],
940
+ outputs=[age, specific_face, source_image_input],
941
+ )
942
+
943
+ def on_set_slider_range(video_path):
944
+ if video_path is None or not os.path.exists(video_path):
945
+ gr.Info("Check video path")
946
+ else:
947
+ try:
948
+ cap = cv2.VideoCapture(video_path)
949
+ fps = cap.get(cv2.CAP_PROP_FPS)
950
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
951
+ cap.release()
952
+ if total_frames > 0:
953
+ total_frames -= 1
954
+ return gr.Slider.update(
955
+ minimum=0, maximum=total_frames, value=0, interactive=True
956
+ )
957
+ gr.Info("Error fetching video")
958
+ except:
959
+ gr.Info("Error fetching video")
960
+
961
+ set_slider_range_event = set_slider_range_btn.click(
962
+ on_set_slider_range,
963
+ inputs=[video_input],
964
+ outputs=[frame_slider],
965
+ )
966
+
967
+ def update_preview(video_path, frame_index, use_foreground_mask, resolution):
968
+ if not os.path.exists(video_path):
969
+ yield gr.update(value=None), gr.update(value=None), gr.update(visible=False)
970
+ else:
971
+ frame = get_single_video_frame(video_path, frame_index)
972
+ if frame is not None:
973
+ if use_foreground_mask:
974
+ overlayed_image = frame
975
+ if frame_index in FOREGROUND_MASK_DICT.keys():
976
+ mask = FOREGROUND_MASK_DICT.get(frame_index, None)
977
+ if mask is not None:
978
+ overlayed_image = image_mask_overlay(frame, mask)
979
+ yield gr.update(value=None), gr.update(value=None), gr.update(visible=False) # clear previous mask
980
+ frame = resize_image_by_resolution(frame, resolution)
981
+ yield gr.update(value=frame[:, :, ::-1]), gr.update(
982
+ value=overlayed_image[:, :, ::-1], visible=True
983
+ ), gr.update(visible=False)
984
+ else:
985
+ frame = resize_image_by_resolution(frame, resolution)
986
+ yield gr.update(value=frame[:, :, ::-1]), gr.update(value=None), gr.update(
987
+ visible=False
988
+ )
989
+
990
+ global CURRENT_FRAME
991
+ CURRENT_FRAME = frame
992
+
993
+ frame_slider_event = frame_slider.change(
994
+ fn=update_preview,
995
+ inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
996
+ outputs=[preview_image, img_fg_mask, preview_video],
997
+ show_progress=False,
998
+ )
999
+
1000
+ def add_foreground_mask(fg, frame_index, softness):
1001
+ if fg is not None:
1002
+ mask = fg.get("mask", None)
1003
+ if mask is not None:
1004
+ alpha_rgb = cv2.cvtColor(mask, cv2.COLOR_BGRA2RGB)
1005
+ alpha_rgb = cv2.blur(alpha_rgb, (softness, softness))
1006
+ FOREGROUND_MASK_DICT[frame_index] = alpha_rgb.astype("float32") / 255.0
1007
+ gr.Info(f"saved mask index {frame_index}")
1008
+
1009
+ add_foreground_mask_event = add_fg_mask_btn.click(
1010
+ fn=add_foreground_mask,
1011
+ inputs=[img_fg_mask, frame_slider, fg_mask_softness],
1012
+ ).then(
1013
+ fn=update_preview,
1014
+ inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
1015
+ outputs=[preview_image, img_fg_mask, preview_video],
1016
+ show_progress=False,
1017
+ )
1018
+
1019
+ def delete_foreground_mask(frame_index):
1020
+ if frame_index in FOREGROUND_MASK_DICT.keys():
1021
+ FOREGROUND_MASK_DICT.pop(frame_index)
1022
+ gr.Info(f"Deleted mask index {frame_index}")
1023
+
1024
+ del_custom_mask_event = del_fg_mask_btn.click(
1025
+ fn=delete_foreground_mask, inputs=[frame_slider]
1026
+ ).then(
1027
+ fn=update_preview,
1028
+ inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
1029
+ outputs=[preview_image, img_fg_mask, preview_video],
1030
+ show_progress=False,
1031
+ )
1032
+
1033
+ def get_collected_faces(image):
1034
+ if image is not None:
1035
+ gr.Info(f"Collecting faces...")
1036
+ faces = SWAP_MUKHAM.collect_heads(image)
1037
+ COLLECTED_FACES.extend(faces)
1038
+ yield COLLECTED_FACES
1039
+ gr.Info(f"Collected {len(faces)} faces")
1040
+
1041
+ collect_faces.click(get_collected_faces, inputs=[preview_image], outputs=[collected_faces])
1042
+
1043
+ src_specific_inputs = []
1044
+ gen_variable_txt = ",".join(
1045
+ [f"src{i+1}" for i in range(gv.NUM_OF_SRC_SPECIFIC)]
1046
+ + [f"trg{i+1}" for i in range(gv.NUM_OF_SRC_SPECIFIC)]
1047
+ )
1048
+ exec(f"src_specific_inputs = ({gen_variable_txt})")
1049
+
1050
+ test_mode = gr.Checkbox(value=False, visible=False)
1051
+
1052
+ swap_inputs = [
1053
+ test_mode,
1054
+ target_type,
1055
+ target_image_input,
1056
+ video_input,
1057
+ directory_input,
1058
+ source_image_input,
1059
+ use_foreground_mask,
1060
+ img_fg_mask,
1061
+ fg_mask_softness,
1062
+ output_directory,
1063
+ output_name,
1064
+ use_datetime_suffix,
1065
+ sequence_output_format,
1066
+ keep_output_sequence,
1067
+ swap_condition,
1068
+ age,
1069
+ distance_slider,
1070
+ face_enhancer_name,
1071
+ face_upscaler_opacity,
1072
+ use_face_parsing_mask,
1073
+ parse_from_target,
1074
+ mask_regions,
1075
+ mask_blur_amount,
1076
+ mask_erode_amount,
1077
+ swap_iteration,
1078
+ face_scale,
1079
+ use_laplacian_blending,
1080
+ crop_top,
1081
+ crop_bott,
1082
+ crop_left,
1083
+ crop_right,
1084
+ frame_slider,
1085
+ number_of_threads,
1086
+ use_frame_selection,
1087
+ frame_selection_ranges,
1088
+ video_quality,
1089
+ face_detection_condition,
1090
+ face_detection_size,
1091
+ face_detection_threshold,
1092
+ averaging_method,
1093
+ *src_specific_inputs,
1094
+ ]
1095
+
1096
+ swap_outputs = [
1097
+ preview_image,
1098
+ output_directory_button,
1099
+ output_video_button,
1100
+ preview_video,
1101
+ ]
1102
+
1103
+ swap_event = swap_button.click(fn=process, inputs=swap_inputs, outputs=swap_outputs)
1104
+
1105
+ test_swap_settings = swap_inputs
1106
+ test_swap_settings[0] = gr.Checkbox(value=True, visible=False)
1107
+
1108
+ test_swap_event = test_swap.click(
1109
+ fn=update_preview,
1110
+ inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
1111
+ outputs=[preview_image, preview_video],
1112
+ show_progress=False,
1113
+ ).then(
1114
+ fn=process, inputs=test_swap_settings, outputs=swap_outputs, show_progress=True
1115
+ )
1116
+
1117
+ def stop_running():
1118
+ global IS_RUNNING
1119
+ IS_RUNNING = False
1120
+ print("[ Process cancelled ]")
1121
+ gr.Info("Process cancelled")
1122
+
1123
+ cancel_button.click(
1124
+ fn=stop_running,
1125
+ inputs=None,
1126
+ cancels=[swap_event, set_slider_range_event, test_swap_event],
1127
+ show_progress=True,
1128
+ )
1129
+
1130
+ if __name__ == "__main__":
1131
+ if gv.USE_COLAB:
1132
+ print("Running in colab mode")
1133
+
1134
+ interface.queue(concurrency_count=2, max_size=20).launch(share=gv.USE_COLAB)
assets/images/loading.gif ADDED
assets/images/logo.png ADDED
assets/pretrained_models/readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+
change_log.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Change-log
2
+
3
+ ## 30/07/2023
4
+ - change existing nsfw filter to open-nsfw from yahoo
5
+ - Add codeformer support
default_paths.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ FFMPEG_PATH = "./ffmpeg/ffmpeg" if os.path.exists("./ffmpeg/ffmpeg") else None
4
+
5
+ INSWAPPER_PATH = "./assets/pretrained_models/inswapper_128.onnx"
6
+ FACE_PARSER_PATH = "./assets/pretrained_models/faceparser.onnx"
7
+ ARCFACE_PATH = "./assets/pretrained_models/w600k_r50.onnx"
8
+ RETINAFACE_PATH = "./assets/pretrained_models/det_10g.onnx"
9
+ OPEN_NSFW_PATH = "./assets/pretrained_models/open-nsfw.onnx"
10
+ GENDERAGE_PATH = "./assets/pretrained_models/gender_age.onnx"
11
+
12
+ CODEFORMER_PATH = "./assets/pretrained_models/codeformer.onnx"
13
+ GFPGAN_V14_PATH = "./assets/pretrained_models/GFPGANv1.4.onnx"
14
+ GFPGAN_V13_PATH = "./assets/pretrained_models/GFPGANv1.3.onnx"
15
+ GFPGAN_V12_PATH = "./assets/pretrained_models/GFPGANv1.2.onnx"
16
+ GPEN_BFR_512_PATH = "./assets/pretrained_models/GPEN-BFR-512.onnx"
17
+ GPEN_BFR_256_PATH = "./assets/pretrained_models/GPEN-BFR-256.onnx"
18
+ RESTOREFORMER_PATH = "./assets/pretrained_models/restoreformer.onnx"
face_analyser.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import threading
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import concurrent.futures
7
+ import default_paths as dp
8
+ from dataclasses import dataclass
9
+ from utils.arcface import ArcFace
10
+ from utils.gender_age import GenderAge
11
+ from utils.retinaface import RetinaFace
12
+
13
+ cache = {}
14
+
15
+ @dataclass
16
+ class Face:
17
+ bbox: np.ndarray
18
+ kps: np.ndarray
19
+ det_score: float
20
+ embedding: np.ndarray
21
+ gender: int
22
+ age: int
23
+
24
+ def __getitem__(self, key):
25
+ return getattr(self, key)
26
+
27
+ def __setitem__(self, key, value):
28
+ if hasattr(self, key):
29
+ setattr(self, key, value)
30
+ else:
31
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
32
+
33
+ single_face_detect_conditions = [
34
+ "best detection",
35
+ "left most",
36
+ "right most",
37
+ "top most",
38
+ "bottom most",
39
+ "middle",
40
+ "biggest",
41
+ "smallest",
42
+ ]
43
+
44
+ multi_face_detect_conditions = [
45
+ "all face",
46
+ "specific face",
47
+ "age less than",
48
+ "age greater than",
49
+ "all male",
50
+ "all female"
51
+ ]
52
+
53
+ face_detect_conditions = multi_face_detect_conditions + single_face_detect_conditions
54
+
55
+
56
+ def get_single_face(faces, method="best detection"):
57
+ total_faces = len(faces)
58
+
59
+ if total_faces == 0:
60
+ return None
61
+
62
+ if total_faces == 1:
63
+ return faces[0]
64
+
65
+ if method == "best detection":
66
+ return sorted(faces, key=lambda face: face["det_score"])[-1]
67
+ elif method == "left most":
68
+ return sorted(faces, key=lambda face: face["bbox"][0])[0]
69
+ elif method == "right most":
70
+ return sorted(faces, key=lambda face: face["bbox"][0])[-1]
71
+ elif method == "top most":
72
+ return sorted(faces, key=lambda face: face["bbox"][1])[0]
73
+ elif method == "bottom most":
74
+ return sorted(faces, key=lambda face: face["bbox"][1])[-1]
75
+ elif method == "middle":
76
+ return sorted(faces, key=lambda face: (
77
+ (face["bbox"][0] + face["bbox"][2]) / 2 - 0.5) ** 2 +
78
+ ((face["bbox"][1] + face["bbox"][3]) / 2 - 0.5) ** 2)[len(faces) // 2]
79
+ elif method == "biggest":
80
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[-1]
81
+ elif method == "smallest":
82
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[0]
83
+
84
+ def filter_face_by_age(faces, age, method="age less than"):
85
+ if method == "age less than":
86
+ return [face for face in faces if face["age"] < age]
87
+ elif method == "age greater than":
88
+ return [face for face in faces if face["age"] > age]
89
+ elif method == "age equals to":
90
+ return [face for face in faces if face["age"] == age]
91
+
92
+ def cosine_distance(a, b):
93
+ a /= np.linalg.norm(a)
94
+ b /= np.linalg.norm(b)
95
+ return 1 - np.dot(a, b)
96
+
97
+ def is_similar_face(face1, face2, threshold=0.6):
98
+ distance = cosine_distance(face1["embedding"], face2["embedding"])
99
+ return distance < threshold
100
+
101
+
102
+ class AnalyseFace:
103
+ def __init__(self, provider=["CPUExecutionProvider"], session_options=None):
104
+ self.detector = RetinaFace(model_file=dp.RETINAFACE_PATH, provider=provider, session_options=session_options)
105
+ self.recognizer = ArcFace(model_file=dp.ARCFACE_PATH, provider=provider, session_options=session_options)
106
+ self.gender_age = GenderAge(model_file=dp.GENDERAGE_PATH, provider=provider, session_options=session_options)
107
+ self.detect_condition = "best detection"
108
+ self.detection_size = (640, 640)
109
+ self.detection_threshold = 0.5
110
+
111
+ def analyser(self, img, skip_task=[]):
112
+ bboxes, kpss = self.detector.detect(img, input_size=self.detection_size, det_thresh=self.detection_threshold)
113
+ faces = []
114
+ for i in range(bboxes.shape[0]):
115
+ feat, gender, age = None, None, None
116
+ bbox = bboxes[i, 0:4]
117
+ det_score = bboxes[i, 4]
118
+ kps = None
119
+ if kpss is not None:
120
+ kps = kpss[i]
121
+ if 'embedding' not in skip_task:
122
+ feat = self.recognizer.get(img, kpss[i])
123
+ if 'gender_age' not in skip_task:
124
+ gender, age = self.gender_age.predict(img, kpss[i])
125
+ face = Face(bbox=bbox, kps=kps, det_score=det_score, embedding=feat, gender=gender, age=age)
126
+ faces.append(face)
127
+ return faces
128
+
129
+ def get_faces(self, image, scale=1., skip_task=[]):
130
+ if isinstance(image, str):
131
+ image = cv2.imread(image)
132
+
133
+ faces = self.analyser(image, skip_task=skip_task)
134
+
135
+ if scale != 1: # landmark-scale
136
+ for i, face in enumerate(faces):
137
+ landmark = face['kps']
138
+ center = np.mean(landmark, axis=0)
139
+ landmark = center + (landmark - center) * scale
140
+ faces[i]['kps'] = landmark
141
+
142
+ return faces
143
+
144
+ def get_face(self, image, scale=1., skip_task=[]):
145
+ faces = self.get_faces(image, scale=scale, skip_task=skip_task)
146
+ return get_single_face(faces, method=self.detect_condition)
147
+
148
+ def get_averaged_face(self, images, method="mean"):
149
+ if not isinstance(images, list):
150
+ images = [images]
151
+
152
+ face = self.get_face(images[0], scale=1., skip_task=[])
153
+
154
+ if len(images) > 1:
155
+ embeddings = [face['embedding']]
156
+
157
+ for image in images[1:]:
158
+ face = self.get_face(image, scale=1., skip_task=[])
159
+ embeddings.append(face['embedding'])
160
+
161
+ if method == "mean":
162
+ avg_embedding = np.mean(embeddings, axis=0)
163
+ elif method == "median":
164
+ avg_embedding = np.median(embeddings, axis=0)
165
+
166
+ face['embedding'] = avg_embedding
167
+
168
+ return face
face_parsing.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import onnxruntime
3
+ import numpy as np
4
+
5
+ mask_regions = {
6
+ "Background":0,
7
+ "Skin":1,
8
+ "L-Eyebrow":2,
9
+ "R-Eyebrow":3,
10
+ "L-Eye":4,
11
+ "R-Eye":5,
12
+ "Eye-G":6,
13
+ "L-Ear":7,
14
+ "R-Ear":8,
15
+ "Ear-R":9,
16
+ "Nose":10,
17
+ "Mouth":11,
18
+ "U-Lip":12,
19
+ "L-Lip":13,
20
+ "Neck":14,
21
+ "Neck-L":15,
22
+ "Cloth":16,
23
+ "Hair":17,
24
+ "Hat":18
25
+ }
26
+
27
+
28
+ class FaceParser:
29
+ def __init__(self, model_path=None, provider=['CPUExecutionProvider'], session_options=None):
30
+ self.session_options = session_options
31
+ if self.session_options is None:
32
+ self.session_options = onnxruntime.SessionOptions()
33
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
34
+ self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
35
+ self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
36
+
37
+ def parse(self, img, regions=[1,2,3,4,5,10,11,12,13]):
38
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
39
+ img = img.astype(np.float32)[:,:,::-1] / 255.0
40
+ img = (img - self.mean) / self.std
41
+ img = np.expand_dims(img.transpose((2, 0, 1)), axis=0).astype(np.float32)
42
+
43
+ out = self.session.run(None, {'input':img})[0]
44
+ out = out.squeeze(0).argmax(0)
45
+ out = np.isin(out, regions).astype('float32')
46
+
47
+ return out.clip(0, 1)
48
+
49
+
50
+ def mask_regions_to_list(values):
51
+ out_ids = []
52
+ for value in values:
53
+ if value in mask_regions.keys():
54
+ out_ids.append(mask_regions.get(value))
55
+ return out_ids
face_swapper.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import onnx
3
+ import cv2
4
+ import onnxruntime
5
+ import numpy as np
6
+ from onnx import numpy_helper
7
+ from numpy.linalg import norm as l2norm
8
+ from utils.face_alignment import norm_crop2
9
+
10
+
11
+ class Inswapper():
12
+ def __init__(self, model_file=None, provider=['CPUExecutionProvider'], session_options=None):
13
+ self.model_file = model_file
14
+ model = onnx.load(self.model_file)
15
+ graph = model.graph
16
+ self.emap = numpy_helper.to_array(graph.initializer[-1])
17
+
18
+ self.session_options = session_options
19
+ if self.session_options is None:
20
+ self.session_options = onnxruntime.SessionOptions()
21
+ self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=provider)
22
+
23
+ def forward(self, frame, target, source, n_pass=1):
24
+ trg, matrix = norm_crop2(frame, target['kps'], 128)
25
+
26
+ latent = source['embedding'].reshape((1, -1))
27
+ latent = np.dot(latent, self.emap)
28
+ latent /= np.linalg.norm(latent)
29
+
30
+ blob = trg.astype('float32') / 255
31
+ blob = blob[:, :, ::-1]
32
+ blob = np.expand_dims(blob, axis=0).transpose(0, 3, 1, 2)
33
+
34
+ for _ in range(max(int(n_pass),1)):
35
+ blob = self.session.run(['output'], {'target': blob, 'source': latent})[0]
36
+
37
+ out = blob[0].transpose((1, 2, 0))
38
+ out = (out * 255).clip(0,255)
39
+ out = out.astype('uint8')[:, :, ::-1]
40
+
41
+ del blob, latent
42
+
43
+ return trg, out, matrix
face_upscaler.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import default_paths as dp
4
+ from upscaler.GPEN import GPEN
5
+ from upscaler.GFPGAN import GFPGAN
6
+ from upscaler.codeformer import CodeFormer
7
+ from upscaler.restoreformer import RestoreFormer
8
+
9
+ def gfpgan_runner(img, model):
10
+ img = model.enhance(img)
11
+ return img
12
+
13
+
14
+ def codeformer_runner(img, model):
15
+ img = model.enhance(img, w=0.9)
16
+ return img
17
+
18
+
19
+ def gpen_runner(img, model):
20
+ img = model.enhance(img)
21
+ return img
22
+
23
+
24
+ def restoreformer_runner(img, model):
25
+ img = model.enhance(img)
26
+ return img
27
+
28
+
29
+ supported_upscalers = {
30
+ "CodeFormer": (dp.CODEFORMER_PATH, codeformer_runner),
31
+ "GFPGANv1.4": (dp.GFPGAN_V14_PATH, gfpgan_runner),
32
+ "GFPGANv1.3": (dp.GFPGAN_V13_PATH, gfpgan_runner),
33
+ "GFPGANv1.2": (dp.GFPGAN_V12_PATH, gfpgan_runner),
34
+ "GPEN-BFR-512": (dp.GPEN_BFR_512_PATH, gpen_runner),
35
+ "GPEN-BFR-256": (dp.GPEN_BFR_256_PATH, gpen_runner),
36
+ "RestoreFormer": (dp.RESTOREFORMER_PATH, gpen_runner),
37
+ }
38
+
39
+ cv2_upscalers = ["LANCZOS4", "CUBIC", "NEAREST"]
40
+
41
+ def get_available_upscalers_names():
42
+ available = []
43
+ for name, data in supported_upscalers.items():
44
+ if os.path.exists(data[0]):
45
+ available.append(name)
46
+ return available
47
+
48
+
49
+ def load_face_upscaler(name='GFPGAN', provider=["CPUExecutionProvider"], session_options=None):
50
+ assert name in get_available_upscalers_names() + cv2_upscalers, f"Face upscaler {name} unavailable."
51
+ if name in supported_upscalers.keys():
52
+ model_path, model_runner = supported_upscalers.get(name)
53
+ if name == 'CodeFormer':
54
+ model = CodeFormer(model_path=model_path, provider=provider, session_options=session_options)
55
+ elif name.startswith('GFPGAN'):
56
+ model = GFPGAN(model_path=model_path, provider=provider, session_options=session_options)
57
+ elif name.startswith('GPEN'):
58
+ model = GPEN(model_path=model_path, provider=provider, session_options=session_options)
59
+ elif name == "RestoreFormer":
60
+ model = RestoreFormer(model_path=model_path, provider=provider, session_options=session_options)
61
+ elif name == 'LANCZOS4':
62
+ model = None
63
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_LANCZOS4)
64
+ elif name == 'CUBIC':
65
+ model = None
66
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_CUBIC)
67
+ elif name == 'NEAREST':
68
+ model = None
69
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_NEAREST)
70
+ else:
71
+ model = None
72
+ return (model, model_runner)
global_variables.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from face_parsing import mask_regions
3
+ from utils.image import resolution_map
4
+ from face_upscaler import get_available_upscalers_names, cv2_upscalers
5
+ from face_analyser import single_face_detect_conditions, face_detect_conditions
6
+
7
+ DEFAULT_OUTPUT_PATH = os.getcwd()
8
+
9
+ MASK_BLUR_AMOUNT = 0.1
10
+ MASK_ERODE_AMOUNT = 0.15
11
+ MASK_REGIONS_DEFAULT = ["Skin", "R-Eyebrow", "L-Eyebrow", "L-Eye", "R-Eye", "Nose", "Mouth", "L-Lip", "U-Lip"]
12
+ MASK_REGIONS = list(mask_regions.keys())
13
+
14
+ NSFW_DETECTOR = None
15
+
16
+ FACE_ENHANCER_LIST = ["NONE"]
17
+ FACE_ENHANCER_LIST.extend(get_available_upscalers_names())
18
+ FACE_ENHANCER_LIST.extend(cv2_upscalers)
19
+
20
+ RESOLUTIONS = list(resolution_map.keys())
21
+
22
+ SINGLE_FACE_DETECT_CONDITIONS = single_face_detect_conditions
23
+ FACE_DETECT_CONDITIONS = face_detect_conditions
24
+ DETECT_CONDITION = "best detection"
25
+ DETECT_SIZE = 640
26
+ DETECT_THRESHOLD = 0.6
27
+
28
+ NUM_OF_SRC_SPECIFIC = 10
29
+
30
+ MAX_THREADS = 2
31
+
32
+ VIDEO_QUALITY_LIST = ["poor", "low", "medium", "high", "best"]
33
+ VIDEO_QUALITY = "high"
34
+
35
+ AVERAGING_METHODS = ["mean", "median"]
36
+ AVERAGING_METHOD = "mean"
nsfw_checker/LICENSE.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Copyright 2016, Yahoo Inc.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
11
+
nsfw_checker/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . opennsfw import NSFWChecker
nsfw_checker/opennsfw.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import onnx
3
+ import onnxruntime
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ # https://github.com/yahoo/open_nsfw
8
+
9
+ def prepare_image(img):
10
+ img = cv2.resize(img, (224,224)).astype('float32')
11
+ img -= np.array([104, 117, 123], dtype=np.float32)
12
+ img = np.expand_dims(img, axis=0)
13
+ return img
14
+
15
+ class NSFWChecker:
16
+ def __init__(self, model_path=None, provider=["CPUExecutionProvider"], session_options=None):
17
+ model = onnx.load(model_path)
18
+ self.input_name = model.graph.input[0].name
19
+ self.session_options = session_options
20
+ if self.session_options == None:
21
+ self.session_options = onnxruntime.SessionOptions()
22
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
23
+
24
+ def check_image(self, image, threshold=0.9):
25
+ if isinstance(image, str):
26
+ image = cv2.imread(image)
27
+ img = prepare_image(image)
28
+ score = self.session.run(None, {self.input_name:img})[0][0][1]
29
+ if score >= threshold:
30
+ return True
31
+ return False
32
+
33
+ def check_video(self, video_path, threshold=0.9, max_frames=100):
34
+ cap = cv2.VideoCapture(video_path)
35
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36
+
37
+ max_frames = min(total_frames, max_frames)
38
+ indexes = np.arange(total_frames, dtype=int)
39
+ shuffled_indexes = np.random.permutation(indexes)[:max_frames]
40
+
41
+ for idx in tqdm(shuffled_indexes, desc="Checking"):
42
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
43
+ valid_frame, frame = cap.read()
44
+ if valid_frame:
45
+ img = prepare_image(frame)
46
+ score = self.session.run(None, {self.input_name:img})[0][0][1]
47
+ if score >= threshold:
48
+ cap.release()
49
+ return True
50
+ cap.release()
51
+ return False
52
+
53
+ def check_image_paths(self, image_paths, threshold=0.9, max_frames=100):
54
+ total_frames = len(image_paths)
55
+ max_frames = min(total_frames, max_frames)
56
+ indexes = np.arange(total_frames, dtype=int)
57
+ shuffled_indexes = np.random.permutation(indexes)[:max_frames]
58
+
59
+ for idx in tqdm(shuffled_indexes, desc="Checking"):
60
+ frame = cv2.imread(image_paths[idx])
61
+ img = prepare_image(frame)
62
+ score = self.session.run(None, {self.input_name:img})[0][0][1]
63
+ if score >= threshold:
64
+ return True
65
+ return False
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=3.40
2
+ numpy>=1.25.2
3
+ opencv-python>=4.7.0.72
4
+ opencv-python-headless>=4.7.0.72
5
+ onnx==1.14.0
6
+ onnxruntime==1.15.0
swap_mukham.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ import default_paths as dp
5
+ from utils.device import get_device_and_provider
6
+ from utils.face_alignment import get_cropped_head
7
+ from utils.image import paste_to_whole, mix_two_image
8
+
9
+ from face_swapper import Inswapper
10
+ from face_parsing import FaceParser
11
+ from face_upscaler import get_available_upscalers_names, cv2_upscalers, load_face_upscaler
12
+ from face_analyser import AnalyseFace, single_face_detect_conditions, face_detect_conditions, get_single_face, is_similar_face
13
+
14
+ from nsfw_checker import NSFWChecker
15
+
16
+ get_device_name = lambda x: x.lower().replace("executionprovider", "")
17
+
18
+ class SwapMukham:
19
+ def __init__(self, device='cpu'):
20
+ self.load_nsfw_detector(device=device)
21
+ self.load_face_swapper(device=device)
22
+ self.load_face_analyser(device=device)
23
+ # self.load_face_parser(device=device)
24
+ # self.load_face_upscaler(device=device)
25
+
26
+ self.face_parser = None
27
+ self.face_upscaler = None
28
+ self.face_upscaler_name = ""
29
+
30
+ def set_values(self, args):
31
+ self.age = args.get('age', 0)
32
+ self.detect_condition = args.get('detect_condition', "left most")
33
+ self.similarity = args.get('similarity', 0.6)
34
+ self.swap_condition = args.get('swap_condition', 'left most')
35
+ self.face_scale = args.get('face_scale', 1.0)
36
+ self.num_of_pass = args.get('num_of_pass', 1)
37
+ self.mask_crop_values = args.get('mask_crop_values', (0,0,0,0))
38
+ self.mask_erode_amount = args.get('mask_erode_amount', 0.1)
39
+ self.mask_blur_amount = args.get('mask_blur_amount', 0.1)
40
+ self.use_laplacian_blending = args.get('use_laplacian_blending', False)
41
+ self.use_face_parsing = args.get('use_face_parsing', False)
42
+ self.face_parse_regions = args.get('face_parse_regions', [1,2,3,4,5,10,11,12,13])
43
+ self.face_upscaler_opacity = args.get('face_upscaler_opacity', 1.)
44
+ self.parse_from_target = args.get('parse_from_target', False)
45
+ self.averaging_method = args.get('averaging_method', 'mean')
46
+
47
+ self.analyser.detection_threshold = args.get('face_detection_threshold', 0.5)
48
+ self.analyser.detection_size = args.get('face_detection_size', (640, 640))
49
+ self.analyser.detect_condition = args.get('face_detection_condition', 'best detection')
50
+
51
+ def load_nsfw_detector(self, device='cpu'):
52
+ device, provider, options = get_device_and_provider(device=device)
53
+ self.nsfw_detector = NSFWChecker(model_path=dp.OPEN_NSFW_PATH, provider=provider, session_options=options)
54
+ _device = get_device_name(self.nsfw_detector.session.get_providers()[0])
55
+ print(f"[{_device}] NSFW detector model loaded.")
56
+
57
+ def load_face_swapper(self, device='cpu'):
58
+ device, provider, options = get_device_and_provider(device=device)
59
+ self.swapper = Inswapper(model_file=dp.INSWAPPER_PATH, provider=provider, session_options=options)
60
+ _device = get_device_name(self.swapper.session.get_providers()[0])
61
+ print(f"[{_device}] Face swapper model loaded.")
62
+
63
+ def load_face_analyser(self, device='cpu'):
64
+ device, provider, options = get_device_and_provider(device=device)
65
+ self.analyser = AnalyseFace(provider=provider, session_options=options)
66
+ _device_d = get_device_name(self.analyser.detector.session.get_providers()[0])
67
+ print(f"[{_device_d}] Face detection model loaded.")
68
+ _device_r = get_device_name(self.analyser.recognizer.session.get_providers()[0])
69
+ print(f"[{_device_r}] Face recognition model loaded.")
70
+ _device_g = get_device_name(self.analyser.gender_age.session.get_providers()[0])
71
+ print(f"[{_device_g}] Gender & Age detection model loaded.")
72
+
73
+ def load_face_parser(self, device='cpu'):
74
+ device, provider, options = get_device_and_provider(device=device)
75
+ self.face_parser = FaceParser(model_path=dp.FACE_PARSER_PATH, provider=provider, session_options=options)
76
+ _device = get_device_name(self.face_parser.session.get_providers()[0])
77
+ print(f"[{_device}] Face parsing model loaded.")
78
+
79
+ def load_face_upscaler(self, name, device='cpu'):
80
+ device, provider, options = get_device_and_provider(device=device)
81
+ if name in get_available_upscalers_names():
82
+ self.face_upscaler = load_face_upscaler(name=name, provider=provider, session_options=options)
83
+ self.face_upscaler_name = name
84
+ _device = get_device_name(self.face_upscaler[0].session.get_providers()[0])
85
+ print(f"[{_device}] Face upscaler model ({name}) loaded.")
86
+ else:
87
+ self.face_upscaler_name = ""
88
+ self.face_upscaler = None
89
+
90
+ def collect_heads(self, frame):
91
+ faces = self.analyser.get_faces(frame, skip_task=['embedding', 'gender_age'])
92
+ return [get_cropped_head(frame, face.kps) for face in faces if face["det_score"] > 0.5]
93
+
94
+ def analyse_source_faces(self, source_specific):
95
+ analysed_source_specific = []
96
+ for i, (source, specific) in enumerate(source_specific):
97
+ if source is not None:
98
+ analysed_source = self.analyser.get_averaged_face(source, method=self.averaging_method)
99
+ if specific is not None:
100
+ analysed_specific = self.analyser.get_face(specific)
101
+ else:
102
+ analysed_specific = None
103
+ analysed_source_specific.append((analysed_source, analysed_specific))
104
+ self.analysed_source_specific = analysed_source_specific
105
+
106
+ def process_frame(self, data):
107
+ frame, custom_mask = data
108
+
109
+ if len(frame.shape) == 2 or (len(frame.shape) == 3 and frame.shape[2] == 1):
110
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
111
+
112
+ alpha = None
113
+ if frame.shape[2] == 4:
114
+ alpha = frame[:, :, 3]
115
+ frame = frame[:, :, :3]
116
+
117
+ _frame = frame.copy()
118
+ condition = self.swap_condition
119
+
120
+ skip_task = []
121
+ if condition != "specific face":
122
+ skip_task.append('embedding')
123
+ if condition not in ['age less than', 'age greater than', 'all male', 'all female']:
124
+ skip_task.append('gender_age')
125
+
126
+ analysed_target_faces = self.analyser.get_faces(frame, scale=self.face_scale, skip_task=skip_task)
127
+
128
+ for analysed_target in analysed_target_faces:
129
+ if (condition == "all face" or
130
+ (condition == "age less than" and analysed_target["age"] <= self.age) or
131
+ (condition == "age greater than" and analysed_target["age"] > self.age) or
132
+ (condition == "all male" and analysed_target["gender"] == 1) or
133
+ (condition == "all female" and analysed_target["gender"] == 0)):
134
+
135
+ trg_face = analysed_target
136
+ src_face = self.analysed_source_specific[0][0]
137
+ _frame = self.swap_face(_frame, trg_face, src_face)
138
+
139
+ elif condition == "specific face":
140
+ for analysed_source, analysed_specific in self.analysed_source_specific:
141
+ if is_similar_face(analysed_specific, analysed_target, threshold=self.similarity):
142
+ trg_face = analysed_target
143
+ src_face = analysed_source
144
+ _frame = self.swap_face(_frame, trg_face, src_face)
145
+
146
+ if condition in single_face_detect_conditions and len(analysed_target_faces) > 0:
147
+ analysed_target = get_single_face(analysed_target_faces, method=condition)
148
+ trg_face = analysed_target
149
+ src_face = self.analysed_source_specific[0][0]
150
+ _frame = self.swap_face(_frame, trg_face, src_face)
151
+
152
+ if custom_mask is not None:
153
+ _mask = cv2.resize(custom_mask, _frame.shape[:2][::-1])
154
+ _frame = _mask * frame.astype('float32') + (1 - _mask) * _frame.astype('float32')
155
+ _frame = _frame.clip(0,255).astype('uint8')
156
+
157
+ if alpha is not None:
158
+ _frame = np.dstack((_frame, alpha))
159
+
160
+ return _frame
161
+
162
+ def swap_face(self, frame, trg_face, src_face):
163
+ target_face, generated_face, matrix = self.swapper.forward(frame, trg_face, src_face, n_pass=self.num_of_pass)
164
+ upscaled_face, matrix = self.upscale_face(generated_face, matrix)
165
+ if self.parse_from_target:
166
+ mask = self.face_parsed_mask(target_face)
167
+ else:
168
+ mask = self.face_parsed_mask(upscaled_face)
169
+ result = paste_to_whole(
170
+ upscaled_face,
171
+ frame,
172
+ matrix,
173
+ mask=mask,
174
+ crop_mask=self.mask_crop_values,
175
+ blur_amount=self.mask_blur_amount,
176
+ erode_amount = self.mask_erode_amount
177
+ )
178
+ return result
179
+
180
+ def upscale_face(self, face, matrix):
181
+ face_size = face.shape[0]
182
+ _face = cv2.resize(face, (512,512))
183
+ if self.face_upscaler is not None:
184
+ model, runner = self.face_upscaler
185
+ face = runner(face, model)
186
+ upscaled_face = cv2.resize(face, (512,512))
187
+ upscaled_face = mix_two_image(_face, upscaled_face, self.face_upscaler_opacity)
188
+ return upscaled_face, matrix * (512/face_size)
189
+
190
+ def face_parsed_mask(self, face):
191
+ if self.face_parser is not None and self.use_face_parsing:
192
+ mask = self.face_parser.parse(face, regions=self.face_parse_regions)
193
+ else:
194
+ mask = None
195
+ return mask
swap_mukham_colab.ipynb ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "bypvIQG5RHl9"
17
+ },
18
+ "source": [
19
+ "# πŸ—Ώ **Swap-Mukham**\n",
20
+ "*Face swap app based on insightface inswapper.*\n",
21
+ "- [Github](https://github.com/harisreedhar/Swap-Mukham)\n",
22
+ "- [Disclaimer](https://github.com/harisreedhar/Swap-Mukham#disclaimer)"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "metadata": {
28
+ "id": "csC_DX5zWLEU"
29
+ },
30
+ "source": [
31
+ "# Clone Repository"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {
38
+ "id": "klcx2cKDKX5x"
39
+ },
40
+ "outputs": [],
41
+ "source": [
42
+ "#@title\n",
43
+ "! git clone https://github.com/harisreedhar/Swap-Mukham"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {
49
+ "id": "bebBDddfWTXf"
50
+ },
51
+ "source": [
52
+ "# Install Requirements"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {
59
+ "id": "VgTpg7EsTN3o"
60
+ },
61
+ "outputs": [],
62
+ "source": [
63
+ "#@title\n",
64
+ "%cd Swap-Mukham/\n",
65
+ "print(\"Installing requirements...\")\n",
66
+ "!pip install -r requirements.txt -q\n",
67
+ "!pip install gdown\n",
68
+ "print(\"Installing requirements done.\")"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {
74
+ "id": "T9L6tgD0Wats"
75
+ },
76
+ "source": [
77
+ "# Download Models"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {
84
+ "id": "17MZO9OvUQAk"
85
+ },
86
+ "outputs": [],
87
+ "source": [
88
+ "#@title\n",
89
+ "inswapper_model = \"https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx\" #@param {type:\"string\"}\n",
90
+ "gfpgan_model = \"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth\" #@param {type:\"string\"}\n",
91
+ "face_parser_model = \"https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812\" #@param {type:\"string\"}\n",
92
+ "real_esrgan_2x_model = \"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth\" #@param {type:\"string\"}\n",
93
+ "real_esrgan_4x_model = \"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth\" #@param {type:\"string\"}\n",
94
+ "real_esrgan_8x_model = \"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x8.pth\" #@param {type:\"string\"}\n",
95
+ "codeformer_model = \"https://huggingface.co/bluefoxcreation/Codeformer-ONNX/resolve/main/codeformer.onnx\" #@param {type:\"string\"}\n",
96
+ "nsfw_det_model = \"https://huggingface.co/bluefoxcreation/open-nsfw/resolve/main/open-nsfw.onnx\" #@param {type:\"string\"}\n",
97
+ "import gdown\n",
98
+ "import urllib.request\n",
99
+ "print(\"Downloading swapper model...\")\n",
100
+ "urllib.request.urlretrieve(inswapper_model, \"/content/Swap-Mukham/assets/pretrained_models/inswapper_128.onnx\")\n",
101
+ "print(\"Downloading gfpgan model...\")\n",
102
+ "urllib.request.urlretrieve(gfpgan_model, \"/content/Swap-Mukham/assets/pretrained_models/GFPGANv1.4.pth\")\n",
103
+ "print(\"Downloading face parsing model...\")\n",
104
+ "gdown.download(face_parser_model, \"/content/Swap-Mukham/assets/pretrained_models/79999_iter.pth\")\n",
105
+ "print(\"Downloading realesrgan 2x model...\")\n",
106
+ "urllib.request.urlretrieve(real_esrgan_2x_model, \"/content/Swap-Mukham/assets/pretrained_models/RealESRGAN_x2.pth\")\n",
107
+ "print(\"Downloading realesrgan 4x model...\")\n",
108
+ "urllib.request.urlretrieve(real_esrgan_4x_model, \"/content/Swap-Mukham/assets/pretrained_models/RealESRGAN_x4.pth\")\n",
109
+ "print(\"Downloading realesrgan 8x model...\")\n",
110
+ "urllib.request.urlretrieve(real_esrgan_8x_model, \"/content/Swap-Mukham/assets/pretrained_models/RealESRGAN_x8.pth\")\n",
111
+ "print(\"Downloading codeformer...\")\n",
112
+ "urllib.request.urlretrieve(codeformer_model, \"/content/Swap-Mukham/assets/pretrained_models/codeformer.onnx\")\n",
113
+ "print(\"Downloading NSFW detector model...\")\n",
114
+ "urllib.request.urlretrieve(nsfw_det_model, \"/content/Swap-Mukham/assets/pretrained_models/open-nsfw.onnx\")\n",
115
+ "print(\"Downloading models done.\")"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "markdown",
120
+ "metadata": {
121
+ "id": "uEcCUw0Co6bE"
122
+ },
123
+ "source": [
124
+ "# Mount Google drive (optional)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {
131
+ "id": "4KssYYippDMw"
132
+ },
133
+ "outputs": [],
134
+ "source": [
135
+ "from google.colab import auth, drive\n",
136
+ "auth.authenticate_user()\n",
137
+ "drive.mount('/content/drive')"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "metadata": {
143
+ "id": "-Tn68Ayqdrlk"
144
+ },
145
+ "source": [
146
+ "# Run App\n",
147
+ "\n"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "metadata": {
154
+ "id": "6dpBjbfVOrrc"
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "#@title\n",
159
+ "default_output_path = \"/content/Swap-Mukham\" #@param {type:\"string\"}\n",
160
+ "\n",
161
+ "command = f\"python app.py --cuda --colab --out_dir {default_output_path}\"\n",
162
+ "!{command}"
163
+ ]
164
+ }
165
+ ],
166
+ "metadata": {
167
+ "accelerator": "GPU",
168
+ "colab": {
169
+ "gpuType": "T4",
170
+ "include_colab_link": true,
171
+ "provenance": []
172
+ },
173
+ "kernelspec": {
174
+ "display_name": "Python 3",
175
+ "name": "python3"
176
+ },
177
+ "language_info": {
178
+ "name": "python"
179
+ }
180
+ },
181
+ "nbformat": 4,
182
+ "nbformat_minor": 0
183
+ }
upscaler/GFPGAN.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import onnxruntime
4
+ import numpy as np
5
+ import threading
6
+ import time
7
+
8
+ # gfpgan converted to onnx
9
+ # using https://github.com/xuanandsix/GFPGAN-onnxruntime-demo
10
+ # same inference code for GFPGANv1.2, GFPGANv1.3, GFPGANv1.4
11
+
12
+ lock = threading.Lock()
13
+
14
+ class GFPGAN:
15
+ def __init__(self, model_path="GFPGANv1.4.onnx", provider=["CPUExecutionProvider"], session_options=None):
16
+ self.session_options = session_options
17
+ if self.session_options is None:
18
+ self.session_options = onnxruntime.SessionOptions()
19
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
20
+ self.resolution = self.session.get_inputs()[0].shape[-2:]
21
+
22
+ def preprocess(self, img):
23
+ img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
24
+ img = img.astype(np.float32)[:,:,::-1] / 255.0
25
+ img = img.transpose((2, 0, 1))
26
+ img = (img - 0.5) / 0.5
27
+ img = np.expand_dims(img, axis=0).astype(np.float32)
28
+ return img
29
+
30
+ def postprocess(self, img):
31
+ img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
32
+ img = (img * 255)[:,:,::-1]
33
+ img = img.clip(0, 255).astype('uint8')
34
+ return img
35
+
36
+ def enhance(self, img):
37
+ img = self.preprocess(img)
38
+ with lock:
39
+ output = self.session.run(None, {'input':img})[0][0]
40
+ output = self.postprocess(output)
41
+ return output
upscaler/GPEN.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import onnxruntime
4
+ import numpy as np
5
+ import threading
6
+ import time
7
+
8
+ lock = threading.Lock()
9
+
10
+ class GPEN:
11
+ def __init__(self, model_path="GPEN-BFR-512.onnx", provider=["CPUExecutionProvider"], session_options=None):
12
+ self.session_options = session_options
13
+ if self.session_options is None:
14
+ self.session_options = onnxruntime.SessionOptions()
15
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
16
+ self.resolution = self.session.get_inputs()[0].shape[-2:]
17
+
18
+ def preprocess(self, img):
19
+ img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
20
+ img = img.astype(np.float32)[:,:,::-1] / 255.0
21
+ img = img.transpose((2, 0, 1))
22
+ img = (img - 0.5) / 0.5
23
+ img = np.expand_dims(img, axis=0).astype(np.float32)
24
+ return img
25
+
26
+ def postprocess(self, img):
27
+ img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
28
+ img = (img * 255)[:,:,::-1]
29
+ img = img.clip(0, 255).astype('uint8')
30
+ return img
31
+
32
+ def enhance(self, img):
33
+ img = self.preprocess(img)
34
+ with lock:
35
+ output = self.session.run(None, {'input':img})[0][0]
36
+ output = self.postprocess(output)
37
+ return output
upscaler/__init__.py ADDED
File without changes
upscaler/codeformer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import onnxruntime
4
+ import numpy as np
5
+ import threading
6
+ import time
7
+
8
+ # codeformer converted to onnx
9
+ # using https://github.com/redthing1/CodeFormer
10
+
11
+ lock = threading.Lock()
12
+
13
+ class CodeFormer:
14
+ def __init__(self, model_path="codeformer.onnx", provider=["CPUExecutionProvider"], session_options=None):
15
+ self.session_options = session_options
16
+ if self.session_options is None:
17
+ self.session_options = onnxruntime.SessionOptions()
18
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
19
+ self.resolution = self.session.get_inputs()[0].shape[-2:]
20
+
21
+ def preprocess(self, img, w):
22
+ img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
23
+ img = img.astype(np.float32)[:,:,::-1] / 255.0
24
+ img = img.transpose((2, 0, 1))
25
+ img = (img - 0.5) / 0.5
26
+ img = np.expand_dims(img, axis=0).astype(np.float32)
27
+ w = np.array([w], dtype=np.double)
28
+ return img, w
29
+
30
+ def postprocess(self, img):
31
+ img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
32
+ img = (img * 255)[:,:,::-1]
33
+ img = img.clip(0, 255).astype('uint8')
34
+ return img
35
+
36
+ def enhance(self, img, w=0.9):
37
+ img, w = self.preprocess(img, w)
38
+ with lock:
39
+ output = self.session.run(None, {'x':img, 'w':w})[0][0]
40
+ output = self.postprocess(output)
41
+ return output
upscaler/restoreformer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import onnxruntime
4
+ import numpy as np
5
+ import threading
6
+ import time
7
+
8
+ lock = threading.Lock()
9
+
10
+ class RestoreFormer:
11
+ def __init__(self, model_path="restoreformer.onnx", provider=["CPUExecutionProvider"], session_options=None):
12
+ self.session_options = session_options
13
+ if self.session_options is None:
14
+ self.session_options = onnxruntime.SessionOptions()
15
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
16
+ self.resolution = self.session.get_inputs()[0].shape[-2:]
17
+
18
+ def preprocess(self, img):
19
+ img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
20
+ img = img.astype(np.float32)[:,:,::-1] / 255.0
21
+ img = img.transpose((2, 0, 1))
22
+ img = (img - 0.5) / 0.5
23
+ img = np.expand_dims(img, axis=0).astype(np.float32)
24
+ return img
25
+
26
+ def postprocess(self, img):
27
+ img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
28
+ img = (img * 255)[:,:,::-1]
29
+ img = img.clip(0, 255).astype('uint8')
30
+ return img
31
+
32
+ def enhance(self, img):
33
+ img = self.preprocess(img)
34
+ with lock:
35
+ output = self.session.run(None, {'input':img})[0][0]
36
+ output = self.postprocess(output)
37
+ return output
utils/__init__.py ADDED
File without changes
utils/arcface.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : insightface.ai
3
+ # @Author : Jia Guo
4
+ # @Time : 2021-09-18
5
+ # @Function :
6
+
7
+
8
+ import os
9
+ import cv2
10
+ import onnx
11
+ import onnxruntime
12
+ import numpy as np
13
+ import default_paths as dp
14
+ from .face_alignment import norm_crop2
15
+
16
+
17
+ class ArcFace:
18
+ def __init__(self, model_file=None, provider=['CUDAExecutionProvider'], session_options=None):
19
+ assert model_file is not None
20
+ self.model_file = model_file
21
+ self.taskname = 'recognition'
22
+ find_sub = False
23
+ find_mul = False
24
+ model = onnx.load(self.model_file)
25
+ graph = model.graph
26
+ for nid, node in enumerate(graph.node[:8]):
27
+ #print(nid, node.name)
28
+ if node.name.startswith('Sub') or node.name.startswith('_minus'):
29
+ find_sub = True
30
+ if node.name.startswith('Mul') or node.name.startswith('_mul'):
31
+ find_mul = True
32
+ if find_sub and find_mul:
33
+ #mxnet arcface model
34
+ input_mean = 0.0
35
+ input_std = 1.0
36
+ else:
37
+ input_mean = 127.5
38
+ input_std = 127.5
39
+ self.input_mean = input_mean
40
+ self.input_std = input_std
41
+ #print('input mean and std:', self.input_mean, self.input_std)
42
+ self.session_options = session_options
43
+ if self.session_options is None:
44
+ self.session_options = onnxruntime.SessionOptions()
45
+ self.session = onnxruntime.InferenceSession(self.model_file, providers=provider, sess_options=self.session_options)
46
+ input_cfg = self.session.get_inputs()[0]
47
+ input_shape = input_cfg.shape
48
+ input_name = input_cfg.name
49
+ self.input_size = tuple(input_shape[2:4][::-1])
50
+ self.input_shape = input_shape
51
+ outputs = self.session.get_outputs()
52
+ output_names = []
53
+ for out in outputs:
54
+ output_names.append(out.name)
55
+ self.input_name = input_name
56
+ self.output_names = output_names
57
+ assert len(self.output_names)==1
58
+ self.output_shape = outputs[0].shape
59
+
60
+ def prepare(self, ctx_id, **kwargs):
61
+ if ctx_id<0:
62
+ self.session.set_providers(['CPUExecutionProvider'])
63
+
64
+ def get(self, img, kps):
65
+ aimg, matrix = norm_crop2(img, landmark=kps, image_size=self.input_size[0])
66
+ embedding = self.get_feat(aimg).flatten()
67
+ return embedding
68
+
69
+ def compute_sim(self, feat1, feat2):
70
+ from numpy.linalg import norm
71
+ feat1 = feat1.ravel()
72
+ feat2 = feat2.ravel()
73
+ sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
74
+ return sim
75
+
76
+ def get_feat(self, imgs):
77
+ if not isinstance(imgs, list):
78
+ imgs = [imgs]
79
+ input_size = self.input_size
80
+
81
+ blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
82
+ (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
83
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
84
+ return net_out
85
+
86
+ def forward(self, batch_data):
87
+ blob = (batch_data - self.input_mean) / self.input_std
88
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
89
+ return net_out
utils/device.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx
2
+ import onnxruntime
3
+
4
+ device_types_list = ["cpu", "cuda"]
5
+
6
+ available_providers = onnxruntime.get_available_providers()
7
+
8
+ def get_device_and_provider(device='cpu'):
9
+ options = onnxruntime.SessionOptions()
10
+ options.log_severity_level = 3
11
+ if device == 'cuda':
12
+ if "CUDAExecutionProvider" in available_providers:
13
+ provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"]
14
+ options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
15
+ else:
16
+ device = 'cpu'
17
+ provider = ["CPUExecutionProvider"]
18
+ else:
19
+ device = 'cpu'
20
+ provider = ["CPUExecutionProvider"]
21
+
22
+ return device, provider, options
23
+
24
+
25
+ data_type_bytes = {'uint8': 1, 'int8': 1, 'uint16': 2, 'int16': 2, 'float16': 2, 'float32': 4}
26
+
27
+
28
+ def estimate_max_batch_size(resolution, chunk_size=1024, data_type='float32', channels=3):
29
+ pixel_size = data_type_bytes.get(data_type, 1)
30
+ image_size = resolution[0] * resolution[1] * pixel_size * channels
31
+ number_of_batches = (chunk_size * 1024 * 1024) // image_size
32
+ return max(number_of_batches, 1)
utils/face_alignment.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def umeyama(src, dst, estimate_scale):
6
+ num = src.shape[0]
7
+ dim = src.shape[1]
8
+ src_mean = src.mean(axis=0)
9
+ dst_mean = dst.mean(axis=0)
10
+ src_demean = src - src_mean
11
+ dst_demean = dst - dst_mean
12
+ A = np.dot(dst_demean.T, src_demean) / num
13
+ d = np.ones((dim,), dtype=np.double)
14
+ if np.linalg.det(A) < 0:
15
+ d[dim - 1] = -1
16
+ T = np.eye(dim + 1, dtype=np.double)
17
+ U, S, V = np.linalg.svd(A)
18
+ rank = np.linalg.matrix_rank(A)
19
+ if rank == 0:
20
+ return np.nan * T
21
+ elif rank == dim - 1:
22
+ if np.linalg.det(U) * np.linalg.det(V) > 0:
23
+ T[:dim, :dim] = np.dot(U, V)
24
+ else:
25
+ s = d[dim - 1]
26
+ d[dim - 1] = -1
27
+ T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V))
28
+ d[dim - 1] = s
29
+ else:
30
+ T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T))
31
+ if estimate_scale:
32
+ scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d)
33
+ else:
34
+ scale = 1.0
35
+ T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T)
36
+ T[:dim, :dim] *= scale
37
+ return T
38
+
39
+
40
+ arcface_dst = np.array(
41
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
42
+ [41.5493, 92.3655], [70.7299, 92.2041]],
43
+ dtype=np.float32)
44
+
45
+
46
+ def estimate_norm(lmk, image_size=112, mode='arcface'):
47
+ assert lmk.shape == (5, 2)
48
+ assert image_size % 112 == 0 or image_size % 128 == 0
49
+ if image_size % 112 == 0:
50
+ ratio = float(image_size) / 112.0
51
+ diff_x = 0
52
+ else:
53
+ ratio = float(image_size) / 128.0
54
+ diff_x = 8.0 * ratio
55
+ dst = arcface_dst * ratio
56
+ dst[:, 0] += diff_x
57
+ M = umeyama(lmk, dst, True)[0:2, :]
58
+ return M
59
+
60
+
61
+ def norm_crop2(img, landmark, image_size=112, mode='arcface'):
62
+ M = estimate_norm(landmark, image_size, mode)
63
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0, borderMode=cv2.BORDER_REPLICATE)
64
+ return warped, M
65
+
66
+
67
+ def get_cropped_head(img, landmark, scale=1.4):
68
+ # it is ugly but works :D
69
+ center = np.mean(landmark, axis=0)
70
+ landmark = center + (landmark - center) * scale
71
+ M = estimate_norm(landmark, 128, mode='arcface')
72
+ warped = cv2.warpAffine(img, M/0.25, (512, 512), borderValue=0.0)
73
+ return warped
utils/gender_age.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import onnxruntime
4
+ from .face_alignment import norm_crop2
5
+
6
+ class GenderAge:
7
+ def __init__(self, model_file=None, provider=['CPUExecutionProvider'], session_options=None):
8
+ self.model_file = model_file
9
+ self.session_options = session_options
10
+ if self.session_options is None:
11
+ self.session_options = onnxruntime.SessionOptions()
12
+ self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=provider)
13
+
14
+ def predict(self, img, kps):
15
+ aimg, matrix = norm_crop2(img, kps, 128)
16
+
17
+ blob = cv2.resize(aimg, (62,62), interpolation=cv2.INTER_AREA)
18
+ blob = np.expand_dims(blob, axis=0).astype('float32')
19
+
20
+ _prob, _age = self.session.run(None, {'data':blob})
21
+ prob = _prob[0][0][0]
22
+ age = round(_age[0][0][0][0] * 100)
23
+ gender = np.argmax(prob)
24
+
25
+ return gender, age
utils/image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import base64
3
+ import numpy as np
4
+
5
+
6
+ def laplacian_blending(A, B, m, num_levels=7):
7
+ assert A.shape == B.shape
8
+ assert B.shape == m.shape
9
+ height = m.shape[0]
10
+ width = m.shape[1]
11
+ size_list = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192])
12
+ size = size_list[np.where(size_list > max(height, width))][0]
13
+ GA = np.zeros((size, size, 3), dtype=np.float32)
14
+ GA[:height, :width, :] = A
15
+ GB = np.zeros((size, size, 3), dtype=np.float32)
16
+ GB[:height, :width, :] = B
17
+ GM = np.zeros((size, size, 3), dtype=np.float32)
18
+ GM[:height, :width, :] = m
19
+ gpA = [GA]
20
+ gpB = [GB]
21
+ gpM = [GM]
22
+ for i in range(num_levels):
23
+ GA = cv2.pyrDown(GA)
24
+ GB = cv2.pyrDown(GB)
25
+ GM = cv2.pyrDown(GM)
26
+ gpA.append(np.float32(GA))
27
+ gpB.append(np.float32(GB))
28
+ gpM.append(np.float32(GM))
29
+ lpA = [gpA[num_levels-1]]
30
+ lpB = [gpB[num_levels-1]]
31
+ gpMr = [gpM[num_levels-1]]
32
+ for i in range(num_levels-1,0,-1):
33
+ LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
34
+ LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
35
+ lpA.append(LA)
36
+ lpB.append(LB)
37
+ gpMr.append(gpM[i-1])
38
+ LS = []
39
+ for la,lb,gm in zip(lpA,lpB,gpMr):
40
+ ls = la * gm + lb * (1.0 - gm)
41
+ LS.append(ls)
42
+ ls_ = LS[0]
43
+ for i in range(1,num_levels):
44
+ ls_ = cv2.pyrUp(ls_)
45
+ ls_ = cv2.add(ls_, LS[i])
46
+ ls_ = ls_[:height, :width, :]
47
+ #ls_ = (ls_ - np.min(ls_)) * (255.0 / (np.max(ls_) - np.min(ls_)))
48
+ return ls_.clip(0, 255)
49
+
50
+
51
+ def mask_crop(mask, crop):
52
+ top, bottom, left, right = crop
53
+ shape = mask.shape
54
+ top = int(top)
55
+ bottom = int(bottom)
56
+ if top + bottom < shape[1]:
57
+ if top > 0: mask[:top, :] = 0
58
+ if bottom > 0: mask[-bottom:, :] = 0
59
+
60
+ left = int(left)
61
+ right = int(right)
62
+ if left + right < shape[0]:
63
+ if left > 0: mask[:, :left] = 0
64
+ if right > 0: mask[:, -right:] = 0
65
+
66
+ return mask
67
+
68
+ def create_image_grid(images, size=128):
69
+ num_images = len(images)
70
+ num_cols = int(np.ceil(np.sqrt(num_images)))
71
+ num_rows = int(np.ceil(num_images / num_cols))
72
+ grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
73
+
74
+ for i, image in enumerate(images):
75
+ row_idx = (i // num_cols) * size
76
+ col_idx = (i % num_cols) * size
77
+ image = cv2.resize(image.copy(), (size,size))
78
+ if image.dtype != np.uint8:
79
+ image = (image.astype('float32') * 255).astype('uint8')
80
+ if image.ndim == 2:
81
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
82
+ grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
83
+
84
+ return grid
85
+
86
+
87
+ def paste_to_whole(foreground, background, matrix, mask=None, crop_mask=(0,0,0,0), blur_amount=0.1, erode_amount = 0.15, blend_method='linear'):
88
+ inv_matrix = cv2.invertAffineTransform(matrix)
89
+ fg_shape = foreground.shape[:2]
90
+ bg_shape = (background.shape[1], background.shape[0])
91
+ foreground = cv2.warpAffine(foreground, inv_matrix, bg_shape, borderValue=0.0, borderMode=cv2.BORDER_REPLICATE)
92
+
93
+ if mask is None:
94
+ mask = np.full(fg_shape, 1., dtype=np.float32)
95
+ mask = mask_crop(mask, crop_mask)
96
+ mask = cv2.warpAffine(mask, inv_matrix, bg_shape, borderValue=0.0)
97
+ else:
98
+ assert fg_shape == mask.shape[:2], "foreground & mask shape mismatch!"
99
+ mask = mask_crop(mask, crop_mask).astype('float32')
100
+ mask = cv2.warpAffine(mask, inv_matrix, (background.shape[1], background.shape[0]), borderValue=0.0)
101
+
102
+ _mask = mask.copy()
103
+ _mask[_mask > 0.05] = 1.
104
+ non_zero_points = cv2.findNonZero(_mask)
105
+ _, _, w, h = cv2.boundingRect(non_zero_points)
106
+ mask_size = int(np.sqrt(w * h))
107
+
108
+ if erode_amount > 0:
109
+ kernel_size = max(int(mask_size * erode_amount), 1)
110
+ structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
111
+ mask = cv2.erode(mask, structuring_element)
112
+
113
+ if blur_amount > 0:
114
+ kernel_size = max(int(mask_size * blur_amount), 3)
115
+ if kernel_size % 2 == 0:
116
+ kernel_size += 1
117
+ mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
118
+
119
+ mask = np.tile(np.expand_dims(mask, axis=-1), (1, 1, 3))
120
+
121
+ if blend_method == 'laplacian':
122
+ composite_image = laplacian_blending(foreground, background, mask.clip(0,1), num_levels=4)
123
+ else:
124
+ composite_image = mask * foreground + (1 - mask) * background
125
+
126
+ return composite_image.astype("uint8").clip(0, 255)
127
+
128
+
129
+ def image_mask_overlay(img, mask):
130
+ img = img.astype('float32') / 255.
131
+ img *= (mask + 0.25).clip(0, 1)
132
+ img = np.clip(img * 255., 0., 255.).astype('uint8')
133
+ return img
134
+
135
+
136
+ def resize_with_padding(img, expected_size=(640, 360), color=(0, 0, 0), max_flip=False):
137
+ original_height, original_width = img.shape[:2]
138
+
139
+ if max_flip and original_height > original_width:
140
+ expected_size = (expected_size[1], expected_size[0])
141
+
142
+ aspect_ratio = original_width / original_height
143
+ new_width = expected_size[0]
144
+ new_height = int(new_width / aspect_ratio)
145
+
146
+ if new_height > expected_size[1]:
147
+ new_height = expected_size[1]
148
+ new_width = int(new_height * aspect_ratio)
149
+
150
+ resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
151
+ canvas = cv2.copyMakeBorder(resized_img,
152
+ top=(expected_size[1] - new_height) // 2,
153
+ bottom=(expected_size[1] - new_height + 1) // 2,
154
+ left=(expected_size[0] - new_width) // 2,
155
+ right=(expected_size[0] - new_width + 1) // 2,
156
+ borderType=cv2.BORDER_CONSTANT, value=color)
157
+ return canvas
158
+
159
+
160
+ def create_image_grid(images, size=128):
161
+ num_images = len(images)
162
+ num_cols = int(np.ceil(np.sqrt(num_images)))
163
+ num_rows = int(np.ceil(num_images / num_cols))
164
+ grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
165
+
166
+ for i, image in enumerate(images):
167
+ row_idx = (i // num_cols) * size
168
+ col_idx = (i % num_cols) * size
169
+ image = cv2.resize(image.copy(), (size,size))
170
+ if image.dtype != np.uint8:
171
+ image = (image.astype('float32') * 255).astype('uint8')
172
+ if image.ndim == 2:
173
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
174
+ grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
175
+
176
+ return grid
177
+
178
+
179
+ def image_to_html(img, size=(640, 360), extension="jpg"):
180
+ if img is not None:
181
+ img = resize_with_padding(img, expected_size=size)
182
+ buffer = cv2.imencode(f".{extension}", img)[1]
183
+ base64_data = base64.b64encode(buffer.tobytes())
184
+ imgbs64 = f"data:image/{extension};base64," + base64_data.decode("utf-8")
185
+ html = '<div style="display: flex; justify-content: center; align-items: center; width: 100%;">'
186
+ html += f'<img src={imgbs64} alt="No Preview" style="max-width: 100%; max-height: 100%;">'
187
+ html += '</div>'
188
+ return html
189
+ return None
190
+
191
+
192
+ def mix_two_image(a, b, opacity=1.):
193
+ a_dtype = a.dtype
194
+ b_dtype = b.dtype
195
+ a = a.astype('float32')
196
+ b = b.astype('float32')
197
+ a = cv2.resize(a, (b.shape[0], b.shape[1]))
198
+ opacity = min(max(opacity, 0.), 1.)
199
+ mixed_img = opacity * b + (1 - opacity) * a
200
+ return mixed_img.astype(a_dtype)
201
+
202
+ resolution_map = {
203
+ "Original": None,
204
+ "240p": (426, 240),
205
+ "360p": (640, 360),
206
+ "480p": (854, 480),
207
+ "720p": (1280, 720),
208
+ "1080p": (1920, 1080),
209
+ "1440p": (2560, 1440),
210
+ "2160p": (3840, 2160),
211
+ }
212
+
213
+ def resize_image_by_resolution(img, quality):
214
+ resolution = resolution_map.get(quality, None)
215
+ if resolution is None:
216
+ return img
217
+
218
+ h, w = img.shape[:2]
219
+ if h > w:
220
+ ratio = resolution[0] / h
221
+ else:
222
+ ratio = resolution[0] / w
223
+
224
+ new_h, new_w = int(h * ratio), int(w * ratio)
225
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
226
+ return img
227
+
228
+ def fast_pil_encode(pil_image):
229
+ image_arr = np.asarray(pil_image)[:,:,::-1]
230
+ buffer = cv2.imencode('.jpg', image_arr)[1]
231
+ base64_data = base64.b64encode(buffer.tobytes())
232
+ return "data:image/jpg;base64," + base64_data.decode("utf-8")
233
+
234
+ def fast_numpy_encode(img_array):
235
+ buffer = cv2.imencode('.jpg', img_array)[1]
236
+ base64_data = base64.b64encode(buffer.tobytes())
237
+ return "data:image/jpg;base64," + base64_data.decode("utf-8")
238
+
239
+ crf_quality_by_resolution = {
240
+ 240: {"poor": 45, "low": 35, "medium": 28, "high": 23, "best": 20},
241
+ 360: {"poor": 35, "low": 28, "medium": 23, "high": 20, "best": 18},
242
+ 480: {"poor": 28, "low": 23, "medium": 20, "high": 18, "best": 16},
243
+ 720: {"poor": 23, "low": 20, "medium": 18, "high": 16, "best": 14},
244
+ 1080: {"poor": 20, "low": 18, "medium": 16, "high": 14, "best": 12},
245
+ 1440: {"poor": 18, "low": 16, "medium": 14, "high": 12, "best": 10},
246
+ 2160: {"poor": 16, "low": 14, "medium": 12, "high": 10, "best": 8}
247
+ }
248
+
249
+ def get_crf_for_resolution(resolution, quality):
250
+ available_resolutions = list(crf_quality_by_resolution.keys())
251
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
252
+ return crf_quality_by_resolution[closest_resolution][quality]
utils/io.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import shutil
5
+ import subprocess
6
+ from datetime import datetime
7
+
8
+
9
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
10
+
11
+ def get_images_from_directory(directory_path):
12
+ file_paths =[]
13
+ for file_path in glob.glob(os.path.join(directory_path, "*")):
14
+ if any(file_path.lower().endswith(ext) for ext in image_extensions):
15
+ file_paths.append(file_path)
16
+ file_paths.sort()
17
+ return file_paths
18
+
19
+
20
+ def open_directory(path=None):
21
+ if path is None:
22
+ return
23
+ try:
24
+ os.startfile(path)
25
+ except:
26
+ subprocess.Popen(["xdg-open", path])
27
+
28
+
29
+ def copy_files_to_directory(files, destination):
30
+ file_paths = []
31
+ for file_path in files:
32
+ new_file_path = shutil.copy(file_path, destination)
33
+ file_paths.append(new_file_path)
34
+ return file_paths
35
+
36
+
37
+ def create_directory(directory_path, remove_existing=True):
38
+ if os.path.exists(directory_path) and remove_existing:
39
+ shutil.rmtree(directory_path)
40
+
41
+ if not os.path.exists(directory_path):
42
+ os.mkdir(directory_path)
43
+ return directory_path
44
+ else:
45
+ counter = 1
46
+ while True:
47
+ new_directory_path = f"{directory_path}_{counter}"
48
+ if not os.path.exists(new_directory_path):
49
+ os.mkdir(new_directory_path)
50
+ return new_directory_path
51
+ counter += 1
52
+
53
+
54
+ def add_datetime_to_filename(filename):
55
+ current_datetime = datetime.now()
56
+ formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")
57
+ file_name, file_extension = os.path.splitext(filename)
58
+ new_filename = f"{file_name}_{formatted_datetime}{file_extension}"
59
+ return new_filename
60
+
61
+
62
+ def get_single_video_frame(video_path, frame_index):
63
+ cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
64
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
65
+ frame_index = min(int(frame_index), total_frames-1)
66
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index))
67
+ valid_frame, frame = cap.read()
68
+ cap.release()
69
+ if valid_frame:
70
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
+ return frame
72
+ return None
73
+
74
+
75
+ def get_video_fps(video_path):
76
+ cap = cv2.VideoCapture(video_path)
77
+ fps = cap.get(cv2.CAP_PROP_FPS)
78
+ cap.release()
79
+ return fps
80
+
81
+
82
+ def ffmpeg_extract_frames(video_path, destination, remove_existing=True, fps=30, name='frame_%d.jpg', ffmpeg_path=None):
83
+ ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
84
+ destination = create_directory(destination, remove_existing=remove_existing)
85
+ cmd = [
86
+ ffmpeg_path,
87
+ '-loglevel', 'info',
88
+ '-hwaccel', 'auto',
89
+ '-i', video_path,
90
+ '-q:v', '3',
91
+ '-pix_fmt', 'rgb24',
92
+ '-vf', 'fps=' + str(fps),
93
+ '-y',
94
+ os.path.join(destination, name)
95
+ ]
96
+ process = subprocess.Popen(cmd)
97
+ process.communicate()
98
+ if process.returncode == 0:
99
+ return True, get_images_from_directory(destination)
100
+ else:
101
+ print(f"Error: Failed to extract video.")
102
+ return False, None
103
+
104
+
105
+ def ffmpeg_merge_frames(sequence_directory, pattern, destination, fps=30, crf=18, ffmpeg_path=None):
106
+ ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
107
+ cmd = [
108
+ ffmpeg_path,
109
+ '-loglevel', 'info',
110
+ '-hwaccel', 'auto',
111
+ '-r', str(fps),
112
+ # '-pattern_type', 'glob',
113
+ '-i', os.path.join(sequence_directory, pattern),
114
+ '-c:v', 'libx264',
115
+ '-crf', str(crf),
116
+ '-pix_fmt', 'yuv420p',
117
+ '-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1',
118
+ '-y', destination
119
+ ]
120
+ process = subprocess.Popen(cmd)
121
+ process.communicate()
122
+ if process.returncode == 0:
123
+ return True, destination
124
+ else:
125
+ print(f"Error: Failed to merge image sequence.")
126
+ return False, None
127
+
128
+
129
+ def ffmpeg_replace_video_segments(main_video_path, sub_clips_info, output_path, ffmpeg_path="ffmpeg"):
130
+ ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
131
+ filter_complex = ""
132
+
133
+ filter_complex += f"[0:v]split=2[v0][main_end]; "
134
+ filter_complex += f"[1:v]split={len(sub_clips_info)}{', '.join([f'[v{index + 1}]' for index in range(len(sub_clips_info))])}; "
135
+
136
+ overlay_exprs = "".join([f"[v{index + 1}]" for index in range(len(sub_clips_info))])
137
+ overlay_filters = f"[main_end][{overlay_exprs}]overlay=eof_action=pass[vout]; "
138
+ filter_complex += overlay_filters
139
+
140
+ cmd = [
141
+ ffmpeg_path, '-i', main_video_path,
142
+ ]
143
+
144
+ for sub_clip_path, _, _ in sub_clips_info:
145
+ cmd.extend(['-i', sub_clip_path])
146
+
147
+ cmd.extend([
148
+ '-filter_complex', filter_complex,
149
+ '-map', '[vout]',
150
+ output_path
151
+ ])
152
+
153
+ subprocess.run(cmd)
154
+
155
+
156
+ def ffmpeg_mux_audio(source, target, output, ffmpeg_path=None):
157
+ ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
158
+ extracted_audio_path = os.path.join(os.path.dirname(output), 'extracted_audio.aac')
159
+ cmd1 = [
160
+ ffmpeg_path,
161
+ '-loglevel', 'info',
162
+ '-i', source,
163
+ '-vn',
164
+ '-c:a', 'aac',
165
+ '-y',
166
+ extracted_audio_path
167
+ ]
168
+ process = subprocess.Popen(cmd1)
169
+ process.communicate()
170
+ if process.returncode != 0:
171
+ print(f"Error: Failed to extract audio.")
172
+ return False, target
173
+
174
+ cmd2 = [
175
+ ffmpeg_path,
176
+ '-loglevel', 'info',
177
+ '-hwaccel', 'auto',
178
+ '-i', target,
179
+ '-i', extracted_audio_path,
180
+ '-c:v', 'copy',
181
+ '-map', '0:v:0',
182
+ '-map', '1:a:0',
183
+ '-y', output
184
+ ]
185
+ process = subprocess.Popen(cmd2)
186
+ process.communicate()
187
+ if process.returncode == 0:
188
+ if os.path.exists(extracted_audio_path):
189
+ os.remove(extracted_audio_path)
190
+ return True, output
191
+ else:
192
+ print(f"Error: Failed to mux audio.")
193
+ return False, None
194
+
utils/retinaface.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : insightface.ai
3
+ # @Author : Jia Guo
4
+ # @Time : 2021-09-18
5
+ # @Function :
6
+
7
+ from __future__ import division
8
+ import datetime
9
+ import numpy as np
10
+ import onnx
11
+ import onnxruntime
12
+ import os
13
+ import cv2
14
+ import sys
15
+ import default_paths as dp
16
+
17
+ def softmax(z):
18
+ assert len(z.shape) == 2
19
+ s = np.max(z, axis=1)
20
+ s = s[:, np.newaxis] # necessary step to do broadcasting
21
+ e_x = np.exp(z - s)
22
+ div = np.sum(e_x, axis=1)
23
+ div = div[:, np.newaxis] # dito
24
+ return e_x / div
25
+
26
+ def distance2bbox(points, distance, max_shape=None):
27
+ """Decode distance prediction to bounding box.
28
+
29
+ Args:
30
+ points (Tensor): Shape (n, 2), [x, y].
31
+ distance (Tensor): Distance from the given point to 4
32
+ boundaries (left, top, right, bottom).
33
+ max_shape (tuple): Shape of the image.
34
+
35
+ Returns:
36
+ Tensor: Decoded bboxes.
37
+ """
38
+ x1 = points[:, 0] - distance[:, 0]
39
+ y1 = points[:, 1] - distance[:, 1]
40
+ x2 = points[:, 0] + distance[:, 2]
41
+ y2 = points[:, 1] + distance[:, 3]
42
+ if max_shape is not None:
43
+ x1 = x1.clamp(min=0, max=max_shape[1])
44
+ y1 = y1.clamp(min=0, max=max_shape[0])
45
+ x2 = x2.clamp(min=0, max=max_shape[1])
46
+ y2 = y2.clamp(min=0, max=max_shape[0])
47
+ return np.stack([x1, y1, x2, y2], axis=-1)
48
+
49
+ def distance2kps(points, distance, max_shape=None):
50
+ """Decode distance prediction to bounding box.
51
+
52
+ Args:
53
+ points (Tensor): Shape (n, 2), [x, y].
54
+ distance (Tensor): Distance from the given point to 4
55
+ boundaries (left, top, right, bottom).
56
+ max_shape (tuple): Shape of the image.
57
+
58
+ Returns:
59
+ Tensor: Decoded bboxes.
60
+ """
61
+ preds = []
62
+ for i in range(0, distance.shape[1], 2):
63
+ px = points[:, i%2] + distance[:, i]
64
+ py = points[:, i%2+1] + distance[:, i+1]
65
+ if max_shape is not None:
66
+ px = px.clamp(min=0, max=max_shape[1])
67
+ py = py.clamp(min=0, max=max_shape[0])
68
+ preds.append(px)
69
+ preds.append(py)
70
+ return np.stack(preds, axis=-1)
71
+
72
+ class RetinaFace:
73
+ def __init__(self, model_file=None, provider=["CPUExecutionProvider"], session_options=None):
74
+ self.model_file = model_file
75
+ self.session_options = session_options
76
+ if self.session_options is None:
77
+ self.session_options = onnxruntime.SessionOptions()
78
+ self.session = onnxruntime.InferenceSession(self.model_file, providers=provider, sess_options=self.session_options)
79
+ self.center_cache = {}
80
+ self.nms_thresh = 0.4
81
+ self.det_thresh = 0.5
82
+ self._init_vars()
83
+
84
+ def _init_vars(self):
85
+ input_cfg = self.session.get_inputs()[0]
86
+ input_shape = input_cfg.shape
87
+ #print(input_shape)
88
+ if isinstance(input_shape[2], str):
89
+ self.input_size = None
90
+ else:
91
+ self.input_size = tuple(input_shape[2:4][::-1])
92
+ #print('image_size:', self.image_size)
93
+ input_name = input_cfg.name
94
+ self.input_shape = input_shape
95
+ outputs = self.session.get_outputs()
96
+ output_names = []
97
+ for o in outputs:
98
+ output_names.append(o.name)
99
+ self.input_name = input_name
100
+ self.output_names = output_names
101
+ self.input_mean = 127.5
102
+ self.input_std = 128.0
103
+ #print(self.output_names)
104
+ #assert len(outputs)==10 or len(outputs)==15
105
+ self.use_kps = False
106
+ self._anchor_ratio = 1.0
107
+ self._num_anchors = 1
108
+ if len(outputs)==6:
109
+ self.fmc = 3
110
+ self._feat_stride_fpn = [8, 16, 32]
111
+ self._num_anchors = 2
112
+ elif len(outputs)==9:
113
+ self.fmc = 3
114
+ self._feat_stride_fpn = [8, 16, 32]
115
+ self._num_anchors = 2
116
+ self.use_kps = True
117
+ elif len(outputs)==10:
118
+ self.fmc = 5
119
+ self._feat_stride_fpn = [8, 16, 32, 64, 128]
120
+ self._num_anchors = 1
121
+ elif len(outputs)==15:
122
+ self.fmc = 5
123
+ self._feat_stride_fpn = [8, 16, 32, 64, 128]
124
+ self._num_anchors = 1
125
+ self.use_kps = True
126
+
127
+ def prepare(self, **kwargs):
128
+ nms_thresh = kwargs.get('nms_thresh', None)
129
+ if nms_thresh is not None:
130
+ self.nms_thresh = nms_thresh
131
+ det_thresh = kwargs.get('det_thresh', None)
132
+ if det_thresh is not None:
133
+ self.det_thresh = det_thresh
134
+ input_size = kwargs.get('input_size', None)
135
+ if input_size is not None:
136
+ if self.input_size is not None:
137
+ print('warning: det_size is already set in detection model, ignore')
138
+ else:
139
+ self.input_size = input_size
140
+
141
+ def forward(self, img, threshold):
142
+ scores_list = []
143
+ bboxes_list = []
144
+ kpss_list = []
145
+ input_size = tuple(img.shape[0:2][::-1])
146
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
147
+ net_outs = self.session.run(self.output_names, {self.input_name : blob})
148
+
149
+ input_height = blob.shape[2]
150
+ input_width = blob.shape[3]
151
+ fmc = self.fmc
152
+ for idx, stride in enumerate(self._feat_stride_fpn):
153
+ scores = net_outs[idx]
154
+ bbox_preds = net_outs[idx+fmc]
155
+ bbox_preds = bbox_preds * stride
156
+ if self.use_kps:
157
+ kps_preds = net_outs[idx+fmc*2] * stride
158
+ height = input_height // stride
159
+ width = input_width // stride
160
+ K = height * width
161
+ key = (height, width, stride)
162
+ if key in self.center_cache:
163
+ anchor_centers = self.center_cache[key]
164
+ else:
165
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
166
+ anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
167
+ if self._num_anchors>1:
168
+ anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
169
+ if len(self.center_cache)<100:
170
+ self.center_cache[key] = anchor_centers
171
+
172
+ pos_inds = np.where(scores>=threshold)[0]
173
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
174
+ pos_scores = scores[pos_inds]
175
+ pos_bboxes = bboxes[pos_inds]
176
+ scores_list.append(pos_scores)
177
+ bboxes_list.append(pos_bboxes)
178
+ if self.use_kps:
179
+ kpss = distance2kps(anchor_centers, kps_preds)
180
+ kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
181
+ pos_kpss = kpss[pos_inds]
182
+ kpss_list.append(pos_kpss)
183
+ return scores_list, bboxes_list, kpss_list
184
+
185
+ def detect(self, img, input_size = (640,640), max_num=0, metric='default', det_thresh=0.5):
186
+ assert input_size is not None or self.input_size is not None
187
+ input_size = self.input_size if input_size is None else input_size
188
+
189
+ im_ratio = float(img.shape[0]) / img.shape[1]
190
+ model_ratio = float(input_size[1]) / input_size[0]
191
+ if im_ratio>model_ratio:
192
+ new_height = input_size[1]
193
+ new_width = int(new_height / im_ratio)
194
+ else:
195
+ new_width = input_size[0]
196
+ new_height = int(new_width * im_ratio)
197
+ det_scale = float(new_height) / img.shape[0]
198
+ resized_img = cv2.resize(img, (new_width, new_height))
199
+ det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
200
+ det_img[:new_height, :new_width, :] = resized_img
201
+
202
+ scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh)
203
+
204
+ scores = np.vstack(scores_list)
205
+ scores_ravel = scores.ravel()
206
+ order = scores_ravel.argsort()[::-1]
207
+ bboxes = np.vstack(bboxes_list) / det_scale
208
+ if self.use_kps:
209
+ kpss = np.vstack(kpss_list) / det_scale
210
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
211
+ pre_det = pre_det[order, :]
212
+ keep = self.nms(pre_det)
213
+ det = pre_det[keep, :]
214
+ if self.use_kps:
215
+ kpss = kpss[order,:,:]
216
+ kpss = kpss[keep,:,:]
217
+ else:
218
+ kpss = None
219
+ if max_num > 0 and det.shape[0] > max_num:
220
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
221
+ det[:, 1])
222
+ img_center = img.shape[0] // 2, img.shape[1] // 2
223
+ offsets = np.vstack([
224
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
225
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
226
+ ])
227
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
228
+ if metric=='max':
229
+ values = area
230
+ else:
231
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
232
+ bindex = np.argsort(
233
+ values)[::-1] # some extra weight on the centering
234
+ bindex = bindex[0:max_num]
235
+ det = det[bindex, :]
236
+ if kpss is not None:
237
+ kpss = kpss[bindex, :]
238
+ return det, kpss
239
+
240
+ def nms(self, dets):
241
+ thresh = self.nms_thresh
242
+ x1 = dets[:, 0]
243
+ y1 = dets[:, 1]
244
+ x2 = dets[:, 2]
245
+ y2 = dets[:, 3]
246
+ scores = dets[:, 4]
247
+
248
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
249
+ order = scores.argsort()[::-1]
250
+
251
+ keep = []
252
+ while order.size > 0:
253
+ i = order[0]
254
+ keep.append(i)
255
+ xx1 = np.maximum(x1[i], x1[order[1:]])
256
+ yy1 = np.maximum(y1[i], y1[order[1:]])
257
+ xx2 = np.minimum(x2[i], x2[order[1:]])
258
+ yy2 = np.minimum(y2[i], y2[order[1:]])
259
+
260
+ w = np.maximum(0.0, xx2 - xx1 + 1)
261
+ h = np.maximum(0.0, yy2 - yy1 + 1)
262
+ inter = w * h
263
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
264
+
265
+ inds = np.where(ovr <= thresh)[0]
266
+ order = order[inds + 1]
267
+
268
+ return keep