tolgacangoz commited on
Commit
b17a00f
·
verified ·
1 Parent(s): d470ded

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. scheduler/matryoshka.py +21 -21
scheduler/matryoshka.py CHANGED
@@ -1517,7 +1517,7 @@ class MatryoshkaTransformerBlock(nn.Module):
1517
  # **cross_attention_kwargs,
1518
  )
1519
 
1520
- attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
  attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1523
  hidden_states = hidden_states + attn_output_cond
@@ -1635,11 +1635,30 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1635
  # key = key.permute(0, 2, 1)
1636
  # value = value.permute(0, 2, 1)
1637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1638
  if attn.norm_q is not None:
1639
  query = attn.norm_q(query)
1640
  if attn.norm_k is not None:
1641
  key = attn.norm_k(key)
1642
 
 
 
 
 
 
 
1643
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1644
  # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1645
  # hidden_states = self.attention(
@@ -1649,31 +1668,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1649
  # mask=attention_mask,
1650
  # num_heads=attn.heads,
1651
  # )
1652
- inner_dim = key.shape[-1]
1653
- head_dim = inner_dim // attn.heads
1654
- #query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1655
- query = query.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1656
- key = key.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1657
- value = value.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1658
- hidden_states = F.scaled_dot_product_attention(
1659
- query,
1660
- key,
1661
- value,
1662
- attn_mask=attention_mask,
1663
- dropout_p=attn.dropout,
1664
- )
1665
 
1666
  hidden_states = hidden_states.to(query.dtype)
1667
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, height * width, channel)
1668
 
1669
  if self_attention_output is not None:
1670
  hidden_states = hidden_states + self_attention_output
1671
-
1672
- if not attn.pre_only:
1673
- # linear proj
1674
- hidden_states = attn.to_out[0](hidden_states)
1675
- # dropout
1676
- hidden_states = attn.to_out[1](hidden_states)
1677
 
1678
  if attn.residual_connection:
1679
  hidden_states = hidden_states + residual
 
1517
  # **cross_attention_kwargs,
1518
  )
1519
 
1520
+ # attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
1521
  attn_output_cond = self.proj_out(attn_output_cond)
1522
  attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
1523
  hidden_states = hidden_states + attn_output_cond
 
1635
  # key = key.permute(0, 2, 1)
1636
  # value = value.permute(0, 2, 1)
1637
 
1638
+ if attn.norm_q is not None:
1639
+ query = attn.norm_q(query)
1640
+ if attn.norm_k is not None:
1641
+ key = attn.norm_k(key)
1642
+
1643
+ inner_dim = key.shape[-1]
1644
+ head_dim = inner_dim // attn.heads
1645
+
1646
+ if self_attention_output is None:
1647
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1648
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1649
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1650
+
1651
  if attn.norm_q is not None:
1652
  query = attn.norm_q(query)
1653
  if attn.norm_k is not None:
1654
  key = attn.norm_k(key)
1655
 
1656
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1657
+ # TODO: add support for attn.scale when we move to Torch 2.1
1658
+ hidden_states = F.scaled_dot_product_attention(
1659
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1660
+ )
1661
+
1662
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1663
  # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1664
  # hidden_states = self.attention(
 
1668
  # mask=attention_mask,
1669
  # num_heads=attn.heads,
1670
  # )
 
 
 
 
 
 
 
 
 
 
 
 
 
1671
 
1672
  hidden_states = hidden_states.to(query.dtype)
 
1673
 
1674
  if self_attention_output is not None:
1675
  hidden_states = hidden_states + self_attention_output
1676
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
 
 
 
 
 
1677
 
1678
  if attn.residual_connection:
1679
  hidden_states = hidden_states + residual