swap-mukham_WIP / app.py
Peleck's picture
Initial
fa8453f
raw
history blame
46.2 kB
import os
import cv2
import time
import shutil
import base64
import datetime
import argparse
import numpy as np
import gradio as gr
from tqdm import tqdm
import concurrent.futures
import threading
cv_reader_lock = threading.Lock()
## ------------------------------ USER ARGS ------------------------------
parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
parser.add_argument("--max_threads", type=int, help="Max num of threads to use", default=2)
parser.add_argument("--colab", action="store_true", help="Colab mode", default=False)
parser.add_argument("--cpu", action="store_true", help="Enable cpu mode", default=False)
parser.add_argument("--prefer_text_widget", action="store_true", help="Replaces target video widget with text widget", default=False)
user_args = parser.parse_args()
USE_CPU = 1
if not USE_CPU:
import torch
import default_paths as dp
import global_variables as gv
from swap_mukham import SwapMukham
from nsfw_checker import NSFWChecker
from face_parsing import mask_regions_to_list
from utils.device import get_device_and_provider, device_types_list
from utils.image import (
image_mask_overlay,
resize_image_by_resolution,
resolution_map,
fast_pil_encode,
fast_numpy_encode,
get_crf_for_resolution,
)
from utils.io import (
open_directory,
get_images_from_directory,
copy_files_to_directory,
create_directory,
get_single_video_frame,
ffmpeg_merge_frames,
ffmpeg_mux_audio,
add_datetime_to_filename,
)
gr.processing_utils.encode_pil_to_base64 = fast_pil_encode
gr.processing_utils.encode_array_to_base64 = fast_numpy_encode
gv.USE_COLAB = user_args.colab
gv.MAX_THREADS = user_args.max_threads
gv.DEFAULT_OUTPUT_PATH = user_args.out_dir
PREFER_TEXT_WIDGET = user_args.prefer_text_widget
WORKSPACE = None
OUTPUT_FILE = None
preferred_device = "cpu" if USE_CPU else "cuda"
DEVICE_LIST = device_types_list
DEVICE, PROVIDER, OPTIONS = get_device_and_provider(device=preferred_device)
SWAP_MUKHAM = SwapMukham(device=DEVICE)
IS_RUNNING = False
CURRENT_FRAME = None
COLLECTED_FACES = []
FOREGROUND_MASK_DICT = {}
NSFW_CACHE = {}
## ------------------------------ MAIN PROCESS ------------------------------
def process(
test_mode,
target_type,
image_path,
video_path,
directory_path,
source_path,
use_foreground_mask,
img_fg_mask,
fg_mask_softness,
output_path,
output_name,
use_datetime_suffix,
sequence_output_format,
keep_output_sequence,
swap_condition,
age,
distance,
face_enhancer_name,
face_upscaler_opacity,
use_face_parsing,
parse_from_target,
mask_regions,
mask_blur_amount,
mask_erode_amount,
swap_iteration,
face_scale,
use_laplacian_blending,
crop_top,
crop_bott,
crop_left,
crop_right,
current_idx,
number_of_threads,
use_frame_selection,
frame_selection_ranges,
video_quality,
face_detection_condition,
face_detection_size,
face_detection_threshold,
averaging_method,
progress=gr.Progress(track_tqdm=True),
*specifics,
):
global WORKSPACE
global OUTPUT_FILE
global PREVIEW
WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
global IS_RUNNING
IS_RUNNING = True
## ------------------------------ GUI UPDATE FUNC ------------------------------
def ui_before():
return (
gr.update(visible=True, value=None),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(visible=False, value=None),
)
def ui_after():
return (
gr.update(visible=True, value=PREVIEW),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(visible=False, value=None),
)
def ui_after_vid():
return (
gr.update(visible=False),
gr.update(interactive=True),
gr.update(interactive=True),
gr.update(value=OUTPUT_FILE, visible=True),
)
if not test_mode:
yield ui_before() # resets ui preview
progress(0, desc="Processing")
start_time = time.time()
total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
get_finsh_text = (
lambda start_time: f"Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
)
## ------------------------------ PREPARE INPUTS ------------------------------
if use_datetime_suffix:
output_name = add_datetime_to_filename(output_name)
mask_regions = mask_regions_to_list(mask_regions)
specifics = list(specifics)
half = len(specifics) // 2
if swap_condition == "specific face":
source_specifics = [
([s.name for s in src] if src is not None else None, spc) for src, spc in zip(specifics[:half], specifics[half:])
]
else:
source_paths = [i.name for i in source_path]
source_specifics = [(source_paths, None)]
if crop_top > crop_bott:
crop_top, crop_bott = crop_bott, crop_top
if crop_left > crop_right:
crop_left, crop_right = crop_right, crop_left
crop_mask = (crop_top, 511 - crop_bott, crop_left, 511 - crop_right)
input_args = {
"similarity": distance,
"age": age,
"face_scale": face_scale,
"num_of_pass": swap_iteration,
"face_upscaler_opacity": face_upscaler_opacity,
"mask_crop_values": crop_mask,
"mask_erode_amount": mask_erode_amount,
"mask_blur_amount": mask_blur_amount,
"use_laplacian_blending": use_laplacian_blending,
"swap_condition": swap_condition,
"face_parse_regions": mask_regions,
"use_face_parsing": use_face_parsing,
"face_detection_size": [int(face_detection_size), int(face_detection_size)],
"face_detection_threshold": face_detection_threshold,
"face_detection_condition": face_detection_condition,
"parse_from_target": parse_from_target,
"averaging_method": averaging_method,
}
SWAP_MUKHAM.set_values(input_args)
if (
SWAP_MUKHAM.face_upscaler is None
or SWAP_MUKHAM.face_upscaler_name != face_enhancer_name
):
SWAP_MUKHAM.load_face_upscaler(face_enhancer_name, device=DEVICE)
if SWAP_MUKHAM.face_parser is None and use_face_parsing:
SWAP_MUKHAM.load_face_parser(device=DEVICE)
SWAP_MUKHAM.analyse_source_faces(source_specifics)
mask = None
if use_foreground_mask and img_fg_mask is not None:
mask = img_fg_mask.get("mask", None)
mask = cv2.cvtColor(mask, cv2.COLOR_BGRA2RGB)
if fg_mask_softness > 0:
mask = cv2.blur(mask, (int(fg_mask_softness), int(fg_mask_softness)))
mask = mask.astype("float32") / 255.0
def nsfw_assertion(is_nsfw):
if is_nsfw:
message = "NSFW content detected !"
gr.Info(message)
assert not is_nsfw, message
## ------------------------------ IMAGE ------------------------------
if target_type == "Image" and not test_mode:
target = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
is_nsfw = SWAP_MUKHAM.nsfw_detector.check_image(target)
nsfw_assertion(is_nsfw)
output = SWAP_MUKHAM.process_frame(
[target, mask]
)
output_file = os.path.join(output_path, output_name + ".png")
cv2.imwrite(output_file, output)
PREVIEW = output
OUTPUT_FILE = output_file
WORKSPACE = output_path
gr.Info(get_finsh_text(start_time))
yield ui_after()
## ------------------------------ VIDEO ------------------------------
elif target_type == "Video" and not test_mode:
video_path = video_path.replace('"', '').strip()
if video_path in NSFW_CACHE.keys():
nsfw_assertion(NSFW_CACHE.get(video_path))
else:
is_nsfw = SWAP_MUKHAM.nsfw_detector.check_video(video_path)
NSFW_CACHE[video_path] = is_nsfw
nsfw_assertion(is_nsfw)
temp_path = os.path.join(output_path, output_name)
os.makedirs(temp_path, exist_ok=True)
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
is_in_range = lambda idx: any([int(rng[0]) <= idx <= int(rng[1]) for rng in frame_selection_ranges]) if use_frame_selection else True
print("[ Swapping process started ]")
def swap_video_func(frame_index):
if IS_RUNNING:
with cv_reader_lock:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(frame_index))
valid_frame, frame = cap.read()
if valid_frame:
if is_in_range(frame_index):
mask = FOREGROUND_MASK_DICT.get(frame_index, None) if use_foreground_mask else None
output = SWAP_MUKHAM.process_frame([frame, mask])
else:
output = frame
frame_path = os.path.join(temp_path, f"frame_{frame_index}.{sequence_output_format}")
if sequence_output_format == "jpg":
cv2.imwrite(frame_path, output, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
else:
cv2.imwrite(frame_path, output)
with concurrent.futures.ThreadPoolExecutor(max_workers=number_of_threads) as executor:
futures = [executor.submit(swap_video_func, idx) for idx in range(total_frames)]
with tqdm(total=total_frames, desc="Processing") as pbar:
for future in concurrent.futures.as_completed(futures):
future.result()
pbar.update(1)
cap.release()
if IS_RUNNING:
print("[ Merging image sequence ]")
progress(0, desc="Merging image sequence")
WORKSPACE = output_path
out_without_audio = output_name + "_without_audio" + ".mp4"
destination = os.path.join(output_path, out_without_audio)
crf = get_crf_for_resolution(max(width,height), video_quality)
ret, destination = ffmpeg_merge_frames(
temp_path, f"frame_%d.{sequence_output_format}", destination, fps=fps, crf=crf, ffmpeg_path=dp.FFMPEG_PATH
)
OUTPUT_FILE = destination
if ret:
print("[ Merging audio ]")
progress(0, desc="Merging audio")
OUTPUT_FILE = destination
out_with_audio = out_without_audio.replace("_without_audio", "")
_ret, _destination = ffmpeg_mux_audio(
video_path, out_without_audio, out_with_audio, ffmpeg_path=dp.FFMPEG_PATH
)
if _ret:
OUTPUT_FILE = _destination
os.remove(out_without_audio)
if os.path.exists(temp_path) and not keep_output_sequence:
print("[ Removing temporary files ]")
progress(0, desc="Removing temporary files")
shutil.rmtree(temp_path)
finish_text = get_finsh_text(start_time)
print(f"[ {finish_text} ]")
gr.Info(finish_text)
yield ui_after_vid()
## ------------------------------ DIRECTORY ------------------------------
elif target_type == "Directory" and not test_mode:
temp_path = os.path.join(output_path, output_name)
temp_path = create_directory(temp_path, remove_existing=True)
directory_path = directory_path.replace('"', '').strip()
image_paths = get_images_from_directory(directory_path)
is_nsfw = SWAP_MUKHAM.nsfw_detector.check_image_paths(image_paths)
nsfw_assertion(is_nsfw)
new_image_paths = copy_files_to_directory(image_paths, temp_path)
def swap_func(img_path):
if IS_RUNNING:
frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
output = SWAP_MUKHAM.process_frame([frame, None])
cv2.imwrite(img_path, output)
with concurrent.futures.ThreadPoolExecutor(max_workers=number_of_threads) as executor:
futures = [executor.submit(swap_func, img_path) for img_path in new_image_paths]
with tqdm(total=len(new_image_paths), desc="Processing") as pbar:
for future in concurrent.futures.as_completed(futures):
future.result()
pbar.update(1)
PREVIEW = cv2.imread(new_image_paths[-1])
WORKSPACE = temp_path
OUTPUT_FILE = new_image_paths[-1]
gr.Info(get_finsh_text(start_time))
yield ui_after()
## ------------------------------ STREAM ------------------------------
elif target_type == "Stream" and not test_mode:
pass
## ------------------------------ TEST ------------------------------
if test_mode and target_type == "Video":
mask = None
if use_face_parsing_mask:
mask = FOREGROUND_MASK_DICT.get(current_idx, None)
if CURRENT_FRAME is not None and isinstance(CURRENT_FRAME, np.ndarray):
PREVIEW = SWAP_MUKHAM.process_frame(
[CURRENT_FRAME[:, :, ::-1], mask]
)
gr.Info(get_finsh_text(start_time))
yield ui_after()
## ------------------------------ GRADIO GUI ------------------------------
css = """
div.gradio-container{
max-width: unset !important;
}
footer{
display:none !important
}
#slider_row {
display: flex;
flex-wrap: wrap;
justify-content: space-between;
}
#refresh_slider {
flex: 0 1 20%;
display: flex;
align-items: center;
}
#frame_slider {
flex: 1 0 80%;
display: flex;
align-items: center;
}
"""
WIDGET_PREVIEW_HEIGHT = 450
with gr.Blocks(css=css, theme=gr.themes.Default()) as interface:
gr.Markdown("# 🗿 Swap Mukham")
gr.Markdown("### Single image face swapper")
with gr.Row():
with gr.Row():
with gr.Column(scale=0.35):
with gr.Tabs():
with gr.TabItem("📄 Input"):
swap_condition = gr.Dropdown(
gv.FACE_DETECT_CONDITIONS,
info="Choose which face or faces in the target image to swap.",
multiselect=False,
show_label=False,
value=gv.FACE_DETECT_CONDITIONS[0],
interactive=True,
)
age = gr.Number(
value=25, label="Value", interactive=True, visible=False
)
## ------------------------------ SOURCE IMAGE ------------------------------
source_image_input = gr.Files(
label="Source face", type="file", interactive=True,
)
## ------------------------------ SOURCE SPECIFIC ------------------------------
with gr.Box(visible=False) as specific_face:
for i in range(gv.NUM_OF_SRC_SPECIFIC):
idx = i + 1
code = "\n"
code += f"with gr.Tab(label='{idx}'):"
code += "\n\twith gr.Row():"
code += f"\n\t\tsrc{idx} = gr.Files(interactive=True, type='file', label='Source Face {idx}')"
code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
exec(code)
## ------------------------------ TARGET TYPE ------------------------------
with gr.Group():
target_type = gr.Radio(
["Image", "Video", "Directory"],
label="Target Type",
value="Video",
)
## ------------------------------ TARGET IMAGE ------------------------------
with gr.Box(visible=False) as input_image_group:
target_image_input = gr.Image(
label="Target Image",
interactive=True,
type="filepath",
height=200
)
## ------------------------------ TARGET VIDEO ------------------------------
with gr.Box(visible=True) as input_video_group:
with gr.Column():
video_widget = gr.Text if PREFER_TEXT_WIDGET else gr.Video
video_input = video_widget(
label="Target Video", interactive=True,
)
## ------------------------------ FRAME SELECTION ------------------------------
with gr.Accordion("Frame Selection", open=False):
use_frame_selection = gr.Checkbox(
label="Use frame selection", value=False, interactive=True,
)
frame_selection_ranges = gr.Numpy(
headers=["Start Frame", "End Frame"],
datatype=["number", "number"],
row_count=1,
col_count=(2, "fixed"),
interactive=True
)
## ------------------------------ TARGET DIRECTORY ------------------------------
with gr.Box(visible=False) as input_directory_group:
directory_input = gr.Text(
label="Target Image Directory", interactive=True
)
## ------------------------------ TAB MODEL ------------------------------
with gr.TabItem("🎚️ Model"):
with gr.Accordion("Detection", open=False):
face_detection_condition = gr.Dropdown(
gv.SINGLE_FACE_DETECT_CONDITIONS,
label="Condition",
value=gv.DETECT_CONDITION,
interactive=True,
info="This condition is only used when multiple faces are detected on source or specific image.",
)
face_detection_size = gr.Number(
label="Detection Size",
value=gv.DETECT_SIZE,
interactive=True,
)
face_detection_threshold = gr.Number(
label="Detection Threshold",
value=gv.DETECT_THRESHOLD,
interactive=True,
)
face_scale = gr.Slider(
label="Landmark Scale",
minimum=0,
maximum=2,
value=1,
interactive=True,
)
with gr.Accordion("Embedding/Recognition", open=True):
averaging_method = gr.Dropdown(
gv.AVERAGING_METHODS,
label="Averaging Method",
value=gv.AVERAGING_METHOD,
interactive=True,
)
distance_slider = gr.Slider(
minimum=0,
maximum=2,
value=0.65,
interactive=True,
label="Specific-Target Distance",
)
with gr.Accordion("Swapper", open=True):
with gr.Row():
swap_iteration = gr.Slider(
label="Swap Iteration",
minimum=1,
maximum=4,
value=1,
step=1,
interactive=True,
)
## ------------------------------ TAB POST-PROCESS ------------------------------
with gr.TabItem("🪄 Post-Process"):
with gr.Row():
face_enhancer_name = gr.Dropdown(
gv.FACE_ENHANCER_LIST,
label="Face Enhancer",
value="NONE",
multiselect=False,
interactive=True,
)
face_upscaler_opacity = gr.Slider(
label="Opacity",
minimum=0,
maximum=1,
value=1,
step=0.001,
interactive=True,
)
with gr.Accordion("Face Mask", open=False):
with gr.Group():
with gr.Row():
use_face_parsing_mask = gr.Checkbox(
label="Enable Face Parsing",
value=False,
interactive=True,
)
parse_from_target = gr.Checkbox(
label="Parse from target",
value=False,
interactive=True,
)
mask_regions = gr.Dropdown(
gv.MASK_REGIONS,
value=gv.MASK_REGIONS_DEFAULT,
multiselect=True,
label="Include",
interactive=True,
)
with gr.Accordion("Crop Face Bounding-Box", open=False):
with gr.Group():
with gr.Row():
crop_top = gr.Slider(
label="Top",
minimum=0,
maximum=511,
value=0,
step=1,
interactive=True,
)
crop_bott = gr.Slider(
label="Bottom",
minimum=0,
maximum=511,
value=511,
step=1,
interactive=True,
)
with gr.Row():
crop_left = gr.Slider(
label="Left",
minimum=0,
maximum=511,
value=0,
step=1,
interactive=True,
)
crop_right = gr.Slider(
label="Right",
minimum=0,
maximum=511,
value=511,
step=1,
interactive=True,
)
with gr.Row():
mask_erode_amount = gr.Slider(
label="Mask Erode",
minimum=0,
maximum=1,
value=gv.MASK_ERODE_AMOUNT,
step=0.001,
interactive=True,
)
mask_blur_amount = gr.Slider(
label="Mask Blur",
minimum=0,
maximum=1,
value=gv.MASK_BLUR_AMOUNT,
step=0.001,
interactive=True,
)
use_laplacian_blending = gr.Checkbox(
label="Laplacian Blending",
value=True,
interactive=True,
)
## ------------------------------ TAB OUTPUT ------------------------------
with gr.TabItem("📤 Output"):
output_directory = gr.Text(
label="Output Directory",
value=gv.DEFAULT_OUTPUT_PATH,
interactive=True,
)
with gr.Group():
output_name = gr.Text(
label="Output Name", value="Result", interactive=True
)
use_datetime_suffix = gr.Checkbox(
label="Suffix date-time", value=True, interactive=True
)
with gr.Accordion("Video settings", open=True):
with gr.Row():
sequence_output_format = gr.Dropdown(
["jpg", "png"],
label="Sequence format",
value="jpg",
interactive=True,
)
video_quality = gr.Dropdown(
gv.VIDEO_QUALITY_LIST,
label="Quality",
value=gv.VIDEO_QUALITY,
interactive=True
)
keep_output_sequence = gr.Checkbox(
label="Keep output sequence", value=False, interactive=True
)
## ------------------------------ TAB PERFORMANCE ------------------------------
with gr.TabItem("🛠️ Performance"):
preview_resolution = gr.Dropdown(
gv.RESOLUTIONS,
label="Preview Resolution",
value="Original",
interactive=True,
)
number_of_threads = gr.Number(
step=1,
interactive=True,
label="Max number of threads",
value=gv.MAX_THREADS,
minimum=1,
)
with gr.Box():
with gr.Column():
with gr.Row():
face_analyser_device = gr.Radio(
DEVICE_LIST,
label="Face detection & recognition",
value=DEVICE,
interactive=True,
)
face_analyser_device_submit = gr.Button("Apply")
with gr.Row():
face_swapper_device = gr.Radio(
DEVICE_LIST,
label="Face swapper",
value=DEVICE,
interactive=True,
)
face_swapper_device_submit = gr.Button("Apply")
with gr.Row():
face_parser_device = gr.Radio(
DEVICE_LIST,
label="Face parsing",
value=DEVICE,
interactive=True,
)
face_parser_device_submit = gr.Button("Apply")
with gr.Row():
face_upscaler_device = gr.Radio(
DEVICE_LIST,
label="Face upscaler",
value=DEVICE,
interactive=True,
)
face_upscaler_device_submit = gr.Button("Apply")
face_analyser_device_submit.click(
fn=lambda d: SWAP_MUKHAM.load_face_analyser(
device=d
),
inputs=[face_analyser_device],
)
face_swapper_device_submit.click(
fn=lambda d: SWAP_MUKHAM.load_face_swapper(
device=d
),
inputs=[face_swapper_device],
)
face_parser_device_submit.click(
fn=lambda d: SWAP_MUKHAM.load_face_parser(device=d),
inputs=[face_parser_device],
)
face_upscaler_device_submit.click(
fn=lambda n, d: SWAP_MUKHAM.load_face_upscaler(
n, device=d
),
inputs=[face_enhancer_name, face_upscaler_device],
)
## ------------------------------ SWAP, CANCEL, FRAME SLIDER ------------------------------
with gr.Column(scale=0.65):
with gr.Row():
swap_button = gr.Button("✨ Swap", variant="primary")
cancel_button = gr.Button("⛔ Cancel")
collect_faces = gr.Button("👨 Collect Faces")
test_swap = gr.Button("🧪 Test Swap")
with gr.Box() as frame_slider_box:
with gr.Row(elem_id="slider_row", equal_height=True):
set_slider_range_btn = gr.Button(
"Set Range", interactive=True, elem_id="refresh_slider"
)
frame_slider = gr.Slider(
label="Frame",
minimum=0,
maximum=1,
value=0,
step=1,
interactive=True,
elem_id="frame_slider",
)
## ------------------------------ PREVIEW ------------------------------
with gr.Tabs():
with gr.TabItem("Preview"):
preview_image = gr.Image(
label="Preview", type="numpy", interactive=False, height=WIDGET_PREVIEW_HEIGHT,
)
preview_video = gr.Video(
label="Output", interactive=False, visible=False, height=WIDGET_PREVIEW_HEIGHT,
)
preview_enabled_text = gr.Markdown(
"Disable paint foreground to preview !", visible=False
)
with gr.Row():
output_directory_button = gr.Button(
"📂", interactive=False, visible=not gv.USE_COLAB
)
output_video_button = gr.Button(
"🎬", interactive=False, visible=not gv.USE_COLAB
)
output_directory_button.click(
lambda: open_directory(path=WORKSPACE),
inputs=None,
outputs=None,
)
output_video_button.click(
lambda: open_directory(path=OUTPUT_FILE),
inputs=None,
outputs=None,
)
## ------------------------------ FOREGROUND MASK ------------------------------
with gr.TabItem("Paint Foreground"):
with gr.Box() as fg_mask_group:
with gr.Row():
with gr.Row():
use_foreground_mask = gr.Checkbox(
label="Use foreground mask", value=False, interactive=True)
fg_mask_softness = gr.Slider(
label="Mask Softness",
minimum=0,
maximum=200,
value=1,
step=1,
interactive=True,
)
add_fg_mask_btn = gr.Button("Add", interactive=True)
del_fg_mask_btn = gr.Button("Del", interactive=True)
img_fg_mask = gr.Image(
label="Paint Mask",
tool="sketch",
interactive=True,
type="numpy",
height=WIDGET_PREVIEW_HEIGHT,
)
## ------------------------------ COLLECT FACE ------------------------------
with gr.TabItem("Collected Faces"):
collected_faces = gr.Gallery(
label="Faces",
show_label=False,
elem_id="gallery",
columns=[6], rows=[6], object_fit="contain", height=WIDGET_PREVIEW_HEIGHT,
)
## ------------------------------ FOOTER LINKS ------------------------------
with gr.Row(variant='panel'):
gr.HTML(
"""
<div style="display: flex; flex-direction: row; justify-content: center;">
<h3 style="margin-right: 10px;"><a href="https://github.com/sponsors/harisreedhar" style="text-decoration: none;">🤝 Sponsor</a></h3>
<h3 style="margin-right: 10px;"><a href="https://github.com/harisreedhar/Swap-Mukham" style="text-decoration: none;">👨‍💻 Source</a></h3>
<h3 style="margin-right: 10px;"><a href="https://github.com/harisreedhar/Swap-Mukham#disclaimer" style="text-decoration: none;">⚠️ Disclaimer</a></h3>
<h3 style="margin-right: 10px;"><a href="https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb" style="text-decoration: none;">🌐 Colab</a></h3>
<h3><a href="https://github.com/harisreedhar/Swap-Mukham#acknowledgements" style="text-decoration: none;">🤗 Acknowledgements</a></h3>
</div>
"""
)
## ------------------------------ GRADIO EVENTS ------------------------------
def on_target_type_change(value):
visibility = {
"Image": (True, False, False, False, True, False, False, False),
"Video": (False, True, False, True, True, True, True, True),
"Directory": (False, False, True, False, False, False, False, False),
"Stream": (False, False, True, False, False, False, False, False),
}
return list(gr.update(visible=i) for i in visibility[value])
target_type.change(
on_target_type_change,
inputs=[target_type],
outputs=[
input_image_group,
input_video_group,
input_directory_group,
frame_slider_box,
fg_mask_group,
add_fg_mask_btn,
del_fg_mask_btn,
test_swap,
],
)
target_image_input.change(
lambda inp: gr.update(value=inp),
inputs=[target_image_input],
outputs=[img_fg_mask]
)
def on_swap_condition_change(value):
visibility = {
"age less than": (True, False, True),
"age greater than": (True, False, True),
"specific face": (False, True, False),
}
return tuple(
gr.update(visible=i) for i in visibility.get(value, (False, False, True))
)
swap_condition.change(
on_swap_condition_change,
inputs=[swap_condition],
outputs=[age, specific_face, source_image_input],
)
def on_set_slider_range(video_path):
if video_path is None or not os.path.exists(video_path):
gr.Info("Check video path")
else:
try:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
if total_frames > 0:
total_frames -= 1
return gr.Slider.update(
minimum=0, maximum=total_frames, value=0, interactive=True
)
gr.Info("Error fetching video")
except:
gr.Info("Error fetching video")
set_slider_range_event = set_slider_range_btn.click(
on_set_slider_range,
inputs=[video_input],
outputs=[frame_slider],
)
def update_preview(video_path, frame_index, use_foreground_mask, resolution):
if not os.path.exists(video_path):
yield gr.update(value=None), gr.update(value=None), gr.update(visible=False)
else:
frame = get_single_video_frame(video_path, frame_index)
if frame is not None:
if use_foreground_mask:
overlayed_image = frame
if frame_index in FOREGROUND_MASK_DICT.keys():
mask = FOREGROUND_MASK_DICT.get(frame_index, None)
if mask is not None:
overlayed_image = image_mask_overlay(frame, mask)
yield gr.update(value=None), gr.update(value=None), gr.update(visible=False) # clear previous mask
frame = resize_image_by_resolution(frame, resolution)
yield gr.update(value=frame[:, :, ::-1]), gr.update(
value=overlayed_image[:, :, ::-1], visible=True
), gr.update(visible=False)
else:
frame = resize_image_by_resolution(frame, resolution)
yield gr.update(value=frame[:, :, ::-1]), gr.update(value=None), gr.update(
visible=False
)
global CURRENT_FRAME
CURRENT_FRAME = frame
frame_slider_event = frame_slider.change(
fn=update_preview,
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
outputs=[preview_image, img_fg_mask, preview_video],
show_progress=False,
)
def add_foreground_mask(fg, frame_index, softness):
if fg is not None:
mask = fg.get("mask", None)
if mask is not None:
alpha_rgb = cv2.cvtColor(mask, cv2.COLOR_BGRA2RGB)
alpha_rgb = cv2.blur(alpha_rgb, (softness, softness))
FOREGROUND_MASK_DICT[frame_index] = alpha_rgb.astype("float32") / 255.0
gr.Info(f"saved mask index {frame_index}")
add_foreground_mask_event = add_fg_mask_btn.click(
fn=add_foreground_mask,
inputs=[img_fg_mask, frame_slider, fg_mask_softness],
).then(
fn=update_preview,
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
outputs=[preview_image, img_fg_mask, preview_video],
show_progress=False,
)
def delete_foreground_mask(frame_index):
if frame_index in FOREGROUND_MASK_DICT.keys():
FOREGROUND_MASK_DICT.pop(frame_index)
gr.Info(f"Deleted mask index {frame_index}")
del_custom_mask_event = del_fg_mask_btn.click(
fn=delete_foreground_mask, inputs=[frame_slider]
).then(
fn=update_preview,
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
outputs=[preview_image, img_fg_mask, preview_video],
show_progress=False,
)
def get_collected_faces(image):
if image is not None:
gr.Info(f"Collecting faces...")
faces = SWAP_MUKHAM.collect_heads(image)
COLLECTED_FACES.extend(faces)
yield COLLECTED_FACES
gr.Info(f"Collected {len(faces)} faces")
collect_faces.click(get_collected_faces, inputs=[preview_image], outputs=[collected_faces])
src_specific_inputs = []
gen_variable_txt = ",".join(
[f"src{i+1}" for i in range(gv.NUM_OF_SRC_SPECIFIC)]
+ [f"trg{i+1}" for i in range(gv.NUM_OF_SRC_SPECIFIC)]
)
exec(f"src_specific_inputs = ({gen_variable_txt})")
test_mode = gr.Checkbox(value=False, visible=False)
swap_inputs = [
test_mode,
target_type,
target_image_input,
video_input,
directory_input,
source_image_input,
use_foreground_mask,
img_fg_mask,
fg_mask_softness,
output_directory,
output_name,
use_datetime_suffix,
sequence_output_format,
keep_output_sequence,
swap_condition,
age,
distance_slider,
face_enhancer_name,
face_upscaler_opacity,
use_face_parsing_mask,
parse_from_target,
mask_regions,
mask_blur_amount,
mask_erode_amount,
swap_iteration,
face_scale,
use_laplacian_blending,
crop_top,
crop_bott,
crop_left,
crop_right,
frame_slider,
number_of_threads,
use_frame_selection,
frame_selection_ranges,
video_quality,
face_detection_condition,
face_detection_size,
face_detection_threshold,
averaging_method,
*src_specific_inputs,
]
swap_outputs = [
preview_image,
output_directory_button,
output_video_button,
preview_video,
]
swap_event = swap_button.click(fn=process, inputs=swap_inputs, outputs=swap_outputs)
test_swap_settings = swap_inputs
test_swap_settings[0] = gr.Checkbox(value=True, visible=False)
test_swap_event = test_swap.click(
fn=update_preview,
inputs=[video_input, frame_slider, use_foreground_mask, preview_resolution],
outputs=[preview_image, preview_video],
show_progress=False,
).then(
fn=process, inputs=test_swap_settings, outputs=swap_outputs, show_progress=True
)
def stop_running():
global IS_RUNNING
IS_RUNNING = False
print("[ Process cancelled ]")
gr.Info("Process cancelled")
cancel_button.click(
fn=stop_running,
inputs=None,
cancels=[swap_event, set_slider_range_event, test_swap_event],
show_progress=True,
)
if __name__ == "__main__":
if gv.USE_COLAB:
print("Running in colab mode")
interface.queue(concurrency_count=2, max_size=20).launch(share=gv.USE_COLAB)