Dionyssos commited on
Commit
e83a997
·
1 Parent(s): a0ce150

DEBUG steaming past key changes 47x during next token calc

Browse files
Files changed (3) hide show
  1. audiocraft/lm.py +1 -1
  2. audiocraft/transformer.py +151 -127
  3. demo.py +2 -2
audiocraft/lm.py CHANGED
@@ -254,7 +254,7 @@ class LMModel(nn.Module):
254
  # so only 2 x sel.flinear() of 4 are used ?
255
  # WHy torch.stack is in dim=1
256
  logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
257
- print(f'{input_.shape=} {out.shape=} {cross_attention_input.shape=} {logits.shape=} FUSER LLM')
258
  # remove the prefix from the model outputs
259
  # if len(self.fuser.fuse2cond['prepend']) > 0:
260
  # logits = logits[:, :, -S:]
 
254
  # so only 2 x sel.flinear() of 4 are used ?
255
  # WHy torch.stack is in dim=1
256
  logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
257
+ # print(f'{input_.shape=} {out.shape=} {cross_attention_input.shape=} {logits.shape=} FUSER LLM')
258
  # remove the prefix from the model outputs
259
  # if len(self.fuser.fuse2cond['prepend']) > 0:
260
  # logits = logits[:, :, -S:]
audiocraft/transformer.py CHANGED
@@ -18,13 +18,7 @@ def set_efficient_attention_backend(backend: str = 'torch'):
18
 
19
 
20
 
21
- def _is_profiled():
22
- # Return true if we are currently running with a xformers profiler activated.
23
- try:
24
- from xformers.profiler import profiler
25
- except ImportError:
26
- return False
27
- return profiler._Profiler._CURRENT_PROFILER is not None
28
 
29
 
30
  def create_norm_fn(norm_type, dim, **kwargs):
@@ -69,35 +63,34 @@ class StreamingMultiheadAttention(nn.Module):
69
  def __init__(self,
70
  embed_dim,
71
  num_heads,
72
- dropout=0.0, bias: bool = True,
73
- causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
74
- memory_efficient: bool = False, attention_as_float32: bool = False,
 
 
 
 
75
  cross_attention: bool = False,
76
- qk_layer_norm: bool = False, kv_repeat: int = 1,
 
77
  device=None, dtype=None):
78
  super().__init__()
79
  factory_kwargs = {'device': device, 'dtype': dtype}
80
  if past_context is not None:
81
  assert causal
82
-
83
  self.embed_dim = embed_dim
84
  self.causal = causal
85
  self.past_context = past_context
86
  self.memory_efficient = memory_efficient
87
  self.attention_as_float32 = attention_as_float32
88
-
89
  self.cross_attention = cross_attention
90
-
91
  self.num_heads = num_heads
92
  self.dropout = dropout
93
  self.kv_repeat = kv_repeat
94
  if cross_attention:
95
  assert not causal, "Causal cannot work with cross attention."
96
-
97
-
98
  if memory_efficient:
99
  _verify_xformers_memory_efficient_compat()
100
-
101
  self.custom = _is_custom(custom, memory_efficient)
102
  if self.custom:
103
  out_dim = embed_dim
@@ -116,18 +109,12 @@ class StreamingMultiheadAttention(nn.Module):
116
  if bias:
117
  self.out_proj.bias.data.zero_()
118
  else:
119
- assert not qk_layer_norm
120
- assert kv_repeat == 1
121
- self.mha = nn.MultiheadAttention(
122
- embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
123
- **factory_kwargs)
124
  self.qk_layer_norm = qk_layer_norm
125
  if qk_layer_norm:
126
- assert self.custom
127
- assert kv_repeat == 1
128
- ln_dim = embed_dim
129
- self.q_layer_norm = nn.LayerNorm(ln_dim)
130
- self.k_layer_norm = nn.LayerNorm(ln_dim)
131
 
132
  def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
133
  if not self.custom:
@@ -137,13 +124,7 @@ class StreamingMultiheadAttention(nn.Module):
137
  if prefix + key in state_dict:
138
  state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
139
  super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
140
-
141
-
142
-
143
 
144
-
145
-
146
-
147
  def forward(self,
148
  query,
149
  key,
@@ -152,7 +133,53 @@ class StreamingMultiheadAttention(nn.Module):
152
  need_weights=False,
153
  attn_mask=None,
154
  is_causal=False):
155
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  assert not is_causal, ("New param added in torch 2.0.1 not supported, "
157
  "use the causal args in the constructor.")
158
  # print(f'{query.shape=} {key.shape=} {value.shape=} MHA')
@@ -167,9 +194,9 @@ class StreamingMultiheadAttention(nn.Module):
167
  custom_attn_mask = attn_mask is not None
168
 
169
  if self.custom:
170
- # custom implementation
171
- assert need_weights is False
172
- assert key_padding_mask is None
173
  if self.cross_attention:
174
  # print('\n\n\n\nCROSS\n\n\n\n')
175
 
@@ -178,9 +205,8 @@ class StreamingMultiheadAttention(nn.Module):
178
  if self.in_proj_bias is None:
179
  bias_q, bias_k, bias_v = None, None, None
180
  else:
181
- bias_q = self.in_proj_bias[:dim]
182
- bias_k = self.in_proj_bias[dim: 2 * dim]
183
- bias_v = self.in_proj_bias[2 * dim:]
184
  q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
185
  # print(f'{q.shape=} TRANSF FORW who concaten')
186
  # todo: when streaming, we could actually save k, v and check the shape actually match.
@@ -191,18 +217,9 @@ class StreamingMultiheadAttention(nn.Module):
191
  k = self.k_layer_norm(k)
192
 
193
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
194
- # print(f'{q.shape=} {k.shape=} {v.shape=} after rearrange')
195
  else:
196
- # print('\n\n\n\nSELF\n\n\n\n')
197
- #
198
- # 47x Transformers selfattn followed by crossattn
199
- #
200
- # self-attn is on history? previous key or is it on only the last token?
201
-
202
- if not _is_profiled():
203
- # profiling breaks that propertysomehow.
204
- assert query is key, "specialized implementation"
205
- assert value is key, "specialized implementation"
206
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
207
  if self.kv_repeat == 1:
208
  if time_dim == 2:
@@ -217,66 +234,50 @@ class StreamingMultiheadAttention(nn.Module):
217
  q, k, v = ops.unbind(packed, dim=2)
218
  # print(f'{q.shape=} {v.shape=} @L331 trasnforemr.py') # packed is bs=2
219
  else:
220
- embed_dim = self.embed_dim
221
- per_head_dim = (embed_dim // self.num_heads)
222
- kv_heads = self.num_heads // self.kv_repeat
223
- q = projected[:, :, :embed_dim]
224
- start = embed_dim
225
- end = start + per_head_dim * kv_heads
226
- k = projected[:, :, start: end]
227
- v = projected[:, :, end:]
228
- q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
229
- k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
230
- v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
231
 
232
  if self.qk_layer_norm is True:
233
- assert self.kv_repeat == 1
234
- q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
235
- q = self.q_layer_norm(q)
236
- k = self.k_layer_norm(k)
237
- q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
238
-
239
 
240
  if self.kv_repeat > 1:
241
- #
242
  print('Expand repear 2')
243
 
244
  if self.attention_as_float32:
245
- q, k, v = [x.float() for x in [q, k, v]]
 
246
  if self.memory_efficient:
247
  if custom_attn_mask:
248
- # When using a custom attn mask:
249
- # Move to query's device, repeat for each sample, remove align8 padding
250
- seq_len = query.shape[1]
251
- attn_mask = attn_mask.to(q.dtype)
252
- attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
253
- attn_mask = attn_mask[..., :seq_len, :seq_len]
254
 
255
  p = self.dropout if self.training else 0
256
  if _efficient_attention_backend == 'torch':
 
 
 
 
 
 
 
 
257
  x = torch.nn.functional.scaled_dot_product_attention(
258
  q, k, v, is_causal=attn_mask is not None, dropout_p=p)
259
  else:
260
- x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
 
 
 
261
  else:
262
- # We include the dot product as float32, for consistency
263
- # with the other implementations that include that step
264
- # as part of the attention. Note that when using `autocast`,
265
- # the einsums would be done as bfloat16, but the softmax
266
- # would be done as bfloat16, so `attention_as_float32` will
267
- # extend a bit the range of operations done in float32,
268
- # although this should make no difference.
269
- q = q / q.shape[-1] ** 0.5
270
- key_layout = layout.replace('t', 'k')
271
- query_layout = layout
272
 
273
- pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
274
- if attn_mask is not None:
275
- pre_w = pre_w + attn_mask
276
- w = torch.softmax(pre_w, dim=-1)
277
- w = F.dropout(w, self.dropout, training=self.training).to(v)
278
- # Key and value have the same format.
279
- x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
280
  x = x.to(dtype)
281
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
282
  x = self.out_proj(x)
@@ -313,8 +314,9 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
313
  'memory_efficient': memory_efficient,
314
  'attention_as_float32': attention_as_float32,
315
  }
316
- self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
317
- causal=causal, past_context=past_context,
 
318
  # rope=rope,
319
  qk_layer_norm=qk_layer_norm,
320
  kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
@@ -336,6 +338,17 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
336
 
337
  self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
338
  self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  def _cross_attention_block(self,
341
  src,
@@ -353,27 +366,37 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
353
  cross_attention_src=None):
354
 
355
 
356
- x = src
357
  if self.norm_first:
358
- # print('selfattn', x.shape, src_mask, src_key_padding_mask)
359
- x = x + self._sa_block(self.norm1(x),
360
- src_mask, #None
361
- src_key_padding_mask # None
362
- ) # Internal nn
363
- # print('crossattn', x.shape, cross_attention_src.shape)
 
 
 
 
 
 
 
364
  if cross_attention_src is not None:
365
  x = x + self._cross_attention_block(
366
  self.norm_cross(x),
367
  cross_attention_src)
368
- # selfattn torch.Size([2, 2, 1536]) None None NO 4D TOKEN!
369
- # crossattn torch.Size([2, 2, 1536]) torch.Size([2, 4, 1536])
370
  else:
371
- raise NotImplementedError # all layers have a self & cross?
 
 
372
  x = x + self._ff_block(self.norm2(x))
373
  else:
374
  print('NLAST')
375
- # print('NT', x.shape) # [1,2 ,1536]
376
  return x
 
 
377
 
378
 
379
  class StreamingTransformer(nn.Module):
@@ -422,6 +445,7 @@ class StreamingTransformer(nn.Module):
422
  device=device, dtype=dtype, **kwargs))
423
 
424
  if self.checkpointing != 'none':
 
425
  for layer in self.layers:
426
  # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
427
  # backward hook inside of FSDP...
@@ -443,30 +467,30 @@ class StreamingTransformer(nn.Module):
443
 
444
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
445
  x = x + self.positional_scale * pos_emb
446
- # UNTIL HERE BATCH=1
 
 
 
 
 
 
 
 
447
  for _, lay in enumerate(self.layers):
448
- # if _ < 2:
449
- # L=0 [1,1,1536]
450
- # L=1 [2,1,1536]
451
 
452
- print(f'L={_} {args=} {kwargs["cross_attention_src"].shape=} {x.shape=} StreamTransf ForLoop') # [2, 1, 1536] BATCH=2
453
- # x = self._apply_layer(layer, x, *args, **kwargs)
454
- # x = lay(x, **kwargs)
455
- x = lay(x,
456
- cross_attention_src=kwargs["cross_attention_src"],
457
- src_mask=kwargs['src_mask'])
458
- # concat old token to query oh not here is on lm generate
459
- print('OUT OF Tall', x.shape) # [1,2,1536] # why this gets filled with sequence 1,2...
460
- # should be 1 query
461
  return x
462
 
463
- def make_optim_group(self):
464
- group = {"params": list(self.parameters())}
465
- if self.lr is not None:
466
- group["lr"] = self.lr
467
- if self.weight_decay is not None:
468
- group["weight_decay"] = self.weight_decay
469
- return group
470
 
471
 
472
  # special attention related function
 
18
 
19
 
20
 
21
+
 
 
 
 
 
 
22
 
23
 
24
  def create_norm_fn(norm_type, dim, **kwargs):
 
63
  def __init__(self,
64
  embed_dim,
65
  num_heads,
66
+ dropout=0.0,
67
+ bias: bool = True,
68
+ causal: bool = False,
69
+ past_context: tp.Optional[int] = None,
70
+ custom: bool = False,
71
+ memory_efficient: bool = False,
72
+ attention_as_float32: bool = False,
73
  cross_attention: bool = False,
74
+ qk_layer_norm: bool = False,
75
+ kv_repeat: int = 1,
76
  device=None, dtype=None):
77
  super().__init__()
78
  factory_kwargs = {'device': device, 'dtype': dtype}
79
  if past_context is not None:
80
  assert causal
 
81
  self.embed_dim = embed_dim
82
  self.causal = causal
83
  self.past_context = past_context
84
  self.memory_efficient = memory_efficient
85
  self.attention_as_float32 = attention_as_float32
 
86
  self.cross_attention = cross_attention
 
87
  self.num_heads = num_heads
88
  self.dropout = dropout
89
  self.kv_repeat = kv_repeat
90
  if cross_attention:
91
  assert not causal, "Causal cannot work with cross attention."
 
 
92
  if memory_efficient:
93
  _verify_xformers_memory_efficient_compat()
 
94
  self.custom = _is_custom(custom, memory_efficient)
95
  if self.custom:
96
  out_dim = embed_dim
 
109
  if bias:
110
  self.out_proj.bias.data.zero_()
111
  else:
112
+ print('mha ini else')
 
 
 
 
113
  self.qk_layer_norm = qk_layer_norm
114
  if qk_layer_norm:
115
+ print('QK norm')
116
+
117
+
 
 
118
 
119
  def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
120
  if not self.custom:
 
124
  if prefix + key in state_dict:
125
  state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
126
  super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
 
 
 
127
 
 
 
 
128
  def forward(self,
129
  query,
130
  key,
 
133
  need_weights=False,
134
  attn_mask=None,
135
  is_causal=False):
136
+ # 2=cond/uncond
137
+ # 24=heads
138
+ # 1=seqlen
139
+ # 64=channel
140
+ #
141
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
142
+ # 43
143
+ # ____________
144
+ # SELF
145
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
146
+ # sa_ x.shape=torch.Size([2, 1, 1536])
147
+
148
+ # X
149
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
150
+ # 44
151
+ # ____________
152
+ # SELF
153
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
154
+ # sa_ x.shape=torch.Size([2, 1, 1536])
155
+
156
+ # X
157
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
158
+ # 45
159
+ # ____________
160
+ # SELF
161
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
162
+ # sa_ x.shape=torch.Size([2, 1, 1536])
163
+
164
+ # X
165
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
166
+ # 46
167
+ # ____________
168
+ # SELF
169
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
170
+ # sa_ x.shape=torch.Size([2, 1, 1536])
171
+
172
+ # X
173
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
174
+ # 47
175
+ # ____________
176
+ # SELF
177
+ # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
178
+ # sa_ x.shape=torch.Size([2, 1, 1536])
179
+
180
+
181
+
182
+
183
  assert not is_causal, ("New param added in torch 2.0.1 not supported, "
184
  "use the causal args in the constructor.")
185
  # print(f'{query.shape=} {key.shape=} {value.shape=} MHA')
 
194
  custom_attn_mask = attn_mask is not None
195
 
196
  if self.custom:
197
+
198
+
199
+
200
  if self.cross_attention:
201
  # print('\n\n\n\nCROSS\n\n\n\n')
202
 
 
205
  if self.in_proj_bias is None:
206
  bias_q, bias_k, bias_v = None, None, None
207
  else:
208
+ print('no self proj bi')
209
+
 
210
  q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
211
  # print(f'{q.shape=} TRANSF FORW who concaten')
212
  # todo: when streaming, we could actually save k, v and check the shape actually match.
 
217
  k = self.k_layer_norm(k)
218
 
219
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
220
+
221
  else:
222
+
 
 
 
 
 
 
 
 
 
223
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
224
  if self.kv_repeat == 1:
225
  if time_dim == 2:
 
234
  q, k, v = ops.unbind(packed, dim=2)
235
  # print(f'{q.shape=} {v.shape=} @L331 trasnforemr.py') # packed is bs=2
236
  else:
237
+
238
+ print("ELSE kv rp")
 
 
 
 
 
 
 
 
 
239
 
240
  if self.qk_layer_norm is True:
241
+
242
+ print('QL lay norm')
 
 
 
 
243
 
244
  if self.kv_repeat > 1:
245
+
246
  print('Expand repear 2')
247
 
248
  if self.attention_as_float32:
249
+ print('AS FLOAT32')
250
+
251
  if self.memory_efficient:
252
  if custom_attn_mask:
253
+
254
+ print('CUSTOM ATTN MSK')
255
+
 
 
 
256
 
257
  p = self.dropout if self.training else 0
258
  if _efficient_attention_backend == 'torch':
259
+
260
+ # print(f'{q.shape=} {k.shape=} {v.shape=} 90')
261
+ print(f'{x.sum()=} {q.sum()=} {k.sum()=} {v.sum()=} 90 variation of qkv during 47')
262
+ # the k.sum(),v.sum() changes over the 47transfs how is that possible if self._sa
263
+ # has q-len = 1.
264
+ #
265
+ #
266
+
267
  x = torch.nn.functional.scaled_dot_product_attention(
268
  q, k, v, is_causal=attn_mask is not None, dropout_p=p)
269
  else:
270
+
271
+ print('MHA OPS')
272
+
273
+
274
  else:
 
 
 
 
 
 
 
 
 
 
275
 
276
+ print('CONSISTENCY ')
277
+
278
+
279
+
280
+
 
 
281
  x = x.to(dtype)
282
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
283
  x = self.out_proj(x)
 
314
  'memory_efficient': memory_efficient,
315
  'attention_as_float32': attention_as_float32,
316
  }
317
+ self.self_attn=StreamingMultiheadAttention(
318
+ causal=causal,
319
+ past_context=past_context,
320
  # rope=rope,
321
  qk_layer_norm=qk_layer_norm,
322
  kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
 
338
 
339
  self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
340
  self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
341
+
342
+ # ENVS....d4/lib/python3.10/site-packages/torch/nn/modules/transformer.py @TransformerEncoderLayer
343
+ def _sa_block(self, q, k, v):
344
+ x = self.self_attn(q,
345
+ k,
346
+ v,
347
+ attn_mask=None,
348
+ key_padding_mask=None,
349
+ need_weights=False,
350
+ is_causal=None)[0]
351
+ return self.dropout1(x)
352
 
353
  def _cross_attention_block(self,
354
  src,
 
366
  cross_attention_src=None):
367
 
368
 
369
+
370
  if self.norm_first:
371
+ print('selfattn')
372
+ history = self.norm1(src)
373
+ x = history[:, -1:, :]
374
+
375
+ # THIS IS COMPUTED with 1 timestep
376
+ # just before the call there is cat([past_k, k])
377
+ # Thus we just
378
+ x = x + self._sa_block(x, # THIS should be square as the history is updated
379
+ # then the -1 item of history goes to the text x text
380
+ #
381
+ history,
382
+ history)
383
+ print('crossattn')
384
  if cross_attention_src is not None:
385
  x = x + self._cross_attention_block(
386
  self.norm_cross(x),
387
  cross_attention_src)
388
+
 
389
  else:
390
+ print('NOT IMPL')
391
+
392
+
393
  x = x + self._ff_block(self.norm2(x))
394
  else:
395
  print('NLAST')
396
+
397
  return x
398
+
399
+
400
 
401
 
402
  class StreamingTransformer(nn.Module):
 
445
  device=device, dtype=dtype, **kwargs))
446
 
447
  if self.checkpointing != 'none':
448
+ print('Checkpointing????????????')
449
  for layer in self.layers:
450
  # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
451
  # backward hook inside of FSDP...
 
467
 
468
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
469
  x = x + self.positional_scale * pos_emb
470
+
471
+
472
+
473
+ # 47x transformer layers for frozen history
474
+ # -> history is updated by self._sa() althought her length is fixed
475
+ # -> the q that comes out of the text x text cross attn
476
+ # is given as q to the next lay's self._sa() with updated history
477
+ # ->
478
+ # ->
479
  for _, lay in enumerate(self.layers):
480
+ print(f'_________________\n{_}')
481
+ # 1 q = last_token x history x history
482
+ # 2 next_token = q x text x text
483
 
484
+ # x preserves full history for self._sa(). After all transformers we return only last -1 tok
485
+ x, history = lay(
486
+ x,
487
+ history=history, # only updated by self_attn (the cross sees only last token)
488
+ cross_attention_src=kwargs["cross_attention_src"],
489
+ src_mask=kwargs['src_mask']
490
+ ) # x : [bs, 24, 37, 64]
 
 
491
  return x
492
 
493
+
 
 
 
 
 
 
494
 
495
 
496
  # special attention related function
demo.py CHANGED
@@ -4,10 +4,10 @@ import numpy as np
4
 
5
  print('\n\n\n\n___________________')
6
 
7
- txt = 'dogs in street'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
- sound_generator.set_generation_params(duration=1.24) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7
 
4
 
5
  print('\n\n\n\n___________________')
6
 
7
+ txt = 'dogs in the street'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
+ sound_generator.set_generation_params(duration=.74) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7