import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.lora import LoRALinearLayer from .functions import AttentionMLP class FuseModule(nn.Module): def __init__(self, embed_dim): super().__init__() self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) self.layer_norm = nn.LayerNorm(embed_dim) def fuse_fn(self, prompt_embeds, id_embeds): stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds stacked_id_embeds = self.mlp2(stacked_id_embeds) stacked_id_embeds = self.layer_norm(stacked_id_embeds) return stacked_id_embeds def forward( self, prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask, ) -> torch.Tensor: id_embeds = id_embeds.to(prompt_embeds.dtype) batch_size, max_num_inputs = id_embeds.shape[:2] # 1,5 seq_length = prompt_embeds.shape[1] # 77 flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1]) # flat_id_embeds torch.Size([5, 1, 768]) valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] # valid_id_embeds torch.Size([4, 1, 768]) prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) # torch.Size([77, 768]) class_tokens_mask = class_tokens_mask.view(-1) # torch.Size([77]) valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # torch.Size([4, 768]) image_token_embeds = prompt_embeds[class_tokens_mask] # torch.Size([4, 768]) stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) # torch.Size([4, 768]) assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) return updated_prompt_embeds class MLP(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): super().__init__() if use_residual: assert in_dim == out_dim self.layernorm = nn.LayerNorm(in_dim) self.fc1 = nn.Linear(in_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, out_dim) self.use_residual = use_residual self.act_fn = nn.GELU() def forward(self, x): residual = x x = self.layernorm(x) x = self.fc1(x) x = self.act_fn(x) x = self.fc2(x) if self.use_residual: x = x + residual return x class FacialEncoder(nn.Module): def __init__(self): super().__init__() self.visual_projection = AttentionMLP() self.fuse_module = FuseModule(768) def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask): bs, num_inputs, token_length, image_dim = multi_image_embeds.shape multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim) id_embeds = self.visual_projection(multi_image_embeds_view) # torch.Size([5, 1, 768]) id_embeds = id_embeds.view(bs, num_inputs, 1, -1) # fuse_module replaces the class tokens in prompt_embeds with the fused (id_embeds, prompt_embeds[class_tokens_mask]) # whose indices are specified by class_tokens_mask. updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask) return updated_prompt_embeds class Consistent_AttProcessor(nn.Module): def __init__( self, hidden_size=None, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, ): super().__init__() self.rank = rank self.lora_scale = lora_scale self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class Consistent_IPAttProcessor(nn.Module): def __init__( self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): super().__init__() self.rank = rank self.lora_scale = lora_scale self.num_tokens = num_tokens self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for module in [self.to_q_lora, self.to_k_lora, self.to_v_lora, self.to_out_lora, self.to_k_ip, self.to_v_ip]: for param in module.parameters(): param.requires_grad = False def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states