Dionyssos commited on
Commit
6cb7713
·
1 Parent(s): 1766442

apply null cond lm.forw()

Browse files
Files changed (4) hide show
  1. audiocraft/builders.py +1 -1
  2. audiocraft/lm.py +54 -46
  3. audiocraft/transformer.py +10 -7
  4. demo.py +2 -1
audiocraft/builders.py CHANGED
@@ -79,7 +79,7 @@ class AudioGen(nn.Module):
79
  conditions=attributes,
80
  max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
81
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
82
- print('______________\nGENTOk 5', gen_tokens.shape)
83
  print('GENAUD 5', x.sum())
84
  return x
85
 
 
79
  conditions=attributes,
80
  max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
81
  x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
82
+ print('______________\nGENTOk 5', gen_tokens)
83
  print('GENAUD 5', x.sum())
84
  return x
85
 
audiocraft/lm.py CHANGED
@@ -9,8 +9,13 @@ from dataclasses import dataclass
9
  from functools import partial
10
  from torch import nn
11
  from audiocraft.activations import get_activation_fn
 
12
 
13
-
 
 
 
 
14
 
15
 
16
 
@@ -141,7 +146,7 @@ class LMModel(nn.Module):
141
  dim: int = 128,
142
  num_heads: int = 8,
143
  hidden_scale: int = 4,
144
- norm: str = 'layer_norm',
145
  norm_first: bool = False,
146
  emb_lr: tp.Optional[float] = None,
147
  bias_proj: bool = True,
@@ -155,7 +160,7 @@ class LMModel(nn.Module):
155
  self.cfg_coef = cfg_coef
156
  self.condition_provider = condition_provider
157
  self.card = card # 2048 ?
158
- self.n_draw = 8 # replicate so many times the generation of each text in batch
159
  embed_dim = self.card + 1
160
  self.n_q = n_q
161
  self.dim = dim
@@ -233,71 +238,73 @@ class LMModel(nn.Module):
233
  def special_token_id(self) -> int:
234
  return self.card
235
 
236
- def sample_top_k(self, p, k=249):
237
- bs, _, _, hidden = p.shape # logits [3, 4, 1, 2048]
238
-
239
- p = torch.softmax(p, dim=3)
240
- top_k_value, i250 = torch.topk(p, k, dim=3) # [3, 4, 1, k]
241
- min_value_top_k = top_k_value[:, :, :, -1:]
242
- p *= (p >= min_value_top_k).float() # zero low probs
243
- p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
244
-
245
 
246
- # BRING THE nq = 4 IN BATCH
247
- p = p.reshape(bs * self.n_q, hidden)
248
- out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
249
- num_samples=self.n_draw,
250
- replacement=False) # [bs*4, self.n_draw]
251
- return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs, self.n_draw, 4]
252
 
253
  def forward(self,
254
  sequence,
255
  condition_tensors=None,
256
  token_count=None):
 
 
 
257
 
258
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
259
- out = self.transformer(input_,
260
- cross_attention_src=condition_tensors['description'][0],
261
  token_count=token_count)
262
  if self.out_norm:
263
  out = self.out_norm(out)
264
 
265
- logits = torch.stack([self.linears[k](out) for k in range(self.n_q)], dim=1)
266
-
267
- return logits # [bs, 4, 1, 2048]
 
 
 
 
 
 
 
 
268
 
 
 
 
 
 
 
 
269
 
270
- # GENERATE class revert_codebook_patterns()
271
  @torch.no_grad()
272
- def generate(self,
273
- prompt = None,
274
- conditions = [],
275
  max_gen_len=256):
276
 
277
- print(f'{prompt=} {conditions=}')
278
- first_param = next(iter(self.parameters()))
279
- device = first_param.device
280
-
281
-
282
 
283
  tokenized = self.condition_provider.tokenize(conditions)
284
 
285
- # print(f'TOKENIZ, {tokenized.keys()=}, {tokenized=}') # 'description'
286
- # TOKENIZ {'description': {'input_ids': tensor([[3887, 16, 2815, 1],
287
- # [3887, 16, 2815, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1],
288
- # [1, 1, 1, 1]], device='cuda:0')}}
289
 
290
  cfg_conditions = self.condition_provider(tokenized)
291
 
292
 
293
- # print(f'CFGcon, {cfg_conditions.keys()=}, {cfg_conditions["description"][0].shape=}')
294
- # USE THIS ATTENTION MASK IF NOT SAME LEN;
295
- bs, _7, _1536 = cfg_conditions['description'][0].shape # [bs, textlen, 1536]
 
 
 
 
 
 
 
 
 
 
296
  pattern = self.pattern_provider.get_pattern(max_gen_len)
297
  gen_codes = torch.full((bs,
298
  self.n_q,
299
- max_gen_len), -1, dtype=torch.long, device=device)
300
-
 
301
  gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
302
  _, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
303
 
@@ -317,13 +324,13 @@ class LMModel(nn.Module):
317
 
318
  for offset in range(1, audiodur):
319
 
320
- # pass only 0-th draw in forward
321
- logits = self.forward(gen_sequence[:, 0, :, offset-1:offset],
322
- condition_tensors=cfg_conditions,
323
- token_count=offset) # [bs, 4, 1, 2048]
324
 
325
 
326
- next_token = self.sample_top_k(logits) # [bs, n_draw, 4]
327
 
328
  # MASK is not full 1---- HAS 4 x audioduration PATTERN
329
  m = mask[:, :, :, offset]
@@ -346,6 +353,7 @@ class LMModel(nn.Module):
346
  out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
347
  out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
348
  print(out_codes.shape, 'o')
 
349
 
350
  # Clear Transformer k/v history (Different history is kept by 48x selfattn)
351
  for lay in self.transformer.layers:
 
9
  from functools import partial
10
  from torch import nn
11
  from audiocraft.activations import get_activation_fn
12
+ import numpy as np
13
 
14
+ def _shift(x):
15
+ n = x.shape[0]
16
+ i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD do we have very short segments
17
+ x = torch.roll(x, i, dims=2)
18
+ return x
19
 
20
 
21
 
 
146
  dim: int = 128,
147
  num_heads: int = 8,
148
  hidden_scale: int = 4,
149
+ norm: str = 'layer_norm',
150
  norm_first: bool = False,
151
  emb_lr: tp.Optional[float] = None,
152
  bias_proj: bool = True,
 
160
  self.cfg_coef = cfg_coef
161
  self.condition_provider = condition_provider
162
  self.card = card # 2048 ?
163
+ self.n_draw = 14 # replicate so many times the generation of each text in batch
164
  embed_dim = self.card + 1
165
  self.n_q = n_q
166
  self.dim = dim
 
238
  def special_token_id(self) -> int:
239
  return self.card
240
 
 
 
 
 
 
 
 
 
 
241
 
 
 
 
 
 
 
242
 
243
  def forward(self,
244
  sequence,
245
  condition_tensors=None,
246
  token_count=None):
247
+ # takes bs=3 duplicates null condition to bs=6 splits logits to cfg returns bs=3
248
+
249
+ bs, _, _ = sequence.shape # sequence [bs, n_draw,4]
250
 
251
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
252
+ out = self.transformer(torch.cat([input_, input_], 0),
253
+ cross_attention_src=condition_tensors,
254
  token_count=token_count)
255
  if self.out_norm:
256
  out = self.out_norm(out)
257
 
258
+ logits = torch.stack([self.linears[k](out) for k in range(self.n_q)], dim=1)#[2*bs,4,1,2048]
259
+
260
+ logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] # [3, 4, 1, 2048]
261
+
262
+ # SAMPLE TOP K
263
+ k = 250
264
+ p = torch.softmax(logits, dim=3)
265
+ top_k_value, _ = torch.topk(p, k, dim=3) # [3, 4, 1, k]
266
+ min_value_top_k = top_k_value[:, :, :, -1:]
267
+ p *= (p >= min_value_top_k).float() # zero low probs
268
+ p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
269
 
270
+
271
+ # BRING THE nq = 4 IN BATCH
272
+ p = p.reshape(bs * self.n_q, 2048)
273
+ out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
274
+ num_samples=self.n_draw,
275
+ replacement=True) # [bs*4, self.n_draw]
276
+ return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs=3not6, self.n_draw, 4]
277
 
 
278
  @torch.no_grad()
279
+ def generate(self, conditions = [],
 
 
280
  max_gen_len=256):
281
 
 
 
 
 
 
282
 
283
  tokenized = self.condition_provider.tokenize(conditions)
284
 
 
 
 
 
285
 
286
  cfg_conditions = self.condition_provider(tokenized)
287
 
288
 
289
+
290
+ # NULL CONDITION
291
+ text_condition = cfg_conditions['description'][0]
292
+ bs, _, _ = text_condition.shape
293
+ text_condition = torch.cat(
294
+ [
295
+ text_condition,
296
+ torch.zeros_like(text_condition)
297
+ ], 0)
298
+
299
+
300
+
301
+
302
  pattern = self.pattern_provider.get_pattern(max_gen_len)
303
  gen_codes = torch.full((bs,
304
  self.n_q,
305
+ max_gen_len), -1, dtype=torch.long,
306
+ device=text_condition.device)
307
+
308
  gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
309
  _, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
310
 
 
324
 
325
  for offset in range(1, audiodur):
326
 
327
+ # forward duplicates the query to nullcond - then cfg & returns deduplicate token
328
+ next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset],
329
+ condition_tensors=text_condition,
330
+ token_count=offset-1) # [bs, 4, 1, 2048]
331
 
332
 
333
+
334
 
335
  # MASK is not full 1---- HAS 4 x audioduration PATTERN
336
  m = mask[:, :, :, offset]
 
353
  out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
354
  out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
355
  print(out_codes.shape, 'o')
356
+ out_codes = _shift(out_codes)
357
 
358
  # Clear Transformer k/v history (Different history is kept by 48x selfattn)
359
  for lay in self.transformer.layers:
audiocraft/transformer.py CHANGED
@@ -175,6 +175,7 @@ class StreamingMultiheadAttention(nn.Module):
175
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
176
 
177
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
 
178
  else:
179
  # 1st projected makes k,v (instantaneous)
180
  # 2nd cat
@@ -190,6 +191,7 @@ class StreamingMultiheadAttention(nn.Module):
190
  # bound_layout = "b t p h d"
191
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
192
  q, k, v = ops.unbind(packed, dim=2)
 
193
 
194
  if self.k_history is not None:
195
  #
@@ -198,8 +200,10 @@ class StreamingMultiheadAttention(nn.Module):
198
  # 24 heads 64 dimofh
199
  self.k_history = torch.cat([self.k_history, k], 2)
200
  self.v_history = torch.cat([self.v_history, v], 2)
 
201
  else:
202
  # init on 1st token (for all 47 transf layers)
 
203
  self.k_history = k
204
  self.v_history = v
205
 
@@ -209,7 +213,7 @@ class StreamingMultiheadAttention(nn.Module):
209
 
210
 
211
  # KV COMPLETION ONLY ON SELF ATTENTION
212
-
213
 
214
 
215
  if self.memory_efficient:
@@ -327,18 +331,17 @@ class StreamingTransformer(nn.Module):
327
  attention_as_float32: bool = False,
328
  cross_attention: bool = False,
329
  positional_embedding: str = 'sin',
330
- max_period: float = 10_000,
331
- positional_scale: float = 1,
332
  layer_class=StreamingTransformerLayer,
333
  checkpointing: str = 'none',
334
  device=None,
335
- dtype=None, **kwargs):
 
336
  super().__init__()
337
  assert d_model % num_heads == 0
338
 
339
  self.positional_embedding = positional_embedding
340
  self.max_period = max_period
341
- self.positional_scale = positional_scale
342
 
343
 
344
 
@@ -378,12 +381,12 @@ class StreamingTransformer(nn.Module):
378
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
379
  positions = positions + kwargs['token_count'] #offsets.view(-1, 1, 1)
380
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
381
- x = x + self.positional_scale * pos_emb
382
 
383
 
384
 
385
  for j, lay in enumerate(self.layers):
386
- # print(f'_________________________{j}___________________')
387
  x = lay(x, cross_attention_src=kwargs["cross_attention_src"]) # txt cond
388
  # each layer (mha) keeps history of its own k,v for all tokens
389
  return x
 
175
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
176
 
177
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
178
+ print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
179
  else:
180
  # 1st projected makes k,v (instantaneous)
181
  # 2nd cat
 
191
  # bound_layout = "b t p h d"
192
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
193
  q, k, v = ops.unbind(packed, dim=2)
194
+
195
 
196
  if self.k_history is not None:
197
  #
 
200
  # 24 heads 64 dimofh
201
  self.k_history = torch.cat([self.k_history, k], 2)
202
  self.v_history = torch.cat([self.v_history, v], 2)
203
+
204
  else:
205
  # init on 1st token (for all 47 transf layers)
206
+ print(f'else skip')
207
  self.k_history = k
208
  self.v_history = v
209
 
 
213
 
214
 
215
  # KV COMPLETION ONLY ON SELF ATTENTION
216
+ print('KV5', self.k_history.sum(), self.v_history.sum(), self.k_history.shape, self.v_history.shape)
217
 
218
 
219
  if self.memory_efficient:
 
331
  attention_as_float32: bool = False,
332
  cross_attention: bool = False,
333
  positional_embedding: str = 'sin',
334
+ max_period: float = 10_000,
 
335
  layer_class=StreamingTransformerLayer,
336
  checkpointing: str = 'none',
337
  device=None,
338
+ dtype=None,
339
+ **kwargs):
340
  super().__init__()
341
  assert d_model % num_heads == 0
342
 
343
  self.positional_embedding = positional_embedding
344
  self.max_period = max_period
 
345
 
346
 
347
 
 
381
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
382
  positions = positions + kwargs['token_count'] #offsets.view(-1, 1, 1)
383
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
384
+ x = x + pos_emb
385
 
386
 
387
 
388
  for j, lay in enumerate(self.layers):
389
+ print(f'5_________________________{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________')
390
  x = lay(x, cross_attention_src=kwargs["cross_attention_src"]) # txt cond
391
  # each layer (mha) keeps history of its own k,v for all tokens
392
  return x
demo.py CHANGED
@@ -1,7 +1,8 @@
1
  import audiofile
2
  import numpy as np
3
  from audiocraft import AudioGen
4
- text_list = ['dogs barging in the street', 'people po']
 
5
 
6
  sound_generator = AudioGen(duration=.74,
7
  device='cuda:0').to('cuda:0').eval()
 
1
  import audiofile
2
  import numpy as np
3
  from audiocraft import AudioGen
4
+ text_list = ['dogs barging in the street',
5
+ 'music']
6
 
7
  sound_generator = AudioGen(duration=.74,
8
  device='cuda:0').to('cuda:0').eval()