PoTaTo721 commited on
Commit
69e8a46
·
1 Parent(s): 469209d

update to 1.2

Browse files
Files changed (45) hide show
  1. app.py +420 -129
  2. fish_speech/configs/base.yaml +1 -0
  3. fish_speech/configs/firefly_gan_vq.yaml +34 -0
  4. fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  5. fish_speech/configs/text2semantic_finetune.yaml +22 -18
  6. fish_speech/datasets/concat_repeat.py +53 -0
  7. fish_speech/datasets/semantic.py +496 -0
  8. fish_speech/datasets/vqgan.py +3 -1
  9. fish_speech/models/text2semantic/__init__.py +0 -3
  10. fish_speech/models/text2semantic/lit_module.py +22 -164
  11. fish_speech/models/text2semantic/llama.py +227 -70
  12. fish_speech/models/text2semantic/lora.py +92 -0
  13. fish_speech/models/vqgan/modules/firefly.py +88 -1
  14. fish_speech/models/vqgan/modules/fsq.py +1 -1
  15. fish_speech/text/__init__.py +2 -1
  16. fish_speech/text/chn_text_norm/.gitignore +114 -0
  17. fish_speech/text/chn_text_norm/README.md +36 -0
  18. fish_speech/text/chn_text_norm/__init__.py +0 -0
  19. fish_speech/text/chn_text_norm/basic_class.py +172 -0
  20. fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  21. fish_speech/text/chn_text_norm/basic_util.py +342 -0
  22. fish_speech/text/chn_text_norm/cardinal.py +32 -0
  23. fish_speech/text/chn_text_norm/date.py +75 -0
  24. fish_speech/text/chn_text_norm/digit.py +32 -0
  25. fish_speech/text/chn_text_norm/fraction.py +35 -0
  26. fish_speech/text/chn_text_norm/money.py +43 -0
  27. fish_speech/text/chn_text_norm/percentage.py +33 -0
  28. fish_speech/text/chn_text_norm/telephone.py +51 -0
  29. fish_speech/text/chn_text_norm/text.py +177 -0
  30. fish_speech/text/clean.py +1 -5
  31. fish_speech/text/spliter.py +130 -0
  32. fish_speech/utils/file.py +1 -1
  33. fish_speech/utils/rich_utils.py +7 -3
  34. fish_speech/utils/spectrogram.py +122 -0
  35. tools/api.py +482 -0
  36. tools/auto_rerank.py +159 -0
  37. tools/llama/build_dataset.py +169 -0
  38. tools/llama/eval_in_context.py +171 -0
  39. tools/llama/generate.py +119 -180
  40. tools/llama/merge_lora.py +95 -0
  41. tools/llama/quantize.py +46 -64
  42. tools/llama/rebuild_tokenizer.py +57 -0
  43. tools/vqgan/create_train_split.py +83 -0
  44. tools/vqgan/extract_vq.py +227 -0
  45. tools/vqgan/inference.py +29 -26
app.py CHANGED
@@ -5,7 +5,7 @@ import hydra
5
 
6
  # Download if not exists
7
  os.makedirs("checkpoints", exist_ok=True)
8
- snapshot_download(repo_id="fishaudio/fish-speech-1", local_dir="./checkpoints/fish-speech-1")
9
 
10
  print("All checkpoints downloaded")
11
 
@@ -23,6 +23,16 @@ from transformers import AutoTokenizer
23
 
24
  from tools.llama.generate import launch_thread_safe_queue
25
  from tools.vqgan.inference import load_model as load_vqgan_model
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Make einx happy
28
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
@@ -30,8 +40,8 @@ os.environ["EINX_FILTER_TRACEBACK"] = "false"
30
 
31
  HEADER_MD = """# Fish Speech
32
 
33
- ## The demo in this space is version 1.0, Please check [Fish Audio](https://fish.audio) for the best model.
34
- ## 该 Demo 为 Fish Speech 1.0 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
35
 
36
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
37
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
@@ -39,14 +49,14 @@ A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https
39
  You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
40
  你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
41
 
42
- Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.
43
- 相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.
44
 
45
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
46
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
47
 
48
- The model running in this WebUI is Fish Speech V1 Medium SFT 4K.
49
- 在此 WebUI 中运行的模型是 Fish Speech V1 Medium SFT 4K.
50
  """
51
 
52
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
@@ -85,36 +95,27 @@ def inference(
85
  top_p,
86
  repetition_penalty,
87
  temperature,
88
- speaker,
89
  ):
90
  if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
91
- return None, f"Text is too long, please keep it under {args.max_gradio_length} characters."
92
-
93
- # Parse reference audio aka prompt
94
- prompt_tokens = None
95
- if enable_reference_audio and reference_audio is not None:
96
- # reference_audio_sr, reference_audio_content = reference_audio
97
- reference_audio_content, _ = librosa.load(
98
- reference_audio, sr=vqgan_model.sampling_rate, mono=True
99
- )
100
- audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
101
- None, None, :
102
- ]
103
-
104
- logger.info(
105
- f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
106
  )
107
 
108
- # VQ Encoder
109
- audio_lengths = torch.tensor(
110
- [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
111
- )
112
- prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
 
113
 
114
  # LLAMA Inference
115
  request = dict(
116
- tokenizer=llama_tokenizer,
117
- device=vqgan_model.device,
118
  max_new_tokens=max_new_tokens,
119
  text=text,
120
  top_p=top_p,
@@ -123,43 +124,246 @@ def inference(
123
  compile=args.compile,
124
  iterative_prompt=chunk_length > 0,
125
  chunk_length=chunk_length,
126
- max_length=args.max_length,
127
- speaker=speaker if speaker else None,
128
  prompt_tokens=prompt_tokens if enable_reference_audio else None,
129
  prompt_text=reference_text if enable_reference_audio else None,
130
  )
131
 
132
- payload = dict(
133
- response_queue=queue.Queue(),
134
- request=request,
 
 
 
135
  )
136
- llama_queue.put(payload)
137
 
138
- codes = []
 
 
 
 
139
  while True:
140
- result = payload["response_queue"].get()
141
- if result == "next":
142
- # TODO: handle next sentence
143
- continue
144
-
145
- if result == "done":
146
- if payload["success"] is False:
147
- return None, build_html_error_message(payload["response"])
148
  break
149
 
150
- codes.append(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- codes = torch.cat(codes, dim=1)
153
 
154
- # VQGAN Inference
155
- feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
156
- fake_audios = vqgan_model.decode(
157
- indices=codes[None], feature_lengths=feature_lengths, return_audios=True
158
- )[0, 0]
159
 
160
- fake_audios = fake_audios.float().cpu().numpy()
 
 
 
 
161
 
162
- return (vqgan_model.sampling_rate, fake_audios), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  def build_app():
@@ -170,95 +374,179 @@ def build_app():
170
  app.load(
171
  None,
172
  None,
173
- js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
 
174
  )
175
 
176
  # Inference
177
  with gr.Row():
178
  with gr.Column(scale=3):
179
  text = gr.Textbox(
180
- label="Input Text / 输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
 
 
 
 
 
 
 
181
  )
182
 
183
  with gr.Row():
184
- with gr.Tab(label="Advanced Config / 高级参数"):
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  chunk_length = gr.Slider(
186
- label="Iterative Prompt Length, 0 means off / 迭代提示长度,0 表示关闭",
187
  minimum=0,
188
- maximum=100,
189
- value=30,
190
  step=8,
191
  )
192
 
193
  max_new_tokens = gr.Slider(
194
- label="Maximum tokens per batch, 0 means no limit / 每批最大令牌数,0 表示无限制",
195
- minimum=128,
196
- maximum=512,
197
- value=512, # 0 means no limit
198
  step=8,
199
  )
200
 
201
  top_p = gr.Slider(
202
- label="Top-P", minimum=0, maximum=1, value=0.7, step=0.01
 
 
 
 
203
  )
204
 
205
  repetition_penalty = gr.Slider(
206
  label="Repetition Penalty",
207
- minimum=0,
208
- maximum=2,
209
- value=1.5,
210
  step=0.01,
211
  )
212
 
213
  temperature = gr.Slider(
214
  label="Temperature",
215
- minimum=0,
216
- maximum=2,
217
  value=0.7,
218
  step=0.01,
219
  )
220
 
221
- speaker = gr.Textbox(
222
- label="Speaker / 说话人",
223
- placeholder="Type name of the speaker / 输入说话人的名称",
224
- lines=1,
225
- )
226
-
227
- with gr.Tab(label="Reference Audio / 参考音频"):
228
  gr.Markdown(
229
- "5 to 10 seconds of reference audio, useful for specifying speaker. \n5 到 10 秒的参考音频,适用于指定音色。"
230
  )
231
 
232
  enable_reference_audio = gr.Checkbox(
233
- label="Enable Reference Audio / 启用参考音频",
234
  )
235
  reference_audio = gr.Audio(
236
- label="Reference Audio / 参考音频",
237
  type="filepath",
238
  )
239
- reference_text = gr.Textbox(
240
- label="Reference Text / 参考文本",
241
- placeholder="参考文本",
242
- lines=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  )
244
 
245
  with gr.Column(scale=3):
246
- with gr.Row():
247
- error = gr.HTML(label="Error Message / 错误信息")
248
- with gr.Row():
249
- audio = gr.Audio(label="Generated Audio / 音频", type="numpy")
 
 
 
 
 
 
 
 
 
 
 
250
 
 
 
 
 
 
 
 
 
251
  with gr.Row():
252
  with gr.Column(scale=3):
253
  generate = gr.Button(
254
- value="\U0001F3A7 Generate / 合成", variant="primary"
 
 
 
 
255
  )
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # # Submit
258
  generate.click(
259
- inference,
260
  [
261
- text,
262
  enable_reference_audio,
263
  reference_audio,
264
  reference_text,
@@ -267,12 +555,29 @@ def build_app():
267
  top_p,
268
  repetition_penalty,
269
  temperature,
270
- speaker,
 
271
  ],
272
- [audio, error],
273
  concurrency_limit=1,
274
  )
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  return app
277
 
278
 
@@ -281,74 +586,60 @@ def parse_args():
281
  parser.add_argument(
282
  "--llama-checkpoint-path",
283
  type=Path,
284
- default="checkpoints/text2semantic-sft-large-v1-4k.pth",
285
  )
286
  parser.add_argument(
287
- "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
288
- )
289
- parser.add_argument(
290
- "--vqgan-checkpoint-path",
291
  type=Path,
292
- default="checkpoints/vq-gan-group-fsq-2x1024.pth",
293
  )
294
- parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
295
- parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
296
  parser.add_argument("--device", type=str, default="cuda")
297
  parser.add_argument("--half", action="store_true")
298
- parser.add_argument("--max-length", type=int, default=2048)
299
  parser.add_argument("--compile", action="store_true")
300
  parser.add_argument("--max-gradio-length", type=int, default=0)
 
301
 
302
  return parser.parse_args()
303
 
304
 
305
  if __name__ == "__main__":
306
  args = parse_args()
307
-
308
  args.precision = torch.half if args.half else torch.bfloat16
309
- args.compile = True
310
- args.max_gradio_length = 1024
311
- args.tokenizer = "./checkpoints/fish-speech-1"
312
- args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-medium-v1-4k.pth"
313
- args.llama_config_name = "dual_ar_2_codebook_medium"
314
- args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
315
- args.vqgan_config_name = "vqgan_pretrain"
316
 
317
  logger.info("Loading Llama model...")
318
  llama_queue = launch_thread_safe_queue(
319
- config_name=args.llama_config_name,
320
  checkpoint_path=args.llama_checkpoint_path,
321
  device=args.device,
322
  precision=args.precision,
323
- max_length=args.max_length,
324
  compile=args.compile,
325
  )
326
- llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
327
  logger.info("Llama model loaded, loading VQ-GAN model...")
328
 
329
- vqgan_model = load_vqgan_model(
330
- config_name=args.vqgan_config_name,
331
- checkpoint_path=args.vqgan_checkpoint_path,
332
  device=args.device,
333
  )
334
 
335
- logger.info("VQ-GAN model loaded, warming up...")
336
 
337
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
338
- inference(
339
- text="Hello, world!",
340
- enable_reference_audio=False,
341
- reference_audio=None,
342
- reference_text="",
343
- max_new_tokens=0,
344
- chunk_length=0,
345
- top_p=0.7,
346
- repetition_penalty=1.5,
347
- temperature=0.7,
348
- speaker=None,
 
349
  )
350
 
351
  logger.info("Warming up done, launching the web UI...")
352
 
353
  app = build_app()
354
- app.launch(show_api=False)
 
5
 
6
  # Download if not exists
7
  os.makedirs("checkpoints", exist_ok=True)
8
+ snapshot_download(repo_id="fishaudio/fish-speech-1.2-sft", local_dir="./checkpoints/fish-speech-1.2-sft")
9
 
10
  print("All checkpoints downloaded")
11
 
 
23
 
24
  from tools.llama.generate import launch_thread_safe_queue
25
  from tools.vqgan.inference import load_model as load_vqgan_model
26
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
27
+ from tools.api import decode_vq_tokens, encode_reference
28
+ from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
29
+ from tools.llama.generate import (
30
+ GenerateRequest,
31
+ GenerateResponse,
32
+ WrappedGenerateResponse,
33
+ launch_thread_safe_queue,
34
+ )
35
+ from tools.vqgan.inference import load_model as load_decoder_model
36
 
37
  # Make einx happy
38
  os.environ["EINX_FILTER_TRACEBACK"] = "false"
 
40
 
41
  HEADER_MD = """# Fish Speech
42
 
43
+ ## The demo in this space is version 1.2, Please check [Fish Audio](https://fish.audio) for the best model.
44
+ ## 该 Demo 为 Fish Speech 1.2 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
45
 
46
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
47
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
 
49
  You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
50
  你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
51
 
52
+ Related code and weights are released under CC BY-NC-SA 4.0 License.
53
+ 相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
54
 
55
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
56
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
57
 
58
+ The model running in this WebUI is Fish Speech V1.2 Medium SFT.
59
+ 在此 WebUI 中运行的模型是 Fish Speech V1.2 Medium SFT.
60
  """
61
 
62
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
 
95
  top_p,
96
  repetition_penalty,
97
  temperature,
98
+ streaming=False,
99
  ):
100
  if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
101
+ return (
102
+ None,
103
+ None,
104
+ "Text is too long, please keep it under {} characters.".format(
105
+ args.max_gradio_length
106
+ ),
 
 
 
 
 
 
 
 
 
107
  )
108
 
109
+ # Parse reference audio aka prompt
110
+ prompt_tokens = encode_reference(
111
+ decoder_model=decoder_model,
112
+ reference_audio=reference_audio,
113
+ enable_reference_audio=enable_reference_audio,
114
+ )
115
 
116
  # LLAMA Inference
117
  request = dict(
118
+ device=decoder_model.device,
 
119
  max_new_tokens=max_new_tokens,
120
  text=text,
121
  top_p=top_p,
 
124
  compile=args.compile,
125
  iterative_prompt=chunk_length > 0,
126
  chunk_length=chunk_length,
127
+ max_length=2048,
 
128
  prompt_tokens=prompt_tokens if enable_reference_audio else None,
129
  prompt_text=reference_text if enable_reference_audio else None,
130
  )
131
 
132
+ response_queue = queue.Queue()
133
+ llama_queue.put(
134
+ GenerateRequest(
135
+ request=request,
136
+ response_queue=response_queue,
137
+ )
138
  )
 
139
 
140
+ if streaming:
141
+ yield wav_chunk_header(), None, None
142
+
143
+ segments = []
144
+
145
  while True:
146
+ result: WrappedGenerateResponse = response_queue.get()
147
+ if result.status == "error":
148
+ yield None, None, build_html_error_message(result.response)
 
 
 
 
 
149
  break
150
 
151
+ result: GenerateResponse = result.response
152
+ if result.action == "next":
153
+ break
154
+
155
+ with torch.autocast(
156
+ device_type=(
157
+ "cpu"
158
+ if decoder_model.device.type == "mps"
159
+ else decoder_model.device.type
160
+ ),
161
+ dtype=args.precision,
162
+ ):
163
+ fake_audios = decode_vq_tokens(
164
+ decoder_model=decoder_model,
165
+ codes=result.codes,
166
+ )
167
+
168
+ fake_audios = fake_audios.float().cpu().numpy()
169
+ segments.append(fake_audios)
170
+
171
+ if streaming:
172
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
173
+
174
+ if len(segments) == 0:
175
+ return (
176
+ None,
177
+ None,
178
+ build_html_error_message(
179
+ "No audio generated, please check the input text."
180
+ ),
181
+ )
182
+
183
+ # No matter streaming or not, we need to return the final audio
184
+ audio = np.concatenate(segments, axis=0)
185
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
186
+
187
+ if torch.cuda.is_available():
188
+ torch.cuda.empty_cache()
189
+ gc.collect()
190
+
191
+
192
+ def inference_with_auto_rerank(
193
+ text,
194
+ enable_reference_audio,
195
+ reference_audio,
196
+ reference_text,
197
+ max_new_tokens,
198
+ chunk_length,
199
+ top_p,
200
+ repetition_penalty,
201
+ temperature,
202
+ use_auto_rerank,
203
+ streaming=False,
204
+ ):
205
+
206
+ max_attempts = 2 if use_auto_rerank else 1
207
+ best_wer = float("inf")
208
+ best_audio = None
209
+ best_sample_rate = None
210
+
211
+ for attempt in range(max_attempts):
212
+ audio_generator = inference(
213
+ text,
214
+ enable_reference_audio,
215
+ reference_audio,
216
+ reference_text,
217
+ max_new_tokens,
218
+ chunk_length,
219
+ top_p,
220
+ repetition_penalty,
221
+ temperature,
222
+ streaming=False,
223
+ )
224
+
225
+ # 获取音频数据
226
+ for _ in audio_generator:
227
+ pass
228
+ _, (sample_rate, audio), message = _
229
+
230
+ if audio is None:
231
+ return None, None, message
232
+
233
+ if not use_auto_rerank:
234
+ return None, (sample_rate, audio), None
235
+
236
+ asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
237
+ wer = calculate_wer(text, asr_result["text"])
238
+ if wer <= 0.3 and not asr_result["huge_gap"]:
239
+ return None, (sample_rate, audio), None
240
+
241
+ if wer < best_wer:
242
+ best_wer = wer
243
+ best_audio = audio
244
+ best_sample_rate = sample_rate
245
+
246
+ if attempt == max_attempts - 1:
247
+ break
248
+
249
+ return None, (best_sample_rate, best_audio), None
250
+
251
+
252
+ inference_stream = partial(inference, streaming=True)
253
+
254
+ n_audios = 4
255
+
256
+ global_audio_list = []
257
+ global_error_list = []
258
+
259
+
260
+ def inference_wrapper(
261
+ text,
262
+ enable_reference_audio,
263
+ reference_audio,
264
+ reference_text,
265
+ max_new_tokens,
266
+ chunk_length,
267
+ top_p,
268
+ repetition_penalty,
269
+ temperature,
270
+ batch_infer_num,
271
+ if_load_asr_model,
272
+ ):
273
+ audios = []
274
+ errors = []
275
+
276
+ for _ in range(batch_infer_num):
277
+ result = inference_with_auto_rerank(
278
+ text,
279
+ enable_reference_audio,
280
+ reference_audio,
281
+ reference_text,
282
+ max_new_tokens,
283
+ chunk_length,
284
+ top_p,
285
+ repetition_penalty,
286
+ temperature,
287
+ if_load_asr_model,
288
+ )
289
+
290
+ _, audio_data, error_message = result
291
+
292
+ audios.append(
293
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
294
+ )
295
+ errors.append(
296
+ gr.HTML(value=error_message if error_message else None, visible=True),
297
+ )
298
+
299
+ for _ in range(batch_infer_num, n_audios):
300
+ audios.append(
301
+ gr.Audio(value=None, visible=False),
302
+ )
303
+ errors.append(
304
+ gr.HTML(value=None, visible=False),
305
+ )
306
+
307
+ return None, *audios, *errors
308
+
309
+
310
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
311
+ buffer = io.BytesIO()
312
+
313
+ with wave.open(buffer, "wb") as wav_file:
314
+ wav_file.setnchannels(channels)
315
+ wav_file.setsampwidth(bit_depth // 8)
316
+ wav_file.setframerate(sample_rate)
317
+
318
+ wav_header_bytes = buffer.getvalue()
319
+ buffer.close()
320
+ return wav_header_bytes
321
+
322
+
323
+ def normalize_text(user_input, use_normalization):
324
+ if use_normalization:
325
+ return ChnNormedText(raw_text=user_input).normalize()
326
+ else:
327
+ return user_input
328
+
329
+
330
+ asr_model = None
331
 
 
332
 
333
+ def change_if_load_asr_model(if_load):
334
+ global asr_model
 
 
 
335
 
336
+ if if_load:
337
+ gr.Warning("Loading faster whisper model...")
338
+ if asr_model is None:
339
+ asr_model = load_model()
340
+ return gr.Checkbox(label="Unload faster whisper model", value=if_load)
341
 
342
+ if if_load is False:
343
+ gr.Warning("Unloading faster whisper model...")
344
+ del asr_model
345
+ asr_model = None
346
+ if torch.cuda.is_available():
347
+ torch.cuda.empty_cache()
348
+ gc.collect()
349
+ return gr.Checkbox(label="Load faster whisper model", value=if_load)
350
+
351
+
352
+ def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
353
+ if if_load and asr_model is not None:
354
+ if (
355
+ if_auto_label
356
+ and enable_ref
357
+ and ref_audio is not None
358
+ and ref_text.strip() == ""
359
+ ):
360
+ data, sample_rate = librosa.load(ref_audio)
361
+ res = batch_asr(asr_model, [data], sample_rate)[0]
362
+ ref_text = res["text"]
363
+ else:
364
+ gr.Warning("Whisper model not loaded!")
365
+
366
+ return gr.Textbox(value=ref_text)
367
 
368
 
369
  def build_app():
 
374
  app.load(
375
  None,
376
  None,
377
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
378
+ % args.theme,
379
  )
380
 
381
  # Inference
382
  with gr.Row():
383
  with gr.Column(scale=3):
384
  text = gr.Textbox(
385
+ label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10
386
+ )
387
+ refined_text = gr.Textbox(
388
+ label="Realtime Transform Text",
389
+ placeholder=
390
+ "Normalization Result Preview (Currently Only Chinese)",
391
+ lines=5,
392
+ interactive=False,
393
  )
394
 
395
  with gr.Row():
396
+ if_refine_text = gr.Checkbox(
397
+ label="Text Normalization",
398
+ value=True,
399
+ scale=1,
400
+ )
401
+
402
+ if_load_asr_model = gr.Checkbox(
403
+ label="Load / Unload ASR model for auto-reranking",
404
+ value=False,
405
+ scale=3,
406
+ )
407
+
408
+ with gr.Row():
409
+ with gr.Tab(label="Advanced Config"):
410
  chunk_length = gr.Slider(
411
+ label="Iterative Prompt Length, 0 means off",
412
  minimum=0,
413
+ maximum=500,
414
+ value=100,
415
  step=8,
416
  )
417
 
418
  max_new_tokens = gr.Slider(
419
+ label="Maximum tokens per batch, 0 means no limit",
420
+ minimum=0,
421
+ maximum=2048,
422
+ value=1024, # 0 means no limit
423
  step=8,
424
  )
425
 
426
  top_p = gr.Slider(
427
+ label="Top-P",
428
+ minimum=0.6,
429
+ maximum=0.9,
430
+ value=0.7,
431
+ step=0.01,
432
  )
433
 
434
  repetition_penalty = gr.Slider(
435
  label="Repetition Penalty",
436
+ minimum=1,
437
+ maximum=1.5,
438
+ value=1.2,
439
  step=0.01,
440
  )
441
 
442
  temperature = gr.Slider(
443
  label="Temperature",
444
+ minimum=0.6,
445
+ maximum=0.9,
446
  value=0.7,
447
  step=0.01,
448
  )
449
 
450
+ with gr.Tab(label="Reference Audio"):
 
 
 
 
 
 
451
  gr.Markdown(
452
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
453
  )
454
 
455
  enable_reference_audio = gr.Checkbox(
456
+ label="Enable Reference Audio",
457
  )
458
  reference_audio = gr.Audio(
459
+ label="Reference Audio",
460
  type="filepath",
461
  )
462
+ with gr.Row():
463
+ if_auto_label = gr.Checkbox(
464
+ label="Auto Labeling",
465
+ min_width=100,
466
+ scale=0,
467
+ value=False,
468
+ )
469
+ reference_text = gr.Textbox(
470
+ label="Reference Text",
471
+ lines=1,
472
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
473
+ value="",
474
+ )
475
+ with gr.Tab(label="Batch Inference"):
476
+ batch_infer_num = gr.Slider(
477
+ label="Batch infer nums",
478
+ minimum=1,
479
+ maximum=n_audios,
480
+ step=1,
481
+ value=1,
482
  )
483
 
484
  with gr.Column(scale=3):
485
+ for _ in range(n_audios):
486
+ with gr.Row():
487
+ error = gr.HTML(
488
+ label="Error Message",
489
+ visible=True if _ == 0 else False,
490
+ )
491
+ global_error_list.append(error)
492
+ with gr.Row():
493
+ audio = gr.Audio(
494
+ label="Generated Audio",
495
+ type="numpy",
496
+ interactive=False,
497
+ visible=True if _ == 0 else False,
498
+ )
499
+ global_audio_list.append(audio)
500
 
501
+ with gr.Row():
502
+ stream_audio = gr.Audio(
503
+ label="Streaming Audio",
504
+ streaming=True,
505
+ autoplay=True,
506
+ interactive=False,
507
+ show_download_button=True,
508
+ )
509
  with gr.Row():
510
  with gr.Column(scale=3):
511
  generate = gr.Button(
512
+ value="\U0001F3A7 " + "Generate", variant="primary"
513
+ )
514
+ generate_stream = gr.Button(
515
+ value="\U0001F3A7 " + "Streaming Generate",
516
+ variant="primary",
517
  )
518
 
519
+ text.input(
520
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
521
+ )
522
+
523
+ if_load_asr_model.change(
524
+ fn=change_if_load_asr_model,
525
+ inputs=[if_load_asr_model],
526
+ outputs=[if_load_asr_model],
527
+ )
528
+
529
+ if_auto_label.change(
530
+ fn=lambda: gr.Textbox(value=""),
531
+ inputs=[],
532
+ outputs=[reference_text],
533
+ ).then(
534
+ fn=change_if_auto_label,
535
+ inputs=[
536
+ if_load_asr_model,
537
+ if_auto_label,
538
+ enable_reference_audio,
539
+ reference_audio,
540
+ reference_text,
541
+ ],
542
+ outputs=[reference_text],
543
+ )
544
+
545
  # # Submit
546
  generate.click(
547
+ inference_wrapper,
548
  [
549
+ refined_text,
550
  enable_reference_audio,
551
  reference_audio,
552
  reference_text,
 
555
  top_p,
556
  repetition_penalty,
557
  temperature,
558
+ batch_infer_num,
559
+ if_load_asr_model,
560
  ],
561
+ [stream_audio, *global_audio_list, *global_error_list],
562
  concurrency_limit=1,
563
  )
564
 
565
+ generate_stream.click(
566
+ inference_stream,
567
+ [
568
+ refined_text,
569
+ enable_reference_audio,
570
+ reference_audio,
571
+ reference_text,
572
+ max_new_tokens,
573
+ chunk_length,
574
+ top_p,
575
+ repetition_penalty,
576
+ temperature,
577
+ ],
578
+ [stream_audio, global_audio_list[0], global_error_list[0]],
579
+ concurrency_limit=10,
580
+ )
581
  return app
582
 
583
 
 
586
  parser.add_argument(
587
  "--llama-checkpoint-path",
588
  type=Path,
589
+ default="checkpoints/fish-speech-1.2-sft",
590
  )
591
  parser.add_argument(
592
+ "--decoder-checkpoint-path",
 
 
 
593
  type=Path,
594
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
595
  )
596
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
 
597
  parser.add_argument("--device", type=str, default="cuda")
598
  parser.add_argument("--half", action="store_true")
 
599
  parser.add_argument("--compile", action="store_true")
600
  parser.add_argument("--max-gradio-length", type=int, default=0)
601
+ parser.add_argument("--theme", type=str, default="light")
602
 
603
  return parser.parse_args()
604
 
605
 
606
  if __name__ == "__main__":
607
  args = parse_args()
 
608
  args.precision = torch.half if args.half else torch.bfloat16
 
 
 
 
 
 
 
609
 
610
  logger.info("Loading Llama model...")
611
  llama_queue = launch_thread_safe_queue(
 
612
  checkpoint_path=args.llama_checkpoint_path,
613
  device=args.device,
614
  precision=args.precision,
 
615
  compile=args.compile,
616
  )
 
617
  logger.info("Llama model loaded, loading VQ-GAN model...")
618
 
619
+ decoder_model = load_decoder_model(
620
+ config_name=args.decoder_config_name,
621
+ checkpoint_path=args.decoder_checkpoint_path,
622
  device=args.device,
623
  )
624
 
625
+ logger.info("Decoder model loaded, warming up...")
626
 
627
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
628
+ list(
629
+ inference(
630
+ text="Hello, world!",
631
+ enable_reference_audio=False,
632
+ reference_audio=None,
633
+ reference_text="",
634
+ max_new_tokens=0,
635
+ chunk_length=100,
636
+ top_p=0.7,
637
+ repetition_penalty=1.2,
638
+ temperature=0.7,
639
+ )
640
  )
641
 
642
  logger.info("Warming up done, launching the web UI...")
643
 
644
  app = build_app()
645
+ app.launch(show_api=True)
fish_speech/configs/base.yaml CHANGED
@@ -17,6 +17,7 @@ trainer:
17
  devices: auto
18
  strategy:
19
  _target_: lightning.pytorch.strategies.DDPStrategy
 
20
 
21
  precision: bf16-mixed
22
 
 
17
  devices: auto
18
  strategy:
19
  _target_: lightning.pytorch.strategies.DDPStrategy
20
+ process_group_backend: nccl # This should be override when training on windows
21
 
22
  precision: bf16-mixed
23
 
fish_speech/configs/firefly_gan_vq.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
2
+ spec_transform:
3
+ _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
4
+ sample_rate: 44100
5
+ n_mels: 160
6
+ n_fft: 2048
7
+ hop_length: 512
8
+ win_length: 2048
9
+ backbone:
10
+ _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
11
+ input_channels: 160
12
+ depths: [3, 3, 9, 3]
13
+ dims: [128, 256, 384, 512]
14
+ drop_path_rate: 0.2
15
+ kernel_size: 7
16
+ head:
17
+ _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
18
+ hop_length: 512
19
+ upsample_rates: [8, 8, 2, 2, 2] # aka. strides
20
+ upsample_kernel_sizes: [16, 16, 4, 4, 4]
21
+ resblock_kernel_sizes: [3, 7, 11]
22
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23
+ num_mels: 512
24
+ upsample_initial_channel: 512
25
+ use_template: false
26
+ pre_conv_kernel_size: 13
27
+ post_conv_kernel_size: 13
28
+ quantizer:
29
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
30
+ input_dim: 512
31
+ n_groups: 4
32
+ n_codebooks: 1
33
+ levels: [8, 5, 5, 5]
34
+ downsample_factor: [2]
fish_speech/configs/lora/r_8_alpha_16.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: fish_speech.models.text2semantic.lora.LoraConfig
2
+ r: 8
3
+ lora_alpha: 16
4
+ lora_dropout: 0.01
fish_speech/configs/text2semantic_finetune.yaml CHANGED
@@ -1,18 +1,16 @@
1
  defaults:
2
  - base
3
- - [email protected]: dual_ar_2_codebook_small
4
  - _self_
5
 
6
  project: text2semantic_finetune_dual_ar
7
- max_length: 2048
8
- ckpt_path: checkpoints/text2semantic-medium-v1-2k.pth
9
- resume_weights_only: true
10
 
11
  # Lightning Trainer
12
  trainer:
13
  accumulate_grad_batches: 1
14
  gradient_clip_val: 1.0
15
- gradient_clip_algorithm: 'norm'
16
  max_steps: 1000
17
  precision: bf16-true
18
  limit_val_batches: 10
@@ -21,29 +19,31 @@ trainer:
21
  # Dataset Configuration
22
  tokenizer:
23
  _target_: transformers.AutoTokenizer.from_pretrained
24
- pretrained_model_name_or_path: fishaudio/fish-speech-1
25
 
26
  # Dataset Configuration
27
  train_dataset:
28
- _target_: fish_speech.datasets.text.AutoAugTextDataset
29
  proto_files:
30
  - data/protos
31
  tokenizer: ${tokenizer}
 
32
  max_length: ${max_length}
33
- num_codebooks: ${model.model.config.num_codebooks}
34
  use_speaker: false
 
35
 
36
  val_dataset:
37
- _target_: fish_speech.datasets.text.AutoAugTextDataset
38
  proto_files:
39
  - data/protos
40
  tokenizer: ${tokenizer}
 
41
  max_length: ${max_length}
42
- num_codebooks: ${model.model.config.num_codebooks}
43
  use_speaker: false
 
44
 
45
  data:
46
- _target_: fish_speech.datasets.text.TextDataModule
47
  train_dataset: ${train_dataset}
48
  val_dataset: ${val_dataset}
49
  num_workers: 4
@@ -53,13 +53,18 @@ data:
53
 
54
  # Model Configuration
55
  model:
56
- _target_: fish_speech.models.text2semantic.TextToSemantic
57
- model: {}
 
 
 
 
 
58
 
59
  optimizer:
60
  _target_: torch.optim.AdamW
61
  _partial_: true
62
- lr: 1e-5
63
  weight_decay: 0
64
  betas: [0.9, 0.95]
65
  eps: 1e-5
@@ -68,12 +73,11 @@ model:
68
  _target_: torch.optim.lr_scheduler.LambdaLR
69
  _partial_: true
70
  lr_lambda:
71
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
72
  _partial_: true
73
- num_warmup_steps: 100
74
- num_training_steps: ${trainer.max_steps}
75
 
76
  # Callbacks
77
  callbacks:
78
  model_checkpoint:
79
- every_n_train_steps: 100
 
1
  defaults:
2
  - base
 
3
  - _self_
4
 
5
  project: text2semantic_finetune_dual_ar
6
+ max_length: 4096
7
+ pretrained_ckpt_path: checkpoints/fish-speech-1.2-sft
 
8
 
9
  # Lightning Trainer
10
  trainer:
11
  accumulate_grad_batches: 1
12
  gradient_clip_val: 1.0
13
+ gradient_clip_algorithm: "norm"
14
  max_steps: 1000
15
  precision: bf16-true
16
  limit_val_batches: 10
 
19
  # Dataset Configuration
20
  tokenizer:
21
  _target_: transformers.AutoTokenizer.from_pretrained
22
+ pretrained_model_name_or_path: ${pretrained_ckpt_path}
23
 
24
  # Dataset Configuration
25
  train_dataset:
26
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
27
  proto_files:
28
  - data/protos
29
  tokenizer: ${tokenizer}
30
+ causal: true
31
  max_length: ${max_length}
 
32
  use_speaker: false
33
+ interactive_prob: 0.7
34
 
35
  val_dataset:
36
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
37
  proto_files:
38
  - data/protos
39
  tokenizer: ${tokenizer}
40
+ causal: true
41
  max_length: ${max_length}
 
42
  use_speaker: false
43
+ interactive_prob: 0.7
44
 
45
  data:
46
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
47
  train_dataset: ${train_dataset}
48
  val_dataset: ${val_dataset}
49
  num_workers: 4
 
53
 
54
  # Model Configuration
55
  model:
56
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
57
+ model:
58
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
59
+ path: ${pretrained_ckpt_path}
60
+ load_weights: true
61
+ max_length: ${max_length}
62
+ lora_config: null
63
 
64
  optimizer:
65
  _target_: torch.optim.AdamW
66
  _partial_: true
67
+ lr: 1e-4
68
  weight_decay: 0
69
  betas: [0.9, 0.95]
70
  eps: 1e-5
 
73
  _target_: torch.optim.lr_scheduler.LambdaLR
74
  _partial_: true
75
  lr_lambda:
76
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
77
  _partial_: true
78
+ num_warmup_steps: 10
 
79
 
80
  # Callbacks
81
  callbacks:
82
  model_checkpoint:
83
+ every_n_train_steps: ${trainer.val_check_interval}
fish_speech/datasets/concat_repeat.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import random
3
+ from typing import Iterable
4
+
5
+ from torch.utils.data import Dataset, IterableDataset
6
+
7
+
8
+ class ConcatRepeatDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+ repeats: list[int]
12
+
13
+ @staticmethod
14
+ def cumsum(sequence, repeats):
15
+ r, s = [], 0
16
+ for dataset, repeat in zip(sequence, repeats):
17
+ l = len(dataset) * repeat
18
+ r.append(l + s)
19
+ s += l
20
+ return r
21
+
22
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23
+ super().__init__()
24
+
25
+ self.datasets = list(datasets)
26
+ self.repeats = repeats
27
+
28
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29
+ assert len(self.datasets) == len(
30
+ repeats
31
+ ), "datasets and repeats should have the same length"
32
+
33
+ for d in self.datasets:
34
+ assert not isinstance(
35
+ d, IterableDataset
36
+ ), "ConcatRepeatDataset does not support IterableDataset"
37
+
38
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39
+
40
+ def __len__(self):
41
+ return self.cumulative_sizes[-1]
42
+
43
+ def __getitem__(self, idx):
44
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
+
46
+ if dataset_idx == 0:
47
+ sample_idx = idx
48
+ else:
49
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50
+
51
+ dataset = self.datasets[dataset_idx]
52
+
53
+ return dataset[sample_idx % len(dataset)]
fish_speech/datasets/semantic.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from itertools import chain
4
+ from pathlib import Path
5
+ from random import Random
6
+ from typing import Optional, Union
7
+
8
+ import numpy as np
9
+ import pyarrow.parquet as pq
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from datasets.download.streaming_download_manager import xopen
13
+ from huggingface_hub import HfApi
14
+ from lightning import LightningDataModule
15
+ from torch.distributed import get_rank, get_world_size, is_initialized
16
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
+ from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
+ from fish_speech.text.clean import clean_text
23
+ from fish_speech.utils import RankedLogger
24
+ from fish_speech.utils.braceexpand import braceexpand
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+
29
+ def split_by_rank_worker(files):
30
+ # We need to know the total number of devices
31
+ # to split the data properly
32
+
33
+ total_devices = 1
34
+ if is_initialized():
35
+ total_devices = get_world_size()
36
+
37
+ worker_info = get_worker_info()
38
+ if worker_info is not None:
39
+ total_devices *= worker_info.num_workers
40
+
41
+ if len(files) < total_devices:
42
+ # Repeat the files N times to match the number of devices
43
+ files = files * (total_devices // len(files) + 1)
44
+
45
+ # DDP
46
+ if is_initialized():
47
+ files = files[get_rank() :: get_world_size()]
48
+
49
+ # Split by worker
50
+ if worker_info is not None:
51
+ files = files[worker_info.id :: worker_info.num_workers]
52
+
53
+ return files
54
+
55
+
56
+ class AutoTextSemanticInstructionDataset(IterableDataset):
57
+ """
58
+ Auto Augment Dataset by Speaker
59
+
60
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
61
+ 2. Automatically normalize the text
62
+
63
+ For interactive mode, we use the following format (multiple sequences):
64
+ <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
65
+
66
+ For non-interactive mode, we use the following format (one long sequence):
67
+ <s> [INST] text [/INST] ... </s>
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ proto_files: list[str],
73
+ seed: int = 42,
74
+ interactive_prob: float = 0.5,
75
+ max_length: int = 1024,
76
+ tokenizer: AutoTokenizer = None,
77
+ use_speaker: bool | float = True,
78
+ causal: bool = True,
79
+ num_codebooks: Optional[int] = None,
80
+ skip_text_prob: float = 0.0,
81
+ ):
82
+ """
83
+ Args:
84
+ proto_files: proto buf files if using local data
85
+ seed: random seed
86
+ interactive_prob: probability to use interactive mode
87
+ max_length: max length of the text
88
+ tokenizer: tokenizer
89
+ use_speaker: include speaker information in the prompt
90
+ causal: use causal sampling when using local data, disable will lead to random sampling
91
+ num_codebooks: number of codebooks, if None, it will be automatically detected
92
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
93
+ """
94
+
95
+ super().__init__()
96
+
97
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
98
+
99
+ self.seed = seed
100
+ self.max_length = max_length
101
+ self.tokenizer = tokenizer
102
+ self.interactive_prob = interactive_prob
103
+ self.use_speaker = use_speaker
104
+ self.proto_files = proto_files
105
+ self.causal = causal
106
+ self.num_codebooks = num_codebooks
107
+ self.skip_text_prob = skip_text_prob
108
+
109
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
110
+ self.groups = None
111
+
112
+ def init_mock_data_server(self):
113
+ if self.groups is not None:
114
+ return
115
+
116
+ # Expand the proto files
117
+ expanded_proto_files = []
118
+ for filename in self.proto_files:
119
+ for i in braceexpand(filename):
120
+ i = Path(i)
121
+ if i.is_file():
122
+ expanded_proto_files.append(i)
123
+ elif i.is_dir():
124
+ expanded_proto_files.extend(i.rglob("*.proto"))
125
+ expanded_proto_files.extend(i.rglob("*.protos"))
126
+ else:
127
+ raise ValueError(f"{i} is not a file or directory")
128
+
129
+ expanded_proto_files = sorted(expanded_proto_files)
130
+ Random(self.seed).shuffle(expanded_proto_files)
131
+
132
+ self.groups = []
133
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
134
+ log.info(
135
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
136
+ )
137
+
138
+ count = 0
139
+ for filename in shard_proto_files:
140
+ with open(filename, "rb") as f:
141
+ for text_data in read_pb_stream(f):
142
+ self.groups.append(text_data)
143
+ count += 1
144
+
145
+ log.info(f"Read total {count} groups of data")
146
+
147
+ # Shuffle the lines
148
+ Random(self.seed).shuffle(self.groups)
149
+ self.group_weights = [len(i.sentences) for i in self.groups]
150
+
151
+ def __iter__(self):
152
+ while True:
153
+ yield self.augment()
154
+
155
+ def tokenize_sentence(self, sentence: str):
156
+ sentence = clean_text(sentence)
157
+ tokens = self.tokenizer.encode(
158
+ f"{sentence}",
159
+ max_length=10**6,
160
+ add_special_tokens=False,
161
+ truncation=False,
162
+ )
163
+ return sentence, len(tokens)
164
+
165
+ def sample_data(self):
166
+ if self.groups is None:
167
+ self.init_mock_data_server()
168
+
169
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
170
+ num_samples = self.max_length // 20
171
+
172
+ # choice group based on their number of samples
173
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
174
+
175
+ if self.causal:
176
+ # Sample in order
177
+ if num_samples >= len(group.sentences):
178
+ samples = group.sentences
179
+ else:
180
+ begin = random.randint(0, len(group.sentences) - num_samples)
181
+ samples = group.sentences[begin : begin + num_samples]
182
+ else:
183
+ samples = random.choices(
184
+ group.sentences, k=min(num_samples, len(group.sentences))
185
+ )
186
+
187
+ return SampledData(
188
+ source=group.source,
189
+ name=group.name,
190
+ samples=samples,
191
+ )
192
+
193
+ def augment(self):
194
+ final_text, final_semantic = [], []
195
+ response = self.sample_data()
196
+ if len(response.samples) == 0:
197
+ # Invalid group
198
+ return None
199
+
200
+ samples = list(response.samples)
201
+ idx = 0
202
+ use_interactive = random.random() < self.interactive_prob
203
+
204
+ if use_interactive is False:
205
+ # Random sample based on speaker using a truncated normal distribution
206
+ a = torch.tensor([0], dtype=torch.float32)
207
+ torch.nn.init.trunc_normal_(
208
+ a,
209
+ mean=self.max_length // 2,
210
+ std=self.max_length // 4,
211
+ a=10,
212
+ b=self.max_length,
213
+ )
214
+ remaining_tokens = a.long().item() - 4
215
+ else:
216
+ remaining_tokens = self.max_length
217
+
218
+ # Use speaker
219
+ if isinstance(self.use_speaker, float):
220
+ use_speaker = random.random() < self.use_speaker
221
+ else:
222
+ use_speaker = self.use_speaker
223
+
224
+ all_tokens, all_labels = [], []
225
+ while remaining_tokens > 0 and len(samples) > 0:
226
+ sentence = samples.pop(0)
227
+
228
+ text = random.choice(sentence.texts)
229
+ text, length = self.tokenize_sentence(text)
230
+ remaining_tokens -= length + len(sentence.semantics[0].values)
231
+
232
+ if use_interactive is False:
233
+ final_text.append(text)
234
+ final_semantic.append(sentence.semantics)
235
+ else:
236
+ # For interactive mode, we only apply speaker for the first sentence
237
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
238
+ tokens, labels = self.pack_sentences(
239
+ sentences=[text],
240
+ semantics=[sentence.semantics],
241
+ speaker=response.name if use_speaker else None,
242
+ skip_text=random.random() < self.skip_text_prob,
243
+ )
244
+
245
+ all_tokens.append(tokens)
246
+ all_labels.append(labels)
247
+
248
+ idx += 1
249
+
250
+ if use_interactive is False:
251
+ tokens, labels = self.pack_sentences(
252
+ final_text,
253
+ semantics=final_semantic,
254
+ speaker=response.name if use_speaker else None,
255
+ )
256
+ all_tokens.append(tokens)
257
+ all_labels.append(labels)
258
+
259
+ tokens = torch.cat(all_tokens, dim=1)
260
+ labels = torch.cat(all_labels, dim=1)
261
+
262
+ # Verify that the length is correct
263
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
264
+
265
+ data = {"tokens": tokens, "labels": labels}
266
+
267
+ return data
268
+
269
+ def pack_sentences(
270
+ self,
271
+ sentences: list[str],
272
+ semantics: list,
273
+ speaker: Optional[str] = None,
274
+ skip_text: bool = False,
275
+ ):
276
+ if speaker is None:
277
+ speaker = "assistant"
278
+
279
+ cated_sentences = " ".join(sentences)
280
+ if skip_text:
281
+ cated_sentences = "<|skip_text|>"
282
+
283
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
284
+ final_text = final_text + f"<|im_start|>{speaker}\n"
285
+
286
+ encoded = self.tokenizer.encode(
287
+ final_text,
288
+ add_special_tokens=False,
289
+ truncation=False,
290
+ max_length=10**6,
291
+ )
292
+ semantic_length = sum([len(i[0].values) for i in semantics])
293
+ prompt_length = len(encoded)
294
+ num_codebooks = (
295
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
296
+ )
297
+
298
+ # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
299
+ tokens = (
300
+ encoded
301
+ + [self.semantic_token_id] * semantic_length
302
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
303
+ )
304
+
305
+ # Codebook bos/padding: 0, eos: 1
306
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
307
+ for segment in semantics:
308
+ for book_idx, book in zip(range(num_codebooks), segment):
309
+ for j in book.values:
310
+ codes[book_idx].append(int(j) + 1)
311
+
312
+ for book in codes:
313
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
314
+
315
+ tokens = [tokens] + codes
316
+
317
+ tokens = torch.tensor(tokens, dtype=torch.long)
318
+ labels = tokens.clone()
319
+
320
+ if skip_text:
321
+ # If text is not provided, the sentence is used for condition only, all labels are -100
322
+ torch.fill_(labels, -100)
323
+ return tokens, labels
324
+
325
+ # Mask out the <s> tokens for semantic, predict semantic tokens only
326
+ # Since we don't mask out the input tokens, the language modeling still works
327
+ labels[1:, :prompt_length] = -100
328
+
329
+ tokens = tokens[:, :-1]
330
+ labels = labels[:, 1:]
331
+
332
+ # Verify the padding is correct, and the last token is eos
333
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
334
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
335
+
336
+ return tokens, labels
337
+
338
+
339
+ @dataclass
340
+ class TextDataCollator:
341
+ tokenizer: AutoTokenizer
342
+ max_length: int = 1024
343
+
344
+ def __call__(self, examples):
345
+ if "negative_tokens" in examples:
346
+ positive_examples = []
347
+ negative_examples = []
348
+
349
+ for i in examples:
350
+ positive_examples.append(
351
+ {
352
+ "tokens": i["tokens"],
353
+ "labels": i["labels"],
354
+ }
355
+ )
356
+ negative_examples.append(
357
+ {
358
+ "tokens": i["negative_tokens"],
359
+ "labels": i["negative_labels"],
360
+ }
361
+ )
362
+
363
+ examples = positive_examples + negative_examples
364
+
365
+ return self.batchify(examples)
366
+
367
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
368
+ tokens, attention_masks, labels = [], [], []
369
+
370
+ # Calculate the max length
371
+ max_tokens_length = 0
372
+ for example in examples:
373
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
374
+ max_tokens_length = min(max_tokens_length, self.max_length)
375
+
376
+ for example in examples:
377
+ _tokens = example[tokens_key][:, :max_tokens_length]
378
+ _labels = example[labels_key][:, :max_tokens_length]
379
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
380
+ tokens_length = _tokens.size(1)
381
+ _attention_mask[:tokens_length] = False
382
+
383
+ assert tokens_length == _labels.size(
384
+ 1
385
+ ), f"{tokens_length} != {_labels.size(1)}"
386
+
387
+ if tokens_length < max_tokens_length:
388
+ _tokens = F.pad(
389
+ _tokens,
390
+ (0, max_tokens_length - tokens_length),
391
+ value=self.tokenizer.eos_token_id,
392
+ )
393
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
394
+ _labels = F.pad(
395
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
396
+ )
397
+
398
+ tokens.append(_tokens)
399
+ attention_masks.append(_attention_mask)
400
+ labels.append(_labels)
401
+
402
+ tokens = torch.stack(tokens, dim=0)
403
+ attention_masks = torch.stack(attention_masks, dim=0)
404
+ labels = torch.stack(labels, dim=0)
405
+
406
+ return {
407
+ "inputs": tokens,
408
+ "attention_masks": attention_masks,
409
+ "labels": labels,
410
+ }
411
+
412
+
413
+ class InterleaveDataset(IterableDataset):
414
+ def __init__(
415
+ self,
416
+ datasets: list[IterableDataset],
417
+ probabilities: list[float],
418
+ seed: int = 42,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.datasets = datasets
423
+ self.probabilities = probabilities
424
+ self.seed = seed
425
+
426
+ def __iter__(self):
427
+ rng = np.random.default_rng(self.seed)
428
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
429
+
430
+ while True:
431
+ # Random choice one
432
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
433
+ dataset_iterator = dataset_iterators[dataset_idx]
434
+
435
+ try:
436
+ yield next(dataset_iterator)
437
+ except StopIteration:
438
+ # Exhausted, create a new iterator
439
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
440
+ yield next(dataset_iterators[dataset_idx])
441
+
442
+
443
+ class SemanticDataModule(LightningDataModule):
444
+ def __init__(
445
+ self,
446
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
447
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
448
+ batch_size: int = 32,
449
+ tokenizer: AutoTokenizer = None,
450
+ max_length: int = 1024,
451
+ num_workers: int = 4,
452
+ ):
453
+ super().__init__()
454
+
455
+ self.train_dataset = train_dataset
456
+ self.val_dataset = val_dataset
457
+ self.batch_size = batch_size
458
+ self.tokenizer = tokenizer
459
+ self.max_length = max_length
460
+ self.num_workers = num_workers
461
+
462
+ def train_dataloader(self):
463
+ return DataLoader(
464
+ self.train_dataset,
465
+ batch_size=self.batch_size,
466
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
467
+ num_workers=self.num_workers,
468
+ persistent_workers=True,
469
+ )
470
+
471
+ def val_dataloader(self):
472
+ return DataLoader(
473
+ self.val_dataset,
474
+ batch_size=self.batch_size,
475
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
476
+ num_workers=self.num_workers,
477
+ persistent_workers=True,
478
+ )
479
+
480
+
481
+ if __name__ == "__main__":
482
+ from tqdm import tqdm
483
+
484
+ ds = AutoTextSemanticInstructionDataset(
485
+ ["data/protos"],
486
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
487
+ use_speaker=False,
488
+ interactive_prob=1.0,
489
+ skip_text_prob=0.5,
490
+ )
491
+
492
+ for i in ds:
493
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
494
+ # i["labels"][0][i["labels"][0] == -100] = 0
495
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
496
+ break
fish_speech/datasets/vqgan.py CHANGED
@@ -28,7 +28,7 @@ class VQGANDataset(Dataset):
28
 
29
  self.files = [
30
  root / line.strip()
31
- for line in filelist.read_text().splitlines()
32
  if line.strip()
33
  ]
34
  self.sample_rate = sample_rate
@@ -120,6 +120,7 @@ class VQGANDataModule(LightningDataModule):
120
  collate_fn=VQGANCollator(),
121
  num_workers=self.num_workers,
122
  shuffle=True,
 
123
  )
124
 
125
  def val_dataloader(self):
@@ -128,6 +129,7 @@ class VQGANDataModule(LightningDataModule):
128
  batch_size=self.val_batch_size,
129
  collate_fn=VQGANCollator(),
130
  num_workers=self.num_workers,
 
131
  )
132
 
133
 
 
28
 
29
  self.files = [
30
  root / line.strip()
31
+ for line in filelist.read_text(encoding="utf-8").splitlines()
32
  if line.strip()
33
  ]
34
  self.sample_rate = sample_rate
 
120
  collate_fn=VQGANCollator(),
121
  num_workers=self.num_workers,
122
  shuffle=True,
123
+ persistent_workers=True,
124
  )
125
 
126
  def val_dataloader(self):
 
129
  batch_size=self.val_batch_size,
130
  collate_fn=VQGANCollator(),
131
  num_workers=self.num_workers,
132
+ persistent_workers=True,
133
  )
134
 
135
 
fish_speech/models/text2semantic/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
- from .lit_module import TextToSemantic
2
-
3
- __all__ = ["TextToSemantic"]
 
 
 
 
fish_speech/models/text2semantic/lit_module.py CHANGED
@@ -1,110 +1,40 @@
1
- from dataclasses import dataclass
2
  from typing import Any, Optional
3
 
4
  import lightning as L
5
- import loralib as lora
6
  import torch
7
  import torch.nn.functional as F
8
  from lightning.pytorch.utilities.types import OptimizerLRScheduler
9
 
10
  import fish_speech.utils as utils
 
11
  from fish_speech.models.text2semantic.llama import NaiveTransformer
12
 
13
  log = utils.RankedLogger(__name__, rank_zero_only=True)
14
 
15
 
16
- @dataclass
17
- class LoraConfig:
18
- r: int
19
- lora_alpha: float
20
- lora_dropout: float = 0.0
21
-
22
-
23
  class TextToSemantic(L.LightningModule):
24
  def __init__(
25
  self,
26
  model: NaiveTransformer,
27
  optimizer: Any,
28
  lr_scheduler: Any,
29
- lora_config: Optional[LoraConfig] = None,
30
- save_lora_only: bool = False,
31
- use_dpo: bool = False,
32
- dpo_beta: float = 0.2,
33
  ):
34
  super().__init__()
35
 
36
  self.model = model
37
  self.optimizer_builder = optimizer
38
  self.lr_scheduler_builder = lr_scheduler
39
- self.lora_config = lora_config
40
- self.save_lora_only = save_lora_only
41
- self.use_dpo = use_dpo # We don't support reference model yet
42
- self.dpo_beta = dpo_beta
43
-
44
- if self.lora_config is not None:
45
- self.setup_lora()
46
-
47
- def setup_lora(self):
48
- # Replace the embedding layer with a LoRA layer
49
- self.model.embeddings = lora.Embedding(
50
- num_embeddings=self.model.embeddings.num_embeddings,
51
- embedding_dim=self.model.embeddings.embedding_dim,
52
- padding_idx=self.model.embeddings.padding_idx,
53
- r=self.lora_config.r,
54
- lora_alpha=self.lora_config.lora_alpha,
55
- )
56
-
57
- # Replace output layer with a LoRA layer
58
- linears = [(self.model, "output")]
59
-
60
- # Replace all linear layers with LoRA layers
61
- for layer in self.model.layers:
62
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
63
- linears.extend(
64
- [
65
- (layer.feed_forward, "w1"),
66
- (layer.feed_forward, "w2"),
67
- (layer.feed_forward, "w3"),
68
- ]
69
- )
70
-
71
- if hasattr(self.model, "fast_layers"):
72
- # Dual-AR model
73
- linears.extend([(self.model, "fast_output")])
74
-
75
- for layer in self.model.fast_layers:
76
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
77
- linears.extend(
78
- [
79
- (layer.feed_forward, "w1"),
80
- (layer.feed_forward, "w2"),
81
- (layer.feed_forward, "w3"),
82
- ]
83
- )
84
-
85
- for module, layer in linears:
86
- updated_linear = lora.Linear(
87
- in_features=getattr(module, layer).in_features,
88
- out_features=getattr(module, layer).out_features,
89
- bias=getattr(module, layer).bias,
90
- r=self.lora_config.r,
91
- lora_alpha=self.lora_config.lora_alpha,
92
- lora_dropout=self.lora_config.lora_dropout,
93
- )
94
- setattr(module, layer, updated_linear)
95
-
96
- # Mark only the LoRA layers as trainable
97
- lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
98
 
99
  def forward(self, x):
100
  return self.model(x)
101
 
102
  def on_save_checkpoint(self, checkpoint):
103
- if self.lora_config is None or self.save_lora_only is False:
104
- return
105
-
106
  # Save only LoRA parameters
107
  state_dict = checkpoint["state_dict"]
 
 
 
 
108
  for name in list(state_dict.keys()):
109
  if "lora" not in name:
110
  state_dict.pop(name)
@@ -178,6 +108,11 @@ class TextToSemantic(L.LightningModule):
178
  def _step(self, batch, batch_idx, stage: str):
179
  is_train = stage == "train"
180
 
 
 
 
 
 
181
  # Do positive and negative samples in the same batch to speed up training
182
  labels = batch["labels"]
183
  outputs = self.model(
@@ -187,92 +122,22 @@ class TextToSemantic(L.LightningModule):
187
  token_logits = outputs.token_logits
188
  codebook_logits = outputs.codebook_logits
189
 
190
- if self.use_dpo:
191
- # Firtst half is positive, second half is negative
192
- token_logits, negative_token_logits = token_logits.chunk(2)
193
- codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
194
- labels, negative_labels = labels.chunk(2)
195
-
196
  # Generate labels
197
  base_loss = F.cross_entropy(
198
- token_logits.reshape(-1, token_logits.size(-1)),
199
  labels[:, 0].reshape(-1),
200
  ignore_index=-100,
201
  )
202
 
203
  codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
204
  semantic_loss = F.cross_entropy(
205
- codebook_logits.reshape(-1, codebook_logits.size(-1)),
206
  codebook_labels.reshape(-1),
207
  ignore_index=-100,
208
  )
209
 
210
  loss = base_loss + semantic_loss
211
 
212
- # If we use dpo
213
- if self.use_dpo:
214
- negative_codebook_labels = negative_labels[
215
- :, 1 : 1 + self.model.config.num_codebooks
216
- ].mT
217
-
218
- positive_codebook_logps = self.get_batch_logps(
219
- codebook_logits, codebook_labels
220
- )
221
- negative_codebook_logps = self.get_batch_logps(
222
- negative_codebook_logits, negative_codebook_labels
223
- )
224
-
225
- # TODO: implement the reference model, avoid screwing up the gradients
226
- dpo_loss = -F.logsigmoid(
227
- (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
228
- ).mean()
229
-
230
- chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
231
- rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
232
- reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
233
- chosen_rewards, rejected_rewards = (
234
- chosen_rewards.mean(),
235
- rejected_rewards.mean(),
236
- )
237
-
238
- loss = loss + dpo_loss
239
-
240
- self.log(
241
- f"{stage}/dpo_loss",
242
- dpo_loss,
243
- on_step=is_train,
244
- on_epoch=not is_train,
245
- prog_bar=False,
246
- logger=True,
247
- )
248
-
249
- self.log(
250
- f"{stage}/chosen_rewards",
251
- chosen_rewards,
252
- on_step=is_train,
253
- on_epoch=not is_train,
254
- prog_bar=False,
255
- logger=True,
256
- )
257
-
258
- self.log(
259
- f"{stage}/rejected_rewards",
260
- rejected_rewards,
261
- on_step=is_train,
262
- on_epoch=not is_train,
263
- prog_bar=False,
264
- logger=True,
265
- )
266
-
267
- self.log(
268
- f"{stage}/reward_accuracy",
269
- reward_accuracy,
270
- on_step=is_train,
271
- on_epoch=not is_train,
272
- prog_bar=False,
273
- logger=True,
274
- )
275
-
276
  self.log(
277
  f"{stage}/loss",
278
  loss,
@@ -280,6 +145,7 @@ class TextToSemantic(L.LightningModule):
280
  on_epoch=not is_train,
281
  prog_bar=True,
282
  logger=True,
 
283
  )
284
 
285
  self.log(
@@ -289,6 +155,7 @@ class TextToSemantic(L.LightningModule):
289
  on_epoch=not is_train,
290
  prog_bar=False,
291
  logger=True,
 
292
  )
293
 
294
  self.log(
@@ -298,6 +165,7 @@ class TextToSemantic(L.LightningModule):
298
  on_epoch=not is_train,
299
  prog_bar=False,
300
  logger=True,
 
301
  )
302
 
303
  # Top-5 accuracy
@@ -309,31 +177,21 @@ class TextToSemantic(L.LightningModule):
309
  on_epoch=not is_train,
310
  prog_bar=True,
311
  logger=True,
 
312
  )
313
 
314
- if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
315
- accuracy = self.get_accuracy(
316
- codebook_logits[:, :, : self.model.config.num_in_codebooks],
317
- codebook_labels[:, :, : self.model.config.num_in_codebooks],
318
- )
319
-
320
- self.log(
321
- f"{stage}/top_5_accuracy_in",
322
- accuracy,
323
- on_step=is_train,
324
- on_epoch=not is_train,
325
- prog_bar=True,
326
- logger=True,
327
- )
328
-
329
  return loss
330
 
331
  def get_accuracy(self, logits, labels):
 
 
 
 
332
  _, indices = logits.topk(5, dim=-1)
333
  correct = indices.eq(labels.unsqueeze(-1))
334
- correct[labels == -100] = 0
335
  correct = correct.sum()
336
- accuracy = correct / (labels != -100).sum()
337
 
338
  return accuracy
339
 
 
 
1
  from typing import Any, Optional
2
 
3
  import lightning as L
 
4
  import torch
5
  import torch.nn.functional as F
6
  from lightning.pytorch.utilities.types import OptimizerLRScheduler
7
 
8
  import fish_speech.utils as utils
9
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
10
  from fish_speech.models.text2semantic.llama import NaiveTransformer
11
 
12
  log = utils.RankedLogger(__name__, rank_zero_only=True)
13
 
14
 
 
 
 
 
 
 
 
15
  class TextToSemantic(L.LightningModule):
16
  def __init__(
17
  self,
18
  model: NaiveTransformer,
19
  optimizer: Any,
20
  lr_scheduler: Any,
 
 
 
 
21
  ):
22
  super().__init__()
23
 
24
  self.model = model
25
  self.optimizer_builder = optimizer
26
  self.lr_scheduler_builder = lr_scheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def forward(self, x):
29
  return self.model(x)
30
 
31
  def on_save_checkpoint(self, checkpoint):
 
 
 
32
  # Save only LoRA parameters
33
  state_dict = checkpoint["state_dict"]
34
+ use_lora = any("lora" in name for name in state_dict.keys())
35
+ if not use_lora:
36
+ return
37
+
38
  for name in list(state_dict.keys()):
39
  if "lora" not in name:
40
  state_dict.pop(name)
 
108
  def _step(self, batch, batch_idx, stage: str):
109
  is_train = stage == "train"
110
 
111
+ if is_train:
112
+ # Key part to make lora work
113
+ # Otherwise the parameters are merged, which lead to incorrect gradients
114
+ self.model.train()
115
+
116
  # Do positive and negative samples in the same batch to speed up training
117
  labels = batch["labels"]
118
  outputs = self.model(
 
122
  token_logits = outputs.token_logits
123
  codebook_logits = outputs.codebook_logits
124
 
 
 
 
 
 
 
125
  # Generate labels
126
  base_loss = F.cross_entropy(
127
+ token_logits.view(-1, token_logits.size(-1)),
128
  labels[:, 0].reshape(-1),
129
  ignore_index=-100,
130
  )
131
 
132
  codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
133
  semantic_loss = F.cross_entropy(
134
+ codebook_logits.view(-1, codebook_logits.size(-1)),
135
  codebook_labels.reshape(-1),
136
  ignore_index=-100,
137
  )
138
 
139
  loss = base_loss + semantic_loss
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  self.log(
142
  f"{stage}/loss",
143
  loss,
 
145
  on_epoch=not is_train,
146
  prog_bar=True,
147
  logger=True,
148
+ sync_dist=not is_train,
149
  )
150
 
151
  self.log(
 
155
  on_epoch=not is_train,
156
  prog_bar=False,
157
  logger=True,
158
+ sync_dist=not is_train,
159
  )
160
 
161
  self.log(
 
165
  on_epoch=not is_train,
166
  prog_bar=False,
167
  logger=True,
168
+ sync_dist=not is_train,
169
  )
170
 
171
  # Top-5 accuracy
 
177
  on_epoch=not is_train,
178
  prog_bar=True,
179
  logger=True,
180
+ sync_dist=not is_train,
181
  )
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  return loss
184
 
185
  def get_accuracy(self, logits, labels):
186
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
187
+ if mask.sum() == 0:
188
+ return torch.tensor(0.0, device=logits.device)
189
+
190
  _, indices = logits.topk(5, dim=-1)
191
  correct = indices.eq(labels.unsqueeze(-1))
192
+ correct[~mask] = 0
193
  correct = correct.sum()
194
+ accuracy = correct / mask.sum()
195
 
196
  return accuracy
197
 
fish_speech/models/text2semantic/llama.py CHANGED
@@ -1,13 +1,25 @@
 
1
  import math
2
  from dataclasses import dataclass
 
3
  from typing import Optional
4
 
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
 
8
  from torch import Tensor
9
  from torch.nn import functional as F
 
10
  from torch.utils.checkpoint import checkpoint
 
 
 
 
 
 
 
 
11
 
12
 
13
  def find_multiple(n: int, k: int) -> int:
@@ -18,6 +30,8 @@ def find_multiple(n: int, k: int) -> int:
18
 
19
  @dataclass
20
  class BaseModelArgs:
 
 
21
  vocab_size: int = 32000
22
  n_layer: int = 32
23
  n_head: int = 32
@@ -29,16 +43,19 @@ class BaseModelArgs:
29
  norm_eps: float = 1e-5
30
  max_seq_len: int = 2048
31
  dropout: float = 0.0
 
 
32
 
33
  # Codebook configs
34
  codebook_size: int = 160
35
  num_codebooks: int = 4
36
- num_in_codebooks: Optional[int] = None
37
- codebook_padding_idx: int = 0
38
 
39
  # Gradient checkpointing
40
  use_gradient_checkpointing: bool = True
41
 
 
 
 
42
  def __post_init__(self):
43
  if self.n_local_heads == -1:
44
  self.n_local_heads = self.n_head
@@ -46,18 +63,41 @@ class BaseModelArgs:
46
  hidden_dim = 4 * self.dim
47
  n_hidden = int(2 * hidden_dim / 3)
48
  self.intermediate_size = find_multiple(n_hidden, 256)
49
- if self.num_in_codebooks is None:
50
- self.num_in_codebooks = self.num_codebooks
51
  self.head_dim = self.dim // self.n_head
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  @dataclass
55
  class NaiveModelArgs(BaseModelArgs):
56
- pass
57
 
58
 
59
  @dataclass
60
  class DualARModelArgs(BaseModelArgs):
 
61
  n_fast_layer: int = 4
62
 
63
 
@@ -95,24 +135,35 @@ class BaseTransformerForwardResult:
95
 
96
 
97
  class BaseTransformer(nn.Module):
98
- def __init__(self, config: BaseModelArgs) -> None:
 
 
99
  super().__init__()
100
  self.config = config
 
 
 
101
 
102
  # Slow transformer
103
  self.embeddings = nn.Embedding(
104
- config.vocab_size + config.codebook_size * config.num_in_codebooks,
 
 
 
 
105
  config.dim,
106
  )
107
  self.layers = nn.ModuleList(
108
  TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
109
  )
110
  self.norm = RMSNorm(config.dim, eps=config.norm_eps)
111
- self.output = nn.Linear(
112
- config.dim,
113
- config.vocab_size,
114
- bias=False,
115
- )
 
 
116
 
117
  self.register_buffer(
118
  "freqs_cis",
@@ -139,6 +190,9 @@ class BaseTransformer(nn.Module):
139
  self.max_batch_size = -1
140
  self.max_seq_len = -1
141
 
 
 
 
142
  def setup_caches(
143
  self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
144
  ):
@@ -161,11 +215,9 @@ class BaseTransformer(nn.Module):
161
 
162
  def embed(self, x: Tensor) -> Tensor:
163
  vocab_embeds = [self.embeddings(x[:, 0])]
164
- for i in range(self.config.num_in_codebooks):
165
- emb = self.embeddings(
166
- x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
167
- )
168
- emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
169
  vocab_embeds.append(emb)
170
 
171
  x = torch.stack(vocab_embeds, dim=3)
@@ -174,21 +226,23 @@ class BaseTransformer(nn.Module):
174
  return x
175
 
176
  def forward(
177
- self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
 
 
178
  ) -> BaseTransformerForwardResult:
179
- # x: (batch, num_codebooks + 1, seq_len)
180
  seq_len = inp.size(2)
181
 
182
  # Here we want to merge the embeddings of the codebooks
183
  x = self.embed(inp)
184
 
185
- mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
186
  freqs_cis = self.freqs_cis[:seq_len]
187
 
188
  # Not that the causal mask here follows the definition of scaled_dot_product_attention
189
  # That is, FALSE means masked out
190
  # To maintain consistency, key_padding_mask use TRUE to mask out
 
191
  if key_padding_mask is not None:
 
192
  mask = mask & key_padding_mask[:, None, None, :].logical_not()
193
 
194
  for layer in self.layers:
@@ -199,7 +253,11 @@ class BaseTransformer(nn.Module):
199
 
200
  # We got slow_out here
201
  slow_out = self.norm(x)
202
- token_logits = self.output(slow_out)
 
 
 
 
203
 
204
  return BaseTransformerForwardResult(
205
  logits=token_logits,
@@ -207,7 +265,10 @@ class BaseTransformer(nn.Module):
207
  )
208
 
209
  def forward_generate(
210
- self, x: Tensor, input_pos: Optional[Tensor] = None
 
 
 
211
  ) -> BaseTransformerForwardResult:
212
  # This is used for generation, optimized for torch compile
213
  assert (
@@ -225,22 +286,117 @@ class BaseTransformer(nn.Module):
225
  x = layer(x, freqs_cis, mask, input_pos=input_pos)
226
 
227
  # If prefill, we only calculate the logits of last token
228
- if x.size(1) > 1:
229
  x = x[:, -1:]
230
 
231
  # We got slow_out here
232
  slow_out = self.norm(x)
233
- token_logits = self.output(slow_out)
 
 
 
 
234
 
235
  return BaseTransformerForwardResult(
236
  logits=token_logits,
237
  hidden_states=x,
238
  )
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  class NaiveTransformer(BaseTransformer):
242
- def __init__(self, config: NaiveModelArgs) -> None:
243
- super().__init__(config)
244
 
245
  self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
246
  self.codebook_output = nn.Linear(
@@ -249,6 +405,8 @@ class NaiveTransformer(BaseTransformer):
249
  bias=False,
250
  )
251
 
 
 
252
  def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
253
  token_logits = result.logits
254
  x = result.hidden_states
@@ -265,9 +423,14 @@ class NaiveTransformer(BaseTransformer):
265
  )
266
 
267
  def forward(
268
- self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
 
 
269
  ) -> TransformerForwardResult:
270
- result = super().forward(inp, key_padding_mask)
 
 
 
271
  return self.decode(result)
272
 
273
  def forward_generate(
@@ -278,13 +441,11 @@ class NaiveTransformer(BaseTransformer):
278
 
279
 
280
  class DualARTransformer(BaseTransformer):
281
- def __init__(self, config: DualARModelArgs) -> None:
282
- super().__init__(config)
283
 
284
  # Fast transformer
285
- self.fast_embeddings = nn.Embedding(
286
- config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
287
- )
288
 
289
  # The equivalent bs is so large that sdpa doesn't work
290
  self.fast_layers = nn.ModuleList(
@@ -297,6 +458,8 @@ class DualARTransformer(BaseTransformer):
297
  bias=False,
298
  )
299
 
 
 
300
  def setup_caches(
301
  self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
302
  ):
@@ -316,7 +479,9 @@ class DualARTransformer(BaseTransformer):
316
  )
317
 
318
  def forward(
319
- self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
 
 
320
  ) -> TransformerForwardResult:
321
  parent_result = super().forward(inp, key_padding_mask)
322
  token_logits = parent_result.logits
@@ -331,7 +496,7 @@ class DualARTransformer(BaseTransformer):
331
 
332
  # Drop the last token and rotate left
333
  codebooks = inp[:, 1:-1, 1:]
334
- codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
335
  codebook_embeddings = self.fast_embeddings(codebooks)
336
  x = torch.cat([x[:, None], codebook_embeddings], dim=1)
337
  b, s = x.size(0), x.size(2)
@@ -339,7 +504,12 @@ class DualARTransformer(BaseTransformer):
339
 
340
  # Remove padded part
341
  codebooks = rearrange(codebooks, "b n s -> (b s) n")
342
- codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
 
 
 
 
 
343
  x_bs, x_len = x.size(0), x.size(1)
344
  x = x[~codebook_mask]
345
 
@@ -422,7 +592,9 @@ class Attention(nn.Module):
422
 
423
  total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
424
  # key, query, value projections for all heads, but in a batch
425
- self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
 
 
426
  self.wo = nn.Linear(config.dim, config.dim, bias=False)
427
  self.kv_cache = None
428
 
@@ -469,13 +641,24 @@ class Attention(nn.Module):
469
  v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
470
 
471
  if self.use_sdpa:
472
- y = F.scaled_dot_product_attention(
473
- q,
474
- k,
475
- v,
476
- attn_mask=mask,
477
- dropout_p=self.dropout if self.training else 0.0,
478
- )
 
 
 
 
 
 
 
 
 
 
 
479
  else:
480
  y = self.eq_scaled_dot_product_attention(
481
  q,
@@ -567,29 +750,3 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
567
 
568
  x_out2 = x_out2.flatten(3)
569
  return x_out2.type_as(x)
570
-
571
-
572
- if __name__ == "__main__":
573
- args = DualARModelArgs(
574
- max_seq_len=4096,
575
- vocab_size=32312,
576
- n_layer=12,
577
- n_fast_layer=4,
578
- n_head=12,
579
- dim=768,
580
- rope_base=10000,
581
- norm_eps=1e-5,
582
- codebook_size=128,
583
- num_codebooks=4,
584
- )
585
-
586
- model = DualARTransformer(args)
587
- model = model.cuda().bfloat16()
588
- print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
589
-
590
- inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
591
- key_padding_mask = torch.zeros(2, 128).bool().cuda()
592
- key_padding_mask[0, 2:] = True
593
- x1 = model(inputs, key_padding_mask=key_padding_mask)
594
- print(x1.token_logits.shape)
595
- print(x1.codebook_logits.shape)
 
1
+ import json
2
  import math
3
  from dataclasses import dataclass
4
+ from pathlib import Path
5
  from typing import Optional
6
 
7
  import torch
8
  import torch.nn as nn
9
  from einops import rearrange
10
+ from loguru import logger
11
  from torch import Tensor
12
  from torch.nn import functional as F
13
+ from torch.nn.attention import SDPBackend, sdpa_kernel
14
  from torch.utils.checkpoint import checkpoint
15
+ from transformers import AutoTokenizer
16
+
17
+ from fish_speech.conversation import SEMANTIC_TOKEN
18
+ from fish_speech.utils import RankedLogger
19
+
20
+ from .lora import LoraConfig, setup_lora
21
+
22
+ log = RankedLogger(__name__, rank_zero_only=True)
23
 
24
 
25
  def find_multiple(n: int, k: int) -> int:
 
30
 
31
  @dataclass
32
  class BaseModelArgs:
33
+ model_type: str = "base"
34
+
35
  vocab_size: int = 32000
36
  n_layer: int = 32
37
  n_head: int = 32
 
43
  norm_eps: float = 1e-5
44
  max_seq_len: int = 2048
45
  dropout: float = 0.0
46
+ tie_word_embeddings: bool = True
47
+ attention_qkv_bias: bool = False
48
 
49
  # Codebook configs
50
  codebook_size: int = 160
51
  num_codebooks: int = 4
 
 
52
 
53
  # Gradient checkpointing
54
  use_gradient_checkpointing: bool = True
55
 
56
+ # Initialize the model
57
+ initializer_range: float = 0.02
58
+
59
  def __post_init__(self):
60
  if self.n_local_heads == -1:
61
  self.n_local_heads = self.n_head
 
63
  hidden_dim = 4 * self.dim
64
  n_hidden = int(2 * hidden_dim / 3)
65
  self.intermediate_size = find_multiple(n_hidden, 256)
 
 
66
  self.head_dim = self.dim // self.n_head
67
 
68
+ @staticmethod
69
+ def from_pretrained(path: str):
70
+ path = Path(path)
71
+
72
+ if path.is_dir():
73
+ path = path / "config.json"
74
+
75
+ with open(path, "r", encoding="utf-8") as f:
76
+ data = json.load(f)
77
+
78
+ match data["model_type"]:
79
+ case "naive":
80
+ cls = NaiveModelArgs
81
+ case "dual_ar":
82
+ cls = DualARModelArgs
83
+ case _:
84
+ raise ValueError(f"Unknown model type: {data['model_type']}")
85
+
86
+ return cls(**data)
87
+
88
+ def save(self, path: str):
89
+ with open(path, "w") as f:
90
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
91
+
92
 
93
  @dataclass
94
  class NaiveModelArgs(BaseModelArgs):
95
+ model_type: str = "naive"
96
 
97
 
98
  @dataclass
99
  class DualARModelArgs(BaseModelArgs):
100
+ model_type: str = "dual_ar"
101
  n_fast_layer: int = 4
102
 
103
 
 
135
 
136
 
137
  class BaseTransformer(nn.Module):
138
+ def __init__(
139
+ self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
140
+ ) -> None:
141
  super().__init__()
142
  self.config = config
143
+ self.tokenizer = tokenizer
144
+
145
+ self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
146
 
147
  # Slow transformer
148
  self.embeddings = nn.Embedding(
149
+ config.vocab_size,
150
+ config.dim,
151
+ )
152
+ self.codebook_embeddings = nn.Embedding(
153
+ config.codebook_size * config.num_codebooks,
154
  config.dim,
155
  )
156
  self.layers = nn.ModuleList(
157
  TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
158
  )
159
  self.norm = RMSNorm(config.dim, eps=config.norm_eps)
160
+
161
+ if self.config.tie_word_embeddings is False:
162
+ self.output = nn.Linear(
163
+ config.dim,
164
+ config.vocab_size,
165
+ bias=False,
166
+ )
167
 
168
  self.register_buffer(
169
  "freqs_cis",
 
190
  self.max_batch_size = -1
191
  self.max_seq_len = -1
192
 
193
+ if init_weights:
194
+ self.apply(self._init_weights)
195
+
196
  def setup_caches(
197
  self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
198
  ):
 
215
 
216
  def embed(self, x: Tensor) -> Tensor:
217
  vocab_embeds = [self.embeddings(x[:, 0])]
218
+ for i in range(self.config.num_codebooks):
219
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
220
+ emb[x[:, 0] != self.semantic_token_id] = 0
 
 
221
  vocab_embeds.append(emb)
222
 
223
  x = torch.stack(vocab_embeds, dim=3)
 
226
  return x
227
 
228
  def forward(
229
+ self,
230
+ inp: Tensor,
231
+ key_padding_mask: Optional[Tensor] = None,
232
  ) -> BaseTransformerForwardResult:
 
233
  seq_len = inp.size(2)
234
 
235
  # Here we want to merge the embeddings of the codebooks
236
  x = self.embed(inp)
237
 
 
238
  freqs_cis = self.freqs_cis[:seq_len]
239
 
240
  # Not that the causal mask here follows the definition of scaled_dot_product_attention
241
  # That is, FALSE means masked out
242
  # To maintain consistency, key_padding_mask use TRUE to mask out
243
+ mask = None
244
  if key_padding_mask is not None:
245
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
246
  mask = mask & key_padding_mask[:, None, None, :].logical_not()
247
 
248
  for layer in self.layers:
 
253
 
254
  # We got slow_out here
255
  slow_out = self.norm(x)
256
+
257
+ if self.config.tie_word_embeddings:
258
+ token_logits = F.linear(slow_out, self.embeddings.weight)
259
+ else:
260
+ token_logits = self.output(slow_out)
261
 
262
  return BaseTransformerForwardResult(
263
  logits=token_logits,
 
265
  )
266
 
267
  def forward_generate(
268
+ self,
269
+ x: Tensor,
270
+ input_pos: Optional[Tensor] = None,
271
+ return_all: bool = False,
272
  ) -> BaseTransformerForwardResult:
273
  # This is used for generation, optimized for torch compile
274
  assert (
 
286
  x = layer(x, freqs_cis, mask, input_pos=input_pos)
287
 
288
  # If prefill, we only calculate the logits of last token
289
+ if x.size(1) > 1 and not return_all:
290
  x = x[:, -1:]
291
 
292
  # We got slow_out here
293
  slow_out = self.norm(x)
294
+
295
+ if self.config.tie_word_embeddings:
296
+ token_logits = F.linear(slow_out, self.embeddings.weight)
297
+ else:
298
+ token_logits = self.output(slow_out)
299
 
300
  return BaseTransformerForwardResult(
301
  logits=token_logits,
302
  hidden_states=x,
303
  )
304
 
305
+ def _init_weights(self, module):
306
+ std = self.config.initializer_range
307
+ if isinstance(module, nn.Linear):
308
+ module.weight.data.normal_(mean=0.0, std=std)
309
+ if module.bias is not None:
310
+ module.bias.data.zero_()
311
+ elif isinstance(module, nn.Embedding):
312
+ module.weight.data.normal_(mean=0.0, std=std)
313
+ if module.padding_idx is not None:
314
+ module.weight.data[module.padding_idx].zero_()
315
+
316
+ @staticmethod
317
+ def from_pretrained(
318
+ path: str,
319
+ load_weights: bool = False,
320
+ max_length: int | None = None,
321
+ lora_config: LoraConfig | None = None,
322
+ rope_base: int | None = None,
323
+ ) -> "BaseTransformer":
324
+ config = BaseModelArgs.from_pretrained(str(path))
325
+ if max_length is not None:
326
+ config.max_seq_len = max_length
327
+ log.info(f"Override max_seq_len to {max_length}")
328
+
329
+ if rope_base is not None:
330
+ config.rope_base = rope_base
331
+ log.info(f"Override rope_base to {rope_base}")
332
+
333
+ match config.model_type:
334
+ case "naive":
335
+ model_cls = NaiveTransformer
336
+ case "dual_ar":
337
+ model_cls = DualARTransformer
338
+ case _:
339
+ raise ValueError(f"Unknown model type: {config.model_type}")
340
+
341
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
342
+ log.info(f"Loading model from {path}, config: {config}")
343
+ model = model_cls(config, tokenizer=tokenizer)
344
+
345
+ if lora_config is not None:
346
+ setup_lora(model, lora_config)
347
+ log.info(f"LoRA setup: {lora_config}")
348
+
349
+ if load_weights is False:
350
+ log.info("Randomly initialized model")
351
+ else:
352
+
353
+ if "int8" in str(Path(path)):
354
+ logger.info("Using int8 weight-only quantization!")
355
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
356
+
357
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
358
+ model = simple_quantizer.convert_for_runtime()
359
+
360
+ if "int4" in str(Path(path)):
361
+ logger.info("Using int4 quantization!")
362
+ path_comps = path.name.split("-")
363
+ assert path_comps[-2].startswith("g")
364
+ groupsize = int(path_comps[-2][1:])
365
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
366
+
367
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
368
+ model = simple_quantizer.convert_for_runtime()
369
+
370
+ weights = torch.load(
371
+ Path(path) / "model.pth", map_location="cpu", mmap=True
372
+ )
373
+ err = model.load_state_dict(weights, strict=False, assign=True)
374
+ log.info(f"Loaded weights with error: {err}")
375
+
376
+ return model
377
+
378
+ def save_pretrained(self, path: str, drop_lora: bool = False):
379
+ path = Path(path)
380
+ path.mkdir(parents=True, exist_ok=True)
381
+
382
+ self.config.save(path / "config.json")
383
+ state_dict = self.state_dict()
384
+
385
+ if drop_lora:
386
+ for key in list(state_dict.keys()):
387
+ if "lora" not in key:
388
+ continue
389
+
390
+ state_dict.pop(key)
391
+ log.info(f"Drop LoRA parameter: {key}")
392
+
393
+ torch.save(state_dict, path / "model.pth")
394
+ self.tokenizer.save_pretrained(path)
395
+
396
 
397
  class NaiveTransformer(BaseTransformer):
398
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
399
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
400
 
401
  self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
402
  self.codebook_output = nn.Linear(
 
405
  bias=False,
406
  )
407
 
408
+ self.apply(self._init_weights)
409
+
410
  def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
411
  token_logits = result.logits
412
  x = result.hidden_states
 
423
  )
424
 
425
  def forward(
426
+ self,
427
+ inp: Tensor,
428
+ key_padding_mask: Optional[Tensor] = None,
429
  ) -> TransformerForwardResult:
430
+ result = super().forward(
431
+ inp=inp,
432
+ key_padding_mask=key_padding_mask,
433
+ )
434
  return self.decode(result)
435
 
436
  def forward_generate(
 
441
 
442
 
443
  class DualARTransformer(BaseTransformer):
444
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
445
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
446
 
447
  # Fast transformer
448
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
 
 
449
 
450
  # The equivalent bs is so large that sdpa doesn't work
451
  self.fast_layers = nn.ModuleList(
 
458
  bias=False,
459
  )
460
 
461
+ self.apply(self._init_weights)
462
+
463
  def setup_caches(
464
  self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
465
  ):
 
479
  )
480
 
481
  def forward(
482
+ self,
483
+ inp: Tensor,
484
+ key_padding_mask: Optional[Tensor] = None,
485
  ) -> TransformerForwardResult:
486
  parent_result = super().forward(inp, key_padding_mask)
487
  token_logits = parent_result.logits
 
496
 
497
  # Drop the last token and rotate left
498
  codebooks = inp[:, 1:-1, 1:]
499
+ codebooks = F.pad(codebooks, (0, 1), value=0)
500
  codebook_embeddings = self.fast_embeddings(codebooks)
501
  x = torch.cat([x[:, None], codebook_embeddings], dim=1)
502
  b, s = x.size(0), x.size(2)
 
504
 
505
  # Remove padded part
506
  codebooks = rearrange(codebooks, "b n s -> (b s) n")
507
+ codebook_mask = (codebooks == 0).all(dim=-1)
508
+
509
+ if torch.all(codebook_mask):
510
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
511
+ codebook_mask[:8] = False
512
+
513
  x_bs, x_len = x.size(0), x.size(1)
514
  x = x[~codebook_mask]
515
 
 
592
 
593
  total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
594
  # key, query, value projections for all heads, but in a batch
595
+ self.wqkv = nn.Linear(
596
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
597
+ )
598
  self.wo = nn.Linear(config.dim, config.dim, bias=False)
599
  self.kv_cache = None
600
 
 
641
  v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
642
 
643
  if self.use_sdpa:
644
+ if mask is None:
645
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
646
+ y = F.scaled_dot_product_attention(
647
+ q,
648
+ k,
649
+ v,
650
+ dropout_p=self.dropout if self.training else 0.0,
651
+ is_causal=True,
652
+ # No third party attn_mask here to use flash_attention
653
+ )
654
+ else:
655
+ y = F.scaled_dot_product_attention(
656
+ q,
657
+ k,
658
+ v,
659
+ attn_mask=mask,
660
+ dropout_p=self.dropout if self.training else 0.0,
661
+ )
662
  else:
663
  y = self.eq_scaled_dot_product_attention(
664
  q,
 
750
 
751
  x_out2 = x_out2.flatten(3)
752
  return x_out2.type_as(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/text2semantic/lora.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import loralib as lora
4
+
5
+
6
+ @dataclass
7
+ class LoraConfig:
8
+ r: int
9
+ lora_alpha: float
10
+ lora_dropout: float = 0.0
11
+
12
+
13
+ def setup_lora(model, lora_config):
14
+ # Replace the embedding layer with a LoRA layer
15
+ model.embeddings = lora.Embedding(
16
+ num_embeddings=model.embeddings.num_embeddings,
17
+ embedding_dim=model.embeddings.embedding_dim,
18
+ padding_idx=model.embeddings.padding_idx,
19
+ r=lora_config.r,
20
+ lora_alpha=lora_config.lora_alpha,
21
+ )
22
+
23
+ model.codebook_embeddings = lora.Embedding(
24
+ num_embeddings=model.codebook_embeddings.num_embeddings,
25
+ embedding_dim=model.codebook_embeddings.embedding_dim,
26
+ padding_idx=model.codebook_embeddings.padding_idx,
27
+ r=lora_config.r,
28
+ lora_alpha=lora_config.lora_alpha,
29
+ )
30
+
31
+ # Replace output layer with a LoRA layer
32
+ linears = [(model, "output")]
33
+
34
+ # Replace all linear layers with LoRA layers
35
+ for layer in model.layers:
36
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
+ linears.extend(
38
+ [
39
+ (layer.feed_forward, "w1"),
40
+ (layer.feed_forward, "w2"),
41
+ (layer.feed_forward, "w3"),
42
+ ]
43
+ )
44
+
45
+ if hasattr(model, "fast_layers"):
46
+ model.fast_embeddings = lora.Embedding(
47
+ num_embeddings=model.fast_embeddings.num_embeddings,
48
+ embedding_dim=model.fast_embeddings.embedding_dim,
49
+ padding_idx=model.fast_embeddings.padding_idx,
50
+ r=lora_config.r,
51
+ lora_alpha=lora_config.lora_alpha,
52
+ )
53
+
54
+ # Dual-AR model
55
+ linears.append((model, "fast_output"))
56
+
57
+ for layer in model.fast_layers:
58
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
+ linears.extend(
60
+ [
61
+ (layer.feed_forward, "w1"),
62
+ (layer.feed_forward, "w2"),
63
+ (layer.feed_forward, "w3"),
64
+ ]
65
+ )
66
+
67
+ for module, layer in linears:
68
+ updated_linear = lora.Linear(
69
+ in_features=getattr(module, layer).in_features,
70
+ out_features=getattr(module, layer).out_features,
71
+ bias=getattr(module, layer).bias,
72
+ r=lora_config.r,
73
+ lora_alpha=lora_config.lora_alpha,
74
+ lora_dropout=lora_config.lora_dropout,
75
+ )
76
+ setattr(module, layer, updated_linear)
77
+
78
+ # Mark only the LoRA layers as trainable
79
+ lora.mark_only_lora_as_trainable(model, bias="none")
80
+
81
+
82
+ def get_merged_state_dict(model):
83
+ # This line will merge the state dict of the model and the LoRA parameters
84
+ model.eval()
85
+
86
+ # Then we need to remove the LoRA parameters from the state dict
87
+ state_dict = model.state_dict()
88
+ for name in list(state_dict.keys()):
89
+ if "lora" in name:
90
+ state_dict.pop(name)
91
+
92
+ return state_dict
fish_speech/models/vqgan/modules/firefly.py CHANGED
@@ -1,5 +1,6 @@
1
  # A inference only version of the FireflyGAN model
2
 
 
3
  from functools import partial
4
  from math import prod
5
  from typing import Callable
@@ -13,6 +14,8 @@ from torch.nn.utils.parametrizations import weight_norm
13
  from torch.nn.utils.parametrize import remove_parametrizations
14
  from torch.utils.checkpoint import checkpoint
15
 
 
 
16
 
17
  def init_weights(m, mean=0.0, std=0.01):
18
  classname = m.__class__.__name__
@@ -474,6 +477,89 @@ class ConvNeXtEncoder(nn.Module):
474
  return self.norm(x)
475
 
476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  class FireflyBase(nn.Module):
478
  def __init__(self, ckpt_path: str = None, pretrained: bool = True):
479
  super().__init__()
@@ -500,11 +586,12 @@ class FireflyBase(nn.Module):
500
  )
501
 
502
  if ckpt_path is not None:
503
- self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
504
  elif pretrained:
505
  state_dict = torch.hub.load_state_dict_from_url(
506
  "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
507
  map_location="cpu",
 
508
  )
509
 
510
  if "state_dict" in state_dict:
 
1
  # A inference only version of the FireflyGAN model
2
 
3
+ import math
4
  from functools import partial
5
  from math import prod
6
  from typing import Callable
 
14
  from torch.nn.utils.parametrize import remove_parametrizations
15
  from torch.utils.checkpoint import checkpoint
16
 
17
+ from fish_speech.models.vqgan.utils import sequence_mask
18
+
19
 
20
  def init_weights(m, mean=0.0, std=0.01):
21
  classname = m.__class__.__name__
 
477
  return self.norm(x)
478
 
479
 
480
+ class FireflyArchitecture(nn.Module):
481
+ def __init__(
482
+ self,
483
+ backbone: nn.Module,
484
+ head: nn.Module,
485
+ quantizer: nn.Module,
486
+ spec_transform: nn.Module,
487
+ ):
488
+ super().__init__()
489
+
490
+ self.backbone = backbone
491
+ self.head = head
492
+ self.quantizer = quantizer
493
+ self.spec_transform = spec_transform
494
+
495
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
496
+ if self.spec_transform is not None:
497
+ x = self.spec_transform(x)
498
+
499
+ x = self.backbone(x)
500
+ if mask is not None:
501
+ x = x * mask
502
+
503
+ if self.quantizer is not None:
504
+ vq_result = self.quantizer(x)
505
+ x = vq_result.z
506
+
507
+ if mask is not None:
508
+ x = x * mask
509
+
510
+ x = self.head(x, template=template)
511
+
512
+ if x.ndim == 2:
513
+ x = x[:, None, :]
514
+
515
+ if self.vq is not None:
516
+ return x, vq_result
517
+
518
+ return x
519
+
520
+ def encode(self, audios, audio_lengths):
521
+ audios = audios.float()
522
+
523
+ mels = self.spec_transform(audios)
524
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
525
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
526
+ mel_masks_float_conv = mel_masks[:, None, :].float()
527
+ mels = mels * mel_masks_float_conv
528
+
529
+ # Encode
530
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
531
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
532
+
533
+ return self.quantizer.encode(encoded_features), feature_lengths
534
+
535
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
536
+ factor = math.prod(self.quantizer.downsample_factor)
537
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
538
+ mel_masks_float_conv = mel_masks[:, None, :].float()
539
+
540
+ audio_masks = sequence_mask(
541
+ feature_lengths * factor * self.spec_transform.hop_length,
542
+ indices.shape[2] * factor * self.spec_transform.hop_length,
543
+ )
544
+ audio_masks_float_conv = audio_masks[:, None, :].float()
545
+
546
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
547
+ x = self.head(z) * audio_masks_float_conv
548
+
549
+ return x
550
+
551
+ def remove_parametrizations(self):
552
+ if hasattr(self.backbone, "remove_parametrizations"):
553
+ self.backbone.remove_parametrizations()
554
+
555
+ if hasattr(self.head, "remove_parametrizations"):
556
+ self.head.remove_parametrizations()
557
+
558
+ @property
559
+ def device(self):
560
+ return next(self.parameters()).device
561
+
562
+
563
  class FireflyBase(nn.Module):
564
  def __init__(self, ckpt_path: str = None, pretrained: bool = True):
565
  super().__init__()
 
586
  )
587
 
588
  if ckpt_path is not None:
589
+ state_dict = torch.load(ckpt_path, map_location="cpu")
590
  elif pretrained:
591
  state_dict = torch.hub.load_state_dict_from_url(
592
  "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
593
  map_location="cpu",
594
+ model_dir="checkpoints",
595
  )
596
 
597
  if "state_dict" in state_dict:
fish_speech/models/vqgan/modules/fsq.py CHANGED
@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
20
  def __init__(
21
  self,
22
  input_dim: int = 512,
23
- n_codebooks: int = 9,
24
  n_groups: int = 1,
25
  levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
  downsample_factor: tuple[int] = (2, 2),
 
20
  def __init__(
21
  self,
22
  input_dim: int = 512,
23
+ n_codebooks: int = 1,
24
  n_groups: int = 1,
25
  levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
  downsample_factor: tuple[int] = (2, 2),
fish_speech/text/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .clean import clean_text
 
2
 
3
- __all__ = ["clean_text"]
 
1
  from .clean import clean_text
2
+ from .spliter import split_text
3
 
4
+ __all__ = ["clean_text", "split_text"]
fish_speech/text/chn_text_norm/.gitignore ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+
106
+ # JetBrains PyCharm
107
+ .idea
108
+
109
+ # Customize
110
+ references
111
+ url.txt
112
+
113
+ # Git
114
+ .git
fish_speech/text/chn_text_norm/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
2
+
3
+ # Chn Text Norm
4
+
5
+ this is a repository for chinese text normalization (no longer maintained).
6
+
7
+ ## Quick Start ##
8
+
9
+ ### Git Clone Repo ###
10
+
11
+ git clone this repo to the root directory of your project which need to use it.
12
+
13
+ cd /path/to/proj
14
+ git clone https://github.com/Joee1995/chn-text-norm.git
15
+
16
+ after that, your doc tree should be:
17
+ ```
18
+ proj # root of your project
19
+ |--- chn_text_norm # this chn-text-norm tool
20
+ |--- text.py
21
+ |--- ...
22
+ |--- text_normalize.py # your text normalization code
23
+ |--- ...
24
+ ```
25
+
26
+ ### How to Use ? ###
27
+
28
+ # text_normalize.py
29
+ from chn_text_norm.text import *
30
+
31
+ raw_text = 'your raw text'
32
+ text = Text(raw_text=raw_text).normalize()
33
+
34
+ ### How to add quantums ###
35
+
36
+ 打开test.py,然后你就知道怎么做了。
fish_speech/text/chn_text_norm/__init__.py ADDED
File without changes
fish_speech/text/chn_text_norm/basic_class.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """基本类
3
+ 中文字符类
4
+ 中文数字/数位类
5
+ 中文数字类
6
+ 中文数位类
7
+ 中文数字系统类
8
+ 中文数学符号类
9
+ *中文其他符号类
10
+ """
11
+
12
+ __author__ = "Zhiyang Zhou <[email protected]>"
13
+ __data__ = "2019-05-02"
14
+
15
+ from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
16
+
17
+
18
+ class ChineseChar(object):
19
+ """
20
+ 中文字符
21
+ 每个字符对应简体和繁体,
22
+ e.g. 简体 = '负', 繁体 = '負'
23
+ 转换时可转换为简体或繁体
24
+ """
25
+
26
+ def __init__(self, simplified, traditional):
27
+ self.simplified = simplified
28
+ self.traditional = traditional
29
+ self.__repr__ = self.__str__
30
+
31
+ def __str__(self):
32
+ return self.simplified or self.traditional or None
33
+
34
+ def __repr__(self):
35
+ return self.__str__()
36
+
37
+
38
+ class ChineseNumberUnit(ChineseChar):
39
+ """
40
+ 中文数字/数位字符
41
+ 每个字符除繁简体外还有一个额外的大写字符
42
+ e.g. '陆' 和 '陸'
43
+ """
44
+
45
+ def __init__(self, power, simplified, traditional, big_s, big_t):
46
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
47
+ self.power = power
48
+ self.big_s = big_s
49
+ self.big_t = big_t
50
+
51
+ def __str__(self):
52
+ return "10^{}".format(self.power)
53
+
54
+ @classmethod
55
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
56
+
57
+ if small_unit:
58
+ return ChineseNumberUnit(
59
+ power=index + 1,
60
+ simplified=value[0],
61
+ traditional=value[1],
62
+ big_s=value[1],
63
+ big_t=value[1],
64
+ )
65
+ elif numbering_type == NUMBERING_TYPES[0]:
66
+ return ChineseNumberUnit(
67
+ power=index + 8,
68
+ simplified=value[0],
69
+ traditional=value[1],
70
+ big_s=value[0],
71
+ big_t=value[1],
72
+ )
73
+ elif numbering_type == NUMBERING_TYPES[1]:
74
+ return ChineseNumberUnit(
75
+ power=(index + 2) * 4,
76
+ simplified=value[0],
77
+ traditional=value[1],
78
+ big_s=value[0],
79
+ big_t=value[1],
80
+ )
81
+ elif numbering_type == NUMBERING_TYPES[2]:
82
+ return ChineseNumberUnit(
83
+ power=pow(2, index + 3),
84
+ simplified=value[0],
85
+ traditional=value[1],
86
+ big_s=value[0],
87
+ big_t=value[1],
88
+ )
89
+ else:
90
+ raise ValueError(
91
+ "Counting type should be in {0} ({1} provided).".format(
92
+ NUMBERING_TYPES, numbering_type
93
+ )
94
+ )
95
+
96
+
97
+ class ChineseNumberDigit(ChineseChar):
98
+ """
99
+ 中文数字字符
100
+ """
101
+
102
+ def __init__(
103
+ self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
104
+ ):
105
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
106
+ self.value = value
107
+ self.big_s = big_s
108
+ self.big_t = big_t
109
+ self.alt_s = alt_s
110
+ self.alt_t = alt_t
111
+
112
+ def __str__(self):
113
+ return str(self.value)
114
+
115
+ @classmethod
116
+ def create(cls, i, v):
117
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
118
+
119
+
120
+ class ChineseMath(ChineseChar):
121
+ """
122
+ 中文数位字符
123
+ """
124
+
125
+ def __init__(self, simplified, traditional, symbol, expression=None):
126
+ super(ChineseMath, self).__init__(simplified, traditional)
127
+ self.symbol = symbol
128
+ self.expression = expression
129
+ self.big_s = simplified
130
+ self.big_t = traditional
131
+
132
+
133
+ CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
134
+
135
+
136
+ class NumberSystem(object):
137
+ """
138
+ 中文数字系统
139
+ """
140
+
141
+ pass
142
+
143
+
144
+ class MathSymbol(object):
145
+ """
146
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
147
+ positive = ['正', '正']
148
+ negative = ['负', '負']
149
+ point = ['点', '點']
150
+ """
151
+
152
+ def __init__(self, positive, negative, point):
153
+ self.positive = positive
154
+ self.negative = negative
155
+ self.point = point
156
+
157
+ def __iter__(self):
158
+ for v in self.__dict__.values():
159
+ yield v
160
+
161
+
162
+ # class OtherSymbol(object):
163
+ # """
164
+ # 其他符号
165
+ # """
166
+ #
167
+ # def __init__(self, sil):
168
+ # self.sil = sil
169
+ #
170
+ # def __iter__(self):
171
+ # for v in self.__dict__.values():
172
+ # yield v
fish_speech/text/chn_text_norm/basic_constant.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """基本常量
3
+ 中文数字/数位/符号字符常量
4
+ """
5
+
6
+ __author__ = "Zhiyang Zhou <[email protected]>"
7
+ __data__ = "2019-05-02"
8
+
9
+ CHINESE_DIGIS = "零一二三四五六七八九"
10
+ BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
11
+ BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
12
+ SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
13
+ SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
14
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
15
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
16
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
17
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
18
+
19
+ ZERO_ALT = "〇"
20
+ ONE_ALT = "幺"
21
+ TWO_ALTS = ["两", "兩"]
22
+
23
+ POSITIVE = ["正", "正"]
24
+ NEGATIVE = ["负", "負"]
25
+ POINT = ["点", "點"]
26
+ # PLUS = [u'加', u'加']
27
+ # SIL = [u'杠', u'槓']
28
+
29
+ # 中文数字系统类型
30
+ NUMBERING_TYPES = ["low", "mid", "high"]
fish_speech/text/chn_text_norm/basic_util.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """基本方法
3
+ 创建中文数字系统 方法
4
+ 中文字符串 <=> 数字串 方法
5
+ 数字串 <=> 中文字符串 方法
6
+ """
7
+
8
+ __author__ = "Zhiyang Zhou <[email protected]>"
9
+ __data__ = "2019-05-02"
10
+
11
+ from fish_speech.text.chn_text_norm.basic_class import *
12
+ from fish_speech.text.chn_text_norm.basic_constant import *
13
+
14
+
15
+ def create_system(numbering_type=NUMBERING_TYPES[1]):
16
+ """
17
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
18
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
19
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
20
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
21
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
22
+ 返回对应的数字系统
23
+ """
24
+
25
+ # chinese number units of '亿' and larger
26
+ all_larger_units = zip(
27
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
28
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
29
+ )
30
+ larger_units = [
31
+ CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
32
+ ]
33
+ # chinese number units of '十, 百, 千, 万'
34
+ all_smaller_units = zip(
35
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
36
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
37
+ )
38
+ smaller_units = [
39
+ CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
40
+ ]
41
+ # digis
42
+ chinese_digis = zip(
43
+ CHINESE_DIGIS,
44
+ CHINESE_DIGIS,
45
+ BIG_CHINESE_DIGIS_SIMPLIFIED,
46
+ BIG_CHINESE_DIGIS_TRADITIONAL,
47
+ )
48
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
49
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
50
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
51
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
52
+
53
+ # symbols
54
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
55
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
56
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
57
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
58
+ system = NumberSystem()
59
+ system.units = smaller_units + larger_units
60
+ system.digits = digits
61
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
62
+ # system.symbols = OtherSymbol(sil_cn)
63
+ return system
64
+
65
+
66
+ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
67
+
68
+ def get_symbol(char, system):
69
+ for u in system.units:
70
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
71
+ return u
72
+ for d in system.digits:
73
+ if char in [
74
+ d.traditional,
75
+ d.simplified,
76
+ d.big_s,
77
+ d.big_t,
78
+ d.alt_s,
79
+ d.alt_t,
80
+ ]:
81
+ return d
82
+ for m in system.math:
83
+ if char in [m.traditional, m.simplified]:
84
+ return m
85
+
86
+ def string2symbols(chinese_string, system):
87
+ int_string, dec_string = chinese_string, ""
88
+ for p in [system.math.point.simplified, system.math.point.traditional]:
89
+ if p in chinese_string:
90
+ int_string, dec_string = chinese_string.split(p)
91
+ break
92
+ return [get_symbol(c, system) for c in int_string], [
93
+ get_symbol(c, system) for c in dec_string
94
+ ]
95
+
96
+ def correct_symbols(integer_symbols, system):
97
+ """
98
+ 一百八 to 一百八十
99
+ 一亿一千三百万 to 一亿 一千万 三百万
100
+ """
101
+
102
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
103
+ if integer_symbols[0].power == 1:
104
+ integer_symbols = [system.digits[1]] + integer_symbols
105
+
106
+ if len(integer_symbols) > 1:
107
+ if isinstance(integer_symbols[-1], CND) and isinstance(
108
+ integer_symbols[-2], CNU
109
+ ):
110
+ integer_symbols.append(
111
+ CNU(integer_symbols[-2].power - 1, None, None, None, None)
112
+ )
113
+
114
+ result = []
115
+ unit_count = 0
116
+ for s in integer_symbols:
117
+ if isinstance(s, CND):
118
+ result.append(s)
119
+ unit_count = 0
120
+ elif isinstance(s, CNU):
121
+ current_unit = CNU(s.power, None, None, None, None)
122
+ unit_count += 1
123
+
124
+ if unit_count == 1:
125
+ result.append(current_unit)
126
+ elif unit_count > 1:
127
+ for i in range(len(result)):
128
+ if (
129
+ isinstance(result[-i - 1], CNU)
130
+ and result[-i - 1].power < current_unit.power
131
+ ):
132
+ result[-i - 1] = CNU(
133
+ result[-i - 1].power + current_unit.power,
134
+ None,
135
+ None,
136
+ None,
137
+ None,
138
+ )
139
+ return result
140
+
141
+ def compute_value(integer_symbols):
142
+ """
143
+ Compute the value.
144
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
145
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
146
+ """
147
+ value = [0]
148
+ last_power = 0
149
+ for s in integer_symbols:
150
+ if isinstance(s, CND):
151
+ value[-1] = s.value
152
+ elif isinstance(s, CNU):
153
+ value[-1] *= pow(10, s.power)
154
+ if s.power > last_power:
155
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
156
+ last_power = s.power
157
+ value.append(0)
158
+ return sum(value)
159
+
160
+ system = create_system(numbering_type)
161
+ int_part, dec_part = string2symbols(chinese_string, system)
162
+ int_part = correct_symbols(int_part, system)
163
+ int_str = str(compute_value(int_part))
164
+ dec_str = "".join([str(d.value) for d in dec_part])
165
+ if dec_part:
166
+ return "{0}.{1}".format(int_str, dec_str)
167
+ else:
168
+ return int_str
169
+
170
+
171
+ def num2chn(
172
+ number_string,
173
+ numbering_type=NUMBERING_TYPES[1],
174
+ big=False,
175
+ traditional=False,
176
+ alt_zero=False,
177
+ alt_one=False,
178
+ alt_two=True,
179
+ use_zeros=True,
180
+ use_units=True,
181
+ ):
182
+
183
+ def get_value(value_string, use_zeros=True):
184
+
185
+ striped_string = value_string.lstrip("0")
186
+
187
+ # record nothing if all zeros
188
+ if not striped_string:
189
+ return []
190
+
191
+ # record one digits
192
+ elif len(striped_string) == 1:
193
+ if use_zeros and len(value_string) != len(striped_string):
194
+ return [system.digits[0], system.digits[int(striped_string)]]
195
+ else:
196
+ return [system.digits[int(striped_string)]]
197
+
198
+ # recursively record multiple digits
199
+ else:
200
+ result_unit = next(
201
+ u for u in reversed(system.units) if u.power < len(striped_string)
202
+ )
203
+ result_string = value_string[: -result_unit.power]
204
+ return (
205
+ get_value(result_string)
206
+ + [result_unit]
207
+ + get_value(striped_string[-result_unit.power :])
208
+ )
209
+
210
+ system = create_system(numbering_type)
211
+
212
+ int_dec = number_string.split(".")
213
+ if len(int_dec) == 1:
214
+ int_string = int_dec[0]
215
+ dec_string = ""
216
+ elif len(int_dec) == 2:
217
+ int_string = int_dec[0]
218
+ dec_string = int_dec[1]
219
+ else:
220
+ raise ValueError(
221
+ "invalid input num string with more than one dot: {}".format(number_string)
222
+ )
223
+
224
+ if use_units and len(int_string) > 1:
225
+ result_symbols = get_value(int_string)
226
+ else:
227
+ result_symbols = [system.digits[int(c)] for c in int_string]
228
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
229
+ if dec_string:
230
+ result_symbols += [system.math.point] + dec_symbols
231
+
232
+ if alt_two:
233
+ liang = CND(
234
+ 2,
235
+ system.digits[2].alt_s,
236
+ system.digits[2].alt_t,
237
+ system.digits[2].big_s,
238
+ system.digits[2].big_t,
239
+ )
240
+ for i, v in enumerate(result_symbols):
241
+ if isinstance(v, CND) and v.value == 2:
242
+ next_symbol = (
243
+ result_symbols[i + 1] if i < len(result_symbols) - 1 else None
244
+ )
245
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
246
+ if isinstance(next_symbol, CNU) and isinstance(
247
+ previous_symbol, (CNU, type(None))
248
+ ):
249
+ if next_symbol.power != 1 and (
250
+ (previous_symbol is None) or (previous_symbol.power != 1)
251
+ ):
252
+ result_symbols[i] = liang
253
+
254
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
255
+ if big:
256
+ attr_name = "big_"
257
+ if traditional:
258
+ attr_name += "t"
259
+ else:
260
+ attr_name += "s"
261
+ else:
262
+ if traditional:
263
+ attr_name = "traditional"
264
+ else:
265
+ attr_name = "simplified"
266
+
267
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
268
+
269
+ # if not use_zeros:
270
+ # result = result.strip(getattr(system.digits[0], attr_name))
271
+
272
+ if alt_zero:
273
+ result = result.replace(
274
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s
275
+ )
276
+
277
+ if alt_one:
278
+ result = result.replace(
279
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s
280
+ )
281
+
282
+ for i, p in enumerate(POINT):
283
+ if result.startswith(p):
284
+ return CHINESE_DIGIS[0] + result
285
+
286
+ # ^10, 11, .., 19
287
+ if (
288
+ len(result) >= 2
289
+ and result[1]
290
+ in [
291
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
292
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
293
+ ]
294
+ and result[0]
295
+ in [
296
+ CHINESE_DIGIS[1],
297
+ BIG_CHINESE_DIGIS_SIMPLIFIED[1],
298
+ BIG_CHINESE_DIGIS_TRADITIONAL[1],
299
+ ]
300
+ ):
301
+ result = result[1:]
302
+
303
+ return result
304
+
305
+
306
+ if __name__ == "__main__":
307
+
308
+ # 测试程序
309
+ all_chinese_number_string = (
310
+ CHINESE_DIGIS
311
+ + BIG_CHINESE_DIGIS_SIMPLIFIED
312
+ + BIG_CHINESE_DIGIS_TRADITIONAL
313
+ + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
314
+ + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
315
+ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
316
+ + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
317
+ + ZERO_ALT
318
+ + ONE_ALT
319
+ + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
320
+ )
321
+
322
+ print("num:", chn2num("一万零四百零三点八零五"))
323
+ print("num:", chn2num("一亿六点三"))
324
+ print("num:", chn2num("一亿零六点三"))
325
+ print("num:", chn2num("两千零一亿六点三"))
326
+ # print('num:', chn2num('一零零八六'))
327
+ print("txt:", num2chn("10260.03", alt_zero=True))
328
+ print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
329
+ print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
330
+ print(
331
+ "txt:",
332
+ num2chn(
333
+ "059523810880",
334
+ alt_one=True,
335
+ alt_two=False,
336
+ use_lzeros=True,
337
+ use_rzeros=True,
338
+ use_units=False,
339
+ ),
340
+ )
341
+
342
+ print(all_chinese_number_string)
fish_speech/text/chn_text_norm/cardinal.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """CARDINAL类 (包含小数DECIMAL类)
3
+ 纯数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 纯数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Cardinal:
14
+ """
15
+ CARDINAL类
16
+ """
17
+
18
+ def __init__(self, cardinal=None, chntext=None):
19
+ self.cardinal = cardinal
20
+ self.chntext = chntext
21
+
22
+ def chntext2cardinal(self):
23
+ return chn2num(self.chntext)
24
+
25
+ def cardinal2chntext(self):
26
+ return num2chn(self.cardinal)
27
+
28
+
29
+ if __name__ == "__main__":
30
+
31
+ # 测试程序
32
+ print(Cardinal(cardinal="21357.230").cardinal2chntext())
fish_speech/text/chn_text_norm/date.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """DATE类
3
+ 日期 <=> 中文字符串 方法
4
+ 中文字符串 <=> 日期 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-07"
9
+
10
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
11
+ from fish_speech.text.chn_text_norm.digit import Digit
12
+
13
+
14
+ class Date:
15
+ """
16
+ DATE类
17
+ """
18
+
19
+ def __init__(self, date=None, chntext=None):
20
+ self.date = date
21
+ self.chntext = chntext
22
+
23
+ # def chntext2date(self):
24
+ # chntext = self.chntext
25
+ # try:
26
+ # year, other = chntext.strip().split('年', maxsplit=1)
27
+ # year = Digit(chntext=year).digit2chntext() + '年'
28
+ # except ValueError:
29
+ # other = chntext
30
+ # year = ''
31
+ # if other:
32
+ # try:
33
+ # month, day = other.strip().split('月', maxsplit=1)
34
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
35
+ # except ValueError:
36
+ # day = chntext
37
+ # month = ''
38
+ # if day:
39
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
40
+ # else:
41
+ # month = ''
42
+ # day = ''
43
+ # date = year + month + day
44
+ # self.date = date
45
+ # return self.date
46
+
47
+ def date2chntext(self):
48
+ date = self.date
49
+ try:
50
+ year, other = date.strip().split("年", maxsplit=1)
51
+ year = Digit(digit=year).digit2chntext() + "年"
52
+ except ValueError:
53
+ other = date
54
+ year = ""
55
+ if other:
56
+ try:
57
+ month, day = other.strip().split("月", maxsplit=1)
58
+ month = Cardinal(cardinal=month).cardinal2chntext() + "月"
59
+ except ValueError:
60
+ day = date
61
+ month = ""
62
+ if day:
63
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
64
+ else:
65
+ month = ""
66
+ day = ""
67
+ chntext = year + month + day
68
+ self.chntext = chntext
69
+ return self.chntext
70
+
71
+
72
+ if __name__ == "__main__":
73
+
74
+ # 测试
75
+ print(Date(date="09年3月16日").date2chntext())
fish_speech/text/chn_text_norm/digit.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """DIGIT类
3
+ 数字串 <=> 中文字符串 方法
4
+ 中文字符串 <=> 数字串 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Digit:
14
+ """
15
+ DIGIT类
16
+ """
17
+
18
+ def __init__(self, digit=None, chntext=None):
19
+ self.digit = digit
20
+ self.chntext = chntext
21
+
22
+ # def chntext2digit(self):
23
+ # return chn2num(self.chntext)
24
+
25
+ def digit2chntext(self):
26
+ return num2chn(self.digit, alt_two=False, use_units=False)
27
+
28
+
29
+ if __name__ == "__main__":
30
+
31
+ # 测试程序
32
+ print(Digit(digit="2016").digit2chntext())
fish_speech/text/chn_text_norm/fraction.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """FRACTION类
3
+ 分数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 分数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Fraction:
14
+ """
15
+ FRACTION类
16
+ """
17
+
18
+ def __init__(self, fraction=None, chntext=None):
19
+ self.fraction = fraction
20
+ self.chntext = chntext
21
+
22
+ def chntext2fraction(self):
23
+ denominator, numerator = self.chntext.split("分之")
24
+ return chn2num(numerator) + "/" + chn2num(denominator)
25
+
26
+ def fraction2chntext(self):
27
+ numerator, denominator = self.fraction.split("/")
28
+ return num2chn(denominator) + "分之" + num2chn(numerator)
29
+
30
+
31
+ if __name__ == "__main__":
32
+
33
+ # 测试程序
34
+ print(Fraction(fraction="2135/7230").fraction2chntext())
35
+ print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
fish_speech/text/chn_text_norm/money.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """MONEY类
3
+ 金钱 <=> 中文字符串 方法
4
+ 中文字符串 <=> 金钱 方法
5
+ """
6
+ import re
7
+
8
+ __author__ = "Zhiyang Zhou <[email protected]>"
9
+ __data__ = "2019-05-08"
10
+
11
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
12
+
13
+
14
+ class Money:
15
+ """
16
+ MONEY类
17
+ """
18
+
19
+ def __init__(self, money=None, chntext=None):
20
+ self.money = money
21
+ self.chntext = chntext
22
+
23
+ # def chntext2money(self):
24
+ # return self.money
25
+
26
+ def money2chntext(self):
27
+ money = self.money
28
+ pattern = re.compile(r"(\d+(\.\d+)?)")
29
+ matchers = pattern.findall(money)
30
+ if matchers:
31
+ for matcher in matchers:
32
+ money = money.replace(
33
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
34
+ )
35
+ self.chntext = money
36
+ return self.chntext
37
+
38
+
39
+ if __name__ == "__main__":
40
+
41
+ # 测试
42
+ print(Money(money="21.5万元").money2chntext())
43
+ print(Money(money="230块5毛").money2chntext())
fish_speech/text/chn_text_norm/percentage.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """PERCENTAGE类
3
+ 百分数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 百分数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-06"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Percentage:
14
+ """
15
+ PERCENTAGE类
16
+ """
17
+
18
+ def __init__(self, percentage=None, chntext=None):
19
+ self.percentage = percentage
20
+ self.chntext = chntext
21
+
22
+ def chntext2percentage(self):
23
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
24
+
25
+ def percentage2chntext(self):
26
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
27
+
28
+
29
+ if __name__ == "__main__":
30
+
31
+ # 测试程序
32
+ print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
33
+ print(Percentage(percentage="65.3%").percentage2chntext())
fish_speech/text/chn_text_norm/telephone.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """TELEPHONE类
3
+ 电话号码 <=> 中文字符串 方法
4
+ 中文字符串 <=> 电话号码 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class TelePhone:
14
+ """
15
+ TELEPHONE类
16
+ """
17
+
18
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
19
+ self.telephone = telephone
20
+ self.raw_chntext = raw_chntext
21
+ self.chntext = chntext
22
+
23
+ # def chntext2telephone(self):
24
+ # sil_parts = self.raw_chntext.split('<SIL>')
25
+ # self.telephone = '-'.join([
26
+ # str(chn2num(p)) for p in sil_parts
27
+ # ])
28
+ # return self.telephone
29
+
30
+ def telephone2chntext(self, fixed=False):
31
+
32
+ if fixed:
33
+ sil_parts = self.telephone.split("-")
34
+ self.raw_chntext = "<SIL>".join(
35
+ [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
36
+ )
37
+ self.chntext = self.raw_chntext.replace("<SIL>", "")
38
+ else:
39
+ sp_parts = self.telephone.strip("+").split()
40
+ self.raw_chntext = "<SP>".join(
41
+ [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
42
+ )
43
+ self.chntext = self.raw_chntext.replace("<SP>", "")
44
+ return self.chntext
45
+
46
+
47
+ if __name__ == "__main__":
48
+
49
+ # 测试程序
50
+ print(TelePhone(telephone="0595-23980880").telephone2chntext())
51
+ # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
fish_speech/text/chn_text_norm/text.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ TEXT类
4
+ """
5
+
6
+ __author__ = "Zhiyang Zhou <[email protected]>"
7
+ __data__ = "2019-05-03"
8
+
9
+ import re
10
+
11
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
12
+ from fish_speech.text.chn_text_norm.date import Date
13
+ from fish_speech.text.chn_text_norm.digit import Digit
14
+ from fish_speech.text.chn_text_norm.fraction import Fraction
15
+ from fish_speech.text.chn_text_norm.money import Money
16
+ from fish_speech.text.chn_text_norm.percentage import Percentage
17
+ from fish_speech.text.chn_text_norm.telephone import TelePhone
18
+
19
+ CURRENCY_NAMES = (
20
+ "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
21
+ "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
22
+ )
23
+ CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
24
+ COM_QUANTIFIERS = (
25
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
26
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
27
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
28
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
29
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
30
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
31
+ )
32
+
33
+
34
+ class Text:
35
+ """
36
+ Text类
37
+ """
38
+
39
+ def __init__(self, raw_text, norm_text=None):
40
+ self.raw_text = "^" + raw_text + "$"
41
+ self.norm_text = norm_text
42
+
43
+ def _particular(self):
44
+ text = self.norm_text
45
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
46
+ matchers = pattern.findall(text)
47
+ if matchers:
48
+ # print('particular')
49
+ for matcher in matchers:
50
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
51
+ self.norm_text = text
52
+ return self.norm_text
53
+
54
+ def normalize(self):
55
+ text = self.raw_text
56
+
57
+ # 规范化日期
58
+ pattern = re.compile(
59
+ r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
60
+ )
61
+ matchers = pattern.findall(text)
62
+ if matchers:
63
+ # print('date')
64
+ for matcher in matchers:
65
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
66
+
67
+ # 规范化金钱
68
+ pattern = re.compile(
69
+ r"\D+((\d+(\.\d+)?)[多余几]?"
70
+ + CURRENCY_UNITS
71
+ + "(\d"
72
+ + CURRENCY_UNITS
73
+ + "?)?)"
74
+ )
75
+ matchers = pattern.findall(text)
76
+ if matchers:
77
+ # print('money')
78
+ for matcher in matchers:
79
+ text = text.replace(
80
+ matcher[0], Money(money=matcher[0]).money2chntext(), 1
81
+ )
82
+
83
+ # 规范化固话/手机号码
84
+ # 手机
85
+ # http://www.jihaoba.com/news/show/13680
86
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
87
+ # 联通:130、131、132、156、155、186、185、176
88
+ # 电信:133、153、189、180、181、177
89
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
90
+ matchers = pattern.findall(text)
91
+ if matchers:
92
+ # print('telephone')
93
+ for matcher in matchers:
94
+ text = text.replace(
95
+ matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
96
+ )
97
+ # 固话
98
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
99
+ matchers = pattern.findall(text)
100
+ if matchers:
101
+ # print('fixed telephone')
102
+ for matcher in matchers:
103
+ text = text.replace(
104
+ matcher[0],
105
+ TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
106
+ 1,
107
+ )
108
+
109
+ # 规范化分数
110
+ pattern = re.compile(r"(\d+/\d+)")
111
+ matchers = pattern.findall(text)
112
+ if matchers:
113
+ # print('fraction')
114
+ for matcher in matchers:
115
+ text = text.replace(
116
+ matcher, Fraction(fraction=matcher).fraction2chntext(), 1
117
+ )
118
+
119
+ # 规范化百分数
120
+ text = text.replace("%", "%")
121
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
122
+ matchers = pattern.findall(text)
123
+ if matchers:
124
+ # print('percentage')
125
+ for matcher in matchers:
126
+ text = text.replace(
127
+ matcher[0],
128
+ Percentage(percentage=matcher[0]).percentage2chntext(),
129
+ 1,
130
+ )
131
+
132
+ # 规范化纯数+量词
133
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
134
+ matchers = pattern.findall(text)
135
+ if matchers:
136
+ # print('cardinal+quantifier')
137
+ for matcher in matchers:
138
+ text = text.replace(
139
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
140
+ )
141
+
142
+ # 规范化数字编号
143
+ pattern = re.compile(r"(\d{4,32})")
144
+ matchers = pattern.findall(text)
145
+ if matchers:
146
+ # print('digit')
147
+ for matcher in matchers:
148
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
149
+
150
+ # 规范化纯数
151
+ pattern = re.compile(r"(\d+(\.\d+)?)")
152
+ matchers = pattern.findall(text)
153
+ if matchers:
154
+ # print('cardinal')
155
+ for matcher in matchers:
156
+ text = text.replace(
157
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
158
+ )
159
+
160
+ self.norm_text = text
161
+ self._particular()
162
+
163
+ return self.norm_text.lstrip("^").rstrip("$")
164
+
165
+
166
+ if __name__ == "__main__":
167
+
168
+ # 测试程序
169
+ print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
170
+ print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
171
+ print(Text(raw_text="分数:32477/76391。").normalize())
172
+ print(Text(raw_text="百分数:80.03%。").normalize())
173
+ print(Text(raw_text="编号:31520181154418。").normalize())
174
+ print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
175
+ print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
176
+ print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
177
+ print(Text(raw_text="特殊:O2O或B2C。").normalize())
fish_speech/text/clean.py CHANGED
@@ -18,7 +18,6 @@ SYMBOLS_MAPPING = {
18
  "·": ",",
19
  "、": ",",
20
  "...": "…",
21
- "$": ".",
22
  "“": "'",
23
  "”": "'",
24
  "‘": "'",
@@ -62,12 +61,9 @@ REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
62
  def clean_text(text):
63
  # Clean the text
64
  text = text.strip()
65
- # Replace <p:(.*?)> with <PPP(.*?)PPP>
66
- text = re.sub(r"<p:(.*?)>", r"<PPP\1PPP>", text)
67
  # Replace all chinese symbols with their english counterparts
68
  text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
69
  text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
70
- # Replace <PPP(.*?)PPP> with <p:(.*?)>
71
- text = re.sub(r"<PPP(.*?)PPP>", r"<p:\1>", text)
72
 
73
  return text
 
18
  "·": ",",
19
  "、": ",",
20
  "...": "…",
 
21
  "“": "'",
22
  "”": "'",
23
  "‘": "'",
 
61
  def clean_text(text):
62
  # Clean the text
63
  text = text.strip()
64
+
 
65
  # Replace all chinese symbols with their english counterparts
66
  text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
67
  text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
 
 
68
 
69
  return text
fish_speech/text/spliter.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+ from fish_speech.text.clean import clean_text
5
+
6
+
7
+ def utf_8_len(text):
8
+ return len(text.encode("utf-8"))
9
+
10
+
11
+ def break_text(texts, length, splits: set):
12
+ for text in texts:
13
+ if utf_8_len(text) <= length:
14
+ yield text
15
+ continue
16
+
17
+ curr = ""
18
+ for char in text:
19
+ curr += char
20
+
21
+ if char in splits:
22
+ yield curr
23
+ curr = ""
24
+
25
+ if curr:
26
+ yield curr
27
+
28
+
29
+ def break_text_by_length(texts, length):
30
+ for text in texts:
31
+ if utf_8_len(text) <= length:
32
+ yield text
33
+ continue
34
+
35
+ curr = ""
36
+ for char in text:
37
+ curr += char
38
+
39
+ if utf_8_len(curr) >= length:
40
+ yield curr
41
+ curr = ""
42
+
43
+ if curr:
44
+ yield curr
45
+
46
+
47
+ def add_cleaned(curr, segments):
48
+ curr = curr.strip()
49
+ if curr and not all(c.isspace() or c in string.punctuation for c in curr):
50
+ segments.append(curr)
51
+
52
+
53
+ def protect_float(text):
54
+ # Turns 3.14 into <3_f_14> to prevent splitting
55
+ return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
56
+
57
+
58
+ def unprotect_float(text):
59
+ # Turns <3_f_14> into 3.14
60
+ return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
61
+
62
+
63
+ def split_text(text, length):
64
+ text = clean_text(text)
65
+
66
+ # Break the text into pieces with following rules:
67
+ # 1. Split the text at ".", "!", "?" if text is NOT a float
68
+ # 2. If the text is longer than length, split at ","
69
+ # 3. If the text is still longer than length, split at " "
70
+ # 4. If the text is still longer than length, split at any character to length
71
+
72
+ texts = [text]
73
+ texts = map(protect_float, texts)
74
+ texts = break_text(texts, length, {".", "!", "?"})
75
+ texts = map(unprotect_float, texts)
76
+ texts = break_text(texts, length, {","})
77
+ texts = break_text(texts, length, {" "})
78
+ texts = list(break_text_by_length(texts, length))
79
+
80
+ # Then, merge the texts into segments with length <= length
81
+ segments = []
82
+ curr = ""
83
+
84
+ for text in texts:
85
+ if utf_8_len(curr) + utf_8_len(text) <= length:
86
+ curr += text
87
+ else:
88
+ add_cleaned(curr, segments)
89
+ curr = text
90
+
91
+ if curr:
92
+ add_cleaned(curr, segments)
93
+
94
+ return segments
95
+
96
+
97
+ if __name__ == "__main__":
98
+ # Test the split_text function
99
+
100
+ text = "This is a test sentence. This is another test sentence. And a third one."
101
+
102
+ assert split_text(text, 50) == [
103
+ "This is a test sentence.",
104
+ "This is another test sentence. And a third one.",
105
+ ]
106
+ assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
107
+ assert split_text(" ", 10) == []
108
+ assert split_text("a", 10) == ["a"]
109
+
110
+ text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
111
+ assert split_text(text, 50) == [
112
+ "This is a test sentence with only commas,",
113
+ "and no dots, and no exclamation marks,",
114
+ "and no question marks, and no newlines.",
115
+ ]
116
+
117
+ text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
118
+ # First half split at " ", second half split at ","
119
+ assert split_text(text, 50) == [
120
+ "This is a test sentence This is a test sentence",
121
+ "This is a test sentence. This is a test sentence,",
122
+ "This is a test sentence, This is a test sentence.",
123
+ ]
124
+
125
+ text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
126
+ assert split_text(text, 50) == [
127
+ "这是一段很长的中文文本,",
128
+ "而且没有句号,也没有感叹号,",
129
+ "也没有问号,也没有换行符.",
130
+ ]
fish_speech/utils/file.py CHANGED
@@ -44,7 +44,7 @@ def list_files(
44
  if not path.exists():
45
  raise FileNotFoundError(f"Directory {path} does not exist.")
46
 
47
- files = [file for ext in extensions for file in path.iglob(f"**/*{ext}")]
48
 
49
  if sort:
50
  files = natsorted(files)
 
44
  if not path.exists():
45
  raise FileNotFoundError(f"Directory {path} does not exist.")
46
 
47
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
48
 
49
  if sort:
50
  files = natsorted(files)
fish_speech/utils/rich_utils.py CHANGED
@@ -43,9 +43,13 @@ def print_config_tree(
43
 
44
  # add fields from `print_order` to queue
45
  for field in print_order:
46
- queue.append(field) if field in cfg else log.warning(
47
- f"Field '{field}' not found in config. "
48
- + f"Skipping '{field}' config printing..."
 
 
 
 
49
  )
50
 
51
  # add all the other fields to queue (not specified in `print_order`)
 
43
 
44
  # add fields from `print_order` to queue
45
  for field in print_order:
46
+ (
47
+ queue.append(field)
48
+ if field in cfg
49
+ else log.warning(
50
+ f"Field '{field}' not found in config. "
51
+ + f"Skipping '{field}' config printing..."
52
+ )
53
  )
54
 
55
  # add all the other fields to queue (not specified in `print_order`)
fish_speech/utils/spectrogram.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio.functional as F
3
+ from torch import Tensor, nn
4
+ from torchaudio.transforms import MelScale
5
+
6
+
7
+ class LinearSpectrogram(nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_fft=2048,
11
+ win_length=2048,
12
+ hop_length=512,
13
+ center=False,
14
+ mode="pow2_sqrt",
15
+ ):
16
+ super().__init__()
17
+
18
+ self.n_fft = n_fft
19
+ self.win_length = win_length
20
+ self.hop_length = hop_length
21
+ self.center = center
22
+ self.mode = mode
23
+
24
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
25
+
26
+ def forward(self, y: Tensor) -> Tensor:
27
+ if y.ndim == 3:
28
+ y = y.squeeze(1)
29
+
30
+ y = torch.nn.functional.pad(
31
+ y.unsqueeze(1),
32
+ (
33
+ (self.win_length - self.hop_length) // 2,
34
+ (self.win_length - self.hop_length + 1) // 2,
35
+ ),
36
+ mode="reflect",
37
+ ).squeeze(1)
38
+
39
+ spec = torch.stft(
40
+ y,
41
+ self.n_fft,
42
+ hop_length=self.hop_length,
43
+ win_length=self.win_length,
44
+ window=self.window,
45
+ center=self.center,
46
+ pad_mode="reflect",
47
+ normalized=False,
48
+ onesided=True,
49
+ return_complex=True,
50
+ )
51
+
52
+ spec = torch.view_as_real(spec)
53
+
54
+ if self.mode == "pow2_sqrt":
55
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
56
+
57
+ return spec
58
+
59
+
60
+ class LogMelSpectrogram(nn.Module):
61
+ def __init__(
62
+ self,
63
+ sample_rate=44100,
64
+ n_fft=2048,
65
+ win_length=2048,
66
+ hop_length=512,
67
+ n_mels=128,
68
+ center=False,
69
+ f_min=0.0,
70
+ f_max=None,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.sample_rate = sample_rate
75
+ self.n_fft = n_fft
76
+ self.win_length = win_length
77
+ self.hop_length = hop_length
78
+ self.center = center
79
+ self.n_mels = n_mels
80
+ self.f_min = f_min
81
+ self.f_max = f_max or float(sample_rate // 2)
82
+
83
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
84
+
85
+ fb = F.melscale_fbanks(
86
+ n_freqs=self.n_fft // 2 + 1,
87
+ f_min=self.f_min,
88
+ f_max=self.f_max,
89
+ n_mels=self.n_mels,
90
+ sample_rate=self.sample_rate,
91
+ norm="slaney",
92
+ mel_scale="slaney",
93
+ )
94
+ self.register_buffer(
95
+ "fb",
96
+ fb,
97
+ persistent=False,
98
+ )
99
+
100
+ def compress(self, x: Tensor) -> Tensor:
101
+ return torch.log(torch.clamp(x, min=1e-5))
102
+
103
+ def decompress(self, x: Tensor) -> Tensor:
104
+ return torch.exp(x)
105
+
106
+ def apply_mel_scale(self, x: Tensor) -> Tensor:
107
+ return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
108
+
109
+ def forward(
110
+ self, x: Tensor, return_linear: bool = False, sample_rate: int = None
111
+ ) -> Tensor:
112
+ if sample_rate is not None and sample_rate != self.sample_rate:
113
+ x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
114
+
115
+ linear = self.spectrogram(x)
116
+ x = self.apply_mel_scale(linear)
117
+ x = self.compress(x)
118
+
119
+ if return_linear:
120
+ return x, self.compress(linear)
121
+
122
+ return x
tools/api.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import queue
5
+ import random
6
+ import traceback
7
+ import wave
8
+ from argparse import ArgumentParser
9
+ from http import HTTPStatus
10
+ from pathlib import Path
11
+ from typing import Annotated, Literal, Optional
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import pyrootutils
16
+ import soundfile as sf
17
+ import torch
18
+ from kui.asgi import (
19
+ Body,
20
+ HTTPException,
21
+ HttpView,
22
+ JSONResponse,
23
+ Kui,
24
+ OpenAPI,
25
+ StreamResponse,
26
+ )
27
+ from kui.asgi.routing import MultimethodRoutes
28
+ from loguru import logger
29
+ from pydantic import BaseModel, Field
30
+
31
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
32
+
33
+ # from fish_speech.models.vqgan.lit_module import VQGAN
34
+ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
35
+ from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
36
+ from tools.llama.generate import (
37
+ GenerateRequest,
38
+ GenerateResponse,
39
+ WrappedGenerateResponse,
40
+ launch_thread_safe_queue,
41
+ )
42
+ from tools.vqgan.inference import load_model as load_decoder_model
43
+
44
+
45
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
46
+ buffer = io.BytesIO()
47
+
48
+ with wave.open(buffer, "wb") as wav_file:
49
+ wav_file.setnchannels(channels)
50
+ wav_file.setsampwidth(bit_depth // 8)
51
+ wav_file.setframerate(sample_rate)
52
+
53
+ wav_header_bytes = buffer.getvalue()
54
+ buffer.close()
55
+ return wav_header_bytes
56
+
57
+
58
+ # Define utils for web server
59
+ async def http_execption_handler(exc: HTTPException):
60
+ return JSONResponse(
61
+ dict(
62
+ statusCode=exc.status_code,
63
+ message=exc.content,
64
+ error=HTTPStatus(exc.status_code).phrase,
65
+ ),
66
+ exc.status_code,
67
+ exc.headers,
68
+ )
69
+
70
+
71
+ async def other_exception_handler(exc: "Exception"):
72
+ traceback.print_exc()
73
+
74
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
75
+ return JSONResponse(
76
+ dict(statusCode=status, message=str(exc), error=status.phrase),
77
+ status,
78
+ )
79
+
80
+
81
+ def load_audio(reference_audio, sr):
82
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
83
+ try:
84
+ audio_data = base64.b64decode(reference_audio)
85
+ reference_audio = io.BytesIO(audio_data)
86
+ except base64.binascii.Error:
87
+ raise ValueError("Invalid path or base64 string")
88
+
89
+ audio, _ = librosa.load(reference_audio, sr=sr, mono=True)
90
+ return audio
91
+
92
+
93
+ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
94
+ if enable_reference_audio and reference_audio is not None:
95
+ # Load audios, and prepare basic info here
96
+ reference_audio_content = load_audio(
97
+ reference_audio, decoder_model.spec_transform.sample_rate
98
+ )
99
+
100
+ audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
101
+ None, None, :
102
+ ]
103
+ audio_lengths = torch.tensor(
104
+ [audios.shape[2]], device=decoder_model.device, dtype=torch.long
105
+ )
106
+ logger.info(
107
+ f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
108
+ )
109
+
110
+ # VQ Encoder
111
+ if isinstance(decoder_model, FireflyArchitecture):
112
+ prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
113
+
114
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
115
+ else:
116
+ prompt_tokens = None
117
+ logger.info("No reference audio provided")
118
+
119
+ return prompt_tokens
120
+
121
+
122
+ def decode_vq_tokens(
123
+ *,
124
+ decoder_model,
125
+ codes,
126
+ ):
127
+ feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
128
+ logger.info(f"VQ features: {codes.shape}")
129
+
130
+ if isinstance(decoder_model, FireflyArchitecture):
131
+ # VQGAN Inference
132
+ return decoder_model.decode(
133
+ indices=codes[None],
134
+ feature_lengths=feature_lengths,
135
+ ).squeeze()
136
+
137
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
138
+
139
+
140
+ routes = MultimethodRoutes(base_class=HttpView)
141
+
142
+
143
+ def get_random_paths(base_path, data, speaker, emotion):
144
+ if base_path and data and speaker and emotion and (Path(base_path).exists()):
145
+ if speaker in data and emotion in data[speaker]:
146
+ files = data[speaker][emotion]
147
+ lab_files = [f for f in files if f.endswith(".lab")]
148
+ wav_files = [f for f in files if f.endswith(".wav")]
149
+
150
+ if lab_files and wav_files:
151
+ selected_lab = random.choice(lab_files)
152
+ selected_wav = random.choice(wav_files)
153
+
154
+ lab_path = Path(base_path) / speaker / emotion / selected_lab
155
+ wav_path = Path(base_path) / speaker / emotion / selected_wav
156
+ if lab_path.exists() and wav_path.exists():
157
+ return lab_path, wav_path
158
+
159
+ return None, None
160
+
161
+
162
+ def load_json(json_file):
163
+ if not json_file:
164
+ logger.info("Not using a json file")
165
+ return None
166
+ try:
167
+ with open(json_file, "r", encoding="utf-8") as file:
168
+ data = json.load(file)
169
+ except FileNotFoundError:
170
+ logger.warning(f"ref json not found: {json_file}")
171
+ data = None
172
+ except Exception as e:
173
+ logger.warning(f"Loading json failed: {e}")
174
+ data = None
175
+ return data
176
+
177
+
178
+ class InvokeRequest(BaseModel):
179
+ text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
180
+ reference_text: Optional[str] = None
181
+ reference_audio: Optional[str] = None
182
+ max_new_tokens: int = 1024
183
+ chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
184
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
185
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
186
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
187
+ emotion: Optional[str] = None
188
+ format: Literal["wav", "mp3", "flac"] = "wav"
189
+ streaming: bool = False
190
+ ref_json: Optional[str] = "ref_data.json"
191
+ ref_base: Optional[str] = "ref_data"
192
+ speaker: Optional[str] = None
193
+
194
+
195
+ def get_content_type(audio_format):
196
+ if audio_format == "wav":
197
+ return "audio/wav"
198
+ elif audio_format == "flac":
199
+ return "audio/flac"
200
+ elif audio_format == "mp3":
201
+ return "audio/mpeg"
202
+ else:
203
+ return "application/octet-stream"
204
+
205
+
206
+ @torch.inference_mode()
207
+ def inference(req: InvokeRequest):
208
+ # Parse reference audio aka prompt
209
+ prompt_tokens = None
210
+
211
+ ref_data = load_json(req.ref_json)
212
+ ref_base = req.ref_base
213
+
214
+ lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
215
+
216
+ if lab_path and wav_path:
217
+ with open(lab_path, "r", encoding="utf-8") as lab_file:
218
+ ref_text = lab_file.read()
219
+ req.reference_audio = wav_path
220
+ req.reference_text = ref_text
221
+ logger.info("ref_path: " + str(wav_path))
222
+ logger.info("ref_text: " + ref_text)
223
+
224
+ # Parse reference audio aka prompt
225
+ prompt_tokens = encode_reference(
226
+ decoder_model=decoder_model,
227
+ reference_audio=req.reference_audio,
228
+ enable_reference_audio=req.reference_audio is not None,
229
+ )
230
+ logger.info(f"ref_text: {req.reference_text}")
231
+ # LLAMA Inference
232
+ request = dict(
233
+ device=decoder_model.device,
234
+ max_new_tokens=req.max_new_tokens,
235
+ text=req.text,
236
+ top_p=req.top_p,
237
+ repetition_penalty=req.repetition_penalty,
238
+ temperature=req.temperature,
239
+ compile=args.compile,
240
+ iterative_prompt=req.chunk_length > 0,
241
+ chunk_length=req.chunk_length,
242
+ max_length=2048,
243
+ prompt_tokens=prompt_tokens,
244
+ prompt_text=req.reference_text,
245
+ )
246
+
247
+ response_queue = queue.Queue()
248
+ llama_queue.put(
249
+ GenerateRequest(
250
+ request=request,
251
+ response_queue=response_queue,
252
+ )
253
+ )
254
+
255
+ if req.streaming:
256
+ yield wav_chunk_header()
257
+
258
+ segments = []
259
+ while True:
260
+ result: WrappedGenerateResponse = response_queue.get()
261
+ if result.status == "error":
262
+ raise result.response
263
+ break
264
+
265
+ result: GenerateResponse = result.response
266
+ if result.action == "next":
267
+ break
268
+
269
+ with torch.autocast(
270
+ device_type=decoder_model.device.type, dtype=args.precision
271
+ ):
272
+ fake_audios = decode_vq_tokens(
273
+ decoder_model=decoder_model,
274
+ codes=result.codes,
275
+ )
276
+
277
+ fake_audios = fake_audios.float().cpu().numpy()
278
+
279
+ if req.streaming:
280
+ yield (fake_audios * 32768).astype(np.int16).tobytes()
281
+ else:
282
+ segments.append(fake_audios)
283
+
284
+ if req.streaming:
285
+ return
286
+
287
+ if len(segments) == 0:
288
+ raise HTTPException(
289
+ HTTPStatus.INTERNAL_SERVER_ERROR,
290
+ content="No audio generated, please check the input text.",
291
+ )
292
+
293
+ fake_audios = np.concatenate(segments, axis=0)
294
+ yield fake_audios
295
+
296
+
297
+ def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
298
+ if not use_auto_rerank:
299
+ # 如果不使用 auto_rerank,直接调用原始的 inference 函数
300
+ return inference(req)
301
+
302
+ zh_model, en_model = load_model()
303
+ max_attempts = 5
304
+ best_wer = float("inf")
305
+ best_audio = None
306
+
307
+ for attempt in range(max_attempts):
308
+ # 调用原始的 inference 函数
309
+ audio_generator = inference(req)
310
+ fake_audios = next(audio_generator)
311
+
312
+ asr_result = batch_asr(
313
+ zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
314
+ )[0]
315
+ wer = calculate_wer(req.text, asr_result["text"])
316
+
317
+ if wer <= 0.1 and not asr_result["huge_gap"]:
318
+ return fake_audios
319
+
320
+ if wer < best_wer:
321
+ best_wer = wer
322
+ best_audio = fake_audios
323
+
324
+ if attempt == max_attempts - 1:
325
+ break
326
+
327
+ return best_audio
328
+
329
+
330
+ async def inference_async(req: InvokeRequest):
331
+ for chunk in inference(req):
332
+ yield chunk
333
+
334
+
335
+ async def buffer_to_async_generator(buffer):
336
+ yield buffer
337
+
338
+
339
+ @routes.http.post("/v1/invoke")
340
+ async def api_invoke_model(
341
+ req: Annotated[InvokeRequest, Body(exclusive=True)],
342
+ ):
343
+ """
344
+ Invoke model and generate audio
345
+ """
346
+
347
+ if args.max_text_length > 0 and len(req.text) > args.max_text_length:
348
+ raise HTTPException(
349
+ HTTPStatus.BAD_REQUEST,
350
+ content=f"Text is too long, max length is {args.max_text_length}",
351
+ )
352
+
353
+ if req.streaming and req.format != "wav":
354
+ raise HTTPException(
355
+ HTTPStatus.BAD_REQUEST,
356
+ content="Streaming only supports WAV format",
357
+ )
358
+
359
+ if req.streaming:
360
+ return StreamResponse(
361
+ iterable=inference_async(req),
362
+ headers={
363
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
364
+ },
365
+ content_type=get_content_type(req.format),
366
+ )
367
+ else:
368
+ fake_audios = next(inference(req))
369
+ buffer = io.BytesIO()
370
+ sf.write(
371
+ buffer,
372
+ fake_audios,
373
+ decoder_model.spec_transform.sample_rate,
374
+ format=req.format,
375
+ )
376
+
377
+ return StreamResponse(
378
+ iterable=buffer_to_async_generator(buffer.getvalue()),
379
+ headers={
380
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
381
+ },
382
+ content_type=get_content_type(req.format),
383
+ )
384
+
385
+
386
+ @routes.http.post("/v1/health")
387
+ async def api_health():
388
+ """
389
+ Health check
390
+ """
391
+
392
+ return JSONResponse({"status": "ok"})
393
+
394
+
395
+ def parse_args():
396
+ parser = ArgumentParser()
397
+ parser.add_argument(
398
+ "--llama-checkpoint-path",
399
+ type=str,
400
+ default="checkpoints/fish-speech-1.2-sft",
401
+ )
402
+ parser.add_argument(
403
+ "--decoder-checkpoint-path",
404
+ type=str,
405
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
406
+ )
407
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
408
+ parser.add_argument("--device", type=str, default="cuda")
409
+ parser.add_argument("--half", action="store_true")
410
+ parser.add_argument("--compile", action="store_true")
411
+ parser.add_argument("--max-text-length", type=int, default=0)
412
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
413
+ parser.add_argument("--workers", type=int, default=1)
414
+ parser.add_argument("--use-auto-rerank", type=bool, default=True)
415
+
416
+ return parser.parse_args()
417
+
418
+
419
+ # Define Kui app
420
+ openapi = OpenAPI(
421
+ {
422
+ "title": "Fish Speech API",
423
+ },
424
+ ).routes
425
+
426
+ app = Kui(
427
+ routes=routes + openapi[1:], # Remove the default route
428
+ exception_handlers={
429
+ HTTPException: http_execption_handler,
430
+ Exception: other_exception_handler,
431
+ },
432
+ cors_config={},
433
+ )
434
+
435
+
436
+ if __name__ == "__main__":
437
+ import threading
438
+
439
+ import uvicorn
440
+
441
+ args = parse_args()
442
+ args.precision = torch.half if args.half else torch.bfloat16
443
+
444
+ logger.info("Loading Llama model...")
445
+ llama_queue = launch_thread_safe_queue(
446
+ checkpoint_path=args.llama_checkpoint_path,
447
+ device=args.device,
448
+ precision=args.precision,
449
+ compile=args.compile,
450
+ )
451
+ logger.info("Llama model loaded, loading VQ-GAN model...")
452
+
453
+ decoder_model = load_decoder_model(
454
+ config_name=args.decoder_config_name,
455
+ checkpoint_path=args.decoder_checkpoint_path,
456
+ device=args.device,
457
+ )
458
+
459
+ logger.info("VQ-GAN model loaded, warming up...")
460
+
461
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
462
+ list(
463
+ inference(
464
+ InvokeRequest(
465
+ text="Hello world.",
466
+ reference_text=None,
467
+ reference_audio=None,
468
+ max_new_tokens=0,
469
+ top_p=0.7,
470
+ repetition_penalty=1.2,
471
+ temperature=0.7,
472
+ emotion=None,
473
+ format="wav",
474
+ ref_base=None,
475
+ ref_json=None,
476
+ )
477
+ )
478
+ )
479
+
480
+ logger.info(f"Warming up done, starting server at http://{args.listen}")
481
+ host, port = args.listen.split(":")
482
+ uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
tools/auto_rerank.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["MODELSCOPE_CACHE"] = ".cache/"
4
+
5
+ import string
6
+ import time
7
+ from threading import Lock
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import opencc
12
+ import torch
13
+ from faster_whisper import WhisperModel
14
+
15
+ t2s_converter = opencc.OpenCC("t2s")
16
+
17
+
18
+ def load_model(*, device="cuda"):
19
+ model = WhisperModel(
20
+ "medium",
21
+ device=device,
22
+ compute_type="float16",
23
+ download_root="faster_whisper",
24
+ )
25
+ print("faster_whisper loaded!")
26
+ return model
27
+
28
+
29
+ @torch.no_grad()
30
+ def batch_asr_internal(model: WhisperModel, audios, sr):
31
+ resampled_audios = []
32
+ for audio in audios:
33
+
34
+ if isinstance(audio, np.ndarray):
35
+ audio = torch.from_numpy(audio).float()
36
+
37
+ if audio.dim() > 1:
38
+ audio = audio.squeeze()
39
+
40
+ assert audio.dim() == 1
41
+ audio_np = audio.numpy()
42
+ resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
43
+ resampled_audios.append(resampled_audio)
44
+
45
+ trans_results = []
46
+
47
+ for resampled_audio in resampled_audios:
48
+ segments, info = model.transcribe(
49
+ resampled_audio,
50
+ language=None,
51
+ beam_size=5,
52
+ initial_prompt="Punctuation is needed in any language.",
53
+ )
54
+ trans_results.append(list(segments))
55
+
56
+ results = []
57
+ for trans_res, audio in zip(trans_results, audios):
58
+
59
+ duration = len(audio) / sr * 1000
60
+ huge_gap = False
61
+ max_gap = 0.0
62
+
63
+ text = None
64
+ last_tr = None
65
+
66
+ for tr in trans_res:
67
+ delta = tr.text.strip()
68
+ if tr.id > 1:
69
+ max_gap = max(tr.start - last_tr.end, max_gap)
70
+ text += delta
71
+ else:
72
+ text = delta
73
+
74
+ last_tr = tr
75
+ if max_gap > 3.0:
76
+ huge_gap = True
77
+ break
78
+
79
+ sim_text = t2s_converter.convert(text)
80
+ results.append(
81
+ {
82
+ "text": sim_text,
83
+ "duration": duration,
84
+ "huge_gap": huge_gap,
85
+ }
86
+ )
87
+
88
+ return results
89
+
90
+
91
+ global_lock = Lock()
92
+
93
+
94
+ def batch_asr(model, audios, sr):
95
+ return batch_asr_internal(model, audios, sr)
96
+
97
+
98
+ def is_chinese(text):
99
+ return True
100
+
101
+
102
+ def calculate_wer(text1, text2, debug=False):
103
+ chars1 = remove_punctuation(text1)
104
+ chars2 = remove_punctuation(text2)
105
+
106
+ m, n = len(chars1), len(chars2)
107
+
108
+ if m > n:
109
+ chars1, chars2 = chars2, chars1
110
+ m, n = n, m
111
+
112
+ prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
113
+ curr = [0] * (m + 1)
114
+
115
+ for j in range(1, n + 1):
116
+ curr[0] = j
117
+ for i in range(1, m + 1):
118
+ if chars1[i - 1] == chars2[j - 1]:
119
+ curr[i] = prev[i - 1]
120
+ else:
121
+ curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
122
+ prev, curr = curr, prev
123
+
124
+ edits = prev[m]
125
+ tot = max(len(chars1), len(chars2))
126
+ wer = edits / tot
127
+
128
+ if debug:
129
+ print(" gt: ", chars1)
130
+ print(" pred: ", chars2)
131
+ print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
132
+
133
+ return wer
134
+
135
+
136
+ def remove_punctuation(text):
137
+ chinese_punctuation = (
138
+ " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
139
+ '‛""„‟…‧﹏'
140
+ )
141
+ all_punctuation = string.punctuation + chinese_punctuation
142
+ translator = str.maketrans("", "", all_punctuation)
143
+ text_without_punctuation = text.translate(translator)
144
+ return text_without_punctuation
145
+
146
+
147
+ if __name__ == "__main__":
148
+ model = load_model()
149
+ audios = [
150
+ librosa.load("44100.wav", sr=44100)[0],
151
+ librosa.load("lengyue.wav", sr=44100)[0],
152
+ ]
153
+ print(np.array(audios[0]))
154
+ print(batch_asr(model, audios, 44100))
155
+
156
+ start_time = time.time()
157
+ for _ in range(10):
158
+ print(batch_asr(model, audios, 44100))
159
+ print("Time taken:", time.time() - start_time)
tools/llama/build_dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import re
4
+ from collections import defaultdict
5
+ from functools import partial
6
+ from multiprocessing import Pool
7
+ from pathlib import Path
8
+
9
+ import click
10
+ import numpy as np
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
+ from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
+ from fish_speech.utils.file import load_filelist
17
+
18
+ # To avoid CPU overload
19
+ os.environ["MKL_NUM_THREADS"] = "1"
20
+ os.environ["OMP_NUM_THREADS"] = "1"
21
+
22
+
23
+ def task_generator_folder(root: Path, text_extension: str):
24
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
25
+ files = sorted(files)
26
+
27
+ grouped_files = defaultdict(list)
28
+ for file in tqdm(files, desc=f"Grouping {root}"):
29
+ p = str(file.parent)
30
+ speaker = file.parent.name
31
+
32
+ try:
33
+ if isinstance(text_extension, str):
34
+ texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
35
+ else:
36
+ texts = [
37
+ file.with_suffix(ext).read_text(encoding="utf-8")
38
+ for ext in text_extension
39
+ ]
40
+ except Exception as e:
41
+ logger.error(f"Failed to read text {file}: {e}")
42
+ continue
43
+
44
+ grouped_files[p].append((speaker, file, texts))
45
+
46
+ logger.info(
47
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
48
+ )
49
+
50
+ for i in grouped_files.values():
51
+ subset = [(f, t) for _, f, t in i]
52
+ yield i[0][0], subset, "folder"
53
+
54
+
55
+ def task_generator_filelist(filelist):
56
+ grouped_files = defaultdict(list)
57
+ for filename, speaker, _, text in load_filelist(filelist):
58
+ grouped_files[speaker].append((Path(filename), [text]))
59
+
60
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
61
+ for speaker, values in grouped_files.items():
62
+ yield speaker, values, "filelist"
63
+
64
+
65
+ def run_task(task):
66
+ name, subset, source = task
67
+
68
+ # Parse the files
69
+ sentences = []
70
+ for file, texts in subset:
71
+ np_file = file.with_suffix(".npy")
72
+ if np_file.exists() is False:
73
+ logger.warning(f"Can't find {np_file}")
74
+ continue
75
+
76
+ new_texts = []
77
+
78
+ for text in texts:
79
+ # Simple cleaning: replace { xxx } and < xxx > with space
80
+ text = re.sub(r"\{.*?\}", " ", text)
81
+ text = re.sub(r"<.*?>", " ", text)
82
+ text = re.sub(r"\s+", " ", text)
83
+ new_texts.append(text)
84
+
85
+ try:
86
+ semantics = np.load(np_file)
87
+ except Exception as e:
88
+ logger.error(f"Failed to parse {file}: {e}")
89
+ continue
90
+
91
+ if isinstance(semantics, np.ndarray):
92
+ semantics = semantics.tolist()
93
+
94
+ sentences.append(
95
+ Sentence(
96
+ texts=new_texts,
97
+ semantics=[Semantics(values=s) for s in semantics],
98
+ )
99
+ )
100
+
101
+ # Pack the sentences
102
+ return pack_pb_stream(
103
+ TextData(
104
+ source=source,
105
+ name=name,
106
+ sentences=sentences,
107
+ )
108
+ )
109
+
110
+
111
+ @click.command()
112
+ @click.option(
113
+ "--input",
114
+ type=click.Path(path_type=Path),
115
+ required=True,
116
+ help="A folder containing the dataset or a filelist",
117
+ multiple=True,
118
+ )
119
+ @click.option(
120
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
121
+ )
122
+ @click.option("--num-workers", type=int, default=16)
123
+ @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
124
+ @click.option(
125
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
126
+ )
127
+ def main(input, output, num_workers, text_extension, shard_size):
128
+ generator_fns = []
129
+
130
+ for f in input:
131
+ assert f.exists(), f"{f} not found"
132
+
133
+ if f.is_dir():
134
+ generator_fn = task_generator_folder(f, text_extension)
135
+ else:
136
+ generator_fn = task_generator_filelist(f)
137
+
138
+ generator_fns.append(generator_fn)
139
+
140
+ generator_fn = itertools.chain(*generator_fns)
141
+ output.mkdir(parents=True, exist_ok=True)
142
+
143
+ dataset_fp = None
144
+ tar_idx = 0
145
+ written_size = 0
146
+
147
+ with Pool(num_workers) as p:
148
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
149
+ if dataset_fp is None:
150
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
151
+
152
+ dataset_fp.write(result)
153
+ written_size += len(result)
154
+
155
+ if written_size > shard_size * 1024 * 1024:
156
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
157
+ dataset_fp.close()
158
+ dataset_fp = None
159
+ written_size = 0
160
+ tar_idx += 1
161
+
162
+ if dataset_fp is not None:
163
+ dataset_fp.close()
164
+
165
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
tools/llama/eval_in_context.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from matplotlib import pyplot as plt
5
+ from transformers import AutoTokenizer
6
+
7
+ # register eval resolver and root
8
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
9
+
10
+ from torch.utils.data import DataLoader
11
+
12
+ from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
13
+ from tools.llama.generate import load_model
14
+
15
+
16
+ def smooth(
17
+ scalars: list[float], weight: float
18
+ ) -> list[float]: # Weight between 0 and 1
19
+ last = scalars[0] # First value in the plot (first timestep)
20
+ smoothed = list()
21
+ for point in scalars:
22
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
23
+ smoothed.append(smoothed_val) # Save it
24
+ last = smoothed_val # Anchor the last smoothed value
25
+
26
+ return smoothed
27
+
28
+
29
+ @torch.inference_mode()
30
+ def analyze_one_model(loader, config, weight, max_length):
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ model = load_model(
33
+ config,
34
+ weight,
35
+ device,
36
+ torch.bfloat16,
37
+ max_length,
38
+ compile=False,
39
+ )[0]
40
+
41
+ current_step = 0
42
+ model.eval()
43
+
44
+ semantic_loss_sum = torch.zeros(
45
+ max_length,
46
+ dtype=torch.float32,
47
+ device=device,
48
+ )
49
+ counter = torch.zeros(
50
+ max_length,
51
+ dtype=torch.long,
52
+ device=device,
53
+ )
54
+
55
+ for batch in loader:
56
+ batch = {k: v.to(device) for k, v in batch.items()}
57
+
58
+ labels = batch["labels"]
59
+ outputs = model(
60
+ inp=batch["inputs"],
61
+ key_padding_mask=batch["attention_masks"],
62
+ )
63
+
64
+ token_logits = outputs.token_logits
65
+ codebook_logits = outputs.codebook_logits
66
+
67
+ # Generate labels
68
+ base_loss = F.cross_entropy(
69
+ token_logits.reshape(-1, token_logits.size(-1)),
70
+ labels[:, 0].reshape(-1),
71
+ ignore_index=-100,
72
+ reduction="none",
73
+ )
74
+
75
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
76
+ semantic_loss = F.cross_entropy(
77
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
78
+ codebook_labels.reshape(-1),
79
+ ignore_index=-100,
80
+ reduction="none",
81
+ )
82
+
83
+ base_loss = base_loss.reshape(labels[:, 0].shape)
84
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
85
+
86
+ semantic_loss_frame = semantic_loss.mean(-1)
87
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
88
+
89
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
90
+ semantic_loss_sum[~pad] += loss_sample[~pad]
91
+ counter[~pad] += 1
92
+
93
+ current_step += 1
94
+ if current_step == 10:
95
+ break
96
+
97
+ semantic_loss = semantic_loss.cpu()
98
+ counter = counter.cpu()
99
+ xs, ys = [], []
100
+
101
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
102
+ if count > 0:
103
+ xs.append(i)
104
+ ys.append((loss / count).item()) # for better loss visualization
105
+
106
+ smoothed_ys = smooth(ys, 0.95)
107
+
108
+ # Unload model
109
+ del model
110
+ torch.cuda.empty_cache()
111
+
112
+ return xs, ys, smoothed_ys
113
+
114
+
115
+ def main():
116
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
117
+ max_length = 4096
118
+
119
+ ds = AutoAugTextDataset(
120
+ ["data/protos/sft/云天河"],
121
+ tokenizer=tokenizer,
122
+ use_speaker=False,
123
+ interactive_prob=1.0,
124
+ max_length=max_length,
125
+ )
126
+
127
+ loader = DataLoader(
128
+ ds,
129
+ batch_size=8,
130
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
131
+ num_workers=0,
132
+ shuffle=False,
133
+ )
134
+
135
+ plt.figure(figsize=(10, 5), dpi=200)
136
+
137
+ plt.xlabel("Frame")
138
+ plt.ylabel("Loss")
139
+ plt.yscale("log")
140
+ plt.title("Semantic Loss")
141
+ plt.grid(which="both", axis="both")
142
+ plt.xlim(0, max_length)
143
+
144
+ tests = [
145
+ (
146
+ "pertrain-medium",
147
+ "dual_ar_2_codebook_medium",
148
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
149
+ ),
150
+ (
151
+ "sft-medium",
152
+ "dual_ar_2_codebook_medium",
153
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
154
+ ),
155
+ (
156
+ "sft-large",
157
+ "dual_ar_2_codebook_large",
158
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
159
+ ),
160
+ ]
161
+
162
+ for name, config, weight in tests:
163
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
164
+ plt.plot(xs, smoothed_ys, label=name)
165
+
166
+ plt.legend()
167
+ plt.savefig("semantic_loss.png")
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()
tools/llama/generate.py CHANGED
@@ -2,8 +2,9 @@ import os
2
  import queue
3
  import threading
4
  import time
 
5
  from pathlib import Path
6
- from typing import Optional, Tuple, Union
7
 
8
  import click
9
  import hydra
@@ -11,14 +12,11 @@ import numpy as np
11
  import torch
12
  import torch._dynamo.config
13
  import torch._inductor.config
14
- from hydra import compose, initialize
15
- from hydra.utils import instantiate
16
  from loguru import logger
17
  from tqdm import tqdm
18
- from transformers import AutoTokenizer
19
 
20
- from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
21
- from fish_speech.text.clean import clean_text
22
 
23
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
  torch._inductor.config.coordinate_descent_tuning = True
@@ -29,7 +27,11 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
29
  torch._inductor.config.fx_graph_cache = True
30
 
31
 
32
- from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
 
 
 
 
33
 
34
 
35
  def multinomial_sample_one_no_sync(
@@ -94,7 +96,9 @@ def decode_one_token_ar(
94
  codebooks = [
95
  sample(
96
  x.logits,
97
- previous_tokens=None, # Disable repetition penalty for the token codebook
 
 
98
  **sampling_kwargs,
99
  )[0]
100
  ]
@@ -159,7 +163,6 @@ def decode_n_tokens(
159
  cur_token: torch.Tensor,
160
  input_pos: torch.Tensor,
161
  num_new_tokens: int,
162
- eos_token_id: int = 2,
163
  im_end_id: int = 4,
164
  decode_one_token=decode_one_token_naive,
165
  **sampling_kwargs,
@@ -195,11 +198,7 @@ def decode_n_tokens(
195
  model.config.num_codebooks + 1, -1
196
  )
197
 
198
- if (
199
- cur_token[0, 0, -1] == eos_token_id
200
- or cur_token[0, 0, -1] == im_end_id
201
- or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
202
- ):
203
  break
204
 
205
  return previous_tokens[:, : i + 1]
@@ -212,7 +211,6 @@ def generate(
212
  model: NaiveTransformer,
213
  prompt: torch.Tensor,
214
  max_new_tokens: int,
215
- eos_token_id: int = 2,
216
  im_end_id: int = 4,
217
  decode_one_token=decode_one_token_naive,
218
  **sampling_kwargs,
@@ -253,6 +251,7 @@ def generate(
253
  if isinstance(model, NaiveTransformer)
254
  else decode_one_token_ar
255
  )
 
256
  next_token = prefill_decode(
257
  model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
258
  )
@@ -264,7 +263,6 @@ def generate(
264
  next_token.view(1, codebook_dim, -1),
265
  input_pos,
266
  max_new_tokens - 1,
267
- eos_token_id=eos_token_id,
268
  im_end_id=im_end_id,
269
  decode_one_token=decode_one_token,
270
  **sampling_kwargs,
@@ -279,22 +277,12 @@ def generate(
279
  def encode_tokens(
280
  tokenizer,
281
  string,
282
- bos=True,
283
  device="cuda",
284
  prompt_tokens=None,
285
- speaker=None,
286
  num_codebooks=4,
287
  ):
288
  string = clean_text(string)
289
-
290
- if speaker is None:
291
- speaker = "assistant"
292
-
293
- string = (
294
- f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
295
- )
296
- if bos:
297
- string = f"<|begin_of_sequence|>{string}"
298
 
299
  new_tokens = tokenizer.encode(
300
  string,
@@ -322,7 +310,7 @@ def encode_tokens(
322
  prompt_tokens = prompt_tokens[0]
323
 
324
  assert prompt_tokens.ndim == 2
325
- data = prompt_tokens + 2
326
 
327
  if prompt_tokens.shape[0] > num_codebooks:
328
  logger.warning(
@@ -330,13 +318,9 @@ def encode_tokens(
330
  )
331
  data = data[:num_codebooks]
332
 
333
- # Add eos token for each codebook
334
  data = torch.cat(
335
- (
336
- data,
337
- torch.ones((data.size(0), 1), dtype=torch.int, device=device)
338
- * CODEBOOK_EOS_TOKEN_ID,
339
- ),
340
  dim=1,
341
  )
342
 
@@ -354,49 +338,13 @@ def encode_tokens(
354
  return prompt
355
 
356
 
357
- def load_model(
358
- config_name, checkpoint_path, device, precision, max_length, compile=False
359
- ):
360
- hydra.core.global_hydra.GlobalHydra.instance().clear()
361
- with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
362
- cfg = compose(
363
- config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
364
- )
365
-
366
- model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
367
-
368
- if "int8" in str(checkpoint_path):
369
- logger.info("Using int8 weight-only quantization!")
370
- from quantize import WeightOnlyInt8QuantHandler
371
-
372
- simple_quantizer = WeightOnlyInt8QuantHandler(model)
373
- model = simple_quantizer.convert_for_runtime()
374
-
375
- if "int4" in str(checkpoint_path):
376
- logger.info("Using int4 quantization!")
377
- path_comps = checkpoint_path.name.split(".")
378
- assert path_comps[-2].startswith("g")
379
- groupsize = int(path_comps[-2][1:])
380
- from quantize import WeightOnlyInt4QuantHandler
381
-
382
- simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
383
- model = simple_quantizer.convert_for_runtime()
384
-
385
- checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
386
- if "state_dict" in checkpoint:
387
- checkpoint = checkpoint["state_dict"]
388
-
389
- if any(k.startswith("model.") for k in checkpoint):
390
- checkpoint = {
391
- k.replace("model.", ""): v
392
- for k, v in checkpoint.items()
393
- if k.startswith("model.")
394
- }
395
-
396
- model.load_state_dict(checkpoint, assign=True)
397
 
398
  model = model.to(device=device, dtype=precision)
399
- logger.info("Restored model from checkpoint")
400
 
401
  if isinstance(model, DualARTransformer):
402
  decode_one_token = decode_one_token_ar
@@ -414,29 +362,16 @@ def load_model(
414
  return model.eval(), decode_one_token
415
 
416
 
417
- def split_text(text, min_length):
418
- text = clean_text(text)
419
- segments = []
420
- curr = ""
421
- for char in text:
422
- curr += char
423
- if char not in [".", ",", "!", "?"]:
424
- continue
425
-
426
- if len(curr) >= min_length:
427
- segments.append(curr)
428
- curr = ""
429
-
430
- if curr:
431
- segments.append(curr)
432
-
433
- return segments
434
 
435
 
436
  def generate_long(
437
  *,
438
  model,
439
- tokenizer: callable,
440
  device: str | torch.device,
441
  decode_one_token: callable,
442
  text: str,
@@ -448,42 +383,49 @@ def generate_long(
448
  compile: bool = False,
449
  iterative_prompt: bool = True,
450
  max_length: int = 2048,
451
- chunk_length: int = 30,
452
- speaker: Optional[str] = None,
453
- prompt_text: Optional[str] = None,
454
- prompt_tokens: Optional[torch.Tensor] = None,
455
- is_streaming: bool = False,
456
  ):
457
  assert 0 < top_p <= 1, "top_p must be in (0, 1]"
458
  assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
459
  assert 0 < temperature < 2, "temperature must be in (0, 2)"
460
 
 
 
 
 
 
 
 
 
 
461
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
 
462
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
463
 
464
- use_prompt = prompt_text is not None and prompt_tokens is not None
465
  encoded = []
466
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
 
467
 
468
  if use_prompt:
469
- encoded_prompts = encode_tokens(
470
- tokenizer,
471
- prompt_text,
472
- prompt_tokens=prompt_tokens,
473
- bos=True,
474
- device=device,
475
- speaker=speaker,
476
- num_codebooks=model.config.num_codebooks,
477
- )
 
478
 
479
  for idx, text in enumerate(texts):
480
  encoded.append(
481
  encode_tokens(
482
  tokenizer,
483
  string=text,
484
- bos=idx == 0 and not use_prompt,
485
  device=device,
486
- speaker=speaker,
487
  num_codebooks=model.config.num_codebooks,
488
  )
489
  )
@@ -502,7 +444,6 @@ def generate_long(
502
  torch.cuda.synchronize()
503
 
504
  global_encoded = []
505
- all_codes = []
506
  seg_idx = 0
507
 
508
  while seg_idx < len(encoded):
@@ -519,7 +460,9 @@ def generate_long(
519
  count = 0
520
  for i, length in enumerate(lengths):
521
  count += length
522
- if count + length > max_length - 1024:
 
 
523
  break
524
 
525
  if i != 0 and i % 2 == 0:
@@ -532,7 +475,7 @@ def generate_long(
532
  partial_encoded = global_encoded
533
 
534
  if use_prompt:
535
- partial_encoded = [encoded_prompts] + partial_encoded
536
 
537
  cat_encoded = torch.cat(partial_encoded, dim=1)
538
  prompt_length = cat_encoded.size(1)
@@ -542,7 +485,6 @@ def generate_long(
542
  model=model,
543
  prompt=cat_encoded,
544
  max_new_tokens=max_new_tokens,
545
- eos_token_id=tokenizer.eos_token_id,
546
  im_end_id=im_end_id,
547
  decode_one_token=decode_one_token,
548
  temperature=temperature,
@@ -574,76 +516,66 @@ def generate_long(
574
 
575
  # Put the generated tokens
576
  # since there is <im_end> and <eos> tokens, we remove last 2 tokens
577
- codes = y[1:, prompt_length:-2].clone()
578
-
579
- codes = codes - 2
580
  assert (codes >= 0).all(), f"Negative code found"
581
 
582
  decoded = y[:, prompt_length:-1].clone()
583
- if decoded[0, -1] != im_end_id: # <im_end>
584
- val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
585
- decoded = torch.cat(
586
- (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
587
- )
588
-
589
  # But for global encoding, we should keep the <im_end> token
 
590
  global_encoded.append(decoded)
 
 
 
591
 
592
- if is_streaming:
593
- assert (codes >= 0).all(), f"Negative code found: {codes}"
594
- yield codes
595
- else:
596
- all_codes.append(codes)
597
 
598
- seg_idx += 1
599
 
600
- if is_streaming:
601
- # This indicates the end of the current sample
602
- yield "next"
603
- else:
604
- all_codes = torch.cat(all_codes, dim=1)
605
- assert (all_codes >= 0).all(), f"Negative code found: {codes}"
606
- yield all_codes
 
 
 
607
 
608
 
609
  def launch_thread_safe_queue(
610
- config_name,
611
  checkpoint_path,
612
  device,
613
  precision,
614
- max_length,
615
- compile=False,
616
  ):
617
  input_queue = queue.Queue()
618
  init_event = threading.Event()
619
 
620
  def worker():
621
  model, decode_one_token = load_model(
622
- config_name, checkpoint_path, device, precision, max_length, compile=compile
623
  )
624
  init_event.set()
625
 
626
  while True:
627
- item = input_queue.get()
628
  if item is None:
629
  break
630
 
631
- kwargs = item["request"]
632
- response_queue = item["response_queue"]
633
 
634
  try:
635
- item["success"] = True
636
  for chunk in generate_long(
637
  model=model, decode_one_token=decode_one_token, **kwargs
638
  ):
639
- response_queue.put(chunk)
640
-
641
- response_queue.put("done")
642
  except Exception as e:
643
- item["success"] = False
644
- item["response"] = e
645
-
646
- response_queue.put("done")
647
 
648
  threading.Thread(target=worker, daemon=True).start()
649
  init_event.wait()
@@ -657,57 +589,58 @@ def launch_thread_safe_queue(
657
  type=str,
658
  default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
659
  )
660
- @click.option("--prompt-text", type=str, default=None)
661
  @click.option(
662
- "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
 
 
 
663
  )
664
  @click.option("--num-samples", type=int, default=1)
665
  @click.option("--max-new-tokens", type=int, default=0)
666
  @click.option("--top-p", type=float, default=0.7)
667
- @click.option("--repetition-penalty", type=float, default=1.5)
668
  @click.option("--temperature", type=float, default=0.7)
669
  @click.option(
670
  "--checkpoint-path",
671
  type=click.Path(path_type=Path, exists=True),
672
- default="results/text2semantic_400m_finetune/step_000002000.pth",
673
  )
674
- @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
675
- @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
676
  @click.option("--compile/--no-compile", default=False)
677
  @click.option("--seed", type=int, default=42)
678
- @click.option("--speaker", type=str, default=None)
679
  @click.option("--half/--no-half", default=False)
680
  @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
681
- @click.option("--max-length", type=int, default=2048)
682
- @click.option("--chunk-length", type=int, default=30)
683
  def main(
684
  text: str,
685
- prompt_text: Optional[str],
686
- prompt_tokens: Optional[Path],
687
  num_samples: int,
688
  max_new_tokens: int,
689
  top_p: int,
690
  repetition_penalty: float,
691
  temperature: float,
692
  checkpoint_path: Path,
693
- config_name: str,
694
- tokenizer: str,
695
  compile: bool,
696
  seed: int,
697
- speaker: Optional[str],
698
  half: bool,
699
  iterative_prompt: bool,
700
- max_length: int,
701
  chunk_length: int,
702
  ) -> None:
703
- device = "cuda"
704
 
705
  precision = torch.half if half else torch.bfloat16
706
 
 
 
 
 
 
707
  logger.info("Loading model ...")
708
  t0 = time.time()
709
  model, decode_one_token = load_model(
710
- config_name, checkpoint_path, device, precision, max_length, compile=compile
711
  )
712
 
713
  if torch.cuda.is_available():
@@ -715,13 +648,9 @@ def main(
715
 
716
  logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
717
 
718
- prompt_tokens = (
719
- torch.from_numpy(np.load(prompt_tokens)).to(device)
720
- if prompt_tokens is not None
721
- else None
722
- )
723
 
724
- tokenizer = AutoTokenizer.from_pretrained(tokenizer)
725
  torch.manual_seed(seed)
726
 
727
  if torch.cuda.is_available():
@@ -737,19 +666,29 @@ def main(
737
  top_p=top_p,
738
  repetition_penalty=repetition_penalty,
739
  temperature=temperature,
740
- tokenizer=tokenizer,
741
  compile=compile,
742
- speaker=speaker,
743
  iterative_prompt=iterative_prompt,
744
- max_length=max_length,
745
  chunk_length=chunk_length,
746
  prompt_text=prompt_text,
747
  prompt_tokens=prompt_tokens,
748
  )
749
 
750
- for idx, codes in enumerate(generator):
751
- np.save(f"codes_{idx}.npy", codes.cpu().numpy())
752
- logger.info(f"Saved codes to codes_{idx}.npy")
 
 
 
 
 
 
 
 
 
 
 
 
 
753
 
754
 
755
  if __name__ == "__main__":
 
2
  import queue
3
  import threading
4
  import time
5
+ from dataclasses import dataclass
6
  from pathlib import Path
7
+ from typing import Literal, Optional, Tuple, Union
8
 
9
  import click
10
  import hydra
 
12
  import torch
13
  import torch._dynamo.config
14
  import torch._inductor.config
 
 
15
  from loguru import logger
16
  from tqdm import tqdm
 
17
 
18
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
19
+ from fish_speech.text import clean_text, split_text
20
 
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
22
  torch._inductor.config.coordinate_descent_tuning = True
 
27
  torch._inductor.config.fx_graph_cache = True
28
 
29
 
30
+ from fish_speech.models.text2semantic.llama import (
31
+ BaseTransformer,
32
+ DualARTransformer,
33
+ NaiveTransformer,
34
+ )
35
 
36
 
37
  def multinomial_sample_one_no_sync(
 
96
  codebooks = [
97
  sample(
98
  x.logits,
99
+ previous_tokens=(
100
+ previous_tokens[0] if previous_tokens is not None else None
101
+ ), # Disable repetition penalty for the token codebook
102
  **sampling_kwargs,
103
  )[0]
104
  ]
 
163
  cur_token: torch.Tensor,
164
  input_pos: torch.Tensor,
165
  num_new_tokens: int,
 
166
  im_end_id: int = 4,
167
  decode_one_token=decode_one_token_naive,
168
  **sampling_kwargs,
 
198
  model.config.num_codebooks + 1, -1
199
  )
200
 
201
+ if cur_token[0, 0, -1] == im_end_id:
 
 
 
 
202
  break
203
 
204
  return previous_tokens[:, : i + 1]
 
211
  model: NaiveTransformer,
212
  prompt: torch.Tensor,
213
  max_new_tokens: int,
 
214
  im_end_id: int = 4,
215
  decode_one_token=decode_one_token_naive,
216
  **sampling_kwargs,
 
251
  if isinstance(model, NaiveTransformer)
252
  else decode_one_token_ar
253
  )
254
+
255
  next_token = prefill_decode(
256
  model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
257
  )
 
263
  next_token.view(1, codebook_dim, -1),
264
  input_pos,
265
  max_new_tokens - 1,
 
266
  im_end_id=im_end_id,
267
  decode_one_token=decode_one_token,
268
  **sampling_kwargs,
 
277
  def encode_tokens(
278
  tokenizer,
279
  string,
 
280
  device="cuda",
281
  prompt_tokens=None,
 
282
  num_codebooks=4,
283
  ):
284
  string = clean_text(string)
285
+ string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
 
 
 
 
 
 
 
 
286
 
287
  new_tokens = tokenizer.encode(
288
  string,
 
310
  prompt_tokens = prompt_tokens[0]
311
 
312
  assert prompt_tokens.ndim == 2
313
+ data = prompt_tokens + 1
314
 
315
  if prompt_tokens.shape[0] > num_codebooks:
316
  logger.warning(
 
318
  )
319
  data = data[:num_codebooks]
320
 
321
+ # Add pad token for each codebook
322
  data = torch.cat(
323
+ (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
 
 
 
 
324
  dim=1,
325
  )
326
 
 
338
  return prompt
339
 
340
 
341
+ def load_model(checkpoint_path, device, precision, compile=False):
342
+ model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
343
+ checkpoint_path, load_weights=True
344
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  model = model.to(device=device, dtype=precision)
347
+ logger.info(f"Restored model from checkpoint")
348
 
349
  if isinstance(model, DualARTransformer):
350
  decode_one_token = decode_one_token_ar
 
362
  return model.eval(), decode_one_token
363
 
364
 
365
+ @dataclass
366
+ class GenerateResponse:
367
+ action: Literal["sample", "next"]
368
+ codes: Optional[torch.Tensor] = None
369
+ text: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
 
372
  def generate_long(
373
  *,
374
  model,
 
375
  device: str | torch.device,
376
  decode_one_token: callable,
377
  text: str,
 
383
  compile: bool = False,
384
  iterative_prompt: bool = True,
385
  max_length: int = 2048,
386
+ chunk_length: int = 150,
387
+ prompt_text: Optional[str | list[str]] = None,
388
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
 
 
389
  ):
390
  assert 0 < top_p <= 1, "top_p must be in (0, 1]"
391
  assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
392
  assert 0 < temperature < 2, "temperature must be in (0, 2)"
393
 
394
+ use_prompt = prompt_text is not None and prompt_tokens is not None
395
+ if use_prompt and isinstance(prompt_text, str):
396
+ prompt_text = [prompt_text]
397
+ prompt_tokens = [prompt_tokens]
398
+
399
+ assert use_prompt is False or len(prompt_text) == len(
400
+ prompt_tokens
401
+ ), "Prompt text and tokens must have the same length"
402
+
403
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
404
+ tokenizer = model.tokenizer
405
  im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
406
 
 
407
  encoded = []
408
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
409
+ encoded_prompts = []
410
 
411
  if use_prompt:
412
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
413
+ encoded_prompts.append(
414
+ encode_tokens(
415
+ tokenizer,
416
+ string=t,
417
+ device=device,
418
+ prompt_tokens=c,
419
+ num_codebooks=model.config.num_codebooks,
420
+ )
421
+ )
422
 
423
  for idx, text in enumerate(texts):
424
  encoded.append(
425
  encode_tokens(
426
  tokenizer,
427
  string=text,
 
428
  device=device,
 
429
  num_codebooks=model.config.num_codebooks,
430
  )
431
  )
 
444
  torch.cuda.synchronize()
445
 
446
  global_encoded = []
 
447
  seg_idx = 0
448
 
449
  while seg_idx < len(encoded):
 
460
  count = 0
461
  for i, length in enumerate(lengths):
462
  count += length
463
+ if count + length > max_length - 1024 - sum(
464
+ t.shape[1] for t in encoded_prompts
465
+ ):
466
  break
467
 
468
  if i != 0 and i % 2 == 0:
 
475
  partial_encoded = global_encoded
476
 
477
  if use_prompt:
478
+ partial_encoded = encoded_prompts + partial_encoded
479
 
480
  cat_encoded = torch.cat(partial_encoded, dim=1)
481
  prompt_length = cat_encoded.size(1)
 
485
  model=model,
486
  prompt=cat_encoded,
487
  max_new_tokens=max_new_tokens,
 
488
  im_end_id=im_end_id,
489
  decode_one_token=decode_one_token,
490
  temperature=temperature,
 
516
 
517
  # Put the generated tokens
518
  # since there is <im_end> and <eos> tokens, we remove last 2 tokens
519
+ codes = y[1:, prompt_length:-1].clone()
520
+ codes = codes - 1
 
521
  assert (codes >= 0).all(), f"Negative code found"
522
 
523
  decoded = y[:, prompt_length:-1].clone()
 
 
 
 
 
 
524
  # But for global encoding, we should keep the <im_end> token
525
+
526
  global_encoded.append(decoded)
527
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
528
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
529
+ seg_idx += 1
530
 
531
+ # This indicates the end of the current sample
532
+ yield GenerateResponse(action="next")
 
 
 
533
 
 
534
 
535
+ @dataclass
536
+ class WrappedGenerateResponse:
537
+ status: Literal["success", "error"]
538
+ response: Optional[GenerateResponse | Exception] = None
539
+
540
+
541
+ @dataclass
542
+ class GenerateRequest:
543
+ request: dict
544
+ response_queue: queue.Queue
545
 
546
 
547
  def launch_thread_safe_queue(
 
548
  checkpoint_path,
549
  device,
550
  precision,
551
+ compile: bool = False,
 
552
  ):
553
  input_queue = queue.Queue()
554
  init_event = threading.Event()
555
 
556
  def worker():
557
  model, decode_one_token = load_model(
558
+ checkpoint_path, device, precision, compile=compile
559
  )
560
  init_event.set()
561
 
562
  while True:
563
+ item: GenerateRequest | None = input_queue.get()
564
  if item is None:
565
  break
566
 
567
+ kwargs = item.request
568
+ response_queue = item.response_queue
569
 
570
  try:
 
571
  for chunk in generate_long(
572
  model=model, decode_one_token=decode_one_token, **kwargs
573
  ):
574
+ response_queue.put(
575
+ WrappedGenerateResponse(status="success", response=chunk)
576
+ )
577
  except Exception as e:
578
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
 
 
 
579
 
580
  threading.Thread(target=worker, daemon=True).start()
581
  init_event.wait()
 
589
  type=str,
590
  default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
591
  )
592
+ @click.option("--prompt-text", type=str, default=None, multiple=True)
593
  @click.option(
594
+ "--prompt-tokens",
595
+ type=click.Path(path_type=Path, exists=True),
596
+ default=None,
597
+ multiple=True,
598
  )
599
  @click.option("--num-samples", type=int, default=1)
600
  @click.option("--max-new-tokens", type=int, default=0)
601
  @click.option("--top-p", type=float, default=0.7)
602
+ @click.option("--repetition-penalty", type=float, default=1.2)
603
  @click.option("--temperature", type=float, default=0.7)
604
  @click.option(
605
  "--checkpoint-path",
606
  type=click.Path(path_type=Path, exists=True),
607
+ default="checkpoints/fish-speech-1.2-sft",
608
  )
609
+ @click.option("--device", type=str, default="cuda")
 
610
  @click.option("--compile/--no-compile", default=False)
611
  @click.option("--seed", type=int, default=42)
 
612
  @click.option("--half/--no-half", default=False)
613
  @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
614
+ @click.option("--chunk-length", type=int, default=100)
 
615
  def main(
616
  text: str,
617
+ prompt_text: Optional[list[str]],
618
+ prompt_tokens: Optional[list[Path]],
619
  num_samples: int,
620
  max_new_tokens: int,
621
  top_p: int,
622
  repetition_penalty: float,
623
  temperature: float,
624
  checkpoint_path: Path,
625
+ device: str,
 
626
  compile: bool,
627
  seed: int,
 
628
  half: bool,
629
  iterative_prompt: bool,
 
630
  chunk_length: int,
631
  ) -> None:
 
632
 
633
  precision = torch.half if half else torch.bfloat16
634
 
635
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
636
+ raise ValueError(
637
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
638
+ )
639
+
640
  logger.info("Loading model ...")
641
  t0 = time.time()
642
  model, decode_one_token = load_model(
643
+ checkpoint_path, device, precision, compile=compile
644
  )
645
 
646
  if torch.cuda.is_available():
 
648
 
649
  logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
650
 
651
+ if prompt_tokens is not None:
652
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
 
 
 
653
 
 
654
  torch.manual_seed(seed)
655
 
656
  if torch.cuda.is_available():
 
666
  top_p=top_p,
667
  repetition_penalty=repetition_penalty,
668
  temperature=temperature,
 
669
  compile=compile,
 
670
  iterative_prompt=iterative_prompt,
 
671
  chunk_length=chunk_length,
672
  prompt_text=prompt_text,
673
  prompt_tokens=prompt_tokens,
674
  )
675
 
676
+ idx = 0
677
+ codes = []
678
+
679
+ for response in generator:
680
+ if response.action == "sample":
681
+ codes.append(response.codes)
682
+ logger.info(f"Sampled text: {response.text}")
683
+ elif response.action == "next":
684
+ if codes:
685
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
686
+ logger.info(f"Saved codes to codes_{idx}.npy")
687
+ logger.info(f"Next sample")
688
+ codes = []
689
+ idx += 1
690
+ else:
691
+ logger.error(f"Error: {response}")
692
 
693
 
694
  if __name__ == "__main__":
tools/llama/merge_lora.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from copy import deepcopy
3
+ from pathlib import Path
4
+
5
+ import click
6
+ import hydra
7
+ import torch
8
+ from hydra import compose, initialize
9
+ from hydra.utils import instantiate
10
+ from loguru import logger
11
+
12
+ from fish_speech.models.text2semantic.llama import BaseTransformer
13
+ from fish_speech.models.text2semantic.lora import get_merged_state_dict
14
+
15
+
16
+ @click.command()
17
+ @click.option("--lora-config", type=str, default="r_8_alpha_16")
18
+ @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
19
+ @click.option("--lora-weight", type=str, required=True)
20
+ @click.option("--output", type=str, required=True)
21
+ def merge(lora_config, base_weight, lora_weight, output):
22
+ output = Path(output)
23
+ logger.info(
24
+ f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
25
+ )
26
+
27
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
28
+ cfg = compose(config_name=lora_config)
29
+
30
+ lora_config = instantiate(cfg)
31
+ logger.info(f"Loaded lora model with config {lora_config}")
32
+
33
+ llama_model = BaseTransformer.from_pretrained(
34
+ path=base_weight,
35
+ load_weights=True,
36
+ lora_config=lora_config,
37
+ )
38
+ logger.info(f"Loaded llama model")
39
+
40
+ llama_state_dict = llama_model.state_dict()
41
+ llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
42
+ llama_state_dict_copy = deepcopy(llama_state_dict)
43
+ lora_state_dict = torch.load(lora_weight, map_location="cpu")
44
+
45
+ if "state_dict" in llama_state_dict:
46
+ llama_state_dict = llama_state_dict["state_dict"]
47
+
48
+ if "state_dict" in lora_state_dict:
49
+ lora_state_dict = lora_state_dict["state_dict"]
50
+
51
+ # remove prefix model.
52
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
53
+ llama_state_dict = {
54
+ k.replace("model.", ""): v
55
+ for k, v in llama_state_dict.items()
56
+ if k.startswith("model.")
57
+ }
58
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
59
+ lora_state_dict = {
60
+ k.replace("model.", ""): v
61
+ for k, v in lora_state_dict.items()
62
+ if k.startswith("model.")
63
+ }
64
+
65
+ logger.info(f"Found {len(llama_state_dict)} keys in llama model")
66
+ logger.info(f"Found {len(lora_state_dict)} keys in lora model")
67
+
68
+ merged_state_dict = llama_state_dict | lora_state_dict
69
+ llama_model.load_state_dict(merged_state_dict, strict=True)
70
+ logger.info(f"Merged model loaded")
71
+
72
+ # Trigger eval mode to merge lora
73
+ llama_model.eval()
74
+ llama_model.save_pretrained(output, drop_lora=True)
75
+ logger.info(f"Saved merged model to {output}, validating")
76
+
77
+ new_state_dict = torch.load(output / "model.pth", map_location="cpu")
78
+ original_keys = set(llama_state_dict_copy.keys())
79
+ merged_keys = set(new_state_dict.keys())
80
+
81
+ assert original_keys == merged_keys, "Keys should be same"
82
+
83
+ for key in original_keys:
84
+ diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
85
+ if diff_l1 != 0:
86
+ break
87
+ else:
88
+ logger.error("Merged model is same as the original model")
89
+ exit(1)
90
+
91
+ logger.info("Merged model is different from the original model, check passed")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ merge()
tools/llama/quantize.py CHANGED
@@ -1,16 +1,20 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  # All rights reserved.
 
 
3
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
  import time
7
  from pathlib import Path
8
 
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
 
13
- from fish_speech.models.text2semantic.llama import ModelArgs, Transformer, find_multiple
 
14
 
15
  ##### Quantization Primitives ######
16
 
@@ -414,13 +418,26 @@ class WeightOnlyInt4Linear(torch.nn.Module):
414
  )
415
 
416
 
417
- def quantize(
418
- checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
419
- mode: str = "int8",
420
- # following arguments only available when setting int4 quantization.
421
- groupsize: int = 128,
422
- ) -> None:
423
- assert checkpoint_path.is_file(), checkpoint_path
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  device = "cpu"
426
  precision = torch.bfloat16
@@ -428,31 +445,14 @@ def quantize(
428
  print("Loading model ...")
429
  t0 = time.time()
430
 
431
- with torch.device("meta"):
432
- model = Transformer(
433
- ModelArgs(
434
- max_seq_len=4096,
435
- vocab_size=36408,
436
- n_layer=24,
437
- n_head=16,
438
- dim=1024,
439
- rope_base=10000,
440
- norm_eps=1e-5,
441
- num_codebooks=4, # single codebook
442
- codebook_size=168, # codebook size 160 + 2 special tokens
443
- )
444
- )
445
-
446
- checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
447
- if "state_dict" in checkpoint:
448
- checkpoint = checkpoint["state_dict"]
449
- checkpoint = {
450
- k.replace("model.", ""): v
451
- for k, v in checkpoint.items()
452
- if k.startswith("model.")
453
- }
454
- model.load_state_dict(checkpoint, assign=True)
455
- model = model.to(dtype=precision, device=device)
456
 
457
  if mode == "int8":
458
  print(
@@ -461,10 +461,12 @@ def quantize(
461
  quant_handler = WeightOnlyInt8QuantHandler(model)
462
  quantized_state_dict = quant_handler.create_quantized_state_dict()
463
 
464
- dir_name = checkpoint_path.parent
465
- base_name = checkpoint_path.stem
466
- suffix = checkpoint_path.suffix
467
- quantize_path = dir_name / f"{base_name}.int8{suffix}"
 
 
468
 
469
  elif mode == "int4":
470
  print(
@@ -473,10 +475,12 @@ def quantize(
473
  quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
474
  quantized_state_dict = quant_handler.create_quantized_state_dict()
475
 
476
- dir_name = checkpoint_path.parent
477
- base_name = checkpoint_path.name
478
- suffix = checkpoint_path.suffix
479
- quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
 
 
480
 
481
  else:
482
  raise ValueError(
@@ -490,26 +494,4 @@ def quantize(
490
 
491
 
492
  if __name__ == "__main__":
493
- import argparse
494
-
495
- parser = argparse.ArgumentParser(description="Quantize a model.")
496
- parser.add_argument(
497
- "--checkpoint_path",
498
- type=Path,
499
- default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
500
- help="Path to the model checkpoint to be quantized.",
501
- )
502
- parser.add_argument(
503
- "--mode",
504
- "-q",
505
- type=str,
506
- default="int8",
507
- choices=["int8", "int4"],
508
- help="type of quantization to perform",
509
- )
510
- parser.add_argument(
511
- "--groupsize", type=int, default=32, help="Group size for int4 quantization."
512
- )
513
-
514
- args = parser.parse_args()
515
- quantize(args.checkpoint_path, args.mode, args.groupsize)
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  # All rights reserved.
3
+ import datetime
4
+ import shutil
5
 
6
  # This source code is licensed under the license found in the
7
  # LICENSE file in the root directory of this source tree.
8
  import time
9
  from pathlib import Path
10
 
11
+ import click
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
 
16
+ from fish_speech.models.text2semantic.llama import find_multiple
17
+ from tools.llama.generate import load_model
18
 
19
  ##### Quantization Primitives ######
20
 
 
418
  )
419
 
420
 
421
+ def generate_folder_name():
422
+ now = datetime.datetime.now()
423
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
424
+ return folder_name
425
+
426
+
427
+ @click.command()
428
+ @click.option(
429
+ "--checkpoint-path",
430
+ type=click.Path(path_type=Path, exists=True),
431
+ default="checkpoints/fish-speech-1.2-sft",
432
+ )
433
+ @click.option(
434
+ "--mode", type=str, default="int8", help="type of quantization to perform"
435
+ )
436
+ @click.option(
437
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
438
+ )
439
+ @click.option("--timestamp", type=str, default="None", help="When to do quantization")
440
+ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
441
 
442
  device = "cpu"
443
  precision = torch.bfloat16
 
445
  print("Loading model ...")
446
  t0 = time.time()
447
 
448
+ model, _ = load_model(
449
+ checkpoint_path=checkpoint_path,
450
+ device=device,
451
+ precision=precision,
452
+ compile=False,
453
+ )
454
+ vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
455
+ now = timestamp if timestamp != "None" else generate_folder_name()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
  if mode == "int8":
458
  print(
 
461
  quant_handler = WeightOnlyInt8QuantHandler(model)
462
  quantized_state_dict = quant_handler.create_quantized_state_dict()
463
 
464
+ dir_name = checkpoint_path
465
+ dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
466
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
467
+ if (dst_name / vq_model).exists():
468
+ (dst_name / vq_model).unlink()
469
+ quantize_path = dst_name / "model.pth"
470
 
471
  elif mode == "int4":
472
  print(
 
475
  quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
476
  quantized_state_dict = quant_handler.create_quantized_state_dict()
477
 
478
+ dir_name = checkpoint_path
479
+ dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
480
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
481
+ if (dst_name / vq_model).exists():
482
+ (dst_name / vq_model).unlink()
483
+ quantize_path = dst_name / "model.pth"
484
 
485
  else:
486
  raise ValueError(
 
494
 
495
 
496
  if __name__ == "__main__":
497
+ quantize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/llama/rebuild_tokenizer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
2
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
3
+
4
+ # Initialize a tokenizer
5
+ tokenizer = Tokenizer(models.BPE())
6
+
7
+ # Customize pre-tokenization and decoding
8
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
9
+ tokenizer.decoder = decoders.ByteLevel()
10
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
11
+
12
+ # Don't train the tokenizer
13
+ trainer = trainers.BpeTrainer(
14
+ vocab_size=0,
15
+ min_frequency=2,
16
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
17
+ special_tokens=[
18
+ "<|begin_of_sequence|>",
19
+ "<|end_of_sequence|>",
20
+ "<|im_start|>",
21
+ "<|im_sep|>", # system, user, assistant, etc.
22
+ "<|im_end|>",
23
+ "<|semantic|>", # audio features
24
+ "<|pad|>",
25
+ ],
26
+ )
27
+
28
+ # <|im_start|>user<|im_sep|>...<|im_end|>
29
+ # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
30
+ tokenizer.train_from_iterator([], trainer=trainer)
31
+
32
+ print(len(tokenizer.get_vocab()))
33
+ x = tokenizer.encode(
34
+ "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
35
+ ).ids
36
+ print(x, len(x))
37
+ print(tokenizer.decode(x, skip_special_tokens=True))
38
+
39
+
40
+ tokenizer = PreTrainedTokenizerFast(
41
+ tokenizer_object=tokenizer,
42
+ pad_token="<|pad|>",
43
+ bos_token="<|begin_of_sequence|>",
44
+ eos_token="<|end_of_sequence|>",
45
+ )
46
+
47
+ # Try tokenizing a new sequence
48
+ sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
49
+ encoded = tokenizer(sequence).input_ids
50
+
51
+ print("Test encoding....")
52
+ print(f"\tSentence: {sequence}")
53
+ print(f"\tEncoded: {encoded}")
54
+ print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
55
+ print(f"\tDecoded: {tokenizer.decode(encoded)}")
56
+
57
+ tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
tools/vqgan/create_train_split.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+ from random import Random
4
+
5
+ import click
6
+ from loguru import logger
7
+ from pydub import AudioSegment
8
+ from tqdm import tqdm
9
+
10
+ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
11
+
12
+
13
+ @click.command()
14
+ @click.argument("root", type=click.Path(exists=True, path_type=Path))
15
+ @click.option("--val-ratio", type=float, default=None)
16
+ @click.option("--val-count", type=int, default=None)
17
+ @click.option("--filelist", default=None, type=Path)
18
+ @click.option("--min-duration", default=None, type=float)
19
+ @click.option("--max-duration", default=None, type=float)
20
+ def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
21
+ if filelist:
22
+ files = [i[0] for i in load_filelist(filelist)]
23
+ else:
24
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
25
+
26
+ if min_duration is None and max_duration is None:
27
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
28
+ else:
29
+ filtered_files = []
30
+ for file in tqdm(files):
31
+ try:
32
+ audio = AudioSegment.from_file(str(file))
33
+ duration = len(audio) / 1000.0
34
+
35
+ if min_duration is not None and duration < min_duration:
36
+ logger.info(
37
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
38
+ )
39
+ continue
40
+
41
+ if max_duration is not None and duration > max_duration:
42
+ logger.info(
43
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
44
+ )
45
+ continue
46
+
47
+ filtered_files.append(str(file.relative_to(root)))
48
+ except Exception as e:
49
+ logger.info(f"Error processing {file}: {e}")
50
+
51
+ logger.info(
52
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
53
+ )
54
+
55
+ Random(42).shuffle(filtered_files)
56
+
57
+ if val_count is None and val_ratio is None:
58
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
59
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
60
+ elif val_count is not None and val_ratio is not None:
61
+ logger.error("Cannot specify both val_count and val_ratio")
62
+ return
63
+ elif val_count is not None:
64
+ if val_count < 1 or val_count > len(filtered_files):
65
+ logger.error("val_count must be between 1 and number of files")
66
+ return
67
+ val_size = val_count
68
+ else:
69
+ val_size = math.ceil(len(filtered_files) * val_ratio)
70
+
71
+ logger.info(f"Using {val_size} files for validation")
72
+
73
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
74
+ f.write("\n".join(filtered_files[val_size:]))
75
+
76
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
77
+ f.write("\n".join(filtered_files[:val_size]))
78
+
79
+ logger.info("Done")
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()
tools/vqgan/extract_vq.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess as sp
3
+ import sys
4
+ import time
5
+ from datetime import timedelta
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from random import Random
9
+
10
+ import click
11
+ import numpy as np
12
+ import torch
13
+ import torchaudio
14
+ from hydra import compose, initialize
15
+ from hydra.utils import instantiate
16
+ from lightning import LightningModule
17
+ from loguru import logger
18
+ from omegaconf import OmegaConf
19
+
20
+ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
21
+
22
+ # register eval resolver
23
+ OmegaConf.register_new_resolver("eval", eval)
24
+ # This file is used to convert the audio files to text files using the Whisper model.
25
+ # It's mainly used to generate the training data for the VQ model.
26
+
27
+
28
+ RANK = int(os.environ.get("SLURM_PROCID", 0))
29
+ WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
30
+
31
+ logger_format = (
32
+ "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
33
+ "<level>{level: <8}</level> | "
34
+ "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
35
+ "{extra[rank]} - <level>{message}</level>"
36
+ )
37
+ logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
38
+ logger.remove()
39
+ logger.add(sys.stderr, format=logger_format)
40
+
41
+
42
+ @lru_cache(maxsize=1)
43
+ def get_model(
44
+ config_name: str = "firefly_gan_vq",
45
+ checkpoint_path: str = "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
46
+ device: str | torch.device = "cuda",
47
+ ):
48
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
49
+ cfg = compose(config_name=config_name)
50
+
51
+ model = instantiate(cfg)
52
+ state_dict = torch.load(
53
+ checkpoint_path,
54
+ map_location=device,
55
+ )
56
+ if "state_dict" in state_dict:
57
+ state_dict = state_dict["state_dict"]
58
+
59
+ if any("generator" in k for k in state_dict):
60
+ state_dict = {
61
+ k.replace("generator.", ""): v
62
+ for k, v in state_dict.items()
63
+ if "generator." in k
64
+ }
65
+
66
+ model.load_state_dict(state_dict, strict=False)
67
+ model.eval()
68
+ model.to(device)
69
+
70
+ logger.info(f"Loaded model")
71
+ return model
72
+
73
+
74
+ @torch.inference_mode()
75
+ def process_batch(files: list[Path], model) -> float:
76
+ wavs = []
77
+ audio_lengths = []
78
+ new_files = []
79
+ max_length = total_time = 0
80
+
81
+ for file in files:
82
+ try:
83
+ wav, sr = torchaudio.load(
84
+ str(file), backend="sox" if sys.platform == "linux" else "soundfile"
85
+ ) # Need to install libsox-dev
86
+ except Exception as e:
87
+ logger.error(f"Error reading {file}: {e}")
88
+ continue
89
+
90
+ if wav.shape[0] > 1:
91
+ wav = wav.mean(dim=0, keepdim=True)
92
+
93
+ wav = torchaudio.functional.resample(
94
+ wav.cuda(), sr, model.spec_transform.sample_rate
95
+ )[0]
96
+ total_time += len(wav) / model.spec_transform.sample_rate
97
+ max_length = max(max_length, len(wav))
98
+
99
+ wavs.append(wav)
100
+ audio_lengths.append(len(wav))
101
+ new_files.append(file)
102
+
103
+ files = new_files
104
+
105
+ # Pad to max length
106
+ for i, wav in enumerate(wavs):
107
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
108
+
109
+ audios = torch.stack(wavs, dim=0)[:, None]
110
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
111
+
112
+ # Calculate lengths
113
+ indices, feature_lengths = model.encode(audios, audio_lengths)
114
+
115
+ # Save to disk
116
+ outputs = indices.cpu().numpy()
117
+
118
+ for file, length, feature, audio_length in zip(
119
+ files, feature_lengths, outputs, audio_lengths
120
+ ):
121
+ feature = feature[:, :length]
122
+
123
+ # (T,)
124
+ with open(file.with_suffix(".npy"), "wb") as f:
125
+ np.save(f, feature)
126
+
127
+ return total_time
128
+
129
+
130
+ @click.command()
131
+ @click.argument("folder")
132
+ @click.option("--num-workers", default=1)
133
+ @click.option("--config-name", default="firefly_gan_vq")
134
+ @click.option(
135
+ "--checkpoint-path",
136
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
137
+ )
138
+ @click.option("--batch-size", default=64)
139
+ @click.option("--filelist", default=None, type=Path)
140
+ def main(
141
+ folder: str,
142
+ num_workers: int,
143
+ config_name: str,
144
+ checkpoint_path: str,
145
+ batch_size: int,
146
+ filelist: Path,
147
+ ):
148
+ if num_workers > 1 and WORLD_SIZE != num_workers:
149
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
150
+
151
+ logger.info(f"Spawning {num_workers} workers")
152
+
153
+ if torch.cuda.is_available():
154
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
155
+ if visible_devices is None:
156
+ visible_devices = list(range(torch.cuda.device_count()))
157
+ else:
158
+ visible_devices = visible_devices.split(",")
159
+ else:
160
+ # Set to empty string to avoid using GPU
161
+ visible_devices = [""]
162
+
163
+ processes = []
164
+ for i in range(num_workers):
165
+ env = os.environ.copy()
166
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
167
+ env["SLURM_PROCID"] = str(i)
168
+ env["SLURM_NTASKS"] = str(num_workers)
169
+
170
+ processes.append(
171
+ sp.Popen(
172
+ [sys.executable] + sys.argv.copy(),
173
+ env=env,
174
+ )
175
+ )
176
+
177
+ for p in processes:
178
+ p.wait()
179
+
180
+ logger.info(f"All workers finished")
181
+ return
182
+
183
+ # This is a worker
184
+ logger.info(f"Starting worker")
185
+ if filelist:
186
+ files = [i[0] for i in load_filelist(filelist)]
187
+ else:
188
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
189
+
190
+ print(f"Found {len(files)} files")
191
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
192
+
193
+ total_files = len(files)
194
+ files = files[RANK::WORLD_SIZE]
195
+ logger.info(f"Processing {len(files)}/{total_files} files")
196
+
197
+ # Batch processing
198
+ total_time = 0
199
+ begin_time = time.time()
200
+ processed_files = 0
201
+ model = get_model(config_name, checkpoint_path)
202
+
203
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
204
+ batch = files[idx : idx + batch_size]
205
+ batch_time = process_batch(batch, model)
206
+
207
+ total_time += batch_time
208
+ processed_files += len(batch)
209
+
210
+ if (n_batch + 1) % 10 == 0:
211
+ eta = (
212
+ (time.time() - begin_time)
213
+ / processed_files
214
+ * (len(files) - processed_files)
215
+ )
216
+ logger.info(
217
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
218
+ + f"ETA: {timedelta(seconds=round(eta))}s"
219
+ )
220
+
221
+ logger.info(
222
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
223
+ )
224
+
225
+
226
+ if __name__ == "__main__":
227
+ main()
tools/vqgan/inference.py CHANGED
@@ -2,13 +2,12 @@ from pathlib import Path
2
 
3
  import click
4
  import hydra
5
- import librosa
6
  import numpy as np
7
  import soundfile as sf
8
  import torch
 
9
  from hydra import compose, initialize
10
  from hydra.utils import instantiate
11
- from lightning import LightningModule
12
  from loguru import logger
13
  from omegaconf import OmegaConf
14
 
@@ -23,20 +22,26 @@ def load_model(config_name, checkpoint_path, device="cuda"):
23
  with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
24
  cfg = compose(config_name=config_name)
25
 
26
- model: LightningModule = instantiate(cfg.model)
27
  state_dict = torch.load(
28
  checkpoint_path,
29
- map_location=model.device,
30
  )
31
-
32
  if "state_dict" in state_dict:
33
  state_dict = state_dict["state_dict"]
34
 
35
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
36
  model.eval()
37
  model.to(device)
38
- logger.info("Restored model from checkpoint")
39
 
 
40
  return model
41
 
42
 
@@ -51,11 +56,10 @@ def load_model(config_name, checkpoint_path, device="cuda"):
51
  @click.option(
52
  "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
53
  )
54
- @click.option("--config-name", "-cfg", default="vqgan_pretrain")
55
  @click.option(
56
  "--checkpoint-path",
57
- "-ckpt",
58
- default="checkpoints/vq-gan-group-fsq-2x1024.pth",
59
  )
60
  @click.option(
61
  "--device",
@@ -67,21 +71,22 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
67
 
68
  if input_path.suffix in AUDIO_EXTENSIONS:
69
  logger.info(f"Processing in-place reconstruction of {input_path}")
 
70
  # Load audio
71
- audio, _ = librosa.load(
72
- input_path,
73
- sr=model.sampling_rate,
74
- mono=True,
 
75
  )
76
- audios = torch.from_numpy(audio).to(model.device)[None, None, :]
 
77
  logger.info(
78
- f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
79
  )
80
 
81
  # VQ Encoder
82
- audio_lengths = torch.tensor(
83
- [audios.shape[2]], device=model.device, dtype=torch.long
84
- )
85
  indices = model.encode(audios, audio_lengths)[0][0]
86
 
87
  logger.info(f"Generated indices of shape {indices.shape}")
@@ -91,17 +96,15 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
91
  elif input_path.suffix == ".npy":
92
  logger.info(f"Processing precomputed indices from {input_path}")
93
  indices = np.load(input_path)
94
- indices = torch.from_numpy(indices).to(model.device).long()
95
  assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
96
  else:
97
  raise ValueError(f"Unknown input type: {input_path}")
98
 
99
  # Restore
100
- feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
101
- fake_audios = model.decode(
102
- indices=indices[None], feature_lengths=feature_lengths, return_audios=True
103
- )
104
- audio_time = fake_audios.shape[-1] / model.sampling_rate
105
 
106
  logger.info(
107
  f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
@@ -109,7 +112,7 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
109
 
110
  # Save audio
111
  fake_audio = fake_audios[0, 0].float().cpu().numpy()
112
- sf.write(output_path, fake_audio, model.sampling_rate)
113
  logger.info(f"Saved audio to {output_path}")
114
 
115
 
 
2
 
3
  import click
4
  import hydra
 
5
  import numpy as np
6
  import soundfile as sf
7
  import torch
8
+ import torchaudio
9
  from hydra import compose, initialize
10
  from hydra.utils import instantiate
 
11
  from loguru import logger
12
  from omegaconf import OmegaConf
13
 
 
22
  with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
23
  cfg = compose(config_name=config_name)
24
 
25
+ model = instantiate(cfg)
26
  state_dict = torch.load(
27
  checkpoint_path,
28
+ map_location=device,
29
  )
 
30
  if "state_dict" in state_dict:
31
  state_dict = state_dict["state_dict"]
32
 
33
+ if any("generator" in k for k in state_dict):
34
+ state_dict = {
35
+ k.replace("generator.", ""): v
36
+ for k, v in state_dict.items()
37
+ if "generator." in k
38
+ }
39
+
40
+ result = model.load_state_dict(state_dict, strict=False)
41
  model.eval()
42
  model.to(device)
 
43
 
44
+ logger.info(f"Loaded model: {result}")
45
  return model
46
 
47
 
 
56
  @click.option(
57
  "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
58
  )
59
+ @click.option("--config-name", default="firefly_gan_vq")
60
  @click.option(
61
  "--checkpoint-path",
62
+ default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
 
63
  )
64
  @click.option(
65
  "--device",
 
71
 
72
  if input_path.suffix in AUDIO_EXTENSIONS:
73
  logger.info(f"Processing in-place reconstruction of {input_path}")
74
+
75
  # Load audio
76
+ audio, sr = torchaudio.load(str(input_path))
77
+ if audio.shape[0] > 1:
78
+ audio = audio.mean(0, keepdim=True)
79
+ audio = torchaudio.functional.resample(
80
+ audio, sr, model.spec_transform.sample_rate
81
  )
82
+
83
+ audios = audio[None].to(device)
84
  logger.info(
85
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
86
  )
87
 
88
  # VQ Encoder
89
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
 
 
90
  indices = model.encode(audios, audio_lengths)[0][0]
91
 
92
  logger.info(f"Generated indices of shape {indices.shape}")
 
96
  elif input_path.suffix == ".npy":
97
  logger.info(f"Processing precomputed indices from {input_path}")
98
  indices = np.load(input_path)
99
+ indices = torch.from_numpy(indices).to(device).long()
100
  assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
101
  else:
102
  raise ValueError(f"Unknown input type: {input_path}")
103
 
104
  # Restore
105
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
106
+ fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
107
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
 
 
108
 
109
  logger.info(
110
  f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
 
112
 
113
  # Save audio
114
  fake_audio = fake_audios[0, 0].float().cpu().numpy()
115
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
116
  logger.info(f"Saved audio to {output_path}")
117
 
118