diff --git a/ConsistentID/.gitattributes b/ConsistentID/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..58f1377d7c533a58a9c9de1a4a43f3cdda09fca4 --- /dev/null +++ b/ConsistentID/.gitattributes @@ -0,0 +1,38 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +images/templates/3f8d901770014c1b8f7f261971f0e92.png filter=lfs diff=lfs merge=lfs -text +images/templates/75583964a834abe33b72f52b1a98e84.png filter=lfs diff=lfs merge=lfs -text +models/LLaVA/images/demo_cli.gif filter=lfs diff=lfs merge=lfs -text diff --git a/ConsistentID/.gitignore b/ConsistentID/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..03902be4a3bb461a24426f4af15a055de6c7d553 --- /dev/null +++ b/ConsistentID/.gitignore @@ -0,0 +1,5 @@ +__pycache__/* +__pycache__ +/*.png +models/insightface +models/Realistic_Vision* diff --git a/ConsistentID/LICENSE b/ConsistentID/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4de2ad6baf433ec9f6fc16246814237acd15c38f --- /dev/null +++ b/ConsistentID/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Jiehui Huang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/ConsistentID/README.md b/ConsistentID/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6edc705fd6ef175c3398fa5abebcc6f182880b51 --- /dev/null +++ b/ConsistentID/README.md @@ -0,0 +1,13 @@ +--- +title: ConsistentID +emoji: 🔥 +colorFrom: yellow +colorTo: yellow +sdk: gradio +sdk_version: 4.37.2 +app_file: app.py +pinned: false +license: apache-2.0 +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/ConsistentID/__init__.py b/ConsistentID/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ConsistentID/app.py b/ConsistentID/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ad923a26857e37af1d9ea93933e3193166dd990a --- /dev/null +++ b/ConsistentID/app.py @@ -0,0 +1,168 @@ +import gradio as gr +import torch +import os +import glob +import spaces +import numpy as np + +from PIL import Image +from diffusers.utils import load_image +from diffusers import EulerDiscreteScheduler +from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('--base_model_path', type=str, + default="models/Realistic_Vision_V4.0_noVAE") +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +device = f"cuda:{args.gpu}" + +### Load base model +pipe = ConsistentIDPipeline.from_pretrained( + args.base_model_path, + torch_dtype=torch.float16, +) + +### Load consistentID_model checkpoint +pipe.load_ConsistentID_model( + consistentID_weight_path="./models/ConsistentID-v1.bin", + bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth", +) +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to(device, torch.float16) + +@spaces.GPU +def process(selected_template_images, custom_image, prompt, + negative_prompt, prompt_selected, model_selected_tab, + prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set): + + # The gradio UI only supports one image at a time. + if model_selected_tab==0: + subj_images = load_image(Image.open(selected_template_images)) + else: + subj_images = load_image(Image.fromarray(custom_image)) + + if prompt_selected_tab==0: + prompt = prompt_selected + negative_prompt = "" + + # hyper-parameter + num_steps = 50 + seed_set = torch.randint(0, 1000, (1,)).item() + # merge_steps = 30 + + if prompt == "": + prompt = "A man, in a forest" + prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals" + prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind" + else: + #prompt=Enhance_prompt(prompt, Image.new('RGB', (200, 200), color = 'white')) + print(prompt) + + if negative_prompt == "": + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry" + + #Extend Prompt + #prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed" + #print(prompt) + + negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))" + negative_prompt = negative_prompt + negtive_prompt_group + + # seed = torch.randint(0, 1000, (1,)).item() + generator = torch.Generator(device=device).manual_seed(seed_set) + + images = pipe( + prompt=prompt, + width=width, + height=height, + input_subj_image_objs=subj_images, + negative_prompt=negative_prompt, + num_images_per_prompt=1, + num_inference_steps=num_steps, + guidance_scale=guidance_scale, + start_merge_step=merge_steps, + generator=generator, + ).images[0] + + return np.array(images) + +# Gets the templates +preset_template = glob.glob("./images/templates/*.png") +preset_template = preset_template + glob.glob("./images/templates/*.jpg") + +with gr.Blocks(title="ConsistentID Demo") as demo: + gr.Markdown("# ConsistentID Demo") + gr.Markdown("\ + Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)") + gr.Markdown("\ + If you find our work interesting, please leave a star in GitHub for us!
\ + https://github.com/JackAILab/ConsistentID") + with gr.Row(): + with gr.Column(): + model_selected_tab = gr.State(0) + with gr.TabItem("template images") as template_images_tab: + template_gallery_list = [(i, i) for i in preset_template] + gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False) + + def select_function(evt: gr.SelectData): + return preset_template[evt.index] + + selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected") + gallery.select(select_function, None, selected_template_images) + with gr.TabItem("Upload Image") as upload_image_tab: + custom_image = gr.Image(label="Upload Image") + + model_selected_tabs = [template_images_tab, upload_image_tab] + for i, tab in enumerate(model_selected_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab]) + + with gr.Column(): + prompt_selected_tab = gr.State(0) + with gr.TabItem("template prompts") as template_prompts_tab: + prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[ + "A woman in a wedding dress", + "A woman, queen, in a gorgeous palace", + "A man sitting at the beach with sunset", + "A person, police officer, half body shot", + "A man, sailor, in a boat above ocean", + "A women wearing headphone, listening music", + "A man, firefighter, half body shot"], label=f"prepared prompts") + + with gr.TabItem("custom prompt") as custom_prompt_tab: + prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat") + nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry") + + prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab] + for i, tab in enumerate(prompt_selected_tabs): + tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab]) + + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=1.0, + maximum=10.0, + step=1.0, + value=5.0, + ) + + width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8) + height = gr.Slider(label="image height",minimum=256,maximum=768,value=512,step=8) + width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height]) + height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width]) + merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1) + seed_set = gr.Slider(label="set the random seed for different results",minimum=1,maximum=2147483647,value=2024,step=1) + + btn = gr.Button("Run") + with gr.Column(): + out = gr.Image(label="Output") + gr.Markdown(''' + N.B.:
+ - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.) + - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female. + - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible

+ ''') + btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected, + model_selected_tab, prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set], outputs=out) + +demo.launch(server_name='0.0.0.0', ssl_verify=False) diff --git a/ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png b/ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png new file mode 100644 index 0000000000000000000000000000000000000000..96322e923d0fe21a29aaf0c0ff81593c4dd8e45b --- /dev/null +++ b/ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fa9319750b9927075934c40a180766e75ff539711293581dae6bac5963b9d05 +size 2061666 diff --git a/ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png b/ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png new file mode 100644 index 0000000000000000000000000000000000000000..08c3c9ada9bb08ec63939621721cdb3e7b7c3b1e Binary files /dev/null and b/ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png differ diff --git a/ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png b/ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png new file mode 100644 index 0000000000000000000000000000000000000000..cb4c91b4840d6e4e62d0242083d66d7f071d66a7 --- /dev/null +++ b/ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:318c942eb3cc8a1f9320b2ea84a88cd95067785c07f8ae1dd18fe6c4cf8e8282 +size 7543309 diff --git a/ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg b/ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c0c4d207ca292c88f846aea31b218ef75a39bcd5 Binary files /dev/null and b/ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg differ diff --git a/ConsistentID/lib/BiSeNet/6.jpg b/ConsistentID/lib/BiSeNet/6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5408c6ecf94a1f260dac210a15e3be073bba6c1d Binary files /dev/null and b/ConsistentID/lib/BiSeNet/6.jpg differ diff --git a/ConsistentID/lib/BiSeNet/__init__.py b/ConsistentID/lib/BiSeNet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34c103210a7fa7fda0b895e183e4f3cbc831f92b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/__init__.py @@ -0,0 +1,2 @@ +#__init__.py +# from BiSeNet.model import * \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/evaluate.py b/ConsistentID/lib/BiSeNet/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..578d75c7e8b4dceeb20cc599ad9062b67311724e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/evaluate.py @@ -0,0 +1,95 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +from logger import setup_logger +import BiSeNet +from face_dataset import FaceMask + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.nn.functional as F +import torch.distributed as dist + +import os +import os.path as osp +import logging +import time +import numpy as np +from tqdm import tqdm +import math +from PIL import Image +import torchvision.transforms as transforms +import cv2 + +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): + # Colors for all 20 parts + part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + # print(vis_parsing_anno_color.shape, vis_im.shape) + vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + + # Save result or not + if save_im: + cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + + # return vis_im + +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): + + if not os.path.exists(respth): + os.makedirs(respth) + + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = osp.join('res/cp', cp) + net.load_state_dict(torch.load(save_pth)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + with torch.no_grad(): + for image_path in os.listdir(dspth): + img = Image.open(osp.join(dspth, image_path)) + image = img.resize((512, 512), Image.BILINEAR) + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.cuda() + out = net(img)[0] + parsing = out.squeeze(0).cpu().numpy().argmax(0) + + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) + + + + + + + +if __name__ == "__main__": + setup_logger('./res') + evaluate() diff --git a/ConsistentID/lib/BiSeNet/face_dataset.py b/ConsistentID/lib/BiSeNet/face_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ece7fb0afd127c7bf085c769540145838e270e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/face_dataset.py @@ -0,0 +1,106 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +from torch.utils.data import Dataset +import torchvision.transforms as transforms + +import os.path as osp +import os +from PIL import Image +import numpy as np +import json +import cv2 + +from transform import * + + + +class FaceMask(Dataset): + def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs): + super(FaceMask, self).__init__(*args, **kwargs) + assert mode in ('train', 'val', 'test') + self.mode = mode + self.ignore_lb = 255 + self.rootpth = rootpth + + self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img')) + + # pre-processing + self.to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + self.trans_train = Compose([ + ColorJitter( + brightness=0.5, + contrast=0.5, + saturation=0.5), + HorizontalFlip(), + RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)), + RandomCrop(cropsize) + ]) + + def __getitem__(self, idx): + impth = self.imgs[idx] + img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth)) + img = img.resize((512, 512), Image.BILINEAR) + label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P') + # print(np.unique(np.array(label))) + if self.mode == 'train': + im_lb = dict(im=img, lb=label) + im_lb = self.trans_train(im_lb) + img, label = im_lb['im'], im_lb['lb'] + img = self.to_tensor(img) + label = np.array(label).astype(np.int64)[np.newaxis, :] + return img, label + + def __len__(self): + return len(self.imgs) + + +if __name__ == "__main__": + face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img' + face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' + mask_path = '/home/zll/data/CelebAMask-HQ/mask' + counter = 0 + total = 0 + for i in range(15): + # files = os.listdir(osp.join(face_sep_mask, str(i))) + + atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', + 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] + + for j in range(i*2000, (i+1)*2000): + + mask = np.zeros((512, 512)) + + for l, att in enumerate(atts, 1): + total += 1 + file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) + path = osp.join(face_sep_mask, str(i), file_name) + + if os.path.exists(path): + counter += 1 + sep_mask = np.array(Image.open(path).convert('P')) + # print(np.unique(sep_mask)) + + mask[sep_mask == 225] = l + cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) + print(j) + + print(counter, total) + + + + + + + + + + + + + + diff --git a/ConsistentID/lib/BiSeNet/hair.png b/ConsistentID/lib/BiSeNet/hair.png new file mode 100644 index 0000000000000000000000000000000000000000..07d194f77af5ccbde364500dafc43b96ebfb5c8b Binary files /dev/null and b/ConsistentID/lib/BiSeNet/hair.png differ diff --git a/ConsistentID/lib/BiSeNet/logger.py b/ConsistentID/lib/BiSeNet/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f9ddcc2cae221b4dd881d02404e848b5396f7e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/logger.py @@ -0,0 +1,23 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import os.path as osp +import time +import sys +import logging + +import torch.distributed as dist + + +def setup_logger(logpth): + logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) + logfile = osp.join(logpth, logfile) + FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' + log_level = logging.INFO + if dist.is_initialized() and not dist.get_rank()==0: + log_level = logging.ERROR + logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) + logging.root.addHandler(logging.StreamHandler()) + + diff --git a/ConsistentID/lib/BiSeNet/loss.py b/ConsistentID/lib/BiSeNet/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..62657de66de513995c87acb81108a35d941fe37f --- /dev/null +++ b/ConsistentID/lib/BiSeNet/loss.py @@ -0,0 +1,72 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class OhemCELoss(nn.Module): + def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs): + super(OhemCELoss, self).__init__() + self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() + self.n_min = n_min + self.ignore_lb = ignore_lb + self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') + + def forward(self, logits, labels): + N, C, H, W = logits.size() + loss = self.criteria(logits, labels).view(-1) + loss, _ = torch.sort(loss, descending=True) + if loss[self.n_min] > self.thresh: + loss = loss[loss>self.thresh] + else: + loss = loss[:self.n_min] + return torch.mean(loss) + + +class SoftmaxFocalLoss(nn.Module): + def __init__(self, gamma, ignore_lb=255, *args, **kwargs): + super(SoftmaxFocalLoss, self).__init__() + self.gamma = gamma + self.nll = nn.NLLLoss(ignore_index=ignore_lb) + + def forward(self, logits, labels): + scores = F.softmax(logits, dim=1) + factor = torch.pow(1.-scores, self.gamma) + log_score = F.log_softmax(logits, dim=1) + log_score = factor * log_score + loss = self.nll(log_score, labels) + return loss + + +if __name__ == '__main__': + torch.manual_seed(15) + criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() + criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda() + net1 = nn.Sequential( + nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), + ) + net1.cuda() + net1.train() + net2 = nn.Sequential( + nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1), + ) + net2.cuda() + net2.train() + + with torch.no_grad(): + inten = torch.randn(16, 3, 20, 20).cuda() + lbs = torch.randint(0, 19, [16, 20, 20]).cuda() + lbs[1, :, :] = 255 + + logits1 = net1(inten) + logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear') + logits2 = net2(inten) + logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear') + + loss1 = criteria1(logits1, lbs) + loss2 = criteria2(logits2, lbs) + loss = loss1 + loss2 + print(loss.detach().cpu()) + loss.backward() diff --git a/ConsistentID/lib/BiSeNet/makeup.py b/ConsistentID/lib/BiSeNet/makeup.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8ceee9944f4f41e97027b2c1f57bbbad912036 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/makeup.py @@ -0,0 +1,129 @@ +import cv2 +import numpy as np +from skimage.filters import gaussian + + +def sharpen(img): + img = img * 1.0 + gauss_out = gaussian(img, sigma=5, multichannel=True) + + alpha = 1.5 + img_out = (img - gauss_out) * alpha + img + + img_out = img_out / 255.0 + + mask_1 = img_out < 0 + mask_2 = img_out > 1 + + img_out = img_out * (1 - mask_1) + img_out = img_out * (1 - mask_2) + mask_2 + img_out = np.clip(img_out, 0, 1) + img_out = img_out * 255 + return np.array(img_out, dtype=np.uint8) + + +def hair(image, parsing, part=17, color=[230, 50, 20]): + b, g, r = color #[10, 50, 250] # [10, 250, 10] + tar_color = np.zeros_like(image) + tar_color[:, :, 0] = b + tar_color[:, :, 1] = g + tar_color[:, :, 2] = r + + image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV) + + if part == 12 or part == 13: + image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2] + else: + image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1] + + changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR) + + if part == 17: + changed = sharpen(changed) + + changed[parsing != part] = image[parsing != part] + # changed = cv2.resize(changed, (512, 512)) + return changed + +# +# def lip(image, parsing, part=17, color=[230, 50, 20]): +# b, g, r = color #[10, 50, 250] # [10, 250, 10] +# tar_color = np.zeros_like(image) +# tar_color[:, :, 0] = b +# tar_color[:, :, 1] = g +# tar_color[:, :, 2] = r +# +# image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab) +# il, ia, ib = cv2.split(image_lab) +# +# tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab) +# tl, ta, tb = cv2.split(tar_lab) +# +# image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100) +# image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128) +# image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128) +# +# +# changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR) +# +# if part == 17: +# changed = sharpen(changed) +# +# changed[parsing != part] = image[parsing != part] +# # changed = cv2.resize(changed, (512, 512)) +# return changed + + +if __name__ == '__main__': + # 1 face + # 10 nose + # 11 teeth + # 12 upper lip + # 13 lower lip + # 17 hair + num = 116 + table = { + 'hair': 17, + 'upper_lip': 12, + 'lower_lip': 13 + } + image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num) + parsing_path = 'res/test_res/{}.png'.format(num) + + image = cv2.imread(image_path) + ori = image.copy() + parsing = np.array(cv2.imread(parsing_path, 0)) + parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST) + + parts = [table['hair'], table['upper_lip'], table['lower_lip']] + # colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]] + colors = [[100, 200, 100]] + for part, color in zip(parts, colors): + image = hair(image, parsing, part, color) + cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512))) + cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512))) + + cv2.imshow('image', cv2.resize(ori, (512, 512))) + cv2.imshow('color', cv2.resize(image, (512, 512))) + + # cv2.imshow('image', ori) + # cv2.imshow('color', image) + + cv2.waitKey(0) + cv2.destroyAllWindows() + + + + + + + + + + + + + + + diff --git a/ConsistentID/lib/BiSeNet/makeup/116_1.png b/ConsistentID/lib/BiSeNet/makeup/116_1.png new file mode 100644 index 0000000000000000000000000000000000000000..dc90bde07e39e824f82c3d055640088f36260d66 Binary files /dev/null and b/ConsistentID/lib/BiSeNet/makeup/116_1.png differ diff --git a/ConsistentID/lib/BiSeNet/makeup/116_3.png b/ConsistentID/lib/BiSeNet/makeup/116_3.png new file mode 100644 index 0000000000000000000000000000000000000000..4970ca1108621d784bcd40291867f6efcf8d2112 Binary files /dev/null and b/ConsistentID/lib/BiSeNet/makeup/116_3.png differ diff --git a/ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png b/ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png new file mode 100644 index 0000000000000000000000000000000000000000..dbd53f43018bd2f8086e24cd645ef6bee7d89812 Binary files /dev/null and b/ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png differ diff --git a/ConsistentID/lib/BiSeNet/makeup/116_ori.png b/ConsistentID/lib/BiSeNet/makeup/116_ori.png new file mode 100644 index 0000000000000000000000000000000000000000..1372e84c3d939d39fffc54c2ec3077dd5cd79f9a Binary files /dev/null and b/ConsistentID/lib/BiSeNet/makeup/116_ori.png differ diff --git a/ConsistentID/lib/BiSeNet/model.py b/ConsistentID/lib/BiSeNet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..54ecdcf2a50e553e259eb17883c8f2148960b4cc --- /dev/null +++ b/ConsistentID/lib/BiSeNet/model.py @@ -0,0 +1,282 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/ConsistentID/lib/BiSeNet/modules/__init__.py b/ConsistentID/lib/BiSeNet/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a098dee5911f3613d320d23db37bc401cf57fa4 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/__init__.py @@ -0,0 +1,5 @@ +from .bn import ABN, InPlaceABN, InPlaceABNSync +from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE +from .misc import GlobalAvgPool2d, SingleGPU +from .residual import IdentityResidualBlock +from .dense import DenseModule diff --git a/ConsistentID/lib/BiSeNet/modules/bn.py b/ConsistentID/lib/BiSeNet/modules/bn.py new file mode 100644 index 0000000000000000000000000000000000000000..cd3928bccfd3f70233414d837876b323217864c8 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/bn.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.nn.functional as functional + +try: + from queue import Queue +except ImportError: + from Queue import Queue + +from .functions import * + + +class ABN(nn.Module): + """Activated Batch Normalization + + This gathers a `BatchNorm2d` and an activation function in a single module + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): + """Creates an Activated Batch Normalization module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics as. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + activation : str + Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. + slope : float + Negative slope for the `leaky_relu` activation. + """ + super(ABN, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + self.activation = activation + self.slope = slope + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + if self.activation == ACT_RELU: + return functional.relu(x, inplace=True) + elif self.activation == ACT_LEAKY_RELU: + return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) + elif self.activation == ACT_ELU: + return functional.elu(x, inplace=True) + else: + return x + + def __repr__(self): + rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ + ' affine={affine}, activation={activation}' + if self.activation == "leaky_relu": + rep += ', slope={slope})' + else: + rep += ')' + return rep.format(name=self.__class__.__name__, **self.__dict__) + + +class InPlaceABN(ABN): + """InPlace Activated Batch Normalization""" + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): + """Creates an InPlace Activated Batch Normalization module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics as. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + activation : str + Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. + slope : float + Negative slope for the `leaky_relu` activation. + """ + super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) + + def forward(self, x): + return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.activation, self.slope) + + +class InPlaceABNSync(ABN): + """InPlace Activated Batch Normalization with cross-GPU synchronization + This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`. + """ + + def forward(self, x): + return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.activation, self.slope) + + def __repr__(self): + rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ + ' affine={affine}, activation={activation}' + if self.activation == "leaky_relu": + rep += ', slope={slope})' + else: + rep += ')' + return rep.format(name=self.__class__.__name__, **self.__dict__) + + diff --git a/ConsistentID/lib/BiSeNet/modules/deeplab.py b/ConsistentID/lib/BiSeNet/modules/deeplab.py new file mode 100644 index 0000000000000000000000000000000000000000..fd25b78369b27ef02c183a0b17b9bf8354c5f7c3 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/deeplab.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as functional + +from models._util import try_index +from .bn import ABN + + +class DeeplabV3(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels=256, + dilations=(12, 24, 36), + norm_act=ABN, + pooling_size=None): + super(DeeplabV3, self).__init__() + self.pooling_size = pooling_size + + self.map_convs = nn.ModuleList([ + nn.Conv2d(in_channels, hidden_channels, 1, bias=False), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]), + nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2]) + ]) + self.map_bn = norm_act(hidden_channels * 4) + + self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) + self.global_pooling_bn = norm_act(hidden_channels) + + self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False) + self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False) + self.red_bn = norm_act(out_channels) + + self.reset_parameters(self.map_bn.activation, self.map_bn.slope) + + def reset_parameters(self, activation, slope): + gain = nn.init.calculate_gain(activation, slope) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight.data, gain) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, ABN): + if hasattr(m, "weight") and m.weight is not None: + nn.init.constant_(m.weight, 1) + if hasattr(m, "bias") and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + # Map convolutions + out = torch.cat([m(x) for m in self.map_convs], dim=1) + out = self.map_bn(out) + out = self.red_conv(out) + + # Global pooling + pool = self._global_pooling(x) + pool = self.global_pooling_conv(pool) + pool = self.global_pooling_bn(pool) + pool = self.pool_red_conv(pool) + if self.training or self.pooling_size is None: + pool = pool.repeat(1, 1, x.size(2), x.size(3)) + + out += pool + out = self.red_bn(out) + return out + + def _global_pooling(self, x): + if self.training or self.pooling_size is None: + pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1) + pool = pool.view(x.size(0), x.size(1), 1, 1) + else: + pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), + min(try_index(self.pooling_size, 1), x.shape[3])) + padding = ( + (pooling_size[1] - 1) // 2, + (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, + (pooling_size[0] - 1) // 2, + (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 + ) + + pool = functional.avg_pool2d(x, pooling_size, stride=1) + pool = functional.pad(pool, pad=padding, mode="replicate") + return pool diff --git a/ConsistentID/lib/BiSeNet/modules/dense.py b/ConsistentID/lib/BiSeNet/modules/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..9638d6e86d2ae838550fefa9002a984af52e6cc8 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/dense.py @@ -0,0 +1,42 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn + +from .bn import ABN + + +class DenseModule(nn.Module): + def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1): + super(DenseModule, self).__init__() + self.in_channels = in_channels + self.growth = growth + self.layers = layers + + self.convs1 = nn.ModuleList() + self.convs3 = nn.ModuleList() + for i in range(self.layers): + self.convs1.append(nn.Sequential(OrderedDict([ + ("bn", norm_act(in_channels)), + ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False)) + ]))) + self.convs3.append(nn.Sequential(OrderedDict([ + ("bn", norm_act(self.growth * bottleneck_factor)), + ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False, + dilation=dilation)) + ]))) + in_channels += self.growth + + @property + def out_channels(self): + return self.in_channels + self.growth * self.layers + + def forward(self, x): + inputs = [x] + for i in range(self.layers): + x = torch.cat(inputs, dim=1) + x = self.convs1[i](x) + x = self.convs3[i](x) + inputs += [x] + + return torch.cat(inputs, dim=1) diff --git a/ConsistentID/lib/BiSeNet/modules/functions.py b/ConsistentID/lib/BiSeNet/modules/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..093615ff4f383e95712c96b57286338ec3b28f3b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/functions.py @@ -0,0 +1,234 @@ +from os import path +import torch +import torch.distributed as dist +import torch.autograd as autograd +import torch.cuda.comm as comm +from torch.autograd.function import once_differentiable +from torch.utils.cpp_extension import load + +_src_path = path.join(path.dirname(path.abspath(__file__)), "src") +_backend = load(name="inplace_abn", + extra_cflags=["-O3"], + sources=[path.join(_src_path, f) for f in [ + "inplace_abn.cpp", + "inplace_abn_cpu.cpp", + "inplace_abn_cuda.cu", + "inplace_abn_cuda_half.cu" + ]], + extra_cuda_cflags=["--expt-extended-lambda"]) + +# Activation names +ACT_RELU = "relu" +ACT_LEAKY_RELU = "leaky_relu" +ACT_ELU = "elu" +ACT_NONE = "none" + + +def _check(fn, *args, **kwargs): + success = fn(*args, **kwargs) + if not success: + raise RuntimeError("CUDA Error encountered in {}".format(fn)) + + +def _broadcast_shape(x): + out_size = [] + for i, s in enumerate(x.size()): + if i != 1: + out_size.append(1) + else: + out_size.append(s) + return out_size + + +def _reduce(x): + if len(x.size()) == 2: + return x.sum(dim=0) + else: + n, c = x.size()[0:2] + return x.contiguous().view((n, c, -1)).sum(2).sum(0) + + +def _count_samples(x): + count = 1 + for i, s in enumerate(x.size()): + if i != 1: + count *= s + return count + + +def _act_forward(ctx, x): + if ctx.activation == ACT_LEAKY_RELU: + _backend.leaky_relu_forward(x, ctx.slope) + elif ctx.activation == ACT_ELU: + _backend.elu_forward(x) + elif ctx.activation == ACT_NONE: + pass + + +def _act_backward(ctx, x, dx): + if ctx.activation == ACT_LEAKY_RELU: + _backend.leaky_relu_backward(x, dx, ctx.slope) + elif ctx.activation == ACT_ELU: + _backend.elu_backward(x, dx) + elif ctx.activation == ACT_NONE: + pass + + +class InPlaceABN(autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): + # Save context + ctx.training = training + ctx.momentum = momentum + ctx.eps = eps + ctx.activation = activation + ctx.slope = slope + ctx.affine = weight is not None and bias is not None + + # Prepare inputs + count = _count_samples(x) + x = x.contiguous() + weight = weight.contiguous() if ctx.affine else x.new_empty(0) + bias = bias.contiguous() if ctx.affine else x.new_empty(0) + + if ctx.training: + mean, var = _backend.mean_var(x) + + # Update running stats + running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) + running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) + + # Mark in-place modified tensors + ctx.mark_dirty(x, running_mean, running_var) + else: + mean, var = running_mean.contiguous(), running_var.contiguous() + ctx.mark_dirty(x) + + # BN forward + activation + _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) + _act_forward(ctx, x) + + # Output + ctx.var = var + ctx.save_for_backward(x, var, weight, bias) + return x + + @staticmethod + @once_differentiable + def backward(ctx, dz): + z, var, weight, bias = ctx.saved_tensors + dz = dz.contiguous() + + # Undo activation + _act_backward(ctx, z, dz) + + if ctx.training: + edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) + else: + # TODO: implement simplified CUDA backward for inference mode + edz = dz.new_zeros(dz.size(1)) + eydz = dz.new_zeros(dz.size(1)) + + dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) + dweight = eydz * weight.sign() if ctx.affine else None + dbias = edz if ctx.affine else None + + return dx, dweight, dbias, None, None, None, None, None, None, None + +class InPlaceABNSync(autograd.Function): + @classmethod + def forward(cls, ctx, x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True): + # Save context + ctx.training = training + ctx.momentum = momentum + ctx.eps = eps + ctx.activation = activation + ctx.slope = slope + ctx.affine = weight is not None and bias is not None + + # Prepare inputs + ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1 + + #count = _count_samples(x) + batch_size = x.new_tensor([x.shape[0]],dtype=torch.long) + + x = x.contiguous() + weight = weight.contiguous() if ctx.affine else x.new_empty(0) + bias = bias.contiguous() if ctx.affine else x.new_empty(0) + + if ctx.training: + mean, var = _backend.mean_var(x) + if ctx.world_size>1: + # get global batch size + if equal_batches: + batch_size *= ctx.world_size + else: + dist.all_reduce(batch_size, dist.ReduceOp.SUM) + + ctx.factor = x.shape[0]/float(batch_size.item()) + + mean_all = mean.clone() * ctx.factor + dist.all_reduce(mean_all, dist.ReduceOp.SUM) + + var_all = (var + (mean - mean_all) ** 2) * ctx.factor + dist.all_reduce(var_all, dist.ReduceOp.SUM) + + mean = mean_all + var = var_all + + # Update running stats + running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) + count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1] + running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1))) + + # Mark in-place modified tensors + ctx.mark_dirty(x, running_mean, running_var) + else: + mean, var = running_mean.contiguous(), running_var.contiguous() + ctx.mark_dirty(x) + + # BN forward + activation + _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) + _act_forward(ctx, x) + + # Output + ctx.var = var + ctx.save_for_backward(x, var, weight, bias) + return x + + @staticmethod + @once_differentiable + def backward(ctx, dz): + z, var, weight, bias = ctx.saved_tensors + dz = dz.contiguous() + + # Undo activation + _act_backward(ctx, z, dz) + + if ctx.training: + edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) + edz_local = edz.clone() + eydz_local = eydz.clone() + + if ctx.world_size>1: + edz *= ctx.factor + dist.all_reduce(edz, dist.ReduceOp.SUM) + + eydz *= ctx.factor + dist.all_reduce(eydz, dist.ReduceOp.SUM) + else: + edz_local = edz = dz.new_zeros(dz.size(1)) + eydz_local = eydz = dz.new_zeros(dz.size(1)) + + dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) + dweight = eydz_local * weight.sign() if ctx.affine else None + dbias = edz_local if ctx.affine else None + + return dx, dweight, dbias, None, None, None, None, None, None, None + +inplace_abn = InPlaceABN.apply +inplace_abn_sync = InPlaceABNSync.apply + +__all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] diff --git a/ConsistentID/lib/BiSeNet/modules/misc.py b/ConsistentID/lib/BiSeNet/modules/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..3c50b69b38c950801baacba8b3684ffd23aef08b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/misc.py @@ -0,0 +1,21 @@ +import torch.nn as nn +import torch +import torch.distributed as dist + +class GlobalAvgPool2d(nn.Module): + def __init__(self): + """Global average pooling over the input's spatial dimensions""" + super(GlobalAvgPool2d, self).__init__() + + def forward(self, inputs): + in_size = inputs.size() + return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) + +class SingleGPU(nn.Module): + def __init__(self, module): + super(SingleGPU, self).__init__() + self.module=module + + def forward(self, input): + return self.module(input.cuda(non_blocking=True)) + diff --git a/ConsistentID/lib/BiSeNet/modules/residual.py b/ConsistentID/lib/BiSeNet/modules/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d51ad274f3841813c1584a0ceb60ce58979d94 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/residual.py @@ -0,0 +1,88 @@ +from collections import OrderedDict + +import torch.nn as nn + +from .bn import ABN + + +class IdentityResidualBlock(nn.Module): + def __init__(self, + in_channels, + channels, + stride=1, + dilation=1, + groups=1, + norm_act=ABN, + dropout=None): + """Configurable identity-mapping residual block + + Parameters + ---------- + in_channels : int + Number of input channels. + channels : list of int + Number of channels in the internal feature maps. Can either have two or three elements: if three construct + a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then + `3 x 3` then `1 x 1` convolutions. + stride : int + Stride of the first `3 x 3` convolution + dilation : int + Dilation to apply to the `3 x 3` convolutions. + groups : int + Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with + bottleneck blocks. + norm_act : callable + Function to create normalization / activation Module. + dropout: callable + Function to create Dropout Module. + """ + super(IdentityResidualBlock, self).__init__() + + # Check parameters for inconsistencies + if len(channels) != 2 and len(channels) != 3: + raise ValueError("channels must contain either two or three values") + if len(channels) == 2 and groups != 1: + raise ValueError("groups > 1 are only valid if len(channels) == 3") + + is_bottleneck = len(channels) == 3 + need_proj_conv = stride != 1 or in_channels != channels[-1] + + self.bn1 = norm_act(in_channels) + if not is_bottleneck: + layers = [ + ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, + dilation=dilation)), + ("bn2", norm_act(channels[0])), + ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, + dilation=dilation)) + ] + if dropout is not None: + layers = layers[0:2] + [("dropout", dropout())] + layers[2:] + else: + layers = [ + ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), + ("bn2", norm_act(channels[0])), + ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, + groups=groups, dilation=dilation)), + ("bn3", norm_act(channels[1])), + ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)) + ] + if dropout is not None: + layers = layers[0:4] + [("dropout", dropout())] + layers[4:] + self.convs = nn.Sequential(OrderedDict(layers)) + + if need_proj_conv: + self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) + + def forward(self, x): + if hasattr(self, "proj_conv"): + bn1 = self.bn1(x) + shortcut = self.proj_conv(bn1) + else: + shortcut = x.clone() + bn1 = self.bn1(x) + + out = self.convs(bn1) + out.add_(shortcut) + + return out diff --git a/ConsistentID/lib/BiSeNet/modules/src/checks.h b/ConsistentID/lib/BiSeNet/modules/src/checks.h new file mode 100644 index 0000000000000000000000000000000000000000..e761a6fe34d0789815b588eba7e3726026e0e868 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/checks.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT +#ifndef AT_CHECK +#define AT_CHECK AT_ASSERT +#endif + +#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") +#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") + +#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a6b1128cc20cbfc476134154e23e5869a92b856 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp @@ -0,0 +1,95 @@ +#include + +#include + +#include "inplace_abn.h" + +std::vector mean_var(at::Tensor x) { + if (x.is_cuda()) { + if (x.type().scalarType() == at::ScalarType::Half) { + return mean_var_cuda_h(x); + } else { + return mean_var_cuda(x); + } + } else { + return mean_var_cpu(x); + } +} + +at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + if (x.is_cuda()) { + if (x.type().scalarType() == at::ScalarType::Half) { + return forward_cuda_h(x, mean, var, weight, bias, affine, eps); + } else { + return forward_cuda(x, mean, var, weight, bias, affine, eps); + } + } else { + return forward_cpu(x, mean, var, weight, bias, affine, eps); + } +} + +std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + if (z.is_cuda()) { + if (z.type().scalarType() == at::ScalarType::Half) { + return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps); + } else { + return edz_eydz_cuda(z, dz, weight, bias, affine, eps); + } + } else { + return edz_eydz_cpu(z, dz, weight, bias, affine, eps); + } +} + +at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + if (z.is_cuda()) { + if (z.type().scalarType() == at::ScalarType::Half) { + return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps); + } else { + return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); + } + } else { + return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); + } +} + +void leaky_relu_forward(at::Tensor z, float slope) { + at::leaky_relu_(z, slope); +} + +void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { + if (z.is_cuda()) { + if (z.type().scalarType() == at::ScalarType::Half) { + return leaky_relu_backward_cuda_h(z, dz, slope); + } else { + return leaky_relu_backward_cuda(z, dz, slope); + } + } else { + return leaky_relu_backward_cpu(z, dz, slope); + } +} + +void elu_forward(at::Tensor z) { + at::elu_(z); +} + +void elu_backward(at::Tensor z, at::Tensor dz) { + if (z.is_cuda()) { + return elu_backward_cuda(z, dz); + } else { + return elu_backward_cpu(z, dz); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("mean_var", &mean_var, "Mean and variance computation"); + m.def("forward", &forward, "In-place forward computation"); + m.def("edz_eydz", &edz_eydz, "First part of backward computation"); + m.def("backward", &backward, "Second part of backward computation"); + m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); + m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); + m.def("elu_forward", &elu_forward, "Elu forward computation"); + m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); +} diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h new file mode 100644 index 0000000000000000000000000000000000000000..17afd1196449ecb6376f28961e54b55e1537492f --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h @@ -0,0 +1,88 @@ +#pragma once + +#include + +#include + +std::vector mean_var_cpu(at::Tensor x); +std::vector mean_var_cuda(at::Tensor x); +std::vector mean_var_cuda_h(at::Tensor x); + +at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps); + +std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps); +std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps); + +at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps); +at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps); +at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps); + +void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); +void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); +void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope); + +void elu_backward_cpu(at::Tensor z, at::Tensor dz); +void elu_backward_cuda(at::Tensor z, at::Tensor dz); + +static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { + num = x.size(0); + chn = x.size(1); + sp = 1; + for (int64_t i = 2; i < x.ndimension(); ++i) + sp *= x.size(i); +} + +/* + * Specialized CUDA reduction functions for BN + */ +#ifdef __CUDACC__ + +#include "utils/cuda.cuh" + +template +__device__ T reduce(Op op, int plane, int N, int S) { + T sum = (T)0; + for (int batch = 0; batch < N; ++batch) { + for (int x = threadIdx.x; x < S; x += blockDim.x) { + sum += op(batch, plane, x); + } + } + + // sum over NumThreads within a warp + sum = warpSum(sum); + + // 'transpose', and reduce within warp again + __shared__ T shared[32]; + __syncthreads(); + if (threadIdx.x % WARP_SIZE == 0) { + shared[threadIdx.x / WARP_SIZE] = sum; + } + if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { + // zero out the other entries in shared + shared[threadIdx.x] = (T)0; + } + __syncthreads(); + if (threadIdx.x / WARP_SIZE == 0) { + sum = warpSum(shared[threadIdx.x]); + if (threadIdx.x == 0) { + shared[0] = sum; + } + } + __syncthreads(); + + // Everyone picks it up, should be broadcast into the whole gradInput + return shared[0]; +} +#endif diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffc6d38c52ea31661b8dd438dc3fe1958f50b61e --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp @@ -0,0 +1,119 @@ +#include + +#include + +#include "utils/checks.h" +#include "inplace_abn.h" + +at::Tensor reduce_sum(at::Tensor x) { + if (x.ndimension() == 2) { + return x.sum(0); + } else { + auto x_view = x.view({x.size(0), x.size(1), -1}); + return x_view.sum(-1).sum(0); + } +} + +at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { + if (x.ndimension() == 2) { + return v; + } else { + std::vector broadcast_size = {1, -1}; + for (int64_t i = 2; i < x.ndimension(); ++i) + broadcast_size.push_back(1); + + return v.view(broadcast_size); + } +} + +int64_t count(at::Tensor x) { + int64_t count = x.size(0); + for (int64_t i = 2; i < x.ndimension(); ++i) + count *= x.size(i); + + return count; +} + +at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { + if (affine) { + return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); + } else { + return z; + } +} + +std::vector mean_var_cpu(at::Tensor x) { + auto num = count(x); + auto mean = reduce_sum(x) / num; + auto diff = x - broadcast_to(mean, x); + auto var = reduce_sum(diff.pow(2)) / num; + + return {mean, var}; +} + +at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); + auto mul = at::rsqrt(var + eps) * gamma; + + x.sub_(broadcast_to(mean, x)); + x.mul_(broadcast_to(mul, x)); + if (affine) x.add_(broadcast_to(bias, x)); + + return x; +} + +std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + auto edz = reduce_sum(dz); + auto y = invert_affine(z, weight, bias, affine, eps); + auto eydz = reduce_sum(y * dz); + + return {edz, eydz}; +} + +at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + auto y = invert_affine(z, weight, bias, affine, eps); + auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); + + auto num = count(z); + auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); + return dx; +} + +void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { + CHECK_CPU_INPUT(z); + CHECK_CPU_INPUT(dz); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { + int64_t count = z.numel(); + auto *_z = z.data(); + auto *_dz = dz.data(); + + for (int64_t i = 0; i < count; ++i) { + if (_z[i] < 0) { + _z[i] *= 1 / slope; + _dz[i] *= slope; + } + } + })); +} + +void elu_backward_cpu(at::Tensor z, at::Tensor dz) { + CHECK_CPU_INPUT(z); + CHECK_CPU_INPUT(dz); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { + int64_t count = z.numel(); + auto *_z = z.data(); + auto *_dz = dz.data(); + + for (int64_t i = 0; i < count; ++i) { + if (_z[i] < 0) { + _z[i] = log1p(_z[i]); + _dz[i] *= (_z[i] + 1.f); + } + } + })); +} diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..b157b06d47173d1645c6a40c89f564b737e84d43 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu @@ -0,0 +1,333 @@ +#include + +#include +#include + +#include + +#include "utils/checks.h" +#include "utils/cuda.cuh" +#include "inplace_abn.h" + +#include + +// Operations for reduce +template +struct SumOp { + __device__ SumOp(const T *t, int c, int s) + : tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ T operator()(int batch, int plane, int n) { + return tensor[(batch * chn + plane) * sp + n]; + } + const T *tensor; + const int chn; + const int sp; +}; + +template +struct VarOp { + __device__ VarOp(T m, const T *t, int c, int s) + : mean(m), tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ T operator()(int batch, int plane, int n) { + T val = tensor[(batch * chn + plane) * sp + n]; + return (val - mean) * (val - mean); + } + const T mean; + const T *tensor; + const int chn; + const int sp; +}; + +template +struct GradOp { + __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) + : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} + __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { + T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; + T _dz = dz[(batch * chn + plane) * sp + n]; + return Pair(_dz, _y * _dz); + } + const T weight; + const T bias; + const T *z; + const T *dz; + const int chn; + const int sp; +}; + +/*********** + * mean_var + ***********/ + +template +__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { + int plane = blockIdx.x; + T norm = T(1) / T(num * sp); + + T _mean = reduce>(SumOp(x, chn, sp), plane, num, sp) * norm; + __syncthreads(); + T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, sp) * norm; + + if (threadIdx.x == 0) { + mean[plane] = _mean; + var[plane] = _var; + } +} + +std::vector mean_var_cuda(at::Tensor x) { + CHECK_CUDA_INPUT(x); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Prepare output tensors + auto mean = at::empty({chn}, x.options()); + auto var = at::empty({chn}, x.options()); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { + mean_var_kernel<<>>( + x.data(), + mean.data(), + var.data(), + num, chn, sp); + })); + + return {mean, var}; +} + +/********** + * forward + **********/ + +template +__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, + bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + T _mean = mean[plane]; + T _var = var[plane]; + T _weight = affine ? abs(weight[plane]) + eps : T(1); + T _bias = affine ? bias[plane] : T(0); + + T mul = rsqrt(_var + eps) * _weight; + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + T _x = x[(batch * chn + plane) * sp + n]; + T _y = (_x - _mean) * mul + _bias; + + x[(batch * chn + plane) * sp + n] = _y; + } + } +} + +at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(x); + CHECK_CUDA_INPUT(mean); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { + forward_kernel<<>>( + x.data(), + mean.data(), + var.data(), + weight.data(), + bias.data(), + affine, eps, num, chn, sp); + })); + + return x; +} + +/*********** + * edz_eydz + ***********/ + +template +__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, + T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + T _weight = affine ? abs(weight[plane]) + eps : 1.f; + T _bias = affine ? bias[plane] : 0.f; + + Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, sp); + __syncthreads(); + + if (threadIdx.x == 0) { + edz[plane] = res.v1; + eydz[plane] = res.v2; + } +} + +std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto edz = at::empty({chn}, z.options()); + auto eydz = at::empty({chn}, z.options()); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { + edz_eydz_kernel<<>>( + z.data(), + dz.data(), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + affine, eps, num, chn, sp); + })); + + return {edz, eydz}; +} + +/*********** + * backward + ***********/ + +template +__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, + const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + T _weight = affine ? abs(weight[plane]) + eps : 1.f; + T _bias = affine ? bias[plane] : 0.f; + T _var = var[plane]; + T _edz = edz[plane]; + T _eydz = eydz[plane]; + + T _mul = _weight * rsqrt(_var + eps); + T count = T(num * sp); + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + T _dz = dz[(batch * chn + plane) * sp + n]; + T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; + + dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; + } + } +} + +at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + CHECK_CUDA_INPUT(edz); + CHECK_CUDA_INPUT(eydz); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto dx = at::zeros_like(z); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { + backward_kernel<<>>( + z.data(), + dz.data(), + var.data(), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + dx.data(), + affine, eps, num, chn, sp); + })); + + return dx; +} + +/************** + * activations + **************/ + +template +inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { + // Create thrust pointers + thrust::device_ptr th_z = thrust::device_pointer_cast(z); + thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); + + auto stream = at::cuda::getCurrentCUDAStream(); + thrust::transform_if(thrust::cuda::par.on(stream), + th_dz, th_dz + count, th_z, th_dz, + [slope] __device__ (const T& dz) { return dz * slope; }, + [] __device__ (const T& z) { return z < 0; }); + thrust::transform_if(thrust::cuda::par.on(stream), + th_z, th_z + count, th_z, + [slope] __device__ (const T& z) { return z / slope; }, + [] __device__ (const T& z) { return z < 0; }); +} + +void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + + int64_t count = z.numel(); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { + leaky_relu_backward_impl(z.data(), dz.data(), slope, count); + })); +} + +template +inline void elu_backward_impl(T *z, T *dz, int64_t count) { + // Create thrust pointers + thrust::device_ptr th_z = thrust::device_pointer_cast(z); + thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); + + auto stream = at::cuda::getCurrentCUDAStream(); + thrust::transform_if(thrust::cuda::par.on(stream), + th_dz, th_dz + count, th_z, th_z, th_dz, + [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, + [] __device__ (const T& z) { return z < 0; }); + thrust::transform_if(thrust::cuda::par.on(stream), + th_z, th_z + count, th_z, + [] __device__ (const T& z) { return log1p(z); }, + [] __device__ (const T& z) { return z < 0; }); +} + +void elu_backward_cuda(at::Tensor z, at::Tensor dz) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + + int64_t count = z.numel(); + + AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { + elu_backward_impl(z.data(), dz.data(), count); + })); +} diff --git a/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu new file mode 100644 index 0000000000000000000000000000000000000000..bb63e73f9d90179e5bd5dae5579c4844da9c25e2 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu @@ -0,0 +1,275 @@ +#include + +#include + +#include + +#include "utils/checks.h" +#include "utils/cuda.cuh" +#include "inplace_abn.h" + +#include + +// Operations for reduce +struct SumOpH { + __device__ SumOpH(const half *t, int c, int s) + : tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ float operator()(int batch, int plane, int n) { + return __half2float(tensor[(batch * chn + plane) * sp + n]); + } + const half *tensor; + const int chn; + const int sp; +}; + +struct VarOpH { + __device__ VarOpH(float m, const half *t, int c, int s) + : mean(m), tensor(t), chn(c), sp(s) {} + __device__ __forceinline__ float operator()(int batch, int plane, int n) { + const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]); + return (t - mean) * (t - mean); + } + const float mean; + const half *tensor; + const int chn; + const int sp; +}; + +struct GradOpH { + __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s) + : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} + __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { + float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight; + float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); + return Pair(_dz, _y * _dz); + } + const float weight; + const float bias; + const half *z; + const half *dz; + const int chn; + const int sp; +}; + +/*********** + * mean_var + ***********/ + +__global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) { + int plane = blockIdx.x; + float norm = 1.f / static_cast(num * sp); + + float _mean = reduce(SumOpH(x, chn, sp), plane, num, sp) * norm; + __syncthreads(); + float _var = reduce(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm; + + if (threadIdx.x == 0) { + mean[plane] = _mean; + var[plane] = _var; + } +} + +std::vector mean_var_cuda_h(at::Tensor x) { + CHECK_CUDA_INPUT(x); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Prepare output tensors + auto mean = at::empty({chn},x.options().dtype(at::kFloat)); + auto var = at::empty({chn},x.options().dtype(at::kFloat)); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + mean_var_kernel_h<<>>( + reinterpret_cast(x.data()), + mean.data(), + var.data(), + num, chn, sp); + + return {mean, var}; +} + +/********** + * forward + **********/ + +__global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias, + bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + const float _mean = mean[plane]; + const float _var = var[plane]; + const float _weight = affine ? abs(weight[plane]) + eps : 1.f; + const float _bias = affine ? bias[plane] : 0.f; + + const float mul = rsqrt(_var + eps) * _weight; + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + half *x_ptr = x + (batch * chn + plane) * sp + n; + float _x = __half2float(*x_ptr); + float _y = (_x - _mean) * mul + _bias; + + *x_ptr = __float2half(_y); + } + } +} + +at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(x); + CHECK_CUDA_INPUT(mean); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(x, num, chn, sp); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + forward_kernel_h<<>>( + reinterpret_cast(x.data()), + mean.data(), + var.data(), + weight.data(), + bias.data(), + affine, eps, num, chn, sp); + + return x; +} + +__global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias, + float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + float _weight = affine ? abs(weight[plane]) + eps : 1.f; + float _bias = affine ? bias[plane] : 0.f; + + Pair res = reduce, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp); + __syncthreads(); + + if (threadIdx.x == 0) { + edz[plane] = res.v1; + eydz[plane] = res.v2; + } +} + +std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, + bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto edz = at::empty({chn},z.options().dtype(at::kFloat)); + auto eydz = at::empty({chn},z.options().dtype(at::kFloat)); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + edz_eydz_kernel_h<<>>( + reinterpret_cast(z.data()), + reinterpret_cast(dz.data()), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + affine, eps, num, chn, sp); + + return {edz, eydz}; +} + +__global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz, + const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) { + int plane = blockIdx.x; + + float _weight = affine ? abs(weight[plane]) + eps : 1.f; + float _bias = affine ? bias[plane] : 0.f; + float _var = var[plane]; + float _edz = edz[plane]; + float _eydz = eydz[plane]; + + float _mul = _weight * rsqrt(_var + eps); + float count = float(num * sp); + + for (int batch = 0; batch < num; ++batch) { + for (int n = threadIdx.x; n < sp; n += blockDim.x) { + float _dz = __half2float(dz[(batch * chn + plane) * sp + n]); + float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight; + + dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul); + } + } +} + +at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, + at::Tensor edz, at::Tensor eydz, bool affine, float eps) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + CHECK_CUDA_INPUT(var); + CHECK_CUDA_INPUT(weight); + CHECK_CUDA_INPUT(bias); + CHECK_CUDA_INPUT(edz); + CHECK_CUDA_INPUT(eydz); + + // Extract dimensions + int64_t num, chn, sp; + get_dims(z, num, chn, sp); + + auto dx = at::zeros_like(z); + + // Run kernel + dim3 blocks(chn); + dim3 threads(getNumThreads(sp)); + auto stream = at::cuda::getCurrentCUDAStream(); + backward_kernel_h<<>>( + reinterpret_cast(z.data()), + reinterpret_cast(dz.data()), + var.data(), + weight.data(), + bias.data(), + edz.data(), + eydz.data(), + reinterpret_cast(dx.data()), + affine, eps, num, chn, sp); + + return dx; +} + +__global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){ + float _z = __half2float(z[i]); + if (_z < 0) { + dz[i] = __float2half(__half2float(dz[i]) * slope); + z[i] = __float2half(_z / slope); + } + } +} + +void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) { + CHECK_CUDA_INPUT(z); + CHECK_CUDA_INPUT(dz); + + int64_t count = z.numel(); + dim3 threads(getNumThreads(count)); + dim3 blocks = (count + threads.x - 1) / threads.x; + auto stream = at::cuda::getCurrentCUDAStream(); + leaky_relu_backward_impl_h<<>>( + reinterpret_cast(z.data()), + reinterpret_cast(dz.data()), + slope, count); +} + diff --git a/ConsistentID/lib/BiSeNet/modules/src/utils/checks.h b/ConsistentID/lib/BiSeNet/modules/src/utils/checks.h new file mode 100644 index 0000000000000000000000000000000000000000..e761a6fe34d0789815b588eba7e3726026e0e868 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/utils/checks.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT +#ifndef AT_CHECK +#define AT_CHECK AT_ASSERT +#endif + +#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") +#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") + +#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/modules/src/utils/common.h b/ConsistentID/lib/BiSeNet/modules/src/utils/common.h new file mode 100644 index 0000000000000000000000000000000000000000..e8403eef8a233b75dd4bb353c16486fe1be2039a --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/utils/common.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +/* + * Functions to share code between CPU and GPU + */ + +#ifdef __CUDACC__ +// CUDA versions + +#define HOST_DEVICE __host__ __device__ +#define INLINE_HOST_DEVICE __host__ __device__ inline +#define FLOOR(x) floor(x) + +#if __CUDA_ARCH__ >= 600 +// Recent compute capabilities have block-level atomicAdd for all data types, so we use that +#define ACCUM(x,y) atomicAdd_block(&(x),(y)) +#else +// Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float +// and use the known atomicCAS-based implementation for double +template +__device__ inline data_t atomic_add(data_t *address, data_t val) { + return atomicAdd(address, val); +} + +template<> +__device__ inline double atomic_add(double *address, double val) { + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} + +#define ACCUM(x,y) atomic_add(&(x),(y)) +#endif // #if __CUDA_ARCH__ >= 600 + +#else +// CPU versions + +#define HOST_DEVICE +#define INLINE_HOST_DEVICE inline +#define FLOOR(x) std::floor(x) +#define ACCUM(x,y) (x) += (y) + +#endif // #ifdef __CUDACC__ \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh b/ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..60c0023835e02c5f7c539c28ac07b75b72df394b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh @@ -0,0 +1,71 @@ +#pragma once + +/* + * General settings and functions + */ +const int WARP_SIZE = 32; +const int MAX_BLOCK_SIZE = 1024; + +static int getNumThreads(int nElem) { + int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE}; + for (int i = 0; i < 6; ++i) { + if (nElem <= threadSizes[i]) { + return threadSizes[i]; + } + } + return MAX_BLOCK_SIZE; +} + +/* + * Reduction utilities + */ +template +__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { +#if CUDART_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } + +template +struct Pair { + T v1, v2; + __device__ Pair() {} + __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} + __device__ Pair(T v) : v1(v), v2(v) {} + __device__ Pair(int v) : v1(v), v2(v) {} + __device__ Pair &operator+=(const Pair &a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } +}; + +template +static __device__ __forceinline__ T warpSum(T val) { +#if __CUDA_ARCH__ >= 300 + for (int i = 0; i < getMSB(WARP_SIZE); ++i) { + val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); + } +#else + __shared__ T values[MAX_BLOCK_SIZE]; + values[threadIdx.x] = val; + __threadfence_block(); + const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; + for (int i = 1; i < WARP_SIZE; i++) { + val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; + } +#endif + return val; +} + +template +static __device__ __forceinline__ Pair warpSum(Pair value) { + value.v1 = warpSum(value.v1); + value.v2 = warpSum(value.v2); + return value; +} \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/optimizer.py b/ConsistentID/lib/BiSeNet/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c99e0645164b22f1e743ee99daadadd26a1cd80 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/optimizer.py @@ -0,0 +1,69 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import logging + +logger = logging.getLogger() + +class Optimizer(object): + def __init__(self, + model, + lr0, + momentum, + wd, + warmup_steps, + warmup_start_lr, + max_iter, + power, + *args, **kwargs): + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr + self.lr0 = lr0 + self.lr = self.lr0 + self.max_iter = float(max_iter) + self.power = power + self.it = 0 + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params() + param_list = [ + {'params': wd_params}, + {'params': nowd_params, 'weight_decay': 0}, + {'params': lr_mul_wd_params, 'lr_mul': True}, + {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}] + self.optim = torch.optim.SGD( + param_list, + lr = lr0, + momentum = momentum, + weight_decay = wd) + self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps) + + + def get_lr(self): + if self.it <= self.warmup_steps: + lr = self.warmup_start_lr*(self.warmup_factor**self.it) + else: + factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power + lr = self.lr0 * factor + return lr + + + def step(self): + self.lr = self.get_lr() + for pg in self.optim.param_groups: + if pg.get('lr_mul', False): + pg['lr'] = self.lr * 10 + else: + pg['lr'] = self.lr + if self.optim.defaults.get('lr_mul', False): + self.optim.defaults['lr'] = self.lr * 10 + else: + self.optim.defaults['lr'] = self.lr + self.it += 1 + self.optim.step() + if self.it == self.warmup_steps+2: + logger.info('==> warmup done, start to implement poly lr strategy') + + def zero_grad(self): + self.optim.zero_grad() + diff --git a/ConsistentID/lib/BiSeNet/prepropess_data.py b/ConsistentID/lib/BiSeNet/prepropess_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ee7ed56dd8c0372d482e6a53f323da17043bd521 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/prepropess_data.py @@ -0,0 +1,38 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import os.path as osp +import os +import cv2 +from transform import * +from PIL import Image + +face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img' +face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno' +mask_path = '/home/zll/data/CelebAMask-HQ/mask' +counter = 0 +total = 0 +for i in range(15): + + atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', + 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] + + for j in range(i * 2000, (i + 1) * 2000): + + mask = np.zeros((512, 512)) + + for l, att in enumerate(atts, 1): + total += 1 + file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png']) + path = osp.join(face_sep_mask, str(i), file_name) + + if os.path.exists(path): + counter += 1 + sep_mask = np.array(Image.open(path).convert('P')) + # print(np.unique(sep_mask)) + + mask[sep_mask == 225] = l + cv2.imwrite('{}/{}.png'.format(mask_path, j), mask) + print(j) + +print(counter, total) \ No newline at end of file diff --git a/ConsistentID/lib/BiSeNet/resnet.py b/ConsistentID/lib/BiSeNet/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2bf95130e9815ba378cb6f73207068b81a04b9 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/ConsistentID/lib/BiSeNet/test.py b/ConsistentID/lib/BiSeNet/test.py new file mode 100644 index 0000000000000000000000000000000000000000..604a89f6e86a6a18581022620c413a43abece91b --- /dev/null +++ b/ConsistentID/lib/BiSeNet/test.py @@ -0,0 +1,90 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +from logger import setup_logger +import BiSeNet + +import torch + +import os +import os.path as osp +import numpy as np +from PIL import Image +import torchvision.transforms as transforms +import cv2 + +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'): + # Colors for all 20 parts + part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + # print(vis_parsing_anno_color.shape, vis_im.shape) + vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + + # Save result or not + if save_im: + cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno) + cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + + # return vis_im + +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): + + if not os.path.exists(respth): + os.makedirs(respth) + + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + save_pth = osp.join('res/cp', cp) + net.load_state_dict(torch.load(save_pth)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + with torch.no_grad(): + for image_path in os.listdir(dspth): + img = Image.open(osp.join(dspth, image_path)) + image = img.resize((512, 512), Image.BILINEAR) + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.cuda() + out = net(img)[0] + parsing = out.squeeze(0).cpu().numpy().argmax(0) + # print(parsing) + print(np.unique(parsing)) + + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path)) + + + + + + + +if __name__ == "__main__": + evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth') + + diff --git a/ConsistentID/lib/BiSeNet/train.py b/ConsistentID/lib/BiSeNet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ca481944086fc19f320f01c4f2c0f1ab7aef5a83 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/train.py @@ -0,0 +1,179 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +from logger import setup_logger +import BiSeNet +from face_dataset import FaceMask +from loss import OhemCELoss +from evaluate import evaluate +from optimizer import Optimizer +import cv2 +import numpy as np + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.nn.functional as F +import torch.distributed as dist + +import os +import os.path as osp +import logging +import time +import datetime +import argparse + + +respth = './res' +if not osp.exists(respth): + os.makedirs(respth) +logger = logging.getLogger() + + +def parse_args(): + parse = argparse.ArgumentParser() + parse.add_argument( + '--local_rank', + dest = 'local_rank', + type = int, + default = -1, + ) + return parse.parse_args() + + +def train(): + args = parse_args() + torch.cuda.set_device(args.local_rank) + dist.init_process_group( + backend = 'nccl', + init_method = 'tcp://127.0.0.1:33241', + world_size = torch.cuda.device_count(), + rank=args.local_rank + ) + setup_logger(respth) + + # dataset + n_classes = 19 + n_img_per_gpu = 16 + n_workers = 8 + cropsize = [448, 448] + data_root = '/home/zll/data/CelebAMask-HQ/' + + ds = FaceMask(data_root, cropsize=cropsize, mode='train') + sampler = torch.utils.data.distributed.DistributedSampler(ds) + dl = DataLoader(ds, + batch_size = n_img_per_gpu, + shuffle = False, + sampler = sampler, + num_workers = n_workers, + pin_memory = True, + drop_last = True) + + # model + ignore_idx = -100 + net = BiSeNet(n_classes=n_classes) + net.cuda() + net.train() + net = nn.parallel.DistributedDataParallel(net, + device_ids = [args.local_rank, ], + output_device = args.local_rank + ) + score_thres = 0.7 + n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 + LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) + + ## optimizer + momentum = 0.9 + weight_decay = 5e-4 + lr_start = 1e-2 + max_iter = 80000 + power = 0.9 + warmup_steps = 1000 + warmup_start_lr = 1e-5 + optim = Optimizer( + model = net.module, + lr0 = lr_start, + momentum = momentum, + wd = weight_decay, + warmup_steps = warmup_steps, + warmup_start_lr = warmup_start_lr, + max_iter = max_iter, + power = power) + + ## train loop + msg_iter = 50 + loss_avg = [] + st = glob_st = time.time() + diter = iter(dl) + epoch = 0 + for it in range(max_iter): + try: + im, lb = next(diter) + if not im.size()[0] == n_img_per_gpu: + raise StopIteration + except StopIteration: + epoch += 1 + sampler.set_epoch(epoch) + diter = iter(dl) + im, lb = next(diter) + im = im.cuda() + lb = lb.cuda() + H, W = im.size()[2:] + lb = torch.squeeze(lb, 1) + + optim.zero_grad() + out, out16, out32 = net(im) + lossp = LossP(out, lb) + loss2 = Loss2(out16, lb) + loss3 = Loss3(out32, lb) + loss = lossp + loss2 + loss3 + loss.backward() + optim.step() + + loss_avg.append(loss.item()) + + # print training log message + if (it+1) % msg_iter == 0: + loss_avg = sum(loss_avg) / len(loss_avg) + lr = optim.lr + ed = time.time() + t_intv, glob_t_intv = ed - st, ed - glob_st + eta = int((max_iter - it) * (glob_t_intv / it)) + eta = str(datetime.timedelta(seconds=eta)) + msg = ', '.join([ + 'it: {it}/{max_it}', + 'lr: {lr:4f}', + 'loss: {loss:.4f}', + 'eta: {eta}', + 'time: {time:.4f}', + ]).format( + it = it+1, + max_it = max_iter, + lr = lr, + loss = loss_avg, + time = t_intv, + eta = eta + ) + logger.info(msg) + loss_avg = [] + st = ed + if dist.get_rank() == 0: + if (it+1) % 5000 == 0: + state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() + if dist.get_rank() == 0: + torch.save(state, './res/cp/{}_iter.pth'.format(it)) + evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it)) + + # dump the final model + save_pth = osp.join(respth, 'model_final_diss.pth') + # net.cpu() + state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() + if dist.get_rank() == 0: + torch.save(state, save_pth) + logger.info('training done, model saved to: {}'.format(save_pth)) + + +if __name__ == "__main__": + train() diff --git a/ConsistentID/lib/BiSeNet/transform.py b/ConsistentID/lib/BiSeNet/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..9479ae356a151f5da8eedf288abeae7458739d24 --- /dev/null +++ b/ConsistentID/lib/BiSeNet/transform.py @@ -0,0 +1,129 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +from PIL import Image +import PIL.ImageEnhance as ImageEnhance +import random +import numpy as np + +class RandomCrop(object): + def __init__(self, size, *args, **kwargs): + self.size = size + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + assert im.size == lb.size + W, H = self.size + w, h = im.size + + if (W, H) == (w, h): return dict(im=im, lb=lb) + if w < W or h < H: + scale = float(W) / w if w < h else float(H) / h + w, h = int(scale * w + 1), int(scale * h + 1) + im = im.resize((w, h), Image.BILINEAR) + lb = lb.resize((w, h), Image.NEAREST) + sw, sh = random.random() * (w - W), random.random() * (h - H) + crop = int(sw), int(sh), int(sw) + W, int(sh) + H + return dict( + im = im.crop(crop), + lb = lb.crop(crop) + ) + + +class HorizontalFlip(object): + def __init__(self, p=0.5, *args, **kwargs): + self.p = p + + def __call__(self, im_lb): + if random.random() > self.p: + return im_lb + else: + im = im_lb['im'] + lb = im_lb['lb'] + + # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', + # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] + + flip_lb = np.array(lb) + flip_lb[lb == 2] = 3 + flip_lb[lb == 3] = 2 + flip_lb[lb == 4] = 5 + flip_lb[lb == 5] = 4 + flip_lb[lb == 7] = 8 + flip_lb[lb == 8] = 7 + flip_lb = Image.fromarray(flip_lb) + return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT), + lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT), + ) + + +class RandomScale(object): + def __init__(self, scales=(1, ), *args, **kwargs): + self.scales = scales + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + W, H = im.size + scale = random.choice(self.scales) + w, h = int(W * scale), int(H * scale) + return dict(im = im.resize((w, h), Image.BILINEAR), + lb = lb.resize((w, h), Image.NEAREST), + ) + + +class ColorJitter(object): + def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs): + if not brightness is None and brightness>0: + self.brightness = [max(1-brightness, 0), 1+brightness] + if not contrast is None and contrast>0: + self.contrast = [max(1-contrast, 0), 1+contrast] + if not saturation is None and saturation>0: + self.saturation = [max(1-saturation, 0), 1+saturation] + + def __call__(self, im_lb): + im = im_lb['im'] + lb = im_lb['lb'] + r_brightness = random.uniform(self.brightness[0], self.brightness[1]) + r_contrast = random.uniform(self.contrast[0], self.contrast[1]) + r_saturation = random.uniform(self.saturation[0], self.saturation[1]) + im = ImageEnhance.Brightness(im).enhance(r_brightness) + im = ImageEnhance.Contrast(im).enhance(r_contrast) + im = ImageEnhance.Color(im).enhance(r_saturation) + return dict(im = im, + lb = lb, + ) + + +class MultiScale(object): + def __init__(self, scales): + self.scales = scales + + def __call__(self, img): + W, H = img.size + sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales] + imgs = [] + [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes] + return imgs + + +class Compose(object): + def __init__(self, do_list): + self.do_list = do_list + + def __call__(self, im_lb): + for comp in self.do_list: + im_lb = comp(im_lb) + return im_lb + + + + +if __name__ == '__main__': + flip = HorizontalFlip(p = 1) + crop = RandomCrop((321, 321)) + rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0)) + img = Image.open('data/img.jpg') + lb = Image.open('data/label.png') diff --git a/ConsistentID/lib/attention.py b/ConsistentID/lib/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9124b3cabd0bc2cafba5b23cfa09bc6aa6261ca8 --- /dev/null +++ b/ConsistentID/lib/attention.py @@ -0,0 +1,287 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.lora import LoRALinearLayer +from .functions import AttentionMLP + +class FuseModule(nn.Module): + def __init__(self, embed_dim): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) + self.layer_norm = nn.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + valid_id_mask, + ) -> torch.Tensor: + id_embeds = id_embeds.to(prompt_embeds.dtype) + batch_size, max_num_inputs = id_embeds.shape[:2] # 1,5 + seq_length = prompt_embeds.shape[1] # 77 + flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1]) + # flat_id_embeds torch.Size([5, 1, 768]) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + # valid_id_embeds torch.Size([4, 1, 768]) + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) # torch.Size([77, 768]) + class_tokens_mask = class_tokens_mask.view(-1) # torch.Size([77]) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # torch.Size([4, 768]) + image_token_embeds = prompt_embeds[class_tokens_mask] # torch.Size([4, 768]) + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) # torch.Size([4, 768]) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + + return updated_prompt_embeds + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + +class FacialEncoder(nn.Module): + def __init__(self): + super().__init__() + self.visual_projection = AttentionMLP() + self.fuse_module = FuseModule(768) + + def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask): + bs, num_inputs, token_length, image_dim = multi_image_embeds.shape + multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim) + id_embeds = self.visual_projection(multi_image_embeds_view) # torch.Size([5, 1, 768]) + id_embeds = id_embeds.view(bs, num_inputs, 1, -1) + # fuse_module replaces the class tokens in prompt_embeds with the fused (id_embeds, prompt_embeds[class_tokens_mask]) + # whose indices are specified by class_tokens_mask. + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask) + return updated_prompt_embeds + +class Consistent_AttProcessor(nn.Module): + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class Consistent_IPAttProcessor(nn.Module): + + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + self.num_tokens = num_tokens + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + for module in [self.to_q_lora, self.to_k_lora, self.to_v_lora, self.to_out_lora, self.to_k_ip, self.to_v_ip]: + for param in module.parameters(): + param.requires_grad = False + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + scale=1.0, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/ConsistentID/lib/functions.py b/ConsistentID/lib/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..e343952cc891459c21ba3a773a2cbd8ad3cede64 --- /dev/null +++ b/ConsistentID/lib/functions.py @@ -0,0 +1,606 @@ +import numpy as np +import math +import types +import torch +import torch.nn as nn +import numpy as np +import cv2 +import re +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange +from PIL import Image + +def extract_first_sentence(text): + end_index = text.find('.') + if end_index != -1: + first_sentence = text[:end_index + 1] + return first_sentence.strip() + else: + return text.strip() + +import re +def remove_duplicate_keywords(text, keywords): + keyword_counts = {} + + words = re.findall(r'\b\w+\b|[.,;!?]', text) + + for keyword in keywords: + keyword_counts[keyword] = 0 + for i, word in enumerate(words): + if word.lower() == keyword.lower(): + keyword_counts[keyword] += 1 + if keyword_counts[keyword] > 1: + words[i] = "" + processed_text = " ".join(words) + + return processed_text + +# text: 'The person has one nose , two eyes , two ears , and a mouth .' +def insert_markers_to_prompt(text, parsing_mask_dict): + keywords = ["face", "ears", "eyes", "nose", "mouth"] + text = remove_duplicate_keywords(text, keywords) + key_parsing_mask_markers = ["Nose", "Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Upper_Lip", "Lower_Lip"] + mapping = { + "Face": "face", + "Left_Ear": "ears", + "Right_Ear": "ears", + "Left_Eye": "eyes", + "Right_Eye": "eyes", + "Nose": "nose", + "Upper_Lip": "mouth", + "Lower_Lip": "mouth", + } + facial_features_align = [] + markers_align = [] + for key in key_parsing_mask_markers: + if key in parsing_mask_dict: + mapped_key = mapping.get(key, key.lower()) + if mapped_key not in facial_features_align: + facial_features_align.append(mapped_key) + markers_align.append("<|" + mapped_key + "|>") + + text_marked = text + align_parsing_mask_dict = parsing_mask_dict + for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]): + pattern = rf'\b{feature}\b' + text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1) + if text_marked == text_marked_new: + for key, value in mapping.items(): + if value == feature: + if key in align_parsing_mask_dict: + del align_parsing_mask_dict[key] + + text_marked = text_marked_new + + text_marked = text_marked.replace('\n', '') + + ordered_text = [] + text_none_makers = [] + facial_marked_count = 0 + skip_count = 0 + for marker in markers_align: + start_idx = text_marked.find(marker) + end_idx = start_idx + len(marker) + + while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]: + start_idx -= 1 + + while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]: + end_idx += 1 + + context = text_marked[start_idx:end_idx].strip() + if context == "": + text_none_makers.append(text_marked[:end_idx]) + else: + if skip_count!=0: + skip_count -= 1 + continue + else: + ordered_text.append(context + ", ") + text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:] + text_marked = text_delete_makers + facial_marked_count += 1 + + # ordered_text: ['The person has one nose <|nose|>, ', 'two ears <|ears|>, ', + # 'two eyes <|eyes|>, ', 'and a mouth <|mouth|>, '] + # align_parsing_mask_dict.keys(): ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip'] + align_marked_text = "".join(ordered_text) + replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"] + for item in replace_list: + align_marked_text = align_marked_text.replace(item, "<|facial|>") + + # align_marked_text: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, ' + return align_marked_text, align_parsing_mask_dict + +def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer): + input_ids = tokenizer.encode(text) + image_noun_phrase_end_mask = [False for _ in input_ids] + facial_noun_phrase_end_mask = [False for _ in input_ids] + clean_input_ids = [] + clean_index = 0 + image_num = 0 + + for i, id in enumerate(input_ids): + if id == image_token_id: + image_noun_phrase_end_mask[clean_index + image_num - 1] = True + image_num += 1 + elif id == facial_token_id: + facial_noun_phrase_end_mask[clean_index - 1] = True + else: + clean_input_ids.append(id) + clean_index += 1 + + max_len = tokenizer.model_max_length + + if len(clean_input_ids) > max_len: + clean_input_ids = clean_input_ids[:max_len] + else: + clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( + max_len - len(clean_input_ids) + ) + + if len(image_noun_phrase_end_mask) > max_len: + image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len] + else: + image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * ( + max_len - len(image_noun_phrase_end_mask) + ) + + if len(facial_noun_phrase_end_mask) > max_len: + facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len] + else: + facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * ( + max_len - len(facial_noun_phrase_end_mask) + ) + clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long) + image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool) + facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool) + + return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0) + +def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5): + image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1] + image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool) + if len(image_token_idx) < max_num_objects: + image_token_idx = torch.cat( + [ + image_token_idx, + torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long), + ] + ) + image_token_idx_mask = torch.cat( + [ + image_token_idx_mask, + torch.zeros( + max_num_objects - len(image_token_idx_mask), + dtype=torch.bool, + ), + ] + ) + facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1] + facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool) + if len(facial_token_idx) < max_num_facials: + facial_token_idx = torch.cat( + [ + facial_token_idx, + torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long), + ] + ) + facial_token_idx_mask = torch.cat( + [ + facial_token_idx_mask, + torch.zeros( + max_num_facials - len(facial_token_idx_mask), + dtype=torch.bool, + ), + ] + ) + image_token_idx = image_token_idx.unsqueeze(0) + image_token_idx_mask = image_token_idx_mask.unsqueeze(0) + + facial_token_idx = facial_token_idx.unsqueeze(0) + facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0) + + return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask + +def get_object_localization_loss_for_one_layer( + cross_attention_scores, + object_segmaps, + object_token_idx, + object_token_idx_mask, + loss_fn, +): + bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape + b, max_num_objects, _, _ = object_segmaps.shape + size = int(num_noise_latents**0.5) + + object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True) + + object_segmaps = object_segmaps.view( + b, max_num_objects, -1 + ) + + num_heads = bxh // b + cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens) + + + object_token_attn_prob = torch.gather( + cross_attention_scores, + dim=3, + index=object_token_idx.view(b, 1, 1, max_num_objects).expand( + b, num_heads, num_noise_latents, max_num_objects + ), + ) + object_segmaps = ( + object_segmaps.permute(0, 2, 1) + .unsqueeze(1) + .expand(b, num_heads, num_noise_latents, max_num_objects) + ) + loss = loss_fn(object_token_attn_prob, object_segmaps) + + loss = loss * object_token_idx_mask.view(b, 1, max_num_objects) + object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5 + loss = (loss.sum(dim=2) / object_token_cnt).mean() + + return loss + + +def get_object_localization_loss( + cross_attention_scores, + object_segmaps, + image_token_idx, + image_token_idx_mask, + loss_fn, +): + num_layers = len(cross_attention_scores) + loss = 0 + for k, v in cross_attention_scores.items(): + layer_loss = get_object_localization_loss_for_one_layer( + v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn + ) + loss += layer_loss + return loss / num_layers + +def unet_store_cross_attention_scores(unet, attention_scores, layers=5): + from diffusers.models.attention_processor import Attention + + UNET_LAYER_NAMES = [ + "down_blocks.0", + "down_blocks.1", + "down_blocks.2", + "mid_block", + "up_blocks.1", + "up_blocks.2", + "up_blocks.3", + ] + + start_layer = (len(UNET_LAYER_NAMES) - layers) // 2 + end_layer = start_layer + layers + applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer] + + def make_new_get_attention_scores_fn(name): + def new_get_attention_scores(module, query, key, attention_mask=None): + attention_probs = module.old_get_attention_scores( + query, key, attention_mask + ) + attention_scores[name] = attention_probs + return attention_probs + + return new_get_attention_scores + + for name, module in unet.named_modules(): + if isinstance(module, Attention) and "attn1" in name: + if not any(layer in name for layer in applicable_layers): + continue + + module.old_get_attention_scores = module.get_attention_scores + module.get_attention_scores = types.MethodType( + make_new_get_attention_scores_fn(name), module + ) + return unet + +class BalancedL1Loss(nn.Module): + def __init__(self, threshold=1.0, normalize=False): + super().__init__() + self.threshold = threshold + self.normalize = normalize + + def forward(self, object_token_attn_prob, object_segmaps): + if self.normalize: + object_token_attn_prob = object_token_attn_prob / ( + object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5 + ) + background_segmaps = 1 - object_segmaps + background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5 + object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5 + + background_loss = (object_token_attn_prob * background_segmaps).sum( + dim=2 + ) / background_segmaps_sum + + object_loss = (object_token_attn_prob * object_segmaps).sum( + dim=2 + ) / object_segmaps_sum + + return background_loss - object_loss + +def apply_mask_to_raw_image(raw_image, mask_image): + mask_image = mask_image.resize(raw_image.size) + mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image) + return mask_raw_image + +mapping_table = [ + {"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]}, + {"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]}, + {"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]}, + {"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]}, + {"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]}, + {"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]}, + {"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]}, + {"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]}, + {"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]}, + {"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]}, + {"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]}, + {"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]}, + {"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]}, + {"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]}, + {"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]}, + {"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]}, + {"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]}, + {"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]}, + {"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]}, + {"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]}, + {"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]}, + {"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]}, + {"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]}, + {"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]}, + {"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]} +] + + +def masks_for_unique_values(image_raw_mask): + + image_array = np.array(image_raw_mask) + unique_values, counts = np.unique(image_array, return_counts=True) + masks_dict = {} + for value in unique_values: + binary_image = np.uint8(image_array == value) * 255 + contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + mask = np.zeros_like(image_array) + for contour in contours: + cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED) + + if value == 0: + body_part="WithoutBackground" + mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype) + masks_dict[body_part] = Image.fromarray(mask2) + + body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}") + if body_part.startswith("Unknown_"): + continue + + masks_dict[body_part] = Image.fromarray(mask) + + return masks_dict +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + x = x.view(bs, length, heads, -1) + x = x.transpose(1, 2) + x = x.reshape(bs, heads, length, -1) + return x + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # x -> kv, latents -> q + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + # x -> kv, latents -> q + def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280]) + x = self.proj_in(x) # x.torch.Size([2, 257, 768]) + for attn, ff in self.layers: + # x -> kv, latents -> q + latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768]) + latents = ff(latents) + latents # latents.torch.Size([2, 4, 768]) + latents = self.proj_out(latents) + return self.norm_out(latents) + +class ProjPlusModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): + + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + # id_embeds -> x -> kv, clip_embeds -> q + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = scale * x + out + return out + +class AttentionMLP(nn.Module): + def __init__( + self, + dtype=torch.float16, + dim=1024, + depth=8, + dim_head=64, + heads=16, + single_num_tokens=1, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + max_seq_len: int = 257*2, + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.single_num_tokens = single_num_tokens + self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + # x torch.Size([5, 257, 1280]) + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) # torch.Size([5, 257, 1024]) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) + + diff --git a/ConsistentID/lib/pipeline_ConsistentID.py b/ConsistentID/lib/pipeline_ConsistentID.py new file mode 100644 index 0000000000000000000000000000000000000000..129c7bb21a5367d673299e5bb7a5501333a6d4e7 --- /dev/null +++ b/ConsistentID/lib/pipeline_ConsistentID.py @@ -0,0 +1,605 @@ +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import cv2 +import PIL +import numpy as np +from PIL import Image +import torch +from torchvision import transforms +from insightface.app import FaceAnalysis +### insight-face installation can be found at https://github.com/deepinsight/insightface +from safetensors import safe_open +from huggingface_hub.utils import validate_hf_hub_args +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline +from .functions import insert_markers_to_prompt, masks_for_unique_values, apply_mask_to_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx +from .functions import ProjPlusModel, masks_for_unique_values +from .attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder +from easydict import EasyDict as edict +from huggingface_hub import hf_hub_download +### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file +### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812 +### Thanks for the open source of face-parsing model. +from .BiSeNet.model import BiSeNet +import os + +PipelineImageInput = Union[ + PIL.Image.Image, + torch.FloatTensor, + List[PIL.Image.Image], + List[torch.FloatTensor], +] + +### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location. +class ConsistentIDPipeline(StableDiffusionPipeline): + # to() should be only called after all modules are loaded. + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(torch_device, dtype=dtype) + self.bise_net.to(torch_device, dtype=dtype) + self.clip_encoder.to(torch_device, dtype=dtype) + self.image_proj_model.to(torch_device, dtype=dtype) + self.FacialEncoder.to(torch_device, dtype=dtype) + # If the unet is not released, the ip_layers should be moved to the specified device and dtype. + if not isinstance(self.unet, edict): + self.ip_layers.to(torch_device, dtype=dtype) + return self + + @validate_hf_hub_args + def load_ConsistentID_model( + self, + consistentID_weight_path: str, + bise_net_weight_path: str, + trigger_word_facial: str = '<|facial|>', + # A CLIP ViT-H/14 model trained with the LAION-2B English subset of LAION-5B using OpenCLIP. + # output dim: 1280. + image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + torch_dtype = torch.float16, + num_tokens = 4, + lora_rank= 128, + **kwargs, + ): + self.lora_rank = lora_rank + self.torch_dtype = torch_dtype + self.num_tokens = num_tokens + self.set_ip_adapter() + self.image_encoder_path = image_encoder_path + self.clip_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path) + self.clip_preprocessor = CLIPImageProcessor() + self.id_image_processor = CLIPImageProcessor() + self.crop_size = 512 + + # face_app: FaceAnalysis object + self.face_app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CPUExecutionProvider']) + # The original det_size=(640, 640) is too large and face_app often fails to detect faces. + self.face_app.prepare(ctx_id=0, det_size=(512, 512)) + + if not os.path.exists(consistentID_weight_path): + ### Download pretrained models + hf_hub_download(repo_id="JackAILab/ConsistentID", repo_type="model", + filename=os.path.basename(consistentID_weight_path), + local_dir=os.path.dirname(consistentID_weight_path)) + if not os.path.exists(bise_net_weight_path): + hf_hub_download(repo_id="JackAILab/ConsistentID", + filename=os.path.basename(bise_net_weight_path), + local_dir=os.path.dirname(bise_net_weight_path)) + + bise_net = BiSeNet(n_classes = 19) + bise_net.load_state_dict(torch.load(bise_net_weight_path, map_location="cpu")) + bise_net.eval() + self.bise_net = bise_net + + # Colors for all 20 parts + self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], + [255, 0, 85], [255, 0, 170], + [0, 255, 0], [85, 255, 0], [170, 255, 0], + [0, 255, 85], [0, 255, 170], + [0, 0, 255], [85, 0, 255], [170, 0, 255], + [0, 85, 255], [0, 170, 255], + [255, 255, 0], [255, 255, 85], [255, 255, 170], + [255, 0, 255], [255, 85, 255], [255, 170, 255], + [0, 255, 255], [85, 255, 255], [170, 255, 255]] + + # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings. + self.image_proj_model = ProjPlusModel( + cross_attention_dim=self.unet.config.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.clip_encoder.config.hidden_size, + num_tokens=self.num_tokens, # 4 - inspirsed by IPAdapter and Midjourney + ) + self.FacialEncoder = FacialEncoder() + + if consistentID_weight_path.endswith(".safetensors"): + state_dict = {"id_encoder": {}, "lora_weights": {}} + with safe_open(consistentID_weight_path, framework="pt", device="cpu") as f: + ### TODO safetensors add + for key in f.keys(): + if key.startswith("FacialEncoder."): + state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key) + elif key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(consistentID_weight_path, map_location="cpu") + + self.trigger_word_facial = trigger_word_facial + + self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True) + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) + self.ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) + self.ip_layers.load_state_dict(state_dict["adapter_modules"], strict=True) + print(f"Successfully loaded weights from checkpoint") + + # Add trigger word token + if self.tokenizer is not None: + self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True) + + def set_ip_adapter(self): + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = Consistent_AttProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ) + else: + attn_procs[name] = Consistent_IPAttProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, + ) + + unet.set_attn_processor(attn_procs) + + @torch.inference_mode() + # parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image. + # clip_encoder maps image parts to image-space diffusion prompts. + # Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]). + def extract_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2, + facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True): + + hidden_states = [] + uncond_hidden_states = [] + for parsed_image_parts in parsed_image_parts2: + hidden_state = self.clip_encoder(parsed_image_parts.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2] + uncond_hidden_state = self.clip_encoder(torch.zeros_like(parsed_image_parts, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2] + hidden_states.append(hidden_state) + uncond_hidden_states.append(uncond_hidden_state) + multi_facial_embeds = torch.stack(hidden_states) + uncond_multi_facial_embeds = torch.stack(uncond_hidden_states) + + # conditional prompt. + # FacialEncoder maps multi_facial_embeds to facial ID embeddings, and replaces the class tokens in prompt_embeds + # with the fused (facial ID embeddings, prompt_embeds[class_tokens_mask]). + # multi_facial_embeds: [1, 5, 257, 1280]. + facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask) + + if not calc_uncond: + return facial_prompt_embeds, None + # unconditional prompt. + uncond_facial_prompt_embeds = self.FacialEncoder(uncond_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask) + + return facial_prompt_embeds, uncond_facial_prompt_embeds + + @torch.inference_mode() + # Extrat OpenCLIP embeddings from the input image and map them to face prompt embeddings. + def extract_global_id_embeds(self, face_image_obj, s_scale=1.0, shortcut=False): + clip_image_ts = self.clip_preprocessor(images=face_image_obj, return_tensors="pt").pixel_values + clip_image_ts = clip_image_ts.to(self.device, dtype=self.torch_dtype) + clip_image_embeds = self.clip_encoder(clip_image_ts, output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.clip_encoder(torch.zeros_like(clip_image_ts), output_hidden_states=True).hidden_states[-2] + + faceid_embeds = self.extract_faceid(face_image_obj) + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings. + # clip_image_embeds are used as queries to transform faceid_embeds. + # faceid_embeds -> kv, clip_image_embeds -> q + global_id_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) + uncond_global_id_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) + + return global_id_embeds, uncond_global_id_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, Consistent_IPAttProcessor): + attn_processor.scale = scale + + @torch.inference_mode() + def extract_faceid(self, face_image_obj): + faceid_image = np.array(face_image_obj) + faces = self.face_app.get(faceid_image) + if faces==[]: + faceid_embeds = torch.zeros_like(torch.empty((1, 512))) + else: + faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) + + return faceid_embeds + + @torch.inference_mode() + def parse_face_mask(self, raw_image_refer): + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + to_pil = transforms.ToPILImage() + + with torch.no_grad(): + image = raw_image_refer.resize((512, 512), Image.BILINEAR) + image_resize_PIL = image + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.to(self.device, dtype=self.torch_dtype) + out = self.bise_net(img)[0] + parsing_anno = out.squeeze(0).cpu().numpy().argmax(0) + + im = np.array(image_resize_PIL) + vis_im = im.copy().astype(np.uint8) + stride=1 + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 + + num_of_class = np.max(vis_parsing_anno) + + for pi in range(1, num_of_class + 1): # num_of_class=17 pi=1~16 + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi] + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0) + + return vis_parsing_anno_color, vis_parsing_anno + + @torch.inference_mode() + def extract_facemask(self, input_image_obj): + vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_obj) + parsing_mask_list = masks_for_unique_values(vis_parsing_anno) + + key_parsing_mask_dict = {} + key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"] + processed_keys = set() + for key, mask_image in parsing_mask_list.items(): + if key in key_list: + if "_" in key: + prefix = key.split("_")[1] + if prefix in processed_keys: + continue + else: + key_parsing_mask_dict[key] = mask_image + processed_keys.add(prefix) + + key_parsing_mask_dict[key] = mask_image + + return key_parsing_mask_dict, vis_parsing_anno_color + + def augment_prompt_with_trigger_word( + self, + prompt: str, + face_caption: str, + key_parsing_mask_dict = None, + facial_token = "<|facial|>", + max_num_facials = 5, + num_id_images: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + # face_caption_align: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, ' + face_caption_align, key_parsing_mask_dict_align = insert_markers_to_prompt(face_caption, key_parsing_mask_dict) + + prompt_face = prompt + " Detail: " + face_caption_align + + max_text_length=330 + if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length, + padding="max_length", truncation=False, return_tensors="pt").input_ids[0]) != 77: + # Put face_caption_align at the beginning of the prompt, so that the original prompt is truncated, + # but the face_caption_align is well kept. + prompt_face = "Detail: " + face_caption_align + " Caption:" + prompt + + # Remove "<|facial|>" from prompt_face. + # augmented_prompt: 'A person, police officer, half body shot Detail: + # The person has one nose , two ears , two eyes , and a mouth , ' + augmented_prompt = prompt_face.replace("<|facial|>", "") + tokenizer = self.tokenizer + facial_token_id = tokenizer.convert_tokens_to_ids(facial_token) + image_token_id = None + + # image_token_id: the token id of "<|image|>". Disabled, as it's set to None. + # facial_token_id: the token id of "<|facial|>". + clean_input_id, image_token_mask, facial_token_mask = \ + tokenize_and_mask_noun_phrases_ends(prompt_face, image_token_id, facial_token_id, tokenizer) + + image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \ + prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials) + + return augmented_prompt, clean_input_id, key_parsing_mask_dict_align, facial_token_mask, facial_token_idx, facial_token_idx_mask + + @torch.inference_mode() + def extract_parsed_image_parts(self, input_image_obj, key_parsing_mask_dict, image_size=512, max_num_facials=5): + facial_masks = [] + parsed_image_parts = [] + key_masked_raw_images_dict = {} + transform_mask = transforms.Compose([transforms.CenterCrop(size=image_size), transforms.ToTensor(),]) + clip_preprocessor = CLIPImageProcessor() + + num_facial_part = len(key_parsing_mask_dict) + + for key in key_parsing_mask_dict: + key_mask=key_parsing_mask_dict[key] + facial_masks.append(transform_mask(key_mask)) + key_masked_raw_image = apply_mask_to_raw_image(input_image_obj, key_mask) + key_masked_raw_images_dict[key] = key_masked_raw_image + # clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero. + # It also resizes the image to 224x224. + parsed_image_part = clip_preprocessor(images=key_masked_raw_image, return_tensors="pt").pixel_values + parsed_image_parts.append(parsed_image_part) + + padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224])) + padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size])) + + if num_facial_part < max_num_facials: + parsed_image_parts += [ torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ] + facial_masks += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part) ] + + parsed_image_parts = torch.stack(parsed_image_parts, dim=1).squeeze(0) + facial_masks = torch.stack(facial_masks, dim=0).squeeze(dim=1) + + return parsed_image_parts, facial_masks, key_masked_raw_images_dict + + # Release the unet/vae/text_encoder to save memory. + def release_components(self, released_components=["unet", "vae", "text_encoder"]): + if "unet" in released_components: + unet = edict() + # Only keep the config and in_channels attributes that are used in the pipeline. + unet.config = self.unet.config + self.unet = unet + + if "vae" in released_components: + self.vae = None + if "text_encoder" in released_components: + self.text_encoder = None + + # input_subj_image_obj: an Image object. + def extract_double_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True): + face_caption = "The person has one nose, two eyes, two ears, and a mouth." + key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj) + + augmented_prompt, clean_input_id, key_parsing_mask_dict_align, \ + facial_token_mask, facial_token_idx, facial_token_idx_mask \ + = self.augment_prompt_with_trigger_word( + prompt = prompt, + face_caption = face_caption, + key_parsing_mask_dict=key_parsing_mask_dict, + device=device, + max_num_facials = 5, + num_id_images = 1 + ) + + text_embeds, uncond_text_embeds = self.encode_prompt( + augmented_prompt, + device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=calc_uncond, + negative_prompt=negative_prompt, + ) + + # 5. Prepare the input ID images + # global_id_embeds: [1, 4, 768] + # extract_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings. + global_id_embeds, uncond_global_id_embeds = \ + self.extract_global_id_embeds(face_image_obj=input_subj_image_obj, s_scale=1.0, shortcut=False) + + # parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor). + parsed_image_parts, facial_masks, key_masked_raw_images_dict = \ + self.extract_parsed_image_parts(input_subj_image_obj, key_parsing_mask_dict_align, image_size=512, max_num_facials=5) + parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype) + facial_token_mask = facial_token_mask.to(device) + facial_token_idx_mask = facial_token_idx_mask.to(device) + + # key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip'] + # for key in key_masked_raw_images_dict: + # key_masked_raw_images_dict[key].save(f"{key}.png") + + # 6. Get the update text embedding + # parsed_image_parts2: the facial areas of the input image + # extract_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds + # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask. + # parsed_image_parts2: [1, 5, 3, 224, 224] + text_local_id_embeds, uncond_text_local_id_embeds = \ + self.extract_local_facial_embeds(text_embeds, uncond_text_embeds, \ + parsed_image_parts2, facial_token_mask, facial_token_idx_mask, + calc_uncond=calc_uncond) + + # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768] + # text_local_id_embeds: [1, 77, 768], only differs with text_embeds on 4 ID embeddings, and is identical + # to text_embeds on the rest 73 tokens. + text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1) + text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1) + + if calc_uncond: + uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1) + coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0) + fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0) + else: + coarse_prompt_embeds = text_global_id_embeds + fine_prompt_embeds = text_local_global_id_embeds + + # fine_prompt_embeds: the conditional part is + # (text_global_id_embeds + text_local_global_id_embeds) / 2. + fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2 + + return coarse_prompt_embeds, fine_prompt_embeds + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + original_size: Optional[Tuple[int, int]] = None, + target_size: Optional[Tuple[int, int]] = None, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + input_subj_image_objs: PipelineImageInput = None, + start_merge_step: int = 0, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale >= 1.0 + assert do_classifier_free_guidance + + if input_subj_image_objs is not None: + if not isinstance(input_subj_image_objs, list): + input_subj_image_objs = [input_subj_image_objs] + + # 3. Encode input prompt + coarse_prompt_embeds, fine_prompt_embeds = \ + self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device) + else: + # Replace the coarse_prompt_embeds and fine_prompt_embeds with the input prompt_embeds. + # This is used when prompt_embeds are computed in advance. + cfg_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + coarse_prompt_embeds = cfg_prompt_embeds + fine_prompt_embeds = cfg_prompt_embeds + + # 7. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 8. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + self.dtype, + device, + generator, + latents, + ) + + # {'eta': 0.0, 'generator': None}. eta is 0 for DDIM. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + cross_attention_kwargs = {} + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + # DDIM doesn't scale latent_model_input. + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if i <= start_merge_step: + current_prompt_embeds = coarse_prompt_embeds + else: + current_prompt_embeds = fine_prompt_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=current_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + assert 0, "Not Implemented" + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or \ + ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + elif output_type == "pil": + # 9.1 Post-processing + image = self.decode_latents(latents) + # 9.3 Convert to PIL + image = self.numpy_to_pil(image) + else: + # 9.1 Post-processing + image = self.decode_latents(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=None + ) + + + + + + + + diff --git a/ConsistentID/models/BiSeNet_pretrained_for_ConsistentID.pth b/ConsistentID/models/BiSeNet_pretrained_for_ConsistentID.pth new file mode 100644 index 0000000000000000000000000000000000000000..ca57f3257ca7715bc340d065764bc249d985c287 --- /dev/null +++ b/ConsistentID/models/BiSeNet_pretrained_for_ConsistentID.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567 +size 53289463 diff --git a/ConsistentID/models/ConsistentID-v1.bin b/ConsistentID/models/ConsistentID-v1.bin new file mode 100644 index 0000000000000000000000000000000000000000..cb4022f6ff8830c29609aa344c162ff749c27063 --- /dev/null +++ b/ConsistentID/models/ConsistentID-v1.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48cd9faab558c09565dfb4a355976ea44501fb496c11e3ced722286a8453765b +size 669123998 diff --git a/ConsistentID/requirements.txt b/ConsistentID/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5ba980a578d3d2963080eac9b431efd05f6ff674 --- /dev/null +++ b/ConsistentID/requirements.txt @@ -0,0 +1,15 @@ +accelerate +safetensors +einops +onnxruntime-gpu +omegaconf +peft +opencv-python +insightface +diffusers +torch +torchvision +transformers +spaces +huggingface-hub +sentencepiece \ No newline at end of file diff --git a/README.md b/README.md index 1dcd5cd1b215dae333a568306803ef0f0c2cb80d..828052a32471c0d089ea14a790d19c75ef6d6813 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ emoji: 🎨 colorFrom: yellow colorTo: green sdk: gradio -sdk_version: 4.36.1 +sdk_version: 4.40.0 app_file: app.py pinned: true --- \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/adaface/adaface-infer.py b/adaface/adaface-infer.py index 5f3c344aa901751e9070320b494a52711f33f47e..f68c2334582b0c8a7d8ededb0c8836b8b6b943bc 100644 --- a/adaface/adaface-infer.py +++ b/adaface/adaface-infer.py @@ -40,7 +40,7 @@ def seed_everything(seed): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5', + parser.add_argument("--base_model_path", type=str, default='models/sd15-dste8-vae.safetensors', help="Type of checkpoints to use (default: SD 1.5)") parser.add_argument("--embman_ckpt", type=str, required=True, help="Path to the checkpoint of the embedding manager") diff --git a/adaface/adaface_infer.py b/adaface/adaface_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..43fde7b8ced86a4058146089697a3ae83960069c --- /dev/null +++ b/adaface/adaface_infer.py @@ -0,0 +1,155 @@ +from adaface.adaface_wrapper import AdaFaceWrapper +import torch +#import torch.nn.functional as F +from PIL import Image +import numpy as np +import os, argparse, glob, re + +def save_images(images, num_images_per_row, subject_name, prompt, perturb_std, save_dir = "samples-ada"): + if num_images_per_row > len(images): + num_images_per_row = len(images) + + os.makedirs(save_dir, exist_ok=True) + + num_columns = int(np.ceil(len(images) / num_images_per_row)) + # Save 4 images as a grid image in save_dir + grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns)) + for i, image in enumerate(images): + image = image.resize((512, 512)) + grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row))) + + prompt_sig = prompt.replace(" ", "_").replace(",", "_") + grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}.png") + if os.path.exists(grid_filepath): + grid_count = 2 + grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png') + while os.path.exists(grid_filepath): + grid_count += 1 + grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-perturb{perturb_std:.02f}-{grid_count}.png') + + grid_image.save(grid_filepath) + print(f"Saved to {grid_filepath}") + +def seed_everything(seed): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PL_GLOBAL_SEED"] = str(seed) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--pipeline", type=str, default="text2img", + choices=["text2img", "img2img", "text2img3", "flux"], + help="Type of pipeline to use (default: txt2img)") + parser.add_argument("--base_model_path", type=str, default=None, + help="Type of checkpoints to use (default: None, using the official model)") + parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+", + default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt']) + parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"], + choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") + # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face). + parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None, + help="CFG scales of output embeddings of the ID2Ada prompt encoders") + parser.add_argument("--main_unet_filepath", type=str, default=None, + help="Path to the checkpoint of the main UNet model, if you want to replace the default UNet within --base_model_path") + parser.add_argument("--extra_unet_dirpaths", type=str, nargs="*", + default=['models/ensemble/rv4-unet', 'models/ensemble/ar18-unet'], + help="Extra paths to the checkpoints of the UNet models") + parser.add_argument('--unet_weights', type=float, nargs="+", default=[4, 2, 1], + help="Weights for the UNet models") + parser.add_argument("--subject", type=str) + parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use") + parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate") + parser.add_argument("--prompt", type=str, default="a woman z in superman costume") + parser.add_argument("--noise", dest='perturb_std', type=float, default=0) + parser.add_argument("--randface", action="store_true") + parser.add_argument("--scale", dest='guidance_scale', type=float, default=4, + help="Guidance scale for the diffusion model") + parser.add_argument("--id_cfg_scale", type=float, default=6, + help="CFG scale when generating the identity embeddings") + + parser.add_argument("--subject_string", + type=str, default="z", + help="Subject placeholder string used in prompts to denote the concept.") + parser.add_argument("--num_images_per_row", type=int, default=4, + help="Number of images to display in a row in the output grid image.") + parser.add_argument("--num_inference_steps", type=int, default=50, + help="Number of DDIM inference steps") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on") + parser.add_argument("--seed", type=int, default=42, + help="the seed (for reproducible sampling). Set to -1 to disable.") + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + if args.seed != -1: + seed_everything(args.seed) + + if re.match(r"^\d+$", args.device): + args.device = f"cuda:{args.device}" + print(f"Using device {args.device}") + + if args.pipeline not in ["text2img", "img2img"]: + args.extra_unet_dirpaths = None + args.unet_weights = None + + adaface = AdaFaceWrapper(args.pipeline, args.base_model_path, + args.adaface_encoder_types, args.adaface_ckpt_paths, + args.adaface_encoder_cfg_scales, + args.subject_string, args.num_inference_steps, + unet_types=None, + main_unet_filepath=args.main_unet_filepath, + extra_unet_dirpaths=args.extra_unet_dirpaths, + unet_weights=args.unet_weights, device=args.device) + + if not args.randface: + image_folder = args.subject + if image_folder.endswith("/"): + image_folder = image_folder[:-1] + + if os.path.isfile(image_folder): + # Get the second to the last part of the path + subject_name = os.path.basename(os.path.dirname(image_folder)) + image_paths = [image_folder] + + else: + subject_name = os.path.basename(image_folder) + image_types = ["*.jpg", "*.png", "*.jpeg"] + alltype_image_paths = [] + for image_type in image_types: + # glob returns the full path. + image_paths = glob.glob(os.path.join(image_folder, image_type)) + if len(image_paths) > 0: + alltype_image_paths.extend(image_paths) + + # Filter out images of "*_mask.png" + alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path] + + # image_paths contain at most args.example_image_count full image paths. + if args.example_image_count > 0: + image_paths = alltype_image_paths[:args.example_image_count] + else: + image_paths = alltype_image_paths + else: + subject_name = None + image_paths = None + image_folder = None + + subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name + rand_init_id_embs = torch.randn(1, 512) + + init_id_embs = rand_init_id_embs if args.randface else None + noise = torch.randn(args.out_image_count, 4, 64, 64).cuda() + # args.perturb_std: the *relative* std of the noise added to the face embeddings. + # A noise level of 0.08 could change gender, but 0.06 is usually safe. + # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call). + adaface_subj_embs = \ + adaface.prepare_adaface_embeddings(image_paths, init_id_embs, + perturb_at_stage='img_prompt_emb', + perturb_std=args.perturb_std, update_text_encoder=True) + images = adaface(noise, args.prompt, None, args.guidance_scale, args.out_image_count, verbose=True) + save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.perturb_std) diff --git a/adaface/adaface_translate.py b/adaface/adaface_translate.py new file mode 100644 index 0000000000000000000000000000000000000000..006b0c2fabe203ea6d1385b537f2b48c99db18f3 --- /dev/null +++ b/adaface/adaface_translate.py @@ -0,0 +1,223 @@ +from adaface.adaface_wrapper import AdaFaceWrapper +import torch +#import torch.nn.functional as F +from PIL import Image +import numpy as np +import os, argparse, glob, re, shutil + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + +def seed_everything(seed): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PL_GLOBAL_SEED"] = str(seed) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors', + help="Path to the UNet checkpoint (default: RealisticVision 4.0)") + parser.add_argument('--adaface_ckpt_paths', type=str, nargs="+", + default=['models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt']) + parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["arc2face"], + choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") + # If adaface_encoder_cfg_scales is not specified, the weights will be set to 6.0 (consistentID) and 1.0 (arc2face). + parser.add_argument('--adaface_encoder_cfg_scales', type=float, nargs="+", default=None, + help="CFG scales of output embeddings of the ID2Ada prompt encoders") + parser.add_argument('--extra_unet_dirpaths', type=str, nargs="*", + default=['models/ensemble/rv4-unet', 'models/ensemble/ar18-unet'], + help="Extra paths to the checkpoints of the UNet models") + parser.add_argument('--unet_weights', type=float, nargs="+", default=[4, 2, 1], + help="Weights for the UNet models") + parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images") + # If True, the input folder contains images of mixed subjects. + # If False, the input folder contains multiple subfolders, each of which contains images of the same subject. + parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?", + help="Whether the input folder contains images of mixed subjects") + parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject") + parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated") + parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images") + parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image") + parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder") + parser.add_argument("--noise", dest='perturb_std', type=float, default=0) + parser.add_argument("--scale", dest='guidance_scale', type=float, default=4, + help="Guidance scale for the diffusion model") + parser.add_argument("--ref_img_strength", type=float, default=0.8, + help="Strength of the reference image in the output image.") + parser.add_argument("--subject_string", + type=str, default="z", + help="Subject placeholder string used in prompts to denote the concept.") + parser.add_argument("--prompt", type=str, default="a person z") + parser.add_argument("--num_images_per_row", type=int, default=4, + help="Number of images to display in a row in the output grid image.") + parser.add_argument("--num_inference_steps", type=int, default=50, + help="Number of DDIM inference steps") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.") + parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on") + parser.add_argument("--seed", type=int, default=42, + help="the seed (for reproducible sampling). Set to -1 to disable.") + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + if args.seed != -1: + seed_everything(args.seed) + +# screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py +# --adaface_ckpt_paths logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt +# --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /path/to/VGGface2_HQ_masks/ +# --is_mix_subj_folder 0 --out_folder /path/to/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2 + if args.num_gpus > 1: + from accelerate import PartialState + distributed_state = PartialState() + args.device = distributed_state.device + process_index = distributed_state.process_index + elif re.match(r"^\d+$", args.device): + args.device = f"cuda:{args.device}" + distributed_state = None + process_index = 0 + + adaface = AdaFaceWrapper("img2img", args.base_model_path, + args.adaface_encoder_types, args.adaface_ckpt_paths, + args.adaface_encoder_cfg_scales, + args.subject_string, args.num_inference_steps, + unet_types=None, + extra_unet_dirpaths=args.extra_unet_dirpaths, unet_weights=args.unet_weights, + device=args.device) + + in_folder = args.in_folder + if os.path.isfile(in_folder): + subject_folders = [ os.path.dirname(in_folder) ] + images_by_subject = [[in_folder]] + else: + if not args.is_mix_subj_folder: + in_folders = [in_folder] + else: + in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ] + + images_by_subject = [] + subject_folders = [] + for in_folder in in_folders: + image_types = ["*.jpg", "*.png", "*.jpeg"] + alltype_image_paths = [] + for image_type in image_types: + # glob returns the full path. + image_paths = glob.glob(os.path.join(in_folder, image_type)) + if len(image_paths) > 0: + alltype_image_paths.extend(image_paths) + + # Filter out images of "*_mask.png" + alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path] + alltype_image_paths = sorted(alltype_image_paths) + + if not args.is_mix_subj_folder: + # image_paths contain at most args.max_images_per_subject full image paths. + if args.max_images_per_subject > 0: + image_paths = alltype_image_paths[:args.max_images_per_subject] + else: + image_paths = alltype_image_paths + + images_by_subject.append(image_paths) + subject_folders.append(in_folder) + else: + # Each image in the folder is treated as an individual subject. + images_by_subject.extend([[image_path] for image_path in alltype_image_paths]) + subject_folders.extend([in_folder] * len(alltype_image_paths)) + + if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count: + break + + if args.trans_subject_count > 0: + images_by_subject = images_by_subject[:args.trans_subject_count] + subject_folders = subject_folders[:args.trans_subject_count] + + out_image_count = 0 + out_mask_count = 0 + if not args.out_folder.endswith("/"): + args.out_folder += "/" + + if args.num_gpus > 1: + # Split the subjects across the GPUs. + subject_folders = subject_folders[process_index::args.num_gpus] + images_by_subject = images_by_subject[process_index::args.num_gpus] + #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject)) + + for (subject_folder, image_paths) in zip(subject_folders, images_by_subject): + # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image. + # Otherwise, we use the folder name as the signature of the images. + images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0]) + + print(f"Translating {images_sig}...") + with torch.no_grad(): + adaface_subj_embs = \ + adaface.prepare_adaface_embeddings(image_paths, None, + perturb_at_stage='img_prompt_emb', + perturb_std=args.perturb_std, + update_text_encoder=True) + + # Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder. + subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1) + if not os.path.exists(subject_out_folder): + os.makedirs(subject_out_folder) + print(f"Output images will be saved to {subject_out_folder}") + + in_images = [] + for image_path in image_paths: + image = Image.open(image_path).convert("RGB").resize((512, 512)) + # [512, 512, 3] -> [3, 512, 512]. + image = np.array(image).transpose(2, 0, 1) + # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU. + image = torch.tensor(image).unsqueeze(0).float().cuda() + in_images.append(image) + + # Put all input images of the subject into a batch. This assumes max_images_per_subject is small. + # NOTE: For simplicity, we do not check overly large batch sizes. + in_images = torch.cat(in_images, dim=0) + # in_images: [5, 3, 512, 512]. + # Normalize the pixel values to [0, 1]. + in_images = in_images / 255.0 + num_out_images = len(in_images) * args.out_count_per_input_image + + with torch.no_grad(): + # args.perturb_std: the *relative* std of the noise added to the face embeddings. + # A noise level of 0.08 could change gender, but 0.06 is usually safe. + # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly. + # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images. + out_images = adaface(in_images, args.prompt, None, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength) + + for img_i, img in enumerate(out_images): + # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ... + subj_i = img_i % len(in_images) + copy_i = img_i // len(in_images) + image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i])) + if copy_i == 0: + img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}")) + else: + img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}")) + + if args.copy_masks: + mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png") + if os.path.exists(mask_path): + if copy_i == 0: + shutil.copy(mask_path, subject_out_folder) + else: + mask_filename_stem = image_filename_stem + shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png")) + + out_mask_count += 1 + + out_image_count += len(out_images) + + print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}") diff --git a/adaface/adaface_wrapper.py b/adaface/adaface_wrapper.py index 47f156c691673e07ba0e30a97e369a02729042a7..8f305ce39c17c9a4ff5affa75315b08e6a6bb1c0 100644 --- a/adaface/adaface_wrapper.py +++ b/adaface/adaface_wrapper.py @@ -4,39 +4,55 @@ from transformers import CLIPTextModel from diffusers import ( StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, - UNet2DConditionModel, + StableDiffusion3Pipeline, + #FluxPipeline, DDIMScheduler, AutoencoderKL, ) -from insightface.app import FaceAnalysis -from adaface.arc2face_models import CLIPTextModelWrapper -from adaface.util import get_arc2face_id_prompt_embs +from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint +from adaface.util import UNetEnsemble +from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder +from safetensors.torch import load_file as safetensors_load_file import re, os +import numpy as np import sys +# Monkey patch the missing ldm module in the old arc2face adaface checkpoint. sys.modules['ldm'] = sys.modules['adaface'] +sys.modules['ldm.modules'] = sys.modules['adaface'] class AdaFaceWrapper(nn.Module): - def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device, - subject_string='z', num_vectors=16, - num_inference_steps=50, negative_prompt=None, - use_840k_vae=False, use_ds_text_encoder=False, is_training=False): + def __init__(self, pipeline_name, base_model_path, adaface_encoder_types, + adaface_ckpt_paths, adaface_encoder_cfg_scales=None, + enabled_encoders=None, + subject_string='z', num_inference_steps=50, negative_prompt=None, + use_840k_vae=False, use_ds_text_encoder=False, + main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights=None, + device='cuda', is_training=False): ''' - pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are + pipeline_name: "text2img", "img2img", "text2img3", "flux", or None. + If None, it's used only as a face encoder, and the unet and vae are removed from the pipeline to release RAM. ''' super().__init__() self.pipeline_name = pipeline_name self.base_model_path = base_model_path - self.adaface_ckpt_path = adaface_ckpt_path - self.use_840k_vae = use_840k_vae - self.use_ds_text_encoder = use_ds_text_encoder + self.adaface_encoder_types = adaface_encoder_types + + self.adaface_ckpt_paths = adaface_ckpt_paths + self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales + self.enabled_encoders = enabled_encoders self.subject_string = subject_string - self.num_vectors = num_vectors + self.num_inference_steps = num_inference_steps + self.use_840k_vae = use_840k_vae + self.use_ds_text_encoder = use_ds_text_encoder + self.main_unet_filepath = main_unet_filepath + self.unet_types = unet_types + self.extra_unet_dirpaths = extra_unet_dirpaths + self.unet_weights = unet_weights self.device = device self.is_training = is_training - self.initialize_pipeline() - self.extend_tokenizer_and_text_encoder() + if negative_prompt is None: self.negative_prompt = \ "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \ @@ -46,45 +62,34 @@ class AdaFaceWrapper(nn.Module): else: self.negative_prompt = negative_prompt - def load_subj_basis_generator(self, adaface_ckpt_path): - ckpt = torch.load(adaface_ckpt_path, map_location='cpu') - string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"] - if self.subject_string not in string_to_subj_basis_generator_dict: - print(f"Subject '{self.subject_string}' not found in the embedding manager.") - breakpoint() - - self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string] - # In the original ckpt, num_out_layers is 16 for layerwise embeddings. - # But we don't do layerwise embeddings here, so we set it to 1. - self.subj_basis_generator.num_out_layers = 1 - print(f"Loaded subject basis generator for '{self.subject_string}'.") - print(repr(self.subj_basis_generator)) - self.subj_basis_generator.to(self.device) - if self.is_training: - self.subj_basis_generator.train() - else: - self.subj_basis_generator.eval() + self.initialize_pipeline() + # During inference, we never use static image suffix embeddings. + # So num_id_vecs is the length of the returned adaface embeddings for each encoder. + self.encoders_num_id_vecs = self.id2ada_prompt_encoder.encoders_num_id_vecs + self.extend_tokenizer_and_text_encoder() def initialize_pipeline(self): - self.load_subj_basis_generator(self.adaface_ckpt_path) - # arc2face_text_encoder maps the face analysis embedding to 16 face embeddings - # in the UNet image space. - arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained( - 'models/arc2face', subfolder="encoder", torch_dtype=torch.float16 - ) - self.arc2face_text_encoder = arc2face_text_encoder.to(self.device) + self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types, + self.adaface_ckpt_paths, + self.adaface_encoder_cfg_scales, + self.enabled_encoders) + + self.id2ada_prompt_encoder.to(self.device) + print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}") if self.use_840k_vae: # The 840000-step vae model is slightly better in face details than the original vae model. # https://huggingface.co/stabilityai/sd-vae-ft-mse-original - vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16) + vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", + torch_dtype=torch.float16) else: vae = None if self.use_ds_text_encoder: # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder. # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder - text_encoder = CLIPTextModel.from_pretrained("models/diffusers/ds_text_encoder", torch_dtype=torch.float16) + text_encoder = CLIPTextModel.from_pretrained("models/diffusers/ds_text_encoder", + torch_dtype=torch.float16) else: text_encoder = None @@ -94,6 +99,10 @@ class AdaFaceWrapper(nn.Module): PipelineClass = StableDiffusionImg2ImgPipeline elif self.pipeline_name == "text2img": PipelineClass = StableDiffusionPipeline + elif self.pipeline_name == "text2img3": + PipelineClass = StableDiffusion3Pipeline + elif self.pipeline_name == "flux": + PipelineClass = FluxPipeline # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images. elif self.pipeline_name is None: PipelineClass = StableDiffusionPipeline @@ -101,6 +110,14 @@ class AdaFaceWrapper(nn.Module): else: raise ValueError(f"Unknown pipeline name: {self.pipeline_name}") + if self.base_model_path is None: + base_model_path_dict = { + 'text2img': 'models/sd15-dste8-vae.safetensors', + 'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers', + 'flux': 'black-forest-labs/FLUX.1-schnell', + } + self.base_model_path = base_model_path_dict[self.pipeline_name] + if os.path.isfile(self.base_model_path): pipeline = PipelineClass.from_single_file( self.base_model_path, @@ -112,8 +129,23 @@ class AdaFaceWrapper(nn.Module): torch_dtype=torch.float16, safety_checker=None ) - print(f"Loaded pipeline from {self.base_model_path}.") + + if self.main_unet_filepath is not None: + print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.") + ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu')) + if len(ret.missing_keys) > 0: + print(f"Missing keys: {ret.missing_keys}") + if len(ret.unexpected_keys) > 0: + print(f"Unexpected keys: {ret.unexpected_keys}") + + if (self.unet_types is not None and len(self.unet_types) > 0) \ + or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0): + unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights, + device=self.device, torch_dtype=torch.float16) + pipeline.unet = unet_ensemble + print(f"Loaded pipeline from {self.base_model_path}.") + if self.use_840k_vae: pipeline.vae = vae print("Replaced the VAE with the 840k-step VAE.") @@ -128,133 +160,139 @@ class AdaFaceWrapper(nn.Module): pipeline.vae = None print("Removed UNet and VAE from the pipeline.") - noise_scheduler = DDIMScheduler( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - steps_offset=1, - ) - - pipeline.scheduler = noise_scheduler + if self.pipeline_name not in ["text2img3", "flux"]: + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + pipeline.scheduler = noise_scheduler + # Otherwise, pipeline.scheduler == FlowMatchEulerDiscreteScheduler self.pipeline = pipeline.to(self.device) - # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2. - # Note there's a second "model" in the path. - self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) - self.face_app.prepare(ctx_id=0, det_size=(512, 512)) - # Patch the missing tokenizer in the subj_basis_generator. - if not hasattr(self.subj_basis_generator, 'clip_tokenizer'): - self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer - print("Patched the missing tokenizer in the subj_basis_generator.") + def load_unet_from_file(self, unet_path, device=None): + if os.path.isfile(unet_path): + if unet_path.endswith(".safetensors"): + unet_state_dict = safetensors_load_file(unet_path, device=device) + else: + unet_state_dict = torch.load(unet_path, map_location=device) + + key0 = list(unet_state_dict.keys())[0] + if key0.startswith("model.diffusion_model"): + key_prefix = "" + is_ldm_unet = True + elif key0.startswith("diffusion_model"): + key_prefix = "model." + is_ldm_unet = True + else: + is_ldm_unet = False + + if is_ldm_unet: + unet_state_dict2 = {} + for key, value in unet_state_dict.items(): + key2 = key_prefix + key + unet_state_dict2[key2] = value + print(f"LDM UNet detected. Convert to diffusers") + ldm_unet_config = { 'layers_per_block': 2 } + unet_state_dict = convert_ldm_unet_checkpoint(unet_state_dict2, ldm_unet_config) + else: + raise ValueError(f"UNet path {unet_path} is not a file.") + return unet_state_dict + def extend_tokenizer_and_text_encoder(self): - if self.num_vectors < 1: - raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}") + if np.sum(self.encoders_num_id_vecs) < 1: + raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}") tokenizer = self.pipeline.tokenizer - # Add z0, z1, z2, ..., z15. - self.placeholder_tokens = [] - for i in range(0, self.num_vectors): - self.placeholder_tokens.append(f"{self.subject_string}_{i}") + # If adaface_encoder_types is ["arc2face", "consistentID"], then total_num_id_vecs = 20. + # We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer. + self.all_placeholder_tokens = [] + self.placeholder_tokens_strs = [] + for i in range(len(self.adaface_encoder_types)): + placeholder_tokens = [] + for j in range(self.encoders_num_id_vecs[i]): + placeholder_tokens.append(f"{self.subject_string}_{i}_{j}") + placeholder_tokens_str = " ".join(placeholder_tokens) - self.placeholder_tokens_str = " ".join(self.placeholder_tokens) + self.all_placeholder_tokens.extend(placeholder_tokens) + self.placeholder_tokens_strs.append(placeholder_tokens_str) + + self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs) # Add the new tokens to the tokenizer. - num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens) - if num_added_tokens != self.num_vectors: + num_added_tokens = tokenizer.add_tokens(self.all_placeholder_tokens) + if num_added_tokens != np.sum(self.encoders_num_id_vecs): raise ValueError( - f"The tokenizer already contains the token {self.subject_string}. Please pass a different" + f"The tokenizer already contains some of the tokens {self.all_placeholder_tokens_str}. Please pass a different" " `subject_string` that is not already in the tokenizer.") - print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.") + print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.") # placeholder_token_ids: [49408, ..., 49423]. - self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens) - # print(self.placeholder_token_ids) + self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens) + #print("New tokens:", self.placeholder_token_ids) # Resize the token embeddings as we are adding new special tokens to the tokenizer - old_weight = self.pipeline.text_encoder.get_input_embeddings().weight + old_weight_shape = self.pipeline.text_encoder.get_input_embeddings().weight.shape self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer)) new_weight = self.pipeline.text_encoder.get_input_embeddings().weight - print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.") + print(f"Resized text encoder token embeddings from {old_weight_shape} to {new_weight.shape} on {new_weight.device}.") # Extend pipeline.text_encoder with the adaface subject emeddings. # subj_embs: [16, 768]. - def update_text_encoder_subj_embs(self, subj_embs): + def update_text_encoder_subj_embeddings(self, subj_embs): # Initialise the newly added placeholder token with the embeddings of the initializer token + # token_embeds: [49412, 768] token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data with torch.no_grad(): for i, token_id in enumerate(self.placeholder_token_ids): token_embeds[token_id] = subj_embs[i] - print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.") + print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.all_placeholder_tokens_str}) in the text encoder.") def update_prompt(self, prompt): if prompt is None: prompt = "" - - # If the placeholder tokens are already in the prompt, then return the prompt as is. - if self.placeholder_tokens_str in prompt: - return prompt - - # If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt. - if re.search(r'\b' + self.subject_string + r'\b', prompt) is None: - print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.") - comp_prompt = self.placeholder_tokens_str + " " + prompt - else: - # Replace the subject string 'z' with the placeholder tokens. - comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt) - return comp_prompt - - # image_paths: a list of image paths. image_folder: the parent folder name. - def generate_adaface_embeddings(self, image_paths, image_folder=None, - pre_face_embs=None, gen_rand_face=False, - out_id_embs_scale=1., noise_level=0, update_text_encoder=True): - # faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512). - # If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times. - # Otherwise, faceid_embeds is a batch of random embeddings, each instance is different. - # The same applies to id_prompt_emb. - # faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space. - # Here id_batch_size = 1, so - # faceid_embeds: [1, 512]. NOT used later. - # id_prompt_emb: [1, 16, 768]. - # NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings. - # arc2face prompt template: "photo of a id person" - # ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings. - face_image_count, faceid_embeds, id_prompt_emb \ - = get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder, - extract_faceid_embeds=not gen_rand_face, - pre_face_embs=pre_face_embs, - # image_folder is passed only for logging purpose. - # image_paths contains the paths of the images. - image_folder=image_folder, image_paths=image_paths, - images_np=None, - id_batch_size=1, - device=self.device, - # input_max_length == 22: only keep the first 22 tokens, - # including 3 template tokens and 16 ID tokens, and BOS and EOS tokens. - # The results are indistinguishable from input_max_length=77. - input_max_length=22, - noise_level=noise_level, - return_core_id_embs=True, - gen_neg_prompt=False, - verbose=True) + + # Delete the subject_string from the prompt. + re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt) + re.sub(r'\b' + self.subject_string + r'\b,?', "", prompt) + # Prevously, arc2face ada prompts work better if they are prepended to the prompt, + # and consistentID ada prompts work better if they are appended to the prompt. + # When we do joint training, seems both work better if they are appended to the prompt. + # Therefore we simply appended all placeholder_tokens_str's to the prompt. + # NOTE: Prepending them hurts compositional prompts. + prompt = prompt + " " + self.all_placeholder_tokens_str + + return prompt + + # avg_at_stage: 'id_emb', 'img_prompt_emb', or None. + # avg_at_stage == ada_prompt_emb usually produces the worst results. + # id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better. + def prepare_adaface_embeddings(self, image_paths, face_id_embs=None, + avg_at_stage='id_emb', # id_emb, img_prompt_emb, ada_prompt_emb, or None. + perturb_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_std=0, update_text_encoder=True): + + all_adaface_subj_embs = \ + self.id2ada_prompt_encoder.generate_adaface_embeddings(\ + image_paths, face_id_embs=face_id_embs, + img_prompt_embs=None, + avg_at_stage=avg_at_stage, + perturb_at_stage=perturb_at_stage, + perturb_std=perturb_std, + enable_static_img_suffix_embs=False) - if face_image_count == 0: + if all_adaface_subj_embs is None: return None - - # adaface_subj_embs: [1, 1, 16, 768]. - # adaface_prompt_embs: [1, 77, 768] (not used). - adaface_subj_embs, adaface_prompt_embs = \ - self.subj_basis_generator(id_prompt_emb, None, None, - out_id_embs_scale=out_id_embs_scale, - is_face=True, is_training=False, - adaface_prompt_embs_inf_type='full_half_pad') - # adaface_subj_embs: [16, 768] - adaface_subj_embs = adaface_subj_embs.squeeze() + + # [1, 1, 16, 768] -> [16, 768] + all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0) + if update_text_encoder: - self.update_text_encoder_subj_embs(adaface_subj_embs) - return adaface_subj_embs + self.update_text_encoder_subj_embeddings(all_adaface_subj_embs) + return all_adaface_subj_embs def encode_prompt(self, prompt, negative_prompt=None, device=None, verbose=False): if negative_prompt is None: @@ -262,14 +300,17 @@ class AdaFaceWrapper(nn.Module): if device is None: device = self.device - + prompt = self.update_prompt(prompt) if verbose: - print(f"Prompt: {prompt}") + print(f"Subject prompt: {prompt}") # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device). # So we manually move it to GPU here. self.pipeline.text_encoder.to(device) + # pooled_prompt_embeds_, negative_pooled_prompt_embeds_ are used by text2img3 and flux. + pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = None, None + # Compatible with older versions of diffusers. if not hasattr(self.pipeline, "encode_prompt"): # prompt_embeds_, negative_prompt_embeds_: [77, 768] -> [1, 77, 768]. @@ -279,37 +320,88 @@ class AdaFaceWrapper(nn.Module): prompt_embeds_ = prompt_embeds_.unsqueeze(0) negative_prompt_embeds_ = negative_prompt_embeds_.unsqueeze(0) else: - # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768] - prompt_embeds_, negative_prompt_embeds_ = \ - self.pipeline.encode_prompt(prompt, device=device, - num_images_per_prompt=1, - do_classifier_free_guidance=True, - negative_prompt=negative_prompt) - - return prompt_embeds_, negative_prompt_embeds_ + if self.pipeline_name in ["text2img3", "flux"]: + # prompt_embeds_, negative_prompt_embeds_: [1, 333, 4096] + # pooled_prompt_embeds_, negative_pooled_prompt_embeds_: [1, 2048] + # CLIP Text Encoder prompt uses a maximum sequence length of 77. + # T5 Text Encoder prompt uses a maximum sequence length of 256. + # 333 = 256 + 77. + prompt_t5 = prompt + "".join([", "] * 256) + if self.pipeline_name == "text2img3": + prompt_embeds_, negative_prompt_embeds_, \ + pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \ + self.pipeline.encode_prompt(prompt, prompt, prompt_t5, device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt) + elif self.pipeline_name == "flux": + # prompt_embeds_: [1, 512, 4096] + # pooled_prompt_embeds_: [1, 768] + prompt_embeds_, pooled_prompt_embeds_, text_ids = \ + self.pipeline.encode_prompt(prompt, prompt_t5, device=device, + num_images_per_prompt=1) + negative_prompt_embeds_ = negative_pooled_prompt_embeds_ = None + else: + breakpoint() + else: + # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768] + prompt_embeds_, negative_prompt_embeds_ = \ + self.pipeline.encode_prompt(prompt, device=device, + num_images_per_prompt=1, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt) + + return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_ # ref_img_strength is used only in the img2img pipeline. - def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0, + def forward(self, noise, prompt, negative_prompt=None, guidance_scale=6.0, out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False): + noise = noise.to(device=self.device, dtype=torch.float16) + if negative_prompt is None: negative_prompt = self.negative_prompt # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768] - prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose) + prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \ + negative_pooled_prompt_embeds_ = \ + self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose) # Repeat the prompt embeddings for all images in the batch. - prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1) - negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1) - noise = noise.to(self.device).to(torch.float16) - - # noise: [BS, 4, 64, 64] - # When the pipeline is text2img, strength is ignored. - images = self.pipeline(image=noise, - prompt_embeds=prompt_embeds_, - negative_prompt_embeds=negative_prompt_embeds_, - num_inference_steps=self.num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=1, - strength=ref_img_strength, - generator=generator).images + prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1) + if negative_prompt_embeds_ is not None: + negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1) + + if self.pipeline_name == "text2img3": + pooled_prompt_embeds_ = pooled_prompt_embeds_.repeat(out_image_count, 1) + negative_pooled_prompt_embeds_ = negative_pooled_prompt_embeds_.repeat(out_image_count, 1) + + # noise: [BS, 4, 64, 64] + # When the pipeline is text2img, strength is ignored. + images = self.pipeline(prompt_embeds=prompt_embeds_, + negative_prompt_embeds=negative_prompt_embeds_, + pooled_prompt_embeds=pooled_prompt_embeds_, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds_, + num_inference_steps=self.num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + generator=generator).images + elif self.pipeline_name == "flux": + images = self.pipeline(prompt_embeds=prompt_embeds_, + pooled_prompt_embeds=pooled_prompt_embeds_, + num_inference_steps=4, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + generator=generator).images + else: + # When the pipeline is text2img, noise: [BS, 4, 64, 64], and strength is ignored. + # When the pipeline is img2img, noise is an initiali image of [BS, 3, 512, 512], + # whose pixels are normalized to [0, 1]. + images = self.pipeline(image=noise, + prompt_embeds=prompt_embeds_, + negative_prompt_embeds=negative_prompt_embeds_, + num_inference_steps=self.num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + strength=ref_img_strength, + generator=generator).images # images: [BS, 3, 512, 512] return images \ No newline at end of file diff --git a/adaface/arc2face_models.py b/adaface/arc2face_models.py index 57be3e604fb39328d75ba7a34b19f5ebe9e586c8..4331630af34c93e857b939a9c93f0f69ce93ba71 100644 --- a/adaface/arc2face_models.py +++ b/adaface/arc2face_models.py @@ -2,14 +2,49 @@ import torch import torch.nn as nn from transformers import CLIPTextModel from transformers.models.clip.modeling_clip import CLIPAttention -from typing import Any, Callable, Dict, Optional, Tuple, Union, List +from typing import Optional, Tuple, Union from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from diffusers import ( + StableDiffusionPipeline, + UNet2DConditionModel, + DDIMScheduler, +) # from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask _make_causal_mask = AttentionMaskConverter._make_causal_mask _expand_mask = AttentionMaskConverter._expand_mask -from adaface.util import add_noise_to_tensor +from .util import perturb_tensor + +def create_arc2face_pipeline(base_model_path="models/ensemble/sd15-dste8-vae.safetensors", + dtype=torch.float16, unet_only=False): + unet = UNet2DConditionModel.from_pretrained( + 'models/arc2face', subfolder="arc2face", torch_dtype=dtype + ) + if unet_only: + return unet + + text_encoder = CLIPTextModelWrapper.from_pretrained( + 'models/arc2face', subfolder="encoder", torch_dtype=dtype + ) + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + pipeline = StableDiffusionPipeline.from_single_file( + base_model_path, + text_encoder=text_encoder, + unet=unet, + torch_dtype=dtype, + safety_checker=None + ) + pipeline.scheduler = noise_scheduler + return pipeline # Extend CLIPAttention by using multiple k_proj and v_proj in each head. # To avoid too much increase of computation, we don't extend q_proj. @@ -43,9 +78,16 @@ class CLIPAttentionMKV(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1, - noise_std_is_relative=True, keep_norm=False, verbose=False): + # clip_attn_layer is usually self. + def extend_weights(self, clip_attn_layer, layer_idx, multiplier, perturb_std=0.2, + perturb_std_is_relative=True, perturb_keep_norm=False, verbose=False): + ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape) + ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0] + ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape) + ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0] + self.multiplier *= multiplier + # q_proj and out_proj are the same as the original CLIPAttention. self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone() self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone() @@ -55,34 +97,50 @@ class CLIPAttentionMKV(nn.Module): # bias doesn't need noise perturbation, as after the weights are noised, # different copies of the weight/bias will receive different gradients, # making the bias terms diverge and identifiable after training. - self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier) self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier) + self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier) - self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1) self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1) + self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1) + + # Correct the out_features attribute of k_proj and v_proj. + self.k_proj.out_features = self.k_proj.weight.shape[0] + self.v_proj.out_features = self.v_proj.weight.shape[0] - if noise_std > 0: - ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape) - ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0] + if perturb_std > 0: # Adding noise to the extra copies of the weights (keep the first copy unchanged). self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \ - add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:], - noise_std, noise_std_is_relative, keep_norm) + perturb_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:], + perturb_std, perturb_std_is_relative, perturb_keep_norm, verbose=verbose) if verbose: NEW_V_SHAPE = list(self.v_proj.weight.shape) NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape) - print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise") + print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {perturb_std} noise") - ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape) - ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0] # Adding noise to the extra copies of the weights. self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \ - add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:], - noise_std, noise_std_is_relative, keep_norm) + perturb_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:], + perturb_std, perturb_std_is_relative, perturb_keep_norm, verbose=verbose) if verbose: NEW_K_SHAPE = list(self.k_proj.weight.shape) NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape) - print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise") + print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {perturb_std} noise") + + def squeeze_weights(self, clip_attn_layer, divisor): + if self.multiplier % divisor != 0: + breakpoint() + self.multiplier //= divisor + + self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.reshape(divisor, -1).mean(dim=0) + self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.reshape(divisor, -1).mean(dim=0) + + self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.reshape(divisor, -1, self.k_proj.weight.shape[1]).mean(dim=0) + self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.reshape(divisor, -1, self.v_proj.weight.shape[1]).mean(dim=0) + + # Correct the out_features attribute of k_proj and v_proj. + self.k_proj.out_features = self.k_proj.weight.shape[0] + self.v_proj.out_features = self.v_proj.weight.shape[0] + def forward( self, @@ -109,7 +167,7 @@ class CLIPAttentionMKV(nn.Module): src_len = key_states.size(1) # src_len0 is the original src_len without the multiplier. src_len0 = src_len // self.multiplier - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2).contiguous()) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( @@ -165,7 +223,7 @@ class CLIPAttentionMKV(nn.Module): ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) @@ -279,25 +337,46 @@ class CLIPTextModelWrapper(CLIPTextModel): attentions=encoder_outputs.attentions, ) - # Applied to layers [begin_layer_idx, end_layer_idx) in the encoder. + # Applied to all attention layers in the encoder, if the corresponding multiplier is not 1. # The layer indexed by end_layer_idx is not included. # If both layer indices are -1, then apply to all layers (0-11). - def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1): + def extend_clip_attention_MKV_multiplier(self, prompt2token_proj_attention_multipliers, perturb_std=0.1): num_extended_layers = 0 for layer_idx, layer in enumerate(self.text_model.encoder.layers): - if begin_layer_idx >= 0 and layer_idx < begin_layer_idx: + multiplier = prompt2token_proj_attention_multipliers[layer_idx] + if multiplier == 1: continue - if end_layer_idx >= 0 and layer_idx >= end_layer_idx: - break # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV. if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)): breakpoint() old_attn_layer = layer.self_attn if not isinstance(old_attn_layer, CLIPAttentionMKV): layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1) - layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True) + # Extends the v_proj and k_proj weights in the self_attn layer. + layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, perturb_std, verbose=True) num_extended_layers += 1 return num_extended_layers - \ No newline at end of file + + # Applied to layers [begin_layer_idx, end_layer_idx) in the encoder. + # The layer indexed by end_layer_idx is not included. + # If both layer indices are -1, then apply to all layers (0-11). + def squeeze_clip_attention_MKV_divisor(self, prompt2token_proj_attention_divisors): + num_squeezed_layers = 0 + + for layer_idx, layer in enumerate(self.text_model.encoder.layers): + divisor = prompt2token_proj_attention_divisors[layer_idx] + if divisor == 1: + continue + # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV. + if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)): + breakpoint() + old_attn_layer = layer.self_attn + if not isinstance(old_attn_layer, CLIPAttentionMKV): + layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1) + # Squeeze the k_proj and v_proj weights in the self_attn layer. + layer.self_attn.squeeze_weights(old_attn_layer, divisor) + num_squeezed_layers += 1 + + return num_squeezed_layers diff --git a/adaface/face_id_to_ada_prompt.py b/adaface/face_id_to_ada_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..2e32e695724d2219c2f5c3eeabfa85a9762934a3 --- /dev/null +++ b/adaface/face_id_to_ada_prompt.py @@ -0,0 +1,1147 @@ +import torch +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPImageProcessor +from .arc2face_models import CLIPTextModelWrapper +from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline +from .util import perturb_tensor, pad_image_obj_to_square, \ + calc_stats, patch_clip_image_encoder_with_mask, CLIPVisionModelWithMask +from adaface.subj_basis_generator import SubjBasisGenerator +import torch.nn.functional as F +import numpy as np +import cv2 +from PIL import Image +from insightface.app import FaceAnalysis +import os +from omegaconf.listconfig import ListConfig + +# adaface_encoder_types can be a list of one or more encoder types. +# adaface_ckpt_paths can be one or a list of ckpt paths. +# adaface_encoder_cfg_scales is None, or a list of scales for the adaface encoder types. +def create_id2ada_prompt_encoder(adaface_encoder_types, adaface_ckpt_paths=None, + adaface_encoder_cfg_scales=None, enabled_encoders=None, + *args, **kwargs): + if len(adaface_encoder_types) == 1: + adaface_encoder_type = adaface_encoder_types[0] + adaface_ckpt_path = adaface_ckpt_paths[0] if adaface_ckpt_paths is not None else None + if adaface_encoder_type == 'arc2face': + id2ada_prompt_encoder = \ + Arc2Face_ID2AdaPrompt(adaface_ckpt_path=adaface_ckpt_path, + *args, **kwargs) + elif adaface_encoder_type == 'consistentID': + id2ada_prompt_encoder = \ + ConsistentID_ID2AdaPrompt(pipe=None, + adaface_ckpt_path=adaface_ckpt_path, + *args, **kwargs) + else: + id2ada_prompt_encoder = Joint_FaceID2AdaPrompt(adaface_encoder_types, adaface_ckpt_paths, + adaface_encoder_cfg_scales, enabled_encoders, + *args, **kwargs) + + return id2ada_prompt_encoder + +class FaceID2AdaPrompt(nn.Module): + # To be initialized in derived classes. + def __init__(self, *args, **kwargs): + super().__init__() + # Initialize model components. + # These components of ConsistentID_ID2AdaPrompt will be shared with the teacher model. + # So we don't initialize them in the ctor(), but borrow them from the teacher model. + # These components of Arc2Face_ID2AdaPrompt will be initialized in its ctor(). + self.clip_image_encoder = None + self.clip_preprocessor = None + self.face_app = None + self.text_to_image_prompt_encoder = None + self.tokenizer = None + self.dtype = kwargs.get('dtype', torch.float16) + + # Load Img2Ada SubjectBasisGenerator. + self.subject_string = kwargs.get('subject_string', 'z') + self.adaface_ckpt_path = kwargs.get('adaface_ckpt_path', None) + self.subj_basis_generator = None + # -1: use the default scale for the adaface encoder type. + # i.e., 6 for arc2face and 1 for consistentID. + self.out_id_embs_cfg_scale = kwargs.get('out_id_embs_cfg_scale', -1) + self.is_training = kwargs.get('is_training', False) + # extend_prompt2token_proj_attention_multiplier is an integer >= 1. + # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers. + self.extend_prompt2token_proj_attention_multiplier = kwargs.get('extend_prompt2token_proj_attention_multiplier', 1) + self.prompt2token_proj_ext_attention_perturb_ratio = kwargs.get('prompt2token_proj_ext_attention_perturb_ratio', 0.1) + + # Set model behavior configurations. + self.gen_neg_img_prompt = False + self.clip_neg_features = None + + self.use_clip_embs = False + self.do_contrast_clip_embs_on_bg_features = False + # num_id_vecs is the output embeddings of the ID2ImgPrompt module. + # If there's no static image suffix embeddings, then num_id_vecs is also + # the number of ada embeddings returned by the subject basis generator. + # num_id_vecs will be set in each derived class. + self.num_static_img_suffix_embs = kwargs.get('num_static_img_suffix_embs', 0) + print(f'{self.name} Adaface uses {self.num_id_vecs} ID image embeddings and {self.num_static_img_suffix_embs} fixed image embeddings as input.') + + self.id_img_prompt_max_length = 77 + self.face_id_dim = 512 + # clip_embedding_dim: by default it's the OpenAI CLIP embedding dim. + # Could be overridden by derived classes. + self.clip_embedding_dim = 1024 + self.output_dim = 768 + + def get_id2img_learnable_modules(self): + raise NotImplementedError + + def load_id2img_learnable_modules(self, id2img_learnable_modules_state_dict_list): + id2img_prompt_encoder_learnable_modules = self.get_id2img_learnable_modules() + for module, state_dict in zip(id2img_prompt_encoder_learnable_modules, id2img_learnable_modules_state_dict_list): + module.load_state_dict(state_dict) + print(f'{len(id2img_prompt_encoder_learnable_modules)} ID2ImgPrompt encoder modules loaded.') + + # init_subj_basis_generator() can only be called after the derived class is initialized, + # when self.num_id_vecs, self.num_static_img_suffix_embs and self.clip_embedding_dim have been set. + def init_subj_basis_generator(self): + self.subj_basis_generator = \ + SubjBasisGenerator(num_id_vecs = self.num_id_vecs, + num_static_img_suffix_embs = self.num_static_img_suffix_embs, + bg_image_embedding_dim = self.clip_embedding_dim, + output_dim = self.output_dim, + placeholder_is_bg = False, + prompt2token_proj_grad_scale = 1, + bg_prompt_translator_has_to_out_proj=False) + + def load_adaface_ckpt(self, adaface_ckpt_path): + ckpt = torch.load(adaface_ckpt_path, map_location='cpu') + string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"] + if self.subject_string not in string_to_subj_basis_generator_dict: + print(f"Subject '{self.subject_string}' not found in the embedding manager.") + breakpoint() + + ckpt_subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string] + ckpt_subj_basis_generator.N_ID = self.num_id_vecs + # Since we directly use the subject basis generator object from the ckpt, + # fixing the number of static image suffix embeddings is much simpler. + # Otherwise if we want to load the subject basis generator from its state_dict, + # things are more complicated, see embedding manager's load(). + ckpt_subj_basis_generator.N_SFX = self.num_static_img_suffix_embs + # obj_proj_in and pos_embs are for non-faces. So they are useless for human faces. + ckpt_subj_basis_generator.obj_proj_in = None + ckpt_subj_basis_generator.pos_embs = None + # Handle differences in num_static_img_suffix_embs between the current model and the ckpt. + ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.num_static_img_suffix_embs, img_prompt_dim=self.output_dim) + # Fix missing variables in old ckpt. + ckpt_subj_basis_generator.patch_old_subj_basis_generator_ckpt() + + self.subj_basis_generator.extend_prompt2token_proj_attention(\ + ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0) + ret = self.subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict(), strict=False) + print(f"{adaface_ckpt_path}: subject basis generator loaded for '{self.name}'.") + print(repr(ckpt_subj_basis_generator)) + + if ret is not None and len(ret.missing_keys) > 0: + print(f"Missing keys: {ret.missing_keys}") + if ret is not None and len(ret.unexpected_keys) > 0: + print(f"Unexpected keys: {ret.unexpected_keys}") + + # extend_prompt2token_proj_attention_multiplier is an integer >= 1. + # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers. + # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict, + # extend subj_basis_generator again. + if self.extend_prompt2token_proj_attention_multiplier > 1: + # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt. + # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1. + # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0. + self.subj_basis_generator.extend_prompt2token_proj_attention(\ + None, -1, -1, self.extend_prompt2token_proj_attention_multiplier, + perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio) + + self.subj_basis_generator.freeze_prompt2token_proj() + + @torch.no_grad() + def get_clip_neg_features(self, BS): + if self.clip_neg_features is None: + # neg_pixel_values: [1, 3, 224, 224]. clip_neg_features is invariant to the actual image. + neg_pixel_values = torch.zeros([1, 3, 224, 224], device=self.clip_image_encoder.device, dtype=self.dtype) + # Precompute CLIP negative features for the negative image prompt. + self.clip_neg_features = self.clip_image_encoder(neg_pixel_values, attn_mask=None, output_hidden_states=True).hidden_states[-2] + + clip_neg_features = self.clip_neg_features.repeat(BS, 1, 1) + return clip_neg_features + + # image_objs: a list of np array / tensor / Image objects of different sizes [Hi, Wi]. + # If image_objs is a list of tensors, then each tensor should be [3, Hi, Wi]. + # If image_objs is None, then image_paths should be provided, + # and image_objs will be loaded from image_paths. + # fg_masks: None, or a list of [Hi, Wi]. + def extract_init_id_embeds_from_images(self, image_objs, image_paths, fg_masks=None, + size=(512, 512), calc_avg=False, + skip_non_faces=True, return_clip_embs=None, + do_contrast_clip_embs_on_bg_features=None, + verbose=False): + # If return_clip_embs or do_contrast_clip_embs_on_bg_features is not provided, + # then use their default values. + if return_clip_embs is None: + return_clip_embs = self.use_clip_embs + if do_contrast_clip_embs_on_bg_features is None: + do_contrast_clip_embs_on_bg_features = self.do_contrast_clip_embs_on_bg_features + + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.clip_image_encoder.device + + image_pixel_values = [] + all_id_embs = [] + faceless_img_count = 0 + + if image_objs is None and image_paths is not None: + image_objs = [] + for image_path in image_paths: + image_obj = Image.open(image_path) + image_objs.append(image_obj) + print(f'Loaded {len(image_objs)} images from {image_paths[0]}...') + + # image_objs could be a batch of images that have been collated into a tensor or np array. + # image_objs can also be a list of images. + # The code below that processes them one by one can be applied in both cases. + # If image_objs are a collated batch, processing them one by one will not add much overhead. + for idx, image_obj in enumerate(image_objs): + if return_clip_embs: + # input to clip_preprocessor: an image or a batch of images, each being PIL.Image.Image, numpy.ndarray, + # torch.Tensor, tf.Tensor or jax.ndarray. + # Different sizes of images are standardized to the same size 224*224. + clip_image_pixel_values = self.clip_preprocessor(images=image_obj, return_tensors="pt").pixel_values + image_pixel_values.append(clip_image_pixel_values) + + # Convert tensor to numpy array. + if isinstance(image_obj, torch.Tensor): + image_obj = image_obj.cpu().numpy().transpose(1, 2, 0) + if isinstance(image_obj, np.ndarray): + image_obj = Image.fromarray(image_obj) + # Resize image_obj to (512, 512). The scheme is Image.NEAREST, to be consistent with + # PersonalizedBase dataset class. + image_obj, _, _ = pad_image_obj_to_square(image_obj) + image_np = np.array(image_obj.resize(size, Image.NEAREST)) + face_info = self.face_app.get(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) + if len(face_info) > 0: + face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face + # id_emb: [512,] + id_emb = torch.from_numpy(face_info.normed_embedding) + else: + faceless_img_count += 1 + print(f'No face detected in {image_paths[idx]}.', end=' ') + if not skip_non_faces: + print('Replace with random face embedding.') + # During training, use a random tensor as the face embedding. + id_emb = torch.randn(512) + else: + print(f'Skip.') + continue + + all_id_embs.append(id_emb) + + if verbose: + print(f'{len(all_id_embs)} face images identified, {faceless_img_count} faceless images.') + + # No face is detected in the input images. + if len(all_id_embs) == 0: + return faceless_img_count, None, None + + # all_id_embs: [BS, 512]. + all_id_embs = torch.stack(all_id_embs, dim=0).to(device=device, dtype=torch.float16) + + if return_clip_embs: + # image_pixel_values: [BS, 3, 224, 224] + image_pixel_values = torch.cat(image_pixel_values, dim=0) + image_pixel_values = image_pixel_values.to(device=device, dtype=torch.float16) + + if fg_masks is not None: + assert len(fg_masks) == len(image_objs) + # fg_masks is a list of masks. + if isinstance(fg_masks, (list, tuple)): + fg_masks2 = [] + for fg_mask in fg_masks: + # fg_mask: [Hi, Wi] + # BUG: clip_preprocessor will do central crop on images. But fg_mask is not central cropped. + # If the ref image is not square, then the fg_mask will not match the image. + # TODO: crop fg_mask and images to square before calling extract_init_id_embeds_from_images(). + # fg_mask2: [Hi, Wi] -> [1, 1, 224, 224] + fg_mask2 = torch.tensor(fg_mask, device=device, dtype=torch.float16).unsqueeze(0).unsqueeze(0) + fg_mask2 = F.interpolate(fg_mask2, size=image_pixel_values.shape[-2:], mode='bilinear', align_corners=False) + fg_masks2.append(fg_mask2) + # fg_masks2: [BS, 224, 224] + fg_masks2 = torch.cat(fg_masks2, dim=0).squeeze(1) + else: + # fg_masks is a collated batch of masks. + # The actual size doesn't matter, + # as fg_mask2 will be resized to the same size as image features + # (much smaller than image_pixel_values). + fg_masks2 = fg_masks.to(device=device, dtype=torch.float16).unsqueeze(1) + # F.interpolate() always return a copy, even if scale_factor=1. So we don't need to clone fg_masks2. + fg_masks2 = F.interpolate(fg_masks2, size=image_pixel_values.shape[-2:], mode='bilinear', align_corners=False) + fg_masks2 = fg_masks2.squeeze(1) + else: + # fg_mask2: [BS, 224, 224]. + fg_masks2 = torch.ones_like(image_pixel_values[:, 0, :, :], device=device, dtype=torch.float16) + + clip_neg_features = self.get_clip_neg_features(BS=image_pixel_values.shape[0]) + + with torch.no_grad(): + # image_fg_features: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds). + image_fg_dict = self.clip_image_encoder(image_pixel_values, attn_mask=fg_masks2, output_hidden_states=True) + # attn_mask: [BS, 1, 257] + image_fg_features = image_fg_dict.hidden_states[-2] + if image_fg_dict.attn_mask is not None: + image_fg_features = image_fg_features * image_fg_dict.attn_mask + + # A negative mask is used to extract the background features. + # If fg_masks is None, then fg_masks2 is all ones, and bg masks is all zeros. + # Therefore, all pixels are masked. The extracted image_bg_features will be + # meaningless in this case. + image_bg_dict = self.clip_image_encoder(image_pixel_values, attn_mask=1-fg_masks2, output_hidden_states=True) + image_bg_features = image_bg_dict.hidden_states[-2] + # Subtract the feature bias (null features) from the bg features, to highlight the useful bg features. + if do_contrast_clip_embs_on_bg_features: + image_bg_features = image_bg_features - clip_neg_features + if image_bg_dict.attn_mask is not None: + image_bg_features = image_bg_features * image_bg_dict.attn_mask + + # clip_fgbg_features: [BS, 514, 1280]. 514 = 257*2. + # all_id_embs: [BS, 512]. + clip_fgbg_features = torch.cat([image_fg_features, image_bg_features], dim=1) + else: + clip_fgbg_features = None + clip_neg_features = None + + if calc_avg: + if return_clip_embs: + # clip_fgbg_features: [BS, 514, 1280] -> [1, 514, 1280]. + # all_id_embs: [BS, 512] -> [1, 512]. + clip_fgbg_features = clip_fgbg_features.mean(dim=0, keepdim=True) + clip_neg_features = clip_neg_features.mean(dim=0, keepdim=True) + + debug = False + if debug and all_id_embs is not None: + print(image_paths) + calc_stats('all_id_embs', all_id_embs) + # Compute pairwise similarities of the embeddings. + all_id_embs = F.normalize(all_id_embs, p=2, dim=1) + pairwise_sim = torch.matmul(all_id_embs, all_id_embs.t()) + print('pairwise_sim:', pairwise_sim) + top_dir = os.path.dirname(image_paths[0]) + mean_emb_path = os.path.join(top_dir, "mean_emb.pt") + if os.path.exists(mean_emb_path): + mean_emb = torch.load(mean_emb_path) + sim_to_mean = torch.matmul(all_id_embs, mean_emb.t()) + print('sim_to_mean:', sim_to_mean) + + if all_id_embs is not None: + id_embs = all_id_embs.mean(dim=0, keepdim=True) + # Without normalization, id_embs.norm(dim=1) is ~0.9. So normalization doesn't have much effect. + id_embs = F.normalize(id_embs, p=2, dim=-1) + # id_embs is None only if insightface_app is None, i.e., disabled by the user. + else: + # Don't do average of all_id_embs. + id_embs = all_id_embs + + return faceless_img_count, id_embs, clip_fgbg_features + + # This function should be implemented in derived classes. + # We don't plan to fine-tune the ID2ImgPrompt module. So disable the gradient computation. + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + raise NotImplementedError + + # If init_id_embs/pre_clip_features is provided, then use the provided face embeddings. + # Otherwise, if image_paths/image_objs are provided, extract face embeddings from the images. + # Otherwise, we generate random face embeddings [id_batch_size, 512]. + def get_img_prompt_embs(self, init_id_embs, pre_clip_features, image_paths, image_objs, + id_batch_size, + skip_non_faces=True, + avg_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_std=0.0, + verbose=False): + face_image_count = 0 + device = self.clip_image_encoder.device + clip_neg_features = self.get_clip_neg_features(BS=id_batch_size) + + if init_id_embs is None: + # Input images are not provided. Generate random face embeddings. + if image_paths is None and image_objs is None: + faceid_embeds_from_images = False + # Use random face embeddings as faceid_embeds. [BS, 512]. + faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16) + # Since it's a batch of random IDs, the CLIP features are all zeros as a placeholder. + # Only ConsistentID_ID2AdaPrompt will use clip_fgbg_features and clip_neg_features. + # Experiments show that using random clip features yields much better images than using zeros. + clip_fgbg_features = torch.randn(id_batch_size, 514, 1280).to(device=device, dtype=torch.float16) \ + if self.use_clip_embs else None + else: + # Extract face ID embeddings and CLIP features from the images. + faceid_embeds_from_images = True + faceless_img_count, faceid_embeds, clip_fgbg_features \ + = self.extract_init_id_embeds_from_images( \ + image_objs, image_paths=image_paths, size=(512, 512), + calc_avg=(avg_at_stage == 'id_emb'), + skip_non_faces=skip_non_faces, + verbose=verbose) + + if image_paths is not None: + face_image_count = len(image_paths) - faceless_img_count + else: + face_image_count = len(image_objs) - faceless_img_count + else: + faceid_embeds_from_images = False + # Use the provided init_id_embs as faceid_embeds. + faceid_embeds = init_id_embs + if pre_clip_features is not None: + clip_fgbg_features = pre_clip_features + else: + clip_fgbg_features = None + + if faceid_embeds.shape[0] == 1: + faceid_embeds = faceid_embeds.repeat(id_batch_size, 1) + if clip_fgbg_features is not None: + clip_fgbg_features = clip_fgbg_features.repeat(id_batch_size, 1, 1) + + # If skip_non_faces, then faceid_embeds won't be None. + # Otherwise, if faceid_embeds_from_images, and no face images are detected, + # then we return Nones. + if faceid_embeds is None: + return face_image_count, None, None, None + + if perturb_at_stage == 'id_emb' and perturb_std > 0: + # If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different. + faceid_embeds = perturb_tensor(faceid_embeds, perturb_std, perturb_std_is_relative=True, keep_norm=True) + if self.name == 'consistentID' or self.name == 'jointIDs': + clip_fgbg_features = perturb_tensor(clip_fgbg_features, perturb_std, perturb_std_is_relative=True, keep_norm=True) + + faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1) + + # pos_prompt_embs, neg_prompt_embs: [BS, 77, 768] or [BS, 22, 768]. + with torch.no_grad(): + pos_prompt_embs = \ + self.map_init_id_to_img_prompt_embs(faceid_embeds, clip_fgbg_features, + called_for_neg_img_prompt=False) + + if avg_at_stage == 'img_prompt_emb': + pos_prompt_embs = pos_prompt_embs.mean(dim=0, keepdim=True) + faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True) + if clip_fgbg_features is not None: + clip_fgbg_features = clip_fgbg_features.mean(dim=0, keepdim=True) + + if perturb_at_stage == 'img_prompt_emb' and perturb_std > 0: + # NOTE: for simplicity, pos_prompt_embs and pos_core_prompt_emb are perturbed independently. + # This could cause inconsistency between pos_prompt_embs and pos_core_prompt_emb. + # But in practice, unless we use both pos_prompt_embs and pos_core_prompt_emb + # this is not an issue. But we rarely use pos_prompt_embs and pos_core_prompt_emb together. + pos_prompt_embs = perturb_tensor(pos_prompt_embs, perturb_std, perturb_std_is_relative=True, keep_norm=True) + + # If faceid_embeds_from_images, and the prompt embeddings are already averaged, then + # we assume all images are from the same subject, and the batch dim of faceid_embeds is 1. + # So we need to repeat faceid_embeds. + if faceid_embeds_from_images and avg_at_stage is not None: + faceid_embeds = faceid_embeds.repeat(id_batch_size, 1) + pos_prompt_embs = pos_prompt_embs.repeat(id_batch_size, 1, 1) + if clip_fgbg_features is not None: + clip_fgbg_features = clip_fgbg_features.repeat(id_batch_size, 1, 1) + + if self.gen_neg_img_prompt: + # Never perturb the negative prompt embeddings. + with torch.no_grad(): + neg_prompt_embs = \ + self.map_init_id_to_img_prompt_embs(torch.zeros_like(faceid_embeds), + clip_neg_features, + called_for_neg_img_prompt=True) + + return face_image_count, faceid_embeds, pos_prompt_embs, neg_prompt_embs + else: + return face_image_count, faceid_embeds, pos_prompt_embs, None + + # get_batched_img_prompt_embs() is a wrapper of get_img_prompt_embs() + # which is convenient for batched training. + # NOTE: get_batched_img_prompt_embs() should only be called during training. + # It is a wrapper of get_img_prompt_embs() which is convenient for batched training. + # If init_id_embs is None, generate random face embeddings [BS, 512]. + # Returns faceid_embeds, id2img_prompt_emb. + def get_batched_img_prompt_embs(self, batch_size, init_id_embs, pre_clip_features): + # pos_prompt_embs, neg_prompt_embs are generated without gradient computation. + # So we don't need to worry that the teacher model weights are updated. + return self.get_img_prompt_embs(init_id_embs=init_id_embs, + pre_clip_features=pre_clip_features, + image_paths=None, + image_objs=None, + id_batch_size=batch_size, + # During training, don't skip non-face images. Instead, + # setting skip_non_faces=False will replace them by random face embeddings. + skip_non_faces=False, + # We always assume the instances belong to different subjects. + # So never average the embeddings across instances. + avg_at_stage=None, + verbose=False) + + # If img_prompt_embs is provided, we use it directly. + # Otherwise, if face_id_embs is provided, we use it to generate img_prompt_embs. + # Otherwise, if image_paths is provided, we extract face_id_embs from the images. + # image_paths: a list of image paths. image_folder: the parent folder name. + # avg_at_stage: 'id_emb', 'img_prompt_emb', or None. + # avg_at_stage == ada_prompt_emb usually produces the worst results. + # avg_at_stage == id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better. + # p_dropout and return_zero_embs_for_dropped_encoders are only used by Joint_FaceID2AdaPrompt. + def generate_adaface_embeddings(self, image_paths, face_id_embs=None, img_prompt_embs=None, + p_dropout=0, + return_zero_embs_for_dropped_encoders=True, + avg_at_stage='id_emb', # id_emb, img_prompt_emb, or None. + perturb_at_stage=None, # id_emb, img_prompt_emb, or None. + perturb_std=0, enable_static_img_suffix_embs=False): + if (avg_at_stage is None) or avg_at_stage.lower() == 'none': + img_prompt_avg_at_stage = None + else: + img_prompt_avg_at_stage = avg_at_stage + + if img_prompt_embs is None: + # Do averaging. So id_batch_size becomes 1 after averaging. + if img_prompt_avg_at_stage is not None: + id_batch_size = 1 + else: + if face_id_embs is not None: + id_batch_size = face_id_embs.shape[0] + elif image_paths is not None: + id_batch_size = len(image_paths) + else: + id_batch_size = 1 + + # faceid_embeds: [BS, 512] is a batch of extracted face analysis embeddings. NOT used later. + # NOTE: If face_id_embs, image_paths and image_objs are all None, + # then get_img_prompt_embs() generates random faceid_embeds/img_prompt_embs, + # and each instance is different. + # Otherwise, if face_id_embs is provided, it's used. + # If not, image_paths/image_objs are used to extract face embeddings. + # img_prompt_embs is in the image prompt space. + # img_prompt_embs: [BS, 16/4, 768]. + face_image_count, faceid_embeds, img_prompt_embs, neg_img_prompt_embs \ + = self.get_img_prompt_embs(\ + init_id_embs=face_id_embs, + pre_clip_features=None, + # image_folder is passed only for logging purpose. + # image_paths contains the paths of the images. + image_paths=image_paths, image_objs=None, + id_batch_size=id_batch_size, + perturb_at_stage=perturb_at_stage, + perturb_std=perturb_std, + avg_at_stage=img_prompt_avg_at_stage, + verbose=True) + + if face_image_count == 0: + return None + + # No matter whether avg_at_stage is id_emb or img_prompt_emb, we average img_prompt_embs. + elif avg_at_stage is not None and avg_at_stage.lower() != 'none': + # img_prompt_embs: [BS, 16/4, 768] -> [1, 16/4, 768]. + img_prompt_embs = img_prompt_embs.mean(dim=0, keepdim=True) + + # adaface_subj_embs: [BS, 16/4, 768]. + adaface_subj_embs = \ + self.subj_basis_generator(img_prompt_embs, clip_features=None, raw_id_embs=None, + out_id_embs_cfg_scale=self.out_id_embs_cfg_scale, + is_face=True, + enable_static_img_suffix_embs=enable_static_img_suffix_embs) + # During training, img_prompt_avg_at_stage is None, and BS >= 1. + # During inference, img_prompt_avg_at_stage is 'id_emb' or 'img_prompt_emb', and BS == 1. + if img_prompt_avg_at_stage is not None: + # adaface_subj_embs: [1, 16, 768] -> [16, 768] + adaface_subj_embs = adaface_subj_embs.squeeze(0) + + return adaface_subj_embs + +class Arc2Face_ID2AdaPrompt(FaceID2AdaPrompt): + def __init__(self, *args, **kwargs): + self.name = 'arc2face' + self.num_id_vecs = 16 + + super().__init__(*args, **kwargs) + + self.clip_image_encoder = CLIPVisionModelWithMask.from_pretrained('openai/clip-vit-large-patch14') + self.clip_preprocessor = CLIPImageProcessor.from_pretrained('openai/clip-vit-large-patch14') + self.clip_image_encoder.eval() + if self.dtype == torch.float16: + self.clip_image_encoder.half() + print(f'CLIP image encoder loaded.') + + ''' + {'landmark_3d_68': , + 'landmark_2d_106': , + 'detection': , + 'genderage': , + 'recognition': } + ''' + # Use the same model as ID2AdaPrompt does. + # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2. + # Note there's a second "model" in the path. + # Note DON'T use CUDAExecutionProvider, as it will hang DDP training. + # Seems when loading insightface onto the GPU, it will only reside on the first GPU. + # Then the process on the second GPU has issue to communicate with insightface on the first GPU, causing hanging. + self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', + providers=['CPUExecutionProvider']) + self.face_app.prepare(ctx_id=0, det_size=(512, 512)) + print(f'Face encoder loaded on CPU.') + + self.text_to_image_prompt_encoder = CLIPTextModelWrapper.from_pretrained( + 'models/arc2face', subfolder="encoder", + torch_dtype=self.dtype + ) + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + if self.out_id_embs_cfg_scale == -1: + self.out_id_embs_cfg_scale = 1 + #### Arc2Face pipeline specific configs #### + self.gen_neg_img_prompt = False + # bg CLIP features are used by the bg subject basis generator. + self.use_clip_embs = True + self.do_contrast_clip_embs_on_bg_features = True + # self.num_static_img_suffix_embs is initialized in the parent class. + self.id_img_prompt_max_length = 22 + self.clip_embedding_dim = 1024 + + self.init_subj_basis_generator() + if self.adaface_ckpt_path is not None: + self.load_adaface_ckpt(self.adaface_ckpt_path) + + print(f"{self.name} ada prompt encoder initialized, " + f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.") + + # Arc2Face_ID2AdaPrompt never uses clip_features or called_for_neg_img_prompt. + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + + ''' + self.text_to_image_prompt_encoder: arc2face_models.py:CLIPTextModelWrapper instance. + init_id_embs: (N, 512) normalized Face ID embeddings. + ''' + + # arcface_token_id: 1014 + arcface_token_id = self.tokenizer.encode("id", add_special_tokens=False)[0] + + # This step should be quite fast, and there's no need to cache the input_ids. + input_ids = self.tokenizer( + "photo of a id person", + truncation=True, + padding="max_length", + # In Arc2Face_ID2AdaPrompt, id_img_prompt_max_length is 22. + # Arc2Face's image prompt is meanlingless in tokens other than ID tokens. + max_length=self.id_img_prompt_max_length, + return_tensors="pt", + ).input_ids.to(init_id_embs.device) + # input_ids: [1, 22] or [3, 22] (during training). + input_ids = input_ids.repeat(len(init_id_embs), 1) + init_id_embs = init_id_embs.to(self.dtype) + # face_embs_padded: [1, 512] -> [1, 768]. + face_embs_padded = F.pad(init_id_embs, (0, self.text_to_image_prompt_encoder.config.hidden_size - init_id_embs.shape[-1]), "constant", 0) + # self.text_to_image_prompt_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping). + # The second call does the ordinary CLIP text encoding pass. + token_embs = self.text_to_image_prompt_encoder(input_ids=input_ids, return_token_embs=True) + token_embs[input_ids==arcface_token_id] = face_embs_padded + + prompt_embeds = self.text_to_image_prompt_encoder( + input_ids=input_ids, + input_token_embs=token_embs, + return_token_embs=False + )[0] + + # Restore the original dtype of prompt_embeds: float16 -> float32. + prompt_embeds = prompt_embeds.to(self.dtype) + + # token 4: 'id' in "photo of a id person". + # 4:20 are the most important 16 embeddings that contain the subject's identity. + # [N, 22, 768] -> [N, 16, 768] + return prompt_embeds[:, 4:20] + + def get_id2img_learnable_modules(self): + return [ self.text_to_image_prompt_encoder ] + +# ConsistentID_ID2AdaPrompt is just a wrapper of ConsistentIDPipeline, so it's not an nn.Module. +class ConsistentID_ID2AdaPrompt(FaceID2AdaPrompt): + def __init__(self, pipe=None, base_model_path="models/sd15-dste8-vae.safetensors", + *args, **kwargs): + self.name = 'consistentID' + self.num_id_vecs = 4 + + super().__init__(*args, **kwargs) + if pipe is None: + # The base_model_path is kind of arbitrary, as the UNet and VAE in the model + # are not used and will be released soon. + # Only the consistentID modules and bise_net are used. + assert base_model_path is not None, "base_model_path should be provided." + pipe = ConsistentIDPipeline.from_single_file( + base_model_path, + ) + pipe.load_ConsistentID_model(consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin", + bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth") + pipe.to(dtype=self.dtype) + # Since the passed-in pipe is None, this should be called during inference, + # when the teacher ConsistentIDPipeline is not initialized. + # Therefore, we release VAE, UNet and text_encoder to save memory. + pipe.release_components(["unet", "vae"]) + + # Otherwise, we share the pipeline with the teacher. + # So we don't release the components. + self.pipe = pipe + self.face_app = pipe.face_app + # ConsistentID uses 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'. + self.clip_image_encoder = patch_clip_image_encoder_with_mask(pipe.clip_encoder) + self.clip_preprocessor = pipe.clip_preprocessor + self.text_to_image_prompt_encoder = pipe.text_encoder + self.tokenizer = pipe.tokenizer + self.image_proj_model = pipe.image_proj_model + + self.clip_image_encoder.eval() + self.image_proj_model.eval() + if self.dtype == torch.float16: + self.clip_image_encoder.half() + self.image_proj_model.half() + + if self.out_id_embs_cfg_scale == -1: + self.out_id_embs_cfg_scale = 6 + #### ConsistentID pipeline specific configs #### + # self.num_static_img_suffix_embs is initialized in the parent class. + self.gen_neg_img_prompt = True + self.use_clip_embs = True + self.do_contrast_clip_embs_on_bg_features = True + self.clip_embedding_dim = 1280 + self.s_scale = 1.0 + self.shortcut = False + + self.init_subj_basis_generator() + if self.adaface_ckpt_path is not None: + self.load_adaface_ckpt(self.adaface_ckpt_path) + + print(f"{self.name} ada prompt encoder initialized, " + f"ID vecs: {self.num_id_vecs}, static suffix: {self.num_static_img_suffix_embs}.") + + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + assert init_id_embs is not None, "init_id_embs should be provided." + + init_id_embs = init_id_embs.to(self.dtype) + clip_features = clip_features.to(self.dtype) + + if not called_for_neg_img_prompt: + # clip_features: [BS, 514, 1280]. + # clip_features is provided when the function is called within + # ConsistentID_ID2AdaPrompt:extract_init_id_embeds_from_images(), which is + # image_fg_features and image_bg_features concatenated at dim=1. + # Therefore, we split clip_image_double_embeds into image_fg_features and image_bg_features. + # image_bg_features is not used in ConsistentID_ID2AdaPrompt. + image_fg_features, image_bg_features = clip_features.chunk(2, dim=1) + # clip_image_embeds: [BS, 257, 1280]. + clip_image_embeds = image_fg_features + else: + # clip_features is the negative image features. So we don't need to split it. + clip_image_embeds = clip_features + init_id_embs = torch.zeros_like(init_id_embs) + + faceid_embeds = init_id_embs + # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings. + # clip_image_embeds are used as queries to transform faceid_embeds. + # faceid_embeds -> kv, clip_image_embeds -> q + if faceid_embeds.shape[0] != clip_image_embeds.shape[0]: + breakpoint() + + try: + global_id_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=self.shortcut, scale=self.s_scale) + except: + breakpoint() + + return global_id_embeds + + def get_id2img_learnable_modules(self): + return [ self.image_proj_model ] + +# A wrapper for combining multiple FaceID2AdaPrompt instances. +class Joint_FaceID2AdaPrompt(FaceID2AdaPrompt): + def __init__(self, adaface_encoder_types, adaface_ckpt_paths, + out_id_embs_cfg_scales=None, enabled_encoders=None, + *args, **kwargs): + self.name = 'jointIDs' + assert len(adaface_encoder_types) > 0, "adaface_encoder_types should not be empty." + adaface_encoder_types2num_id_vecs = { 'arc2face': 16, 'consistentID': 4 } + self.encoders_num_id_vecs = [ adaface_encoder_types2num_id_vecs[encoder_type] \ + for encoder_type in adaface_encoder_types ] + self.num_id_vecs = sum(self.encoders_num_id_vecs) + super().__init__(*args, **kwargs) + + self.num_sub_encoders = len(adaface_encoder_types) + self.id2ada_prompt_encoders = nn.ModuleList() + self.encoders_num_static_img_suffix_embs = [] + + # TODO: apply adaface_encoder_cfg_scales to influence the final prompt embeddings. + # Now they are just placeholders. + if out_id_embs_cfg_scales is None: + # -1: use the default scale for the adaface encoder type. + # i.e., 6 for arc2face and 1 for consistentID. + self.out_id_embs_cfg_scales = [-1] * self.num_sub_encoders + else: + # Do not normalize the weights, and just use them as is. + self.out_id_embs_cfg_scales = out_id_embs_cfg_scales + + # Note we don't pass the adaface_ckpt_paths to the base class, but instead, + # we load them once and for all in self.load_adaface_ckpt(). + for i, encoder_type in enumerate(adaface_encoder_types): + kwargs['out_id_embs_cfg_scale'] = self.out_id_embs_cfg_scales[i] + if encoder_type == 'arc2face': + encoder = Arc2Face_ID2AdaPrompt(*args, **kwargs) + elif encoder_type == 'consistentID': + encoder = ConsistentID_ID2AdaPrompt(*args, **kwargs) + else: + breakpoint() + self.id2ada_prompt_encoders.append(encoder) + self.encoders_num_static_img_suffix_embs.append(encoder.num_static_img_suffix_embs) + + self.num_static_img_suffix_embs = sum(self.encoders_num_static_img_suffix_embs) + # No need to set gen_neg_img_prompt, as we don't access it in this class, but rather + # in the derived classes. + # self.gen_neg_img_prompt = True + # self.use_clip_embs = True + # self.do_contrast_clip_embs_on_bg_features = True + self.face_id_dims = [encoder.face_id_dim for encoder in self.id2ada_prompt_encoders] + self.face_id_dim = sum(self.face_id_dims) + # Different adaface encoders may have different clip_embedding_dim. + # clip_embedding_dim is only used for bg subject basis generator. + # Here we use the joint clip embeddings of both OpenAI CLIP and laion CLIP. + # Therefore, the clip_embedding_dim is the sum of the clip_embedding_dims of all adaface encoders. + self.clip_embedding_dims = [encoder.clip_embedding_dim for encoder in self.id2ada_prompt_encoders] + self.clip_embedding_dim = sum(self.clip_embedding_dims) + # The ctors of the derived classes have already initialized encoder.subj_basis_generator. + # If subj_basis_generator expansion params are specified, they are equally applied to all adaface encoders. + # This self.subj_basis_generator is not meant to be called as self.subj_basis_generator(), but instead, + # it's used as a unified interface to save/load the subj_basis_generator of all adaface encoders. + self.subj_basis_generator = \ + nn.ModuleList( [encoder.subj_basis_generator for encoder \ + in self.id2ada_prompt_encoders] ) + + if adaface_ckpt_paths is not None: + self.load_adaface_ckpt(adaface_ckpt_paths) + + print(f"{self.name} ada prompt encoder initialized with {self.num_sub_encoders} sub-encoders. " + f"ID vecs: {self.num_id_vecs}, static suffix embs: {self.num_static_img_suffix_embs}.") + + if enabled_encoders is not None: + self.are_encoders_enabled = \ + torch.tensor([True if encoder_type in enabled_encoders else False \ + for encoder_type in adaface_encoder_types]) + if not self.are_encoders_enabled.any(): + print(f"All encoders are disabled, which shoudn't happen.") + breakpoint() + if self.are_encoders_enabled.sum() < self.num_sub_encoders: + disabled_encoders = [ encoder_type for i, encoder_type in enumerate(adaface_encoder_types) \ + if not self.are_encoders_enabled[i] ] + print(f"{len(disabled_encoders)} encoders are disabled: {disabled_encoders}.") + else: + self.are_encoders_enabled = \ + torch.tensor([True] * self.num_sub_encoders) + + def load_adaface_ckpt(self, adaface_ckpt_paths): + # If only one adaface ckpt path is provided, then we assume it's the ckpt of the Joint_FaceID2AdaPrompt, + # so we dereference the list to get the actual path and load the subj_basis_generators of all adaface encoders. + if isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)): + if len(adaface_ckpt_paths) == 1 and self.num_sub_encoders > 1: + adaface_ckpt_paths = adaface_ckpt_paths[0] + + if isinstance(adaface_ckpt_paths, str): + # This is only applicable to newest ckpts of Joint_FaceID2AdaPrompt, where + # the ckpt_subj_basis_generator is an nn.ModuleList of multiple subj_basis_generators. + # Therefore, no need to patch missing variables. + ckpt = torch.load(adaface_ckpt_paths, map_location='cpu') + string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"] + if self.subject_string not in string_to_subj_basis_generator_dict: + print(f"Subject '{self.subject_string}' not found in the embedding manager.") + breakpoint() + + ckpt_subj_basis_generators = string_to_subj_basis_generator_dict[self.subject_string] + for i, subj_basis_generator in enumerate(self.subj_basis_generator): + ckpt_subj_basis_generator = ckpt_subj_basis_generators[i] + # Handle differences in num_static_img_suffix_embs between the current model and the ckpt. + ckpt_subj_basis_generator.initialize_static_img_suffix_embs(self.encoders_num_static_img_suffix_embs[i], + img_prompt_dim=self.output_dim) + + subj_basis_generator.extend_prompt2token_proj_attention(\ + ckpt_subj_basis_generator.prompt2token_proj_attention_multipliers, -1, -1, 1, perturb_std=0) + subj_basis_generator.load_state_dict(ckpt_subj_basis_generator.state_dict()) + + # extend_prompt2token_proj_attention_multiplier is an integer >= 1. + # TODO: extend_prompt2token_proj_attention_multiplier should be a list of integers. + # If extend_prompt2token_proj_attention_multiplier > 1, then after loading state_dict, + # extend subj_basis_generator again. + if self.extend_prompt2token_proj_attention_multiplier > 1: + # During this extension, the added noise does change the extra copies of attention weights, since they are not in the ckpt. + # During training, prompt2token_proj_ext_attention_perturb_ratio == 0.1. + # During inference, prompt2token_proj_ext_attention_perturb_ratio == 0. + subj_basis_generator.extend_prompt2token_proj_attention(\ + None, -1, -1, self.extend_prompt2token_proj_attention_multiplier, + perturb_std=self.prompt2token_proj_ext_attention_perturb_ratio) + + subj_basis_generator.freeze_prompt2token_proj() + + print(f"{adaface_ckpt_paths}: {len(self.subj_basis_generator)} subj_basis_generators loaded for {self.name}.") + + elif isinstance(adaface_ckpt_paths, (list, tuple, ListConfig)): + for i, ckpt_path in enumerate(adaface_ckpt_paths): + self.id2ada_prompt_encoders[i].load_adaface_ckpt(ckpt_path) + else: + breakpoint() + + def extract_init_id_embeds_from_images(self, *args, **kwargs): + total_faceless_img_count = 0 + all_id_embs = [] + all_clip_fgbg_features = [] + id_embs_shape = None + clip_fgbg_features_shape = None + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.id2ada_prompt_encoders[0].clip_image_encoder.device + + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + faceless_img_count, id_embs, clip_fgbg_features = \ + id2ada_prompt_encoder.extract_init_id_embeds_from_images(*args, **kwargs) + total_faceless_img_count += faceless_img_count + # id_embs: [BS, 512] or [1, 512] (if calc_avg == True), or None. + # id_embs has the same shape across all id2ada_prompt_encoders. + all_id_embs.append(id_embs) + # clip_fgbg_features: [BS, 514, 1280/1024] or [1, 514, 1280/1024] (if calc_avg == True), or None. + # clip_fgbg_features has the same shape except for the last dimension across all id2ada_prompt_encoders. + all_clip_fgbg_features.append(clip_fgbg_features) + if id_embs is not None: + id_embs_shape = id_embs.shape + if clip_fgbg_features is not None: + clip_fgbg_features_shape = clip_fgbg_features.shape + + num_extracted_id_embs = 0 + for i in range(len(all_id_embs)): + if all_id_embs[i] is not None: + # As calc_avg is the same for all id2ada_prompt_encoders, + # each id_embs and clip_fgbg_features should have the same shape, if they are not None. + if all_id_embs[i].shape != id_embs_shape: + print("Inconsistent ID embedding shapes.") + breakpoint() + else: + num_extracted_id_embs += 1 + else: + all_id_embs[i] = torch.zeros(id_embs_shape, dtype=torch.float16, device=device) + + clip_fgbg_features_shape2 = torch.Size(clip_fgbg_features_shape[:-1] + (self.clip_embedding_dims[i],)) + if all_clip_fgbg_features[i] is not None: + if all_clip_fgbg_features[i].shape != clip_fgbg_features_shape2: + print("Inconsistent clip features shapes.") + breakpoint() + else: + all_clip_fgbg_features[i] = torch.zeros(clip_fgbg_features_shape2, + dtype=torch.float16, device=device) + + # If at least one face encoder detects faces, then return the embeddings. + # Otherwise return None embeddings. + # It's possible that some face encoders detect faces, while others don't, + # since different face encoders use different face detection models. + if num_extracted_id_embs == 0: + return 0, None, None + + all_id_embs = torch.cat(all_id_embs, dim=1) + # clip_fgbg_features: [BS, 514, 1280] or [BS, 514, 1024]. So we concatenate them along dim=2. + all_clip_fgbg_features = torch.cat(all_clip_fgbg_features, dim=2) + return total_faceless_img_count, all_id_embs, all_clip_fgbg_features + + # init_id_embs, clip_features are never None. + def map_init_id_to_img_prompt_embs(self, init_id_embs, + clip_features=None, + called_for_neg_img_prompt=False): + if init_id_embs is None or clip_features is None: + breakpoint() + + # each id_embs and clip_fgbg_features should have the same shape. + # If some of them were None, they have been replaced by zero embeddings. + all_init_id_embs = init_id_embs.split(self.face_id_dims, dim=1) + all_clip_features = clip_features.split(self.clip_embedding_dims, dim=2) + all_img_prompt_embs = [] + + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + img_prompt_embs = id2ada_prompt_encoder.map_init_id_to_img_prompt_embs( + all_init_id_embs[i], clip_features=all_clip_features[i], + called_for_neg_img_prompt=called_for_neg_img_prompt, + ) + all_img_prompt_embs.append(img_prompt_embs) + + all_img_prompt_embs = torch.cat(all_img_prompt_embs, dim=1) + return all_img_prompt_embs + + # If init_id_embs/pre_clip_features is provided, then use the provided face embeddings. + # Otherwise, if image_paths/image_objs are provided, extract face embeddings from the images. + # Otherwise, we generate random face embeddings [id_batch_size, 512]. + def get_img_prompt_embs(self, init_id_embs, pre_clip_features, *args, **kwargs): + face_image_counts = [] + all_faceid_embeds = [] + all_pos_prompt_embs = [] + all_neg_prompt_embs = [] + faceid_embeds_shape = None + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.id2ada_prompt_encoders[0].clip_image_encoder.device + + # init_id_embs, pre_clip_features could be None. If they are None, + # we split them into individual vectors for each id2ada_prompt_encoder. + if init_id_embs is not None: + all_init_id_embs = init_id_embs.split(self.face_id_dims, dim=1) + else: + all_init_id_embs = [None] * self.num_sub_encoders + if pre_clip_features is not None: + all_pre_clip_features = pre_clip_features.split(self.clip_embedding_dims, dim=2) + else: + all_pre_clip_features = [None] * self.num_sub_encoders + + faceid_embeds_shape = None + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + face_image_count, faceid_embeds, pos_prompt_embs, neg_prompt_embs = \ + id2ada_prompt_encoder.get_img_prompt_embs(all_init_id_embs[i], all_pre_clip_features[i], + *args, **kwargs) + face_image_counts.append(face_image_count) + all_faceid_embeds.append(faceid_embeds) + all_pos_prompt_embs.append(pos_prompt_embs) + all_neg_prompt_embs.append(neg_prompt_embs) + # all faceid_embeds have the same shape across all id2ada_prompt_encoders. + # But pos_prompt_embs and neg_prompt_embs may have different number of ID embeddings. + if faceid_embeds is not None: + faceid_embeds_shape = faceid_embeds.shape + + if faceid_embeds_shape is None: + return 0, None, None, None + + # We take the maximum face_image_count among all adaface encoders. + face_image_count = max(face_image_counts) + BS = faceid_embeds.shape[0] + + for i in range(len(all_faceid_embeds)): + if all_faceid_embeds[i] is not None: + if all_faceid_embeds[i].shape != faceid_embeds_shape: + print("Inconsistent face embedding shapes.") + breakpoint() + else: + all_faceid_embeds[i] = torch.zeros(faceid_embeds_shape, dtype=torch.float16, device=device) + + N_ID = self.encoders_num_id_vecs[i] + if all_pos_prompt_embs[i] is None: + # Both pos_prompt_embs and neg_prompt_embs have N_ID == num_id_vecs embeddings. + all_pos_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device) + if all_neg_prompt_embs[i] is None: + all_neg_prompt_embs[i] = torch.zeros((BS, N_ID, 768), dtype=torch.float16, device=device) + + all_faceid_embeds = torch.cat(all_faceid_embeds, dim=1) + all_pos_prompt_embs = torch.cat(all_pos_prompt_embs, dim=1) + all_neg_prompt_embs = torch.cat(all_neg_prompt_embs, dim=1) + + return face_image_count, all_faceid_embeds, all_pos_prompt_embs, all_neg_prompt_embs + + # We don't need to implement get_batched_img_prompt_embs() since the interface + # is fully compatible with FaceID2AdaPrompt.get_batched_img_prompt_embs(). + + def generate_adaface_embeddings(self, image_paths, face_id_embs=None, + img_prompt_embs=None, p_dropout=0, + return_zero_embs_for_dropped_encoders=True, + *args, **kwargs): + # clip_image_encoder should be already put on GPU. + # So its .device is the device of its parameters. + device = self.id2ada_prompt_encoders[0].clip_image_encoder.device + is_emb_averaged = kwargs.get('avg_at_stage', None) is not None + BS = -1 + + if face_id_embs is not None: + BS = face_id_embs.shape[0] + all_face_id_embs = face_id_embs.split(self.face_id_dims, dim=1) + else: + all_face_id_embs = [None] * self.num_sub_encoders + if img_prompt_embs is not None: + BS = img_prompt_embs.shape[0] if BS == -1 else BS + if img_prompt_embs.shape[1] != self.num_id_vecs: + breakpoint() + all_img_prompt_embs = img_prompt_embs.split(self.encoders_num_id_vecs, dim=1) + else: + all_img_prompt_embs = [None] * self.num_sub_encoders + if image_paths is not None: + BS = len(image_paths) if BS == -1 else BS + if BS == -1: + breakpoint() + + # During training, p_dropout is 0.1. During inference, p_dropout is 0. + # When there are two sub-encoders, the prob of one encoder being dropped is + # p_dropout * 2 - p_dropout^2 = 0.18. + if p_dropout > 0: + # self.are_encoders_enabled is a global mask. + # are_encoders_enabled is a local mask for each batch. + are_encoders_enabled = torch.rand(self.num_sub_encoders) < p_dropout + are_encoders_enabled = are_encoders_enabled & self.are_encoders_enabled + # We should at least enable one encoder. + if not are_encoders_enabled.any(): + # Randomly enable an encoder with self.are_encoders_enabled[i] == True. + enabled_indices = torch.nonzero(self.are_encoders_enabled).squeeze(1) + sel_idx = torch.randint(0, len(enabled_indices), (1,)).item() + are_encoders_enabled[enabled_indices[sel_idx]] = True + else: + are_encoders_enabled = self.are_encoders_enabled + + all_adaface_subj_embs = [] + num_available_id_vecs = 0 + + for i, id2ada_prompt_encoder in enumerate(self.id2ada_prompt_encoders): + if not are_encoders_enabled[i]: + adaface_subj_embs = None + print(f"Encoder {id2ada_prompt_encoder.name} is dropped.") + else: + # ddpm.embedding_manager.train() -> id2ada_prompt_encoder.train() -> each sub-enconder's train(). + # -> each sub-enconder's subj_basis_generator.train(). + # Therefore grad for the following call is enabled. + adaface_subj_embs = \ + id2ada_prompt_encoder.generate_adaface_embeddings(image_paths, + all_face_id_embs[i], + all_img_prompt_embs[i], + *args, **kwargs) + + # adaface_subj_embs: [16, 768] or [4, 768]. + N_ID = self.encoders_num_id_vecs[i] + if adaface_subj_embs is None: + if not return_zero_embs_for_dropped_encoders: + continue + else: + subj_emb_shape = (N_ID, 768) if is_emb_averaged else (BS, N_ID, 768) + # adaface_subj_embs is zero-filled. So N_ID is not counted as available subject embeddings. + adaface_subj_embs = torch.zeros(subj_emb_shape, dtype=torch.float16, device=device) + all_adaface_subj_embs.append(adaface_subj_embs) + else: + all_adaface_subj_embs.append(adaface_subj_embs) + num_available_id_vecs += N_ID + + # No faces are found in the images, so return None embeddings. + # We don't want to return an all-zero embedding, which is useless. + if num_available_id_vecs == 0: + return None + + # If id2ada_prompt_encoders are ["arc2face", "consistentID"], then + # during inference, we average across the batch dim. + # all_adaface_subj_embs[0]: [4, 768]. all_adaface_subj_embs[1]: [16, 768]. + # all_adaface_subj_embs: [20, 768]. + # during training, we don't average across the batch dim. + # all_adaface_subj_embs[0]: [BS, 4, 768]. all_adaface_subj_embs[1]: [BS, 16, 768]. + # all_adaface_subj_embs: [BS, 20, 768]. + all_adaface_subj_embs = torch.cat(all_adaface_subj_embs, dim=-2) + return all_adaface_subj_embs + + +''' +# For ip-adapter distillation on objects. Strictly speaking, it's not face-to-image prompts, but +# CLIP/DINO visual features to image prompts. +class Objects_Vis2ImgPrompt(nn.Module): + def __init__(self): + self.dino_encoder = ViTModel.from_pretrained('facebook/dino-vits16') + self.dino_encoder.eval() + self.dino_encoder.half() + self.dino_preprocess = ViTFeatureExtractor.from_pretrained('facebook/dino-vits16') + print(f'DINO encoder loaded.') + +''' diff --git a/adaface/subj_basis_generator.py b/adaface/subj_basis_generator.py index bbd55f1ea2ec8e7630914249f31ea0a17c115e77..d8944461435e46d03a953078b85430fdaeafb87e 100644 --- a/adaface/subj_basis_generator.py +++ b/adaface/subj_basis_generator.py @@ -7,27 +7,20 @@ import math import torch from torch import nn -import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange -from transformers import CLIPVisionModel, CLIPTokenizer +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig -import numpy as np from torch import einsum -from dataclasses import dataclass -from typing import Optional, Tuple -from transformers.utils import ModelOutput -from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler +from adaface.util import gen_gradient_scaler from adaface.arc2face_models import CLIPTextModelWrapper -import sys -sys.modules['ldm'] = sys.modules['adaface'] def reshape_tensor(x, num_heads): bs, length, width = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, num_heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) + x = x.transpose(1, 2).contiguous() # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs, num_heads, length, -1) return x @@ -314,13 +307,13 @@ class CrossAttention(nn.Module): if self.q_aware_to_v: # v: [6, 64, 17, 128]. # v is query-specific, so there's an extra dim for the query. - v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h) + v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h).contiguous() # Each v is for a query group with 512/64 = 8 queries. # So each v is repeated 8 times to match the number of queries. # v: [6, 64, 17, 128] -> [6, 512, 17, 128]. v = v.repeat(1, self.v_repeat, 1, 1) else: - v = rearrange(v, 'b n (h d) -> (b h) n d', h=h) + v = rearrange(v, 'b n (h d) -> (b h) n d', h=h).contiguous() if attn_mat is None: scale = q.size(-1) ** -0.25 @@ -344,7 +337,7 @@ class CrossAttention(nn.Module): out = einsum('b i j, b j d -> b i d', attn, v) # [6, 32, 128] -> [1, 32, 768]. - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h).contiguous() if self.out_has_skip: out = self.to_out(out) + out @@ -356,91 +349,290 @@ class CrossAttention(nn.Module): else: return out -class SubjBasisGenerator(nn.Module): +class ImgPrompt2TextPrompt(nn.Module): + def __init__(self, placeholder_is_bg, num_id_vecs, dtype=torch.float32, *args, **kwargs): + super().__init__() + self.N_ID = num_id_vecs + # If not placeholder_is_bg, then N_SFX will be updated in initialize_text_components(). + self.N_SFX = 0 + + if not placeholder_is_bg: + self.initialize_text_components(*args, **kwargs) + + # prompt2token_proj: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**. + # prompt2token_proj is with the same architecture as the original arc2face text encoder, + # but retrained to do inverse mapping. + # To be initialized in the subclass. + self.prompt2token_proj = None + self.dtype = dtype + + def initialize_static_img_suffix_embs(self, num_static_img_suffix_embs, img_prompt_dim=768): + self.N_SFX = num_static_img_suffix_embs + # We always take the first num_static_img_suffix_embs embeddings out of static_img_suffix_embs. + # So it's OK that static_img_suffix_embs is larger than required number num_static_img_suffix_embs. + # This holds even if num_static_img_suffix_embs is 0. + if hasattr(self, 'static_img_suffix_embs') and self.static_img_suffix_embs is not None: + if self.static_img_suffix_embs.shape[1] == self.N_SFX: + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs ({self.N_SFX} required). Skip initialization.") + elif self.static_img_suffix_embs.shape[1] < self.N_SFX: + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (< {self.N_SFX} required). Reinitialize.") + self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) + elif self.N_SFX > 0: + # self.static_img_suffix_embs.shape[1] > self.N_SFX > 0. + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (> {self.N_SFX} required). Truncate.") + self.static_img_suffix_embs = nn.Parameter(self.static_img_suffix_embs[:, :self.N_SFX]) + else: + # self.static_img_suffix_embs.shape[1] > self.N_SFX == 0. + print(f"static_img_suffix_embs had been initialized to be {self.static_img_suffix_embs.shape[1]} vecs (0 required). Erase.") + self.static_img_suffix_embs = None + else: + if self.N_SFX > 0: + # Either static_img_suffix_embs does not exist or is None, + # or it's initialized but has fewer than num_static_img_suffix_embs embeddings (this situation should be very rare, + # so we don't consider to reuse and extend a shorter static_img_suffix_embs). + # So we reinitialize it. + self.static_img_suffix_embs = nn.Parameter(torch.randn(1, self.N_SFX, img_prompt_dim)) + else: + # If static_img_suffix_embs had been initialized, then it will be set to None, i.e., erased from the SubjBasisGenerator instance. + self.static_img_suffix_embs = None + + # Implement a separate initialization function, so that it can be called from SubjBasisGenerator + # after the SubjBasisGenerator is initialized. This can be used to fix old SubjBasisGenerator + # ckpts which were not subclassed from ImgPrompt2TextPrompt. + def initialize_text_components(self, max_prompt_length=77, num_id_vecs=16, + num_static_img_suffix_embs=0, img_prompt_dim=768): + self.initialize_static_img_suffix_embs(num_static_img_suffix_embs, img_prompt_dim) + self.max_prompt_length = max_prompt_length + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + # clip_text_embeddings: CLIPTextEmbeddings instance. + clip_text_embeddings = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").text_model.embeddings + # clip_text_embeddings() and clip_text_embeddings.token_embedding() differ in that + # clip_text_embeddings() adds positional embeddings, while clip_text_embeddings.token_embedding() doesn't. + # Adding positional embeddings seems to help somewhat. + # pad_tokens: pad_token_id 49407 repeated 77 times. + # pad_token_id is the EOS token. But BOS is 49406. + pad_tokens = torch.tensor([self.tokenizer.pad_token_id]).repeat(self.max_prompt_length) + # pad_embeddings: [77, 768]. + # pad_embeddings is still on CPU. But should be moved to GPU automatically. + # Note: detach pad_embeddings from the computation graph, otherwise + # deepcopy() in embedding_manager.py:make_frozen_copy_of_subj_basis_generators() will fail. + self.pad_embeddings = clip_text_embeddings(pad_tokens)[0].detach() + + # image prompt space -> text prompt space. + # return_emb_types: a list of strings, each string is among + # ['full', 'core', 'full_pad', 'full_half_pad']. + def inverse_img_prompt_embs(self, face_prompt_embs, list_extra_words, + return_emb_types, hidden_state_layer_weights=None, + enable_static_img_suffix_embs=False): + + ''' + face_prompt_embs: (BS, self.N_ID, 768), in the image prompt space. + Only the core embeddings, no paddings. + list_extra_words: None or [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt. + ''' + if list_extra_words is not None: + if len(list_extra_words) != len(face_prompt_embs): + if len(face_prompt_embs) > 1: + print("Warn: list_extra_words has different length as face_prompt_embs.") + if len(list_extra_words) == 1: + list_extra_words = list_extra_words * len(face_prompt_embs) + else: + breakpoint() + else: + # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_comp_prompt_distillation. + # But list_extra_words always corresponds to the actual batch size. So we only take the first element. + list_extra_words = list_extra_words[:1] + + for extra_words in list_extra_words: + assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words." + # 16 or 4 ", " are placeholders for face_prompt_embs. + prompt_templates = [ "photo of a " + ", " * self.N_ID + list_extra_words[i] for i in range(len(list_extra_words)) ] + else: + # 16 or 4 ", " are placeholders for face_prompt_embs. + # No extra words are added to the prompt. So we add 2 more ", " to the template to keep + # the number of tokens roughly the same as when extra words are added. + prompt_templates = [ "photo of a " + ", " * (self.N_ID + 2) for _ in range(len(face_prompt_embs)) ] + + # This step should be quite fast, and there's no need to cache the input_ids. + # input_ids: [BS, 77]. + input_ids = self.tokenizer( + prompt_templates, + truncation=True, + padding="max_length", + max_length=self.max_prompt_length, + return_tensors="pt", + ).input_ids.to(face_prompt_embs.device) + + face_prompt_embs_orig_dtype = face_prompt_embs.dtype + face_prompt_embs = face_prompt_embs.to(self.dtype) + + ID_END = 4 + self.N_ID + PAD_BEGIN = ID_END + self.N_SFX + 2 + + # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping). + token_embs = self.prompt2token_proj(input_ids=input_ids, return_token_embs=True) + # token 4: first ", " in the template prompt. + # Replace embeddings of 16 or 4 placeholder ", " with face_prompt_embs. + token_embs[:, 4:ID_END] = face_prompt_embs + # Only when do_unet_distill == True, we append the static image suffix embeddings. + # Otherwise, static image suffix embeddings are ignored, + # and token_embs[:, ID_END:ID_END+self.N_SFX] are the filler embeddings of the + # extra ", " in the template prompt. + if enable_static_img_suffix_embs and self.N_SFX > 0: + # Put the static image suffix embeddings right after face_prompt_embs. + token_embs[:, ID_END:ID_END+self.N_SFX] = self.static_img_suffix_embs[:, :self.N_SFX] + + # This call does the ordinary CLIP text encoding pass. + prompt_embeds = self.prompt2token_proj( + input_ids=input_ids, + input_token_embs=token_embs, + hidden_state_layer_weights=hidden_state_layer_weights, + return_token_embs=False + )[0] + + # Restore the original dtype of prompt_embeds: float16 -> float32. + prompt_embeds = prompt_embeds.to(face_prompt_embs_orig_dtype) + # token 4: first ", " in the template prompt. + # When N_ID == 16, + # prompt_embeds 4:20 are the most important 16 embeddings that contain the subject's identity. + # 20:22 are embeddings of the (at most) two extra words. + # [N, 77, 768] -> [N, 16, 768] + if enable_static_img_suffix_embs: + core_prompt_embs = prompt_embeds[:, 4:ID_END+self.N_SFX] + else: + core_prompt_embs = prompt_embeds[:, 4:ID_END] + + if list_extra_words is not None: + # [N, 16, 768] -> [N, 18, 768] + extra_words_embs = prompt_embeds[:, ID_END+self.N_SFX:PAD_BEGIN] + core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1) + + returned_prompt_embs = [] + for emb_type in return_emb_types: + if emb_type == 'full': + returned_prompt_embs.append(prompt_embeds) + elif emb_type == 'full_half_pad': + prompt_embeds2 = prompt_embeds.clone() + # PAD_BEGIN is 22 or 10. Also exclude the last EOS token. + # So we subtract max_prompt_length by (PAD_BEGIN + 1). + PADS = self.max_prompt_length - PAD_BEGIN - 1 + if PADS >= 2: + # Fill half of the remaining embeddings with pad embeddings. + prompt_embeds2[:, PAD_BEGIN:PAD_BEGIN+PADS//2] = self.pad_embeddings[PAD_BEGIN:PAD_BEGIN+PADS//2] + returned_prompt_embs.append(prompt_embeds2) + elif emb_type == 'full_pad': + prompt_embeds2 = prompt_embeds.clone() + # Replace the PAD_BEGIN-th to the second last embeddings with pad embeddings. + # Skip replacing the last embedding, which might has special roles. + # (Although all padding tokens are the same EOS, the last token might acquire special semantics + # due to its special position.) + prompt_embeds2[:, PAD_BEGIN:-1] = self.pad_embeddings[PAD_BEGIN:-1] + returned_prompt_embs.append(prompt_embeds2) + elif emb_type == 'full_zeroed_extra': + prompt_embeds2 = prompt_embeds.clone() + # Only add two pad embeddings. The remaining embeddings are set to 0. + # Make the positional embeddings align with the actual positions. + prompt_embeds2[:, 22:24] = self.pad_embeddings[22:24] + prompt_embeds2[:, 24:-1] = 0 + returned_prompt_embs.append(prompt_embeds2) + elif emb_type == 'core': + returned_prompt_embs.append(core_prompt_embs) + else: + breakpoint() + + return returned_prompt_embs + +class SubjBasisGenerator(ImgPrompt2TextPrompt): def __init__( self, - # number of cross-attention heads. Half of the number of heads 12 of OpenAI clip-vit-large-patch14: + # number of cross-attention heads of the bg prompt translator. + # Taken as a half of the number of heads 12 of OpenAI clip-vit-large-patch14: # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json - num_heads=6, - num_id_vecs={ 'subj': 77, 'bg': 257 }, # number of identity vectors. 18: 16 face tokens + 2 extra tokens. 257: 257 CLIP tokens. - num_out_embs_per_layer=4, # num_out_embs. subj: 16. bg: 4. - num_out_layers=16, # number of layers of output embeddings. - image_embedding_dim=768, # CLIP image feature dimension, as per config.json above. - # DINO vits16 has 6 attention heads: - # https://huggingface.co/facebook/dino-vits16/blob/main/config.json - dino_embedding_dim=384, # DINO object feature dimension for objects. - output_dim=768, # CLIP text embedding input dimension. - placeholder_is_bg: bool = False, # Whether the placeholder is for the image background. + num_bg_encoder_heads=6, + # number of subject input identity vectors (only when the subject is not face), + # or number of background input identity vectors (no matter the subject is face or not). + # 257: 257 CLIP tokens. + num_nonface_in_id_vecs={ 'subj': 77, 'bg': 257 }, + num_id_vecs=16, # num_id_vecs: subj: 16. bg: 4. + num_static_img_suffix_embs: int = 0, # Number of extra static learnable image embeddings appended to translated ID embeddings. + bg_image_embedding_dim=1024, # CLIP image hidden layer feature dimension, as per config.json above. + obj_embedding_dim=384, # DINO object feature dimension for objects. + output_dim=768, # CLIP text embedding input dimension. + placeholder_is_bg: bool = False, # Whether the placeholder is for the image background tokens. prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj. - zs_extra_words_scale: float = 0.5, # Scale for extra words in the prompt2token_proj. - learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer. - bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection. + learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer. + bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection. ): - super().__init__() - self.placeholder_is_bg = placeholder_is_bg - self.num_out_layers = num_out_layers - self.num_out_embs_per_layer = num_out_embs_per_layer - # subj: 64, bg: 32. - self.num_out_embs = num_out_layers * num_out_embs_per_layer - self.output_dim = output_dim - # num_id_vecs should be the number of core ID embs, 16. + # If not placeholder_is_bg, then it calls initialize_text_components() in the superclass. + super().__init__(placeholder_is_bg=placeholder_is_bg, num_id_vecs=num_id_vecs, max_prompt_length=77, + num_static_img_suffix_embs=num_static_img_suffix_embs, img_prompt_dim=output_dim) + + self.placeholder_is_bg = placeholder_is_bg + self.num_out_embs = self.N_ID + self.N_SFX + self.output_dim = output_dim + # num_nonface_in_id_vecs should be the number of core ID embs, 16. # However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set. - self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj'] - self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim)) - self.pos_embs_ln = nn.LayerNorm(output_dim) - self.zs_extra_words_scale = zs_extra_words_scale - self.output_scale = output_dim ** -0.5 - self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.num_nonface_in_id_vecs = num_nonface_in_id_vecs['bg'] if placeholder_is_bg else num_nonface_in_id_vecs['subj'] + self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj if not self.placeholder_is_bg: # [1, 384] -> [1, 16, 768]. # TODO: use CLIPTextModelWrapper as obj_proj_in. - self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs) + self.obj_proj_in = ExpandEmbs(obj_embedding_dim, output_dim, expansion_ratio=self.num_nonface_in_id_vecs) - # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings). + # ** prompt2token_proj does the actual job: ** + # it is the inverse projection that maps from faceid2img_prompt_embs to adaface_prompt_embs. + # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings) or [1, 16, 768] (without paddings). # If self.placeholder_is_bg: prompt2token_proj is set to None. - self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14') - self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale - self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale) + # Use an attention dropout of 0.2 to increase robustness. + clip_dropout_config = None #CLIPTextConfig.from_pretrained('openai/clip-vit-large-patch14', attention_dropout=0.05, dropout=0.05) + self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14', + config=clip_dropout_config) + self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale + self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale) print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.") - # Freeze prompt2token_proj if prompt2token_proj_grad_scale is 0. - # Set requires_grad to False for all parameters in prompt2token_proj, to save memory taken by the optimizer. - if prompt2token_proj_grad_scale == 0: - self.freeze_prompt2token_proj() + # If prompt2token_proj_grad_scale is 0, freeze all params in prompt2token_proj. + # Otherwise, only freeze token and positional embeddings of the original CLIPTextModel. + self.freeze_prompt2token_proj() - self.prompt2token_proj_attention_multiplier = -1 + # These multipliers are relative to the original CLIPTextModel. + self.prompt2token_proj_attention_multipliers = [1] * 12 self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu') - self.pad_embeddings = None self.bg_proj_in = None + self.pos_embs = self.pos_embs_ln = self.latent_queries = self.latent_queries_ln = None else: # For background placeholders, face and object embeddings are not used as they are foreground. self.obj_proj_in = None - self.prompt2token_proj = None - print("Bg prompt2token_proj is set to None.") self.bg_proj_in = nn.Sequential( - nn.Linear(image_embedding_dim, output_dim, bias=False), + nn.Linear(bg_image_embedding_dim, output_dim, bias=False), nn.LayerNorm(output_dim), ) + self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, output_dim)) + self.pos_embs_ln = nn.LayerNorm(output_dim) self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim)) self.latent_queries_ln = nn.LayerNorm(output_dim) - self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj identity_to_v = False v_has_skip = not identity_to_v # True identity_to_out = not bg_prompt_translator_has_to_out_proj # True out_has_skip = not identity_to_out # False + # prompt_translator maps the clip image features (of the background) to the prompt embedding space. + # It is only used during training when placeholder_is_bg is True. # prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection. - # dim=768, num_heads=6. + # dim=768, num_bg_encoder_heads=6. self.prompt_translator = \ - CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05, - identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v, - q_aware_to_v=False, v_has_skip=v_has_skip, - num_q=0, # When not q_aware_to_v, num_q is not referenced. - identity_to_out=identity_to_out, - out_has_skip=out_has_skip) + CrossAttention(input_dim=output_dim, num_heads=num_bg_encoder_heads, p_dropout=0.05, + identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v, + q_aware_to_v=False, v_has_skip=v_has_skip, + num_q=0, # When not q_aware_to_v, num_q is not referenced. + identity_to_out=identity_to_out, + out_has_skip=out_has_skip) + + self.output_scale = output_dim ** -0.5 + ''' prompt_translator: CLIPEncoder # https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566 @@ -464,70 +656,51 @@ class SubjBasisGenerator(nn.Module): print(repr(self)) - # raw_id_embs: ArcFace embeddings for faces (not used since we have arc2face_id_embs), - # or DINO embeddings for objects. - # arc2face_id_embs: [BS, 16, 768], the core identity embeddings generated by Arc2Face. - def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0, - is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'): + # raw_id_embs: only used when the subject is non-faces. In that case it's DINO embeddings. + # Otherwise, raw_id_embs is not used. + # faceid2img_prompt_embs: [BS, 16, 768], the core ID prompt embeddings generated by ID2ImgPrompt. + def forward(self, faceid2img_prompt_embs, clip_features=None, raw_id_embs=None, out_id_embs_cfg_scale=1.0, + is_face=True, enable_static_img_suffix_embs=False): if not self.placeholder_is_bg: - BS = arc2face_id_embs.shape[0] + BS = faceid2img_prompt_embs.shape[0] else: - # If bg, then arc2face_id_embs is set to None, but clip_features is not None. + # If bg, then faceid2img_prompt_embs is set to None, but clip_features is not None. BS = clip_features.shape[0] - - adaface_prompt_embs = None - if not hasattr(self, 'clip_tokenizer'): - self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - + clip_features = clip_features.to(self.dtype) + # No need to use raw_id_embs if placeholder_is_bg. if not self.placeholder_is_bg: if is_face: - assert arc2face_id_embs is not None - # arc2face_embs has been projected to the (modified) prompt embedding space - # by arc2face_forward_face_embs. This prompt embedding space is modified because Arc2Face finetuned - # the text encoder and the U-Net. + assert faceid2img_prompt_embs is not None + # id2img_embs has been projected to the (modified) prompt embedding space + # by ID2AdaPrompt::map_init_id_to_img_prompt_embs(). This prompt embedding space is modified because + # the ID2ImgPrompt module (at least when it's arc2face) may have finetuned the + # text encoder and the U-Net. # in embedding_manager: [BS, 16, 768] -> [BS, 77, 768]. - # arc2face_id_embs is part of arc2face_embs: [BS, 77, 768] -> [BS, 16, 768]. + # faceid2img_prompt_embs is part of id2img_embs: [BS, 77, 768] -> [BS, 16, 768]. # adaface_prompt_embs is projected to the prompt embedding spaces. This is the # original U-Net prompt embedding space. # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights) - # return_emb_types: a list of strings, each string is among - # ['full', 'core', 'full_pad', 'full_half_pad', 'full_zeroed_extra', 'b_core_e']. - # Using b_core_e is more computationally efficient than using full_zeroed_extra. - # But there is an unknow BUG that causes crash when using b_core_e. - if is_training: - return_emb_types = ['full_pad', 'core'] - else: - # adaface_prompt_embs_inf_type: default is full_half_pad, same as training. - return_emb_types = [adaface_prompt_embs_inf_type, 'core'] - - if self.pad_embeddings is None: - self.generate_pad_embeddings() - else: - self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device) + # faceid2img_prompt_embs -> ada_id_embs: image prompt space -> text prompt space. with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0): - # If list_extra_words is not None, then core_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens - # and (at most) two extra words in full_prompt_embs, without BOS and EOS. - # If list_extra_words is None, then core_id_embs: [BS, 16, 768], the 16 identity tokens in full_prompt_embs. + # If list_extra_words is not None, then ada_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens + # and (at most) two extra words in adaface_prompt_embs, without BOS and EOS. + # If list_extra_words is None, then ada_id_embs: [BS, 16, 768], the 16 identity tokens in adaface_prompt_embs. # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]] - # zs_extra_words_scale is only effective when list_extra_words is not None. - # adaface_prompt_embs: [BS, 77, 768], core_id_embs: [BS, 16, 768]. - adaface_prompt_embs, core_id_embs = \ - arc2face_inverse_face_prompt_embs(self.clip_tokenizer, - self.prompt2token_proj, - arc2face_id_embs, - list_extra_words=None, - return_emb_types=return_emb_types, - pad_embeddings=self.pad_embeddings, - hidden_state_layer_weights=hidden_state_layer_weights, - input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale) - # Reduce the update rate to prompt2token_proj. - adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs) - core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs) + # ada_id_embs: [BS, 16, 768]. + # return_emb_types: a list of strings, each string is among + # ['full', 'core', 'full_pad', 'full_half_pad']. + ada_id_embs, = \ + self.inverse_img_prompt_embs(faceid2img_prompt_embs, + list_extra_words=None, + return_emb_types=['core'], + hidden_state_layer_weights=hidden_state_layer_weights, + enable_static_img_suffix_embs=enable_static_img_suffix_embs) + ada_id_embs = self.prompt2token_proj_grad_scaler(ada_id_embs) elif raw_id_embs is not None: # id_embs: [BS, 384] -> [BS, 18, 768]. # obj_proj_in is expected to project the DINO object features to @@ -550,21 +723,21 @@ class SubjBasisGenerator(nn.Module): # to bg prompt embeddings. with torch.set_grad_enabled(self.training): id_embs_out = self.prompt_translator(latent_queries, id_embs) - # [BS, 64, 768] -> [BS, 16, 4, 768] - id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim) - adaface_subj_embs = id_embs_out * self.output_scale # * 0.036 + + adaface_out_embs = id_embs_out * self.output_scale # * 0.036 else: - # adaface_subj_embs: [BS, 16, 768] -> [BS, 1, 16, 768] -> [BS, 16, 16, 768] - adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1) - - # If out_id_embs_scale < 1, adaface_subj_embs is a mix of adaface_subj_embs and pad_embeddings. - if out_id_embs_scale != 1: - # pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768]. - pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0) - adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \ - + pad_embeddings * (1 - out_id_embs_scale) - - return adaface_subj_embs, adaface_prompt_embs + adaface_out_embs = ada_id_embs + # If out_id_embs_cfg_scale < 1, adaface_out_embs is a mix of adaface_out_embs and pad_embeddings. + if out_id_embs_cfg_scale != 1: + # pad_embeddings: [77, 768] -> [16, 768] -> [1, 16, 768]. + # NOTE: Never do cfg on static image suffix embeddings. + # So we take self.N_ID embeddings, instead of self.N_ID + self.N_SFX, + # even if enable_static_img_suffix_embs=True. + pad_embeddings = self.pad_embeddings[4:4+self.N_ID].unsqueeze(0).to(ada_id_embs.device) + adaface_out_embs[:, :self.N_ID] = ada_id_embs[:, :self.N_ID] * out_id_embs_cfg_scale \ + + pad_embeddings * (1 - out_id_embs_cfg_scale) + + return adaface_out_embs def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device): if learnable_hidden_state_weights_scheme == 'none': @@ -579,180 +752,117 @@ class SubjBasisGenerator(nn.Module): # hidden_state_layer_weights: [3, 1]. self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device), requires_grad=True) + # A gradient scaler of 5 makes the gradients on hidden_state_layer_weights 5 times larger. self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5) print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.") else: breakpoint() - def generate_pad_embeddings(self): - # clip_embeddings: CLIPTextEmbeddings instance. pad_embeddings is generated after - # prompt2token_proj is loaded from the finetuned weight. It seems such pad embeddings perform - # slightly better than the original pad embeddings. - clip_embeddings = self.prompt2token_proj.text_model.embeddings - # clip_embeddings() and clip_embeddings.token_embedding() differ in that - # clip_embeddings() adds positional embeddings, while clip_embeddings.token_embedding() doesn't. - # Adding positional embeddings seems to help somewhat. - # pad_tokens: pad_token_id 49407 repeated 77 times. - # pad_token_id is the EOS token. But BOS is 49406. - pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77) - # pad_embeddings: [77, 768]. - pad_embeddings = clip_embeddings(pad_tokens)[0] - # We don't allow face recon to influence the pad embeddings. - # Otherwise, face identity will leak into the pad embeddings. - self.pad_embeddings = pad_embeddings.detach() - - def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1): - if multiplier > 1: - num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std) - self.prompt2token_proj_attention_multiplier = multiplier - print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}") - + def extend_prompt2token_proj_attention(self, prompt2token_proj_attention_multipliers=None, + begin_layer_idx=-1, end_layer_idx=-1, multiplier=1, perturb_std=0.1): + if begin_layer_idx == -1: + begin_layer_idx = 0 + if end_layer_idx == -1: + end_layer_idx = 11 + + if prompt2token_proj_attention_multipliers is None and multiplier == 1: + print("prompt2token_proj_attention_multipliers are all 1. No extension is done.") + return + + elif prompt2token_proj_attention_multipliers is None: + # prompt2token_proj_attention_multipliers are relative to the current prompt2token_proj. + prompt2token_proj_attention_multipliers = [1] * 12 + for i in range(begin_layer_idx, end_layer_idx+1): + prompt2token_proj_attention_multipliers[i] = multiplier + # Otherwise, use the given prompt2token_proj_attention_multipliers. + + num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(prompt2token_proj_attention_multipliers, perturb_std) + # Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). + for i in range(begin_layer_idx, end_layer_idx+1): + self.prompt2token_proj_attention_multipliers[i] *= prompt2token_proj_attention_multipliers[i] + + print(f"{num_extended_layers} layers in prompt2token_proj_attention are extended by {prompt2token_proj_attention_multipliers}") + return num_extended_layers + + def squeeze_prompt2token_proj_attention(self, prompt2token_proj_attention_divisors=None, + begin_layer_idx=-1, end_layer_idx=-1, divisor=1): + if begin_layer_idx == -1: + begin_layer_idx = 0 + if end_layer_idx == -1: + end_layer_idx = 11 + + if prompt2token_proj_attention_divisors is None and divisor == 1: + print("prompt2token_proj_attention_divisors are all 1. No squeezing is done.") + return + elif prompt2token_proj_attention_divisors is None: + prompt2token_proj_attention_divisors = [1] * 12 + for i in range(begin_layer_idx, end_layer_idx+1): + prompt2token_proj_attention_divisors[i] = divisor + # Otherwise, use the given prompt2token_proj_attention_divisors. + + num_squeezed_layers = self.prompt2token_proj.squeeze_clip_attention_MKV_divisor(prompt2token_proj_attention_divisors) + # Update prompt2token_proj_attention_multipliers (relative to the original CLIPTextModel). + for i in range(begin_layer_idx, end_layer_idx+1): + self.prompt2token_proj_attention_multipliers[i] //= prompt2token_proj_attention_divisors[i] + + print(f"{num_squeezed_layers} layers in prompt2token_proj_attention are squeezed by {prompt2token_proj_attention_divisors}") + return num_squeezed_layers + def freeze_prompt2token_proj(self): + # Only applicable to fg basis generator. + if self.placeholder_is_bg: + return # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it. # Then we don't have to check whether it's for subj or bg. + if self.prompt2token_proj_grad_scale == 0: + frozen_components_name = 'all' + frozen_param_set = self.prompt2token_proj.named_parameters() + else: + frozen_components_name = 'token_pos_embeddings' + frozen_param_set = self.prompt2token_proj.text_model.embeddings.named_parameters() + if self.prompt2token_proj is not None: frozen_param_names = [] - for param_name, param in self.prompt2token_proj.named_parameters(): + for param_name, param in frozen_param_set: if param.requires_grad: param.requires_grad = False frozen_param_names.append(param_name) # If param is already frozen, then no need to freeze it again. - print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.") + print(f"{frozen_components_name} {len(frozen_param_names)} params in Subj prompt2token_proj is frozen.") #print(f"Frozen parameters:\n{frozen_param_names}") - def __repr__(self): - type_sig = 'subj' if not self.placeholder_is_bg else 'bg' + def patch_old_subj_basis_generator_ckpt(self): # Fix compatability with the previous version. if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'): self.bg_prompt_translator_has_to_out_proj = False if not hasattr(self, 'num_out_embs'): self.num_out_embs = -1 - return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \ - f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}" - -@dataclass -class BaseModelOutputWithPooling2(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - attn_mask: Optional[torch.FloatTensor] = None - -# Revised from CLIPVisionTransformer to support attention mask. -# self: a CLIPVisionTransformer instance. -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821 -# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224] -# attn_mask: B*H*W attention mask. -def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None, - output_attentions = None, - output_hidden_states = None, return_dict = None): - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # Visual tokens are flattended in embeddings(). - # self.embeddings: CLIPVisionEmbeddings. - # hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds). - # 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False). - hidden_states = self.embeddings(pixel_values) - hidden_states = self.pre_layrnorm(hidden_states) - - if attn_mask is not None: - # feat_edge_size: 16. - feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int) - # attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16]. - attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest') - # Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256]. - attn_mask = attn_mask.flatten(2) - # Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257]. - # This 1 corresponds to class_embeds, which is always attended to. - attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1) - attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1) + if hasattr(self, 'num_id_vecs') and not hasattr(self, 'N_ID'): + self.N_ID = self.num_id_vecs + if not hasattr(self, 'num_nonface_in_id_vecs') and hasattr(self, 'N_ID'): + self.num_nonface_in_id_vecs = self.N_ID + if not hasattr(self, 'dtype'): + self.dtype = torch.float32 + + if self.placeholder_is_bg: + if not hasattr(self, 'pos_embs') or self.pos_embs is None: + self.pos_embs = nn.Parameter(torch.zeros(1, self.num_nonface_in_id_vecs, self.output_dim)) + if not hasattr(self, 'latent_queries') or self.latent_queries is None: + self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, self.output_dim)) + # Background encoder doesn't require initializing text components. else: - attn_mask_pairs = None - - # encoder: CLIPEncoder. - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - # New feature: (***The official documentation is wrong***) - # attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*): - # Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`: - # - 1 for pairs that are **not masked**, - # - 0 for pairs that are **masked**. - # attention_mask is eventually used by CLIPEncoderLayer: - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370 - attention_mask=attn_mask_pairs, - output_attentions=output_attentions, # False - output_hidden_states=output_hidden_states, # True - return_dict=return_dict, # True - ) + self.initialize_hidden_state_layer_weights('per-layer', 'cpu') + if not hasattr(self, 'prompt2token_proj_attention_multipliers'): + # Please manually set prompt2token_proj_attention_multipliers in the ckpt. + breakpoint() - # last_hidden_state: [BS, 257, 1280] - last_hidden_state = encoder_outputs[0] - pooled_output = last_hidden_state[:, 0, :] - pooled_output = self.post_layernorm(pooled_output) - - # return_dict is True. - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPooling2( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - # Newly added: return resized flattened attention mask. - # [BS, 1, 257] -> [BS, 257, 1] - attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None - ) + self.initialize_text_components(max_prompt_length=77, num_id_vecs=self.N_ID, + num_static_img_suffix_embs=self.N_SFX, + img_prompt_dim=self.output_dim) + def __repr__(self): + type_sig = 'subj' if not self.placeholder_is_bg else 'bg' -class CLIPVisionModelWithMask(CLIPVisionModel): - def __init__(self, config): - super().__init__(config) - # Replace vision_model.forward() with the new one that supports mask. - self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model) - - def forward(self, pixel_values = None, attn_mask = None, output_attentions = None, - output_hidden_states = None, return_dict = None): - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - return self.vision_model( - pixel_values=pixel_values, - attn_mask=attn_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \ + f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}" diff --git a/adaface/test_img_prompt_model.py b/adaface/test_img_prompt_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1d017494655f9480615c95595e26bb124fa2b3b5 --- /dev/null +++ b/adaface/test_img_prompt_model.py @@ -0,0 +1,199 @@ +import torch +from PIL import Image +import os, argparse, glob +import numpy as np +from .face_id_to_ada_prompt import create_id2ada_prompt_encoder +from .util import create_consistentid_pipeline +from .arc2face_models import create_arc2face_pipeline +from transformers import CLIPTextModel + +def save_images(images, subject_name, id2img_prompt_encoder_type, + prompt, perturb_std, save_dir = "samples-ada"): + os.makedirs(save_dir, exist_ok=True) + # Save 4 images as a grid image in save_dir + grid_image = Image.new('RGB', (512 * 2, 512 * 2)) + for i, image in enumerate(images): + image = image.resize((512, 512)) + grid_image.paste(image, (512 * (i % 2), 512 * (i // 2))) + + prompt_sig = prompt.replace(" ", "_").replace(",", "_") + grid_filepath = os.path.join(save_dir, + "-".join([subject_name, id2img_prompt_encoder_type, + prompt_sig, f"perturb{perturb_std:.02f}.png"])) + + if os.path.exists(grid_filepath): + grid_count = 2 + grid_filepath = os.path.join(save_dir, + "-".join([ subject_name, id2img_prompt_encoder_type, + prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png") + while os.path.exists(grid_filepath): + grid_count += 1 + grid_filepath = os.path.join(save_dir, + "-".join([ subject_name, id2img_prompt_encoder_type, + prompt_sig, f"perturb{perturb_std:.02f}", str(grid_count) ]) + ".png") + + grid_image.save(grid_filepath) + print(f"Saved to {grid_filepath}") + +def seed_everything(seed): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ["PL_GLOBAL_SEED"] = str(seed) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # --base_model_path models/Realistic_Vision_V4.0_noVAE + parser.add_argument("--base_model_path", type=str, default="models/sar/sar.safetensors") + parser.add_argument("--id2img_prompt_encoder_type", type=str, + choices=["arc2face", "consistentID"], + help="Types of the ID2Img prompt encoder") + parser.add_argument("--subject", type=str, default="subjects-celebrity/taylorswift") + parser.add_argument("--example_image_count", type=int, default=5, help="Number of example images to use") + parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate") + parser.add_argument("--init_img", type=str, default=None) + parser.add_argument("--prompt", type=str, default="portrait photo of a person in superman costume") + parser.add_argument("--use_core_only", action="store_true") + parser.add_argument("--truncate_prompt_at", type=int, default=-1, + help="Truncate the prompt to this length") + parser.add_argument("--randface", action="store_true") + parser.add_argument("--seed", type=int, default=-1) + parser.add_argument("--perturb_std", type=float, default=1) + + args = parser.parse_args() + if args.seed > 0: + seed_everything(args.seed) + + if args.id2img_prompt_encoder_type == "arc2face": + pipeline = create_arc2face_pipeline(args.base_model_path) + use_teacher_neg = False + elif args.id2img_prompt_encoder_type == "consistentID": + pipeline = create_consistentid_pipeline(args.base_model_path) + use_teacher_neg = True + + pipeline = pipeline.to('cuda', torch.float16) + + # When the second argument, adaface_ckpt_path = None, create_id2ada_prompt_encoder() + # returns an id2ada_prompt_encoder object, with .subj_basis_generator uninitialized. + # But it doesn't matter, as we don't use the subj_basis_generator to generate ada embeddings. + id2img_prompt_encoder = create_id2ada_prompt_encoder([args.id2img_prompt_encoder_type], + num_static_img_suffix_embs=0) + id2img_prompt_encoder.to('cuda') + + if not args.randface: + image_folder = args.subject + if image_folder.endswith("/"): + image_folder = image_folder[:-1] + + if os.path.isfile(image_folder): + # Get the second to the last part of the path + subject_name = os.path.basename(os.path.dirname(image_folder)) + image_paths = [image_folder] + + else: + subject_name = os.path.basename(image_folder) + image_types = ["*.jpg", "*.png", "*.jpeg"] + alltype_image_paths = [] + for image_type in image_types: + # glob returns the full path. + image_paths = glob.glob(os.path.join(image_folder, image_type)) + if len(image_paths) > 0: + alltype_image_paths.extend(image_paths) + # image_paths contain at most args.example_image_count full image paths. + image_paths = alltype_image_paths[:args.example_image_count] + else: + subject_name = None + image_paths = None + image_folder = None + + subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name + id_batch_size = args.out_image_count + + text_encoder = pipeline.text_encoder + orig_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda") + + noise = torch.randn(args.out_image_count, 4, 64, 64, device='cuda', dtype=torch.float16) + if args.randface: + init_id_embs = torch.randn(1, 512, device='cuda', dtype=torch.float16) + if args.id2img_prompt_encoder_type == "arc2face": + pre_clip_features = None + elif args.id2img_prompt_encoder_type == "consistentID": + # For ConsistentID, random clip features are much better than zero clip features. + rand_clip_fgbg_features = torch.randn(1, 514, 1280, device='cuda', dtype=torch.float16) + pre_clip_features = rand_clip_fgbg_features + else: + breakpoint() + else: + init_id_embs = None + pre_clip_features = None + + # perturb_std is the *relative* std of the noise added to the face ID embeddings. + # For Arc2Face, a perturb_std of 0.08 could change gender, but 0.06 is usually safe. + # For ConsistentID, the image prompt embeddings are extremely robust to noise, + # and the perturb_std can be set to 0.5, only leading to a slight change in the result images. + # Seems ConsistentID mainly relies on CLIP features, instead of the face ID embeddings. + for perturb_std in (args.perturb_std, 0): + # id_prompt_emb is in the image prompt space. + # neg_id_prompt_emb is used in ConsistentID only. + face_image_count, faceid_embeds, id_prompt_emb, neg_id_prompt_emb \ + = id2img_prompt_encoder.get_img_prompt_embs( \ + init_id_embs=init_id_embs, + pre_clip_features=pre_clip_features, + image_paths=image_paths, + image_objs=None, + id_batch_size=id_batch_size, + perturb_at_stage='img_prompt_emb', + perturb_std=perturb_std, + avg_at_stage='id_emb', + verbose=True) + + pipeline.text_encoder = orig_text_encoder + + comp_prompt = args.prompt + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + # prompt_embeds_, negative_prompt_embeds_: [4, 77, 768] + prompt_embeds_, negative_prompt_embeds_ = \ + pipeline.encode_prompt(comp_prompt, device='cuda', num_images_per_prompt=args.out_image_count, + do_classifier_free_guidance=True, negative_prompt=negative_prompt) + #pipeline.text_encoder = text_encoder + # Postpend the id prompt embeddings to the prompt embeddings. + # For arc2face, id_prompt_emb can be either pre- or post-pended. + # But for ConsistentID, id_prompt_emb has to be **post-pended**. Otherwise, the result images are blank. + + full_negative_prompt_embeds_ = negative_prompt_embeds_ + if args.truncate_prompt_at >= 0: + prompt_embeds_ = prompt_embeds_[:, :args.truncate_prompt_at] + negative_prompt_embeds_ = negative_prompt_embeds_[:, :args.truncate_prompt_at] + + prompt_embeds_ = torch.cat([prompt_embeds_, id_prompt_emb], dim=1) + M = id_prompt_emb.shape[1] + + if (not use_teacher_neg) or neg_id_prompt_emb is None: + # For arc2face, neg_id_prompt_emb is None. So we concatenate the last M negative prompt embeddings, + # to make the negative prompt embeddings have the same length as the prompt embeddings. + negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, full_negative_prompt_embeds_[:, -M:]], dim=1) + else: + # NOTE: For ConsistentID, neg_id_prompt_emb has to be present in the negative prompt embeddings. + # Otherwise, the result images are cartoonish. + negative_prompt_embeds_ = torch.cat([negative_prompt_embeds_, neg_id_prompt_emb], dim=1) + + if args.use_core_only: + prompt_embeds_ = id_prompt_emb + if (not use_teacher_neg) or neg_id_prompt_emb is None: + negative_prompt_embeds_ = full_negative_prompt_embeds_[:, :M] + else: + negative_prompt_embeds_ = neg_id_prompt_emb + + for guidance_scale in [6]: + images = pipeline(latents=noise, + prompt_embeds=prompt_embeds_, + negative_prompt_embeds=negative_prompt_embeds_, + num_inference_steps=50, + guidance_scale=guidance_scale, + num_images_per_prompt=1).images + + save_images(images, subject_name, args.id2img_prompt_encoder_type, + f"guide{guidance_scale}", perturb_std) diff --git a/adaface/unet_teachers.py b/adaface/unet_teachers.py new file mode 100644 index 0000000000000000000000000000000000000000..3190658e3ca93a68b3551cfa155ca6edf7214e4b --- /dev/null +++ b/adaface/unet_teachers.py @@ -0,0 +1,218 @@ +import torch +import numpy as np +import pytorch_lightning as pl +from diffusers import UNet2DConditionModel +from adaface.util import UNetEnsemble, create_consistentid_pipeline +from diffusers import UNet2DConditionModel +from omegaconf.listconfig import ListConfig + +def create_unet_teacher(teacher_type, device='cpu', **kwargs): + # If teacher_type is a list with only one element, we dereference it. + if isinstance(teacher_type, (tuple, list, ListConfig)) and len(teacher_type) == 1: + teacher_type = teacher_type[0] + + if teacher_type == "arc2face": + return Arc2FaceTeacher(**kwargs) + elif teacher_type == "unet_ensemble": + # unet, extra_unet_dirpaths and unet_weights are passed in kwargs. + # Even if we distill from unet_ensemble, we still need to load arc2face for generating + # arc2face embeddings. + # The first (optional) ctor param of UNetEnsembleTeacher is an instantiated unet, + # in our case, the ddpm unet. Ideally we should reuse it to save GPU RAM. + # However, since the __call__ method of the ddpm unet takes different formats of params, + # for simplicity, we still use the diffusers unet. + # unet_teacher is put on CPU first, then moved to GPU when DDPM is moved to GPU. + return UNetEnsembleTeacher(device=device, **kwargs) + elif teacher_type == "consistentID": + return ConsistentIDTeacher(**kwargs) + elif teacher_type == "simple_unet": + return SimpleUNetTeacher(**kwargs) + # Since we've dereferenced the list if it has only one element, + # this holding implies the list has more than one element. Therefore it's UNetEnsembleTeacher. + elif isinstance(teacher_type, (tuple, list, ListConfig)): + # teacher_type is a list of teacher types. So it's UNetEnsembleTeacher. + return UNetEnsembleTeacher(unet_types=teacher_type, device=device, **kwargs) + else: + raise NotImplementedError(f"Teacher type {teacher_type} not implemented.") + +class UNetTeacher(pl.LightningModule): + def __init__(self, **kwargs): + super().__init__() + self.name = None + # self.unet will be initialized in the child class. + self.unet = None + self.p_uses_cfg = kwargs.get("p_uses_cfg", 0) + # self.cfg_scale will be randomly sampled from cfg_scale_range. + self.cfg_scale_range = kwargs.get("cfg_scale_range", [1.3, 2]) + # Initialize cfg_scale to 1. It will be randomly sampled during forward pass. + self.cfg_scale = 1 + if self.p_uses_cfg > 0: + print(f"Using CFG with probability {self.p_uses_cfg} and scale range {self.cfg_scale_range}.") + else: + print(f"Never using CFG.") + + # Passing in ddpm_model to use its q_sample and predict_start_from_noise methods. + # We don't implement the two functions here, because they involve a few tensors + # to be initialized, which will unnecessarily complicate the code. + # noise: the initial noise for the first iteration. + # t: the initial t. We will sample additional (num_denoising_steps - 1) smaller t. + # uses_same_t: when sampling t, use the same t for all instances. + def forward(self, ddpm_model, x_start, noise, t, teacher_context, + num_denoising_steps=1, uses_same_t=False): + assert num_denoising_steps <= 10 + + if self.p_uses_cfg > 0: + self.uses_cfg = np.random.rand() < self.p_uses_cfg + if self.uses_cfg: + # Randomly sample a cfg_scale from cfg_scale_range. + self.cfg_scale = np.random.uniform(*self.cfg_scale_range) + if self.cfg_scale == 1: + self.uses_cfg = False + + if self.uses_cfg: + print(f"Teacher samples CFG scale {self.cfg_scale:.1f}.") + else: + self.cfg_scale = 1 + print("Teacher does not use CFG.") + + # If p_uses_cfg > 0, we always pass both pos_context and neg_context to the teacher. + # But the neg_context is only used when self.uses_cfg is True and cfg_scale > 1. + # So we manually split the teacher_context into pos_context and neg_context, and only keep pos_context. + if self.name == 'unet_ensemble': + teacher_pos_contexts = [] + # teacher_context is a list of teacher contexts. + for teacher_context_i in teacher_context: + pos_context, neg_context = torch.chunk(teacher_context_i, 2, dim=0) + if pos_context.shape[0] != x_start.shape[0]: + breakpoint() + teacher_pos_contexts.append(pos_context) + teacher_context = teacher_pos_contexts + else: + pos_context, neg_context = torch.chunk(teacher_context, 2, dim=0) + if pos_context.shape[0] != x_start.shape[0]: + breakpoint() + teacher_context = pos_context + else: + # p_uses_cfg = 0. Never use CFG. + self.uses_cfg = False + # In this case, the student only passes pos_context to the teacher, + # so no need to split teacher_context into pos_context and neg_context. + # self.cfg_scale will be accessed by the student, + # so we need to make sure it is always set correctly, + # in case someday we want to switch from CFG to non-CFG during runtime. + self.cfg_scale = 1 + + if self.name == 'unet_ensemble': + # teacher_context is a list of teacher contexts. + for teacher_context_i in teacher_context: + if teacher_context_i.shape[0] != x_start.shape[0] * (1 + self.uses_cfg): + breakpoint() + else: + if teacher_context.shape[0] != x_start.shape[0] * (1 + self.uses_cfg): + breakpoint() + + # Initially, x_starts only contains the original x_start. + x_starts = [ x_start ] + noises = [ noise ] + ts = [ t ] + noise_preds = [] + + with torch.autocast(device_type='cuda', dtype=torch.float16): + for i in range(num_denoising_steps): + x_start = x_starts[i] + t = ts[i] + noise = noises[i] + # sqrt_alphas_cumprod[t] * x_start + sqrt_one_minus_alphas_cumprod[t] * noise + x_noisy = ddpm_model.q_sample(x_start, t, noise) + + if self.uses_cfg: + x_noisy2 = x_noisy.repeat(2, 1, 1, 1) + t2 = t.repeat(2) + else: + x_noisy2 = x_noisy + t2 = t + + # If do_arc2face_distill, then pos_context is [BS=6, 21, 768]. + noise_pred = self.unet(sample=x_noisy2, timestep=t2, encoder_hidden_states=teacher_context, + return_dict=False)[0] + if self.uses_cfg and self.cfg_scale > 1: + pos_noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0) + noise_pred = pos_noise_pred * self.cfg_scale - neg_noise_pred * (self.cfg_scale - 1) + + # sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * noise + pred_x0 = ddpm_model.predict_start_from_noise(x_noisy, t, noise_pred) + noise_preds.append(noise_pred) + + # The predicted x0 is used as the x_start for the next denoising step. + x_starts.append(pred_x0) + + # Sample an earlier timestep for the next denoising step. + if i < num_denoising_steps - 1: + # NOTE: rand_like() samples from U(0, 1), not like randn_like(). + relative_ts = torch.rand_like(t.float()) + # Make sure at the middle step (i = sqrt(num_denoising_steps - 1), the timestep + # is between 50% and 70% of the current timestep. So if num_denoising_steps = 5, + # we take timesteps within [0.5^0.66, 0.7^0.66] = [0.63, 0.79] of the current timestep. + # If num_denoising_steps = 4, we take timesteps within [0.5^0.72, 0.7^0.72] = [0.61, 0.77] + # of the current timestep. + t_lb = t * np.power(0.5, np.power(num_denoising_steps - 1, -0.3)) + t_ub = t * np.power(0.7, np.power(num_denoising_steps - 1, -0.3)) + earlier_timesteps = (t_ub - t_lb) * relative_ts + t_lb + earlier_timesteps = earlier_timesteps.long() + + if uses_same_t: + # If uses_same_t, we use the same earlier_timesteps for all instances. + earlier_timesteps = earlier_timesteps[0].repeat(x_start.shape[0]) + + # earlier_timesteps = ts[i+1] < ts[i]. + ts.append(earlier_timesteps) + + noise = torch.randn_like(pred_x0) + noises.append(noise) + + return noise_preds, x_starts, noises, ts + +class Arc2FaceTeacher(UNetTeacher): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.name = "arc2face" + self.unet = UNet2DConditionModel.from_pretrained( + #"runwayml/stable-diffusion-v1-5", subfolder="unet" + 'models/arc2face', subfolder="arc2face", torch_dtype=torch.float16 + ) + # Disable CFG. Even if p_uses_cfg > 0, the randomly drawn cfg_scale is still 1, + # so the CFG is effectively disabled. + self.cfg_scale_range = [1, 1] + +class UNetEnsembleTeacher(UNetTeacher): + # unet_weights are not model weights, but scalar weights for individual unets. + def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', **kwargs): + super().__init__(**kwargs) + self.name = "unet_ensemble" + self.unet = UNetEnsemble(unets, unet_types, extra_unet_dirpaths, unet_weights, device) + +class ConsistentIDTeacher(UNetTeacher): + def __init__(self, base_model_path="models/ensemble/sd15-dste8-vae.safetensors", **kwargs): + super().__init__(**kwargs) + self.name = "consistentID" + ### Load base model + # In contrast to Arc2FaceTeacher or UNetEnsembleTeacher, ConsistentIDPipeline is not a torch.nn.Module. + # We couldn't initialize the ConsistentIDPipeline to CPU first and wait it to be automatically moved to GPU. + # Instead, we have to initialize it to GPU directly. + pipe = create_consistentid_pipeline(base_model_path) + # Compatible with the UNetTeacher interface. + self.unet = pipe.unet + # Release VAE and text_encoder to save memory. UNet is still needed for denoising + # (the unet is implemented in diffusers in fp16, so probably faster than the LDM unet). + pipe.release_components(["vae", "text_encoder"]) + +# We use the default cfg_scale_range=[1.3, 2] for SimpleUNetTeacher. +# Note p_uses_cfg=0.5 will also be passed in in kwargs. +class SimpleUNetTeacher(UNetTeacher): + def __init__(self, unet_dirpath='models/ensemble/sd15-unet', + torch_dtype=torch.float16, **kwargs): + super().__init__(**kwargs) + self.name = "simple_unet" + self.unet = UNet2DConditionModel.from_pretrained( + unet_dirpath, torch_dtype=torch_dtype + ) diff --git a/adaface/util.py b/adaface/util.py index 7beb81d2183b2ed687e4fd0ebe4797aa14116f37..62caad8a719bbe1ea9a96e131587af3f4a5e2caa 100644 --- a/adaface/util.py +++ b/adaface/util.py @@ -1,18 +1,44 @@ import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np from PIL import Image -import cv2 +from diffusers import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput +from transformers import CLIPVisionModel +from dataclasses import dataclass +from typing import Optional, Tuple +from transformers.utils import ModelOutput +import numpy as np +import argparse +from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline +from diffusers import ( + UNet2DConditionModel, + DDIMScheduler, +) + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") -# add_noise_to_tensor() adds a fixed amount of noise to the tensor. -def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False, - std_dim=-1, norm_dim=-1): - if noise_std_is_relative: +# perturb_tensor() adds a fixed amount of noise to the tensor. +def perturb_tensor(ts, perturb_std, perturb_std_is_relative=True, keep_norm=False, + std_dim=-1, norm_dim=-1, verbose=True): + orig_ts = ts + if perturb_std_is_relative: ts_std_mean = ts.std(dim=std_dim).mean().detach() - noise_std *= ts_std_mean - noise = torch.randn_like(ts) * noise_std + perturb_std *= ts_std_mean + # ts_std_mean: 50~80 for unnormalized images, perturb_std: 2.5-4 for 0.05 noise. + if verbose: + print(f"ts_std_mean: {ts_std_mean:.03f}, perturb_std: {perturb_std:.03f}") + + noise = torch.randn_like(ts) * perturb_std if keep_norm: orig_norm = ts.norm(dim=norm_dim, keepdim=True) ts = ts + noise @@ -20,9 +46,32 @@ def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=Fal ts = ts * orig_norm / (new_norm + 1e-8) else: ts = ts + noise + + if verbose: + print(f"Correlations between new and original tensors: {F.cosine_similarity(ts.flatten(), orig_ts.flatten(), dim=0).item():.03f}") return ts +def perturb_np_array(np_array, perturb_std, perturb_std_is_relative=True, std_dim=-1): + ts = torch.from_numpy(np_array).to(dtype=torch.float32) + ts = perturb_tensor(ts, perturb_std, perturb_std_is_relative, std_dim=std_dim) + return ts.numpy().astype(np_array.dtype) + +def calc_stats(emb_name, embeddings, mean_dim=0): + print("%s:" %emb_name) + repeat_count = [1] * embeddings.ndim + repeat_count[mean_dim] = embeddings.shape[mean_dim] + # Average across the mean_dim dim. + # Make emb_mean the same size as embeddings, as required by F.l1_loss. + emb_mean = embeddings.mean(mean_dim, keepdim=True).repeat(repeat_count) + l1_loss = F.l1_loss(embeddings, emb_mean) + # F.l2_loss doesn't take sqrt. So the loss is very small. + # Compute it manually. + l2_loss = ((embeddings - emb_mean) ** 2).mean().sqrt() + norms = torch.norm(embeddings, dim=1).detach().cpu().numpy() + print("L1: %.4f, L2: %.4f" %(l1_loss.item(), l2_loss.item())) + print("Norms: min: %.4f, max: %.4f, mean: %.4f, std: %.4f" %(norms.min(), norms.max(), norms.mean(), norms.std())) + # Revised from RevGrad, by removing the grad negation. class ScaleGrad(torch.autograd.Function): @@ -71,272 +120,270 @@ def gen_gradient_scaler(alpha, debug=False): # Don't use lambda function here, otherwise the object can't be pickled. return torch.detach -#@torch.autocast(device_type="cuda") -# In AdaFaceWrapper, input_max_length is 22. -def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs, - input_max_length=77, return_full_and_core_embs=True): - - ''' - arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance. - face_embs: (N, 512) normalized ArcFace embeddings. - return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings. - If False, return only the core embeddings. - - ''' - - # arcface_token_id: 1014 - arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0] - - # This step should be quite fast, and there's no need to cache the input_ids. - input_ids = tokenizer( - "photo of a id person", - truncation=True, - padding="max_length", - max_length=input_max_length, #tokenizer.model_max_length, - return_tensors="pt", - ).input_ids.to(face_embs.device) - # input_ids: [1, 77] or [3, 77] (during training). - input_ids = input_ids.repeat(len(face_embs), 1) - face_embs_dtype = face_embs.dtype - face_embs = face_embs.to(arc2face_text_encoder.dtype) - # face_embs_padded: [1, 512] -> [1, 768]. - face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0) - # arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping). - # The second call does the ordinary CLIP text encoding pass. - token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True) - token_embs[input_ids==arcface_token_id] = face_embs_padded - - prompt_embeds = arc2face_text_encoder( - input_ids=input_ids, - input_token_embs=token_embs, - return_token_embs=False - )[0] - - # Restore the original dtype of prompt_embeds: float16 -> float32. - prompt_embeds = prompt_embeds.to(face_embs_dtype) - - if return_full_and_core_embs: - # token 4: 'id' in "photo of a id person". - # 4:20 are the most important 16 embeddings that contain the subject's identity. - # [N, 77, 768] -> [N, 16, 768] - return prompt_embeds, prompt_embeds[:, 4:20] +def pad_image_obj_to_square(image_obj, new_size=-1): + # Remove alpha channel if it exists. + if image_obj.mode == 'RGBA': + image_obj = image_obj.convert('RGB') + + # Pad input to be width == height + width, height = orig_size = image_obj.size + new_width, new_height = max(width, height), max(width, height) + + if width != height: + if width > height: + pads = (0, (width - height) // 2) + elif height > width: + pads = ((height - width) // 2, 0) + square_image_obj = Image.new("RGB", (new_width, new_height)) + # pads indicates the upper left corner to paste the input. + square_image_obj.paste(image_obj, pads) + #square_image_obj = square_image_obj.resize((512, 512)) + print(f"{width}x{height} -> {new_width}x{new_height} -> {square_image_obj.size}") + long_short_ratio = max(width, height) / min(width, height) else: - # [N, 16, 768] - return prompt_embeds[:, 4:20] - -def get_b_core_e_embeddings(prompt_embeds, length=22): - b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1) - return b_core_e_embs - -# return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e']. -def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words, - return_emb_types, pad_embeddings, hidden_state_layer_weights=None, - input_max_length=77, zs_extra_words_scale=0.5): - - ''' - inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**. - inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping. - face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings. - list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt. - return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings. - If False, return only the core embeddings. - ''' - - if list_extra_words is not None: - if len(list_extra_words) != len(face_prompt_embs): - if len(face_prompt_embs) > 1: - print("Warn: list_extra_words has different length as face_prompt_embs.") - if len(list_extra_words) == 1: - list_extra_words = list_extra_words * len(face_prompt_embs) + square_image_obj = image_obj + pads = (0, 0) + long_short_ratio = 1 + + if new_size > 0: + # Resize the shorter edge to 512. + square_image_obj = square_image_obj.resize([int(new_size * long_short_ratio), int(new_size * long_short_ratio)]) + + return square_image_obj, pads, orig_size + +class UNetEnsemble(nn.Module): + # The first unet is the unet already loaded in a pipeline. + def __init__(self, unets, unet_types, extra_unet_dirpaths, unet_weights=None, device='cuda', torch_dtype=torch.float16): + super().__init__() + + self.unets = nn.ModuleList() + if unets is not None: + self.unets += [ unet.to(device) for unet in unets ] + + if unet_types is not None: + for unet_type in unet_types: + if unet_type == "arc2face": + from adaface.arc2face_models import create_arc2face_pipeline + unet = create_arc2face_pipeline(unet_only=True) + elif unet_type == "consistentID": + unet = create_consistentid_pipeline(unet_only=True) else: breakpoint() - else: - # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation. - # But list_extra_words always corresponds to the actual batch size. So we only take the first element. - list_extra_words = list_extra_words[:1] - - for extra_words in list_extra_words: - assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words." - # 16 ", " are placeholders for face_prompt_embs. - prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ] - else: - # 16 ", " are placeholders for face_prompt_embs. - # No extra words are added to the prompt. - prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ] - - # This step should be quite fast, and there's no need to cache the input_ids. - # input_ids: [BS, 77]. - input_ids = clip_tokenizer( - prompt_templates, - truncation=True, - padding="max_length", - max_length=input_max_length, - return_tensors="pt", - ).input_ids.to(face_prompt_embs.device) - - face_prompt_embs_dtype = face_prompt_embs.dtype - face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype) - - # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping). - token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True) - # token 4: first ", " in the template prompt. - # Replace embeddings of 16 placeholder ", " with face_prompt_embs. - token_embs[:, 4:20] = face_prompt_embs - - # This call does the ordinary CLIP text encoding pass. - prompt_embeds = inverse_text_encoder( - input_ids=input_ids, - input_token_embs=token_embs, - hidden_state_layer_weights=hidden_state_layer_weights, - return_token_embs=False - )[0] - - # Restore the original dtype of prompt_embeds: float16 -> float32. - prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype) - # token 4: first ", " in the template prompt. - # 4:20 are the most important 16 embeddings that contain the subject's identity. - # 20:22 are embeddings of the (at most) two extra words. - # [N, 77, 768] -> [N, 16, 768] - core_prompt_embs = prompt_embeds[:, 4:20] - if list_extra_words is not None: - # [N, 16, 768] -> [N, 18, 768] - extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale - core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1) - - return_prompts = [] - for emb_type in return_emb_types: - if emb_type == 'full': - return_prompts.append(prompt_embeds) - elif emb_type == 'full_half_pad': - prompt_embeds2 = prompt_embeds.clone() - PADS = prompt_embeds2.shape[1] - 23 - if PADS >= 2: - # Fill half of the remaining embeddings with pad embeddings. - prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2] - return_prompts.append(prompt_embeds2) - elif emb_type == 'full_pad': - prompt_embeds2 = prompt_embeds.clone() - # Fill the 22nd to the second last embeddings with pad embeddings. - prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1] - return_prompts.append(prompt_embeds2) - elif emb_type == 'core': - return_prompts.append(core_prompt_embs) - elif emb_type == 'full_zeroed_extra': - prompt_embeds2 = prompt_embeds.clone() - # Only add two pad embeddings. The remaining embeddings are set to 0. - # Make the positional embeddings align with the actual positions. - prompt_embeds2[:, 22:24] = pad_embeddings[22:24] - prompt_embeds2[:, 24:-1] = 0 - return_prompts.append(prompt_embeds2) - elif emb_type == 'b_core_e': - # The first 22 embeddings, plus the last EOS embedding. - b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22) - return_prompts.append(b_core_e_embs) - else: + self.unets.append(unet.to(device=device)) + + if extra_unet_dirpaths is not None: + for unet_path in extra_unet_dirpaths: + unet = UNet2DConditionModel.from_pretrained(unet_path, torch_dtype=torch_dtype) + self.unets.append(unet.to(device=device)) + + if unet_weights is None: + unet_weights = [1.] * len(self.unets) + elif len(self.unets) < len(unet_weights): + unet_weights = unet_weights[:len(self.unets)] + elif len(self.unets) > len(unet_weights): breakpoint() + + unet_weights = torch.tensor(unet_weights, dtype=torch_dtype) + unet_weights = unet_weights / unet_weights.sum() + self.unet_weights = nn.Parameter(unet_weights, requires_grad=False) + + print(f"UNetEnsemble: {len(self.unets)} UNets loaded with weights: {self.unet_weights.data.cpu().numpy()}") + # Set these fields to be compatible with diffusers. + self.dtype = self.unets[0].dtype + self.device = self.unets[0].device + self.config = self.unets[0].config + + def forward(self, *args, **kwargs): + return_dict = kwargs.get('return_dict', True) + teacher_contexts = kwargs.pop('encoder_hidden_states', None) + # Only one teacher_context is provided. That means all unets will use the same teacher_context. + # We repeat the teacher_contexts to match the number of unets. + if not isinstance(teacher_contexts, (list, tuple)): + teacher_contexts = [teacher_contexts] + if len(teacher_contexts) == 1 and len(self.unets) > 1: + teacher_contexts = teacher_contexts * len(self.unets) + + samples = [] + + for unet, teacher_context in zip(self.unets, teacher_contexts): + sample = unet(encoder_hidden_states=teacher_context, *args, **kwargs) + if not return_dict: + sample = sample[0] + else: + sample = sample.sample - return return_prompts - -# if pre_face_embs is None, generate random face embeddings [BS, 512]. -# image_folder is passed only for logging purpose. image_paths contains the paths of the images. -def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder, - extract_faceid_embeds, pre_face_embs, - image_folder, image_paths, images_np, - id_batch_size, device, - input_max_length=77, noise_level=0.0, - return_core_id_embs=False, - gen_neg_prompt=False, verbose=False): - face_image_count = 0 - - if extract_faceid_embeds: - faceid_embeds = [] - if image_paths is not None: - images_np = [] - for image_path in image_paths: - image_np = np.array(Image.open(image_path)) - images_np.append(image_np) - - for i, image_np in enumerate(images_np): - image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST) - # Remove alpha channel if it exists. - if image_obj.mode == 'RGBA': - image_obj = image_obj.convert('RGB') - # This seems NOT a bug. The input image should be in BGR format, as per - # https://github.com/deepinsight/insightface/issues/524 - image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR) - image_np = np.array(image_obj) - - face_infos = face_app.get(image_np) - if verbose and image_paths is not None: - print(image_paths[i], len(face_infos)) - # Assume all images belong to the same subject. Therefore, we can skip the images with no face detected. - if len(face_infos) == 0: - continue - # only use the maximum face - face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] - # Each faceid_embed: [1, 512] - faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0)) - face_image_count += 1 + samples.append(sample) - if verbose: - if image_folder is not None: - print(f"Extracted ID embeddings from {face_image_count} images in {image_folder}") - else: - print(f"Extracted ID embeddings from {face_image_count} images") + samples = torch.stack(samples, dim=0) + unet_weights = self.unet_weights.reshape(-1, *([1] * (samples.ndim - 1))) + sample = (samples * unet_weights).sum(dim=0) - if len(faceid_embeds) == 0: - print("No face detected. Use a random face instead.") - faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16) + if not return_dict: + return (sample,) else: - # faceid_embeds: [10, 512] - faceid_embeds = torch.cat(faceid_embeds, dim=0) - # faceid_embeds: [10, 512] -> [1, 512]. - # and the resulted prompt embeddings are the same. - faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16) - else: - # Random face embeddings. faceid_embeds: [BS, 512]. - if pre_face_embs is None: - faceid_embeds = torch.randn(id_batch_size, 512) + return UNet2DConditionOutput(sample=sample) + +def create_consistentid_pipeline(base_model_path="models/ensemble/sd15-dste8-vae.safetensors", + dtype=torch.float16, unet_only=False): + pipe = ConsistentIDPipeline.from_single_file( + base_model_path, + torch_dtype=dtype, + ) + # consistentID specific modules are still in fp32. Will be converted to fp16 + # later with .to(device, torch_dtype) by the caller. + pipe.load_ConsistentID_model( + consistentID_weight_path="./models/ConsistentID/ConsistentID-v1.bin", + bise_net_weight_path="./models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth", + ) + # We load the pipeline first, then use the unet in the pipeline. + # Since the pipeline initialization will load LoRA into the unet, + # now we have the unet with LoRA loaded. + if unet_only: + # We release text_encoder and VAE to save memory. + pipe.release_components(["text_encoder", "vae"]) + return pipe.unet + + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + pipe.scheduler = noise_scheduler + + return pipe + +@dataclass +class BaseModelOutputWithPooling2(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + attn_mask: Optional[torch.FloatTensor] = None + +# Revised from CLIPVisionTransformer to support attention mask. +# self: a CLIPVisionTransformer instance. +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821 +# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224] +# attn_mask: B*H*W attention mask. +def CLIPVisionTransformer_forward_with_mask(self, pixel_values = None, attn_mask=None, + output_attentions = None, + output_hidden_states = None, return_dict = None): + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Visual tokens are flattended in embeddings(). + # self.embeddings: CLIPVisionEmbeddings. + # hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds). + # 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False). + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + if attn_mask is not None: + # feat_edge_size: 16. + feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int) + # attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16]. + attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest') + # Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256]. + attn_mask = attn_mask.flatten(2) + # Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257]. + # This 1 corresponds to class_embeds, which is always attended to. + attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1) + attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1) else: - faceid_embeds = pre_face_embs - if pre_face_embs.shape[0] == 1: - faceid_embeds = faceid_embeds.repeat(id_batch_size, 1) - - faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16) - - if noise_level > 0: - # If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different. - faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True) - - faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1) - - # arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768] - with torch.no_grad(): - arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \ - arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder, - faceid_embeds, input_max_length=input_max_length, - return_full_and_core_embs=True) - if return_core_id_embs: - arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb - # If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1. - # So we need to repeat faceid_embeds. - if extract_faceid_embeds: - faceid_embeds = faceid_embeds.repeat(id_batch_size, 1) - arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1) - - if gen_neg_prompt: - with torch.no_grad(): - arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \ - arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder, - torch.zeros_like(faceid_embeds), - input_max_length=input_max_length, - return_full_and_core_embs=True) - if return_core_id_embs: - arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb - - #if extract_faceid_embeds: - # arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1) - return face_image_count, faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb - else: - return face_image_count, faceid_embeds, arc2face_pos_prompt_emb - \ No newline at end of file + attn_mask_pairs = None + + # encoder: CLIPEncoder. + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + # New feature: (***The official documentation is wrong***) + # attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*): + # Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`: + # - 1 for pairs that are **not masked**, + # - 0 for pairs that are **masked**. + # attention_mask is eventually used by CLIPEncoderLayer: + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370 + attention_mask=attn_mask_pairs, + output_attentions=output_attentions, # False + output_hidden_states=output_hidden_states, # True + return_dict=return_dict, # True + ) + + # last_hidden_state: [BS, 257, 1280] + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + # return_dict is True. + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling2( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + # Newly added: return resized flattened attention mask. + # [BS, 1, 257] -> [BS, 257, 1] + attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None + ) + +def CLIPVisionModel_forward_with_mask(self, pixel_values = None, attn_mask = None, output_attentions = None, + output_hidden_states = None, return_dict = None): + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + attn_mask=attn_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + +# patch_clip_image_encoder_with_mask() is applicable to both CLIPVisionModel and CLIPVisionModelWithProjection. +def patch_clip_image_encoder_with_mask(clip_image_encoder): + clip_image_encoder.vision_model.forward = CLIPVisionTransformer_forward_with_mask.__get__(clip_image_encoder.vision_model) + clip_image_encoder.forward = CLIPVisionModel_forward_with_mask.__get__(clip_image_encoder) + return clip_image_encoder + +class CLIPVisionModelWithMask(CLIPVisionModel): + def __init__(self, config): + super().__init__(config) + # Replace vision_model.forward() with the new one that supports mask. + patch_clip_image_encoder_with_mask(self) + diff --git a/animatediff/models/unet.py b/animatediff/models/unet.py index 3772bed3b63f811e9d7e267e8eef3842fbf0517b..ff2c7b32b42a3785e3a1debb9969711062b89d9a 100644 --- a/animatediff/models/unet.py +++ b/animatediff/models/unet.py @@ -520,9 +520,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): raise RuntimeError(f"{model_file} does not exist") state_dict = torch.load(model_file, map_location="cpu") + # pretrained_model_path is the pretrained 3D unet, so it doesn't contain motion module weights. + # There are 714 missing keys and 0 unexpected keys. m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") - + params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] print(f"### Motion Module Parameters: {sum(params) / 1e6} M") diff --git a/animatediff/pipelines/pipeline_animation.py b/animatediff/pipelines/pipeline_animation.py index 20e611561227ac059cde34b92e0f6b3d37d65a02..bc9a34a3e9d7f8b458fb6ed289764eaabf973958 100644 --- a/animatediff/pipelines/pipeline_animation.py +++ b/animatediff/pipelines/pipeline_animation.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py import inspect -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Union, Tuple from dataclasses import dataclass import PIL.Image @@ -488,7 +488,7 @@ class AnimationPipeline(DiffusionPipeline): width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt_embeds, (list, tuple)): - prompt_embeds_begin, prompt_embeds_end, adaface_anneal_steps = prompt_embeds + prompt_embeds_begin, prompt_embeds_end, id_animator_anneal_steps = prompt_embeds prompt_embeds = prompt_embeds_begin do_prompt_embeds_annealing = True else: @@ -595,9 +595,11 @@ class AnimationPipeline(DiffusionPipeline): ) if do_prompt_embeds_annealing: - # i: 0 to num_inference_steps. Anneal the first adaface_anneal_steps steps. - # If adaface_anneal_steps == 0, then anneal_factor is always 1. - anneal_factor = i / adaface_anneal_steps if i < adaface_anneal_steps else 1 + # i: 0 to num_inference_steps. Anneal the first id_animator_anneal_steps steps, + # by linearly moving from prompt_embeds_begin to prompt_embeds_end. + # If id_animator_anneal_steps == 0, then anneal_factor is always 1, i.e., + # always uses prompt_embeds_end. + anneal_factor = i / id_animator_anneal_steps if i < id_animator_anneal_steps else 1 prompt_embeds_annealed = prompt_embeds_begin + anneal_factor * (prompt_embeds_end - prompt_embeds_begin) text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds_annealed]) @@ -640,7 +642,9 @@ class AnimationPipeline(DiffusionPipeline): prompt: Union[str, List[str]], video_length: Optional[int], init_image: Union[PIL.Image.Image, torch.Tensor], - init_image_strength: float = 1.0, + # init_image_strength could be a number to indicate the init_image_strength for all frames, + # or a tuple of two numbers to indicate the init_image_strength for the first frame and the final frame. + init_image_strength: Union[float, Tuple[float, float]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, diff --git a/app.py b/app.py index ea83ea2d6ae0250c2776846f6b1d1286801cce0b..16d4493d19e1bb1fb00c8867229cc0f7abe43049 100644 --- a/app.py +++ b/app.py @@ -4,6 +4,7 @@ css = ''' .gradio-container {width: 85% !important} ''' from animatediff.utils.util import save_videos_grid +from adaface.adaface_wrapper import AdaFaceWrapper import random from infer import load_model @@ -20,31 +21,43 @@ import torch import argparse # From command line read command adaface_ckpt_path parser = argparse.ArgumentParser() +parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"], + choices=["arc2face", "consistentID"], help="Type(s) of the ID2Ada prompt encoders") parser.add_argument('--adaface_ckpt_path', type=str, - default='models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt') + default='models/adaface/VGGface2_HQ_masks2024-10-05T09-28-53_zero3-ada-28000.pt') # Don't use 'sd15' for base_model_type; it just generates messy videos. -parser.add_argument('--base_model_type', type=str, default='sar') -parser.add_argument('--adaface_base_model_type', type=str, default='sar') +parser.add_argument('--base_model_type', type=str, default='rv51', + choices=["sar", "rv51"]) parser.add_argument('--gpu', type=int, default=None) parser.add_argument('--ip', type=str, default="0.0.0.0") args = parser.parse_args() +base_model_type_to_path = { + "sd15": "models/sd15-dste8-vae.safetensors", # LDM format. Needs to be converted. + "sar": "models/sar/sar.safetensors", # LDM format. Needs to be converted. + "rv51": "models/rv51/realisticVisionV51_v51VAE.safetensors" +} + def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed # model = load_model() -# This FaceAnalysis uses a different model from what AdaFace uses, but it's fine. -# This is just to crop the face areas from the uploaded images. +# This FaceAnalysis is just to crop the face areas from the uploaded images, +# and is independent of the adaface FaceAnalysis apps. app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) app.prepare(ctx_id=0, det_size=(320, 320)) device = "cuda" if args.gpu is None else f"cuda:{args.gpu}" -id_animator, adaface = load_model(base_model_type=args.base_model_type, - adaface_base_model_type=args.adaface_base_model_type, - adaface_ckpt_path=args.adaface_ckpt_path, - device=device) +id_animator = load_model(base_model_type=args.base_model_type, device=device) + +base_model_path = base_model_type_to_path[args.base_model_type] + +adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path, + adaface_encoder_types=args.adaface_encoder_types, + adaface_ckpt_paths=[args.adaface_ckpt_path], device=device) + basedir = os.getcwd() savedir = os.path.join(basedir,'samples') os.makedirs(savedir, exist_ok=True) @@ -68,7 +81,7 @@ def get_clicked_image(data: gr.SelectData): return data.index @spaces.GPU -def gen_init_images(uploaded_image_paths, prompt, adaface_id_cfg_scale, out_image_count=3): +def gen_init_images(uploaded_image_paths, prompt, out_image_count=3): if uploaded_image_paths is None: print("No image uploaded") return None, None, None @@ -76,11 +89,9 @@ def gen_init_images(uploaded_image_paths, prompt, adaface_id_cfg_scale, out_imag # [('/tmp/gradio/249981e66a7c665aaaf1c7eaeb24949af4366c88/jensen huang.jpg', None)] # Extract the file paths. uploaded_image_paths = [path[0] for path in uploaded_image_paths] - # gen_init_images() uses a larger adaface_id_cfg_scale to generate more authentic faces. - adaface_id_cfg_scale_ = min(6, adaface_id_cfg_scale * 2) adaface_subj_embs = \ - adaface.generate_adaface_embeddings(image_folder=None, image_paths=uploaded_image_paths, - out_id_embs_scale=adaface_id_cfg_scale_, update_text_encoder=True) + adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None, + update_text_encoder=True) if adaface_subj_embs is None: raise gr.Error(f"Failed to detect any faces! Please try with other images") @@ -105,9 +116,10 @@ def gen_init_images(uploaded_image_paths, prompt, adaface_id_cfg_scale, out_imag @spaces.GPU(duration=90) def generate_image(image_container, uploaded_image_paths, init_img_file_paths, init_img_selected_idx, init_image_strength, init_image_final_weight, - prompt, negative_prompt, num_steps, video_length, guidance_scale, seed, attn_scale, image_embed_scale, - is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale, - adaface_anneal_steps, progress=gr.Progress(track_tqdm=True)): + prompt, negative_prompt, num_steps, video_length, guidance_scale, seed, + attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale, + is_adaface_enabled, adaface_ckpt_path, adaface_power_scale, + id_animator_anneal_steps, progress=gr.Progress(track_tqdm=True)): if prompt is None: prompt = "" @@ -129,16 +141,22 @@ def generate_image(image_container, uploaded_image_paths, init_img_file_paths, i if adaface is None or not is_adaface_enabled: adaface_prompt_embeds = None + image_embed_cfg_scales = (1, 1) else: - if adaface_ckpt_path != args.adaface_ckpt_path: + if (adaface_ckpt_path is not None and adaface_ckpt_path.strip() != '') \ + and (adaface_ckpt_path != args.adaface_ckpt_path): # Reload the embedding manager - adaface.load_subj_basis_generator(adaface_ckpt_path) + adaface.id2ada_prompt_encoder.load_adaface_ckpt(adaface_ckpt_path) with torch.no_grad(): - adaface.generate_adaface_embeddings(image_folder=None, image_paths=uploaded_image_paths, - out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True) + adaface_subj_embs = \ + adaface.prepare_adaface_embeddings(image_paths=uploaded_image_paths, face_id_embs=None, + update_text_encoder=True) + # adaface_prompt_embeds: [1, 77, 768]. - adaface_prompt_embeds, _ = adaface.encode_prompt(prompt, verbose=True) + adaface_prompt_embeds, _, _, _ = adaface.encode_prompt(prompt, verbose=True) + + image_embed_cfg_scales = (image_embed_cfg_begin_scale, image_embed_cfg_end_scale) # init_img_file_paths is a list of image paths. If not chose, init_img_file_paths is None. if init_img_file_paths is not None: @@ -152,23 +170,23 @@ def generate_image(image_container, uploaded_image_paths, init_img_file_paths, i init_image = None sample = id_animator.generate(prompt_img_lists, - init_image = init_image, - init_image_strength = (init_image_strength, init_image_final_weight), - prompt = prompt, - negative_prompt = negative_prompt, - adaface_embeds = adaface_prompt_embeds, - # adaface_scale is not so useful, and when it's set >= 2, weird artifacts appear. + init_image = init_image, + init_image_strength = (init_image_strength, init_image_final_weight), + prompt = prompt, + negative_prompt = negative_prompt, + adaface_prompt_embeds = adaface_prompt_embeds, + # adaface_power_scale is not so useful, and when it's set >= 2, weird artifacts appear. # Here it's limited to 0.7~1.3. - adaface_scale = adaface_power_scale, - num_inference_steps = num_steps, - adaface_anneal_steps = adaface_anneal_steps, - seed=seed, - guidance_scale = guidance_scale, - width = 512, - height = 512, - video_length = video_length, - attn_scale = attn_scale, - image_embed_scale = image_embed_scale, + adaface_power_scale = adaface_power_scale, + num_inference_steps = num_steps, + id_animator_anneal_steps = id_animator_anneal_steps, + seed = seed, + guidance_scale = guidance_scale, + width = 512, + height = 512, + video_length = video_length, + attn_scale = attn_scale, + image_embed_cfg_scales = image_embed_cfg_scales, ) save_sample_path = os.path.join(savedir, f"{random_name}.mp4") @@ -179,57 +197,6 @@ def validate_prompt(prompt): if not prompt: raise gr.Error("Prompt cannot be blank") -examples = [ - [ - "demo/ann.png", - ["demo/ann.png" ], - "A young girl with a passion for reading, curled up with a book in a cozy nook near a window", - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck,", - 30, - 8, 8290,1,16 - ], - [ - "demo/lecun.png", - ["demo/lecun.png" ], - "Iron Man soars through the clouds, his repulsors blazing", - "worst quality, low quality, jpeg artifacts, ugly, duplicate, blurry, long neck", - 30, - 8, 4993,0.7,16 - ], - [ - "demo/mix.png", - ["demo/lecun.png","demo/ann.png"], - "A musician playing a guitar, fingers deftly moving across the strings, producing a soulful melody", - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", - 30, - 8, 1897,0.9,16 - ], - [ - "demo/zendaya.png", - ["demo/zendaya.png" ], - "A woman on a serene beach at sunset, the sky ablaze with hues of orange and purple.", - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", - 30, - 8, 5992,1,16 - ], - [ - "demo/qianlong.png", - ["demo/qianlong.png" ], - "A chef in a white apron, complete with a toqueblanche, garnishing a gourmet dish", - "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, UnrealisticDream", - 30, - 8, 1844,0.8,16 - ], - [ - "demo/augustus.png", - ["demo/augustus.png" ], - "A man with dyed pink and purple hair, styledin a high ponytail", - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", - 30, - 8, 870,0.7,16 - ] -] - with gr.Blocks(css=css) as demo: gr.Markdown( """ @@ -238,7 +205,7 @@ with gr.Blocks(css=css) as demo: ) gr.Markdown( """ -Official demo for our NeurIPS 2024 submission AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization.
+Official demo for our working paper AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization.
❗️**Tips**❗️ - You can upload one or more subject images for generating ID-specific video. @@ -267,7 +234,7 @@ with gr.Blocks(css=css) as demo: remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=files, size="sm") init_img_files = gr.File( - label="[Optional] Select 1 image for initialization, or generate 3 images with the button below and select 1", + label="[Optional] Generate 3 images and select 1 image", file_types=["image"], file_count="multiple" ) @@ -278,13 +245,36 @@ with gr.Blocks(css=css) as demo: # placeholder is just hint, not the real value. So we use "value='0'" instead of "placeholder='0'". init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False) + with gr.Column(visible=True) as init_gen_button_column: + gen_init = gr.Button(value="Generate 3 new init images") + with gr.Column(visible=False) as init_clear_button_column: + remove_init_and_reupload = gr.ClearButton(value="Upload an old init image", components=init_img_files, size="sm") + + prompt = gr.Dropdown(label="Prompt", + info="Try something like 'man/woman walking on the beach'.", + value="((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin", + allow_custom_value=True, + filterable=False, + choices=[ + "((best quality)), ((masterpiece)), ((realistic)), highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin", + "walking on the beach, sunset, orange sky, eye level shot", + "in a white apron and chef hat, garnishing a gourmet dish, full body view, long shot", + "dancing pose among folks in a park, waving hands", + "in iron man costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot", + "jedi wielding a lightsaber, star wars, full body view, eye level shot", + "playing guitar on a boat, ocean waves", + "with a passion for reading, curled up with a book in a cozy nook near a window", + "running pose in a park, eye level shot", + "in superman costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot" + ]) + init_image_strength = gr.Slider( label="Init Image Strength", info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).", minimum=0, - maximum=3, + maximum=1.5, step=0.25, - value=1.5, + value=1, ) init_image_final_weight = gr.Slider( label="Final Weight of the Init Image", @@ -295,65 +285,51 @@ with gr.Blocks(css=css) as demo: value=0.1, ) - with gr.Column(visible=False) as init_clear_button_column: - remove_init_and_reupload = gr.ClearButton(value="Remove and upload new init image", components=init_img_files, size="sm") - with gr.Column(visible=True) as init_gen_button_column: - gen_init = gr.Button(value="Generate 3 new init images") - - prompt = gr.Dropdown(label="Prompt", - info="Try something like 'man/woman walking on the beach'. If the face is not in focus, try adding 'face portrait of' at the beginning.", - value=None, - allow_custom_value=True, - filterable=False, - choices=[ - "woman ((best quality)), ((masterpiece)), ((realistic)), long highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin", - "woman walking on the beach, sunset, orange sky, eye level shot", - "woman in a white apron and chef hat, garnishing a gourmet dish, full body view, long shot", - "woman dancing pose among folks in a park, waving hands", - "woman in iron man costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot", - "woman jedi wielding a lightsaber, star wars, full body view, eye level shot", - "woman playing guitar on a boat, ocean waves", - "woman with a passion for reading, curled up with a book in a cozy nook near a window", - "woman running pose in a park, eye level shot", - "woman in superman costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot" - ]) - - image_embed_scale = gr.Slider( - label="ID-Animator Image Embedding Scale", - info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)", - minimum=0, - maximum=2, - step=0.1, - value=0.8, - ) - attn_scale = gr.Slider( - label="Attention Processor Scale", - info="The scale of the ID embeddings on the attention (the higher, the more focus on the face, less on the background)" , - minimum=0, - maximum=2, - step=0.1, - value=0.8, - ) - adaface_id_cfg_scale = gr.Slider( - label="AdaFace CFG Scale", - info="The CFG scale of the AdaFace ID embeddings (influencing fine facial features)", - minimum=0.5, - maximum=6, - step=0.25, - value=1.5, - ) + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=1.0, + maximum=8.0, + step=0.5, + value=6, + ) + + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=985, + ) + randomize_seed = gr.Checkbox( + label="Randomize seed", + value=True, + info="Uncheck for reproducible results") + + negative_prompt = gr.Textbox( + label="Negative Prompt", + placeholder="low quality", + value="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream", + ) + num_steps = gr.Slider( + label="Number of sampling steps", + minimum=30, + maximum=80, + step=1, + value=50, + ) submit = gr.Button("Generate Video") with gr.Accordion(open=False, label="Advanced Options"): video_length = gr.Slider( label="video_length", - info="Do not change, otherwise the video will be messy", + info="Do not change; any values other than 16 will mess up the output video", minimum=16, maximum=21, step=1, value=16, interactive=False, + visible=False, ) is_adaface_enabled = gr.Checkbox(label="Enable AdaFace", info="Enable AdaFace for better face details. If unchecked, it falls back to ID-Animator (https://huggingface.co/spaces/ID-Animator/ID-Animator).", @@ -372,44 +348,41 @@ with gr.Blocks(css=css) as demo: step=0.1, value=1, ) - - # adaface_anneal_steps is no longer necessary, but we keep it here for future use. - adaface_anneal_steps = gr.Slider( - label="AdaFace Anneal Steps", - minimum=0, - maximum=2, - step=1, - value=0, - visible=False, - ) - - negative_prompt = gr.Textbox( - label="Negative Prompt", - placeholder="low quality", - value="face portrait, (deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream", - ) - num_steps = gr.Slider( - label="Number of sampling steps", - minimum=25, - maximum=100, - step=1, - value=40, - ) - guidance_scale = gr.Slider( - label="Guidance scale (usually you don't need to change)", - minimum=1.0, - maximum=10.0, - step=0.5, - value=4, - ) - seed = gr.Slider( - label="Seed", + + image_embed_cfg_begin_scale = gr.Slider( + label="ID-Animator Image Embedding Initial Scale", + info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)", + minimum=0.3, + maximum=1.5, + step=0.1, + value=1, + ) + image_embed_cfg_end_scale = gr.Slider( + label="ID-Animator Image Embedding Final Scale", + info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)", + minimum=0.0, + maximum=0.6, + step=0.1, + value=0.5, + ) + + id_animator_anneal_steps = gr.Slider( + label="ID-Animator Scale Anneal Steps", minimum=0, - maximum=MAX_SEED, + maximum=40, step=1, - value=985, + value=20, + visible=True, ) - randomize_seed = gr.Checkbox(label="Randomize seed", value=False, info="Uncheck for reproducible results") + + attn_scale = gr.Slider( + label="ID-Animator Attention Processor Scale", + info="The scale of the ID embeddings on the attention (the higher, the more focus on the face, less on the background)" , + minimum=0, + maximum=2, + step=0.1, + value=1, + ) with gr.Column(): result_video = gr.Video(label="Generated Animation", interactive=False) @@ -420,7 +393,7 @@ with gr.Blocks(css=css) as demo: init_img_files.upload(fn=swap_to_gallery, inputs=init_img_files, outputs=[uploaded_init_img_gallery, init_clear_button_column, init_img_files]) remove_init_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_init_img_gallery, init_clear_button_column, init_img_files, init_img_selected_idx]) - gen_init.click(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt, adaface_id_cfg_scale], + gen_init.click(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt], outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column]) uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx) @@ -435,15 +408,9 @@ with gr.Blocks(css=css) as demo: fn=generate_image, inputs=[image_container, files, init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight, prompt, negative_prompt, num_steps, video_length, guidance_scale, - seed, attn_scale, image_embed_scale, - is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale, adaface_anneal_steps], + seed, attn_scale, image_embed_cfg_begin_scale, image_embed_cfg_end_scale, + is_adaface_enabled, adaface_ckpt_path, adaface_power_scale, id_animator_anneal_steps], outputs=[result_video] ) - gr.Examples( fn=generate_image, examples=[], #examples, - inputs=[image_container, files, init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight, - prompt, negative_prompt, num_steps, video_length, guidance_scale, - seed, attn_scale, image_embed_scale, - is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale, adaface_anneal_steps], - outputs=[result_video], cache_examples=True ) demo.launch(share=True, server_name=args.ip, ssl_verify=False) diff --git a/faceadapter/face_adapter.py b/faceadapter/face_adapter.py index d6349aff213afff7a0a2f2ab5adb1289b1c16be0..3a873aa6e837ad6f709dedf1167e72297232fae7 100644 --- a/faceadapter/face_adapter.py +++ b/faceadapter/face_adapter.py @@ -250,18 +250,18 @@ class FaceAdapterPlusForVideoLora(FaceAdapterLora): clip_image_embeds=None, prompt=None, negative_prompt=None, - adaface_embeds=None, - adaface_scale=1.0, + adaface_prompt_embeds=None, + adaface_power_scale=1.0, attn_scale=1.0, num_samples=1, seed=None, guidance_scale=4, num_inference_steps=30, - adaface_anneal_steps=0, + id_animator_anneal_steps=0, width=512, height=512, video_length=16, - image_embed_scale=1, + image_embed_cfg_scales=[0.8, 0.3], controlnet_images: torch.FloatTensor = None, controlnet_image_index: list = [0], **kwargs, @@ -278,6 +278,7 @@ class FaceAdapterPlusForVideoLora(FaceAdapterLora): prompt = [prompt] * num_prompts if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * num_prompts + num_prompt_img = len(pil_image) total_image_prompt_embeds = 0 for i in range(num_prompt_img): @@ -289,10 +290,11 @@ class FaceAdapterPlusForVideoLora(FaceAdapterLora): image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) total_image_prompt_embeds += image_prompt_embeds - total_image_prompt_embeds /= num_prompt_img - image_prompt_embeds = total_image_prompt_embeds + + image_prompt_embeds = total_image_prompt_embeds / num_prompt_img uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + with torch.inference_mode(): # if do_classifier_free_guidance, # duplicate unconditional embeddings for each generation per prompt, using mps friendly method. @@ -305,21 +307,24 @@ class FaceAdapterPlusForVideoLora(FaceAdapterLora): negative_prompt=negative_prompt, ) - if adaface_embeds is not None: - prompt_embeds0_ = prompt_embeds_ - # self.torch_type == torch.float16. adaface_embeds is torch.float32. - prompt_embeds_ = adaface_embeds.repeat(num_samples, 1, 1).to(dtype=self.torch_type) * adaface_scale - # Scale down ID-Animator's face embeddings, so that they don't dominate the generation. + if adaface_prompt_embeds is not None: + # self.torch_type == torch.float16. adaface_prompt_embeds is torch.float32. + prompt_embeds_ = adaface_prompt_embeds.repeat(num_samples, 1, 1).to(dtype=self.torch_type) \ + * adaface_power_scale # Note to balance image_prompt_embeds with uncond_image_prompt_embeds after scaling. - image_prompt_embeds = image_prompt_embeds * image_embed_scale + uncond_image_prompt_embeds * (1 - image_embed_scale) - # We still need uncond_image_prompt_embeds, otherwise the output is blank. - prompt_embeds_end = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) - prompt_embeds_begin = torch.cat([prompt_embeds0_, torch.zeros_like(image_prompt_embeds)], dim=1) - prompt_embeds = (prompt_embeds_begin, prompt_embeds_end, adaface_anneal_steps) + image_prompt_embeds_begin = image_prompt_embeds * image_embed_cfg_scales[0] + uncond_image_prompt_embeds * (1 - image_embed_cfg_scales[0]) + image_prompt_embeds_end = image_prompt_embeds * image_embed_cfg_scales[1] + uncond_image_prompt_embeds * (1 - image_embed_cfg_scales[1]) + # Disable annealing by setting prompt_embeds_begin the same as prompt_embeds_end. + prompt_embeds_begin = torch.cat([prompt_embeds_, image_prompt_embeds_begin], dim=1) + prompt_embeds_end = torch.cat([prompt_embeds_, image_prompt_embeds_end], dim=1) + # Scale down ID-Animator's face embeddings from prompt_embeds_begin to prompt_embeds_end, + # so that they don't dominate the generation. + prompt_embeds = (prompt_embeds_begin, prompt_embeds_end, id_animator_anneal_steps) else: + # The conventional ID-Animator's way to apply face embeddings. prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) - # prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + # We still need uncond_image_prompt_embeds, otherwise the output is blank. negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) generator = get_generator(seed, self.device) diff --git a/infer.py b/infer.py index 13203a16257d5cf1ca61ad6adce90fe982fa4587..8f6df47af508baab5f8747edd0820092a10f3139 100644 --- a/infer.py +++ b/infer.py @@ -8,38 +8,20 @@ from animatediff.utils.util import load_weights from safetensors import safe_open from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from faceadapter.face_adapter import FaceAdapterPlusForVideoLora -from adaface.adaface_wrapper import AdaFaceWrapper -def load_adaface(base_model_path, adaface_ckpt_path, device="cuda"): - # base_model_path is only used for initialization, not really used in the inference. - adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path=base_model_path, - adaface_ckpt_path=adaface_ckpt_path, device=device) - return adaface +base_model_type_to_path = { + "sd15": "models/sd15-dste8-vae.safetensors", # LDM format. Needs to be converted. + "sar": "models/sar/sar.safetensors", # LDM format. Needs to be converted. + "rv51": "models/rv51/realisticVisionV51_v51VAE.safetensors" +} -def load_model(base_model_type="sar", adaface_base_model_type="sar", - adaface_ckpt_path=None, device="cuda"): +def load_model(base_model_type="rv51", device="cuda"): inference_config = "inference-v2.yaml" sd_version = "animatediff/sd" id_ckpt = "models/animator.ckpt" image_encoder_path = "models/image_encoder" - base_model_type_to_path = { - "rv40": "models/realisticvision/realisticVisionV40_v40VAE.safetensors", - "rv60": "models/realisticvision/realisticVisionV60B1_v51VAE.safetensors", - "sd15": "models/stable-diffusion-v-1-5/v1-5-pruned.safetensors", - "sd15_adaface": "models/stable-diffusion-v-1-5/v1-5-dste8-vae.ckpt", - "toonyou": "models/toonyou/toonyou_beta6.safetensors", - "epv5": "models/epic_realism/epicrealism_pureEvolutionV5.safetensors", - "ar181": "models/absolutereality/absolutereality_v181.safetensors", - "ar16": "models/absolutereality/ar-v1-6.safetensors", - "sar": "models/sar/sar.safetensors", - } - base_model_path = base_model_type_to_path[base_model_type] - if adaface_base_model_type + "_adaface" in base_model_type_to_path: - adaface_base_model_path = base_model_type_to_path[adaface_base_model_type + "_adaface"] - else: - adaface_base_model_path = base_model_type_to_path[adaface_base_model_type] motion_module_path="models/v3_sd15_mm.ckpt" motion_lora_path = "models/v3_sd15_adapter.ckpt" @@ -120,11 +102,5 @@ def load_model(base_model_type="sar", adaface_base_model_type="sar", id_animator = FaceAdapterPlusForVideoLora(pipeline, image_encoder_path, id_ckpt, num_tokens=16, device=torch.device(device), torch_type=torch.float16) - if adaface_ckpt_path is not None: - adaface = load_adaface(adaface_base_model_path, #dreambooth_model_path, - adaface_ckpt_path, device) - else: - adaface = None - - return id_animator, adaface + return id_animator diff --git a/models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth b/models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth new file mode 100644 index 0000000000000000000000000000000000000000..ca57f3257ca7715bc340d065764bc249d985c287 --- /dev/null +++ b/models/ConsistentID/BiSeNet_pretrained_for_ConsistentID.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567 +size 53289463 diff --git a/models/ConsistentID/ConsistentID-v1.bin b/models/ConsistentID/ConsistentID-v1.bin new file mode 100644 index 0000000000000000000000000000000000000000..cb4022f6ff8830c29609aa344c162ff749c27063 --- /dev/null +++ b/models/ConsistentID/ConsistentID-v1.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:48cd9faab558c09565dfb4a355976ea44501fb496c11e3ced722286a8453765b +size 669123998 diff --git a/models/adaface/VGGface2_HQ_masks2024-10-05T09-28-53_zero3-ada-28000.pt b/models/adaface/VGGface2_HQ_masks2024-10-05T09-28-53_zero3-ada-28000.pt new file mode 100644 index 0000000000000000000000000000000000000000..8d06ccf3404b8f78f008d7242128c580a777654a --- /dev/null +++ b/models/adaface/VGGface2_HQ_masks2024-10-05T09-28-53_zero3-ada-28000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f6959ba41eb8cc8fcc738ba5ecc751de3acc0d1180e3af2272b7b52b04c6ae8 +size 1814922042 diff --git a/models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt b/models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt deleted file mode 100644 index 6011c1f026181050a8f83589fef8840d95815f88..0000000000000000000000000000000000000000 --- a/models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4aa1eb9ff3e364ea1b9db6dfff0c281ff3b57864d7ccc4c64d5f29ed752484f3 -size 821700521 diff --git a/models/rv51/realisticVisionV51_v51VAE.safetensors b/models/rv51/realisticVisionV51_v51VAE.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..be041649019837e489d917b5371c46a6ba343ebf --- /dev/null +++ b/models/rv51/realisticVisionV51_v51VAE.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15012c538f503ce2ebfc2c8547b268c75ccdaff7a281db55399940ff1d70e21d +size 2132625894 diff --git a/models/sar/sar.safetensors b/models/sar/sar.safetensors index 3f4a4034e2849e802b8ec46fd3857381d1e8069c..387a99ea3c10ecf5f17c5cb3503d53f562edb8e4 100644 --- a/models/sar/sar.safetensors +++ b/models/sar/sar.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:35a5d7615850879ffecce7b1e463ae0317c95fe784dd9b179793b58531a9e3ab -size 2299982596 +oid sha256:cb1e0b365337bd67b59a2c2b151699d2b60f16b3d762297d8b93ecc5b48d6c94 +size 2132650982 diff --git a/models/sd15-dste8-vae.safetensors b/models/sd15-dste8-vae.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..ac289e49808266684a28df13fa967f4c1d62b74c --- /dev/null +++ b/models/sd15-dste8-vae.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb52b6406b954dc898b784d4e91840c53f51aed80c8509b556102b5bbf7da2cf +size 2132650982 diff --git a/requirements.txt b/requirements.txt index 3962ee921758b1e54bab7501805d3c47795cf02e..6d47e2e4339562595cdb933aa35ffa61f8e620e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +diffusers==0.29.2 torch torchvision imageio @@ -6,11 +7,9 @@ accelerate einops gradio transformers -wandb insightface omegaconf opencv-python -diffusers onnx>=1.16.0 onnxruntime safetensors