FSFM-3C
init
b32e831
raw
history blame
17.1 kB
# -*- coding: utf-8 -*-
# Author: Gaojian Wang@ZJUICSR
# --------------------------------------------------------
# This source code is licensed under the Attribution-NonCommercial 4.0 International License.
# You can find the license in the LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# pip uninstall nvidia_cublas_cu11
import sys
sys.path.append('..')
import os
os.system(f'pip install dlib')
import torch
import numpy as np
from PIL import Image
from torch.nn import functional as F
import gradio as gr
import models_vit
from util.datasets import build_dataset
import argparse
from engine_finetune import test_all
import dlib
from huggingface_hub import hf_hub_download
P = os.path.abspath(__file__)
FRAME_SAVE_PATH = os.path.join(P[:-6], 'frame')
CKPT_SAVE_PATH = os.path.join(P[:-6], 'checkpoints')
CKPT_LIST = ['DfD Checkpoint_Fine-tuned on FF++',
'FAS Checkpoint_Fine-tuned on MCIO']
CKPT_NAME = {'DfD Checkpoint_Fine-tuned on FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
'FAS Checkpoint_Fine-tuned on MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth' }
os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=64, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
parser.add_argument('--normalize_from_IMN', action='store_true',
help='cal mean and std from imagenet, else from pretrain datasets')
parser.set_defaults(normalize_from_IMN=True)
parser.add_argument('--apply_simple_augment', action='store_true',
help='apply simple data augment')
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)')
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.75,
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=1000, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.set_defaults(eval=True)
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
args = get_args_parser()
args = args.parse_args()
args.nb_classes = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models_vit.__dict__['vit_base_patch16'](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
def load_model(ckpt):
if ckpt=='hoose from here':
return gr.update()
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
if os.path.isfile(args.resume) == False:
hf_hub_download(local_dir=CKPT_SAVE_PATH,
repo_id='Wolowolo/fsfm-3c',
filename=ckpt)
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
return gr.update()
def get_boundingbox(face, width, height, minsize=None):
"""
From FF++:
https://github.com/ondyari/FaceForensics/blob/master/classification/detect_from_video.py
Expects a dlib face to generate a quadratic bounding box.
:param face: dlib face class
:param width: frame width
:param height: frame height
:param cfg.face_scale: bounding box size multiplier to get a bigger face region
:param minsize: set minimum bounding box size
:return: x, y, bounding_box_size in opencv form
"""
x1 = face.left()
y1 = face.top()
x2 = face.right()
y2 = face.bottom()
size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
if minsize:
if size_bb < minsize:
size_bb = minsize
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
# Check for out of bounds, x-y top left corner
x1 = max(int(center_x - size_bb // 2), 0)
y1 = max(int(center_y - size_bb // 2), 0)
# Check for too big bb size for given x, y
size_bb = min(width - x1, size_bb)
size_bb = min(height - y1, size_bb)
return x1, y1, size_bb
def extract_face(frame):
face_detector = dlib.get_frontal_face_detector()
image = np.array(frame.convert('RGB'))
faces = face_detector(image, 1)
if len(faces) > 0:
# For now only take the biggest face
face = faces[0]
# Face crop and rescale(follow FF++)
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
# Get the landmarks/parts for the face in box d only with the five key points
cropped_face = image[y:y + size, x:x + size]
# cropped_face = cv2.resize(cropped_face, (224, 224), interpolation=cv2.INTER_CUBIC)
return Image.fromarray(cropped_face)
else:
return None
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
interval = np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int)
return interval.tolist()
import cv2
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None, device='cpu'):
"""
1) extract specific num of frames from videos in [1st(index 0) frame, last frame] with uniform sample interval
2) extract face from frame with specific enlarge size
"""
video_capture = cv2.VideoCapture(src_video)
total_frames = video_capture.get(7)
# extract from the 1st(index 0) frame
if num_frames is not None:
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames)
else:
frame_indices = range(int(total_frames))
for frame_index in frame_indices:
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = video_capture.read()
image = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
img = extract_face(image)
if img == None:
continue
img = img.resize((224, 224), Image.BICUBIC)
if not ret:
continue
save_img_name = f"frame_{frame_index}.png"
img.save(os.path.join(dst_path, '0', save_img_name))
# cv2.imwrite(os.path.join(dst_path, '0', save_img_name), frame)
video_capture.release()
# cv2.destroyAllWindows()
def FSFM3C_video_detection(video):
model.to(device)
# extract frames
num_frames = 32
files = os.listdir(FRAME_SAVE_PATH)
num_files = len(files)
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
os.makedirs(frame_path, exist_ok=True)
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames, device=device)
args.data_path = frame_path
args.batch_size = 32
dataset_val = build_dataset(is_train=False, args=args)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
frame_preds_list, video_y_pred_list = test_all(data_loader_val, model, device)
return video_y_pred_list
def FSFM3C_image_detection(image):
model.to(device)
files = os.listdir(FRAME_SAVE_PATH)
num_files = len(files)
frame_path = os.path.join(FRAME_SAVE_PATH, str(num_files))
os.makedirs(frame_path, exist_ok=True)
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
save_img_name = f"frame_0.png"
img = extract_face(image)
if img is None:
return ['Invalid Input']
img = img.resize((224, 224), Image.BICUBIC)
img.save(os.path.join(frame_path, '0', save_img_name))
args.data_path = frame_path
args.batch_size = 1
dataset_val = build_dataset(is_train=False, args=args)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
frame_preds_list, video_y_pred_list = test_all(data_loader_val, model, device)
return video_y_pred_list
# WebUI
with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery and Spoofing (Deepfake/Diffusion/Presentation-attacks)</h1>")
gr.Markdown("# ---Based on the fine-tuned model that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)")
gr.Markdown("## Release <br>"
"V1.0 [2024-12] (Current): <br>"
"[1] Create this page with basic detectors (simply fine-tuned models that follow the paper implementation): <br> "
" - DfD Checkpoint_Fine-tuned on FF++: FSFM VIT-B fine-tuned on the FF++ (c23, train&val sets, 32 frames per video, 4 manipulations) dataset <br>"
" - FAS Checkpoint_Fine-tuned on MCIO: FSFM VIT-B fine-tuned on the MCIO datasets (2 frames per video) <br> "
" Performance is limited because no any optimization of data, models, hyperparameters, etc. is done for downstream tasks")
gr.Markdown("### TODO: We will soon update practical models, and optimized interfaces, and provide more functions such as visualizations, a unified detector, and multi-modal diagnosis.")
gr.Markdown("> Please provide an <b>image</b> or a <b>video (<100s </b>, default to uniform sampling 32 frames)</b> for detection:")
with gr.Column():
ckpt_select_dropdown = gr.Dropdown(
label = "Select the Model Checkpoint for Detection (🖱️ below)",
choices = ['choose from here'] + CKPT_LIST + ['Continuously updating...'],
multiselect = False,
value = 'choose from here',
interactive = True,
)
with gr.Row(elem_classes="center-align"):
with gr.Column(scale=5):
gr.Markdown(
"## Image Detection"
)
image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
image_submit_btn = gr.Button("Submit")
output_results_image = gr.Textbox(label="Detection Result")
with gr.Column(scale=5):
gr.Markdown(
"## Video Detection"
)
video = gr.Video(label="Upload/Capture your video")
video_submit_btn = gr.Button("Submit")
output_results_video = gr.Textbox(label="Detection Result")
image_submit_btn.click(
fn=FSFM3C_image_detection,
inputs=[image],
outputs=[output_results_image],
)
video_submit_btn.click(
fn=FSFM3C_video_detection,
inputs=[video],
outputs=[output_results_video],
)
ckpt_select_dropdown.change(
fn=load_model,
inputs=[ckpt_select_dropdown],
outputs=[ckpt_select_dropdown],
)
if __name__ == "__main__":
gr.close_all()
demo.queue()
demo.launch()