Dionyssos commited on
Commit
731cb10
·
1 Parent(s): e83a997

kv history @ each transf layer

Browse files
Files changed (3) hide show
  1. audiocraft/lm.py +20 -35
  2. audiocraft/transformer.py +212 -325
  3. demo.py +2 -2
audiocraft/lm.py CHANGED
@@ -147,7 +147,7 @@ class LMModel(nn.Module):
147
  super().__init__()
148
  self.cfg_coef = cfg_coef
149
 
150
- self.n_draw = 5
151
  self.condition_provider = condition_provider
152
  self.fuser = fuser
153
  self.card = card # 2048 ?
@@ -235,19 +235,16 @@ class LMModel(nn.Module):
235
  def forward(self,
236
  sequence,
237
  condition_tensors=None,
238
- stage = -1):
239
  B, K, S = sequence.shape # linears are n_q
240
-
241
- input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
242
-
243
-
244
  # input_, cross_attention_input = self.fuser(input_, condition_tensors)
245
  cross_attention_input = condition_tensors['description'][0]
246
- # print(f'{input_.shape=} {cross_attention_input.shape=} FUSER LLM')
247
-
248
 
249
- out = self.transformer(input_, cross_attention_src=cross_attention_input,
250
- src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
 
 
251
  if self.out_norm:
252
  out = self.out_norm(out)
253
  # K = 2 because of llm producing 2 tokens?
@@ -323,38 +320,23 @@ class LMModel(nn.Module):
323
  ]
324
 
325
 
326
- for offset in range(1, _gen_sequence.shape[2]): # gen_sequence shape is [B, K, S]):
327
- # print(f'{_gen_sequence.shape=}') # [1,4,16]
328
- # starts from 1 not 0 thus uses the 0:1 as curr sequence
329
- # although this is empty contains -1 ?
330
-
331
-
332
 
333
 
334
- # ====================== SAMPLE NEXT TOK
335
- # next_token = self._sample_next_token(
336
- # _gen_sequence[..., :offset],
337
- # cfg_conditions) # [5, 4, 1]
338
- # --
339
- # def _sample_next_token(self,
340
- # sequence,
341
- # cfg_conditions):
342
- model = self if self._fsdp is None else self._fsdp
343
 
344
- logits = model(_gen_sequence[..., :offset],
345
- condition_tensors=cfg_conditions)
346
- # print(logits.shape, 'Next Logits') # [1, 4, 2, 2048] why 2 tokens on query
347
 
348
- # use cfg
349
- # logits = (3 * logits[1, :, :, :] - 2.4 * logits[0, :, :, :]).transpose(1,0)
 
350
 
351
- # or use 1 of logits
352
- logits = logits[0, :, 0:1, :] # [1,4,2048]
353
  next_token = utils.sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
354
- # =================================
355
 
356
 
357
- _gen_sequence[:, :, offset] = next_token[0, :, 0] #gen_sequence.shape=torch.Size([1, 4, 39])
358
 
359
  duplicate_draw.append(next_token)
360
 
@@ -396,7 +378,10 @@ class LMModel(nn.Module):
396
  # <=> CODES out_codes.shape=torch.Size([1, 4, 35]) 30 2024
397
 
398
 
399
-
400
 
 
 
 
 
401
 
402
  return out_codes #
 
147
  super().__init__()
148
  self.cfg_coef = cfg_coef
149
 
150
+ self.n_draw = 8
151
  self.condition_provider = condition_provider
152
  self.fuser = fuser
153
  self.card = card # 2048 ?
 
235
  def forward(self,
236
  sequence,
237
  condition_tensors=None,
238
+ token_count=None):
239
  B, K, S = sequence.shape # linears are n_q
240
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
 
 
 
241
  # input_, cross_attention_input = self.fuser(input_, condition_tensors)
242
  cross_attention_input = condition_tensors['description'][0]
 
 
243
 
244
+ print(f'{input_.shape=}')
245
+ out = self.transformer(input_,
246
+ cross_attention_src=cross_attention_input,
247
+ token_count=token_count)
248
  if self.out_norm:
249
  out = self.out_norm(out)
250
  # K = 2 because of llm producing 2 tokens?
 
320
  ]
321
 
322
 
323
+ for offset in range(1, _gen_sequence.shape[2]):
324
+
325
+
 
 
 
326
 
327
 
 
 
 
 
 
 
 
 
 
328
 
 
 
 
329
 
330
+ logits = self.forward(_gen_sequence[:, :, offset-1:offset], # bs/n_draw, 4, 1
331
+ condition_tensors=cfg_conditions,
332
+ token_count=offset)
333
 
334
+ # print(f'BEF {logits.shape=} BEF utils.SampleTop5') # AGREES 4 BEF logits.shape=torch.Size([1, 4, 1, 2048]) BEF utils.SampleTop5
 
335
  next_token = utils.sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
336
+
337
 
338
 
339
+ _gen_sequence[:, :, offset] = next_token[0, :, 0] # next_token=[1,4,6] gen_seq=[1, 4, 39]
340
 
341
  duplicate_draw.append(next_token)
342
 
 
378
  # <=> CODES out_codes.shape=torch.Size([1, 4, 35]) 30 2024
379
 
380
 
 
381
 
382
+ # Clean Transformer MHA k_history v_history
383
+ for lay in self.transformer.layers:
384
+ lay.self_attn.k_history = None
385
+ lay.self_attn.v_history = None
386
 
387
  return out_codes #
audiocraft/transformer.py CHANGED
@@ -3,26 +3,36 @@ from einops import rearrange
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
 
6
  from xformers import ops
7
 
 
8
  _efficient_attention_backend: str = 'torch'
9
 
10
 
11
- def set_efficient_attention_backend(backend: str = 'torch'):
12
- # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
13
- global _efficient_attention_backend
14
- assert _efficient_attention_backend in ['xformers', 'torch']
15
- _efficient_attention_backend = backend
16
 
17
 
18
 
 
 
 
 
 
19
 
20
 
21
 
22
 
23
 
24
- def create_norm_fn(norm_type, dim, **kwargs):
 
25
 
 
 
 
 
 
 
 
26
  if norm_type == 'layer_norm':
27
  return nn.LayerNorm(dim, eps=1e-5, **kwargs)
28
  else:
@@ -48,11 +58,27 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
48
  adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
49
  max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
50
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
51
- # print('==============CONCAT 3 ============'
52
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
53
-
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
 
@@ -62,36 +88,37 @@ class StreamingMultiheadAttention(nn.Module):
62
 
63
  def __init__(self,
64
  embed_dim,
65
- num_heads,
66
- dropout=0.0,
67
- bias: bool = True,
68
- causal: bool = False,
69
- past_context: tp.Optional[int] = None,
70
- custom: bool = False,
71
- memory_efficient: bool = False,
72
- attention_as_float32: bool = False,
73
  cross_attention: bool = False,
74
- qk_layer_norm: bool = False,
75
  kv_repeat: int = 1,
76
  device=None, dtype=None):
77
  super().__init__()
78
  factory_kwargs = {'device': device, 'dtype': dtype}
79
  if past_context is not None:
80
  assert causal
 
81
  self.embed_dim = embed_dim
82
- self.causal = causal
83
- self.past_context = past_context
 
 
84
  self.memory_efficient = memory_efficient
85
  self.attention_as_float32 = attention_as_float32
 
86
  self.cross_attention = cross_attention
 
87
  self.num_heads = num_heads
88
  self.dropout = dropout
89
  self.kv_repeat = kv_repeat
90
- if cross_attention:
91
- assert not causal, "Causal cannot work with cross attention."
92
- if memory_efficient:
93
- _verify_xformers_memory_efficient_compat()
94
- self.custom = _is_custom(custom, memory_efficient)
 
 
95
  if self.custom:
96
  out_dim = embed_dim
97
  assert num_heads % kv_repeat == 0
@@ -109,12 +136,11 @@ class StreamingMultiheadAttention(nn.Module):
109
  if bias:
110
  self.out_proj.bias.data.zero_()
111
  else:
112
- print('mha ini else')
113
- self.qk_layer_norm = qk_layer_norm
114
- if qk_layer_norm:
115
- print('QK norm')
116
-
117
-
118
 
119
  def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
120
  if not self.custom:
@@ -124,185 +150,140 @@ class StreamingMultiheadAttention(nn.Module):
124
  if prefix + key in state_dict:
125
  state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
126
  super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
 
 
 
 
 
127
 
 
128
  def forward(self,
129
  query,
130
- key,
131
- value,
132
- key_padding_mask=None,
133
- need_weights=False,
134
- attn_mask=None,
135
- is_causal=False):
136
- # 2=cond/uncond
137
- # 24=heads
138
- # 1=seqlen
139
- # 64=channel
140
- #
141
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
142
- # 43
143
- # ____________
144
- # SELF
145
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
146
- # sa_ x.shape=torch.Size([2, 1, 1536])
147
-
148
- # X
149
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
150
- # 44
151
- # ____________
152
- # SELF
153
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
154
- # sa_ x.shape=torch.Size([2, 1, 1536])
155
-
156
- # X
157
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
158
- # 45
159
- # ____________
160
- # SELF
161
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
162
- # sa_ x.shape=torch.Size([2, 1, 1536])
163
-
164
- # X
165
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
166
- # 46
167
- # ____________
168
- # SELF
169
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
170
- # sa_ x.shape=torch.Size([2, 1, 1536])
171
-
172
- # X
173
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 7, 64]) v.shape=torch.Size([2, 24, 7, 64]) CROSSSattn
174
- # 47
175
- # ____________
176
- # SELF
177
- # q.shape=torch.Size([2, 24, 1, 64]) k.shape=torch.Size([2, 24, 25, 64]) v.shape=torch.Size([2, 24, 25, 64]) CROSSSattn
178
- # sa_ x.shape=torch.Size([2, 1, 1536])
179
-
180
-
181
-
182
-
183
- assert not is_causal, ("New param added in torch 2.0.1 not supported, "
184
- "use the causal args in the constructor.")
185
- # print(f'{query.shape=} {key.shape=} {value.shape=} MHA')
186
- time_dim = 2
187
- if time_dim == 2:
188
- layout = "b h t d"
189
- else:
190
- layout = "b t h d"
191
- dtype = query.dtype
192
 
193
 
194
- custom_attn_mask = attn_mask is not None
195
 
196
  if self.custom:
197
-
198
-
199
 
200
  if self.cross_attention:
201
- # print('\n\n\n\nCROSS\n\n\n\n')
202
-
203
-
204
  dim = self.in_proj_weight.shape[0] // 3
205
  if self.in_proj_bias is None:
206
  bias_q, bias_k, bias_v = None, None, None
207
  else:
208
- print('no self proj bi')
209
-
 
210
  q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
211
- # print(f'{q.shape=} TRANSF FORW who concaten')
212
  # todo: when streaming, we could actually save k, v and check the shape actually match.
213
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
214
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
215
- if self.qk_layer_norm is True:
216
- q = self.q_layer_norm(q)
217
- k = self.k_layer_norm(k)
218
 
219
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
220
-
221
  else:
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
224
  if self.kv_repeat == 1:
225
- if time_dim == 2:
226
- bound_layout = "b h p t d"
227
- else:
228
- bound_layout = "b t p h d"
229
-
230
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
231
-
232
-
233
- # print(f'{query.shape=} before unbind') # [2, 1, 4 , 2048] already bs=2
234
  q, k, v = ops.unbind(packed, dim=2)
235
- # print(f'{q.shape=} {v.shape=} @L331 trasnforemr.py') # packed is bs=2
236
- else:
237
-
238
- print("ELSE kv rp")
239
 
240
- if self.qk_layer_norm is True:
241
-
242
- print('QL lay norm')
243
 
244
- if self.kv_repeat > 1:
 
245
 
246
- print('Expand repear 2')
247
-
248
- if self.attention_as_float32:
249
- print('AS FLOAT32')
 
 
 
250
 
 
 
 
 
 
251
  if self.memory_efficient:
252
- if custom_attn_mask:
253
-
254
- print('CUSTOM ATTN MSK')
255
-
256
 
257
  p = self.dropout if self.training else 0
258
  if _efficient_attention_backend == 'torch':
259
-
260
- # print(f'{q.shape=} {k.shape=} {v.shape=} 90')
261
- print(f'{x.sum()=} {q.sum()=} {k.sum()=} {v.sum()=} 90 variation of qkv during 47')
262
- # the k.sum(),v.sum() changes over the 47transfs how is that possible if self._sa
263
- # has q-len = 1.
264
- #
265
- #
266
-
267
  x = torch.nn.functional.scaled_dot_product_attention(
268
- q, k, v, is_causal=attn_mask is not None, dropout_p=p)
269
- else:
270
-
271
- print('MHA OPS')
272
-
273
-
274
- else:
275
-
276
- print('CONSISTENCY ')
277
-
278
-
279
-
280
-
281
- x = x.to(dtype)
282
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
283
  x = self.out_proj(x)
284
- else:
285
- raise NotImplementedError
286
 
287
- return x, None
288
 
 
 
289
 
290
- class StreamingTransformerLayer(nn.TransformerEncoderLayer):
291
- def __init__(self,
292
- d_model,
293
- num_heads,
294
- dim_feedforward=2048,
295
- dropout=0.1,
296
- bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
297
- past_context: tp.Optional[int] = None, custom: bool = False,
298
- memory_efficient: bool = False, attention_as_float32: bool = False,
299
- qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
300
- cross_attention: bool = False,
301
- # rope=None,
302
  attention_dropout: tp.Optional[float] = None,
303
- kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
304
- super().__init__(d_model, num_heads, dim_feedforward, dropout,
305
- device=device, dtype=dtype, batch_first=True, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
306
  factory_kwargs = {'device': device, 'dtype': dtype}
307
  # Redefine self_attn to our streaming multi-head attention
308
  attn_kwargs: tp.Dict[str, tp.Any] = {
@@ -314,123 +295,84 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
314
  'memory_efficient': memory_efficient,
315
  'attention_as_float32': attention_as_float32,
316
  }
317
- self.self_attn=StreamingMultiheadAttention(
318
- causal=causal,
319
- past_context=past_context,
320
- # rope=rope,
321
- qk_layer_norm=qk_layer_norm,
322
- kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
323
  # Redefine feedforward layers to expose bias parameter
324
  self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
325
  self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
 
326
 
327
-
328
 
329
- self.cross_attention = None # default
330
  if cross_attention:
331
  self.cross_attention = StreamingMultiheadAttention(
332
- cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
333
- **attn_kwargs, **factory_kwargs)
334
- # Norm and dropout
 
335
  self.dropout_cross = nn.Dropout(dropout)
336
- # eps value matching that used in PyTorch reference implementation.
337
- self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
338
 
 
339
  self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
340
  self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
 
 
 
 
 
 
341
 
342
- # ENVS....d4/lib/python3.10/site-packages/torch/nn/modules/transformer.py @TransformerEncoderLayer
343
- def _sa_block(self, q, k, v):
344
- x = self.self_attn(q,
345
- k,
346
- v,
347
- attn_mask=None,
348
- key_padding_mask=None,
349
- need_weights=False,
350
- is_causal=None)[0]
351
- return self.dropout1(x)
352
-
353
- def _cross_attention_block(self,
354
- src,
355
- cross_attention_src):
356
 
357
- # queries are from src, keys and values from cross_attention_src.
358
- x = self.cross_attention(
359
- src, cross_attention_src, cross_attention_src, need_weights=False)[0]
360
- return self.dropout_cross(x) # type: ignore
361
-
362
- def forward(self,
363
- src,
364
- src_mask=None,
365
- src_key_padding_mask=None, # key = value = looooong I think I pass them inversed
366
- cross_attention_src=None):
367
-
368
-
369
 
370
- if self.norm_first:
371
- print('selfattn')
372
- history = self.norm1(src)
373
- x = history[:, -1:, :]
374
-
375
- # THIS IS COMPUTED with 1 timestep
376
- # just before the call there is cat([past_k, k])
377
- # Thus we just
378
- x = x + self._sa_block(x, # THIS should be square as the history is updated
379
- # then the -1 item of history goes to the text x text
380
- #
381
- history,
382
- history)
383
- print('crossattn')
384
- if cross_attention_src is not None:
385
- x = x + self._cross_attention_block(
386
- self.norm_cross(x),
387
- cross_attention_src)
388
-
389
- else:
390
- print('NOT IMPL')
391
-
392
-
393
- x = x + self._ff_block(self.norm2(x))
394
- else:
395
- print('NLAST')
396
 
 
397
  return x
398
-
399
-
400
 
401
 
402
  class StreamingTransformer(nn.Module):
403
- '''layer_class=<class 'audiocraft.transformer.StreamingTransformerLayer'> StrTrnsf'''
404
-
405
- def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
406
- dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
407
- causal: bool = False, past_context: tp.Optional[int] = None,
408
- custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
 
 
 
 
 
409
  cross_attention: bool = False,
410
- positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
411
- xpos=False,
412
- lr=None,
413
- weight_decay=None,
414
  layer_class=StreamingTransformerLayer,
415
- checkpointing='none',
416
  device=None,
417
- dtype=None,
418
- **kwargs):
419
  super().__init__()
420
  assert d_model % num_heads == 0
421
-
422
  self.positional_embedding = positional_embedding
423
  self.max_period = max_period
424
  self.positional_scale = positional_scale
425
- self.weight_decay = weight_decay
426
- self.lr = lr
 
 
427
 
428
- assert positional_embedding in ['sin', 'rope', 'sin_rope']
429
  self.checkpointing = checkpointing
430
 
431
- assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
432
- if self.checkpointing.startswith('xformers'):
433
- _verify_xformers_internal_compat()
434
 
435
  self.layers = nn.ModuleList()
436
  for idx in range(num_layers):
@@ -438,90 +380,35 @@ class StreamingTransformer(nn.Module):
438
  layer_class(
439
  d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
440
  dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
441
- causal=causal, past_context=past_context, custom=custom,
442
  memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
443
- cross_attention=cross_attention,
444
- # rope=self.rope,
445
  device=device, dtype=dtype, **kwargs))
446
 
447
  if self.checkpointing != 'none':
448
- print('Checkpointing????????????')
449
  for layer in self.layers:
450
  # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
451
  # backward hook inside of FSDP...
452
  layer._magma_checkpointed = True # type: ignore
453
 
454
-
455
 
456
  def forward(self, x: torch.Tensor, *args, **kwargs):
457
- # print(f'{x.shape=} StreamingTransf') # [1, 1, 1536] Always no batch==2 here
458
- # why is this called with time-len = 1? Shouldnt be called with context?
459
- B, T, C = x.shape
460
-
461
 
 
462
 
463
 
464
- if self.positional_embedding in ['sin',
465
- 'sin_rope']:
466
- positions = torch.arange(T, device=x.device).view(1, -1, 1)
467
 
 
 
468
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
469
  x = x + self.positional_scale * pos_emb
470
-
471
-
472
-
473
- # 47x transformer layers for frozen history
474
- # -> history is updated by self._sa() althought her length is fixed
475
- # -> the q that comes out of the text x text cross attn
476
- # is given as q to the next lay's self._sa() with updated history
477
- # ->
478
- # ->
479
- for _, lay in enumerate(self.layers):
480
- print(f'_________________\n{_}')
481
- # 1 q = last_token x history x history
482
- # 2 next_token = q x text x text
483
 
484
- # x preserves full history for self._sa(). After all transformers we return only last -1 tok
485
- x, history = lay(
486
- x,
487
- history=history, # only updated by self_attn (the cross sees only last token)
488
- cross_attention_src=kwargs["cross_attention_src"],
489
- src_mask=kwargs['src_mask']
490
- ) # x : [bs, 24, 37, 64]
491
- return x
492
-
493
-
494
-
495
-
496
- # special attention related function
497
-
498
- def _verify_xformers_memory_efficient_compat():
499
- try:
500
- from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
501
- except ImportError:
502
- raise ImportError(
503
- "xformers is not installed. Please install it and try again.\n"
504
- "To install on AWS and Azure, run \n"
505
- "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
506
- "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n"
507
- "To install on FAIR Cluster, run \n"
508
- "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
509
- "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n")
510
-
511
-
512
- def _verify_xformers_internal_compat():
513
- try:
514
- from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
515
- except ImportError:
516
- raise ImportError(
517
- "Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
518
- "To install on AWS and Azure, run \n"
519
- "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
520
- "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n"
521
- "To install on FAIR Cluster, run \n"
522
- "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
523
- "pip install -U git+https://[email protected]/fairinternal/xformers.git#egg=xformers\n")
524
-
525
 
526
- def _is_custom(custom: bool, memory_efficient: bool):
527
- return custom or memory_efficient
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
7
  from xformers import ops
8
 
9
+
10
  _efficient_attention_backend: str = 'torch'
11
 
12
 
 
 
 
 
 
13
 
14
 
15
 
16
+ def _get_attention_time_dimension(memory_efficient: bool) -> int:
17
+ if _efficient_attention_backend == 'torch' and memory_efficient:
18
+ return 2
19
+ else:
20
+ return 1
21
 
22
 
23
 
24
 
25
 
26
+ def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
27
+ """Create normalization module for transformer encoder layer.
28
 
29
+ Args:
30
+ norm_type (str): Normalization method.
31
+ dim (int): Dimension of the normalized layer.
32
+ **kwargs (dict): Additional parameters for normalization layer.
33
+ Returns:
34
+ nn.Module: Normalization module.
35
+ """
36
  if norm_type == 'layer_norm':
37
  return nn.LayerNorm(dim, eps=1e-5, **kwargs)
38
  else:
 
58
  adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
59
  max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
60
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
61
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
 
 
62
 
63
 
64
+ def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
65
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
66
+ if n_rep == 1:
67
+ return x
68
+ if _efficient_attention_backend == 'torch' and memory_efficient:
69
+ bs, n_kv_heads, slen, head_dim = x.shape
70
+ return (
71
+ x[:, :, None, :, :]
72
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
73
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
74
+ )
75
+ else:
76
+ bs, slen, n_kv_heads, head_dim = x.shape
77
+ return (
78
+ x[:, :, :, None, :]
79
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
80
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
81
+ )
82
 
83
 
84
 
 
88
 
89
  def __init__(self,
90
  embed_dim,
91
+ num_heads, dropout: float = 0.0, bias: bool = True,
92
+ causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
93
+ memory_efficient: bool = False, attention_as_float32: bool = False,
 
 
 
 
 
94
  cross_attention: bool = False,
 
95
  kv_repeat: int = 1,
96
  device=None, dtype=None):
97
  super().__init__()
98
  factory_kwargs = {'device': device, 'dtype': dtype}
99
  if past_context is not None:
100
  assert causal
101
+
102
  self.embed_dim = embed_dim
103
+
104
+ self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
105
+ self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
106
+
107
  self.memory_efficient = memory_efficient
108
  self.attention_as_float32 = attention_as_float32
109
+
110
  self.cross_attention = cross_attention
111
+
112
  self.num_heads = num_heads
113
  self.dropout = dropout
114
  self.kv_repeat = kv_repeat
115
+
116
+
117
+
118
+
119
+ self.custom = True #_is_custom(custom, memory_efficient)
120
+ if not self.custom:
121
+ print(f'{self.custom}')
122
  if self.custom:
123
  out_dim = embed_dim
124
  assert num_heads % kv_repeat == 0
 
136
  if bias:
137
  self.out_proj.bias.data.zero_()
138
  else:
139
+ assert kv_repeat == 1
140
+ self.mha = nn.MultiheadAttention(
141
+ embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
142
+ **factory_kwargs)
143
+
 
144
 
145
  def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
146
  if not self.custom:
 
150
  if prefix + key in state_dict:
151
  state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
152
  super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
153
+
154
+
155
+
156
+
157
+
158
 
159
+
160
  def forward(self,
161
  query,
162
+ key=None, # ignores those 2 args if not self.cross_attn
163
+ value=None):
164
+
165
+
166
+ # time_dim = _get_attention_time_dimension(self.memory_efficient)
167
+ # if time_dim == 2:
168
+ layout = "b h t d"
169
+ # else:
170
+ # layout = "b t h d"
171
+ # dtype = query.dtype
172
+
173
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
 
176
+
177
 
178
  if self.custom:
 
 
179
 
180
  if self.cross_attention:
181
+ # Different queries, keys, values, we have to spit manually the weights
182
+ # before applying the linear.
 
183
  dim = self.in_proj_weight.shape[0] // 3
184
  if self.in_proj_bias is None:
185
  bias_q, bias_k, bias_v = None, None, None
186
  else:
187
+ bias_q = self.in_proj_bias[:dim]
188
+ bias_k = self.in_proj_bias[dim: 2 * dim]
189
+ bias_v = self.in_proj_bias[2 * dim:]
190
  q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
 
191
  # todo: when streaming, we could actually save k, v and check the shape actually match.
192
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
193
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
 
 
 
194
 
195
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
 
196
  else:
197
 
198
+ # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
199
+ if self.k_history is not None:
200
+ #
201
+ # pk.shape=torch.Size([2, 24, 3, 64]) k.shape=torch.Size([2, 24, 1, 64]) CONCAT
202
+ # has to be 4D with batch 1 due to single condition 3=seqlen
203
+ # 24 heads 64 dimofh
204
+ self.k_history = torch.cat([self.k_history, query], 2)
205
+ self.v_history = torch.cat([self.v_history, query], 2)
206
+ else:
207
+ # init on 1st token (for all 47 transf layers)
208
+ self.k_history = query
209
+ self.v_history = query
210
+
211
+
212
  projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
213
  if self.kv_repeat == 1:
214
+ # if time_dim == 2:
215
+ bound_layout = "b h p t d"
216
+ # else:
217
+ # bound_layout = "b t p h d"
 
218
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
 
 
 
219
  q, k, v = ops.unbind(packed, dim=2)
220
+
221
+
 
 
222
 
 
 
 
223
 
224
+ # KV COMPLETION ONLY ON SELF ATTENTION
225
+ #======================================================
226
 
227
+ # so the previous layer passes you here the k,v having concatenated all previous
228
+ #
229
+ # also return those 2 for the next transformer layer
230
+ #
231
+ # also clean up after ending the transformer? NOOOOOOOOOOOOO is goes along tokens
232
+ #
233
+ # also why completekv does not grow longer during the 47 transformers but changes sum
234
 
235
+ # k, v = self._complete_kv(k, v)
236
+ # print(k.sum(), v.sum(), k.shape, v.shape,'ATTNext')
237
+
238
+ if self.attention_as_float32:
239
+ q, k, v = [x.float() for x in [q, k, v]]
240
  if self.memory_efficient:
241
+ # print('EVER IN MEMORY EFFICIENT A')
242
+
 
 
243
 
244
  p = self.dropout if self.training else 0
245
  if _efficient_attention_backend == 'torch':
246
+ # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(), 'CROSSopen')
 
 
 
 
 
 
 
247
  x = torch.nn.functional.scaled_dot_product_attention(
248
+ q, k, v, is_causal=False, dropout_p=p
249
+ )
250
+
251
+ x = x.to(q.dtype)
 
 
 
 
 
 
 
 
 
 
252
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
253
  x = self.out_proj(x)
254
+ return x
 
255
 
 
256
 
257
+ class StreamingTransformerLayer(nn.Module): #nn.TransformerEncoderLayer):
258
+ # INHERITS MHA !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
259
 
260
+ def __init__(self,
261
+ d_model: int,
262
+ num_heads: int,
263
+ dim_feedforward: int = 2048,
264
+ dropout: float = 0.1,
265
+ bias_ff: bool = True,
266
+ bias_attn: bool = True,
267
+ custom: bool = False,
268
+ memory_efficient: bool = False,
269
+ attention_as_float32: bool = False,
270
+ cross_attention: bool = False,
 
271
  attention_dropout: tp.Optional[float] = None,
272
+ kv_repeat: int = 1,
273
+ norm: str = 'layer_norm',
274
+ device=None,
275
+ dtype=None,
276
+ **kwargs):
277
+
278
+
279
+ super().__init__() #d_model, num_heads, dim_feedforward, dropout,
280
+ #device=device, dtype=dtype, batch_first=True, **kwargs)
281
+ # print(kwargs['activation'], 'ACTIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII\n\n\n\n')
282
+ # -- EN Layer
283
+ # DOES NOT INHERIT NO VARIABLE FROM nn.TransformerEncoderLayer only the _sa_block function
284
+
285
+ # -- EN layer
286
+
287
  factory_kwargs = {'device': device, 'dtype': dtype}
288
  # Redefine self_attn to our streaming multi-head attention
289
  attn_kwargs: tp.Dict[str, tp.Any] = {
 
295
  'memory_efficient': memory_efficient,
296
  'attention_as_float32': attention_as_float32,
297
  }
298
+ self.self_attn = StreamingMultiheadAttention(
299
+ kv_repeat=kv_repeat,
300
+ **attn_kwargs,
301
+ **factory_kwargs) # type: ignore
 
 
302
  # Redefine feedforward layers to expose bias parameter
303
  self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
304
  self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
305
+ # print('LAYER scale', layer_scale, '\n\n\n\n\n\n\n\n\n') # always
306
 
 
307
 
308
+ self.cross_attention= None
309
  if cross_attention:
310
  self.cross_attention = StreamingMultiheadAttention(
311
+ cross_attention=True,
312
+ **attn_kwargs,
313
+ **factory_kwargs)
314
+
315
  self.dropout_cross = nn.Dropout(dropout)
 
 
316
 
317
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
318
  self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
319
  self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
320
+
321
+
322
+ def forward(self,
323
+ src,
324
+ cross_attention_src=None): # txtcond
325
+ '''T layer'''
326
 
327
+ x = src
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ x = x + self.self_attn(self.norm1(x))
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ if cross_attention_src is not None:
332
+ x = x + self.cross_attention(
333
+ query = self.norm_cross(x),
334
+ key = cross_attention_src,
335
+ value = cross_attention_src) # txtcondition
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
338
  return x
 
 
339
 
340
 
341
  class StreamingTransformer(nn.Module):
342
+
343
+ def __init__(self, d_model: int,
344
+ num_heads: int,
345
+ num_layers: int,
346
+ dim_feedforward: int = 2048,
347
+ dropout: float = 0.1,
348
+ bias_ff: bool = True,
349
+ bias_attn: bool = True,
350
+ custom: bool = False,
351
+ memory_efficient: bool = False,
352
+ attention_as_float32: bool = False,
353
  cross_attention: bool = False,
354
+ positional_embedding: str = 'sin',
355
+ max_period: float = 10_000,
356
+ positional_scale: float = 1,
 
357
  layer_class=StreamingTransformerLayer,
358
+ checkpointing: str = 'none',
359
  device=None,
360
+ dtype=None, **kwargs):
 
361
  super().__init__()
362
  assert d_model % num_heads == 0
363
+
364
  self.positional_embedding = positional_embedding
365
  self.max_period = max_period
366
  self.positional_scale = positional_scale
367
+
368
+
369
+
370
+ # self._stream_off = 0 # the llm should reinitialize this at ery generate()
371
 
 
372
  self.checkpointing = checkpointing
373
 
374
+
375
+
 
376
 
377
  self.layers = nn.ModuleList()
378
  for idx in range(num_layers):
 
380
  layer_class(
381
  d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
382
  dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
383
+ custom=custom,
384
  memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
385
+ cross_attention=cross_attention,
 
386
  device=device, dtype=dtype, **kwargs))
387
 
388
  if self.checkpointing != 'none':
 
389
  for layer in self.layers:
390
  # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
391
  # backward hook inside of FSDP...
392
  layer._magma_checkpointed = True # type: ignore
393
 
394
+
395
 
396
  def forward(self, x: torch.Tensor, *args, **kwargs):
 
 
 
 
397
 
398
+ B, T, C = x.shape
399
 
400
 
401
+ if self.positional_embedding in ['sin', 'sin_rope']:
 
 
402
 
403
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
404
+ positions = positions + kwargs['token_count'] #offsets.view(-1, 1, 1)
405
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
406
  x = x + self.positional_scale * pos_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
+ for j, lay in enumerate(self.layers):
411
+ print(f'_________________________{j}___________________')
412
+ x = lay(x, cross_attention_src=kwargs["cross_attention_src"]) # txt cond
413
+ # each layer (mha) keeps history of its own k,v for all tokens
414
+ return x
demo.py CHANGED
@@ -4,10 +4,10 @@ import numpy as np
4
 
5
  print('\n\n\n\n___________________')
6
 
7
- txt = 'dogs in the street'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
- sound_generator.set_generation_params(duration=.74) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7
 
4
 
5
  print('\n\n\n\n___________________')
6
 
7
+ txt = 'dogs barging in the street'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
+ sound_generator.set_generation_params(duration=.46) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7