Fabrice-TIERCELIN commited on
Commit
67b8f7c
·
verified ·
1 Parent(s): dfd1126

' instead of "

Browse files
Files changed (1) hide show
  1. hyvideo/modules/attenion.py +212 -212
hyvideo/modules/attenion.py CHANGED
@@ -1,212 +1,212 @@
1
- import importlib.metadata
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- try:
9
- import flash_attn
10
- from flash_attn.flash_attn_interface import _flash_attn_forward
11
- from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
- except ImportError:
13
- flash_attn = None
14
- flash_attn_varlen_func = None
15
- _flash_attn_forward = None
16
-
17
-
18
- MEMORY_LAYOUT = {
19
- "flash": (
20
- lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
21
- lambda x: x,
22
- ),
23
- "torch": (
24
- lambda x: x.transpose(1, 2),
25
- lambda x: x.transpose(1, 2),
26
- ),
27
- "vanilla": (
28
- lambda x: x.transpose(1, 2),
29
- lambda x: x.transpose(1, 2),
30
- ),
31
- }
32
-
33
-
34
- def get_cu_seqlens(text_mask, img_len):
35
- """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
36
-
37
- Args:
38
- text_mask (torch.Tensor): the mask of text
39
- img_len (int): the length of image
40
-
41
- Returns:
42
- torch.Tensor: the calculated cu_seqlens for flash attention
43
- """
44
- batch_size = text_mask.shape[0]
45
- text_len = text_mask.sum(dim=1)
46
- max_len = text_mask.shape[1] + img_len
47
-
48
- cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
49
-
50
- for i in range(batch_size):
51
- s = text_len[i] + img_len
52
- s1 = i * max_len + s
53
- s2 = (i + 1) * max_len
54
- cu_seqlens[2 * i + 1] = s1
55
- cu_seqlens[2 * i + 2] = s2
56
-
57
- return cu_seqlens
58
-
59
-
60
- def attention(
61
- q,
62
- k,
63
- v,
64
- mode="torch",
65
- drop_rate=0,
66
- attn_mask=None,
67
- causal=False,
68
- cu_seqlens_q=None,
69
- cu_seqlens_kv=None,
70
- max_seqlen_q=None,
71
- max_seqlen_kv=None,
72
- batch_size=1,
73
- ):
74
- """
75
- Perform QKV self attention.
76
-
77
- Args:
78
- q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
79
- k (torch.Tensor): Key tensor with shape [b, s1, a, d]
80
- v (torch.Tensor): Value tensor with shape [b, s1, a, d]
81
- mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
82
- drop_rate (float): Dropout rate in attention map. (default: 0)
83
- attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
84
- (default: None)
85
- causal (bool): Whether to use causal attention. (default: False)
86
- cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
87
- used to index into q.
88
- cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
89
- used to index into kv.
90
- max_seqlen_q (int): The maximum sequence length in the batch of q.
91
- max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
92
-
93
- Returns:
94
- torch.Tensor: Output tensor after self attention with shape [b, s, ad]
95
- """
96
- pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
97
- q = pre_attn_layout(q)
98
- k = pre_attn_layout(k)
99
- v = pre_attn_layout(v)
100
-
101
- if mode == "torch":
102
- if attn_mask is not None and attn_mask.dtype != torch.bool:
103
- attn_mask = attn_mask.to(q.dtype)
104
- x = F.scaled_dot_product_attention(
105
- q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
106
- )
107
- elif mode == "flash":
108
- x = flash_attn_varlen_func(
109
- q,
110
- k,
111
- v,
112
- cu_seqlens_q,
113
- cu_seqlens_kv,
114
- max_seqlen_q,
115
- max_seqlen_kv,
116
- )
117
- # x with shape [(bxs), a, d]
118
- x = x.view(
119
- batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
120
- ) # reshape x to [b, s, a, d]
121
- elif mode == "vanilla":
122
- scale_factor = 1 / math.sqrt(q.size(-1))
123
-
124
- b, a, s, _ = q.shape
125
- s1 = k.size(2)
126
- attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
127
- if causal:
128
- # Only applied to self attention
129
- assert (
130
- attn_mask is None
131
- ), "Causal mask and attn_mask cannot be used together"
132
- temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
133
- diagonal=0
134
- )
135
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
136
- attn_bias.to(q.dtype)
137
-
138
- if attn_mask is not None:
139
- if attn_mask.dtype == torch.bool:
140
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
141
- else:
142
- attn_bias += attn_mask
143
-
144
- # TODO: Maybe force q and k to be float32 to avoid numerical overflow
145
- attn = (q @ k.transpose(-2, -1)) * scale_factor
146
- attn += attn_bias
147
- attn = attn.softmax(dim=-1)
148
- attn = torch.dropout(attn, p=drop_rate, train=True)
149
- x = attn @ v
150
- else:
151
- raise NotImplementedError(f"Unsupported attention mode: {mode}")
152
-
153
- x = post_attn_layout(x)
154
- b, s, a, d = x.shape
155
- out = x.reshape(b, s, -1)
156
- return out
157
-
158
-
159
- def parallel_attention(
160
- hybrid_seq_parallel_attn,
161
- q,
162
- k,
163
- v,
164
- img_q_len,
165
- img_kv_len,
166
- cu_seqlens_q,
167
- cu_seqlens_kv
168
- ):
169
- attn1 = hybrid_seq_parallel_attn(
170
- None,
171
- q[:, :img_q_len, :, :],
172
- k[:, :img_kv_len, :, :],
173
- v[:, :img_kv_len, :, :],
174
- dropout_p=0.0,
175
- causal=False,
176
- joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
177
- joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
178
- joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
179
- joint_strategy="rear",
180
- )
181
- if flash_attn.__version__ >= '2.7.0':
182
- attn2, *_ = _flash_attn_forward(
183
- q[:,cu_seqlens_q[1]:],
184
- k[:,cu_seqlens_kv[1]:],
185
- v[:,cu_seqlens_kv[1]:],
186
- dropout_p=0.0,
187
- softmax_scale=q.shape[-1] ** (-0.5),
188
- causal=False,
189
- window_size_left=-1,
190
- window_size_right=-1,
191
- softcap=0.0,
192
- alibi_slopes=None,
193
- return_softmax=False,
194
- )
195
- else:
196
- attn2, *_ = _flash_attn_forward(
197
- q[:,cu_seqlens_q[1]:],
198
- k[:,cu_seqlens_kv[1]:],
199
- v[:,cu_seqlens_kv[1]:],
200
- dropout_p=0.0,
201
- softmax_scale=q.shape[-1] ** (-0.5),
202
- causal=False,
203
- window_size=(-1, -1),
204
- softcap=0.0,
205
- alibi_slopes=None,
206
- return_softmax=False,
207
- )
208
- attn = torch.cat([attn1, attn2], dim=1)
209
- b, s, a, d = attn.shape
210
- attn = attn.reshape(b, s, -1)
211
-
212
- return attn
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except ImportError:
13
+ flash_attn = None
14
+ flash_attn_varlen_func = None
15
+ _flash_attn_forward = None
16
+
17
+
18
+ MEMORY_LAYOUT = {
19
+ "flash": (
20
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
21
+ lambda x: x,
22
+ ),
23
+ "torch": (
24
+ lambda x: x.transpose(1, 2),
25
+ lambda x: x.transpose(1, 2),
26
+ ),
27
+ "vanilla": (
28
+ lambda x: x.transpose(1, 2),
29
+ lambda x: x.transpose(1, 2),
30
+ ),
31
+ }
32
+
33
+
34
+ def get_cu_seqlens(text_mask, img_len):
35
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
36
+
37
+ Args:
38
+ text_mask (torch.Tensor): the mask of text
39
+ img_len (int): the length of image
40
+
41
+ Returns:
42
+ torch.Tensor: the calculated cu_seqlens for flash attention
43
+ """
44
+ batch_size = text_mask.shape[0]
45
+ text_len = text_mask.sum(dim=1)
46
+ max_len = text_mask.shape[1] + img_len
47
+
48
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
49
+
50
+ for i in range(batch_size):
51
+ s = text_len[i] + img_len
52
+ s1 = i * max_len + s
53
+ s2 = (i + 1) * max_len
54
+ cu_seqlens[2 * i + 1] = s1
55
+ cu_seqlens[2 * i + 2] = s2
56
+
57
+ return cu_seqlens
58
+
59
+
60
+ def attention(
61
+ q,
62
+ k,
63
+ v,
64
+ mode="torch",
65
+ drop_rate=0,
66
+ attn_mask=None,
67
+ causal=False,
68
+ cu_seqlens_q=None,
69
+ cu_seqlens_kv=None,
70
+ max_seqlen_q=None,
71
+ max_seqlen_kv=None,
72
+ batch_size=1,
73
+ ):
74
+ """
75
+ Perform QKV self attention.
76
+
77
+ Args:
78
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
79
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
80
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
81
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
82
+ drop_rate (float): Dropout rate in attention map. (default: 0)
83
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
84
+ (default: None)
85
+ causal (bool): Whether to use causal attention. (default: False)
86
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
87
+ used to index into q.
88
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
89
+ used to index into kv.
90
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
91
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
92
+
93
+ Returns:
94
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
95
+ """
96
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
97
+ q = pre_attn_layout(q)
98
+ k = pre_attn_layout(k)
99
+ v = pre_attn_layout(v)
100
+
101
+ if mode == "torch":
102
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
103
+ attn_mask = attn_mask.to(q.dtype)
104
+ x = F.scaled_dot_product_attention(
105
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
106
+ )
107
+ elif mode == "flash":
108
+ x = flash_attn_varlen_func(
109
+ q,
110
+ k,
111
+ v,
112
+ cu_seqlens_q,
113
+ cu_seqlens_kv,
114
+ max_seqlen_q,
115
+ max_seqlen_kv,
116
+ )
117
+ # x with shape [(bxs), a, d]
118
+ x = x.view(
119
+ batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
120
+ ) # reshape x to [b, s, a, d]
121
+ elif mode == "vanilla":
122
+ scale_factor = 1 / math.sqrt(q.size(-1))
123
+
124
+ b, a, s, _ = q.shape
125
+ s1 = k.size(2)
126
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
127
+ if causal:
128
+ # Only applied to self attention
129
+ assert (
130
+ attn_mask is None
131
+ ), "Causal mask and attn_mask cannot be used together"
132
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
133
+ diagonal=0
134
+ )
135
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
136
+ attn_bias.to(q.dtype)
137
+
138
+ if attn_mask is not None:
139
+ if attn_mask.dtype == torch.bool:
140
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
141
+ else:
142
+ attn_bias += attn_mask
143
+
144
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
145
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
146
+ attn += attn_bias
147
+ attn = attn.softmax(dim=-1)
148
+ attn = torch.dropout(attn, p=drop_rate, train=True)
149
+ x = attn @ v
150
+ else:
151
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
152
+
153
+ x = post_attn_layout(x)
154
+ b, s, a, d = x.shape
155
+ out = x.reshape(b, s, -1)
156
+ return out
157
+
158
+
159
+ def parallel_attention(
160
+ hybrid_seq_parallel_attn,
161
+ q,
162
+ k,
163
+ v,
164
+ img_q_len,
165
+ img_kv_len,
166
+ cu_seqlens_q,
167
+ cu_seqlens_kv
168
+ ):
169
+ attn1 = hybrid_seq_parallel_attn(
170
+ None,
171
+ q[:, :img_q_len, :, :],
172
+ k[:, :img_kv_len, :, :],
173
+ v[:, :img_kv_len, :, :],
174
+ dropout_p=0.0,
175
+ causal=False,
176
+ joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
177
+ joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
178
+ joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
179
+ joint_strategy="rear",
180
+ )
181
+ if flash_attn.__version__ >= "2.7.0":
182
+ attn2, *_ = _flash_attn_forward(
183
+ q[:,cu_seqlens_q[1]:],
184
+ k[:,cu_seqlens_kv[1]:],
185
+ v[:,cu_seqlens_kv[1]:],
186
+ dropout_p=0.0,
187
+ softmax_scale=q.shape[-1] ** (-0.5),
188
+ causal=False,
189
+ window_size_left=-1,
190
+ window_size_right=-1,
191
+ softcap=0.0,
192
+ alibi_slopes=None,
193
+ return_softmax=False,
194
+ )
195
+ else:
196
+ attn2, *_ = _flash_attn_forward(
197
+ q[:,cu_seqlens_q[1]:],
198
+ k[:,cu_seqlens_kv[1]:],
199
+ v[:,cu_seqlens_kv[1]:],
200
+ dropout_p=0.0,
201
+ softmax_scale=q.shape[-1] ** (-0.5),
202
+ causal=False,
203
+ window_size=(-1, -1),
204
+ softcap=0.0,
205
+ alibi_slopes=None,
206
+ return_softmax=False,
207
+ )
208
+ attn = torch.cat([attn1, attn2], dim=1)
209
+ b, s, a, d = attn.shape
210
+ attn = attn.reshape(b, s, -1)
211
+
212
+ return attn