tolgacangoz
commited on
Upload matryoshka.py
Browse files- scheduler/matryoshka.py +21 -21
scheduler/matryoshka.py
CHANGED
@@ -1517,7 +1517,7 @@ class MatryoshkaTransformerBlock(nn.Module):
|
|
1517 |
# **cross_attention_kwargs,
|
1518 |
)
|
1519 |
|
1520 |
-
attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
|
1521 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1522 |
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
|
1523 |
hidden_states = hidden_states + attn_output_cond
|
@@ -1635,11 +1635,30 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1635 |
# key = key.permute(0, 2, 1)
|
1636 |
# value = value.permute(0, 2, 1)
|
1637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1638 |
if attn.norm_q is not None:
|
1639 |
query = attn.norm_q(query)
|
1640 |
if attn.norm_k is not None:
|
1641 |
key = attn.norm_k(key)
|
1642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1643 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1644 |
# TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
|
1645 |
# hidden_states = self.attention(
|
@@ -1649,31 +1668,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1649 |
# mask=attention_mask,
|
1650 |
# num_heads=attn.heads,
|
1651 |
# )
|
1652 |
-
inner_dim = key.shape[-1]
|
1653 |
-
head_dim = inner_dim // attn.heads
|
1654 |
-
#query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1655 |
-
query = query.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1656 |
-
key = key.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1657 |
-
value = value.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1658 |
-
hidden_states = F.scaled_dot_product_attention(
|
1659 |
-
query,
|
1660 |
-
key,
|
1661 |
-
value,
|
1662 |
-
attn_mask=attention_mask,
|
1663 |
-
dropout_p=attn.dropout,
|
1664 |
-
)
|
1665 |
|
1666 |
hidden_states = hidden_states.to(query.dtype)
|
1667 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, height * width, channel)
|
1668 |
|
1669 |
if self_attention_output is not None:
|
1670 |
hidden_states = hidden_states + self_attention_output
|
1671 |
-
|
1672 |
-
if not attn.pre_only:
|
1673 |
-
# linear proj
|
1674 |
-
hidden_states = attn.to_out[0](hidden_states)
|
1675 |
-
# dropout
|
1676 |
-
hidden_states = attn.to_out[1](hidden_states)
|
1677 |
|
1678 |
if attn.residual_connection:
|
1679 |
hidden_states = hidden_states + residual
|
|
|
1517 |
# **cross_attention_kwargs,
|
1518 |
)
|
1519 |
|
1520 |
+
# attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous()
|
1521 |
attn_output_cond = self.proj_out(attn_output_cond)
|
1522 |
attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)
|
1523 |
hidden_states = hidden_states + attn_output_cond
|
|
|
1635 |
# key = key.permute(0, 2, 1)
|
1636 |
# value = value.permute(0, 2, 1)
|
1637 |
|
1638 |
+
if attn.norm_q is not None:
|
1639 |
+
query = attn.norm_q(query)
|
1640 |
+
if attn.norm_k is not None:
|
1641 |
+
key = attn.norm_k(key)
|
1642 |
+
|
1643 |
+
inner_dim = key.shape[-1]
|
1644 |
+
head_dim = inner_dim // attn.heads
|
1645 |
+
|
1646 |
+
if self_attention_output is None:
|
1647 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1648 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1649 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1650 |
+
|
1651 |
if attn.norm_q is not None:
|
1652 |
query = attn.norm_q(query)
|
1653 |
if attn.norm_k is not None:
|
1654 |
key = attn.norm_k(key)
|
1655 |
|
1656 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1657 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1658 |
+
hidden_states = F.scaled_dot_product_attention(
|
1659 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1660 |
+
)
|
1661 |
+
|
1662 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1663 |
# TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
|
1664 |
# hidden_states = self.attention(
|
|
|
1668 |
# mask=attention_mask,
|
1669 |
# num_heads=attn.heads,
|
1670 |
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1671 |
|
1672 |
hidden_states = hidden_states.to(query.dtype)
|
|
|
1673 |
|
1674 |
if self_attention_output is not None:
|
1675 |
hidden_states = hidden_states + self_attention_output
|
1676 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
|
|
|
|
|
|
|
|
|
1677 |
|
1678 |
if attn.residual_connection:
|
1679 |
hidden_states = hidden_states + residual
|