import numpy as np import math import types import torch import torch.nn as nn import numpy as np import cv2 import re import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange from PIL import Image def extract_first_sentence(text): end_index = text.find('.') if end_index != -1: first_sentence = text[:end_index + 1] return first_sentence.strip() else: return text.strip() import re def remove_duplicate_keywords(text, keywords): keyword_counts = {} words = re.findall(r'\b\w+\b|[.,;!?]', text) for keyword in keywords: keyword_counts[keyword] = 0 for i, word in enumerate(words): if word.lower() == keyword.lower(): keyword_counts[keyword] += 1 if keyword_counts[keyword] > 1: words[i] = "" processed_text = " ".join(words) return processed_text # text: 'The person has one nose , two eyes , two ears , and a mouth .' def insert_markers_to_prompt(text, parsing_mask_dict): keywords = ["face", "ears", "eyes", "nose", "mouth"] text = remove_duplicate_keywords(text, keywords) key_parsing_mask_markers = ["Nose", "Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Upper_Lip", "Lower_Lip"] mapping = { "Face": "face", "Left_Ear": "ears", "Right_Ear": "ears", "Left_Eye": "eyes", "Right_Eye": "eyes", "Nose": "nose", "Upper_Lip": "mouth", "Lower_Lip": "mouth", } facial_features_align = [] markers_align = [] for key in key_parsing_mask_markers: if key in parsing_mask_dict: mapped_key = mapping.get(key, key.lower()) if mapped_key not in facial_features_align: facial_features_align.append(mapped_key) markers_align.append("<|" + mapped_key + "|>") text_marked = text align_parsing_mask_dict = parsing_mask_dict for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]): pattern = rf'\b{feature}\b' text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1) if text_marked == text_marked_new: for key, value in mapping.items(): if value == feature: if key in align_parsing_mask_dict: del align_parsing_mask_dict[key] text_marked = text_marked_new text_marked = text_marked.replace('\n', '') ordered_text = [] text_none_makers = [] facial_marked_count = 0 skip_count = 0 for marker in markers_align: start_idx = text_marked.find(marker) end_idx = start_idx + len(marker) while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]: start_idx -= 1 while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]: end_idx += 1 context = text_marked[start_idx:end_idx].strip() if context == "": text_none_makers.append(text_marked[:end_idx]) else: if skip_count!=0: skip_count -= 1 continue else: ordered_text.append(context + ", ") text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:] text_marked = text_delete_makers facial_marked_count += 1 # ordered_text: ['The person has one nose <|nose|>, ', 'two ears <|ears|>, ', # 'two eyes <|eyes|>, ', 'and a mouth <|mouth|>, '] # align_parsing_mask_dict.keys(): ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip'] align_marked_text = "".join(ordered_text) replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"] for item in replace_list: align_marked_text = align_marked_text.replace(item, "<|facial|>") # align_marked_text: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, ' return align_marked_text, align_parsing_mask_dict def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer): input_ids = tokenizer.encode(text) image_noun_phrase_end_mask = [False for _ in input_ids] facial_noun_phrase_end_mask = [False for _ in input_ids] clean_input_ids = [] clean_index = 0 image_num = 0 for i, id in enumerate(input_ids): if id == image_token_id: image_noun_phrase_end_mask[clean_index + image_num - 1] = True image_num += 1 elif id == facial_token_id: facial_noun_phrase_end_mask[clean_index - 1] = True else: clean_input_ids.append(id) clean_index += 1 max_len = tokenizer.model_max_length if len(clean_input_ids) > max_len: clean_input_ids = clean_input_ids[:max_len] else: clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( max_len - len(clean_input_ids) ) if len(image_noun_phrase_end_mask) > max_len: image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len] else: image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * ( max_len - len(image_noun_phrase_end_mask) ) if len(facial_noun_phrase_end_mask) > max_len: facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len] else: facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * ( max_len - len(facial_noun_phrase_end_mask) ) clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long) image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool) facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool) return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0) def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5): image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1] image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool) if len(image_token_idx) < max_num_objects: image_token_idx = torch.cat( [ image_token_idx, torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long), ] ) image_token_idx_mask = torch.cat( [ image_token_idx_mask, torch.zeros( max_num_objects - len(image_token_idx_mask), dtype=torch.bool, ), ] ) facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1] facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool) if len(facial_token_idx) < max_num_facials: facial_token_idx = torch.cat( [ facial_token_idx, torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long), ] ) facial_token_idx_mask = torch.cat( [ facial_token_idx_mask, torch.zeros( max_num_facials - len(facial_token_idx_mask), dtype=torch.bool, ), ] ) image_token_idx = image_token_idx.unsqueeze(0) image_token_idx_mask = image_token_idx_mask.unsqueeze(0) facial_token_idx = facial_token_idx.unsqueeze(0) facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0) return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask def get_object_localization_loss_for_one_layer( cross_attention_scores, object_segmaps, object_token_idx, object_token_idx_mask, loss_fn, ): bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape b, max_num_objects, _, _ = object_segmaps.shape size = int(num_noise_latents**0.5) object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True) object_segmaps = object_segmaps.view( b, max_num_objects, -1 ) num_heads = bxh // b cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens) object_token_attn_prob = torch.gather( cross_attention_scores, dim=3, index=object_token_idx.view(b, 1, 1, max_num_objects).expand( b, num_heads, num_noise_latents, max_num_objects ), ) object_segmaps = ( object_segmaps.permute(0, 2, 1) .unsqueeze(1) .expand(b, num_heads, num_noise_latents, max_num_objects) ) loss = loss_fn(object_token_attn_prob, object_segmaps) loss = loss * object_token_idx_mask.view(b, 1, max_num_objects) object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5 loss = (loss.sum(dim=2) / object_token_cnt).mean() return loss def get_object_localization_loss( cross_attention_scores, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn, ): num_layers = len(cross_attention_scores) loss = 0 for k, v in cross_attention_scores.items(): layer_loss = get_object_localization_loss_for_one_layer( v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn ) loss += layer_loss return loss / num_layers def unet_store_cross_attention_scores(unet, attention_scores, layers=5): from diffusers.models.attention_processor import Attention UNET_LAYER_NAMES = [ "down_blocks.0", "down_blocks.1", "down_blocks.2", "mid_block", "up_blocks.1", "up_blocks.2", "up_blocks.3", ] start_layer = (len(UNET_LAYER_NAMES) - layers) // 2 end_layer = start_layer + layers applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer] def make_new_get_attention_scores_fn(name): def new_get_attention_scores(module, query, key, attention_mask=None): attention_probs = module.old_get_attention_scores( query, key, attention_mask ) attention_scores[name] = attention_probs return attention_probs return new_get_attention_scores for name, module in unet.named_modules(): if isinstance(module, Attention) and "attn1" in name: if not any(layer in name for layer in applicable_layers): continue module.old_get_attention_scores = module.get_attention_scores module.get_attention_scores = types.MethodType( make_new_get_attention_scores_fn(name), module ) return unet class BalancedL1Loss(nn.Module): def __init__(self, threshold=1.0, normalize=False): super().__init__() self.threshold = threshold self.normalize = normalize def forward(self, object_token_attn_prob, object_segmaps): if self.normalize: object_token_attn_prob = object_token_attn_prob / ( object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5 ) background_segmaps = 1 - object_segmaps background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5 object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5 background_loss = (object_token_attn_prob * background_segmaps).sum( dim=2 ) / background_segmaps_sum object_loss = (object_token_attn_prob * object_segmaps).sum( dim=2 ) / object_segmaps_sum return background_loss - object_loss def apply_mask_to_raw_image(raw_image, mask_image): mask_image = mask_image.resize(raw_image.size) mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image) return mask_raw_image mapping_table = [ {"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]}, {"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]}, {"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]}, {"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]}, {"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]}, {"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]}, {"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]}, {"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]}, {"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]}, {"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]}, {"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]}, {"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]}, {"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]}, {"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]}, {"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]}, {"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]}, {"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]}, {"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]}, {"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]}, {"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]}, {"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]}, {"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]}, {"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]}, {"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]}, {"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]} ] def masks_for_unique_values(image_raw_mask): image_array = np.array(image_raw_mask) unique_values, counts = np.unique(image_array, return_counts=True) masks_dict = {} for value in unique_values: binary_image = np.uint8(image_array == value) * 255 contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) mask = np.zeros_like(image_array) for contour in contours: cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED) if value == 0: body_part="WithoutBackground" mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype) masks_dict[body_part] = Image.fromarray(mask2) body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}") if body_part.startswith("Unknown_"): continue masks_dict[body_part] = Image.fromarray(mask) return masks_dict # FFN def FeedForward(dim, mult=4): inner_dim = int(dim * mult) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), nn.GELU(), nn.Linear(inner_dim, dim, bias=False), ) def reshape_tensor(x, heads): bs, length, width = x.shape x = x.view(bs, length, heads, -1) x = x.transpose(1, 2) x = x.reshape(bs, heads, length, -1) return x class PerceiverAttention(nn.Module): def __init__(self, *, dim, dim_head=64, heads=8): super().__init__() self.scale = dim_head**-0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) # x -> kv, latents -> q def forward(self, x, latents): """ Args: x (torch.Tensor): image features shape (b, n1, D) latent (torch.Tensor): latent features shape (b, n2, D) """ x = self.norm1(x) latents = self.norm2(latents) b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) weight = (q * scale) @ (k * scale).transpose(-2, -1) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) class FacePerceiverResampler(torch.nn.Module): def __init__( self, *, dim=768, depth=4, dim_head=64, heads=16, embedding_dim=1280, output_dim=768, ff_mult=4, ): super().__init__() self.proj_in = torch.nn.Linear(embedding_dim, dim) self.proj_out = torch.nn.Linear(dim, output_dim) self.norm_out = torch.nn.LayerNorm(output_dim) self.layers = torch.nn.ModuleList([]) for _ in range(depth): self.layers.append( torch.nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) # x -> kv, latents -> q def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280]) x = self.proj_in(x) # x.torch.Size([2, 257, 768]) for attn, ff in self.layers: # x -> kv, latents -> q latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768]) latents = ff(latents) + latents # latents.torch.Size([2, 4, 768]) latents = self.proj_out(latents) return self.norm_out(latents) class ProjPlusModel(torch.nn.Module): def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): super().__init__() self.cross_attention_dim = cross_attention_dim self.num_tokens = num_tokens self.proj = torch.nn.Sequential( torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), torch.nn.GELU(), torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), ) self.norm = torch.nn.LayerNorm(cross_attention_dim) self.perceiver_resampler = FacePerceiverResampler( dim=cross_attention_dim, depth=4, dim_head=64, heads=cross_attention_dim // 64, embedding_dim=clip_embeddings_dim, output_dim=cross_attention_dim, ff_mult=4, ) def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): x = self.proj(id_embeds) x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) x = self.norm(x) # id_embeds -> x -> kv, clip_embeds -> q out = self.perceiver_resampler(x, clip_embeds) if shortcut: out = scale * x + out return out class AttentionMLP(nn.Module): def __init__( self, dtype=torch.float16, dim=1024, depth=8, dim_head=64, heads=16, single_num_tokens=1, embedding_dim=1280, output_dim=768, ff_mult=4, max_seq_len: int = 257*2, apply_pos_emb: bool = False, num_latents_mean_pooled: int = 0, ): super().__init__() self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None self.single_num_tokens = single_num_tokens self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5) self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) self.to_latents_from_mean_pooled_seq = ( nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, dim * num_latents_mean_pooled), Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), ) if num_latents_mean_pooled > 0 else None ) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), ] ) ) def forward(self, x): if self.pos_emb is not None: n, device = x.shape[1], x.device pos_emb = self.pos_emb(torch.arange(n, device=device)) x = x + pos_emb # x torch.Size([5, 257, 1280]) latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) # torch.Size([5, 257, 1024]) if self.to_latents_from_mean_pooled_seq: meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) latents = torch.cat((meanpooled_latents, latents), dim=-2) for attn, ff in self.layers: latents = attn(x, latents) + latents latents = ff(latents) + latents latents = self.proj_out(latents) return self.norm_out(latents) def masked_mean(t, *, dim, mask=None): if mask is None: return t.mean(dim=dim) denom = mask.sum(dim=dim, keepdim=True) mask = rearrange(mask, "b n -> b n 1") masked_t = t.masked_fill(~mask, 0.0) return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)