Ailyth commited on
Commit
516fd45
·
1 Parent(s): 9a035cf

0308-022448-Synchronize_GitHub_update_improve_inference_speed

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. AR/__pycache__/__init__.cpython-310.pyc +0 -0
  2. AR/data/bucket_sampler.py +2 -1
  3. AR/data/data_module.py +4 -2
  4. AR/data/dataset.py +2 -1
  5. AR/models/__pycache__/__init__.cpython-310.pyc +0 -0
  6. AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc +0 -0
  7. AR/models/__pycache__/t2s_model.cpython-310.pyc +0 -0
  8. AR/models/__pycache__/utils.cpython-310.pyc +0 -0
  9. AR/models/t2s_lightning_module.py +4 -3
  10. AR/models/t2s_lightning_module_onnx.py +2 -1
  11. AR/models/t2s_model.py +165 -44
  12. AR/models/t2s_model_onnx.py +2 -1
  13. AR/models/utils.py +72 -3
  14. AR/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  15. AR/modules/__pycache__/activation.cpython-310.pyc +0 -0
  16. AR/modules/__pycache__/embedding.cpython-310.pyc +0 -0
  17. AR/modules/__pycache__/lr_schedulers.cpython-310.pyc +0 -0
  18. AR/modules/__pycache__/optim.cpython-310.pyc +0 -0
  19. AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc +0 -0
  20. AR/modules/__pycache__/scaling.cpython-310.pyc +0 -0
  21. AR/modules/__pycache__/transformer.cpython-310.pyc +0 -0
  22. AR/modules/lr_schedulers.py +2 -1
  23. AR/modules/patched_mha_with_cache.py +4 -2
  24. AR/modules/scaling.py +1 -1
  25. AR/text_processing/phonemizer.py +2 -1
  26. AR/text_processing/symbols.py +2 -1
  27. MODELS/21/1.mp3 +0 -0
  28. MODELS/21/11.mp3 +0 -0
  29. MODELS/21/191.mp3 +0 -0
  30. MODELS/21/21.ckpt +0 -3
  31. MODELS/21/21.pth +0 -3
  32. MODELS/21/s1.mp3 +0 -0
  33. MODELS/21/s2.mp3 +0 -0
  34. MODELS/21/s3.mp3 +0 -0
  35. MODELS/22/22.ckpt +0 -3
  36. MODELS/22/22.pth +0 -3
  37. MODELS/22/passion.mp3 +0 -0
  38. MODELS/22/s1.mp3 +0 -0
  39. MODELS/22/s2.mp3 +0 -0
  40. MODELS/22/s3.mp3 +0 -0
  41. MODELS/22/slow_calm.mp3 +0 -0
  42. MODELS/22/speed.mp3 +0 -0
  43. MODELS/31/1.mp3 +0 -0
  44. MODELS/31/148.mp3 +0 -0
  45. MODELS/31/31.ckpt +0 -3
  46. MODELS/31/31.pth +0 -3
  47. MODELS/31/96.mp3 +0 -0
  48. MODELS/31/s1.mp3 +0 -0
  49. MODELS/31/s2.mp3 +0 -0
  50. MODELS/31/s3.mp3 +0 -0
AR/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/__pycache__/__init__.cpython-310.pyc and b/AR/__pycache__/__init__.cpython-310.pyc differ
 
AR/data/bucket_sampler.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py
 
2
  import itertools
3
  import math
4
  import random
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import itertools
4
  import math
5
  import random
AR/data/data_module.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
 
2
  from pytorch_lightning import LightningDataModule
3
  from AR.data.bucket_sampler import DistributedBucketSampler
4
  from AR.data.dataset import Text2SemanticDataset
@@ -41,7 +42,8 @@ class Text2SemanticDataModule(LightningDataModule):
41
  # pad_val=self.config['data']['pad_val'])
42
 
43
  def train_dataloader(self):
44
- batch_size = max(min(self.config["train"]["batch_size"],len(self._train_dataset)//4),1)#防止不保存
 
45
  sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
46
  return DataLoader(
47
  self._train_dataset,
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  from pytorch_lightning import LightningDataModule
4
  from AR.data.bucket_sampler import DistributedBucketSampler
5
  from AR.data.dataset import Text2SemanticDataset
 
42
  # pad_val=self.config['data']['pad_val'])
43
 
44
  def train_dataloader(self):
45
+ batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
46
+ batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
47
  sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
48
  return DataLoader(
49
  self._train_dataset,
AR/data/dataset.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py
 
2
  import pdb
3
  import sys
4
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import pdb
4
  import sys
5
 
AR/models/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/__init__.cpython-310.pyc and b/AR/models/__pycache__/__init__.cpython-310.pyc differ
 
AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc and b/AR/models/__pycache__/t2s_lightning_module.cpython-310.pyc differ
 
AR/models/__pycache__/t2s_model.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/t2s_model.cpython-310.pyc and b/AR/models/__pycache__/t2s_model.cpython-310.pyc differ
 
AR/models/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/AR/models/__pycache__/utils.cpython-310.pyc and b/AR/models/__pycache__/utils.cpython-310.pyc differ
 
AR/models/t2s_lightning_module.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
 
2
  import os, sys
3
 
4
  now_dir = os.getcwd()
@@ -11,7 +12,6 @@ from AR.models.t2s_model import Text2SemanticDecoder
11
  from AR.modules.lr_schedulers import WarmupCosineLRSchedule
12
  from AR.modules.optim import ScaledAdam
13
 
14
-
15
  class Text2SemanticLightningModule(LightningModule):
16
  def __init__(self, config, output_dir, is_train=True):
17
  super().__init__()
@@ -35,7 +35,8 @@ class Text2SemanticLightningModule(LightningModule):
35
  def training_step(self, batch: Dict, batch_idx: int):
36
  opt = self.optimizers()
37
  scheduler = self.lr_schedulers()
38
- loss, acc = self.model.forward(
 
39
  batch["phoneme_ids"],
40
  batch["phoneme_ids_len"],
41
  batch["semantic_ids"],
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import os, sys
4
 
5
  now_dir = os.getcwd()
 
12
  from AR.modules.lr_schedulers import WarmupCosineLRSchedule
13
  from AR.modules.optim import ScaledAdam
14
 
 
15
  class Text2SemanticLightningModule(LightningModule):
16
  def __init__(self, config, output_dir, is_train=True):
17
  super().__init__()
 
35
  def training_step(self, batch: Dict, batch_idx: int):
36
  opt = self.optimizers()
37
  scheduler = self.lr_schedulers()
38
+ forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
39
+ loss, acc = forward(
40
  batch["phoneme_ids"],
41
  batch["phoneme_ids_len"],
42
  batch["semantic_ids"],
AR/models/t2s_lightning_module_onnx.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
 
2
  import os, sys
3
 
4
  now_dir = os.getcwd()
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import os, sys
4
 
5
  now_dir = os.getcwd()
AR/models/t2s_model.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
 
2
  import torch
3
  from tqdm import tqdm
4
 
@@ -8,6 +9,9 @@ from AR.models.utils import (
8
  sample,
9
  logits_to_probs,
10
  multinomial_sample_one_no_sync,
 
 
 
11
  )
12
  from AR.modules.embedding import SinePositionalEmbedding
13
  from AR.modules.embedding import TokenEmbedding
@@ -85,11 +89,104 @@ class Text2SemanticDecoder(nn.Module):
85
  ignore_index=self.EOS,
86
  )
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def forward(self, x, x_lens, y, y_lens, bert_feature):
89
  """
90
  x: phoneme_ids
91
  y: semantic_ids
92
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  x = self.ar_text_embedding(x)
94
  x = x + self.bert_proj(bert_feature.transpose(1, 2))
95
  x = self.ar_text_position(x)
@@ -231,6 +328,7 @@ class Text2SemanticDecoder(nn.Module):
231
  prompts, ####参考音频token
232
  bert_feature,
233
  top_k: int = -100,
 
234
  early_stop_num: int = -1,
235
  temperature: float = 1.0,
236
  ):
@@ -240,7 +338,7 @@ class Text2SemanticDecoder(nn.Module):
240
 
241
  # AR Decoder
242
  y = prompts
243
- prefix_len = y.shape[1]
244
  x_len = x.shape[1]
245
  x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
246
  stop = False
@@ -256,47 +354,41 @@ class Text2SemanticDecoder(nn.Module):
256
  "first_infer": 1,
257
  "stage": 0,
258
  }
259
- for idx in tqdm(range(1500)):
260
- if cache["first_infer"] == 1:
261
- y_emb = self.ar_audio_embedding(y)
262
- else:
263
- y_emb = torch.cat(
264
- [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
265
- )
266
- cache["y_emb"] = y_emb
267
  y_pos = self.ar_audio_position(y_emb)
268
- # x 和逐渐增长的 y 一起输入给模型
269
- if cache["first_infer"] == 1:
270
- xy_pos = torch.concat([x, y_pos], dim=1)
271
- else:
272
- xy_pos = y_pos[:, -1:]
273
- y_len = y_pos.shape[1]
274
- ###以下3个不做缓存
275
- if cache["first_infer"] == 1:
276
- x_attn_mask_pad = F.pad(
 
 
 
 
277
  x_attn_mask,
278
  (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
279
  value=True,
280
  )
281
- y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
282
- torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
283
- (x_len, 0),
284
- value=False,
285
- )
286
- xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
287
- y.device
288
- )
289
- else:
290
- ###最右边一列(是错的)
291
- # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
292
- # xy_attn_mask[:,-1]=False
293
- ###最下面一行(是对的)
294
- xy_attn_mask = torch.zeros(
295
- (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
296
- )
297
- # pdb.set_trace()
298
- ###缓存重头戏
299
- # print(1111,xy_pos.shape,xy_attn_mask.shape,x_len,y_len)
300
  xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
301
  logits = self.ar_predict_layer(
302
  xy_dec[:, -1]
@@ -305,8 +397,12 @@ class Text2SemanticDecoder(nn.Module):
305
  if(idx==0):###第一次跑不能EOS否则没有了
306
  logits = logits[:, :-1] ###刨除1024终止符号的概率
307
  samples = sample(
308
- logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
309
  )[0].unsqueeze(0)
 
 
 
 
310
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
311
  print("use early stop num:", early_stop_num)
312
  stop = True
@@ -315,13 +411,38 @@ class Text2SemanticDecoder(nn.Module):
315
  # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
316
  stop = True
317
  if stop:
318
- if prompts.shape[1] == y.shape[1]:
 
 
 
319
  y = torch.concat([y, torch.zeros_like(samples)], dim=1)
320
  print("bad zero prediction")
321
  print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
322
  break
323
- # 本次生成的 semantic_ids 和之前的 y 构成新的 y
324
- # print(samples.shape)#[1,1]#第一个1是bs
325
- y = torch.concat([y, samples], dim=1)
326
  cache["first_infer"] = 0
327
- return y, idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import torch
4
  from tqdm import tqdm
5
 
 
9
  sample,
10
  logits_to_probs,
11
  multinomial_sample_one_no_sync,
12
+ dpo_loss,
13
+ make_reject_y,
14
+ get_batch_logps
15
  )
16
  from AR.modules.embedding import SinePositionalEmbedding
17
  from AR.modules.embedding import TokenEmbedding
 
89
  ignore_index=self.EOS,
90
  )
91
 
92
+ def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
93
+ x = self.ar_text_embedding(x)
94
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
95
+ x = self.ar_text_position(x)
96
+ x_mask = make_pad_mask(x_lens)
97
+
98
+ y_mask = make_pad_mask(y_lens)
99
+ y_mask_int = y_mask.type(torch.int64)
100
+ codes = y.type(torch.int64) * (1 - y_mask_int)
101
+
102
+ # Training
103
+ # AR Decoder
104
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
105
+ x_len = x_lens.max()
106
+ y_len = y_lens.max()
107
+ y_emb = self.ar_audio_embedding(y)
108
+ y_pos = self.ar_audio_position(y_emb)
109
+
110
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
111
+
112
+ ar_xy_padding_mask = xy_padding_mask
113
+
114
+ x_attn_mask = F.pad(
115
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
116
+ (0, y_len),
117
+ value=True,
118
+ )
119
+
120
+ y_attn_mask = F.pad(
121
+ torch.triu(
122
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
123
+ diagonal=1,
124
+ ),
125
+ (x_len, 0),
126
+ value=False,
127
+ )
128
+
129
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
130
+ bsz, src_len = x.shape[0], x_len + y_len
131
+ _xy_padding_mask = (
132
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
133
+ .expand(-1, self.num_head, -1, -1)
134
+ .reshape(bsz * self.num_head, 1, src_len)
135
+ )
136
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
137
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
138
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
139
+ xy_attn_mask = new_attn_mask
140
+ # x 和完整的 y 一次性输入模型
141
+ xy_pos = torch.concat([x, y_pos], dim=1)
142
+
143
+ return xy_pos, xy_attn_mask, targets
144
+
145
  def forward(self, x, x_lens, y, y_lens, bert_feature):
146
  """
147
  x: phoneme_ids
148
  y: semantic_ids
149
  """
150
+
151
+ reject_y, reject_y_lens = make_reject_y(y, y_lens)
152
+
153
+ xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
154
+
155
+ xy_dec, _ = self.h(
156
+ (xy_pos, None),
157
+ mask=xy_attn_mask,
158
+ )
159
+ x_len = x_lens.max()
160
+ logits = self.ar_predict_layer(xy_dec[:, x_len:])
161
+
162
+ ###### DPO #############
163
+ reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
164
+
165
+ reject_xy_dec, _ = self.h(
166
+ (reject_xy_pos, None),
167
+ mask=reject_xy_attn_mask,
168
+ )
169
+ x_len = x_lens.max()
170
+ reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
171
+
172
+ # loss
173
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
174
+
175
+ loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
176
+ acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
177
+
178
+ A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
179
+ loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
180
+
181
+ loss = loss_1 + loss_2
182
+
183
+ return loss, acc
184
+
185
+ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
186
+ """
187
+ x: phoneme_ids
188
+ y: semantic_ids
189
+ """
190
  x = self.ar_text_embedding(x)
191
  x = x + self.bert_proj(bert_feature.transpose(1, 2))
192
  x = self.ar_text_position(x)
 
328
  prompts, ####参考音频token
329
  bert_feature,
330
  top_k: int = -100,
331
+ top_p: int = 100,
332
  early_stop_num: int = -1,
333
  temperature: float = 1.0,
334
  ):
 
338
 
339
  # AR Decoder
340
  y = prompts
341
+
342
  x_len = x.shape[1]
343
  x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
344
  stop = False
 
354
  "first_infer": 1,
355
  "stage": 0,
356
  }
357
+ ################### first step ##########################
358
+ if y is not None:
359
+ y_emb = self.ar_audio_embedding(y)
360
+ y_len = y_emb.shape[1]
361
+ prefix_len = y.shape[1]
 
 
 
362
  y_pos = self.ar_audio_position(y_emb)
363
+ xy_pos = torch.concat([x, y_pos], dim=1)
364
+ cache["y_emb"] = y_emb
365
+ ref_free = False
366
+ else:
367
+ y_emb = None
368
+ y_len = 0
369
+ prefix_len = 0
370
+ y_pos = None
371
+ xy_pos = x
372
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
373
+ ref_free = True
374
+
375
+ x_attn_mask_pad = F.pad(
376
  x_attn_mask,
377
  (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
378
  value=True,
379
  )
380
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
381
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
382
+ (x_len, 0),
383
+ value=False,
384
+ )
385
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
386
+ x.device
387
+ )
388
+
389
+
390
+ for idx in tqdm(range(1500)):
391
+
 
 
 
 
 
 
 
392
  xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
393
  logits = self.ar_predict_layer(
394
  xy_dec[:, -1]
 
397
  if(idx==0):###第一次跑不能EOS否则没有了
398
  logits = logits[:, :-1] ###刨除1024终止符号的概率
399
  samples = sample(
400
+ logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
401
  )[0].unsqueeze(0)
402
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
403
+ # print(samples.shape)#[1,1]#第一个1是bs
404
+ y = torch.concat([y, samples], dim=1)
405
+
406
  if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
407
  print("use early stop num:", early_stop_num)
408
  stop = True
 
411
  # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
412
  stop = True
413
  if stop:
414
+ # if prompts.shape[1] == y.shape[1]:
415
+ # y = torch.concat([y, torch.zeros_like(samples)], dim=1)
416
+ # print("bad zero prediction")
417
+ if y.shape[1]==0:
418
  y = torch.concat([y, torch.zeros_like(samples)], dim=1)
419
  print("bad zero prediction")
420
  print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
421
  break
422
+
423
+ ####################### update next step ###################################
 
424
  cache["first_infer"] = 0
425
+ if cache["y_emb"] is not None:
426
+ y_emb = torch.cat(
427
+ [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], dim = 1
428
+ )
429
+ cache["y_emb"] = y_emb
430
+ y_pos = self.ar_audio_position(y_emb)
431
+ xy_pos = y_pos[:, -1:]
432
+ else:
433
+ y_emb = self.ar_audio_embedding(y[:, -1:])
434
+ cache["y_emb"] = y_emb
435
+ y_pos = self.ar_audio_position(y_emb)
436
+ xy_pos = y_pos
437
+ y_len = y_pos.shape[1]
438
+
439
+ ###最右边一列(是错的)
440
+ # xy_attn_mask=torch.ones((1, x_len+y_len), dtype=torch.bool,device=xy_pos.device)
441
+ # xy_attn_mask[:,-1]=False
442
+ ###最下面一行(是对的)
443
+ xy_attn_mask = torch.zeros(
444
+ (1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
445
+ )
446
+ if ref_free:
447
+ return y[:, :-1], 0
448
+ return y[:, :-1], idx-1
AR/models/t2s_model_onnx.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
 
2
  import torch
3
  from tqdm import tqdm
4
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import torch
4
  from tqdm import tqdm
5
 
AR/models/utils.py CHANGED
@@ -1,7 +1,8 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
 
2
  import torch
3
  import torch.nn.functional as F
4
-
5
 
6
  def sequence_mask(length, max_length=None):
7
  if max_length is None:
@@ -114,7 +115,8 @@ def logits_to_probs(
114
  top_p: Optional[int] = None,
115
  repetition_penalty: float = 1.0,
116
  ):
117
- previous_tokens = previous_tokens.squeeze()
 
118
  # print(logits.shape,previous_tokens.shape)
119
  # pdb.set_trace()
120
  if previous_tokens is not None and repetition_penalty != 1.0:
@@ -158,3 +160,70 @@ def sample(
158
  )
159
  idx_next = multinomial_sample_one_no_sync(probs)
160
  return idx_next, probs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import torch
4
  import torch.nn.functional as F
5
+ from typing import Tuple
6
 
7
  def sequence_mask(length, max_length=None):
8
  if max_length is None:
 
115
  top_p: Optional[int] = None,
116
  repetition_penalty: float = 1.0,
117
  ):
118
+ if previous_tokens is not None:
119
+ previous_tokens = previous_tokens.squeeze()
120
  # print(logits.shape,previous_tokens.shape)
121
  # pdb.set_trace()
122
  if previous_tokens is not None and repetition_penalty != 1.0:
 
160
  )
161
  idx_next = multinomial_sample_one_no_sync(probs)
162
  return idx_next, probs
163
+
164
+ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
165
+ policy_rejected_logps: torch.FloatTensor,
166
+ reference_chosen_logps: torch.FloatTensor,
167
+ reference_rejected_logps: torch.FloatTensor,
168
+ beta: float,
169
+ reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
170
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
171
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
172
+
173
+ if reference_free:
174
+ ref_logratios = 0
175
+
176
+ logits = pi_logratios - ref_logratios
177
+
178
+ losses = -F.logsigmoid(beta * logits)
179
+ chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
180
+ rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
181
+
182
+ return losses.mean(), chosen_rewards, rejected_rewards
183
+
184
+ def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
185
+
186
+ # dummy token; we'll ignore the losses on these tokens later
187
+
188
+ per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
189
+ per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
190
+
191
+ return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
192
+
193
+ def make_reject_y(y_o, y_lens):
194
+ def repeat_P(y):
195
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
196
+ pre = y[:range_idx[0]]
197
+ shf = y[range_idx[1]:]
198
+ range_text = y[range_idx[0]:range_idx[1]]
199
+ new_y = torch.cat([pre, range_text, range_text, shf])
200
+ return new_y
201
+ def lost_P(y):
202
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
203
+ pre = y[:range_idx[0]]
204
+ shf = y[range_idx[1]:]
205
+ range_text = y[range_idx[0]:range_idx[1]]
206
+ new_y = torch.cat([pre, shf])
207
+ return new_y
208
+ bs = len(y_lens)
209
+ reject_y = []
210
+ reject_y_lens = []
211
+ for b in range(bs):
212
+ process_item_idx = torch.randint(0, 1, size=(1, ))[0]
213
+ if process_item_idx == 0:
214
+ new_y = repeat_P(y_o[b])
215
+ reject_y.append(new_y)
216
+ reject_y_lens.append(len(new_y))
217
+ elif process_item_idx==1:
218
+ new_y = lost_P(y_o[b])
219
+ reject_y.append(new_y)
220
+ reject_y_lens.append(len(new_y))
221
+ max_length = max(reject_y_lens)
222
+ for b in range(bs):
223
+ pad_length = max_length - reject_y_lens[b]
224
+ reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
225
+
226
+ reject_y = torch.stack(reject_y, dim = 0)
227
+ reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
228
+
229
+ return reject_y, reject_y_lens
AR/modules/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/__init__.cpython-310.pyc and b/AR/modules/__pycache__/__init__.cpython-310.pyc differ
 
AR/modules/__pycache__/activation.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/activation.cpython-310.pyc and b/AR/modules/__pycache__/activation.cpython-310.pyc differ
 
AR/modules/__pycache__/embedding.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/embedding.cpython-310.pyc and b/AR/modules/__pycache__/embedding.cpython-310.pyc differ
 
AR/modules/__pycache__/lr_schedulers.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc and b/AR/modules/__pycache__/lr_schedulers.cpython-310.pyc differ
 
AR/modules/__pycache__/optim.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/optim.cpython-310.pyc and b/AR/modules/__pycache__/optim.cpython-310.pyc differ
 
AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc and b/AR/modules/__pycache__/patched_mha_with_cache.cpython-310.pyc differ
 
AR/modules/__pycache__/scaling.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/scaling.cpython-310.pyc and b/AR/modules/__pycache__/scaling.cpython-310.pyc differ
 
AR/modules/__pycache__/transformer.cpython-310.pyc CHANGED
Binary files a/AR/modules/__pycache__/transformer.cpython-310.pyc and b/AR/modules/__pycache__/transformer.cpython-310.pyc differ
 
AR/modules/lr_schedulers.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py
 
2
  import math
3
 
4
  import torch
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import math
4
 
5
  import torch
AR/modules/patched_mha_with_cache.py CHANGED
@@ -5,8 +5,8 @@ from torch.nn.functional import (
5
  _none_or_dtype,
6
  _in_projection_packed,
7
  )
8
-
9
- # import torch
10
  # Tensor = torch.Tensor
11
  # from typing import Callable, List, Optional, Tuple, Union
12
 
@@ -448,9 +448,11 @@ def multi_head_attention_forward_patched(
448
  k = k.view(bsz, num_heads, src_len, head_dim)
449
  v = v.view(bsz, num_heads, src_len, head_dim)
450
 
 
451
  attn_output = scaled_dot_product_attention(
452
  q, k, v, attn_mask, dropout_p, is_causal
453
  )
 
454
  attn_output = (
455
  attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
456
  )
 
5
  _none_or_dtype,
6
  _in_projection_packed,
7
  )
8
+ from torch.nn import functional as F
9
+ import torch
10
  # Tensor = torch.Tensor
11
  # from typing import Callable, List, Optional, Tuple, Union
12
 
 
448
  k = k.view(bsz, num_heads, src_len, head_dim)
449
  v = v.view(bsz, num_heads, src_len, head_dim)
450
 
451
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
452
  attn_output = scaled_dot_product_attention(
453
  q, k, v, attn_mask, dropout_p, is_causal
454
  )
455
+
456
  attn_output = (
457
  attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
458
  )
AR/modules/scaling.py CHANGED
@@ -13,7 +13,7 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
- #import logging
17
  import math
18
  import random
19
  from typing import Optional
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
+ import logging
17
  import math
18
  import random
19
  from typing import Optional
AR/text_processing/phonemizer.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py
 
2
  import itertools
3
  import re
4
  from typing import Dict
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  import itertools
4
  import re
5
  from typing import Dict
AR/text_processing/symbols.py CHANGED
@@ -1,4 +1,5 @@
1
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py
 
2
  PAD = "_"
3
  PUNCTUATION = ';:,.!?¡¿—…"«»“” '
4
  LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
  PAD = "_"
4
  PUNCTUATION = ';:,.!?¡¿—…"«»“” '
5
  LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
MODELS/21/1.mp3 DELETED
Binary file (30.9 kB)
 
MODELS/21/11.mp3 DELETED
Binary file (28 kB)
 
MODELS/21/191.mp3 DELETED
Binary file (29.5 kB)
 
MODELS/21/21.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c4b29bb398a9dbed95c50489a2633f90a01c0c4ae1e4432f5d37d388401f9887
3
- size 155077753
 
 
 
 
MODELS/21/21.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bfb359648e858765e9c1e3f7d51869aec9f607d18efd90d059cb83f1a7988141
3
- size 84927748
 
 
 
 
MODELS/21/s1.mp3 DELETED
Binary file (29 kB)
 
MODELS/21/s2.mp3 DELETED
Binary file (29 kB)
 
MODELS/21/s3.mp3 DELETED
Binary file (28.5 kB)
 
MODELS/22/22.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3632e3d1876f7a8e86850f346338c5e2390d09f382891277acf77a4e1a65a25
3
- size 155083315
 
 
 
 
MODELS/22/22.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3dfe7fe2765b179db75d8e12bd2b32e1f8d624dcee9a3fecdecbc94904757c29
3
- size 84927982
 
 
 
 
MODELS/22/passion.mp3 DELETED
Binary file (131 kB)
 
MODELS/22/s1.mp3 DELETED
Binary file (26.8 kB)
 
MODELS/22/s2.mp3 DELETED
Binary file (33.1 kB)
 
MODELS/22/s3.mp3 DELETED
Binary file (30.2 kB)
 
MODELS/22/slow_calm.mp3 DELETED
Binary file (79.2 kB)
 
MODELS/22/speed.mp3 DELETED
Binary file (122 kB)
 
MODELS/31/1.mp3 DELETED
Binary file (111 kB)
 
MODELS/31/148.mp3 DELETED
Binary file (86.8 kB)
 
MODELS/31/31.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:532d92b5b2a1550ed1151aa2d0a801a2fc390fc7b87a7d0278ca7af4cad50c7f
3
- size 155084485
 
 
 
 
MODELS/31/31.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e3d128cd00c3853ebe375dd5aeccd979c55a7e8d036cc41843507e2191ccd6d3
3
- size 84929396
 
 
 
 
MODELS/31/96.mp3 DELETED
Binary file (83.4 kB)
 
MODELS/31/s1.mp3 DELETED
Binary file (32.2 kB)
 
MODELS/31/s2.mp3 DELETED
Binary file (43 kB)
 
MODELS/31/s3.mp3 DELETED
Binary file (39.1 kB)