DEBUG steaming past key changes 47x during next token calc
Browse files- audiocraft/lm.py +1 -1
- audiocraft/transformer.py +151 -127
- demo.py +2 -2
audiocraft/lm.py
CHANGED
@@ -254,7 +254,7 @@ class LMModel(nn.Module):
|
|
254 |
# so only 2 x sel.flinear() of 4 are used ?
|
255 |
# WHy torch.stack is in dim=1
|
256 |
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
257 |
-
print(f'{input_.shape=} {out.shape=} {cross_attention_input.shape=} {logits.shape=} FUSER LLM')
|
258 |
# remove the prefix from the model outputs
|
259 |
# if len(self.fuser.fuse2cond['prepend']) > 0:
|
260 |
# logits = logits[:, :, -S:]
|
|
|
254 |
# so only 2 x sel.flinear() of 4 are used ?
|
255 |
# WHy torch.stack is in dim=1
|
256 |
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
257 |
+
# print(f'{input_.shape=} {out.shape=} {cross_attention_input.shape=} {logits.shape=} FUSER LLM')
|
258 |
# remove the prefix from the model outputs
|
259 |
# if len(self.fuser.fuse2cond['prepend']) > 0:
|
260 |
# logits = logits[:, :, -S:]
|
audiocraft/transformer.py
CHANGED
@@ -18,13 +18,7 @@ def set_efficient_attention_backend(backend: str = 'torch'):
|
|
18 |
|
19 |
|
20 |
|
21 |
-
|
22 |
-
# Return true if we are currently running with a xformers profiler activated.
|
23 |
-
try:
|
24 |
-
from xformers.profiler import profiler
|
25 |
-
except ImportError:
|
26 |
-
return False
|
27 |
-
return profiler._Profiler._CURRENT_PROFILER is not None
|
28 |
|
29 |
|
30 |
def create_norm_fn(norm_type, dim, **kwargs):
|
@@ -69,35 +63,34 @@ class StreamingMultiheadAttention(nn.Module):
|
|
69 |
def __init__(self,
|
70 |
embed_dim,
|
71 |
num_heads,
|
72 |
-
dropout=0.0,
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
cross_attention: bool = False,
|
76 |
-
qk_layer_norm: bool = False,
|
|
|
77 |
device=None, dtype=None):
|
78 |
super().__init__()
|
79 |
factory_kwargs = {'device': device, 'dtype': dtype}
|
80 |
if past_context is not None:
|
81 |
assert causal
|
82 |
-
|
83 |
self.embed_dim = embed_dim
|
84 |
self.causal = causal
|
85 |
self.past_context = past_context
|
86 |
self.memory_efficient = memory_efficient
|
87 |
self.attention_as_float32 = attention_as_float32
|
88 |
-
|
89 |
self.cross_attention = cross_attention
|
90 |
-
|
91 |
self.num_heads = num_heads
|
92 |
self.dropout = dropout
|
93 |
self.kv_repeat = kv_repeat
|
94 |
if cross_attention:
|
95 |
assert not causal, "Causal cannot work with cross attention."
|
96 |
-
|
97 |
-
|
98 |
if memory_efficient:
|
99 |
_verify_xformers_memory_efficient_compat()
|
100 |
-
|
101 |
self.custom = _is_custom(custom, memory_efficient)
|
102 |
if self.custom:
|
103 |
out_dim = embed_dim
|
@@ -116,18 +109,12 @@ class StreamingMultiheadAttention(nn.Module):
|
|
116 |
if bias:
|
117 |
self.out_proj.bias.data.zero_()
|
118 |
else:
|
119 |
-
|
120 |
-
assert kv_repeat == 1
|
121 |
-
self.mha = nn.MultiheadAttention(
|
122 |
-
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
|
123 |
-
**factory_kwargs)
|
124 |
self.qk_layer_norm = qk_layer_norm
|
125 |
if qk_layer_norm:
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
self.q_layer_norm = nn.LayerNorm(ln_dim)
|
130 |
-
self.k_layer_norm = nn.LayerNorm(ln_dim)
|
131 |
|
132 |
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
133 |
if not self.custom:
|
@@ -137,13 +124,7 @@ class StreamingMultiheadAttention(nn.Module):
|
|
137 |
if prefix + key in state_dict:
|
138 |
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
139 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
def forward(self,
|
148 |
query,
|
149 |
key,
|
@@ -152,7 +133,53 @@ class StreamingMultiheadAttention(nn.Module):
|
|
152 |
need_weights=False,
|
153 |
attn_mask=None,
|
154 |
is_causal=False):
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
157 |
"use the causal args in the constructor.")
|
158 |
# print(f'{query.shape=} {key.shape=} {value.shape=} MHA')
|
@@ -167,9 +194,9 @@ class StreamingMultiheadAttention(nn.Module):
|
|
167 |
custom_attn_mask = attn_mask is not None
|
168 |
|
169 |
if self.custom:
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
if self.cross_attention:
|
174 |
# print('\n\n\n\nCROSS\n\n\n\n')
|
175 |
|
@@ -178,9 +205,8 @@ class StreamingMultiheadAttention(nn.Module):
|
|
178 |
if self.in_proj_bias is None:
|
179 |
bias_q, bias_k, bias_v = None, None, None
|
180 |
else:
|
181 |
-
|
182 |
-
|
183 |
-
bias_v = self.in_proj_bias[2 * dim:]
|
184 |
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
185 |
# print(f'{q.shape=} TRANSF FORW who concaten')
|
186 |
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
@@ -191,18 +217,9 @@ class StreamingMultiheadAttention(nn.Module):
|
|
191 |
k = self.k_layer_norm(k)
|
192 |
|
193 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
194 |
-
|
195 |
else:
|
196 |
-
|
197 |
-
#
|
198 |
-
# 47x Transformers selfattn followed by crossattn
|
199 |
-
#
|
200 |
-
# self-attn is on history? previous key or is it on only the last token?
|
201 |
-
|
202 |
-
if not _is_profiled():
|
203 |
-
# profiling breaks that propertysomehow.
|
204 |
-
assert query is key, "specialized implementation"
|
205 |
-
assert value is key, "specialized implementation"
|
206 |
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
|
207 |
if self.kv_repeat == 1:
|
208 |
if time_dim == 2:
|
@@ -217,66 +234,50 @@ class StreamingMultiheadAttention(nn.Module):
|
|
217 |
q, k, v = ops.unbind(packed, dim=2)
|
218 |
# print(f'{q.shape=} {v.shape=} @L331 trasnforemr.py') # packed is bs=2
|
219 |
else:
|
220 |
-
|
221 |
-
|
222 |
-
kv_heads = self.num_heads // self.kv_repeat
|
223 |
-
q = projected[:, :, :embed_dim]
|
224 |
-
start = embed_dim
|
225 |
-
end = start + per_head_dim * kv_heads
|
226 |
-
k = projected[:, :, start: end]
|
227 |
-
v = projected[:, :, end:]
|
228 |
-
q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
|
229 |
-
k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
|
230 |
-
v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
|
231 |
|
232 |
if self.qk_layer_norm is True:
|
233 |
-
|
234 |
-
|
235 |
-
q = self.q_layer_norm(q)
|
236 |
-
k = self.k_layer_norm(k)
|
237 |
-
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
|
238 |
-
|
239 |
|
240 |
if self.kv_repeat > 1:
|
241 |
-
|
242 |
print('Expand repear 2')
|
243 |
|
244 |
if self.attention_as_float32:
|
245 |
-
|
|
|
246 |
if self.memory_efficient:
|
247 |
if custom_attn_mask:
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
attn_mask = attn_mask.to(q.dtype)
|
252 |
-
attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
|
253 |
-
attn_mask = attn_mask[..., :seq_len, :seq_len]
|
254 |
|
255 |
p = self.dropout if self.training else 0
|
256 |
if _efficient_attention_backend == 'torch':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
x = torch.nn.functional.scaled_dot_product_attention(
|
258 |
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
|
259 |
else:
|
260 |
-
|
|
|
|
|
|
|
261 |
else:
|
262 |
-
# We include the dot product as float32, for consistency
|
263 |
-
# with the other implementations that include that step
|
264 |
-
# as part of the attention. Note that when using `autocast`,
|
265 |
-
# the einsums would be done as bfloat16, but the softmax
|
266 |
-
# would be done as bfloat16, so `attention_as_float32` will
|
267 |
-
# extend a bit the range of operations done in float32,
|
268 |
-
# although this should make no difference.
|
269 |
-
q = q / q.shape[-1] ** 0.5
|
270 |
-
key_layout = layout.replace('t', 'k')
|
271 |
-
query_layout = layout
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
# Key and value have the same format.
|
279 |
-
x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
|
280 |
x = x.to(dtype)
|
281 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
282 |
x = self.out_proj(x)
|
@@ -313,8 +314,9 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
313 |
'memory_efficient': memory_efficient,
|
314 |
'attention_as_float32': attention_as_float32,
|
315 |
}
|
316 |
-
self.self_attn
|
317 |
-
causal=causal,
|
|
|
318 |
# rope=rope,
|
319 |
qk_layer_norm=qk_layer_norm,
|
320 |
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
|
@@ -336,6 +338,17 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
336 |
|
337 |
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
338 |
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
def _cross_attention_block(self,
|
341 |
src,
|
@@ -353,27 +366,37 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
353 |
cross_attention_src=None):
|
354 |
|
355 |
|
356 |
-
|
357 |
if self.norm_first:
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
if cross_attention_src is not None:
|
365 |
x = x + self._cross_attention_block(
|
366 |
self.norm_cross(x),
|
367 |
cross_attention_src)
|
368 |
-
|
369 |
-
# crossattn torch.Size([2, 2, 1536]) torch.Size([2, 4, 1536])
|
370 |
else:
|
371 |
-
|
|
|
|
|
372 |
x = x + self._ff_block(self.norm2(x))
|
373 |
else:
|
374 |
print('NLAST')
|
375 |
-
|
376 |
return x
|
|
|
|
|
377 |
|
378 |
|
379 |
class StreamingTransformer(nn.Module):
|
@@ -422,6 +445,7 @@ class StreamingTransformer(nn.Module):
|
|
422 |
device=device, dtype=dtype, **kwargs))
|
423 |
|
424 |
if self.checkpointing != 'none':
|
|
|
425 |
for layer in self.layers:
|
426 |
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
427 |
# backward hook inside of FSDP...
|
@@ -443,30 +467,30 @@ class StreamingTransformer(nn.Module):
|
|
443 |
|
444 |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
445 |
x = x + self.positional_scale * pos_emb
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
for _, lay in enumerate(self.layers):
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
print('OUT OF Tall', x.shape) # [1,2,1536] # why this gets filled with sequence 1,2...
|
460 |
-
# should be 1 query
|
461 |
return x
|
462 |
|
463 |
-
|
464 |
-
group = {"params": list(self.parameters())}
|
465 |
-
if self.lr is not None:
|
466 |
-
group["lr"] = self.lr
|
467 |
-
if self.weight_decay is not None:
|
468 |
-
group["weight_decay"] = self.weight_decay
|
469 |
-
return group
|
470 |
|
471 |
|
472 |
# special attention related function
|
|
|
18 |
|
19 |
|
20 |
|
21 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def create_norm_fn(norm_type, dim, **kwargs):
|
|
|
63 |
def __init__(self,
|
64 |
embed_dim,
|
65 |
num_heads,
|
66 |
+
dropout=0.0,
|
67 |
+
bias: bool = True,
|
68 |
+
causal: bool = False,
|
69 |
+
past_context: tp.Optional[int] = None,
|
70 |
+
custom: bool = False,
|
71 |
+
memory_efficient: bool = False,
|
72 |
+
attention_as_float32: bool = False,
|
73 |
cross_attention: bool = False,
|
74 |
+
qk_layer_norm: bool = False,
|
75 |
+
kv_repeat: int = 1,
|
76 |
device=None, dtype=None):
|
77 |
super().__init__()
|
78 |
factory_kwargs = {'device': device, 'dtype': dtype}
|
79 |
if past_context is not None:
|
80 |
assert causal
|
|
|
81 |
self.embed_dim = embed_dim
|
82 |
self.causal = causal
|
83 |
self.past_context = past_context
|
84 |
self.memory_efficient = memory_efficient
|
85 |
self.attention_as_float32 = attention_as_float32
|
|
|
86 |
self.cross_attention = cross_attention
|
|
|
87 |
self.num_heads = num_heads
|
88 |
self.dropout = dropout
|
89 |
self.kv_repeat = kv_repeat
|
90 |
if cross_attention:
|
91 |
assert not causal, "Causal cannot work with cross attention."
|
|
|
|
|
92 |
if memory_efficient:
|
93 |
_verify_xformers_memory_efficient_compat()
|
|
|
94 |
self.custom = _is_custom(custom, memory_efficient)
|
95 |
if self.custom:
|
96 |
out_dim = embed_dim
|
|
|
109 |
if bias:
|
110 |
self.out_proj.bias.data.zero_()
|
111 |
else:
|
112 |
+
print('mha ini else')
|
|
|
|
|
|
|
|
|
113 |
self.qk_layer_norm = qk_layer_norm
|
114 |
if qk_layer_norm:
|
115 |
+
print('QK norm')
|
116 |
+
|
117 |
+
|
|
|
|
|
118 |
|
119 |
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
120 |
if not self.custom:
|
|
|
124 |
if prefix + key in state_dict:
|
125 |
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
126 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
|
|
|
|
|
127 |
|
|
|
|
|
|
|
128 |
def forward(self,
|
129 |
query,
|
130 |
key,
|
|
|
133 |
need_weights=False,
|
134 |
attn_mask=None,
|
135 |
is_causal=False):
|
136 |
+
# 2=cond/uncond
|
137 |
+
# 24=heads
|
138 |
+
# 1=seqlen
|
139 |
+
# 64=channel
|
140 |
+
#
|
141 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
|
142 |
+
# 43
|
143 |
+
# ____________
|
144 |
+
# SELF
|
145 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
|
146 |
+
# sa_ x.shape=torch.Size([2, 1, 1536])
|
147 |
+
|
148 |
+
# X
|
149 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
|
150 |
+
# 44
|
151 |
+
# ____________
|
152 |
+
# SELF
|
153 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
|
154 |
+
# sa_ x.shape=torch.Size([2, 1, 1536])
|
155 |
+
|
156 |
+
# X
|
157 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
|
158 |
+
# 45
|
159 |
+
# ____________
|
160 |
+
# SELF
|
161 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
|
162 |
+
# sa_ x.shape=torch.Size([2, 1, 1536])
|
163 |
+
|
164 |
+
# X
|
165 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
|
166 |
+
# 46
|
167 |
+
# ____________
|
168 |
+
# SELF
|
169 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
|
170 |
+
# sa_ x.shape=torch.Size([2, 1, 1536])
|
171 |
+
|
172 |
+
# X
|
173 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
|
174 |
+
# 47
|
175 |
+
# ____________
|
176 |
+
# SELF
|
177 |
+
# q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
|
178 |
+
# sa_ x.shape=torch.Size([2, 1, 1536])
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
184 |
"use the causal args in the constructor.")
|
185 |
# print(f'{query.shape=} {key.shape=} {value.shape=} MHA')
|
|
|
194 |
custom_attn_mask = attn_mask is not None
|
195 |
|
196 |
if self.custom:
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
if self.cross_attention:
|
201 |
# print('\n\n\n\nCROSS\n\n\n\n')
|
202 |
|
|
|
205 |
if self.in_proj_bias is None:
|
206 |
bias_q, bias_k, bias_v = None, None, None
|
207 |
else:
|
208 |
+
print('no self proj bi')
|
209 |
+
|
|
|
210 |
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
211 |
# print(f'{q.shape=} TRANSF FORW who concaten')
|
212 |
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
|
|
217 |
k = self.k_layer_norm(k)
|
218 |
|
219 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
220 |
+
|
221 |
else:
|
222 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
|
224 |
if self.kv_repeat == 1:
|
225 |
if time_dim == 2:
|
|
|
234 |
q, k, v = ops.unbind(packed, dim=2)
|
235 |
# print(f'{q.shape=} {v.shape=} @L331 trasnforemr.py') # packed is bs=2
|
236 |
else:
|
237 |
+
|
238 |
+
print("ELSE kv rp")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
if self.qk_layer_norm is True:
|
241 |
+
|
242 |
+
print('QL lay norm')
|
|
|
|
|
|
|
|
|
243 |
|
244 |
if self.kv_repeat > 1:
|
245 |
+
|
246 |
print('Expand repear 2')
|
247 |
|
248 |
if self.attention_as_float32:
|
249 |
+
print('AS FLOAT32')
|
250 |
+
|
251 |
if self.memory_efficient:
|
252 |
if custom_attn_mask:
|
253 |
+
|
254 |
+
print('CUSTOM ATTN MSK')
|
255 |
+
|
|
|
|
|
|
|
256 |
|
257 |
p = self.dropout if self.training else 0
|
258 |
if _efficient_attention_backend == 'torch':
|
259 |
+
|
260 |
+
# print(f'{q.shape=} {k.shape=} {v.shape=} 90')
|
261 |
+
print(f'{x.sum()=} {q.sum()=} {k.sum()=} {v.sum()=} 90 variation of qkv during 47')
|
262 |
+
# the k.sum(),v.sum() changes over the 47transfs how is that possible if self._sa
|
263 |
+
# has q-len = 1.
|
264 |
+
#
|
265 |
+
#
|
266 |
+
|
267 |
x = torch.nn.functional.scaled_dot_product_attention(
|
268 |
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
|
269 |
else:
|
270 |
+
|
271 |
+
print('MHA OPS')
|
272 |
+
|
273 |
+
|
274 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
+
print('CONSISTENCY ')
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
|
|
|
|
281 |
x = x.to(dtype)
|
282 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
283 |
x = self.out_proj(x)
|
|
|
314 |
'memory_efficient': memory_efficient,
|
315 |
'attention_as_float32': attention_as_float32,
|
316 |
}
|
317 |
+
self.self_attn=StreamingMultiheadAttention(
|
318 |
+
causal=causal,
|
319 |
+
past_context=past_context,
|
320 |
# rope=rope,
|
321 |
qk_layer_norm=qk_layer_norm,
|
322 |
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
|
|
|
338 |
|
339 |
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
340 |
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
341 |
+
|
342 |
+
# ENVS....d4/lib/python3.10/site-packages/torch/nn/modules/transformer.py @TransformerEncoderLayer
|
343 |
+
def _sa_block(self, q, k, v):
|
344 |
+
x = self.self_attn(q,
|
345 |
+
k,
|
346 |
+
v,
|
347 |
+
attn_mask=None,
|
348 |
+
key_padding_mask=None,
|
349 |
+
need_weights=False,
|
350 |
+
is_causal=None)[0]
|
351 |
+
return self.dropout1(x)
|
352 |
|
353 |
def _cross_attention_block(self,
|
354 |
src,
|
|
|
366 |
cross_attention_src=None):
|
367 |
|
368 |
|
369 |
+
|
370 |
if self.norm_first:
|
371 |
+
print('selfattn')
|
372 |
+
history = self.norm1(src)
|
373 |
+
x = history[:, -1:, :]
|
374 |
+
|
375 |
+
# THIS IS COMPUTED with 1 timestep
|
376 |
+
# just before the call there is cat([past_k, k])
|
377 |
+
# Thus we just
|
378 |
+
x = x + self._sa_block(x, # THIS should be square as the history is updated
|
379 |
+
# then the -1 item of history goes to the text x text
|
380 |
+
#
|
381 |
+
history,
|
382 |
+
history)
|
383 |
+
print('crossattn')
|
384 |
if cross_attention_src is not None:
|
385 |
x = x + self._cross_attention_block(
|
386 |
self.norm_cross(x),
|
387 |
cross_attention_src)
|
388 |
+
|
|
|
389 |
else:
|
390 |
+
print('NOT IMPL')
|
391 |
+
|
392 |
+
|
393 |
x = x + self._ff_block(self.norm2(x))
|
394 |
else:
|
395 |
print('NLAST')
|
396 |
+
|
397 |
return x
|
398 |
+
|
399 |
+
|
400 |
|
401 |
|
402 |
class StreamingTransformer(nn.Module):
|
|
|
445 |
device=device, dtype=dtype, **kwargs))
|
446 |
|
447 |
if self.checkpointing != 'none':
|
448 |
+
print('Checkpointing????????????')
|
449 |
for layer in self.layers:
|
450 |
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
451 |
# backward hook inside of FSDP...
|
|
|
467 |
|
468 |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
469 |
x = x + self.positional_scale * pos_emb
|
470 |
+
|
471 |
+
|
472 |
+
|
473 |
+
# 47x transformer layers for frozen history
|
474 |
+
# -> history is updated by self._sa() althought her length is fixed
|
475 |
+
# -> the q that comes out of the text x text cross attn
|
476 |
+
# is given as q to the next lay's self._sa() with updated history
|
477 |
+
# ->
|
478 |
+
# ->
|
479 |
for _, lay in enumerate(self.layers):
|
480 |
+
print(f'_________________\n{_}')
|
481 |
+
# 1 q = last_token x history x history
|
482 |
+
# 2 next_token = q x text x text
|
483 |
|
484 |
+
# x preserves full history for self._sa(). After all transformers we return only last -1 tok
|
485 |
+
x, history = lay(
|
486 |
+
x,
|
487 |
+
history=history, # only updated by self_attn (the cross sees only last token)
|
488 |
+
cross_attention_src=kwargs["cross_attention_src"],
|
489 |
+
src_mask=kwargs['src_mask']
|
490 |
+
) # x : [bs, 24, 37, 64]
|
|
|
|
|
491 |
return x
|
492 |
|
493 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
|
495 |
|
496 |
# special attention related function
|
demo.py
CHANGED
@@ -4,10 +4,10 @@ import numpy as np
|
|
4 |
|
5 |
print('\n\n\n\n___________________')
|
6 |
|
7 |
-
txt = 'dogs in street'
|
8 |
|
9 |
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
10 |
-
sound_generator.set_generation_params(duration
|
11 |
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|
|
|
4 |
|
5 |
print('\n\n\n\n___________________')
|
6 |
|
7 |
+
txt = 'dogs in the street'
|
8 |
|
9 |
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
10 |
+
sound_generator.set_generation_params(duration=.74) # why is generating so long at 14 seconds
|
11 |
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|