Fabrice-TIERCELIN
commited on
' instead of "
Browse files- hyvideo/modules/attenion.py +212 -212
hyvideo/modules/attenion.py
CHANGED
@@ -1,212 +1,212 @@
|
|
1 |
-
import importlib.metadata
|
2 |
-
import math
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
try:
|
9 |
-
import flash_attn
|
10 |
-
from flash_attn.flash_attn_interface import _flash_attn_forward
|
11 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
12 |
-
except ImportError:
|
13 |
-
flash_attn = None
|
14 |
-
flash_attn_varlen_func = None
|
15 |
-
_flash_attn_forward = None
|
16 |
-
|
17 |
-
|
18 |
-
MEMORY_LAYOUT = {
|
19 |
-
"flash": (
|
20 |
-
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
21 |
-
lambda x: x,
|
22 |
-
),
|
23 |
-
"torch": (
|
24 |
-
lambda x: x.transpose(1, 2),
|
25 |
-
lambda x: x.transpose(1, 2),
|
26 |
-
),
|
27 |
-
"vanilla": (
|
28 |
-
lambda x: x.transpose(1, 2),
|
29 |
-
lambda x: x.transpose(1, 2),
|
30 |
-
),
|
31 |
-
}
|
32 |
-
|
33 |
-
|
34 |
-
def get_cu_seqlens(text_mask, img_len):
|
35 |
-
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
36 |
-
|
37 |
-
Args:
|
38 |
-
text_mask (torch.Tensor): the mask of text
|
39 |
-
img_len (int): the length of image
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
torch.Tensor: the calculated cu_seqlens for flash attention
|
43 |
-
"""
|
44 |
-
batch_size = text_mask.shape[0]
|
45 |
-
text_len = text_mask.sum(dim=1)
|
46 |
-
max_len = text_mask.shape[1] + img_len
|
47 |
-
|
48 |
-
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
49 |
-
|
50 |
-
for i in range(batch_size):
|
51 |
-
s = text_len[i] + img_len
|
52 |
-
s1 = i * max_len + s
|
53 |
-
s2 = (i + 1) * max_len
|
54 |
-
cu_seqlens[2 * i + 1] = s1
|
55 |
-
cu_seqlens[2 * i + 2] = s2
|
56 |
-
|
57 |
-
return cu_seqlens
|
58 |
-
|
59 |
-
|
60 |
-
def attention(
|
61 |
-
q,
|
62 |
-
k,
|
63 |
-
v,
|
64 |
-
mode="torch",
|
65 |
-
drop_rate=0,
|
66 |
-
attn_mask=None,
|
67 |
-
causal=False,
|
68 |
-
cu_seqlens_q=None,
|
69 |
-
cu_seqlens_kv=None,
|
70 |
-
max_seqlen_q=None,
|
71 |
-
max_seqlen_kv=None,
|
72 |
-
batch_size=1,
|
73 |
-
):
|
74 |
-
"""
|
75 |
-
Perform QKV self attention.
|
76 |
-
|
77 |
-
Args:
|
78 |
-
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
79 |
-
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
80 |
-
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
81 |
-
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
82 |
-
drop_rate (float): Dropout rate in attention map. (default: 0)
|
83 |
-
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
84 |
-
(default: None)
|
85 |
-
causal (bool): Whether to use causal attention. (default: False)
|
86 |
-
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
87 |
-
used to index into q.
|
88 |
-
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
89 |
-
used to index into kv.
|
90 |
-
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
91 |
-
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
92 |
-
|
93 |
-
Returns:
|
94 |
-
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
95 |
-
"""
|
96 |
-
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
97 |
-
q = pre_attn_layout(q)
|
98 |
-
k = pre_attn_layout(k)
|
99 |
-
v = pre_attn_layout(v)
|
100 |
-
|
101 |
-
if mode == "torch":
|
102 |
-
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
103 |
-
attn_mask = attn_mask.to(q.dtype)
|
104 |
-
x = F.scaled_dot_product_attention(
|
105 |
-
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
106 |
-
)
|
107 |
-
elif mode == "flash":
|
108 |
-
x = flash_attn_varlen_func(
|
109 |
-
q,
|
110 |
-
k,
|
111 |
-
v,
|
112 |
-
cu_seqlens_q,
|
113 |
-
cu_seqlens_kv,
|
114 |
-
max_seqlen_q,
|
115 |
-
max_seqlen_kv,
|
116 |
-
)
|
117 |
-
# x with shape [(bxs), a, d]
|
118 |
-
x = x.view(
|
119 |
-
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
|
120 |
-
) # reshape x to [b, s, a, d]
|
121 |
-
elif mode == "vanilla":
|
122 |
-
scale_factor = 1 / math.sqrt(q.size(-1))
|
123 |
-
|
124 |
-
b, a, s, _ = q.shape
|
125 |
-
s1 = k.size(2)
|
126 |
-
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
127 |
-
if causal:
|
128 |
-
# Only applied to self attention
|
129 |
-
assert (
|
130 |
-
attn_mask is None
|
131 |
-
), "Causal mask and attn_mask cannot be used together"
|
132 |
-
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
133 |
-
diagonal=0
|
134 |
-
)
|
135 |
-
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
136 |
-
attn_bias.to(q.dtype)
|
137 |
-
|
138 |
-
if attn_mask is not None:
|
139 |
-
if attn_mask.dtype == torch.bool:
|
140 |
-
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
141 |
-
else:
|
142 |
-
attn_bias += attn_mask
|
143 |
-
|
144 |
-
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
145 |
-
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
146 |
-
attn += attn_bias
|
147 |
-
attn = attn.softmax(dim=-1)
|
148 |
-
attn = torch.dropout(attn, p=drop_rate, train=True)
|
149 |
-
x = attn @ v
|
150 |
-
else:
|
151 |
-
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
152 |
-
|
153 |
-
x = post_attn_layout(x)
|
154 |
-
b, s, a, d = x.shape
|
155 |
-
out = x.reshape(b, s, -1)
|
156 |
-
return out
|
157 |
-
|
158 |
-
|
159 |
-
def parallel_attention(
|
160 |
-
hybrid_seq_parallel_attn,
|
161 |
-
q,
|
162 |
-
k,
|
163 |
-
v,
|
164 |
-
img_q_len,
|
165 |
-
img_kv_len,
|
166 |
-
cu_seqlens_q,
|
167 |
-
cu_seqlens_kv
|
168 |
-
):
|
169 |
-
attn1 = hybrid_seq_parallel_attn(
|
170 |
-
None,
|
171 |
-
q[:, :img_q_len, :, :],
|
172 |
-
k[:, :img_kv_len, :, :],
|
173 |
-
v[:, :img_kv_len, :, :],
|
174 |
-
dropout_p=0.0,
|
175 |
-
causal=False,
|
176 |
-
joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
|
177 |
-
joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
|
178 |
-
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
|
179 |
-
joint_strategy="rear",
|
180 |
-
)
|
181 |
-
if flash_attn.__version__ >=
|
182 |
-
attn2, *_ = _flash_attn_forward(
|
183 |
-
q[:,cu_seqlens_q[1]:],
|
184 |
-
k[:,cu_seqlens_kv[1]:],
|
185 |
-
v[:,cu_seqlens_kv[1]:],
|
186 |
-
dropout_p=0.0,
|
187 |
-
softmax_scale=q.shape[-1] ** (-0.5),
|
188 |
-
causal=False,
|
189 |
-
window_size_left=-1,
|
190 |
-
window_size_right=-1,
|
191 |
-
softcap=0.0,
|
192 |
-
alibi_slopes=None,
|
193 |
-
return_softmax=False,
|
194 |
-
)
|
195 |
-
else:
|
196 |
-
attn2, *_ = _flash_attn_forward(
|
197 |
-
q[:,cu_seqlens_q[1]:],
|
198 |
-
k[:,cu_seqlens_kv[1]:],
|
199 |
-
v[:,cu_seqlens_kv[1]:],
|
200 |
-
dropout_p=0.0,
|
201 |
-
softmax_scale=q.shape[-1] ** (-0.5),
|
202 |
-
causal=False,
|
203 |
-
window_size=(-1, -1),
|
204 |
-
softcap=0.0,
|
205 |
-
alibi_slopes=None,
|
206 |
-
return_softmax=False,
|
207 |
-
)
|
208 |
-
attn = torch.cat([attn1, attn2], dim=1)
|
209 |
-
b, s, a, d = attn.shape
|
210 |
-
attn = attn.reshape(b, s, -1)
|
211 |
-
|
212 |
-
return attn
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
try:
|
9 |
+
import flash_attn
|
10 |
+
from flash_attn.flash_attn_interface import _flash_attn_forward
|
11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
12 |
+
except ImportError:
|
13 |
+
flash_attn = None
|
14 |
+
flash_attn_varlen_func = None
|
15 |
+
_flash_attn_forward = None
|
16 |
+
|
17 |
+
|
18 |
+
MEMORY_LAYOUT = {
|
19 |
+
"flash": (
|
20 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
21 |
+
lambda x: x,
|
22 |
+
),
|
23 |
+
"torch": (
|
24 |
+
lambda x: x.transpose(1, 2),
|
25 |
+
lambda x: x.transpose(1, 2),
|
26 |
+
),
|
27 |
+
"vanilla": (
|
28 |
+
lambda x: x.transpose(1, 2),
|
29 |
+
lambda x: x.transpose(1, 2),
|
30 |
+
),
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def get_cu_seqlens(text_mask, img_len):
|
35 |
+
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
36 |
+
|
37 |
+
Args:
|
38 |
+
text_mask (torch.Tensor): the mask of text
|
39 |
+
img_len (int): the length of image
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
torch.Tensor: the calculated cu_seqlens for flash attention
|
43 |
+
"""
|
44 |
+
batch_size = text_mask.shape[0]
|
45 |
+
text_len = text_mask.sum(dim=1)
|
46 |
+
max_len = text_mask.shape[1] + img_len
|
47 |
+
|
48 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
49 |
+
|
50 |
+
for i in range(batch_size):
|
51 |
+
s = text_len[i] + img_len
|
52 |
+
s1 = i * max_len + s
|
53 |
+
s2 = (i + 1) * max_len
|
54 |
+
cu_seqlens[2 * i + 1] = s1
|
55 |
+
cu_seqlens[2 * i + 2] = s2
|
56 |
+
|
57 |
+
return cu_seqlens
|
58 |
+
|
59 |
+
|
60 |
+
def attention(
|
61 |
+
q,
|
62 |
+
k,
|
63 |
+
v,
|
64 |
+
mode="torch",
|
65 |
+
drop_rate=0,
|
66 |
+
attn_mask=None,
|
67 |
+
causal=False,
|
68 |
+
cu_seqlens_q=None,
|
69 |
+
cu_seqlens_kv=None,
|
70 |
+
max_seqlen_q=None,
|
71 |
+
max_seqlen_kv=None,
|
72 |
+
batch_size=1,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Perform QKV self attention.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
79 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
80 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
81 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
82 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
83 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
84 |
+
(default: None)
|
85 |
+
causal (bool): Whether to use causal attention. (default: False)
|
86 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
87 |
+
used to index into q.
|
88 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
89 |
+
used to index into kv.
|
90 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
91 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
95 |
+
"""
|
96 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
97 |
+
q = pre_attn_layout(q)
|
98 |
+
k = pre_attn_layout(k)
|
99 |
+
v = pre_attn_layout(v)
|
100 |
+
|
101 |
+
if mode == "torch":
|
102 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
103 |
+
attn_mask = attn_mask.to(q.dtype)
|
104 |
+
x = F.scaled_dot_product_attention(
|
105 |
+
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
106 |
+
)
|
107 |
+
elif mode == "flash":
|
108 |
+
x = flash_attn_varlen_func(
|
109 |
+
q,
|
110 |
+
k,
|
111 |
+
v,
|
112 |
+
cu_seqlens_q,
|
113 |
+
cu_seqlens_kv,
|
114 |
+
max_seqlen_q,
|
115 |
+
max_seqlen_kv,
|
116 |
+
)
|
117 |
+
# x with shape [(bxs), a, d]
|
118 |
+
x = x.view(
|
119 |
+
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
|
120 |
+
) # reshape x to [b, s, a, d]
|
121 |
+
elif mode == "vanilla":
|
122 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
123 |
+
|
124 |
+
b, a, s, _ = q.shape
|
125 |
+
s1 = k.size(2)
|
126 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
127 |
+
if causal:
|
128 |
+
# Only applied to self attention
|
129 |
+
assert (
|
130 |
+
attn_mask is None
|
131 |
+
), "Causal mask and attn_mask cannot be used together"
|
132 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
|
133 |
+
diagonal=0
|
134 |
+
)
|
135 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
136 |
+
attn_bias.to(q.dtype)
|
137 |
+
|
138 |
+
if attn_mask is not None:
|
139 |
+
if attn_mask.dtype == torch.bool:
|
140 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
141 |
+
else:
|
142 |
+
attn_bias += attn_mask
|
143 |
+
|
144 |
+
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
145 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
146 |
+
attn += attn_bias
|
147 |
+
attn = attn.softmax(dim=-1)
|
148 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
149 |
+
x = attn @ v
|
150 |
+
else:
|
151 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
152 |
+
|
153 |
+
x = post_attn_layout(x)
|
154 |
+
b, s, a, d = x.shape
|
155 |
+
out = x.reshape(b, s, -1)
|
156 |
+
return out
|
157 |
+
|
158 |
+
|
159 |
+
def parallel_attention(
|
160 |
+
hybrid_seq_parallel_attn,
|
161 |
+
q,
|
162 |
+
k,
|
163 |
+
v,
|
164 |
+
img_q_len,
|
165 |
+
img_kv_len,
|
166 |
+
cu_seqlens_q,
|
167 |
+
cu_seqlens_kv
|
168 |
+
):
|
169 |
+
attn1 = hybrid_seq_parallel_attn(
|
170 |
+
None,
|
171 |
+
q[:, :img_q_len, :, :],
|
172 |
+
k[:, :img_kv_len, :, :],
|
173 |
+
v[:, :img_kv_len, :, :],
|
174 |
+
dropout_p=0.0,
|
175 |
+
causal=False,
|
176 |
+
joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
|
177 |
+
joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
|
178 |
+
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
|
179 |
+
joint_strategy="rear",
|
180 |
+
)
|
181 |
+
if flash_attn.__version__ >= "2.7.0":
|
182 |
+
attn2, *_ = _flash_attn_forward(
|
183 |
+
q[:,cu_seqlens_q[1]:],
|
184 |
+
k[:,cu_seqlens_kv[1]:],
|
185 |
+
v[:,cu_seqlens_kv[1]:],
|
186 |
+
dropout_p=0.0,
|
187 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
188 |
+
causal=False,
|
189 |
+
window_size_left=-1,
|
190 |
+
window_size_right=-1,
|
191 |
+
softcap=0.0,
|
192 |
+
alibi_slopes=None,
|
193 |
+
return_softmax=False,
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
attn2, *_ = _flash_attn_forward(
|
197 |
+
q[:,cu_seqlens_q[1]:],
|
198 |
+
k[:,cu_seqlens_kv[1]:],
|
199 |
+
v[:,cu_seqlens_kv[1]:],
|
200 |
+
dropout_p=0.0,
|
201 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
202 |
+
causal=False,
|
203 |
+
window_size=(-1, -1),
|
204 |
+
softcap=0.0,
|
205 |
+
alibi_slopes=None,
|
206 |
+
return_softmax=False,
|
207 |
+
)
|
208 |
+
attn = torch.cat([attn1, attn2], dim=1)
|
209 |
+
b, s, a, d = attn.shape
|
210 |
+
attn = attn.reshape(b, s, -1)
|
211 |
+
|
212 |
+
return attn
|