Spaces:
Build error
Build error
Initial
Browse files- .gitignore +4 -0
- README.md +3 -3
- app.py +1134 -0
- assets/images/loading.gif +0 -0
- assets/images/logo.png +0 -0
- assets/pretrained_models/readme.md +1 -0
- change_log.md +5 -0
- default_paths.py +18 -0
- face_analyser.py +168 -0
- face_parsing.py +55 -0
- face_swapper.py +43 -0
- face_upscaler.py +72 -0
- global_variables.py +36 -0
- nsfw_checker/LICENSE.md +11 -0
- nsfw_checker/__init__.py +1 -0
- nsfw_checker/opennsfw.py +65 -0
- requirements.txt +6 -0
- swap_mukham.py +195 -0
- swap_mukham_colab.ipynb +183 -0
- upscaler/GFPGAN.py +41 -0
- upscaler/GPEN.py +37 -0
- upscaler/__init__.py +0 -0
- upscaler/codeformer.py +41 -0
- upscaler/restoreformer.py +37 -0
- utils/__init__.py +0 -0
- utils/arcface.py +89 -0
- utils/device.py +32 -0
- utils/face_alignment.py +73 -0
- utils/gender_age.py +25 -0
- utils/image.py +252 -0
- utils/io.py +194 -0
- utils/retinaface.py +268 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
*.pth
|
3 |
+
*.onnx
|
4 |
+
*.pyc
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Swap-mukham WIP
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.40.1
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: Swap-mukham WIP
|
3 |
+
emoji: π
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: black
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.40.1
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,1134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import shutil
|
5 |
+
import base64
|
6 |
+
import datetime
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
import gradio as gr
|
10 |
+
from tqdm import tqdm
|
11 |
+
import concurrent.futures
|
12 |
+
|
13 |
+
import threading
|
14 |
+
cv_reader_lock = threading.Lock()
|
15 |
+
|
16 |
+
## ------------------------------ USER ARGS ------------------------------
|
17 |
+
|
18 |
+
parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
|
19 |
+
parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
|
20 |
+
parser.add_argument("--max_threads", type=int, help="Max num of threads to use", default=2)
|
21 |
+
parser.add_argument("--colab", action="store_true", help="Colab mode", default=False)
|
22 |
+
parser.add_argument("--cpu", action="store_true", help="Enable cpu mode", default=False)
|
23 |
+
parser.add_argument("--prefer_text_widget", action="store_true", help="Replaces target video widget with text widget", default=False)
|
24 |
+
user_args = parser.parse_args()
|
25 |
+
|
26 |
+
USE_CPU = 1
|
27 |
+
|
28 |
+
if not USE_CPU:
|
29 |
+
import torch
|
30 |
+
|
31 |
+
import default_paths as dp
|
32 |
+
import global_variables as gv
|
33 |
+
|
34 |
+
from swap_mukham import SwapMukham
|
35 |
+
from nsfw_checker import NSFWChecker
|
36 |
+
|
37 |
+
from face_parsing import mask_regions_to_list
|
38 |
+
|
39 |
+
from utils.device import get_device_and_provider, device_types_list
|
40 |
+
from utils.image import (
|
41 |
+
image_mask_overlay,
|
42 |
+
resize_image_by_resolution,
|
43 |
+
resolution_map,
|
44 |
+
fast_pil_encode,
|
45 |
+
fast_numpy_encode,
|
46 |
+
get_crf_for_resolution,
|
47 |
+
)
|
48 |
+
from utils.io import (
|
49 |
+
open_directory,
|
50 |
+
get_images_from_directory,
|
51 |
+
copy_files_to_directory,
|
52 |
+
create_directory,
|
53 |
+
get_single_video_frame,
|
54 |
+
ffmpeg_merge_frames,
|
55 |
+
ffmpeg_mux_audio,
|
56 |
+
add_datetime_to_filename,
|
57 |
+
)
|
58 |
+
|
59 |
+
gr.processing_utils.encode_pil_to_base64 = fast_pil_encode
|
60 |
+
gr.processing_utils.encode_array_to_base64 = fast_numpy_encode
|
61 |
+
|
62 |
+
gv.USE_COLAB = user_args.colab
|
63 |
+
gv.MAX_THREADS = user_args.max_threads
|
64 |
+
gv.DEFAULT_OUTPUT_PATH = user_args.out_dir
|
65 |
+
|
66 |
+
PREFER_TEXT_WIDGET = user_args.prefer_text_widget
|
67 |
+
|
68 |
+
WORKSPACE = None
|
69 |
+
OUTPUT_FILE = None
|
70 |
+
|
71 |
+
preferred_device = "cpu" if USE_CPU else "cuda"
|
72 |
+
DEVICE_LIST = device_types_list
|
73 |
+
DEVICE, PROVIDER, OPTIONS = get_device_and_provider(device=preferred_device)
|
74 |
+
SWAP_MUKHAM = SwapMukham(device=DEVICE)
|
75 |
+
|
76 |
+
IS_RUNNING = False
|
77 |
+
CURRENT_FRAME = None
|
78 |
+
COLLECTED_FACES = []
|
79 |
+
FOREGROUND_MASK_DICT = {}
|
80 |
+
NSFW_CACHE = {}
|
81 |
+
|
82 |
+
|
83 |
+
## ------------------------------ MAIN PROCESS ------------------------------
|
84 |
+
|
85 |
+
|
86 |
+
def process(
|
87 |
+
test_mode,
|
88 |
+
target_type,
|
89 |
+
image_path,
|
90 |
+
video_path,
|
91 |
+
directory_path,
|
92 |
+
source_path,
|
93 |
+
use_foreground_mask,
|
94 |
+
img_fg_mask,
|
95 |
+
fg_mask_softness,
|
96 |
+
output_path,
|
97 |
+
output_name,
|
98 |
+
use_datetime_suffix,
|
99 |
+
sequence_output_format,
|
100 |
+
keep_output_sequence,
|
101 |
+
swap_condition,
|
102 |
+
age,
|
103 |
+
distance,
|
104 |
+
face_enhancer_name,
|
105 |
+
face_upscaler_opacity,
|
106 |
+
use_face_parsing,
|
107 |
+
parse_from_target,
|
108 |
+
mask_regions,
|
109 |
+
mask_blur_amount,
|
110 |
+
mask_erode_amount,
|
111 |
+
swap_iteration,
|
112 |
+
face_scale,
|
113 |
+
use_laplacian_blending,
|
114 |
+
crop_top,
|
115 |
+
crop_bott,
|
116 |
+
crop_left,
|
117 |
+
crop_right,
|
118 |
+
current_idx,
|
119 |
+
number_of_threads,
|
120 |
+
use_frame_selection,
|
121 |
+
frame_selection_ranges,
|
122 |
+
video_quality,
|
123 |
+
face_detection_condition,
|
124 |
+
face_detection_size,
|
125 |
+
face_detection_threshold,
|
126 |
+
averaging_method,
|
127 |
+
progress=gr.Progress(track_tqdm=True),
|
128 |
+
*specifics,
|
129 |
+
):
|
130 |
+
global WORKSPACE
|
131 |
+
global OUTPUT_FILE
|
132 |
+
global PREVIEW
|
133 |
+
WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
|
134 |
+
|
135 |
+
global IS_RUNNING
|
136 |
+
IS_RUNNING = True
|
137 |
+
|
138 |
+
## ------------------------------ GUI UPDATE FUNC ------------------------------
|
139 |
+
def ui_before():
|
140 |
+
return (
|
141 |
+
gr.update(visible=True, value=None),
|
142 |
+
gr.update(interactive=False),
|
143 |
+
gr.update(interactive=False),
|
144 |
+
gr.update(visible=False, value=None),
|
145 |
+
)
|
146 |
+
|
147 |
+
def ui_after():
|
148 |
+
return (
|
149 |
+
gr.update(visible=True, value=PREVIEW),
|
150 |
+
gr.update(interactive=True),
|
151 |
+
gr.update(interactive=True),
|
152 |
+
gr.update(visible=False, value=None),
|
153 |
+
)
|
154 |
+
|
155 |
+
def ui_after_vid():
|
156 |
+
return (
|
157 |
+
gr.update(visible=False),
|
158 |
+
gr.update(interactive=True),
|
159 |
+
gr.update(interactive=True),
|
160 |
+
gr.update(value=OUTPUT_FILE, visible=True),
|
161 |
+
)
|
162 |
+
|
163 |
+
if not test_mode:
|
164 |
+
yield ui_before() # resets ui preview
|
165 |
+
progress(0, desc="Processing")
|
166 |
+
|
167 |
+
start_time = time.time()
|
168 |
+
total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
|
169 |
+
get_finsh_text = (
|
170 |
+
lambda start_time: f"Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
|
171 |
+
)
|
172 |
+
|
173 |
+
## ------------------------------ PREPARE INPUTS ------------------------------
|
174 |
+
|
175 |
+
if use_datetime_suffix:
|
176 |
+
output_name = add_datetime_to_filename(output_name)
|
177 |
+
|
178 |
+
mask_regions = mask_regions_to_list(mask_regions)
|
179 |
+
|
180 |
+
specifics = list(specifics)
|
181 |
+
half = len(specifics) // 2
|
182 |
+
if swap_condition == "specific face":
|
183 |
+
source_specifics = [
|
184 |
+
([s.name for s in src] if src is not None else None, spc) for src, spc in zip(specifics[:half], specifics[half:])
|
185 |
+
]
|
186 |
+
else:
|
187 |
+
source_paths = [i.name for i in source_path]
|
188 |
+
source_specifics = [(source_paths, None)]
|
189 |
+
|
190 |
+
if crop_top > crop_bott:
|
191 |
+
crop_top, crop_bott = crop_bott, crop_top
|
192 |
+
if crop_left > crop_right:
|
193 |
+
crop_left, crop_right = crop_right, crop_left
|
194 |
+
crop_mask = (crop_top, 511 - crop_bott, crop_left, 511 - crop_right)
|
195 |
+
|
196 |
+
input_args = {
|
197 |
+
"similarity": distance,
|
198 |
+
"age": age,
|
199 |
+
"face_scale": face_scale,
|
200 |
+
"num_of_pass": swap_iteration,
|
201 |
+
"face_upscaler_opacity": face_upscaler_opacity,
|
202 |
+
"mask_crop_values": crop_mask,
|
203 |
+
"mask_erode_amount": mask_erode_amount,
|
204 |
+
"mask_blur_amount": mask_blur_amount,
|
205 |
+
"use_laplacian_blending": use_laplacian_blending,
|
206 |
+
"swap_condition": swap_condition,
|
207 |
+
"face_parse_regions": mask_regions,
|
208 |
+
"use_face_parsing": use_face_parsing,
|
209 |
+
"face_detection_size": [int(face_detection_size), int(face_detection_size)],
|
210 |
+
"face_detection_threshold": face_detection_threshold,
|
211 |
+
"face_detection_condition": face_detection_condition,
|
212 |
+
"parse_from_target": parse_from_target,
|
213 |
+
"averaging_method": averaging_method,
|
214 |
+
}
|
215 |
+
|
216 |
+
SWAP_MUKHAM.set_values(input_args)
|
217 |
+
if (
|
218 |
+
SWAP_MUKHAM.face_upscaler is None
|
219 |
+
or SWAP_MUKHAM.face_upscaler_name != face_enhancer_name
|
220 |
+
):
|
221 |
+
SWAP_MUKHAM.load_face_upscaler(face_enhancer_name, device=DEVICE)
|
222 |
+
if SWAP_MUKHAM.face_parser is None and use_face_parsing:
|
223 |
+
SWAP_MUKHAM.load_face_parser(device=DEVICE)
|
224 |
+
SWAP_MUKHAM.analyse_source_faces(source_specifics)
|
225 |
+
|
226 |
+
mask = None
|
227 |
+
if use_foreground_mask and img_fg_mask is not None:
|
228 |
+
mask = img_fg_mask.get("mask", None)
|
229 |
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGRA2RGB)
|
230 |
+
if fg_mask_softness > 0:
|
231 |
+
mask = cv2.blur(mask, (int(fg_mask_softness), int(fg_mask_softness)))
|
232 |
+
mask = mask.astype("float32") / 255.0
|
233 |
+
|
234 |
+
def nsfw_assertion(is_nsfw):
|
235 |
+
if is_nsfw:
|
236 |
+
message = "NSFW content detected !"
|
237 |
+
gr.Info(message)
|
238 |
+
assert not is_nsfw, message
|
239 |
+
|
240 |
+
## ------------------------------ IMAGE ------------------------------
|
241 |
+
|
242 |
+
if target_type == "Image" and not test_mode:
|
243 |
+
target = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
244 |
+
|
245 |
+
is_nsfw = SWAP_MUKHAM.nsfw_detector.check_image(target)
|
246 |
+
nsfw_assertion(is_nsfw)
|
247 |
+
|
248 |
+
output = SWAP_MUKHAM.process_frame(
|
249 |
+
[target, mask]
|
250 |
+
)
|
251 |
+
output_file = os.path.join(output_path, output_name + ".png")
|
252 |
+
cv2.imwrite(output_file, output)
|
253 |
+
|
254 |
+
PREVIEW = output
|
255 |
+
OUTPUT_FILE = output_file
|
256 |
+
WORKSPACE = output_path
|
257 |
+
|
258 |
+
gr.Info(get_finsh_text(start_time))
|
259 |
+
yield ui_after()
|
260 |
+
|
261 |
+
## ------------------------------ VIDEO ------------------------------
|
262 |
+
|
263 |
+
elif target_type == "Video" and not test_mode:
|
264 |
+
video_path = video_path.replace('"', '').strip()
|
265 |
+
|
266 |
+
if video_path in NSFW_CACHE.keys():
|
267 |
+
nsfw_assertion(NSFW_CACHE.get(video_path))
|
268 |
+
else:
|
269 |
+
is_nsfw = SWAP_MUKHAM.nsfw_detector.check_video(video_path)
|
270 |
+
NSFW_CACHE[video_path] = is_nsfw
|
271 |
+
nsfw_assertion(is_nsfw)
|
272 |
+
|
273 |
+
temp_path = os.path.join(output_path, output_name)
|
274 |
+
os.makedirs(temp_path, exist_ok=True)
|
275 |
+
|
276 |
+
cap = cv2.VideoCapture(video_path)
|
277 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
278 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
279 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
280 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
281 |
+
|
282 |
+
is_in_range = lambda idx: any([int(rng[0]) <= idx <= int(rng[1]) for rng in frame_selection_ranges]) if use_frame_selection else True
|
283 |
+
|
284 |
+
print("[ Swapping process started ]")
|
285 |
+
|
286 |
+
def swap_video_func(frame_index):
|
287 |
+
if IS_RUNNING:
|
288 |
+
with cv_reader_lock:
|
289 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index))
|
290 |
+
valid_frame, frame = cap.read()
|
291 |
+
|
292 |
+
if valid_frame:
|
293 |
+
if is_in_range(frame_index):
|
294 |
+
mask = FOREGROUND_MASK_DICT.get(frame_index, None) if use_foreground_mask else None
|
295 |
+
output = SWAP_MUKHAM.process_frame([frame, mask])
|
296 |
+
else:
|
297 |
+
output = frame
|
298 |
+
frame_path = os.path.join(temp_path, f"frame_{frame_index}.{sequence_output_format}")
|
299 |
+
if sequence_output_format == "jpg":
|
300 |
+
cv2.imwrite(frame_path, output, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
301 |
+
else:
|
302 |
+
cv2.imwrite(frame_path, output)
|
303 |
+
|
304 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=number_of_threads) as executor:
|
305 |
+
futures = [executor.submit(swap_video_func, idx) for idx in range(total_frames)]
|
306 |
+
|
307 |
+
with tqdm(total=total_frames, desc="Processing") as pbar:
|
308 |
+
for future in concurrent.futures.as_completed(futures):
|
309 |
+
future.result()
|
310 |
+
pbar.update(1)
|
311 |
+
|
312 |
+
cap.release()
|
313 |
+
|
314 |
+
if IS_RUNNING:
|
315 |
+
print("[ Merging image sequence ]")
|
316 |
+
progress(0, desc="Merging image sequence")
|
317 |
+
WORKSPACE = output_path
|
318 |
+
out_without_audio = output_name + "_without_audio" + ".mp4"
|
319 |
+
destination = os.path.join(output_path, out_without_audio)
|
320 |
+
crf = get_crf_for_resolution(max(width,height), video_quality)
|
321 |
+
ret, destination = ffmpeg_merge_frames(
|
322 |
+
temp_path, f"frame_%d.{sequence_output_format}", destination, fps=fps, crf=crf, ffmpeg_path=dp.FFMPEG_PATH
|
323 |
+
)
|
324 |
+
OUTPUT_FILE = destination
|
325 |
+
|
326 |
+
if ret:
|
327 |
+
print("[ Merging audio ]")
|
328 |
+
progress(0, desc="Merging audio")
|
329 |
+
OUTPUT_FILE = destination
|
330 |
+
out_with_audio = out_without_audio.replace("_without_audio", "")
|
331 |
+
_ret, _destination = ffmpeg_mux_audio(
|
332 |
+
video_path, out_without_audio, out_with_audio, ffmpeg_path=dp.FFMPEG_PATH
|
333 |
+
)
|
334 |
+
|
335 |
+
if _ret:
|
336 |
+
OUTPUT_FILE = _destination
|
337 |
+
os.remove(out_without_audio)
|
338 |
+
|
339 |
+
if os.path.exists(temp_path) and not keep_output_sequence:
|
340 |
+
print("[ Removing temporary files ]")
|
341 |
+
progress(0, desc="Removing temporary files")
|
342 |
+
shutil.rmtree(temp_path)
|
343 |
+
|
344 |
+
finish_text = get_finsh_text(start_time)
|
345 |
+
print(f"[ {finish_text} ]")
|
346 |
+
gr.Info(finish_text)
|
347 |
+
yield ui_after_vid()
|
348 |
+
|
349 |
+
## ------------------------------ DIRECTORY ------------------------------
|
350 |
+
|
351 |
+
elif target_type == "Directory" and not test_mode:
|
352 |
+
temp_path = os.path.join(output_path, output_name)
|
353 |
+
temp_path = create_directory(temp_path, remove_existing=True)
|
354 |
+
|
355 |
+
directory_path = directory_path.replace('"', '').strip()
|
356 |
+
image_paths = get_images_from_directory(directory_path)
|
357 |
+
|
358 |
+
is_nsfw = SWAP_MUKHAM.nsfw_detector.check_image_paths(image_paths)
|
359 |
+
nsfw_assertion(is_nsfw)
|
360 |
+
|
361 |
+
new_image_paths = copy_files_to_directory(image_paths, temp_path)
|
362 |
+
|
363 |
+
def swap_func(img_path):
|
364 |
+
if IS_RUNNING:
|
365 |
+
frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
366 |
+
output = SWAP_MUKHAM.process_frame([frame, None])
|
367 |
+
cv2.imwrite(img_path, output)
|
368 |
+
|
369 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=number_of_threads) as executor:
|
370 |
+
futures = [executor.submit(swap_func, img_path) for img_path in new_image_paths]
|
371 |
+
|
372 |
+
with tqdm(total=len(new_image_paths), desc="Processing") as pbar:
|
373 |
+
for future in concurrent.futures.as_completed(futures):
|
374 |
+
future.result()
|
375 |
+
pbar.update(1)
|
376 |
+
|
377 |
+
PREVIEW = cv2.imread(new_image_paths[-1])
|
378 |
+
WORKSPACE = temp_path
|
379 |
+
OUTPUT_FILE = new_image_paths[-1]
|
380 |
+
|
381 |
+
gr.Info(get_finsh_text(start_time))
|
382 |
+
yield ui_after()
|
383 |
+
|
384 |
+
## ------------------------------ STREAM ------------------------------
|
385 |
+
|
386 |
+
elif target_type == "Stream" and not test_mode:
|
387 |
+
pass
|
388 |
+
|
389 |
+
## ------------------------------ TEST ------------------------------
|
390 |
+
|
391 |
+
if test_mode and target_type == "Video":
|
392 |
+
mask = None
|
393 |
+
if use_face_parsing_mask:
|
394 |
+
mask = FOREGROUND_MASK_DICT.get(current_idx, None)
|
395 |
+
if CURRENT_FRAME is not None and isinstance(CURRENT_FRAME, np.ndarray):
|
396 |
+
PREVIEW = SWAP_MUKHAM.process_frame(
|
397 |
+
[CURRENT_FRAME[:, :, ::-1], mask]
|
398 |
+
)
|
399 |
+
gr.Info(get_finsh_text(start_time))
|
400 |
+
yield ui_after()
|
401 |
+
|
402 |
+
|
403 |
+
## ------------------------------ GRADIO GUI ------------------------------
|
404 |
+
|
405 |
+
css = """
|
406 |
+
|
407 |
+
div.gradio-container{
|
408 |
+
max-width: unset !important;
|
409 |
+
}
|
410 |
+
|
411 |
+
footer{
|
412 |
+
display:none !important
|
413 |
+
}
|
414 |
+
|
415 |
+
#slider_row {
|
416 |
+
display: flex;
|
417 |
+
flex-wrap: wrap;
|
418 |
+
justify-content: space-between;
|
419 |
+
}
|
420 |
+
|
421 |
+
#refresh_slider {
|
422 |
+
flex: 0 1 20%;
|
423 |
+
display: flex;
|
424 |
+
align-items: center;
|
425 |
+
}
|
426 |
+
|
427 |
+
#frame_slider {
|
428 |
+
flex: 1 0 80%;
|
429 |
+
display: flex;
|
430 |
+
align-items: center;
|
431 |
+
}
|
432 |
+
|
433 |
+
"""
|
434 |
+
|
435 |
+
WIDGET_PREVIEW_HEIGHT = 450
|
436 |
+
|
437 |
+
with gr.Blocks(css=css, theme=gr.themes.Default()) as interface:
|
438 |
+
gr.Markdown("# πΏ Swap Mukham")
|
439 |
+
gr.Markdown("### Single image face swapper")
|
440 |
+
with gr.Row():
|
441 |
+
with gr.Row():
|
442 |
+
with gr.Column(scale=0.35):
|
443 |
+
with gr.Tabs():
|
444 |
+
with gr.TabItem("π Input"):
|
445 |
+
swap_condition = gr.Dropdown(
|
446 |
+
gv.FACE_DETECT_CONDITIONS,
|
447 |
+
info="Choose which face or faces in the target image to swap.",
|
448 |
+
multiselect=False,
|
449 |
+
show_label=False,
|
450 |
+
value=gv.FACE_DETECT_CONDITIONS[0],
|
451 |
+
interactive=True,
|
452 |
+
)
|
453 |
+
age = gr.Number(
|
454 |
+
value=25, label="Value", interactive=True, visible=False
|
455 |
+
)
|
456 |
+
|
457 |
+
## ------------------------------ SOURCE IMAGE ------------------------------
|
458 |
+
|
459 |
+
source_image_input = gr.Files(
|
460 |
+
label="Source face", type="file", interactive=True,
|
461 |
+
)
|
462 |
+
|
463 |
+
## ------------------------------ SOURCE SPECIFIC ------------------------------
|
464 |
+
|
465 |
+
with gr.Box(visible=False) as specific_face:
|
466 |
+
for i in range(gv.NUM_OF_SRC_SPECIFIC):
|
467 |
+
idx = i + 1
|
468 |
+
code = "\n"
|
469 |
+
code += f"with gr.Tab(label='{idx}'):"
|
470 |
+
code += "\n\twith gr.Row():"
|
471 |
+
code += f"\n\t\tsrc{idx} = gr.Files(interactive=True, type='file', label='Source Face {idx}')"
|
472 |
+
code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
|
473 |
+
exec(code)
|
474 |
+
|
475 |
+
## ------------------------------ TARGET TYPE ------------------------------
|
476 |
+
|
477 |
+
with gr.Group():
|
478 |
+
target_type = gr.Radio(
|
479 |
+
["Image", "Video", "Directory"],
|
480 |
+
label="Target Type",
|
481 |
+
value="Video",
|
482 |
+
)
|
483 |
+
|
484 |
+
## ------------------------------ TARGET IMAGE ------------------------------
|
485 |
+
|
486 |
+
with gr.Box(visible=False) as input_image_group:
|
487 |
+
target_image_input = gr.Image(
|
488 |
+
label="Target Image",
|
489 |
+
interactive=True,
|
490 |
+
type="filepath",
|
491 |
+
height=200
|
492 |
+
)
|
493 |
+
|
494 |
+
## ------------------------------ TARGET VIDEO ------------------------------
|
495 |
+
|
496 |
+
with gr.Box(visible=True) as input_video_group:
|
497 |
+
with gr.Column():
|
498 |
+
video_widget = gr.Text if PREFER_TEXT_WIDGET else gr.Video
|
499 |
+
video_input = video_widget(
|
500 |
+
label="Target Video", interactive=True,
|
501 |
+
)
|
502 |
+
|
503 |
+
## ------------------------------ FRAME SELECTION ------------------------------
|
504 |
+
|
505 |
+
with gr.Accordion("Frame Selection", open=False):
|
506 |
+
use_frame_selection = gr.Checkbox(
|
507 |
+
label="Use frame selection", value=False, interactive=True,
|
508 |
+
)
|
509 |
+
frame_selection_ranges = gr.Numpy(
|
510 |
+
headers=["Start Frame", "End Frame"],
|
511 |
+
datatype=["number", "number"],
|
512 |
+
row_count=1,
|
513 |
+
col_count=(2, "fixed"),
|
514 |
+
interactive=True
|
515 |
+
)
|
516 |
+
|
517 |
+
## ------------------------------ TARGET DIRECTORY ------------------------------
|
518 |
+
|
519 |
+
with gr.Box(visible=False) as input_directory_group:
|
520 |
+
directory_input = gr.Text(
|
521 |
+
label="Target Image Directory", interactive=True
|
522 |
+
)
|
523 |
+
|
524 |
+
## ------------------------------ TAB MODEL ------------------------------
|
525 |
+
|
526 |
+
with gr.TabItem("ποΈ Model"):
|
527 |
+
with gr.Accordion("Detection", open=False):
|
528 |
+
face_detection_condition = gr.Dropdown(
|
529 |
+
gv.SINGLE_FACE_DETECT_CONDITIONS,
|
530 |
+
label="Condition",
|
531 |
+
value=gv.DETECT_CONDITION,
|
532 |
+
interactive=True,
|
533 |
+
info="This condition is only used when multiple faces are detected on source or specific image.",
|
534 |
+
)
|
535 |
+
face_detection_size = gr.Number(
|
536 |
+
label="Detection Size",
|
537 |
+
value=gv.DETECT_SIZE,
|
538 |
+
interactive=True,
|
539 |
+
)
|
540 |
+
face_detection_threshold = gr.Number(
|
541 |
+
label="Detection Threshold",
|
542 |
+
value=gv.DETECT_THRESHOLD,
|
543 |
+
interactive=True,
|
544 |
+
)
|
545 |
+
face_scale = gr.Slider(
|
546 |
+
label="Landmark Scale",
|
547 |
+
minimum=0,
|
548 |
+
maximum=2,
|
549 |
+
value=1,
|
550 |
+
interactive=True,
|
551 |
+
)
|
552 |
+
with gr.Accordion("Embedding/Recognition", open=True):
|
553 |
+
averaging_method = gr.Dropdown(
|
554 |
+
gv.AVERAGING_METHODS,
|
555 |
+
label="Averaging Method",
|
556 |
+
value=gv.AVERAGING_METHOD,
|
557 |
+
interactive=True,
|
558 |
+
)
|
559 |
+
distance_slider = gr.Slider(
|
560 |
+
minimum=0,
|
561 |
+
maximum=2,
|
562 |
+
value=0.65,
|
563 |
+
interactive=True,
|
564 |
+
label="Specific-Target Distance",
|
565 |
+
)
|
566 |
+
with gr.Accordion("Swapper", open=True):
|
567 |
+
with gr.Row():
|
568 |
+
swap_iteration = gr.Slider(
|
569 |
+
label="Swap Iteration",
|
570 |
+
minimum=1,
|
571 |
+
maximum=4,
|
572 |
+
value=1,
|
573 |
+
step=1,
|
574 |
+
interactive=True,
|
575 |
+
)
|
576 |
+
|
577 |
+
## ------------------------------ TAB POST-PROCESS ------------------------------
|
578 |
+
|
579 |
+
with gr.TabItem("πͺ Post-Process"):
|
580 |
+
with gr.Row():
|
581 |
+
face_enhancer_name = gr.Dropdown(
|
582 |
+
gv.FACE_ENHANCER_LIST,
|
583 |
+
label="Face Enhancer",
|
584 |
+
value="NONE",
|
585 |
+
multiselect=False,
|
586 |
+
interactive=True,
|
587 |
+
)
|
588 |
+
face_upscaler_opacity = gr.Slider(
|
589 |
+
label="Opacity",
|
590 |
+
minimum=0,
|
591 |
+
maximum=1,
|
592 |
+
value=1,
|
593 |
+
step=0.001,
|
594 |
+
interactive=True,
|
595 |
+
)
|
596 |
+
|
597 |
+
with gr.Accordion("Face Mask", open=False):
|
598 |
+
with gr.Group():
|
599 |
+
with gr.Row():
|
600 |
+
use_face_parsing_mask = gr.Checkbox(
|
601 |
+
label="Enable Face Parsing",
|
602 |
+
value=False,
|
603 |
+
interactive=True,
|
604 |
+
)
|
605 |
+
parse_from_target = gr.Checkbox(
|
606 |
+
label="Parse from target",
|
607 |
+
value=False,
|
608 |
+
interactive=True,
|
609 |
+
)
|
610 |
+
mask_regions = gr.Dropdown(
|
611 |
+
gv.MASK_REGIONS,
|
612 |
+
value=gv.MASK_REGIONS_DEFAULT,
|
613 |
+
multiselect=True,
|
614 |
+
label="Include",
|
615 |
+
interactive=True,
|
616 |
+
)
|
617 |
+
|
618 |
+
with gr.Accordion("Crop Face Bounding-Box", open=False):
|
619 |
+
with gr.Group():
|
620 |
+
with gr.Row():
|
621 |
+
crop_top = gr.Slider(
|
622 |
+
label="Top",
|
623 |
+
minimum=0,
|
624 |
+
maximum=511,
|
625 |
+
value=0,
|
626 |
+
step=1,
|
627 |
+
interactive=True,
|
628 |
+
)
|
629 |
+
crop_bott = gr.Slider(
|
630 |
+
label="Bottom",
|
631 |
+
minimum=0,
|
632 |
+
maximum=511,
|
633 |
+
value=511,
|
634 |
+
step=1,
|
635 |
+
interactive=True,
|
636 |
+
)
|
637 |
+
with gr.Row():
|
638 |
+
crop_left = gr.Slider(
|
639 |
+
label="Left",
|
640 |
+
minimum=0,
|
641 |
+
maximum=511,
|
642 |
+
value=0,
|
643 |
+
step=1,
|
644 |
+
interactive=True,
|
645 |
+
)
|
646 |
+
crop_right = gr.Slider(
|
647 |
+
label="Right",
|
648 |
+
minimum=0,
|
649 |
+
maximum=511,
|
650 |
+
value=511,
|
651 |
+
step=1,
|
652 |
+
interactive=True,
|
653 |
+
)
|
654 |
+
|
655 |
+
with gr.Row():
|
656 |
+
mask_erode_amount = gr.Slider(
|
657 |
+
label="Mask Erode",
|
658 |
+
minimum=0,
|
659 |
+
maximum=1,
|
660 |
+
value=gv.MASK_ERODE_AMOUNT,
|
661 |
+
step=0.001,
|
662 |
+
interactive=True,
|
663 |
+
)
|
664 |
+
|
665 |
+
mask_blur_amount = gr.Slider(
|
666 |
+
label="Mask Blur",
|
667 |
+
minimum=0,
|
668 |
+
maximum=1,
|
669 |
+
value=gv.MASK_BLUR_AMOUNT,
|
670 |
+
step=0.001,
|
671 |
+
interactive=True,
|
672 |
+
)
|
673 |
+
|
674 |
+
use_laplacian_blending = gr.Checkbox(
|
675 |
+
label="Laplacian Blending",
|
676 |
+
value=True,
|
677 |
+
interactive=True,
|
678 |
+
)
|
679 |
+
|
680 |
+
## ------------------------------ TAB OUTPUT ------------------------------
|
681 |
+
|
682 |
+
with gr.TabItem("π€ Output"):
|
683 |
+
output_directory = gr.Text(
|
684 |
+
label="Output Directory",
|
685 |
+
value=gv.DEFAULT_OUTPUT_PATH,
|
686 |
+
interactive=True,
|
687 |
+
)
|
688 |
+
with gr.Group():
|
689 |
+
output_name = gr.Text(
|
690 |
+
label="Output Name", value="Result", interactive=True
|
691 |
+
)
|
692 |
+
use_datetime_suffix = gr.Checkbox(
|
693 |
+
label="Suffix date-time", value=True, interactive=True
|
694 |
+
)
|
695 |
+
with gr.Accordion("Video settings", open=True):
|
696 |
+
with gr.Row():
|
697 |
+
sequence_output_format = gr.Dropdown(
|
698 |
+
["jpg", "png"],
|
699 |
+
label="Sequence format",
|
700 |
+
value="jpg",
|
701 |
+
interactive=True,
|
702 |
+
)
|
703 |
+
video_quality = gr.Dropdown(
|
704 |
+
gv.VIDEO_QUALITY_LIST,
|
705 |
+
label="Quality",
|
706 |
+
value=gv.VIDEO_QUALITY,
|
707 |
+
interactive=True
|
708 |
+
)
|
709 |
+
keep_output_sequence = gr.Checkbox(
|
710 |
+
label="Keep output sequence", value=False, interactive=True
|
711 |
+
)
|
712 |
+
|
713 |
+
## ------------------------------ TAB PERFORMANCE ------------------------------
|
714 |
+
with gr.TabItem("π οΈ Performance"):
|
715 |
+
preview_resolution = gr.Dropdown(
|
716 |
+
gv.RESOLUTIONS,
|
717 |
+
label="Preview Resolution",
|
718 |
+
value="Original",
|
719 |
+
interactive=True,
|
720 |
+
)
|
721 |
+
number_of_threads = gr.Number(
|
722 |
+
step=1,
|
723 |
+
interactive=True,
|
724 |
+
label="Max number of threads",
|
725 |
+
value=gv.MAX_THREADS,
|
726 |
+
minimum=1,
|
727 |
+
)
|
728 |
+
with gr.Box():
|
729 |
+
with gr.Column():
|
730 |
+
with gr.Row():
|
731 |
+
face_analyser_device = gr.Radio(
|
732 |
+
DEVICE_LIST,
|
733 |
+
label="Face detection & recognition",
|
734 |
+
value=DEVICE,
|
735 |
+
interactive=True,
|
736 |
+
)
|
737 |
+
face_analyser_device_submit = gr.Button("Apply")
|
738 |
+
with gr.Row():
|
739 |
+
face_swapper_device = gr.Radio(
|
740 |
+
DEVICE_LIST,
|
741 |
+
label="Face swapper",
|
742 |
+
value=DEVICE,
|
743 |
+
interactive=True,
|
744 |
+
)
|
745 |
+
face_swapper_device_submit = gr.Button("Apply")
|
746 |
+
with gr.Row():
|
747 |
+
face_parser_device = gr.Radio(
|
748 |
+
DEVICE_LIST,
|
749 |
+
label="Face parsing",
|
750 |
+
value=DEVICE,
|
751 |
+
interactive=True,
|
752 |
+
)
|
753 |
+
face_parser_device_submit = gr.Button("Apply")
|
754 |
+
with gr.Row():
|
755 |
+
face_upscaler_device = gr.Radio(
|
756 |
+
DEVICE_LIST,
|
757 |
+
label="Face upscaler",
|
758 |
+
value=DEVICE,
|
759 |
+
interactive=True,
|
760 |
+
)
|
761 |
+
face_upscaler_device_submit = gr.Button("Apply")
|
762 |
+
|
763 |
+
face_analyser_device_submit.click(
|
764 |
+
fn=lambda d: SWAP_MUKHAM.load_face_analyser(
|
765 |
+
device=d
|
766 |
+
),
|
767 |
+
inputs=[face_analyser_device],
|
768 |
+
)
|
769 |
+
face_swapper_device_submit.click(
|
770 |
+
fn=lambda d: SWAP_MUKHAM.load_face_swapper(
|
771 |
+
device=d
|
772 |
+
),
|
773 |
+
inputs=[face_swapper_device],
|
774 |
+
)
|
775 |
+
face_parser_device_submit.click(
|
776 |
+
fn=lambda d: SWAP_MUKHAM.load_face_parser(device=d),
|
777 |
+
inputs=[face_parser_device],
|
778 |
+
)
|
779 |
+
face_upscaler_device_submit.click(
|
780 |
+
fn=lambda n, d: SWAP_MUKHAM.load_face_upscaler(
|
781 |
+
n, device=d
|
782 |
+
),
|
783 |
+
inputs=[face_enhancer_name, face_upscaler_device],
|
784 |
+
)
|
785 |
+
|
786 |
+
## ------------------------------ SWAP, CANCEL, FRAME SLIDER ------------------------------
|
787 |
+
|
788 |
+
with gr.Column(scale=0.65):
|
789 |
+
with gr.Row():
|
790 |
+
swap_button = gr.Button("β¨ Swap", variant="primary")
|
791 |
+
cancel_button = gr.Button("β Cancel")
|
792 |
+
collect_faces = gr.Button("π¨ Collect Faces")
|
793 |
+
test_swap = gr.Button("π§ͺ Test Swap")
|
794 |
+
|
795 |
+
with gr.Box() as frame_slider_box:
|
796 |
+
with gr.Row(elem_id="slider_row", equal_height=True):
|
797 |
+
set_slider_range_btn = gr.Button(
|
798 |
+
"Set Range", interactive=True, elem_id="refresh_slider"
|
799 |
+
)
|
800 |
+
frame_slider = gr.Slider(
|
801 |
+
label="Frame",
|
802 |
+
minimum=0,
|
803 |
+
maximum=1,
|
804 |
+
value=0,
|
805 |
+
step=1,
|
806 |
+
interactive=True,
|
807 |
+
elem_id="frame_slider",
|
808 |
+
)
|
809 |
+
|
810 |
+
## ------------------------------ PREVIEW ------------------------------
|
811 |
+
|
812 |
+
with gr.Tabs():
|
813 |
+
with gr.TabItem("Preview"):
|
814 |
+
|
815 |
+
preview_image = gr.Image(
|
816 |
+
label="Preview", type="numpy", interactive=False, height=WIDGET_PREVIEW_HEIGHT,
|
817 |
+
)
|
818 |
+
|
819 |
+
preview_video = gr.Video(
|
820 |
+
label="Output", interactive=False, visible=False, height=WIDGET_PREVIEW_HEIGHT,
|
821 |
+
)
|
822 |
+
preview_enabled_text = gr.Markdown(
|
823 |
+
"Disable paint foreground to preview !", visible=False
|
824 |
+
)
|
825 |
+
with gr.Row():
|
826 |
+
output_directory_button = gr.Button(
|
827 |
+
"π", interactive=False, visible=not gv.USE_COLAB
|
828 |
+
)
|
829 |
+
output_video_button = gr.Button(
|
830 |
+
"π¬", interactive=False, visible=not gv.USE_COLAB
|
831 |
+
)
|
832 |
+
|
833 |
+
output_directory_button.click(
|
834 |
+
lambda: open_directory(path=WORKSPACE),
|
835 |
+
inputs=None,
|
836 |
+
outputs=None,
|
837 |
+
)
|
838 |
+
output_video_button.click(
|
839 |
+
lambda: open_directory(path=OUTPUT_FILE),
|
840 |
+
inputs=None,
|
841 |
+
outputs=None,
|
842 |
+
)
|
843 |
+
|
844 |
+
## ------------------------------ FOREGROUND MASK ------------------------------
|
845 |
+
|
846 |
+
with gr.TabItem("Paint Foreground"):
|
847 |
+
with gr.Box() as fg_mask_group:
|
848 |
+
with gr.Row():
|
849 |
+
with gr.Row():
|
850 |
+
use_foreground_mask = gr.Checkbox(
|
851 |
+
label="Use foreground mask", value=False, interactive=True)
|
852 |
+
fg_mask_softness = gr.Slider(
|
853 |
+
label="Mask Softness",
|
854 |
+
minimum=0,
|
855 |
+
maximum=200,
|
856 |
+
value=1,
|
857 |
+
step=1,
|
858 |
+
interactive=True,
|
859 |
+
)
|
860 |
+
add_fg_mask_btn = gr.Button("Add", interactive=True)
|
861 |
+
del_fg_mask_btn = gr.Button("Del", interactive=True)
|
862 |
+
img_fg_mask = gr.Image(
|
863 |
+
label="Paint Mask",
|
864 |
+
tool="sketch",
|
865 |
+
interactive=True,
|
866 |
+
type="numpy",
|
867 |
+
height=WIDGET_PREVIEW_HEIGHT,
|
868 |
+
)
|
869 |
+
|
870 |
+
## ------------------------------ COLLECT FACE ------------------------------
|
871 |
+
|
872 |
+
with gr.TabItem("Collected Faces"):
|
873 |
+
collected_faces = gr.Gallery(
|
874 |
+
label="Faces",
|
875 |
+
show_label=False,
|
876 |
+
elem_id="gallery",
|
877 |
+
columns=[6], rows=[6], object_fit="contain", height=WIDGET_PREVIEW_HEIGHT,
|
878 |
+
)
|
879 |
+
|
880 |
+
## ------------------------------ FOOTER LINKS ------------------------------
|
881 |
+
|
882 |
+
with gr.Row(variant='panel'):
|
883 |
+
gr.HTML(
|
884 |
+
"""
|
885 |
+
<div style="display: flex; flex-direction: row; justify-content: center;">
|
886 |
+
<h3 style="margin-right: 10px;"><a href="https://github.com/sponsors/harisreedhar" style="text-decoration: none;">π€ Sponsor</a></h3>
|
887 |
+
<h3 style="margin-right: 10px;"><a href="https://github.com/harisreedhar/Swap-Mukham" style="text-decoration: none;">π¨βπ» Source</a></h3>
|
888 |
+
<h3 style="margin-right: 10px;"><a href="https://github.com/harisreedhar/Swap-Mukham#disclaimer" style="text-decoration: none;">β οΈ Disclaimer</a></h3>
|
889 |
+
<h3 style="margin-right: 10px;"><a href="https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb" style="text-decoration: none;">π Colab</a></h3>
|
890 |
+
<h3><a href="https://github.com/harisreedhar/Swap-Mukham#acknowledgements" style="text-decoration: none;">π€ Acknowledgements</a></h3>
|
891 |
+
</div>
|
892 |
+
"""
|
893 |
+
)
|
894 |
+
|
895 |
+
## ------------------------------ GRADIO EVENTS ------------------------------
|
896 |
+
|
897 |
+
def on_target_type_change(value):
|
898 |
+
visibility = {
|
899 |
+
"Image": (True, False, False, False, True, False, False, False),
|
900 |
+
"Video": (False, True, False, True, True, True, True, True),
|
901 |
+
"Directory": (False, False, True, False, False, False, False, False),
|
902 |
+
"Stream": (False, False, True, False, False, False, False, False),
|
903 |
+
}
|
904 |
+
return list(gr.update(visible=i) for i in visibility[value])
|
905 |
+
|
906 |
+
target_type.change(
|
907 |
+
on_target_type_change,
|
908 |
+
inputs=[target_type],
|
909 |
+
outputs=[
|
910 |
+
input_image_group,
|
911 |
+
input_video_group,
|
912 |
+
input_directory_group,
|
913 |
+
frame_slider_box,
|
914 |
+
fg_mask_group,
|
915 |
+
add_fg_mask_btn,
|
916 |
+
del_fg_mask_btn,
|
917 |
+
test_swap,
|
918 |
+
],
|
919 |
+
)
|
920 |
+
|
921 |
+
target_image_input.change(
|
922 |
+
lambda inp: gr.update(value=inp),
|
923 |
+
inputs=[target_image_input],
|
924 |
+
outputs=[img_fg_mask]
|
925 |
+
)
|
926 |
+
|
927 |
+
def on_swap_condition_change(value):
|
928 |
+
visibility = {
|
929 |
+
"age less than": (True, False, True),
|
930 |
+
"age greater than": (True, False, True),
|
931 |
+
"specific face": (False, True, False),
|
932 |
+
}
|
933 |
+
return tuple(
|
934 |
+
gr.update(visible=i) for i in visibility.get(value, (False, False, True))
|
935 |
+
)
|
936 |
+
|
937 |
+
swap_condition.change(
|
938 |
+
on_swap_condition_change,
|
939 |
+
inputs=[swap_condition],
|
940 |
+
outputs=[age, specific_face, source_image_input],
|
941 |
+
)
|
942 |
+
|
943 |
+
def on_set_slider_range(video_path):
|
944 |
+
if video_path is None or not os.path.exists(video_path):
|
945 |
+
gr.Info("Check video path")
|
946 |
+
else:
|
947 |
+
try:
|
948 |
+
cap = cv2.VideoCapture(video_path)
|
949 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
950 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
951 |
+
cap.release()
|
952 |
+
if total_frames > 0:
|
953 |
+
total_frames -= 1
|
954 |
+
return gr.Slider.update(
|
955 |
+
minimum=0, maximum=total_frames, value=0, interactive=True
|
956 |
+
)
|
957 |
+
gr.Info("Error fetching video")
|
958 |
+
except:
|
959 |
+
gr.Info("Error fetching video")
|
960 |
+
|
961 |
+
set_slider_range_event = set_slider_range_btn.click(
|
962 |
+
on_set_slider_range,
|
963 |
+
inputs=[video_input],
|
964 |
+
outputs=[frame_slider],
|
965 |
+
)
|
966 |
+
|
967 |
+
def update_preview(video_path, frame_index, use_foreground_mask, resolution):
|
968 |
+
if not os.path.exists(video_path):
|
969 |
+
yield gr.update(value=None), gr.update(value=None), gr.update(visible=False)
|
970 |
+
else:
|
971 |
+
frame = get_single_video_frame(video_path, frame_index)
|
972 |
+
if frame is not None:
|
973 |
+
if use_foreground_mask:
|
974 |
+
overlayed_image = frame
|
975 |
+
if frame_index in FOREGROUND_MASK_DICT.keys():
|
976 |
+
mask = FOREGROUND_MASK_DICT.get(frame_index, None)
|
977 |
+
if mask is not None:
|
978 |
+
overlayed_image = image_mask_overlay(frame, mask)
|
979 |
+
yield gr.update(value=None), gr.update(value=None), gr.update(visible=False) # clear previous mask
|
980 |
+
frame = resize_image_by_resolution(frame, resolution)
|
981 |
+
yield gr.update(value=frame[:, :, ::-1]), gr.update(
|
982 |
+
value=overlayed_image[:, :, ::-1], visible=True
|
983 |
+
), gr.update(visible=False)
|
984 |
+
else:
|
985 |
+
frame = resize_image_by_resolution(frame, resolution)
|
986 |
+
yield gr.update(value=frame[:, :, ::-1]), gr.update(value=None), gr.update(
|
987 |
+
visible=False
|
988 |
+
)
|
989 |
+
|
990 |
+
global CURRENT_FRAME
|
991 |
+
CURRENT_FRAME = frame
|
992 |
+
|
993 |
+
frame_slider_event = frame_slider.change(
|
994 |
+
fn=update_preview,
|
995 |
+
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
|
996 |
+
outputs=[preview_image, img_fg_mask, preview_video],
|
997 |
+
show_progress=False,
|
998 |
+
)
|
999 |
+
|
1000 |
+
def add_foreground_mask(fg, frame_index, softness):
|
1001 |
+
if fg is not None:
|
1002 |
+
mask = fg.get("mask", None)
|
1003 |
+
if mask is not None:
|
1004 |
+
alpha_rgb = cv2.cvtColor(mask, cv2.COLOR_BGRA2RGB)
|
1005 |
+
alpha_rgb = cv2.blur(alpha_rgb, (softness, softness))
|
1006 |
+
FOREGROUND_MASK_DICT[frame_index] = alpha_rgb.astype("float32") / 255.0
|
1007 |
+
gr.Info(f"saved mask index {frame_index}")
|
1008 |
+
|
1009 |
+
add_foreground_mask_event = add_fg_mask_btn.click(
|
1010 |
+
fn=add_foreground_mask,
|
1011 |
+
inputs=[img_fg_mask, frame_slider, fg_mask_softness],
|
1012 |
+
).then(
|
1013 |
+
fn=update_preview,
|
1014 |
+
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
|
1015 |
+
outputs=[preview_image, img_fg_mask, preview_video],
|
1016 |
+
show_progress=False,
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
def delete_foreground_mask(frame_index):
|
1020 |
+
if frame_index in FOREGROUND_MASK_DICT.keys():
|
1021 |
+
FOREGROUND_MASK_DICT.pop(frame_index)
|
1022 |
+
gr.Info(f"Deleted mask index {frame_index}")
|
1023 |
+
|
1024 |
+
del_custom_mask_event = del_fg_mask_btn.click(
|
1025 |
+
fn=delete_foreground_mask, inputs=[frame_slider]
|
1026 |
+
).then(
|
1027 |
+
fn=update_preview,
|
1028 |
+
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
|
1029 |
+
outputs=[preview_image, img_fg_mask, preview_video],
|
1030 |
+
show_progress=False,
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
def get_collected_faces(image):
|
1034 |
+
if image is not None:
|
1035 |
+
gr.Info(f"Collecting faces...")
|
1036 |
+
faces = SWAP_MUKHAM.collect_heads(image)
|
1037 |
+
COLLECTED_FACES.extend(faces)
|
1038 |
+
yield COLLECTED_FACES
|
1039 |
+
gr.Info(f"Collected {len(faces)} faces")
|
1040 |
+
|
1041 |
+
collect_faces.click(get_collected_faces, inputs=[preview_image], outputs=[collected_faces])
|
1042 |
+
|
1043 |
+
src_specific_inputs = []
|
1044 |
+
gen_variable_txt = ",".join(
|
1045 |
+
[f"src{i+1}" for i in range(gv.NUM_OF_SRC_SPECIFIC)]
|
1046 |
+
+ [f"trg{i+1}" for i in range(gv.NUM_OF_SRC_SPECIFIC)]
|
1047 |
+
)
|
1048 |
+
exec(f"src_specific_inputs = ({gen_variable_txt})")
|
1049 |
+
|
1050 |
+
test_mode = gr.Checkbox(value=False, visible=False)
|
1051 |
+
|
1052 |
+
swap_inputs = [
|
1053 |
+
test_mode,
|
1054 |
+
target_type,
|
1055 |
+
target_image_input,
|
1056 |
+
video_input,
|
1057 |
+
directory_input,
|
1058 |
+
source_image_input,
|
1059 |
+
use_foreground_mask,
|
1060 |
+
img_fg_mask,
|
1061 |
+
fg_mask_softness,
|
1062 |
+
output_directory,
|
1063 |
+
output_name,
|
1064 |
+
use_datetime_suffix,
|
1065 |
+
sequence_output_format,
|
1066 |
+
keep_output_sequence,
|
1067 |
+
swap_condition,
|
1068 |
+
age,
|
1069 |
+
distance_slider,
|
1070 |
+
face_enhancer_name,
|
1071 |
+
face_upscaler_opacity,
|
1072 |
+
use_face_parsing_mask,
|
1073 |
+
parse_from_target,
|
1074 |
+
mask_regions,
|
1075 |
+
mask_blur_amount,
|
1076 |
+
mask_erode_amount,
|
1077 |
+
swap_iteration,
|
1078 |
+
face_scale,
|
1079 |
+
use_laplacian_blending,
|
1080 |
+
crop_top,
|
1081 |
+
crop_bott,
|
1082 |
+
crop_left,
|
1083 |
+
crop_right,
|
1084 |
+
frame_slider,
|
1085 |
+
number_of_threads,
|
1086 |
+
use_frame_selection,
|
1087 |
+
frame_selection_ranges,
|
1088 |
+
video_quality,
|
1089 |
+
face_detection_condition,
|
1090 |
+
face_detection_size,
|
1091 |
+
face_detection_threshold,
|
1092 |
+
averaging_method,
|
1093 |
+
*src_specific_inputs,
|
1094 |
+
]
|
1095 |
+
|
1096 |
+
swap_outputs = [
|
1097 |
+
preview_image,
|
1098 |
+
output_directory_button,
|
1099 |
+
output_video_button,
|
1100 |
+
preview_video,
|
1101 |
+
]
|
1102 |
+
|
1103 |
+
swap_event = swap_button.click(fn=process, inputs=swap_inputs, outputs=swap_outputs)
|
1104 |
+
|
1105 |
+
test_swap_settings = swap_inputs
|
1106 |
+
test_swap_settings[0] = gr.Checkbox(value=True, visible=False)
|
1107 |
+
|
1108 |
+
test_swap_event = test_swap.click(
|
1109 |
+
fn=update_preview,
|
1110 |
+
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
|
1111 |
+
outputs=[preview_image, preview_video],
|
1112 |
+
show_progress=False,
|
1113 |
+
).then(
|
1114 |
+
fn=process, inputs=test_swap_settings, outputs=swap_outputs, show_progress=True
|
1115 |
+
)
|
1116 |
+
|
1117 |
+
def stop_running():
|
1118 |
+
global IS_RUNNING
|
1119 |
+
IS_RUNNING = False
|
1120 |
+
print("[ Process cancelled ]")
|
1121 |
+
gr.Info("Process cancelled")
|
1122 |
+
|
1123 |
+
cancel_button.click(
|
1124 |
+
fn=stop_running,
|
1125 |
+
inputs=None,
|
1126 |
+
cancels=[swap_event, set_slider_range_event, test_swap_event],
|
1127 |
+
show_progress=True,
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
if __name__ == "__main__":
|
1131 |
+
if gv.USE_COLAB:
|
1132 |
+
print("Running in colab mode")
|
1133 |
+
|
1134 |
+
interface.queue(concurrency_count=2, max_size=20).launch(share=gv.USE_COLAB)
|
assets/images/loading.gif
ADDED
assets/images/logo.png
ADDED
assets/pretrained_models/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
change_log.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Change-log
|
2 |
+
|
3 |
+
## 30/07/2023
|
4 |
+
- change existing nsfw filter to open-nsfw from yahoo
|
5 |
+
- Add codeformer support
|
default_paths.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
FFMPEG_PATH = "./ffmpeg/ffmpeg" if os.path.exists("./ffmpeg/ffmpeg") else None
|
4 |
+
|
5 |
+
INSWAPPER_PATH = "./assets/pretrained_models/inswapper_128.onnx"
|
6 |
+
FACE_PARSER_PATH = "./assets/pretrained_models/faceparser.onnx"
|
7 |
+
ARCFACE_PATH = "./assets/pretrained_models/w600k_r50.onnx"
|
8 |
+
RETINAFACE_PATH = "./assets/pretrained_models/det_10g.onnx"
|
9 |
+
OPEN_NSFW_PATH = "./assets/pretrained_models/open-nsfw.onnx"
|
10 |
+
GENDERAGE_PATH = "./assets/pretrained_models/gender_age.onnx"
|
11 |
+
|
12 |
+
CODEFORMER_PATH = "./assets/pretrained_models/codeformer.onnx"
|
13 |
+
GFPGAN_V14_PATH = "./assets/pretrained_models/GFPGANv1.4.onnx"
|
14 |
+
GFPGAN_V13_PATH = "./assets/pretrained_models/GFPGANv1.3.onnx"
|
15 |
+
GFPGAN_V12_PATH = "./assets/pretrained_models/GFPGANv1.2.onnx"
|
16 |
+
GPEN_BFR_512_PATH = "./assets/pretrained_models/GPEN-BFR-512.onnx"
|
17 |
+
GPEN_BFR_256_PATH = "./assets/pretrained_models/GPEN-BFR-256.onnx"
|
18 |
+
RESTOREFORMER_PATH = "./assets/pretrained_models/restoreformer.onnx"
|
face_analyser.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import threading
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
import concurrent.futures
|
7 |
+
import default_paths as dp
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from utils.arcface import ArcFace
|
10 |
+
from utils.gender_age import GenderAge
|
11 |
+
from utils.retinaface import RetinaFace
|
12 |
+
|
13 |
+
cache = {}
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class Face:
|
17 |
+
bbox: np.ndarray
|
18 |
+
kps: np.ndarray
|
19 |
+
det_score: float
|
20 |
+
embedding: np.ndarray
|
21 |
+
gender: int
|
22 |
+
age: int
|
23 |
+
|
24 |
+
def __getitem__(self, key):
|
25 |
+
return getattr(self, key)
|
26 |
+
|
27 |
+
def __setitem__(self, key, value):
|
28 |
+
if hasattr(self, key):
|
29 |
+
setattr(self, key, value)
|
30 |
+
else:
|
31 |
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
|
32 |
+
|
33 |
+
single_face_detect_conditions = [
|
34 |
+
"best detection",
|
35 |
+
"left most",
|
36 |
+
"right most",
|
37 |
+
"top most",
|
38 |
+
"bottom most",
|
39 |
+
"middle",
|
40 |
+
"biggest",
|
41 |
+
"smallest",
|
42 |
+
]
|
43 |
+
|
44 |
+
multi_face_detect_conditions = [
|
45 |
+
"all face",
|
46 |
+
"specific face",
|
47 |
+
"age less than",
|
48 |
+
"age greater than",
|
49 |
+
"all male",
|
50 |
+
"all female"
|
51 |
+
]
|
52 |
+
|
53 |
+
face_detect_conditions = multi_face_detect_conditions + single_face_detect_conditions
|
54 |
+
|
55 |
+
|
56 |
+
def get_single_face(faces, method="best detection"):
|
57 |
+
total_faces = len(faces)
|
58 |
+
|
59 |
+
if total_faces == 0:
|
60 |
+
return None
|
61 |
+
|
62 |
+
if total_faces == 1:
|
63 |
+
return faces[0]
|
64 |
+
|
65 |
+
if method == "best detection":
|
66 |
+
return sorted(faces, key=lambda face: face["det_score"])[-1]
|
67 |
+
elif method == "left most":
|
68 |
+
return sorted(faces, key=lambda face: face["bbox"][0])[0]
|
69 |
+
elif method == "right most":
|
70 |
+
return sorted(faces, key=lambda face: face["bbox"][0])[-1]
|
71 |
+
elif method == "top most":
|
72 |
+
return sorted(faces, key=lambda face: face["bbox"][1])[0]
|
73 |
+
elif method == "bottom most":
|
74 |
+
return sorted(faces, key=lambda face: face["bbox"][1])[-1]
|
75 |
+
elif method == "middle":
|
76 |
+
return sorted(faces, key=lambda face: (
|
77 |
+
(face["bbox"][0] + face["bbox"][2]) / 2 - 0.5) ** 2 +
|
78 |
+
((face["bbox"][1] + face["bbox"][3]) / 2 - 0.5) ** 2)[len(faces) // 2]
|
79 |
+
elif method == "biggest":
|
80 |
+
return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[-1]
|
81 |
+
elif method == "smallest":
|
82 |
+
return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[0]
|
83 |
+
|
84 |
+
def filter_face_by_age(faces, age, method="age less than"):
|
85 |
+
if method == "age less than":
|
86 |
+
return [face for face in faces if face["age"] < age]
|
87 |
+
elif method == "age greater than":
|
88 |
+
return [face for face in faces if face["age"] > age]
|
89 |
+
elif method == "age equals to":
|
90 |
+
return [face for face in faces if face["age"] == age]
|
91 |
+
|
92 |
+
def cosine_distance(a, b):
|
93 |
+
a /= np.linalg.norm(a)
|
94 |
+
b /= np.linalg.norm(b)
|
95 |
+
return 1 - np.dot(a, b)
|
96 |
+
|
97 |
+
def is_similar_face(face1, face2, threshold=0.6):
|
98 |
+
distance = cosine_distance(face1["embedding"], face2["embedding"])
|
99 |
+
return distance < threshold
|
100 |
+
|
101 |
+
|
102 |
+
class AnalyseFace:
|
103 |
+
def __init__(self, provider=["CPUExecutionProvider"], session_options=None):
|
104 |
+
self.detector = RetinaFace(model_file=dp.RETINAFACE_PATH, provider=provider, session_options=session_options)
|
105 |
+
self.recognizer = ArcFace(model_file=dp.ARCFACE_PATH, provider=provider, session_options=session_options)
|
106 |
+
self.gender_age = GenderAge(model_file=dp.GENDERAGE_PATH, provider=provider, session_options=session_options)
|
107 |
+
self.detect_condition = "best detection"
|
108 |
+
self.detection_size = (640, 640)
|
109 |
+
self.detection_threshold = 0.5
|
110 |
+
|
111 |
+
def analyser(self, img, skip_task=[]):
|
112 |
+
bboxes, kpss = self.detector.detect(img, input_size=self.detection_size, det_thresh=self.detection_threshold)
|
113 |
+
faces = []
|
114 |
+
for i in range(bboxes.shape[0]):
|
115 |
+
feat, gender, age = None, None, None
|
116 |
+
bbox = bboxes[i, 0:4]
|
117 |
+
det_score = bboxes[i, 4]
|
118 |
+
kps = None
|
119 |
+
if kpss is not None:
|
120 |
+
kps = kpss[i]
|
121 |
+
if 'embedding' not in skip_task:
|
122 |
+
feat = self.recognizer.get(img, kpss[i])
|
123 |
+
if 'gender_age' not in skip_task:
|
124 |
+
gender, age = self.gender_age.predict(img, kpss[i])
|
125 |
+
face = Face(bbox=bbox, kps=kps, det_score=det_score, embedding=feat, gender=gender, age=age)
|
126 |
+
faces.append(face)
|
127 |
+
return faces
|
128 |
+
|
129 |
+
def get_faces(self, image, scale=1., skip_task=[]):
|
130 |
+
if isinstance(image, str):
|
131 |
+
image = cv2.imread(image)
|
132 |
+
|
133 |
+
faces = self.analyser(image, skip_task=skip_task)
|
134 |
+
|
135 |
+
if scale != 1: # landmark-scale
|
136 |
+
for i, face in enumerate(faces):
|
137 |
+
landmark = face['kps']
|
138 |
+
center = np.mean(landmark, axis=0)
|
139 |
+
landmark = center + (landmark - center) * scale
|
140 |
+
faces[i]['kps'] = landmark
|
141 |
+
|
142 |
+
return faces
|
143 |
+
|
144 |
+
def get_face(self, image, scale=1., skip_task=[]):
|
145 |
+
faces = self.get_faces(image, scale=scale, skip_task=skip_task)
|
146 |
+
return get_single_face(faces, method=self.detect_condition)
|
147 |
+
|
148 |
+
def get_averaged_face(self, images, method="mean"):
|
149 |
+
if not isinstance(images, list):
|
150 |
+
images = [images]
|
151 |
+
|
152 |
+
face = self.get_face(images[0], scale=1., skip_task=[])
|
153 |
+
|
154 |
+
if len(images) > 1:
|
155 |
+
embeddings = [face['embedding']]
|
156 |
+
|
157 |
+
for image in images[1:]:
|
158 |
+
face = self.get_face(image, scale=1., skip_task=[])
|
159 |
+
embeddings.append(face['embedding'])
|
160 |
+
|
161 |
+
if method == "mean":
|
162 |
+
avg_embedding = np.mean(embeddings, axis=0)
|
163 |
+
elif method == "median":
|
164 |
+
avg_embedding = np.median(embeddings, axis=0)
|
165 |
+
|
166 |
+
face['embedding'] = avg_embedding
|
167 |
+
|
168 |
+
return face
|
face_parsing.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import onnxruntime
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
mask_regions = {
|
6 |
+
"Background":0,
|
7 |
+
"Skin":1,
|
8 |
+
"L-Eyebrow":2,
|
9 |
+
"R-Eyebrow":3,
|
10 |
+
"L-Eye":4,
|
11 |
+
"R-Eye":5,
|
12 |
+
"Eye-G":6,
|
13 |
+
"L-Ear":7,
|
14 |
+
"R-Ear":8,
|
15 |
+
"Ear-R":9,
|
16 |
+
"Nose":10,
|
17 |
+
"Mouth":11,
|
18 |
+
"U-Lip":12,
|
19 |
+
"L-Lip":13,
|
20 |
+
"Neck":14,
|
21 |
+
"Neck-L":15,
|
22 |
+
"Cloth":16,
|
23 |
+
"Hair":17,
|
24 |
+
"Hat":18
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class FaceParser:
|
29 |
+
def __init__(self, model_path=None, provider=['CPUExecutionProvider'], session_options=None):
|
30 |
+
self.session_options = session_options
|
31 |
+
if self.session_options is None:
|
32 |
+
self.session_options = onnxruntime.SessionOptions()
|
33 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
|
34 |
+
self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
|
35 |
+
self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
|
36 |
+
|
37 |
+
def parse(self, img, regions=[1,2,3,4,5,10,11,12,13]):
|
38 |
+
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
39 |
+
img = img.astype(np.float32)[:,:,::-1] / 255.0
|
40 |
+
img = (img - self.mean) / self.std
|
41 |
+
img = np.expand_dims(img.transpose((2, 0, 1)), axis=0).astype(np.float32)
|
42 |
+
|
43 |
+
out = self.session.run(None, {'input':img})[0]
|
44 |
+
out = out.squeeze(0).argmax(0)
|
45 |
+
out = np.isin(out, regions).astype('float32')
|
46 |
+
|
47 |
+
return out.clip(0, 1)
|
48 |
+
|
49 |
+
|
50 |
+
def mask_regions_to_list(values):
|
51 |
+
out_ids = []
|
52 |
+
for value in values:
|
53 |
+
if value in mask_regions.keys():
|
54 |
+
out_ids.append(mask_regions.get(value))
|
55 |
+
return out_ids
|
face_swapper.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import onnx
|
3 |
+
import cv2
|
4 |
+
import onnxruntime
|
5 |
+
import numpy as np
|
6 |
+
from onnx import numpy_helper
|
7 |
+
from numpy.linalg import norm as l2norm
|
8 |
+
from utils.face_alignment import norm_crop2
|
9 |
+
|
10 |
+
|
11 |
+
class Inswapper():
|
12 |
+
def __init__(self, model_file=None, provider=['CPUExecutionProvider'], session_options=None):
|
13 |
+
self.model_file = model_file
|
14 |
+
model = onnx.load(self.model_file)
|
15 |
+
graph = model.graph
|
16 |
+
self.emap = numpy_helper.to_array(graph.initializer[-1])
|
17 |
+
|
18 |
+
self.session_options = session_options
|
19 |
+
if self.session_options is None:
|
20 |
+
self.session_options = onnxruntime.SessionOptions()
|
21 |
+
self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=provider)
|
22 |
+
|
23 |
+
def forward(self, frame, target, source, n_pass=1):
|
24 |
+
trg, matrix = norm_crop2(frame, target['kps'], 128)
|
25 |
+
|
26 |
+
latent = source['embedding'].reshape((1, -1))
|
27 |
+
latent = np.dot(latent, self.emap)
|
28 |
+
latent /= np.linalg.norm(latent)
|
29 |
+
|
30 |
+
blob = trg.astype('float32') / 255
|
31 |
+
blob = blob[:, :, ::-1]
|
32 |
+
blob = np.expand_dims(blob, axis=0).transpose(0, 3, 1, 2)
|
33 |
+
|
34 |
+
for _ in range(max(int(n_pass),1)):
|
35 |
+
blob = self.session.run(['output'], {'target': blob, 'source': latent})[0]
|
36 |
+
|
37 |
+
out = blob[0].transpose((1, 2, 0))
|
38 |
+
out = (out * 255).clip(0,255)
|
39 |
+
out = out.astype('uint8')[:, :, ::-1]
|
40 |
+
|
41 |
+
del blob, latent
|
42 |
+
|
43 |
+
return trg, out, matrix
|
face_upscaler.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import default_paths as dp
|
4 |
+
from upscaler.GPEN import GPEN
|
5 |
+
from upscaler.GFPGAN import GFPGAN
|
6 |
+
from upscaler.codeformer import CodeFormer
|
7 |
+
from upscaler.restoreformer import RestoreFormer
|
8 |
+
|
9 |
+
def gfpgan_runner(img, model):
|
10 |
+
img = model.enhance(img)
|
11 |
+
return img
|
12 |
+
|
13 |
+
|
14 |
+
def codeformer_runner(img, model):
|
15 |
+
img = model.enhance(img, w=0.9)
|
16 |
+
return img
|
17 |
+
|
18 |
+
|
19 |
+
def gpen_runner(img, model):
|
20 |
+
img = model.enhance(img)
|
21 |
+
return img
|
22 |
+
|
23 |
+
|
24 |
+
def restoreformer_runner(img, model):
|
25 |
+
img = model.enhance(img)
|
26 |
+
return img
|
27 |
+
|
28 |
+
|
29 |
+
supported_upscalers = {
|
30 |
+
"CodeFormer": (dp.CODEFORMER_PATH, codeformer_runner),
|
31 |
+
"GFPGANv1.4": (dp.GFPGAN_V14_PATH, gfpgan_runner),
|
32 |
+
"GFPGANv1.3": (dp.GFPGAN_V13_PATH, gfpgan_runner),
|
33 |
+
"GFPGANv1.2": (dp.GFPGAN_V12_PATH, gfpgan_runner),
|
34 |
+
"GPEN-BFR-512": (dp.GPEN_BFR_512_PATH, gpen_runner),
|
35 |
+
"GPEN-BFR-256": (dp.GPEN_BFR_256_PATH, gpen_runner),
|
36 |
+
"RestoreFormer": (dp.RESTOREFORMER_PATH, gpen_runner),
|
37 |
+
}
|
38 |
+
|
39 |
+
cv2_upscalers = ["LANCZOS4", "CUBIC", "NEAREST"]
|
40 |
+
|
41 |
+
def get_available_upscalers_names():
|
42 |
+
available = []
|
43 |
+
for name, data in supported_upscalers.items():
|
44 |
+
if os.path.exists(data[0]):
|
45 |
+
available.append(name)
|
46 |
+
return available
|
47 |
+
|
48 |
+
|
49 |
+
def load_face_upscaler(name='GFPGAN', provider=["CPUExecutionProvider"], session_options=None):
|
50 |
+
assert name in get_available_upscalers_names() + cv2_upscalers, f"Face upscaler {name} unavailable."
|
51 |
+
if name in supported_upscalers.keys():
|
52 |
+
model_path, model_runner = supported_upscalers.get(name)
|
53 |
+
if name == 'CodeFormer':
|
54 |
+
model = CodeFormer(model_path=model_path, provider=provider, session_options=session_options)
|
55 |
+
elif name.startswith('GFPGAN'):
|
56 |
+
model = GFPGAN(model_path=model_path, provider=provider, session_options=session_options)
|
57 |
+
elif name.startswith('GPEN'):
|
58 |
+
model = GPEN(model_path=model_path, provider=provider, session_options=session_options)
|
59 |
+
elif name == "RestoreFormer":
|
60 |
+
model = RestoreFormer(model_path=model_path, provider=provider, session_options=session_options)
|
61 |
+
elif name == 'LANCZOS4':
|
62 |
+
model = None
|
63 |
+
model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_LANCZOS4)
|
64 |
+
elif name == 'CUBIC':
|
65 |
+
model = None
|
66 |
+
model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_CUBIC)
|
67 |
+
elif name == 'NEAREST':
|
68 |
+
model = None
|
69 |
+
model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_NEAREST)
|
70 |
+
else:
|
71 |
+
model = None
|
72 |
+
return (model, model_runner)
|
global_variables.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from face_parsing import mask_regions
|
3 |
+
from utils.image import resolution_map
|
4 |
+
from face_upscaler import get_available_upscalers_names, cv2_upscalers
|
5 |
+
from face_analyser import single_face_detect_conditions, face_detect_conditions
|
6 |
+
|
7 |
+
DEFAULT_OUTPUT_PATH = os.getcwd()
|
8 |
+
|
9 |
+
MASK_BLUR_AMOUNT = 0.1
|
10 |
+
MASK_ERODE_AMOUNT = 0.15
|
11 |
+
MASK_REGIONS_DEFAULT = ["Skin", "R-Eyebrow", "L-Eyebrow", "L-Eye", "R-Eye", "Nose", "Mouth", "L-Lip", "U-Lip"]
|
12 |
+
MASK_REGIONS = list(mask_regions.keys())
|
13 |
+
|
14 |
+
NSFW_DETECTOR = None
|
15 |
+
|
16 |
+
FACE_ENHANCER_LIST = ["NONE"]
|
17 |
+
FACE_ENHANCER_LIST.extend(get_available_upscalers_names())
|
18 |
+
FACE_ENHANCER_LIST.extend(cv2_upscalers)
|
19 |
+
|
20 |
+
RESOLUTIONS = list(resolution_map.keys())
|
21 |
+
|
22 |
+
SINGLE_FACE_DETECT_CONDITIONS = single_face_detect_conditions
|
23 |
+
FACE_DETECT_CONDITIONS = face_detect_conditions
|
24 |
+
DETECT_CONDITION = "best detection"
|
25 |
+
DETECT_SIZE = 640
|
26 |
+
DETECT_THRESHOLD = 0.6
|
27 |
+
|
28 |
+
NUM_OF_SRC_SPECIFIC = 10
|
29 |
+
|
30 |
+
MAX_THREADS = 2
|
31 |
+
|
32 |
+
VIDEO_QUALITY_LIST = ["poor", "low", "medium", "high", "best"]
|
33 |
+
VIDEO_QUALITY = "high"
|
34 |
+
|
35 |
+
AVERAGING_METHODS = ["mean", "median"]
|
36 |
+
AVERAGING_METHOD = "mean"
|
nsfw_checker/LICENSE.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Copyright 2016, Yahoo Inc.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
11 |
+
|
nsfw_checker/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . opennsfw import NSFWChecker
|
nsfw_checker/opennsfw.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import onnx
|
3 |
+
import onnxruntime
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
# https://github.com/yahoo/open_nsfw
|
8 |
+
|
9 |
+
def prepare_image(img):
|
10 |
+
img = cv2.resize(img, (224,224)).astype('float32')
|
11 |
+
img -= np.array([104, 117, 123], dtype=np.float32)
|
12 |
+
img = np.expand_dims(img, axis=0)
|
13 |
+
return img
|
14 |
+
|
15 |
+
class NSFWChecker:
|
16 |
+
def __init__(self, model_path=None, provider=["CPUExecutionProvider"], session_options=None):
|
17 |
+
model = onnx.load(model_path)
|
18 |
+
self.input_name = model.graph.input[0].name
|
19 |
+
self.session_options = session_options
|
20 |
+
if self.session_options == None:
|
21 |
+
self.session_options = onnxruntime.SessionOptions()
|
22 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
|
23 |
+
|
24 |
+
def check_image(self, image, threshold=0.9):
|
25 |
+
if isinstance(image, str):
|
26 |
+
image = cv2.imread(image)
|
27 |
+
img = prepare_image(image)
|
28 |
+
score = self.session.run(None, {self.input_name:img})[0][0][1]
|
29 |
+
if score >= threshold:
|
30 |
+
return True
|
31 |
+
return False
|
32 |
+
|
33 |
+
def check_video(self, video_path, threshold=0.9, max_frames=100):
|
34 |
+
cap = cv2.VideoCapture(video_path)
|
35 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
36 |
+
|
37 |
+
max_frames = min(total_frames, max_frames)
|
38 |
+
indexes = np.arange(total_frames, dtype=int)
|
39 |
+
shuffled_indexes = np.random.permutation(indexes)[:max_frames]
|
40 |
+
|
41 |
+
for idx in tqdm(shuffled_indexes, desc="Checking"):
|
42 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
|
43 |
+
valid_frame, frame = cap.read()
|
44 |
+
if valid_frame:
|
45 |
+
img = prepare_image(frame)
|
46 |
+
score = self.session.run(None, {self.input_name:img})[0][0][1]
|
47 |
+
if score >= threshold:
|
48 |
+
cap.release()
|
49 |
+
return True
|
50 |
+
cap.release()
|
51 |
+
return False
|
52 |
+
|
53 |
+
def check_image_paths(self, image_paths, threshold=0.9, max_frames=100):
|
54 |
+
total_frames = len(image_paths)
|
55 |
+
max_frames = min(total_frames, max_frames)
|
56 |
+
indexes = np.arange(total_frames, dtype=int)
|
57 |
+
shuffled_indexes = np.random.permutation(indexes)[:max_frames]
|
58 |
+
|
59 |
+
for idx in tqdm(shuffled_indexes, desc="Checking"):
|
60 |
+
frame = cv2.imread(image_paths[idx])
|
61 |
+
img = prepare_image(frame)
|
62 |
+
score = self.session.run(None, {self.input_name:img})[0][0][1]
|
63 |
+
if score >= threshold:
|
64 |
+
return True
|
65 |
+
return False
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=3.40
|
2 |
+
numpy>=1.25.2
|
3 |
+
opencv-python>=4.7.0.72
|
4 |
+
opencv-python-headless>=4.7.0.72
|
5 |
+
onnx==1.14.0
|
6 |
+
onnxruntime==1.15.0
|
swap_mukham.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import default_paths as dp
|
5 |
+
from utils.device import get_device_and_provider
|
6 |
+
from utils.face_alignment import get_cropped_head
|
7 |
+
from utils.image import paste_to_whole, mix_two_image
|
8 |
+
|
9 |
+
from face_swapper import Inswapper
|
10 |
+
from face_parsing import FaceParser
|
11 |
+
from face_upscaler import get_available_upscalers_names, cv2_upscalers, load_face_upscaler
|
12 |
+
from face_analyser import AnalyseFace, single_face_detect_conditions, face_detect_conditions, get_single_face, is_similar_face
|
13 |
+
|
14 |
+
from nsfw_checker import NSFWChecker
|
15 |
+
|
16 |
+
get_device_name = lambda x: x.lower().replace("executionprovider", "")
|
17 |
+
|
18 |
+
class SwapMukham:
|
19 |
+
def __init__(self, device='cpu'):
|
20 |
+
self.load_nsfw_detector(device=device)
|
21 |
+
self.load_face_swapper(device=device)
|
22 |
+
self.load_face_analyser(device=device)
|
23 |
+
# self.load_face_parser(device=device)
|
24 |
+
# self.load_face_upscaler(device=device)
|
25 |
+
|
26 |
+
self.face_parser = None
|
27 |
+
self.face_upscaler = None
|
28 |
+
self.face_upscaler_name = ""
|
29 |
+
|
30 |
+
def set_values(self, args):
|
31 |
+
self.age = args.get('age', 0)
|
32 |
+
self.detect_condition = args.get('detect_condition', "left most")
|
33 |
+
self.similarity = args.get('similarity', 0.6)
|
34 |
+
self.swap_condition = args.get('swap_condition', 'left most')
|
35 |
+
self.face_scale = args.get('face_scale', 1.0)
|
36 |
+
self.num_of_pass = args.get('num_of_pass', 1)
|
37 |
+
self.mask_crop_values = args.get('mask_crop_values', (0,0,0,0))
|
38 |
+
self.mask_erode_amount = args.get('mask_erode_amount', 0.1)
|
39 |
+
self.mask_blur_amount = args.get('mask_blur_amount', 0.1)
|
40 |
+
self.use_laplacian_blending = args.get('use_laplacian_blending', False)
|
41 |
+
self.use_face_parsing = args.get('use_face_parsing', False)
|
42 |
+
self.face_parse_regions = args.get('face_parse_regions', [1,2,3,4,5,10,11,12,13])
|
43 |
+
self.face_upscaler_opacity = args.get('face_upscaler_opacity', 1.)
|
44 |
+
self.parse_from_target = args.get('parse_from_target', False)
|
45 |
+
self.averaging_method = args.get('averaging_method', 'mean')
|
46 |
+
|
47 |
+
self.analyser.detection_threshold = args.get('face_detection_threshold', 0.5)
|
48 |
+
self.analyser.detection_size = args.get('face_detection_size', (640, 640))
|
49 |
+
self.analyser.detect_condition = args.get('face_detection_condition', 'best detection')
|
50 |
+
|
51 |
+
def load_nsfw_detector(self, device='cpu'):
|
52 |
+
device, provider, options = get_device_and_provider(device=device)
|
53 |
+
self.nsfw_detector = NSFWChecker(model_path=dp.OPEN_NSFW_PATH, provider=provider, session_options=options)
|
54 |
+
_device = get_device_name(self.nsfw_detector.session.get_providers()[0])
|
55 |
+
print(f"[{_device}] NSFW detector model loaded.")
|
56 |
+
|
57 |
+
def load_face_swapper(self, device='cpu'):
|
58 |
+
device, provider, options = get_device_and_provider(device=device)
|
59 |
+
self.swapper = Inswapper(model_file=dp.INSWAPPER_PATH, provider=provider, session_options=options)
|
60 |
+
_device = get_device_name(self.swapper.session.get_providers()[0])
|
61 |
+
print(f"[{_device}] Face swapper model loaded.")
|
62 |
+
|
63 |
+
def load_face_analyser(self, device='cpu'):
|
64 |
+
device, provider, options = get_device_and_provider(device=device)
|
65 |
+
self.analyser = AnalyseFace(provider=provider, session_options=options)
|
66 |
+
_device_d = get_device_name(self.analyser.detector.session.get_providers()[0])
|
67 |
+
print(f"[{_device_d}] Face detection model loaded.")
|
68 |
+
_device_r = get_device_name(self.analyser.recognizer.session.get_providers()[0])
|
69 |
+
print(f"[{_device_r}] Face recognition model loaded.")
|
70 |
+
_device_g = get_device_name(self.analyser.gender_age.session.get_providers()[0])
|
71 |
+
print(f"[{_device_g}] Gender & Age detection model loaded.")
|
72 |
+
|
73 |
+
def load_face_parser(self, device='cpu'):
|
74 |
+
device, provider, options = get_device_and_provider(device=device)
|
75 |
+
self.face_parser = FaceParser(model_path=dp.FACE_PARSER_PATH, provider=provider, session_options=options)
|
76 |
+
_device = get_device_name(self.face_parser.session.get_providers()[0])
|
77 |
+
print(f"[{_device}] Face parsing model loaded.")
|
78 |
+
|
79 |
+
def load_face_upscaler(self, name, device='cpu'):
|
80 |
+
device, provider, options = get_device_and_provider(device=device)
|
81 |
+
if name in get_available_upscalers_names():
|
82 |
+
self.face_upscaler = load_face_upscaler(name=name, provider=provider, session_options=options)
|
83 |
+
self.face_upscaler_name = name
|
84 |
+
_device = get_device_name(self.face_upscaler[0].session.get_providers()[0])
|
85 |
+
print(f"[{_device}] Face upscaler model ({name}) loaded.")
|
86 |
+
else:
|
87 |
+
self.face_upscaler_name = ""
|
88 |
+
self.face_upscaler = None
|
89 |
+
|
90 |
+
def collect_heads(self, frame):
|
91 |
+
faces = self.analyser.get_faces(frame, skip_task=['embedding', 'gender_age'])
|
92 |
+
return [get_cropped_head(frame, face.kps) for face in faces if face["det_score"] > 0.5]
|
93 |
+
|
94 |
+
def analyse_source_faces(self, source_specific):
|
95 |
+
analysed_source_specific = []
|
96 |
+
for i, (source, specific) in enumerate(source_specific):
|
97 |
+
if source is not None:
|
98 |
+
analysed_source = self.analyser.get_averaged_face(source, method=self.averaging_method)
|
99 |
+
if specific is not None:
|
100 |
+
analysed_specific = self.analyser.get_face(specific)
|
101 |
+
else:
|
102 |
+
analysed_specific = None
|
103 |
+
analysed_source_specific.append((analysed_source, analysed_specific))
|
104 |
+
self.analysed_source_specific = analysed_source_specific
|
105 |
+
|
106 |
+
def process_frame(self, data):
|
107 |
+
frame, custom_mask = data
|
108 |
+
|
109 |
+
if len(frame.shape) == 2 or (len(frame.shape) == 3 and frame.shape[2] == 1):
|
110 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
|
111 |
+
|
112 |
+
alpha = None
|
113 |
+
if frame.shape[2] == 4:
|
114 |
+
alpha = frame[:, :, 3]
|
115 |
+
frame = frame[:, :, :3]
|
116 |
+
|
117 |
+
_frame = frame.copy()
|
118 |
+
condition = self.swap_condition
|
119 |
+
|
120 |
+
skip_task = []
|
121 |
+
if condition != "specific face":
|
122 |
+
skip_task.append('embedding')
|
123 |
+
if condition not in ['age less than', 'age greater than', 'all male', 'all female']:
|
124 |
+
skip_task.append('gender_age')
|
125 |
+
|
126 |
+
analysed_target_faces = self.analyser.get_faces(frame, scale=self.face_scale, skip_task=skip_task)
|
127 |
+
|
128 |
+
for analysed_target in analysed_target_faces:
|
129 |
+
if (condition == "all face" or
|
130 |
+
(condition == "age less than" and analysed_target["age"] <= self.age) or
|
131 |
+
(condition == "age greater than" and analysed_target["age"] > self.age) or
|
132 |
+
(condition == "all male" and analysed_target["gender"] == 1) or
|
133 |
+
(condition == "all female" and analysed_target["gender"] == 0)):
|
134 |
+
|
135 |
+
trg_face = analysed_target
|
136 |
+
src_face = self.analysed_source_specific[0][0]
|
137 |
+
_frame = self.swap_face(_frame, trg_face, src_face)
|
138 |
+
|
139 |
+
elif condition == "specific face":
|
140 |
+
for analysed_source, analysed_specific in self.analysed_source_specific:
|
141 |
+
if is_similar_face(analysed_specific, analysed_target, threshold=self.similarity):
|
142 |
+
trg_face = analysed_target
|
143 |
+
src_face = analysed_source
|
144 |
+
_frame = self.swap_face(_frame, trg_face, src_face)
|
145 |
+
|
146 |
+
if condition in single_face_detect_conditions and len(analysed_target_faces) > 0:
|
147 |
+
analysed_target = get_single_face(analysed_target_faces, method=condition)
|
148 |
+
trg_face = analysed_target
|
149 |
+
src_face = self.analysed_source_specific[0][0]
|
150 |
+
_frame = self.swap_face(_frame, trg_face, src_face)
|
151 |
+
|
152 |
+
if custom_mask is not None:
|
153 |
+
_mask = cv2.resize(custom_mask, _frame.shape[:2][::-1])
|
154 |
+
_frame = _mask * frame.astype('float32') + (1 - _mask) * _frame.astype('float32')
|
155 |
+
_frame = _frame.clip(0,255).astype('uint8')
|
156 |
+
|
157 |
+
if alpha is not None:
|
158 |
+
_frame = np.dstack((_frame, alpha))
|
159 |
+
|
160 |
+
return _frame
|
161 |
+
|
162 |
+
def swap_face(self, frame, trg_face, src_face):
|
163 |
+
target_face, generated_face, matrix = self.swapper.forward(frame, trg_face, src_face, n_pass=self.num_of_pass)
|
164 |
+
upscaled_face, matrix = self.upscale_face(generated_face, matrix)
|
165 |
+
if self.parse_from_target:
|
166 |
+
mask = self.face_parsed_mask(target_face)
|
167 |
+
else:
|
168 |
+
mask = self.face_parsed_mask(upscaled_face)
|
169 |
+
result = paste_to_whole(
|
170 |
+
upscaled_face,
|
171 |
+
frame,
|
172 |
+
matrix,
|
173 |
+
mask=mask,
|
174 |
+
crop_mask=self.mask_crop_values,
|
175 |
+
blur_amount=self.mask_blur_amount,
|
176 |
+
erode_amount = self.mask_erode_amount
|
177 |
+
)
|
178 |
+
return result
|
179 |
+
|
180 |
+
def upscale_face(self, face, matrix):
|
181 |
+
face_size = face.shape[0]
|
182 |
+
_face = cv2.resize(face, (512,512))
|
183 |
+
if self.face_upscaler is not None:
|
184 |
+
model, runner = self.face_upscaler
|
185 |
+
face = runner(face, model)
|
186 |
+
upscaled_face = cv2.resize(face, (512,512))
|
187 |
+
upscaled_face = mix_two_image(_face, upscaled_face, self.face_upscaler_opacity)
|
188 |
+
return upscaled_face, matrix * (512/face_size)
|
189 |
+
|
190 |
+
def face_parsed_mask(self, face):
|
191 |
+
if self.face_parser is not None and self.use_face_parsing:
|
192 |
+
mask = self.face_parser.parse(face, regions=self.face_parse_regions)
|
193 |
+
else:
|
194 |
+
mask = None
|
195 |
+
return mask
|
swap_mukham_colab.ipynb
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"colab_type": "text",
|
7 |
+
"id": "view-in-github"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"id": "bypvIQG5RHl9"
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"# πΏ **Swap-Mukham**\n",
|
20 |
+
"*Face swap app based on insightface inswapper.*\n",
|
21 |
+
"- [Github](https://github.com/harisreedhar/Swap-Mukham)\n",
|
22 |
+
"- [Disclaimer](https://github.com/harisreedhar/Swap-Mukham#disclaimer)"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "markdown",
|
27 |
+
"metadata": {
|
28 |
+
"id": "csC_DX5zWLEU"
|
29 |
+
},
|
30 |
+
"source": [
|
31 |
+
"# Clone Repository"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": null,
|
37 |
+
"metadata": {
|
38 |
+
"id": "klcx2cKDKX5x"
|
39 |
+
},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"#@title\n",
|
43 |
+
"! git clone https://github.com/harisreedhar/Swap-Mukham"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "markdown",
|
48 |
+
"metadata": {
|
49 |
+
"id": "bebBDddfWTXf"
|
50 |
+
},
|
51 |
+
"source": [
|
52 |
+
"# Install Requirements"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
+
"metadata": {
|
59 |
+
"id": "VgTpg7EsTN3o"
|
60 |
+
},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"#@title\n",
|
64 |
+
"%cd Swap-Mukham/\n",
|
65 |
+
"print(\"Installing requirements...\")\n",
|
66 |
+
"!pip install -r requirements.txt -q\n",
|
67 |
+
"!pip install gdown\n",
|
68 |
+
"print(\"Installing requirements done.\")"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "markdown",
|
73 |
+
"metadata": {
|
74 |
+
"id": "T9L6tgD0Wats"
|
75 |
+
},
|
76 |
+
"source": [
|
77 |
+
"# Download Models"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"metadata": {
|
84 |
+
"id": "17MZO9OvUQAk"
|
85 |
+
},
|
86 |
+
"outputs": [],
|
87 |
+
"source": [
|
88 |
+
"#@title\n",
|
89 |
+
"inswapper_model = \"https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx\" #@param {type:\"string\"}\n",
|
90 |
+
"gfpgan_model = \"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth\" #@param {type:\"string\"}\n",
|
91 |
+
"face_parser_model = \"https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812\" #@param {type:\"string\"}\n",
|
92 |
+
"real_esrgan_2x_model = \"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth\" #@param {type:\"string\"}\n",
|
93 |
+
"real_esrgan_4x_model = \"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth\" #@param {type:\"string\"}\n",
|
94 |
+
"real_esrgan_8x_model = \"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x8.pth\" #@param {type:\"string\"}\n",
|
95 |
+
"codeformer_model = \"https://huggingface.co/bluefoxcreation/Codeformer-ONNX/resolve/main/codeformer.onnx\" #@param {type:\"string\"}\n",
|
96 |
+
"nsfw_det_model = \"https://huggingface.co/bluefoxcreation/open-nsfw/resolve/main/open-nsfw.onnx\" #@param {type:\"string\"}\n",
|
97 |
+
"import gdown\n",
|
98 |
+
"import urllib.request\n",
|
99 |
+
"print(\"Downloading swapper model...\")\n",
|
100 |
+
"urllib.request.urlretrieve(inswapper_model, \"/content/Swap-Mukham/assets/pretrained_models/inswapper_128.onnx\")\n",
|
101 |
+
"print(\"Downloading gfpgan model...\")\n",
|
102 |
+
"urllib.request.urlretrieve(gfpgan_model, \"/content/Swap-Mukham/assets/pretrained_models/GFPGANv1.4.pth\")\n",
|
103 |
+
"print(\"Downloading face parsing model...\")\n",
|
104 |
+
"gdown.download(face_parser_model, \"/content/Swap-Mukham/assets/pretrained_models/79999_iter.pth\")\n",
|
105 |
+
"print(\"Downloading realesrgan 2x model...\")\n",
|
106 |
+
"urllib.request.urlretrieve(real_esrgan_2x_model, \"/content/Swap-Mukham/assets/pretrained_models/RealESRGAN_x2.pth\")\n",
|
107 |
+
"print(\"Downloading realesrgan 4x model...\")\n",
|
108 |
+
"urllib.request.urlretrieve(real_esrgan_4x_model, \"/content/Swap-Mukham/assets/pretrained_models/RealESRGAN_x4.pth\")\n",
|
109 |
+
"print(\"Downloading realesrgan 8x model...\")\n",
|
110 |
+
"urllib.request.urlretrieve(real_esrgan_8x_model, \"/content/Swap-Mukham/assets/pretrained_models/RealESRGAN_x8.pth\")\n",
|
111 |
+
"print(\"Downloading codeformer...\")\n",
|
112 |
+
"urllib.request.urlretrieve(codeformer_model, \"/content/Swap-Mukham/assets/pretrained_models/codeformer.onnx\")\n",
|
113 |
+
"print(\"Downloading NSFW detector model...\")\n",
|
114 |
+
"urllib.request.urlretrieve(nsfw_det_model, \"/content/Swap-Mukham/assets/pretrained_models/open-nsfw.onnx\")\n",
|
115 |
+
"print(\"Downloading models done.\")"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "markdown",
|
120 |
+
"metadata": {
|
121 |
+
"id": "uEcCUw0Co6bE"
|
122 |
+
},
|
123 |
+
"source": [
|
124 |
+
"# Mount Google drive (optional)"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"execution_count": null,
|
130 |
+
"metadata": {
|
131 |
+
"id": "4KssYYippDMw"
|
132 |
+
},
|
133 |
+
"outputs": [],
|
134 |
+
"source": [
|
135 |
+
"from google.colab import auth, drive\n",
|
136 |
+
"auth.authenticate_user()\n",
|
137 |
+
"drive.mount('/content/drive')"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "markdown",
|
142 |
+
"metadata": {
|
143 |
+
"id": "-Tn68Ayqdrlk"
|
144 |
+
},
|
145 |
+
"source": [
|
146 |
+
"# Run App\n",
|
147 |
+
"\n"
|
148 |
+
]
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"execution_count": null,
|
153 |
+
"metadata": {
|
154 |
+
"id": "6dpBjbfVOrrc"
|
155 |
+
},
|
156 |
+
"outputs": [],
|
157 |
+
"source": [
|
158 |
+
"#@title\n",
|
159 |
+
"default_output_path = \"/content/Swap-Mukham\" #@param {type:\"string\"}\n",
|
160 |
+
"\n",
|
161 |
+
"command = f\"python app.py --cuda --colab --out_dir {default_output_path}\"\n",
|
162 |
+
"!{command}"
|
163 |
+
]
|
164 |
+
}
|
165 |
+
],
|
166 |
+
"metadata": {
|
167 |
+
"accelerator": "GPU",
|
168 |
+
"colab": {
|
169 |
+
"gpuType": "T4",
|
170 |
+
"include_colab_link": true,
|
171 |
+
"provenance": []
|
172 |
+
},
|
173 |
+
"kernelspec": {
|
174 |
+
"display_name": "Python 3",
|
175 |
+
"name": "python3"
|
176 |
+
},
|
177 |
+
"language_info": {
|
178 |
+
"name": "python"
|
179 |
+
}
|
180 |
+
},
|
181 |
+
"nbformat": 4,
|
182 |
+
"nbformat_minor": 0
|
183 |
+
}
|
upscaler/GFPGAN.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import onnxruntime
|
4 |
+
import numpy as np
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
|
8 |
+
# gfpgan converted to onnx
|
9 |
+
# using https://github.com/xuanandsix/GFPGAN-onnxruntime-demo
|
10 |
+
# same inference code for GFPGANv1.2, GFPGANv1.3, GFPGANv1.4
|
11 |
+
|
12 |
+
lock = threading.Lock()
|
13 |
+
|
14 |
+
class GFPGAN:
|
15 |
+
def __init__(self, model_path="GFPGANv1.4.onnx", provider=["CPUExecutionProvider"], session_options=None):
|
16 |
+
self.session_options = session_options
|
17 |
+
if self.session_options is None:
|
18 |
+
self.session_options = onnxruntime.SessionOptions()
|
19 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
|
20 |
+
self.resolution = self.session.get_inputs()[0].shape[-2:]
|
21 |
+
|
22 |
+
def preprocess(self, img):
|
23 |
+
img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
|
24 |
+
img = img.astype(np.float32)[:,:,::-1] / 255.0
|
25 |
+
img = img.transpose((2, 0, 1))
|
26 |
+
img = (img - 0.5) / 0.5
|
27 |
+
img = np.expand_dims(img, axis=0).astype(np.float32)
|
28 |
+
return img
|
29 |
+
|
30 |
+
def postprocess(self, img):
|
31 |
+
img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
|
32 |
+
img = (img * 255)[:,:,::-1]
|
33 |
+
img = img.clip(0, 255).astype('uint8')
|
34 |
+
return img
|
35 |
+
|
36 |
+
def enhance(self, img):
|
37 |
+
img = self.preprocess(img)
|
38 |
+
with lock:
|
39 |
+
output = self.session.run(None, {'input':img})[0][0]
|
40 |
+
output = self.postprocess(output)
|
41 |
+
return output
|
upscaler/GPEN.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import onnxruntime
|
4 |
+
import numpy as np
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
|
8 |
+
lock = threading.Lock()
|
9 |
+
|
10 |
+
class GPEN:
|
11 |
+
def __init__(self, model_path="GPEN-BFR-512.onnx", provider=["CPUExecutionProvider"], session_options=None):
|
12 |
+
self.session_options = session_options
|
13 |
+
if self.session_options is None:
|
14 |
+
self.session_options = onnxruntime.SessionOptions()
|
15 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
|
16 |
+
self.resolution = self.session.get_inputs()[0].shape[-2:]
|
17 |
+
|
18 |
+
def preprocess(self, img):
|
19 |
+
img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
|
20 |
+
img = img.astype(np.float32)[:,:,::-1] / 255.0
|
21 |
+
img = img.transpose((2, 0, 1))
|
22 |
+
img = (img - 0.5) / 0.5
|
23 |
+
img = np.expand_dims(img, axis=0).astype(np.float32)
|
24 |
+
return img
|
25 |
+
|
26 |
+
def postprocess(self, img):
|
27 |
+
img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
|
28 |
+
img = (img * 255)[:,:,::-1]
|
29 |
+
img = img.clip(0, 255).astype('uint8')
|
30 |
+
return img
|
31 |
+
|
32 |
+
def enhance(self, img):
|
33 |
+
img = self.preprocess(img)
|
34 |
+
with lock:
|
35 |
+
output = self.session.run(None, {'input':img})[0][0]
|
36 |
+
output = self.postprocess(output)
|
37 |
+
return output
|
upscaler/__init__.py
ADDED
File without changes
|
upscaler/codeformer.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import onnxruntime
|
4 |
+
import numpy as np
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
|
8 |
+
# codeformer converted to onnx
|
9 |
+
# using https://github.com/redthing1/CodeFormer
|
10 |
+
|
11 |
+
lock = threading.Lock()
|
12 |
+
|
13 |
+
class CodeFormer:
|
14 |
+
def __init__(self, model_path="codeformer.onnx", provider=["CPUExecutionProvider"], session_options=None):
|
15 |
+
self.session_options = session_options
|
16 |
+
if self.session_options is None:
|
17 |
+
self.session_options = onnxruntime.SessionOptions()
|
18 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
|
19 |
+
self.resolution = self.session.get_inputs()[0].shape[-2:]
|
20 |
+
|
21 |
+
def preprocess(self, img, w):
|
22 |
+
img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
|
23 |
+
img = img.astype(np.float32)[:,:,::-1] / 255.0
|
24 |
+
img = img.transpose((2, 0, 1))
|
25 |
+
img = (img - 0.5) / 0.5
|
26 |
+
img = np.expand_dims(img, axis=0).astype(np.float32)
|
27 |
+
w = np.array([w], dtype=np.double)
|
28 |
+
return img, w
|
29 |
+
|
30 |
+
def postprocess(self, img):
|
31 |
+
img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
|
32 |
+
img = (img * 255)[:,:,::-1]
|
33 |
+
img = img.clip(0, 255).astype('uint8')
|
34 |
+
return img
|
35 |
+
|
36 |
+
def enhance(self, img, w=0.9):
|
37 |
+
img, w = self.preprocess(img, w)
|
38 |
+
with lock:
|
39 |
+
output = self.session.run(None, {'x':img, 'w':w})[0][0]
|
40 |
+
output = self.postprocess(output)
|
41 |
+
return output
|
upscaler/restoreformer.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import onnxruntime
|
4 |
+
import numpy as np
|
5 |
+
import threading
|
6 |
+
import time
|
7 |
+
|
8 |
+
lock = threading.Lock()
|
9 |
+
|
10 |
+
class RestoreFormer:
|
11 |
+
def __init__(self, model_path="restoreformer.onnx", provider=["CPUExecutionProvider"], session_options=None):
|
12 |
+
self.session_options = session_options
|
13 |
+
if self.session_options is None:
|
14 |
+
self.session_options = onnxruntime.SessionOptions()
|
15 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=self.session_options, providers=provider)
|
16 |
+
self.resolution = self.session.get_inputs()[0].shape[-2:]
|
17 |
+
|
18 |
+
def preprocess(self, img):
|
19 |
+
img = cv2.resize(img, self.resolution, interpolation=cv2.INTER_LINEAR)
|
20 |
+
img = img.astype(np.float32)[:,:,::-1] / 255.0
|
21 |
+
img = img.transpose((2, 0, 1))
|
22 |
+
img = (img - 0.5) / 0.5
|
23 |
+
img = np.expand_dims(img, axis=0).astype(np.float32)
|
24 |
+
return img
|
25 |
+
|
26 |
+
def postprocess(self, img):
|
27 |
+
img = (img.transpose(1,2,0).clip(-1,1) + 1) * 0.5
|
28 |
+
img = (img * 255)[:,:,::-1]
|
29 |
+
img = img.clip(0, 255).astype('uint8')
|
30 |
+
return img
|
31 |
+
|
32 |
+
def enhance(self, img):
|
33 |
+
img = self.preprocess(img)
|
34 |
+
with lock:
|
35 |
+
output = self.session.run(None, {'input':img})[0][0]
|
36 |
+
output = self.postprocess(output)
|
37 |
+
return output
|
utils/__init__.py
ADDED
File without changes
|
utils/arcface.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Organization : insightface.ai
|
3 |
+
# @Author : Jia Guo
|
4 |
+
# @Time : 2021-09-18
|
5 |
+
# @Function :
|
6 |
+
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import onnx
|
11 |
+
import onnxruntime
|
12 |
+
import numpy as np
|
13 |
+
import default_paths as dp
|
14 |
+
from .face_alignment import norm_crop2
|
15 |
+
|
16 |
+
|
17 |
+
class ArcFace:
|
18 |
+
def __init__(self, model_file=None, provider=['CUDAExecutionProvider'], session_options=None):
|
19 |
+
assert model_file is not None
|
20 |
+
self.model_file = model_file
|
21 |
+
self.taskname = 'recognition'
|
22 |
+
find_sub = False
|
23 |
+
find_mul = False
|
24 |
+
model = onnx.load(self.model_file)
|
25 |
+
graph = model.graph
|
26 |
+
for nid, node in enumerate(graph.node[:8]):
|
27 |
+
#print(nid, node.name)
|
28 |
+
if node.name.startswith('Sub') or node.name.startswith('_minus'):
|
29 |
+
find_sub = True
|
30 |
+
if node.name.startswith('Mul') or node.name.startswith('_mul'):
|
31 |
+
find_mul = True
|
32 |
+
if find_sub and find_mul:
|
33 |
+
#mxnet arcface model
|
34 |
+
input_mean = 0.0
|
35 |
+
input_std = 1.0
|
36 |
+
else:
|
37 |
+
input_mean = 127.5
|
38 |
+
input_std = 127.5
|
39 |
+
self.input_mean = input_mean
|
40 |
+
self.input_std = input_std
|
41 |
+
#print('input mean and std:', self.input_mean, self.input_std)
|
42 |
+
self.session_options = session_options
|
43 |
+
if self.session_options is None:
|
44 |
+
self.session_options = onnxruntime.SessionOptions()
|
45 |
+
self.session = onnxruntime.InferenceSession(self.model_file, providers=provider, sess_options=self.session_options)
|
46 |
+
input_cfg = self.session.get_inputs()[0]
|
47 |
+
input_shape = input_cfg.shape
|
48 |
+
input_name = input_cfg.name
|
49 |
+
self.input_size = tuple(input_shape[2:4][::-1])
|
50 |
+
self.input_shape = input_shape
|
51 |
+
outputs = self.session.get_outputs()
|
52 |
+
output_names = []
|
53 |
+
for out in outputs:
|
54 |
+
output_names.append(out.name)
|
55 |
+
self.input_name = input_name
|
56 |
+
self.output_names = output_names
|
57 |
+
assert len(self.output_names)==1
|
58 |
+
self.output_shape = outputs[0].shape
|
59 |
+
|
60 |
+
def prepare(self, ctx_id, **kwargs):
|
61 |
+
if ctx_id<0:
|
62 |
+
self.session.set_providers(['CPUExecutionProvider'])
|
63 |
+
|
64 |
+
def get(self, img, kps):
|
65 |
+
aimg, matrix = norm_crop2(img, landmark=kps, image_size=self.input_size[0])
|
66 |
+
embedding = self.get_feat(aimg).flatten()
|
67 |
+
return embedding
|
68 |
+
|
69 |
+
def compute_sim(self, feat1, feat2):
|
70 |
+
from numpy.linalg import norm
|
71 |
+
feat1 = feat1.ravel()
|
72 |
+
feat2 = feat2.ravel()
|
73 |
+
sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
|
74 |
+
return sim
|
75 |
+
|
76 |
+
def get_feat(self, imgs):
|
77 |
+
if not isinstance(imgs, list):
|
78 |
+
imgs = [imgs]
|
79 |
+
input_size = self.input_size
|
80 |
+
|
81 |
+
blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
|
82 |
+
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
83 |
+
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
|
84 |
+
return net_out
|
85 |
+
|
86 |
+
def forward(self, batch_data):
|
87 |
+
blob = (batch_data - self.input_mean) / self.input_std
|
88 |
+
net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
|
89 |
+
return net_out
|
utils/device.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnx
|
2 |
+
import onnxruntime
|
3 |
+
|
4 |
+
device_types_list = ["cpu", "cuda"]
|
5 |
+
|
6 |
+
available_providers = onnxruntime.get_available_providers()
|
7 |
+
|
8 |
+
def get_device_and_provider(device='cpu'):
|
9 |
+
options = onnxruntime.SessionOptions()
|
10 |
+
options.log_severity_level = 3
|
11 |
+
if device == 'cuda':
|
12 |
+
if "CUDAExecutionProvider" in available_providers:
|
13 |
+
provider = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"]
|
14 |
+
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
15 |
+
else:
|
16 |
+
device = 'cpu'
|
17 |
+
provider = ["CPUExecutionProvider"]
|
18 |
+
else:
|
19 |
+
device = 'cpu'
|
20 |
+
provider = ["CPUExecutionProvider"]
|
21 |
+
|
22 |
+
return device, provider, options
|
23 |
+
|
24 |
+
|
25 |
+
data_type_bytes = {'uint8': 1, 'int8': 1, 'uint16': 2, 'int16': 2, 'float16': 2, 'float32': 4}
|
26 |
+
|
27 |
+
|
28 |
+
def estimate_max_batch_size(resolution, chunk_size=1024, data_type='float32', channels=3):
|
29 |
+
pixel_size = data_type_bytes.get(data_type, 1)
|
30 |
+
image_size = resolution[0] * resolution[1] * pixel_size * channels
|
31 |
+
number_of_batches = (chunk_size * 1024 * 1024) // image_size
|
32 |
+
return max(number_of_batches, 1)
|
utils/face_alignment.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def umeyama(src, dst, estimate_scale):
|
6 |
+
num = src.shape[0]
|
7 |
+
dim = src.shape[1]
|
8 |
+
src_mean = src.mean(axis=0)
|
9 |
+
dst_mean = dst.mean(axis=0)
|
10 |
+
src_demean = src - src_mean
|
11 |
+
dst_demean = dst - dst_mean
|
12 |
+
A = np.dot(dst_demean.T, src_demean) / num
|
13 |
+
d = np.ones((dim,), dtype=np.double)
|
14 |
+
if np.linalg.det(A) < 0:
|
15 |
+
d[dim - 1] = -1
|
16 |
+
T = np.eye(dim + 1, dtype=np.double)
|
17 |
+
U, S, V = np.linalg.svd(A)
|
18 |
+
rank = np.linalg.matrix_rank(A)
|
19 |
+
if rank == 0:
|
20 |
+
return np.nan * T
|
21 |
+
elif rank == dim - 1:
|
22 |
+
if np.linalg.det(U) * np.linalg.det(V) > 0:
|
23 |
+
T[:dim, :dim] = np.dot(U, V)
|
24 |
+
else:
|
25 |
+
s = d[dim - 1]
|
26 |
+
d[dim - 1] = -1
|
27 |
+
T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V))
|
28 |
+
d[dim - 1] = s
|
29 |
+
else:
|
30 |
+
T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T))
|
31 |
+
if estimate_scale:
|
32 |
+
scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d)
|
33 |
+
else:
|
34 |
+
scale = 1.0
|
35 |
+
T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T)
|
36 |
+
T[:dim, :dim] *= scale
|
37 |
+
return T
|
38 |
+
|
39 |
+
|
40 |
+
arcface_dst = np.array(
|
41 |
+
[[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
|
42 |
+
[41.5493, 92.3655], [70.7299, 92.2041]],
|
43 |
+
dtype=np.float32)
|
44 |
+
|
45 |
+
|
46 |
+
def estimate_norm(lmk, image_size=112, mode='arcface'):
|
47 |
+
assert lmk.shape == (5, 2)
|
48 |
+
assert image_size % 112 == 0 or image_size % 128 == 0
|
49 |
+
if image_size % 112 == 0:
|
50 |
+
ratio = float(image_size) / 112.0
|
51 |
+
diff_x = 0
|
52 |
+
else:
|
53 |
+
ratio = float(image_size) / 128.0
|
54 |
+
diff_x = 8.0 * ratio
|
55 |
+
dst = arcface_dst * ratio
|
56 |
+
dst[:, 0] += diff_x
|
57 |
+
M = umeyama(lmk, dst, True)[0:2, :]
|
58 |
+
return M
|
59 |
+
|
60 |
+
|
61 |
+
def norm_crop2(img, landmark, image_size=112, mode='arcface'):
|
62 |
+
M = estimate_norm(landmark, image_size, mode)
|
63 |
+
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0, borderMode=cv2.BORDER_REPLICATE)
|
64 |
+
return warped, M
|
65 |
+
|
66 |
+
|
67 |
+
def get_cropped_head(img, landmark, scale=1.4):
|
68 |
+
# it is ugly but works :D
|
69 |
+
center = np.mean(landmark, axis=0)
|
70 |
+
landmark = center + (landmark - center) * scale
|
71 |
+
M = estimate_norm(landmark, 128, mode='arcface')
|
72 |
+
warped = cv2.warpAffine(img, M/0.25, (512, 512), borderValue=0.0)
|
73 |
+
return warped
|
utils/gender_age.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
from .face_alignment import norm_crop2
|
5 |
+
|
6 |
+
class GenderAge:
|
7 |
+
def __init__(self, model_file=None, provider=['CPUExecutionProvider'], session_options=None):
|
8 |
+
self.model_file = model_file
|
9 |
+
self.session_options = session_options
|
10 |
+
if self.session_options is None:
|
11 |
+
self.session_options = onnxruntime.SessionOptions()
|
12 |
+
self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=provider)
|
13 |
+
|
14 |
+
def predict(self, img, kps):
|
15 |
+
aimg, matrix = norm_crop2(img, kps, 128)
|
16 |
+
|
17 |
+
blob = cv2.resize(aimg, (62,62), interpolation=cv2.INTER_AREA)
|
18 |
+
blob = np.expand_dims(blob, axis=0).astype('float32')
|
19 |
+
|
20 |
+
_prob, _age = self.session.run(None, {'data':blob})
|
21 |
+
prob = _prob[0][0][0]
|
22 |
+
age = round(_age[0][0][0][0] * 100)
|
23 |
+
gender = np.argmax(prob)
|
24 |
+
|
25 |
+
return gender, age
|
utils/image.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import base64
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def laplacian_blending(A, B, m, num_levels=7):
|
7 |
+
assert A.shape == B.shape
|
8 |
+
assert B.shape == m.shape
|
9 |
+
height = m.shape[0]
|
10 |
+
width = m.shape[1]
|
11 |
+
size_list = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192])
|
12 |
+
size = size_list[np.where(size_list > max(height, width))][0]
|
13 |
+
GA = np.zeros((size, size, 3), dtype=np.float32)
|
14 |
+
GA[:height, :width, :] = A
|
15 |
+
GB = np.zeros((size, size, 3), dtype=np.float32)
|
16 |
+
GB[:height, :width, :] = B
|
17 |
+
GM = np.zeros((size, size, 3), dtype=np.float32)
|
18 |
+
GM[:height, :width, :] = m
|
19 |
+
gpA = [GA]
|
20 |
+
gpB = [GB]
|
21 |
+
gpM = [GM]
|
22 |
+
for i in range(num_levels):
|
23 |
+
GA = cv2.pyrDown(GA)
|
24 |
+
GB = cv2.pyrDown(GB)
|
25 |
+
GM = cv2.pyrDown(GM)
|
26 |
+
gpA.append(np.float32(GA))
|
27 |
+
gpB.append(np.float32(GB))
|
28 |
+
gpM.append(np.float32(GM))
|
29 |
+
lpA = [gpA[num_levels-1]]
|
30 |
+
lpB = [gpB[num_levels-1]]
|
31 |
+
gpMr = [gpM[num_levels-1]]
|
32 |
+
for i in range(num_levels-1,0,-1):
|
33 |
+
LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
|
34 |
+
LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
|
35 |
+
lpA.append(LA)
|
36 |
+
lpB.append(LB)
|
37 |
+
gpMr.append(gpM[i-1])
|
38 |
+
LS = []
|
39 |
+
for la,lb,gm in zip(lpA,lpB,gpMr):
|
40 |
+
ls = la * gm + lb * (1.0 - gm)
|
41 |
+
LS.append(ls)
|
42 |
+
ls_ = LS[0]
|
43 |
+
for i in range(1,num_levels):
|
44 |
+
ls_ = cv2.pyrUp(ls_)
|
45 |
+
ls_ = cv2.add(ls_, LS[i])
|
46 |
+
ls_ = ls_[:height, :width, :]
|
47 |
+
#ls_ = (ls_ - np.min(ls_)) * (255.0 / (np.max(ls_) - np.min(ls_)))
|
48 |
+
return ls_.clip(0, 255)
|
49 |
+
|
50 |
+
|
51 |
+
def mask_crop(mask, crop):
|
52 |
+
top, bottom, left, right = crop
|
53 |
+
shape = mask.shape
|
54 |
+
top = int(top)
|
55 |
+
bottom = int(bottom)
|
56 |
+
if top + bottom < shape[1]:
|
57 |
+
if top > 0: mask[:top, :] = 0
|
58 |
+
if bottom > 0: mask[-bottom:, :] = 0
|
59 |
+
|
60 |
+
left = int(left)
|
61 |
+
right = int(right)
|
62 |
+
if left + right < shape[0]:
|
63 |
+
if left > 0: mask[:, :left] = 0
|
64 |
+
if right > 0: mask[:, -right:] = 0
|
65 |
+
|
66 |
+
return mask
|
67 |
+
|
68 |
+
def create_image_grid(images, size=128):
|
69 |
+
num_images = len(images)
|
70 |
+
num_cols = int(np.ceil(np.sqrt(num_images)))
|
71 |
+
num_rows = int(np.ceil(num_images / num_cols))
|
72 |
+
grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
|
73 |
+
|
74 |
+
for i, image in enumerate(images):
|
75 |
+
row_idx = (i // num_cols) * size
|
76 |
+
col_idx = (i % num_cols) * size
|
77 |
+
image = cv2.resize(image.copy(), (size,size))
|
78 |
+
if image.dtype != np.uint8:
|
79 |
+
image = (image.astype('float32') * 255).astype('uint8')
|
80 |
+
if image.ndim == 2:
|
81 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
82 |
+
grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
|
83 |
+
|
84 |
+
return grid
|
85 |
+
|
86 |
+
|
87 |
+
def paste_to_whole(foreground, background, matrix, mask=None, crop_mask=(0,0,0,0), blur_amount=0.1, erode_amount = 0.15, blend_method='linear'):
|
88 |
+
inv_matrix = cv2.invertAffineTransform(matrix)
|
89 |
+
fg_shape = foreground.shape[:2]
|
90 |
+
bg_shape = (background.shape[1], background.shape[0])
|
91 |
+
foreground = cv2.warpAffine(foreground, inv_matrix, bg_shape, borderValue=0.0, borderMode=cv2.BORDER_REPLICATE)
|
92 |
+
|
93 |
+
if mask is None:
|
94 |
+
mask = np.full(fg_shape, 1., dtype=np.float32)
|
95 |
+
mask = mask_crop(mask, crop_mask)
|
96 |
+
mask = cv2.warpAffine(mask, inv_matrix, bg_shape, borderValue=0.0)
|
97 |
+
else:
|
98 |
+
assert fg_shape == mask.shape[:2], "foreground & mask shape mismatch!"
|
99 |
+
mask = mask_crop(mask, crop_mask).astype('float32')
|
100 |
+
mask = cv2.warpAffine(mask, inv_matrix, (background.shape[1], background.shape[0]), borderValue=0.0)
|
101 |
+
|
102 |
+
_mask = mask.copy()
|
103 |
+
_mask[_mask > 0.05] = 1.
|
104 |
+
non_zero_points = cv2.findNonZero(_mask)
|
105 |
+
_, _, w, h = cv2.boundingRect(non_zero_points)
|
106 |
+
mask_size = int(np.sqrt(w * h))
|
107 |
+
|
108 |
+
if erode_amount > 0:
|
109 |
+
kernel_size = max(int(mask_size * erode_amount), 1)
|
110 |
+
structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
|
111 |
+
mask = cv2.erode(mask, structuring_element)
|
112 |
+
|
113 |
+
if blur_amount > 0:
|
114 |
+
kernel_size = max(int(mask_size * blur_amount), 3)
|
115 |
+
if kernel_size % 2 == 0:
|
116 |
+
kernel_size += 1
|
117 |
+
mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
|
118 |
+
|
119 |
+
mask = np.tile(np.expand_dims(mask, axis=-1), (1, 1, 3))
|
120 |
+
|
121 |
+
if blend_method == 'laplacian':
|
122 |
+
composite_image = laplacian_blending(foreground, background, mask.clip(0,1), num_levels=4)
|
123 |
+
else:
|
124 |
+
composite_image = mask * foreground + (1 - mask) * background
|
125 |
+
|
126 |
+
return composite_image.astype("uint8").clip(0, 255)
|
127 |
+
|
128 |
+
|
129 |
+
def image_mask_overlay(img, mask):
|
130 |
+
img = img.astype('float32') / 255.
|
131 |
+
img *= (mask + 0.25).clip(0, 1)
|
132 |
+
img = np.clip(img * 255., 0., 255.).astype('uint8')
|
133 |
+
return img
|
134 |
+
|
135 |
+
|
136 |
+
def resize_with_padding(img, expected_size=(640, 360), color=(0, 0, 0), max_flip=False):
|
137 |
+
original_height, original_width = img.shape[:2]
|
138 |
+
|
139 |
+
if max_flip and original_height > original_width:
|
140 |
+
expected_size = (expected_size[1], expected_size[0])
|
141 |
+
|
142 |
+
aspect_ratio = original_width / original_height
|
143 |
+
new_width = expected_size[0]
|
144 |
+
new_height = int(new_width / aspect_ratio)
|
145 |
+
|
146 |
+
if new_height > expected_size[1]:
|
147 |
+
new_height = expected_size[1]
|
148 |
+
new_width = int(new_height * aspect_ratio)
|
149 |
+
|
150 |
+
resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
151 |
+
canvas = cv2.copyMakeBorder(resized_img,
|
152 |
+
top=(expected_size[1] - new_height) // 2,
|
153 |
+
bottom=(expected_size[1] - new_height + 1) // 2,
|
154 |
+
left=(expected_size[0] - new_width) // 2,
|
155 |
+
right=(expected_size[0] - new_width + 1) // 2,
|
156 |
+
borderType=cv2.BORDER_CONSTANT, value=color)
|
157 |
+
return canvas
|
158 |
+
|
159 |
+
|
160 |
+
def create_image_grid(images, size=128):
|
161 |
+
num_images = len(images)
|
162 |
+
num_cols = int(np.ceil(np.sqrt(num_images)))
|
163 |
+
num_rows = int(np.ceil(num_images / num_cols))
|
164 |
+
grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
|
165 |
+
|
166 |
+
for i, image in enumerate(images):
|
167 |
+
row_idx = (i // num_cols) * size
|
168 |
+
col_idx = (i % num_cols) * size
|
169 |
+
image = cv2.resize(image.copy(), (size,size))
|
170 |
+
if image.dtype != np.uint8:
|
171 |
+
image = (image.astype('float32') * 255).astype('uint8')
|
172 |
+
if image.ndim == 2:
|
173 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
174 |
+
grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
|
175 |
+
|
176 |
+
return grid
|
177 |
+
|
178 |
+
|
179 |
+
def image_to_html(img, size=(640, 360), extension="jpg"):
|
180 |
+
if img is not None:
|
181 |
+
img = resize_with_padding(img, expected_size=size)
|
182 |
+
buffer = cv2.imencode(f".{extension}", img)[1]
|
183 |
+
base64_data = base64.b64encode(buffer.tobytes())
|
184 |
+
imgbs64 = f"data:image/{extension};base64," + base64_data.decode("utf-8")
|
185 |
+
html = '<div style="display: flex; justify-content: center; align-items: center; width: 100%;">'
|
186 |
+
html += f'<img src={imgbs64} alt="No Preview" style="max-width: 100%; max-height: 100%;">'
|
187 |
+
html += '</div>'
|
188 |
+
return html
|
189 |
+
return None
|
190 |
+
|
191 |
+
|
192 |
+
def mix_two_image(a, b, opacity=1.):
|
193 |
+
a_dtype = a.dtype
|
194 |
+
b_dtype = b.dtype
|
195 |
+
a = a.astype('float32')
|
196 |
+
b = b.astype('float32')
|
197 |
+
a = cv2.resize(a, (b.shape[0], b.shape[1]))
|
198 |
+
opacity = min(max(opacity, 0.), 1.)
|
199 |
+
mixed_img = opacity * b + (1 - opacity) * a
|
200 |
+
return mixed_img.astype(a_dtype)
|
201 |
+
|
202 |
+
resolution_map = {
|
203 |
+
"Original": None,
|
204 |
+
"240p": (426, 240),
|
205 |
+
"360p": (640, 360),
|
206 |
+
"480p": (854, 480),
|
207 |
+
"720p": (1280, 720),
|
208 |
+
"1080p": (1920, 1080),
|
209 |
+
"1440p": (2560, 1440),
|
210 |
+
"2160p": (3840, 2160),
|
211 |
+
}
|
212 |
+
|
213 |
+
def resize_image_by_resolution(img, quality):
|
214 |
+
resolution = resolution_map.get(quality, None)
|
215 |
+
if resolution is None:
|
216 |
+
return img
|
217 |
+
|
218 |
+
h, w = img.shape[:2]
|
219 |
+
if h > w:
|
220 |
+
ratio = resolution[0] / h
|
221 |
+
else:
|
222 |
+
ratio = resolution[0] / w
|
223 |
+
|
224 |
+
new_h, new_w = int(h * ratio), int(w * ratio)
|
225 |
+
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
226 |
+
return img
|
227 |
+
|
228 |
+
def fast_pil_encode(pil_image):
|
229 |
+
image_arr = np.asarray(pil_image)[:,:,::-1]
|
230 |
+
buffer = cv2.imencode('.jpg', image_arr)[1]
|
231 |
+
base64_data = base64.b64encode(buffer.tobytes())
|
232 |
+
return "data:image/jpg;base64," + base64_data.decode("utf-8")
|
233 |
+
|
234 |
+
def fast_numpy_encode(img_array):
|
235 |
+
buffer = cv2.imencode('.jpg', img_array)[1]
|
236 |
+
base64_data = base64.b64encode(buffer.tobytes())
|
237 |
+
return "data:image/jpg;base64," + base64_data.decode("utf-8")
|
238 |
+
|
239 |
+
crf_quality_by_resolution = {
|
240 |
+
240: {"poor": 45, "low": 35, "medium": 28, "high": 23, "best": 20},
|
241 |
+
360: {"poor": 35, "low": 28, "medium": 23, "high": 20, "best": 18},
|
242 |
+
480: {"poor": 28, "low": 23, "medium": 20, "high": 18, "best": 16},
|
243 |
+
720: {"poor": 23, "low": 20, "medium": 18, "high": 16, "best": 14},
|
244 |
+
1080: {"poor": 20, "low": 18, "medium": 16, "high": 14, "best": 12},
|
245 |
+
1440: {"poor": 18, "low": 16, "medium": 14, "high": 12, "best": 10},
|
246 |
+
2160: {"poor": 16, "low": 14, "medium": 12, "high": 10, "best": 8}
|
247 |
+
}
|
248 |
+
|
249 |
+
def get_crf_for_resolution(resolution, quality):
|
250 |
+
available_resolutions = list(crf_quality_by_resolution.keys())
|
251 |
+
closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
|
252 |
+
return crf_quality_by_resolution[closest_resolution][quality]
|
utils/io.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
|
9 |
+
image_extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
|
10 |
+
|
11 |
+
def get_images_from_directory(directory_path):
|
12 |
+
file_paths =[]
|
13 |
+
for file_path in glob.glob(os.path.join(directory_path, "*")):
|
14 |
+
if any(file_path.lower().endswith(ext) for ext in image_extensions):
|
15 |
+
file_paths.append(file_path)
|
16 |
+
file_paths.sort()
|
17 |
+
return file_paths
|
18 |
+
|
19 |
+
|
20 |
+
def open_directory(path=None):
|
21 |
+
if path is None:
|
22 |
+
return
|
23 |
+
try:
|
24 |
+
os.startfile(path)
|
25 |
+
except:
|
26 |
+
subprocess.Popen(["xdg-open", path])
|
27 |
+
|
28 |
+
|
29 |
+
def copy_files_to_directory(files, destination):
|
30 |
+
file_paths = []
|
31 |
+
for file_path in files:
|
32 |
+
new_file_path = shutil.copy(file_path, destination)
|
33 |
+
file_paths.append(new_file_path)
|
34 |
+
return file_paths
|
35 |
+
|
36 |
+
|
37 |
+
def create_directory(directory_path, remove_existing=True):
|
38 |
+
if os.path.exists(directory_path) and remove_existing:
|
39 |
+
shutil.rmtree(directory_path)
|
40 |
+
|
41 |
+
if not os.path.exists(directory_path):
|
42 |
+
os.mkdir(directory_path)
|
43 |
+
return directory_path
|
44 |
+
else:
|
45 |
+
counter = 1
|
46 |
+
while True:
|
47 |
+
new_directory_path = f"{directory_path}_{counter}"
|
48 |
+
if not os.path.exists(new_directory_path):
|
49 |
+
os.mkdir(new_directory_path)
|
50 |
+
return new_directory_path
|
51 |
+
counter += 1
|
52 |
+
|
53 |
+
|
54 |
+
def add_datetime_to_filename(filename):
|
55 |
+
current_datetime = datetime.now()
|
56 |
+
formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")
|
57 |
+
file_name, file_extension = os.path.splitext(filename)
|
58 |
+
new_filename = f"{file_name}_{formatted_datetime}{file_extension}"
|
59 |
+
return new_filename
|
60 |
+
|
61 |
+
|
62 |
+
def get_single_video_frame(video_path, frame_index):
|
63 |
+
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
|
64 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
65 |
+
frame_index = min(int(frame_index), total_frames-1)
|
66 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index))
|
67 |
+
valid_frame, frame = cap.read()
|
68 |
+
cap.release()
|
69 |
+
if valid_frame:
|
70 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
71 |
+
return frame
|
72 |
+
return None
|
73 |
+
|
74 |
+
|
75 |
+
def get_video_fps(video_path):
|
76 |
+
cap = cv2.VideoCapture(video_path)
|
77 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
78 |
+
cap.release()
|
79 |
+
return fps
|
80 |
+
|
81 |
+
|
82 |
+
def ffmpeg_extract_frames(video_path, destination, remove_existing=True, fps=30, name='frame_%d.jpg', ffmpeg_path=None):
|
83 |
+
ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
|
84 |
+
destination = create_directory(destination, remove_existing=remove_existing)
|
85 |
+
cmd = [
|
86 |
+
ffmpeg_path,
|
87 |
+
'-loglevel', 'info',
|
88 |
+
'-hwaccel', 'auto',
|
89 |
+
'-i', video_path,
|
90 |
+
'-q:v', '3',
|
91 |
+
'-pix_fmt', 'rgb24',
|
92 |
+
'-vf', 'fps=' + str(fps),
|
93 |
+
'-y',
|
94 |
+
os.path.join(destination, name)
|
95 |
+
]
|
96 |
+
process = subprocess.Popen(cmd)
|
97 |
+
process.communicate()
|
98 |
+
if process.returncode == 0:
|
99 |
+
return True, get_images_from_directory(destination)
|
100 |
+
else:
|
101 |
+
print(f"Error: Failed to extract video.")
|
102 |
+
return False, None
|
103 |
+
|
104 |
+
|
105 |
+
def ffmpeg_merge_frames(sequence_directory, pattern, destination, fps=30, crf=18, ffmpeg_path=None):
|
106 |
+
ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
|
107 |
+
cmd = [
|
108 |
+
ffmpeg_path,
|
109 |
+
'-loglevel', 'info',
|
110 |
+
'-hwaccel', 'auto',
|
111 |
+
'-r', str(fps),
|
112 |
+
# '-pattern_type', 'glob',
|
113 |
+
'-i', os.path.join(sequence_directory, pattern),
|
114 |
+
'-c:v', 'libx264',
|
115 |
+
'-crf', str(crf),
|
116 |
+
'-pix_fmt', 'yuv420p',
|
117 |
+
'-vf', 'colorspace=bt709:iall=bt601-6-625:fast=1',
|
118 |
+
'-y', destination
|
119 |
+
]
|
120 |
+
process = subprocess.Popen(cmd)
|
121 |
+
process.communicate()
|
122 |
+
if process.returncode == 0:
|
123 |
+
return True, destination
|
124 |
+
else:
|
125 |
+
print(f"Error: Failed to merge image sequence.")
|
126 |
+
return False, None
|
127 |
+
|
128 |
+
|
129 |
+
def ffmpeg_replace_video_segments(main_video_path, sub_clips_info, output_path, ffmpeg_path="ffmpeg"):
|
130 |
+
ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
|
131 |
+
filter_complex = ""
|
132 |
+
|
133 |
+
filter_complex += f"[0:v]split=2[v0][main_end]; "
|
134 |
+
filter_complex += f"[1:v]split={len(sub_clips_info)}{', '.join([f'[v{index + 1}]' for index in range(len(sub_clips_info))])}; "
|
135 |
+
|
136 |
+
overlay_exprs = "".join([f"[v{index + 1}]" for index in range(len(sub_clips_info))])
|
137 |
+
overlay_filters = f"[main_end][{overlay_exprs}]overlay=eof_action=pass[vout]; "
|
138 |
+
filter_complex += overlay_filters
|
139 |
+
|
140 |
+
cmd = [
|
141 |
+
ffmpeg_path, '-i', main_video_path,
|
142 |
+
]
|
143 |
+
|
144 |
+
for sub_clip_path, _, _ in sub_clips_info:
|
145 |
+
cmd.extend(['-i', sub_clip_path])
|
146 |
+
|
147 |
+
cmd.extend([
|
148 |
+
'-filter_complex', filter_complex,
|
149 |
+
'-map', '[vout]',
|
150 |
+
output_path
|
151 |
+
])
|
152 |
+
|
153 |
+
subprocess.run(cmd)
|
154 |
+
|
155 |
+
|
156 |
+
def ffmpeg_mux_audio(source, target, output, ffmpeg_path=None):
|
157 |
+
ffmpeg_path = 'ffmpeg' if ffmpeg_path is None else ffmpeg_path
|
158 |
+
extracted_audio_path = os.path.join(os.path.dirname(output), 'extracted_audio.aac')
|
159 |
+
cmd1 = [
|
160 |
+
ffmpeg_path,
|
161 |
+
'-loglevel', 'info',
|
162 |
+
'-i', source,
|
163 |
+
'-vn',
|
164 |
+
'-c:a', 'aac',
|
165 |
+
'-y',
|
166 |
+
extracted_audio_path
|
167 |
+
]
|
168 |
+
process = subprocess.Popen(cmd1)
|
169 |
+
process.communicate()
|
170 |
+
if process.returncode != 0:
|
171 |
+
print(f"Error: Failed to extract audio.")
|
172 |
+
return False, target
|
173 |
+
|
174 |
+
cmd2 = [
|
175 |
+
ffmpeg_path,
|
176 |
+
'-loglevel', 'info',
|
177 |
+
'-hwaccel', 'auto',
|
178 |
+
'-i', target,
|
179 |
+
'-i', extracted_audio_path,
|
180 |
+
'-c:v', 'copy',
|
181 |
+
'-map', '0:v:0',
|
182 |
+
'-map', '1:a:0',
|
183 |
+
'-y', output
|
184 |
+
]
|
185 |
+
process = subprocess.Popen(cmd2)
|
186 |
+
process.communicate()
|
187 |
+
if process.returncode == 0:
|
188 |
+
if os.path.exists(extracted_audio_path):
|
189 |
+
os.remove(extracted_audio_path)
|
190 |
+
return True, output
|
191 |
+
else:
|
192 |
+
print(f"Error: Failed to mux audio.")
|
193 |
+
return False, None
|
194 |
+
|
utils/retinaface.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Organization : insightface.ai
|
3 |
+
# @Author : Jia Guo
|
4 |
+
# @Time : 2021-09-18
|
5 |
+
# @Function :
|
6 |
+
|
7 |
+
from __future__ import division
|
8 |
+
import datetime
|
9 |
+
import numpy as np
|
10 |
+
import onnx
|
11 |
+
import onnxruntime
|
12 |
+
import os
|
13 |
+
import cv2
|
14 |
+
import sys
|
15 |
+
import default_paths as dp
|
16 |
+
|
17 |
+
def softmax(z):
|
18 |
+
assert len(z.shape) == 2
|
19 |
+
s = np.max(z, axis=1)
|
20 |
+
s = s[:, np.newaxis] # necessary step to do broadcasting
|
21 |
+
e_x = np.exp(z - s)
|
22 |
+
div = np.sum(e_x, axis=1)
|
23 |
+
div = div[:, np.newaxis] # dito
|
24 |
+
return e_x / div
|
25 |
+
|
26 |
+
def distance2bbox(points, distance, max_shape=None):
|
27 |
+
"""Decode distance prediction to bounding box.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
points (Tensor): Shape (n, 2), [x, y].
|
31 |
+
distance (Tensor): Distance from the given point to 4
|
32 |
+
boundaries (left, top, right, bottom).
|
33 |
+
max_shape (tuple): Shape of the image.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Tensor: Decoded bboxes.
|
37 |
+
"""
|
38 |
+
x1 = points[:, 0] - distance[:, 0]
|
39 |
+
y1 = points[:, 1] - distance[:, 1]
|
40 |
+
x2 = points[:, 0] + distance[:, 2]
|
41 |
+
y2 = points[:, 1] + distance[:, 3]
|
42 |
+
if max_shape is not None:
|
43 |
+
x1 = x1.clamp(min=0, max=max_shape[1])
|
44 |
+
y1 = y1.clamp(min=0, max=max_shape[0])
|
45 |
+
x2 = x2.clamp(min=0, max=max_shape[1])
|
46 |
+
y2 = y2.clamp(min=0, max=max_shape[0])
|
47 |
+
return np.stack([x1, y1, x2, y2], axis=-1)
|
48 |
+
|
49 |
+
def distance2kps(points, distance, max_shape=None):
|
50 |
+
"""Decode distance prediction to bounding box.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
points (Tensor): Shape (n, 2), [x, y].
|
54 |
+
distance (Tensor): Distance from the given point to 4
|
55 |
+
boundaries (left, top, right, bottom).
|
56 |
+
max_shape (tuple): Shape of the image.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Tensor: Decoded bboxes.
|
60 |
+
"""
|
61 |
+
preds = []
|
62 |
+
for i in range(0, distance.shape[1], 2):
|
63 |
+
px = points[:, i%2] + distance[:, i]
|
64 |
+
py = points[:, i%2+1] + distance[:, i+1]
|
65 |
+
if max_shape is not None:
|
66 |
+
px = px.clamp(min=0, max=max_shape[1])
|
67 |
+
py = py.clamp(min=0, max=max_shape[0])
|
68 |
+
preds.append(px)
|
69 |
+
preds.append(py)
|
70 |
+
return np.stack(preds, axis=-1)
|
71 |
+
|
72 |
+
class RetinaFace:
|
73 |
+
def __init__(self, model_file=None, provider=["CPUExecutionProvider"], session_options=None):
|
74 |
+
self.model_file = model_file
|
75 |
+
self.session_options = session_options
|
76 |
+
if self.session_options is None:
|
77 |
+
self.session_options = onnxruntime.SessionOptions()
|
78 |
+
self.session = onnxruntime.InferenceSession(self.model_file, providers=provider, sess_options=self.session_options)
|
79 |
+
self.center_cache = {}
|
80 |
+
self.nms_thresh = 0.4
|
81 |
+
self.det_thresh = 0.5
|
82 |
+
self._init_vars()
|
83 |
+
|
84 |
+
def _init_vars(self):
|
85 |
+
input_cfg = self.session.get_inputs()[0]
|
86 |
+
input_shape = input_cfg.shape
|
87 |
+
#print(input_shape)
|
88 |
+
if isinstance(input_shape[2], str):
|
89 |
+
self.input_size = None
|
90 |
+
else:
|
91 |
+
self.input_size = tuple(input_shape[2:4][::-1])
|
92 |
+
#print('image_size:', self.image_size)
|
93 |
+
input_name = input_cfg.name
|
94 |
+
self.input_shape = input_shape
|
95 |
+
outputs = self.session.get_outputs()
|
96 |
+
output_names = []
|
97 |
+
for o in outputs:
|
98 |
+
output_names.append(o.name)
|
99 |
+
self.input_name = input_name
|
100 |
+
self.output_names = output_names
|
101 |
+
self.input_mean = 127.5
|
102 |
+
self.input_std = 128.0
|
103 |
+
#print(self.output_names)
|
104 |
+
#assert len(outputs)==10 or len(outputs)==15
|
105 |
+
self.use_kps = False
|
106 |
+
self._anchor_ratio = 1.0
|
107 |
+
self._num_anchors = 1
|
108 |
+
if len(outputs)==6:
|
109 |
+
self.fmc = 3
|
110 |
+
self._feat_stride_fpn = [8, 16, 32]
|
111 |
+
self._num_anchors = 2
|
112 |
+
elif len(outputs)==9:
|
113 |
+
self.fmc = 3
|
114 |
+
self._feat_stride_fpn = [8, 16, 32]
|
115 |
+
self._num_anchors = 2
|
116 |
+
self.use_kps = True
|
117 |
+
elif len(outputs)==10:
|
118 |
+
self.fmc = 5
|
119 |
+
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
120 |
+
self._num_anchors = 1
|
121 |
+
elif len(outputs)==15:
|
122 |
+
self.fmc = 5
|
123 |
+
self._feat_stride_fpn = [8, 16, 32, 64, 128]
|
124 |
+
self._num_anchors = 1
|
125 |
+
self.use_kps = True
|
126 |
+
|
127 |
+
def prepare(self, **kwargs):
|
128 |
+
nms_thresh = kwargs.get('nms_thresh', None)
|
129 |
+
if nms_thresh is not None:
|
130 |
+
self.nms_thresh = nms_thresh
|
131 |
+
det_thresh = kwargs.get('det_thresh', None)
|
132 |
+
if det_thresh is not None:
|
133 |
+
self.det_thresh = det_thresh
|
134 |
+
input_size = kwargs.get('input_size', None)
|
135 |
+
if input_size is not None:
|
136 |
+
if self.input_size is not None:
|
137 |
+
print('warning: det_size is already set in detection model, ignore')
|
138 |
+
else:
|
139 |
+
self.input_size = input_size
|
140 |
+
|
141 |
+
def forward(self, img, threshold):
|
142 |
+
scores_list = []
|
143 |
+
bboxes_list = []
|
144 |
+
kpss_list = []
|
145 |
+
input_size = tuple(img.shape[0:2][::-1])
|
146 |
+
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
147 |
+
net_outs = self.session.run(self.output_names, {self.input_name : blob})
|
148 |
+
|
149 |
+
input_height = blob.shape[2]
|
150 |
+
input_width = blob.shape[3]
|
151 |
+
fmc = self.fmc
|
152 |
+
for idx, stride in enumerate(self._feat_stride_fpn):
|
153 |
+
scores = net_outs[idx]
|
154 |
+
bbox_preds = net_outs[idx+fmc]
|
155 |
+
bbox_preds = bbox_preds * stride
|
156 |
+
if self.use_kps:
|
157 |
+
kps_preds = net_outs[idx+fmc*2] * stride
|
158 |
+
height = input_height // stride
|
159 |
+
width = input_width // stride
|
160 |
+
K = height * width
|
161 |
+
key = (height, width, stride)
|
162 |
+
if key in self.center_cache:
|
163 |
+
anchor_centers = self.center_cache[key]
|
164 |
+
else:
|
165 |
+
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
|
166 |
+
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
|
167 |
+
if self._num_anchors>1:
|
168 |
+
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
|
169 |
+
if len(self.center_cache)<100:
|
170 |
+
self.center_cache[key] = anchor_centers
|
171 |
+
|
172 |
+
pos_inds = np.where(scores>=threshold)[0]
|
173 |
+
bboxes = distance2bbox(anchor_centers, bbox_preds)
|
174 |
+
pos_scores = scores[pos_inds]
|
175 |
+
pos_bboxes = bboxes[pos_inds]
|
176 |
+
scores_list.append(pos_scores)
|
177 |
+
bboxes_list.append(pos_bboxes)
|
178 |
+
if self.use_kps:
|
179 |
+
kpss = distance2kps(anchor_centers, kps_preds)
|
180 |
+
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
|
181 |
+
pos_kpss = kpss[pos_inds]
|
182 |
+
kpss_list.append(pos_kpss)
|
183 |
+
return scores_list, bboxes_list, kpss_list
|
184 |
+
|
185 |
+
def detect(self, img, input_size = (640,640), max_num=0, metric='default', det_thresh=0.5):
|
186 |
+
assert input_size is not None or self.input_size is not None
|
187 |
+
input_size = self.input_size if input_size is None else input_size
|
188 |
+
|
189 |
+
im_ratio = float(img.shape[0]) / img.shape[1]
|
190 |
+
model_ratio = float(input_size[1]) / input_size[0]
|
191 |
+
if im_ratio>model_ratio:
|
192 |
+
new_height = input_size[1]
|
193 |
+
new_width = int(new_height / im_ratio)
|
194 |
+
else:
|
195 |
+
new_width = input_size[0]
|
196 |
+
new_height = int(new_width * im_ratio)
|
197 |
+
det_scale = float(new_height) / img.shape[0]
|
198 |
+
resized_img = cv2.resize(img, (new_width, new_height))
|
199 |
+
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
|
200 |
+
det_img[:new_height, :new_width, :] = resized_img
|
201 |
+
|
202 |
+
scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh)
|
203 |
+
|
204 |
+
scores = np.vstack(scores_list)
|
205 |
+
scores_ravel = scores.ravel()
|
206 |
+
order = scores_ravel.argsort()[::-1]
|
207 |
+
bboxes = np.vstack(bboxes_list) / det_scale
|
208 |
+
if self.use_kps:
|
209 |
+
kpss = np.vstack(kpss_list) / det_scale
|
210 |
+
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
|
211 |
+
pre_det = pre_det[order, :]
|
212 |
+
keep = self.nms(pre_det)
|
213 |
+
det = pre_det[keep, :]
|
214 |
+
if self.use_kps:
|
215 |
+
kpss = kpss[order,:,:]
|
216 |
+
kpss = kpss[keep,:,:]
|
217 |
+
else:
|
218 |
+
kpss = None
|
219 |
+
if max_num > 0 and det.shape[0] > max_num:
|
220 |
+
area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
|
221 |
+
det[:, 1])
|
222 |
+
img_center = img.shape[0] // 2, img.shape[1] // 2
|
223 |
+
offsets = np.vstack([
|
224 |
+
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
|
225 |
+
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
|
226 |
+
])
|
227 |
+
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
|
228 |
+
if metric=='max':
|
229 |
+
values = area
|
230 |
+
else:
|
231 |
+
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
|
232 |
+
bindex = np.argsort(
|
233 |
+
values)[::-1] # some extra weight on the centering
|
234 |
+
bindex = bindex[0:max_num]
|
235 |
+
det = det[bindex, :]
|
236 |
+
if kpss is not None:
|
237 |
+
kpss = kpss[bindex, :]
|
238 |
+
return det, kpss
|
239 |
+
|
240 |
+
def nms(self, dets):
|
241 |
+
thresh = self.nms_thresh
|
242 |
+
x1 = dets[:, 0]
|
243 |
+
y1 = dets[:, 1]
|
244 |
+
x2 = dets[:, 2]
|
245 |
+
y2 = dets[:, 3]
|
246 |
+
scores = dets[:, 4]
|
247 |
+
|
248 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
249 |
+
order = scores.argsort()[::-1]
|
250 |
+
|
251 |
+
keep = []
|
252 |
+
while order.size > 0:
|
253 |
+
i = order[0]
|
254 |
+
keep.append(i)
|
255 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
256 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
257 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
258 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
259 |
+
|
260 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
261 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
262 |
+
inter = w * h
|
263 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
264 |
+
|
265 |
+
inds = np.where(ovr <= thresh)[0]
|
266 |
+
order = order[inds + 1]
|
267 |
+
|
268 |
+
return keep
|