|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
sys.path.append('..') |
|
import os |
|
os.system(f'pip install dlib') |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from torch.nn import functional as F |
|
|
|
import gradio as gr |
|
|
|
import models_vit |
|
from util.datasets import build_dataset |
|
import argparse |
|
from engine_finetune import test_all |
|
import dlib |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
P = os.path.abspath(__file__) |
|
FRAME_SAVE_PATH = os.path.join(P[:-6], 'frame') |
|
CKPT_SAVE_PATH = os.path.join(P[:-6], 'checkpoints') |
|
CKPT_LIST = ['DfD Checkpoint_Fine-tuned on FF++', |
|
'FAS Checkpoint_Fine-tuned on MCIO'] |
|
CKPT_NAME = {'DfD Checkpoint_Fine-tuned on FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth', |
|
'FAS Checkpoint_Fine-tuned on MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth' } |
|
os.makedirs(FRAME_SAVE_PATH, exist_ok=True) |
|
os.makedirs(CKPT_SAVE_PATH, exist_ok=True) |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) |
|
parser.add_argument('--batch_size', default=64, type=int, |
|
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') |
|
parser.add_argument('--epochs', default=50, type=int) |
|
parser.add_argument('--accum_iter', default=1, type=int, |
|
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') |
|
|
|
|
|
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', |
|
help='Name of model to train') |
|
|
|
parser.add_argument('--input_size', default=224, type=int, |
|
help='images input size') |
|
parser.add_argument('--normalize_from_IMN', action='store_true', |
|
help='cal mean and std from imagenet, else from pretrain datasets') |
|
parser.set_defaults(normalize_from_IMN=True) |
|
parser.add_argument('--apply_simple_augment', action='store_true', |
|
help='apply simple data augment') |
|
|
|
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', |
|
help='Drop path rate (default: 0.1)') |
|
|
|
|
|
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', |
|
help='Clip gradient norm (default: None, no clipping)') |
|
parser.add_argument('--weight_decay', type=float, default=0.05, |
|
help='weight decay (default: 0.05)') |
|
|
|
parser.add_argument('--lr', type=float, default=None, metavar='LR', |
|
help='learning rate (absolute lr)') |
|
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', |
|
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') |
|
parser.add_argument('--layer_decay', type=float, default=0.75, |
|
help='layer-wise lr decay from ELECTRA/BEiT') |
|
|
|
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', |
|
help='lower lr bound for cyclic schedulers that hit 0') |
|
|
|
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', |
|
help='epochs to warmup LR') |
|
|
|
|
|
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', |
|
help='Color jitter factor (enabled only when not using Auto/RandAug)') |
|
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', |
|
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), |
|
parser.add_argument('--smoothing', type=float, default=0.1, |
|
help='Label smoothing (default: 0.1)') |
|
|
|
|
|
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', |
|
help='Random erase prob (default: 0.25)') |
|
parser.add_argument('--remode', type=str, default='pixel', |
|
help='Random erase mode (default: "pixel")') |
|
parser.add_argument('--recount', type=int, default=1, |
|
help='Random erase count (default: 1)') |
|
parser.add_argument('--resplit', action='store_true', default=False, |
|
help='Do not random erase first (clean) augmentation split') |
|
|
|
|
|
parser.add_argument('--mixup', type=float, default=0, |
|
help='mixup alpha, mixup enabled if > 0.') |
|
parser.add_argument('--cutmix', type=float, default=0, |
|
help='cutmix alpha, cutmix enabled if > 0.') |
|
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, |
|
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') |
|
parser.add_argument('--mixup_prob', type=float, default=1.0, |
|
help='Probability of performing mixup or cutmix when either/both is enabled') |
|
parser.add_argument('--mixup_switch_prob', type=float, default=0.5, |
|
help='Probability of switching to cutmix when both mixup and cutmix enabled') |
|
parser.add_argument('--mixup_mode', type=str, default='batch', |
|
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') |
|
|
|
|
|
parser.add_argument('--finetune', default='', |
|
help='finetune from checkpoint') |
|
parser.add_argument('--global_pool', action='store_true') |
|
parser.set_defaults(global_pool=True) |
|
parser.add_argument('--cls_token', action='store_false', dest='global_pool', |
|
help='Use class token instead of global pool for classification') |
|
|
|
|
|
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, |
|
help='dataset path') |
|
parser.add_argument('--nb_classes', default=1000, type=int, |
|
help='number of the classification types') |
|
|
|
parser.add_argument('--output_dir', default='', |
|
help='path where to save, empty for no saving') |
|
parser.add_argument('--log_dir', default='', |
|
help='path where to tensorboard log') |
|
parser.add_argument('--device', default='cuda', |
|
help='device to use for training / testing') |
|
parser.add_argument('--seed', default=0, type=int) |
|
parser.add_argument('--resume', default='', |
|
help='resume from checkpoint') |
|
|
|
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', |
|
help='start epoch') |
|
parser.add_argument('--eval', action='store_true', |
|
help='Perform evaluation only') |
|
parser.set_defaults(eval=True) |
|
parser.add_argument('--dist_eval', action='store_true', default=False, |
|
help='Enabling distributed evaluation (recommended during training for faster monitor') |
|
parser.add_argument('--num_workers', default=10, type=int) |
|
parser.add_argument('--pin_mem', action='store_true', |
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') |
|
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') |
|
parser.set_defaults(pin_mem=True) |
|
|
|
|
|
parser.add_argument('--world_size', default=1, type=int, |
|
help='number of distributed processes') |
|
parser.add_argument('--local_rank', default=-1, type=int) |
|
parser.add_argument('--dist_on_itp', action='store_true') |
|
parser.add_argument('--dist_url', default='env://', |
|
help='url used to set up distributed training') |
|
|
|
return parser |
|
|
|
|
|
args = get_args_parser() |
|
args = args.parse_args() |
|
args.nb_classes = 2 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = models_vit.__dict__['vit_base_patch16']( |
|
num_classes=args.nb_classes, |
|
drop_path_rate=args.drop_path, |
|
global_pool=args.global_pool, |
|
) |
|
|
|
def load_model(ckpt): |
|
if ckpt=='hoose from here': |
|
return gr.update() |
|
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt) |
|
if os.path.isfile(args.resume) == False: |
|
hf_hub_download(local_dir=CKPT_SAVE_PATH, |
|
repo_id='Wolowolo/fsfm-3c', |
|
filename=ckpt) |
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
model.load_state_dict(checkpoint['model']) |
|
return gr.update() |
|
|
|
|
|
def get_boundingbox(face, width, height, minsize=None): |
|
""" |
|
From FF++: |
|
https://github.com/ondyari/FaceForensics/blob/master/classification/detect_from_video.py |
|
Expects a dlib face to generate a quadratic bounding box. |
|
:param face: dlib face class |
|
:param width: frame width |
|
:param height: frame height |
|
:param cfg.face_scale: bounding box size multiplier to get a bigger face region |
|
:param minsize: set minimum bounding box size |
|
:return: x, y, bounding_box_size in opencv form |
|
""" |
|
x1 = face.left() |
|
y1 = face.top() |
|
x2 = face.right() |
|
y2 = face.bottom() |
|
size_bb = int(max(x2 - x1, y2 - y1) * 1.3) |
|
if minsize: |
|
if size_bb < minsize: |
|
size_bb = minsize |
|
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 |
|
|
|
|
|
x1 = max(int(center_x - size_bb // 2), 0) |
|
y1 = max(int(center_y - size_bb // 2), 0) |
|
|
|
size_bb = min(width - x1, size_bb) |
|
size_bb = min(height - y1, size_bb) |
|
|
|
return x1, y1, size_bb |
|
|
|
|
|
def extract_face(frame): |
|
face_detector = dlib.get_frontal_face_detector() |
|
image = np.array(frame.convert('RGB')) |
|
faces = face_detector(image, 1) |
|
if len(faces) > 0: |
|
|
|
face = faces[0] |
|
|
|
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0]) |
|
|
|
cropped_face = image[y:y + size, x:x + size] |
|
|
|
return Image.fromarray(cropped_face) |
|
else: |
|
return None |
|
|
|
|
|
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num): |
|
interval = np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int) |
|
return interval.tolist() |
|
|
|
|
|
import cv2 |
|
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None, device='cpu'): |
|
""" |
|
1) extract specific num of frames from videos in [1st(index 0) frame, last frame] with uniform sample interval |
|
2) extract face from frame with specific enlarge size |
|
""" |
|
video_capture = cv2.VideoCapture(src_video) |
|
total_frames = video_capture.get(7) |
|
|
|
|
|
if num_frames is not None: |
|
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) |
|
else: |
|
frame_indices = range(int(total_frames)) |
|
|
|
for frame_index in frame_indices: |
|
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index) |
|
ret, frame = video_capture.read() |
|
image = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)) |
|
img = extract_face(image) |
|
if img == None: |
|
continue |
|
img = img.resize((224, 224), Image.BICUBIC) |
|
if not ret: |
|
continue |
|
save_img_name = f"frame_{frame_index}.png" |
|
|
|
img.save(os.path.join(dst_path, '0', save_img_name)) |
|
|
|
|
|
video_capture.release() |
|
|
|
|
|
|
|
def FSFM3C_video_detection(video): |
|
model.to(device) |
|
|
|
|
|
num_frames = 32 |
|
|
|
files = os.listdir(FRAME_SAVE_PATH) |
|
num_files = len(files) |
|
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files)) |
|
os.makedirs(frame_path, exist_ok=True) |
|
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True) |
|
extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames, device=device) |
|
|
|
args.data_path = frame_path |
|
args.batch_size = 32 |
|
dataset_val = build_dataset(is_train=False, args=args) |
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
data_loader_val = torch.utils.data.DataLoader( |
|
dataset_val, sampler=sampler_val, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
pin_memory=args.pin_mem, |
|
drop_last=False |
|
) |
|
|
|
frame_preds_list, video_y_pred_list = test_all(data_loader_val, model, device) |
|
|
|
return video_y_pred_list |
|
|
|
|
|
def FSFM3C_image_detection(image): |
|
model.to(device) |
|
|
|
files = os.listdir(FRAME_SAVE_PATH) |
|
num_files = len(files) |
|
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files)) |
|
os.makedirs(frame_path, exist_ok=True) |
|
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True) |
|
|
|
save_img_name = f"frame_0.png" |
|
img = extract_face(image) |
|
if img is None: |
|
return ['Invalid Input'] |
|
img = img.resize((224, 224), Image.BICUBIC) |
|
img.save(os.path.join(frame_path, '0', save_img_name)) |
|
|
|
args.data_path = frame_path |
|
args.batch_size = 1 |
|
dataset_val = build_dataset(is_train=False, args=args) |
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
data_loader_val = torch.utils.data.DataLoader( |
|
dataset_val, sampler=sampler_val, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
pin_memory=args.pin_mem, |
|
drop_last=False |
|
) |
|
|
|
frame_preds_list, video_y_pred_list = test_all(data_loader_val, model, device) |
|
|
|
return video_y_pred_list |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery and Spoofing (Deepfake/Diffusion/Presentation-attacks)</h1>") |
|
gr.Markdown("# ---Based on the fine-tuned model that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)") |
|
|
|
gr.Markdown("## Release <br>" |
|
"V1.0 [2024-12] (Current): <br>" |
|
"[1] Create this page with basic detectors (simply fine-tuned models that follow the paper implementation): <br> " |
|
" - DfD Checkpoint_Fine-tuned on FF++: FSFM VIT-B fine-tuned on the FF++ (c23, train&val sets, 32 frames per video, 4 manipulations) dataset <br>" |
|
" - FAS Checkpoint_Fine-tuned on MCIO: FSFM VIT-B fine-tuned on the MCIO datasets (2 frames per video) <br> " |
|
" Performance is limited because no any optimization of data, models, hyperparameters, etc. is done for downstream tasks") |
|
|
|
gr.Markdown("### TODO: We will soon update practical models, and optimized interfaces, and provide more functions such as visualizations, a unified detector, and multi-modal diagnosis.") |
|
|
|
gr.Markdown("> Please provide an <b>image</b> or a <b>video (<100s </b>, default to uniform sampling 32 frames)</b> for detection:") |
|
|
|
|
|
with gr.Column(): |
|
ckpt_select_dropdown = gr.Dropdown( |
|
label = "Select the Model Checkpoint for Detection (🖱️ below)", |
|
choices = ['choose from here'] + CKPT_LIST + ['Continuously updating...'], |
|
multiselect = False, |
|
value = 'choose from here', |
|
interactive = True, |
|
) |
|
with gr.Row(elem_classes="center-align"): |
|
with gr.Column(scale=5): |
|
gr.Markdown( |
|
"## Image Detection" |
|
) |
|
image = gr.Image(label="Upload/Capture/Paste your image", type="pil") |
|
image_submit_btn = gr.Button("Submit") |
|
output_results_image = gr.Textbox(label="Detection Result") |
|
with gr.Column(scale=5): |
|
gr.Markdown( |
|
"## Video Detection" |
|
) |
|
video = gr.Video(label="Upload/Capture your video") |
|
video_submit_btn = gr.Button("Submit") |
|
output_results_video = gr.Textbox(label="Detection Result") |
|
|
|
image_submit_btn.click( |
|
fn=FSFM3C_image_detection, |
|
inputs=[image], |
|
outputs=[output_results_image], |
|
) |
|
video_submit_btn.click( |
|
fn=FSFM3C_video_detection, |
|
inputs=[video], |
|
outputs=[output_results_video], |
|
) |
|
ckpt_select_dropdown.change( |
|
fn=load_model, |
|
inputs=[ckpt_select_dropdown], |
|
outputs=[ckpt_select_dropdown], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
gr.close_all() |
|
demo.queue() |
|
demo.launch() |