File size: 21,541 Bytes
02cc20b
 
 
 
 
 
ad88a0b
 
02cc20b
 
 
ad88a0b
 
 
 
02cc20b
ad88a0b
02cc20b
 
ad88a0b
 
 
 
 
 
 
02cc20b
ad88a0b
 
02cc20b
 
 
 
 
ad88a0b
 
 
 
 
02cc20b
ad88a0b
02cc20b
ad88a0b
 
 
 
 
 
02cc20b
 
ad88a0b
02cc20b
 
 
 
 
 
 
 
 
ad88a0b
 
 
 
 
02cc20b
 
ad88a0b
 
 
 
 
 
 
02cc20b
 
 
 
ad88a0b
 
02cc20b
 
 
 
 
 
ad88a0b
 
02cc20b
 
 
 
 
 
 
 
 
ad88a0b
 
61fbdeb
 
02cc20b
 
 
 
 
 
 
ad88a0b
 
 
 
 
 
 
 
02cc20b
 
 
 
 
 
 
 
 
 
 
ad88a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
02cc20b
ad88a0b
 
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
ad88a0b
 
 
 
 
 
 
 
 
 
 
02cc20b
 
ad88a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02cc20b
ad88a0b
 
02cc20b
 
ad88a0b
 
 
 
 
 
 
 
 
02cc20b
ad88a0b
 
 
 
02cc20b
 
ad88a0b
 
02cc20b
ad88a0b
02cc20b
 
ad88a0b
02cc20b
 
ad88a0b
 
02cc20b
ad88a0b
02cc20b
 
ad88a0b
02cc20b
 
 
ad88a0b
02cc20b
ad88a0b
02cc20b
 
 
 
ad88a0b
02cc20b
a29cf91
abaeb15
 
ad88a0b
 
 
 
 
 
 
 
 
a29cf91
 
 
 
ad88a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02cc20b
ad88a0b
b0b5a77
ad88a0b
 
 
 
02cc20b
ad88a0b
 
02cc20b
a29cf91
 
 
f4d87fd
 
13d8b07
 
 
ad88a0b
a29cf91
02cc20b
ad88a0b
02cc20b
 
 
 
ad88a0b
 
 
13d8b07
 
 
 
 
 
 
 
 
ad88a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02cc20b
 
a29cf91
 
 
 
ad88a0b
 
b0b5a77
 
02cc20b
ad88a0b
 
a29cf91
 
 
02cc20b
ad88a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02cc20b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
import torch
import torch.nn as nn
from transformers import CLIPTextModel
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusion3Pipeline,
    #FluxPipeline,
    DDIMScheduler,
    AutoencoderKL,
)
from diffusers.loaders.single_file_utils import convert_ldm_unet_checkpoint
from adaface.util import UNetEnsemble
from adaface.face_id_to_ada_prompt import create_id2ada_prompt_encoder
from safetensors.torch import load_file as safetensors_load_file
import re, os
import numpy as np

class AdaFaceWrapper(nn.Module):
    def __init__(self, pipeline_name, base_model_path, adaface_encoder_types, 
                 adaface_ckpt_paths, adaface_encoder_cfg_scales=None, 
                 enabled_encoders=None,
                 subject_string='z', num_inference_steps=50, negative_prompt=None,
                 use_840k_vae=False, use_ds_text_encoder=False, 
                 main_unet_filepath=None, unet_types=None, extra_unet_dirpaths=None, unet_weights=None,
                 device='cuda', is_training=False):
        '''
        pipeline_name: "text2img", "img2img", "text2img3", "flux", or None. 
        If None, it's used only as a face encoder, and the unet and vae are
        removed from the pipeline to release RAM.
        '''
        super().__init__()
        self.pipeline_name = pipeline_name
        self.base_model_path = base_model_path
        self.adaface_encoder_types = adaface_encoder_types

        self.adaface_ckpt_paths = adaface_ckpt_paths
        self.adaface_encoder_cfg_scales = adaface_encoder_cfg_scales
        self.enabled_encoders = enabled_encoders
        self.subject_string = subject_string

        self.num_inference_steps = num_inference_steps
        self.use_840k_vae = use_840k_vae
        self.use_ds_text_encoder = use_ds_text_encoder
        self.main_unet_filepath = main_unet_filepath
        self.unet_types = unet_types
        self.extra_unet_dirpaths = extra_unet_dirpaths
        self.unet_weights = unet_weights
        self.device = device
        self.is_training = is_training

        if negative_prompt is None:
            self.negative_prompt = \
            "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
            "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
            "mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
            "nude, naked, nsfw, topless, bare breasts"
        else:
            self.negative_prompt = negative_prompt

        self.initialize_pipeline()
        # During inference, we never use static image suffix embeddings. 
        # So num_id_vecs is the length of the returned adaface embeddings for each encoder.
        self.encoders_num_id_vecs = self.id2ada_prompt_encoder.encoders_num_id_vecs
        self.extend_tokenizer_and_text_encoder()

    def initialize_pipeline(self):
        self.id2ada_prompt_encoder = create_id2ada_prompt_encoder(self.adaface_encoder_types,
                                                                  self.adaface_ckpt_paths,
                                                                  self.adaface_encoder_cfg_scales,
                                                                  self.enabled_encoders)

        self.id2ada_prompt_encoder.to(self.device)
        print(f"adaface_encoder_cfg_scales: {self.adaface_encoder_cfg_scales}")

        if self.use_840k_vae:
            # The 840000-step vae model is slightly better in face details than the original vae model.
            # https://huggingface.co/stabilityai/sd-vae-ft-mse-original
            vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", 
                                                 torch_dtype=torch.float16)
        else:
            vae = None

        if self.use_ds_text_encoder:
            # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
            # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
            text_encoder = CLIPTextModel.from_pretrained("models/diffusers/ds_text_encoder", 
                                                         torch_dtype=torch.float16)
        else:
            text_encoder = None

        remove_unet = False

        if self.pipeline_name == "img2img":
            PipelineClass = StableDiffusionImg2ImgPipeline
        elif self.pipeline_name == "text2img":
            PipelineClass = StableDiffusionPipeline
        elif self.pipeline_name == "text2img3":
            PipelineClass = StableDiffusion3Pipeline
        #elif self.pipeline_name == "flux":
        #    PipelineClass = FluxPipeline
        # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
        elif self.pipeline_name is None:
            PipelineClass = StableDiffusionPipeline
            remove_unet = True
        else:
            raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
        
        if self.base_model_path is None:
            base_model_path_dict = { 
                'text2img':  'models/sd15-dste8-vae.safetensors',
                'text2img3': 'stabilityai/stable-diffusion-3-medium-diffusers',
                'flux':      'black-forest-labs/FLUX.1-schnell',
            }
            self.base_model_path = base_model_path_dict[self.pipeline_name]

        if os.path.isfile(self.base_model_path):
            pipeline = PipelineClass.from_single_file(
                self.base_model_path, 
                torch_dtype=torch.float16
                )
        else:
            pipeline = PipelineClass.from_pretrained(
                    self.base_model_path,
                    torch_dtype=torch.float16,
                    safety_checker=None
                )
        
        if self.main_unet_filepath is not None:
            print(f"Replacing the UNet with the UNet from {self.main_unet_filepath}.")
            ret = pipeline.unet.load_state_dict(self.load_unet_from_file(self.main_unet_filepath, device='cpu'))
            if len(ret.missing_keys) > 0:
                print(f"Missing keys: {ret.missing_keys}")
            if len(ret.unexpected_keys) > 0:
                print(f"Unexpected keys: {ret.unexpected_keys}")

        if (self.unet_types is not None and len(self.unet_types) > 0) \
          or (self.extra_unet_dirpaths is not None and len(self.extra_unet_dirpaths) > 0):
            unet_ensemble = UNetEnsemble([pipeline.unet], self.unet_types, self.extra_unet_dirpaths, self.unet_weights,
                                         device=self.device, torch_dtype=torch.float16)
            pipeline.unet = unet_ensemble

        print(f"Loaded pipeline from {self.base_model_path}.")
        
        if self.use_840k_vae:
            pipeline.vae = vae
            print("Replaced the VAE with the 840k-step VAE.")
            
        if self.use_ds_text_encoder:
            pipeline.text_encoder = text_encoder
            print("Replaced the text encoder with the DreamShaper text encoder.")

        if remove_unet:
            # Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
            pipeline.unet = None
            pipeline.vae  = None
            print("Removed UNet and VAE from the pipeline.")

        if self.pipeline_name not in ["text2img3", "flux"]:
            noise_scheduler = DDIMScheduler(
                num_train_timesteps=1000,
                beta_start=0.00085,
                beta_end=0.012,
                beta_schedule="scaled_linear",
                clip_sample=False,
                set_alpha_to_one=False,
            )
            pipeline.scheduler = noise_scheduler
        # Otherwise, pipeline.scheduler == FlowMatchEulerDiscreteScheduler
        self.pipeline = pipeline.to(self.device)

    def load_unet_from_file(self, unet_path, device=None):
        if os.path.isfile(unet_path):
            if unet_path.endswith(".safetensors"):
                unet_state_dict = safetensors_load_file(unet_path, device=device)
            else:
                unet_state_dict = torch.load(unet_path, map_location=device)

            key0 = list(unet_state_dict.keys())[0]
            if key0.startswith("model.diffusion_model"):
                key_prefix = ""
                is_ldm_unet = True
            elif key0.startswith("diffusion_model"):
                key_prefix = "model."
                is_ldm_unet = True
            else:
                is_ldm_unet = False

            if is_ldm_unet:
                unet_state_dict2 = {}
                for key, value in unet_state_dict.items():
                    key2 = key_prefix + key
                    unet_state_dict2[key2] = value
                print(f"LDM UNet detected. Convert to diffusers")
                ldm_unet_config = { 'layers_per_block': 2 }
                unet_state_dict = convert_ldm_unet_checkpoint(unet_state_dict2, ldm_unet_config)
        else:
            raise ValueError(f"UNet path {unet_path} is not a file.")
        return unet_state_dict
        
    def extend_tokenizer_and_text_encoder(self):
        if np.sum(self.encoders_num_id_vecs) < 1:
            raise ValueError(f"encoders_num_id_vecs has to be larger or equal to 1, but is {self.encoders_num_id_vecs}")

        tokenizer = self.pipeline.tokenizer
        # If adaface_encoder_types is ["arc2face", "consistentID"], then total_num_id_vecs = 20.
        # We add z_0_0, z_0_1, z_0_2, ..., z_0_15, z_1_0, z_1_1, z_1_2, z_1_3 to the tokenizer.
        self.all_placeholder_tokens = []
        self.placeholder_tokens_strs = []
        for i in range(len(self.adaface_encoder_types)):
            placeholder_tokens = []
            for j in range(self.encoders_num_id_vecs[i]):
                placeholder_tokens.append(f"{self.subject_string}_{i}_{j}")
                placeholder_tokens_str = " ".join(placeholder_tokens)

            self.all_placeholder_tokens.extend(placeholder_tokens)
            self.placeholder_tokens_strs.append(placeholder_tokens_str)

        self.all_placeholder_tokens_str = " ".join(self.placeholder_tokens_strs)

        # Add the new tokens to the tokenizer.
        num_added_tokens = tokenizer.add_tokens(self.all_placeholder_tokens)
        if num_added_tokens != np.sum(self.encoders_num_id_vecs):
            raise ValueError(
                f"The tokenizer already contains some of the tokens {self.all_placeholder_tokens_str}. Please pass a different"
                " `subject_string` that is not already in the tokenizer.")

        print(f"Added {num_added_tokens} tokens ({self.all_placeholder_tokens_str}) to the tokenizer.")
        
        # placeholder_token_ids: [49408, ..., 49423].
        self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.all_placeholder_tokens)
        #print("New tokens:", self.placeholder_token_ids)
        # Resize the token embeddings as we are adding new special tokens to the tokenizer
        old_weight_shape = self.pipeline.text_encoder.get_input_embeddings().weight.shape
        self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
        new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
        print(f"Resized text encoder token embeddings from {old_weight_shape} to {new_weight.shape} on {new_weight.device}.")

    # Extend pipeline.text_encoder with the adaface subject emeddings.
    # subj_embs: [16, 768].
    def update_text_encoder_subj_embeddings(self, subj_embs):
        # Initialise the newly added placeholder token with the embeddings of the initializer token
        # token_embeds: [49412, 768]
        token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
        with torch.no_grad():
            for i, token_id in enumerate(self.placeholder_token_ids):
                token_embeds[token_id] = subj_embs[i]
            print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.all_placeholder_tokens_str}) in the text encoder.")

    def update_prompt(self, prompt, placeholder_tokens_pos='postpend'):
        if prompt is None:
            prompt = ""

        # Delete the subject_string from the prompt.
        re.sub(r'\b(a|an|the)\s+' + self.subject_string + r'\b,?', "", prompt)
        re.sub(r'\b' + self.subject_string + r'\b,?',              "", prompt)
        # Prevously, arc2face ada prompts work better if they are prepended to the prompt,
        # and consistentID ada prompts work better if they are appended to the prompt.
        # When we do joint training, seems both work better if they are appended to the prompt.
        # Therefore we simply appended all placeholder_tokens_str's to the prompt.
        # NOTE: Prepending them hurts compositional prompts.
        if placeholder_tokens_pos == 'prepend':
            prompt = self.all_placeholder_tokens_str + " " + prompt
        elif placeholder_tokens_pos == 'postpend':
            prompt = prompt + " " + self.all_placeholder_tokens_str

        return prompt

    # avg_at_stage: 'id_emb', 'img_prompt_emb', or None.
    # avg_at_stage == ada_prompt_emb usually produces the worst results.
    # id_emb is slightly better than img_prompt_emb, but sometimes img_prompt_emb is better.
    def prepare_adaface_embeddings(self, image_paths, face_id_embs=None, 
                                   avg_at_stage='id_emb', # id_emb, img_prompt_emb, ada_prompt_emb, or None.
                                   perturb_at_stage=None, # id_emb, img_prompt_emb, or None.
                                   perturb_std=0, update_text_encoder=True):

        all_adaface_subj_embs = \
            self.id2ada_prompt_encoder.generate_adaface_embeddings(\
                    image_paths, face_id_embs=face_id_embs, 
                    img_prompt_embs=None, 
                    avg_at_stage=avg_at_stage,
                    perturb_at_stage=perturb_at_stage,
                    perturb_std=perturb_std,
                    enable_static_img_suffix_embs=False)
        
        if all_adaface_subj_embs is None:
            return None
        
        # [1, 1, 16, 768] -> [16, 768]
        all_adaface_subj_embs = all_adaface_subj_embs.squeeze(0).squeeze(0)

        if update_text_encoder:
            self.update_text_encoder_subj_embeddings(all_adaface_subj_embs)
        return all_adaface_subj_embs

    def encode_prompt(self, prompt, negative_prompt=None, 
                      placeholder_tokens_pos='postpend',
                      device=None, verbose=False):
        if negative_prompt is None:
            negative_prompt = self.negative_prompt
        
        if device is None:
            device = self.device
        
        prompt = self.update_prompt(prompt, placeholder_tokens_pos=placeholder_tokens_pos)
        if verbose:
            print(f"Subject prompt: {prompt}")

        # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
        # So we manually move it to GPU here.
        self.pipeline.text_encoder.to(device)
        # pooled_prompt_embeds_, negative_pooled_prompt_embeds_ are used by text2img3 and flux.
        pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = None, None

        # Compatible with older versions of diffusers.
        if not hasattr(self.pipeline, "encode_prompt"):
            # prompt_embeds_, negative_prompt_embeds_: [77, 768] -> [1, 77, 768].
            prompt_embeds_, negative_prompt_embeds_ = \
                self.pipeline._encode_prompt(prompt, device=device, num_images_per_prompt=1,
                                             do_classifier_free_guidance=True, negative_prompt=negative_prompt)
            prompt_embeds_ = prompt_embeds_.unsqueeze(0)
            negative_prompt_embeds_ = negative_prompt_embeds_.unsqueeze(0)
        else:
            if self.pipeline_name in ["text2img3", "flux"]:
                # prompt_embeds_, negative_prompt_embeds_: [1, 333, 4096]
                # pooled_prompt_embeds_, negative_pooled_prompt_embeds_: [1, 2048]
                # CLIP Text Encoder prompt uses a maximum sequence length of 77.
                # T5 Text Encoder prompt uses a maximum sequence length of 256.
                # 333 = 256 + 77.
                prompt_t5 = prompt + "".join([", "] * 256)
                if self.pipeline_name == "text2img3":
                    prompt_embeds_, negative_prompt_embeds_, \
                    pooled_prompt_embeds_, negative_pooled_prompt_embeds_ = \
                        self.pipeline.encode_prompt(prompt, prompt, prompt_t5, device=device, 
                                                    num_images_per_prompt=1, 
                                                    do_classifier_free_guidance=True,
                                                    negative_prompt=negative_prompt)
                elif self.pipeline_name == "flux":
                    # prompt_embeds_: [1, 512, 4096]
                    # pooled_prompt_embeds_: [1, 768]
                    prompt_embeds_, pooled_prompt_embeds_, text_ids = \
                        self.pipeline.encode_prompt(prompt, prompt_t5, device=device, 
                                                    num_images_per_prompt=1)
                    negative_prompt_embeds_ = negative_pooled_prompt_embeds_ = None
                else:
                    breakpoint()
            else:
                # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
                prompt_embeds_, negative_prompt_embeds_ = \
                    self.pipeline.encode_prompt(prompt, device=device, 
                                                num_images_per_prompt=1, 
                                                do_classifier_free_guidance=True,
                                                negative_prompt=negative_prompt)
                
        return prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, negative_pooled_prompt_embeds_
    
    # ref_img_strength is used only in the img2img pipeline.
    def forward(self, noise, prompt, negative_prompt=None, 
                placeholder_tokens_pos='postpend',
                guidance_scale=6.0, out_image_count=4, 
                ref_img_strength=0.8, generator=None, verbose=False):
        noise = noise.to(device=self.device, dtype=torch.float16)

        if negative_prompt is None:
            negative_prompt = self.negative_prompt
        # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
        prompt_embeds_, negative_prompt_embeds_, pooled_prompt_embeds_, \
            negative_pooled_prompt_embeds_ = \
                self.encode_prompt(prompt, negative_prompt, 
                                   placeholder_tokens_pos=placeholder_tokens_pos,
                                   device=self.device, verbose=verbose)
        # Repeat the prompt embeddings for all images in the batch.
        prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
        if negative_prompt_embeds_ is not None:
            negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)

        if self.pipeline_name == "text2img3":
            pooled_prompt_embeds_           = pooled_prompt_embeds_.repeat(out_image_count, 1)
            negative_pooled_prompt_embeds_  = negative_pooled_prompt_embeds_.repeat(out_image_count, 1)

            # noise: [BS, 4, 64, 64]
            # When the pipeline is text2img, strength is ignored.
            images = self.pipeline(prompt_embeds=prompt_embeds_, 
                                   negative_prompt_embeds=negative_prompt_embeds_, 
                                   pooled_prompt_embeds=pooled_prompt_embeds_,
                                   negative_pooled_prompt_embeds=negative_pooled_prompt_embeds_,
                                   num_inference_steps=self.num_inference_steps, 
                                   guidance_scale=guidance_scale, 
                                   num_images_per_prompt=1,
                                   generator=generator).images
        elif self.pipeline_name == "flux":
            images = self.pipeline(prompt_embeds=prompt_embeds_, 
                                   pooled_prompt_embeds=pooled_prompt_embeds_, 
                                   num_inference_steps=4, 
                                   guidance_scale=guidance_scale, 
                                   num_images_per_prompt=1,
                                   generator=generator).images
        else:
            # When the pipeline is text2img, noise: [BS, 4, 64, 64], and strength is ignored.
            # When the pipeline is img2img,  noise is an initiali image of [BS, 3, 512, 512],
            # whose pixels are normalized to [0, 1].
            images = self.pipeline(image=noise,
                                   prompt_embeds=prompt_embeds_, 
                                   negative_prompt_embeds=negative_prompt_embeds_, 
                                   num_inference_steps=self.num_inference_steps, 
                                   guidance_scale=guidance_scale, 
                                   num_images_per_prompt=1,
                                   strength=ref_img_strength,
                                   generator=generator).images
        # images: [BS, 3, 512, 512]
        return images