Spaces:
Running
on
L4
Running
on
L4
lengyue233
commited on
Update model to large sft
Browse files- app.py +53 -37
- tools/extract_model.py +0 -21
- tools/llama/build_dataset.py +0 -165
- tools/llama/generate.py +64 -8
- tools/llama/rebuild_tokenizer.py +0 -57
- tools/merge_asr_files.py +0 -55
- tools/vqgan/create_train_split.py +0 -54
- tools/vqgan/extract_vq.py +0 -213
- tools/whisper_asr.py +0 -113
app.py
CHANGED
@@ -1,34 +1,26 @@
|
|
1 |
import subprocess as sp
|
2 |
import os
|
|
|
3 |
|
4 |
# Download if not exists
|
5 |
os.makedirs("checkpoints", exist_ok=True)
|
6 |
-
|
7 |
-
if not os.path.exists("checkpoints/text2semantic-medium-v1-2k.pth"):
|
8 |
-
print("Downloading text2semantic-medium-v1-2k.pth")
|
9 |
-
sp.run(["wget", "-q", "-O", "checkpoints/text2semantic-medium-v1-2k.pth", os.environ["CKPT_SEMANTIC"]])
|
10 |
-
|
11 |
-
if not os.path.exists("checkpoints/vq-gan-group-fsq-2x1024.pth"):
|
12 |
-
print("Downloading vq-gan-group-fsq-2x1024.pth")
|
13 |
-
sp.run(["wget", "-q", "-O", "checkpoints/vq-gan-group-fsq-2x1024.pth", os.environ["CKPT_VQGAN"]])
|
14 |
|
15 |
print("All checkpoints downloaded")
|
16 |
|
17 |
import html
|
|
|
|
|
18 |
from argparse import ArgumentParser
|
19 |
-
from io import BytesIO
|
20 |
from pathlib import Path
|
21 |
|
22 |
import gradio as gr
|
23 |
import librosa
|
24 |
-
import spaces
|
25 |
import torch
|
26 |
from loguru import logger
|
27 |
-
from torchaudio import functional as AF
|
28 |
from transformers import AutoTokenizer
|
29 |
|
30 |
-
from tools.llama.generate import
|
31 |
-
from tools.llama.generate import load_model as load_llama_model
|
32 |
from tools.vqgan.inference import load_model as load_vqgan_model
|
33 |
|
34 |
# Make einx happy
|
@@ -52,16 +44,30 @@ We are not responsible for any misuse of the model, please consider your local l
|
|
52 |
|
53 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def build_html_error_message(error):
|
57 |
return f"""
|
58 |
-
<div style="color: red;
|
|
|
59 |
{html.escape(error)}
|
60 |
</div>
|
61 |
"""
|
62 |
|
63 |
|
64 |
-
@
|
|
|
65 |
def inference(
|
66 |
text,
|
67 |
enable_reference_audio,
|
@@ -73,13 +79,10 @@ def inference(
|
|
73 |
top_p,
|
74 |
repetition_penalty,
|
75 |
temperature,
|
76 |
-
speaker
|
77 |
):
|
78 |
-
if len(reference_text) > 100:
|
79 |
-
return None, "Ref text is too long, please keep it under 100 characters."
|
80 |
-
|
81 |
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
82 |
-
return None, "Text is too long, please keep it under
|
83 |
|
84 |
# Parse reference audio aka prompt
|
85 |
prompt_tokens = None
|
@@ -103,11 +106,9 @@ def inference(
|
|
103 |
prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
104 |
|
105 |
# LLAMA Inference
|
106 |
-
|
107 |
-
model=llama_model,
|
108 |
tokenizer=llama_tokenizer,
|
109 |
device=vqgan_model.device,
|
110 |
-
decode_one_token=decode_one_token,
|
111 |
max_new_tokens=max_new_tokens,
|
112 |
text=text,
|
113 |
top_k=int(top_k) if top_k > 0 else None,
|
@@ -123,7 +124,18 @@ def inference(
|
|
123 |
prompt_text=reference_text if enable_reference_audio else None,
|
124 |
)
|
125 |
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# VQGAN Inference
|
129 |
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
@@ -151,9 +163,7 @@ def build_app():
|
|
151 |
with gr.Row():
|
152 |
with gr.Column(scale=3):
|
153 |
text = gr.Textbox(
|
154 |
-
label="Input Text / 输入文本",
|
155 |
-
placeholder=TEXTBOX_PLACEHOLDER,
|
156 |
-
lines=15,
|
157 |
)
|
158 |
|
159 |
with gr.Row():
|
@@ -198,11 +208,11 @@ def build_app():
|
|
198 |
step=0.01,
|
199 |
)
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
|
207 |
with gr.Tab(label="Reference Audio / 参考音频"):
|
208 |
gr.Markdown(
|
@@ -248,7 +258,7 @@ def build_app():
|
|
248 |
top_p,
|
249 |
repetition_penalty,
|
250 |
temperature,
|
251 |
-
|
252 |
],
|
253 |
[audio, error],
|
254 |
concurrency_limit=1,
|
@@ -262,10 +272,10 @@ def parse_args():
|
|
262 |
parser.add_argument(
|
263 |
"--llama-checkpoint-path",
|
264 |
type=Path,
|
265 |
-
default="checkpoints/text2semantic-
|
266 |
)
|
267 |
parser.add_argument(
|
268 |
-
"--llama-config-name", type=str, default="
|
269 |
)
|
270 |
parser.add_argument(
|
271 |
"--vqgan-checkpoint-path",
|
@@ -278,7 +288,7 @@ def parse_args():
|
|
278 |
parser.add_argument("--half", action="store_true")
|
279 |
parser.add_argument("--max-length", type=int, default=2048)
|
280 |
parser.add_argument("--compile", action="store_true")
|
281 |
-
parser.add_argument("--max-gradio-length", type=int, default=
|
282 |
|
283 |
return parser.parse_args()
|
284 |
|
@@ -288,9 +298,15 @@ if __name__ == "__main__":
|
|
288 |
|
289 |
args.precision = torch.half if args.half else torch.bfloat16
|
290 |
args.compile = True
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
logger.info("Loading Llama model...")
|
293 |
-
|
294 |
config_name=args.llama_config_name,
|
295 |
checkpoint_path=args.llama_checkpoint_path,
|
296 |
device=args.device,
|
|
|
1 |
import subprocess as sp
|
2 |
import os
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
|
5 |
# Download if not exists
|
6 |
os.makedirs("checkpoints", exist_ok=True)
|
7 |
+
hf_hub_download("fishaudio/fish-speech-1", "./checkpoints/fish-speech-1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
print("All checkpoints downloaded")
|
10 |
|
11 |
import html
|
12 |
+
import os
|
13 |
+
import threading
|
14 |
from argparse import ArgumentParser
|
|
|
15 |
from pathlib import Path
|
16 |
|
17 |
import gradio as gr
|
18 |
import librosa
|
|
|
19 |
import torch
|
20 |
from loguru import logger
|
|
|
21 |
from transformers import AutoTokenizer
|
22 |
|
23 |
+
from tools.llama.generate import launch_thread_safe_queue
|
|
|
24 |
from tools.vqgan.inference import load_model as load_vqgan_model
|
25 |
|
26 |
# Make einx happy
|
|
|
44 |
|
45 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
46 |
|
47 |
+
try:
|
48 |
+
import spaces
|
49 |
+
|
50 |
+
GPU_DECORATOR = spaces.GPU
|
51 |
+
except ImportError:
|
52 |
+
|
53 |
+
def GPU_DECORATOR(func):
|
54 |
+
def wrapper(*args, **kwargs):
|
55 |
+
return func(*args, **kwargs)
|
56 |
+
|
57 |
+
return wrapper
|
58 |
+
|
59 |
|
60 |
def build_html_error_message(error):
|
61 |
return f"""
|
62 |
+
<div style="color: red;
|
63 |
+
font-weight: bold;">
|
64 |
{html.escape(error)}
|
65 |
</div>
|
66 |
"""
|
67 |
|
68 |
|
69 |
+
@GPU_DECORATOR
|
70 |
+
@torch.inference_mode()
|
71 |
def inference(
|
72 |
text,
|
73 |
enable_reference_audio,
|
|
|
79 |
top_p,
|
80 |
repetition_penalty,
|
81 |
temperature,
|
82 |
+
speaker,
|
83 |
):
|
|
|
|
|
|
|
84 |
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
85 |
+
return None, f"Text is too long, please keep it under {args.max_gradio_length} characters."
|
86 |
|
87 |
# Parse reference audio aka prompt
|
88 |
prompt_tokens = None
|
|
|
106 |
prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
107 |
|
108 |
# LLAMA Inference
|
109 |
+
request = dict(
|
|
|
110 |
tokenizer=llama_tokenizer,
|
111 |
device=vqgan_model.device,
|
|
|
112 |
max_new_tokens=max_new_tokens,
|
113 |
text=text,
|
114 |
top_k=int(top_k) if top_k > 0 else None,
|
|
|
124 |
prompt_text=reference_text if enable_reference_audio else None,
|
125 |
)
|
126 |
|
127 |
+
payload = dict(
|
128 |
+
event=threading.Event(),
|
129 |
+
request=request,
|
130 |
+
)
|
131 |
+
llama_queue.put(payload)
|
132 |
+
|
133 |
+
# Wait for the result
|
134 |
+
payload["event"].wait()
|
135 |
+
if payload["success"] is False:
|
136 |
+
raise payload["response"]
|
137 |
+
|
138 |
+
codes = payload["response"][0]
|
139 |
|
140 |
# VQGAN Inference
|
141 |
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
|
|
163 |
with gr.Row():
|
164 |
with gr.Column(scale=3):
|
165 |
text = gr.Textbox(
|
166 |
+
label="Input Text / 输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
|
|
|
|
|
167 |
)
|
168 |
|
169 |
with gr.Row():
|
|
|
208 |
step=0.01,
|
209 |
)
|
210 |
|
211 |
+
speaker = gr.Textbox(
|
212 |
+
label="Speaker / 说话人",
|
213 |
+
placeholder="Type name of the speaker / 输入说话人的名称",
|
214 |
+
lines=1,
|
215 |
+
)
|
216 |
|
217 |
with gr.Tab(label="Reference Audio / 参考音频"):
|
218 |
gr.Markdown(
|
|
|
258 |
top_p,
|
259 |
repetition_penalty,
|
260 |
temperature,
|
261 |
+
speaker,
|
262 |
],
|
263 |
[audio, error],
|
264 |
concurrency_limit=1,
|
|
|
272 |
parser.add_argument(
|
273 |
"--llama-checkpoint-path",
|
274 |
type=Path,
|
275 |
+
default="checkpoints/text2semantic-sft-large-v1-4k.pth",
|
276 |
)
|
277 |
parser.add_argument(
|
278 |
+
"--llama-config-name", type=str, default="dual_ar_2_codebook_large"
|
279 |
)
|
280 |
parser.add_argument(
|
281 |
"--vqgan-checkpoint-path",
|
|
|
288 |
parser.add_argument("--half", action="store_true")
|
289 |
parser.add_argument("--max-length", type=int, default=2048)
|
290 |
parser.add_argument("--compile", action="store_true")
|
291 |
+
parser.add_argument("--max-gradio-length", type=int, default=0)
|
292 |
|
293 |
return parser.parse_args()
|
294 |
|
|
|
298 |
|
299 |
args.precision = torch.half if args.half else torch.bfloat16
|
300 |
args.compile = True
|
301 |
+
args.max_gradio_length = 1024
|
302 |
+
args.tokenizer = "./checkpoints/fish-speech-1"
|
303 |
+
args.llama_checkpoint_path = "./checkpoints/text2semantic-sft-large-v1-4k.pth"
|
304 |
+
args.llama_config_name = "dual_ar_2_codebook_large"
|
305 |
+
args.vqgan_checkpoint_path = "./checkpoints/vq-gan-group-fsq-2x1024.pth"
|
306 |
+
args.vqgan_config_name = "vqgan_pretrain"
|
307 |
|
308 |
logger.info("Loading Llama model...")
|
309 |
+
llama_queue = launch_thread_safe_queue(
|
310 |
config_name=args.llama_config_name,
|
311 |
checkpoint_path=args.llama_checkpoint_path,
|
312 |
device=args.device,
|
tools/extract_model.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import click
|
2 |
-
import torch
|
3 |
-
from loguru import logger
|
4 |
-
|
5 |
-
|
6 |
-
@click.command()
|
7 |
-
@click.argument("model_path")
|
8 |
-
@click.argument("output_path")
|
9 |
-
def main(model_path, output_path):
|
10 |
-
if model_path == output_path:
|
11 |
-
logger.error("Model path and output path are the same")
|
12 |
-
return
|
13 |
-
|
14 |
-
logger.info(f"Loading model from {model_path}")
|
15 |
-
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
16 |
-
torch.save(state_dict, output_path)
|
17 |
-
logger.info(f"Model saved to {output_path}")
|
18 |
-
|
19 |
-
|
20 |
-
if __name__ == "__main__":
|
21 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/build_dataset.py
DELETED
@@ -1,165 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
import os
|
3 |
-
import re
|
4 |
-
from collections import defaultdict
|
5 |
-
from functools import partial
|
6 |
-
from multiprocessing import Pool
|
7 |
-
from pathlib import Path
|
8 |
-
|
9 |
-
import click
|
10 |
-
import numpy as np
|
11 |
-
from loguru import logger
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
15 |
-
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
16 |
-
from fish_speech.utils.file import load_filelist
|
17 |
-
|
18 |
-
# To avoid CPU overload
|
19 |
-
os.environ["MKL_NUM_THREADS"] = "1"
|
20 |
-
os.environ["OMP_NUM_THREADS"] = "1"
|
21 |
-
|
22 |
-
|
23 |
-
def task_generator_folder(root: Path, text_extension: str):
|
24 |
-
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
25 |
-
files = sorted(files)
|
26 |
-
|
27 |
-
grouped_files = defaultdict(list)
|
28 |
-
for file in tqdm(files, desc=f"Grouping {root}"):
|
29 |
-
p = str(file.parent)
|
30 |
-
|
31 |
-
try:
|
32 |
-
if isinstance(text_extension, str):
|
33 |
-
texts = [file.with_suffix(text_extension).read_text()]
|
34 |
-
else:
|
35 |
-
texts = [file.with_suffix(ext).read_text() for ext in text_extension]
|
36 |
-
except Exception as e:
|
37 |
-
logger.error(f"Failed to read text {file}: {e}")
|
38 |
-
continue
|
39 |
-
|
40 |
-
grouped_files[p].append((file, texts))
|
41 |
-
|
42 |
-
logger.info(
|
43 |
-
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
44 |
-
)
|
45 |
-
for name, subset in grouped_files.items():
|
46 |
-
yield name, subset, "folder"
|
47 |
-
|
48 |
-
|
49 |
-
def task_generator_filelist(filelist):
|
50 |
-
grouped_files = defaultdict(list)
|
51 |
-
for filename, speaker, _, text in load_filelist(filelist):
|
52 |
-
grouped_files[speaker].append((Path(filename), [text]))
|
53 |
-
|
54 |
-
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
55 |
-
for speaker, values in grouped_files.items():
|
56 |
-
yield speaker, values, "filelist"
|
57 |
-
|
58 |
-
|
59 |
-
def run_task(task):
|
60 |
-
name, subset, source = task
|
61 |
-
|
62 |
-
# Parse the files
|
63 |
-
sentences = []
|
64 |
-
for file in subset:
|
65 |
-
file, texts = file
|
66 |
-
|
67 |
-
np_file = file.with_suffix(".npy")
|
68 |
-
if np_file.exists() is False:
|
69 |
-
logger.warning(f"Can't find {np_file}")
|
70 |
-
continue
|
71 |
-
|
72 |
-
new_texts = []
|
73 |
-
|
74 |
-
for text in texts:
|
75 |
-
# Simple cleaning: replace { xxx } and < xxx > with space
|
76 |
-
text = re.sub(r"\{.*?\}", " ", text)
|
77 |
-
text = re.sub(r"<.*?>", " ", text)
|
78 |
-
text = re.sub(r"\s+", " ", text)
|
79 |
-
new_texts.append(text)
|
80 |
-
|
81 |
-
try:
|
82 |
-
semantics = np.load(np_file)
|
83 |
-
except Exception as e:
|
84 |
-
logger.error(f"Failed to parse {file}: {e}")
|
85 |
-
continue
|
86 |
-
|
87 |
-
if isinstance(semantics, np.ndarray):
|
88 |
-
semantics = semantics.tolist()
|
89 |
-
|
90 |
-
sentences.append(
|
91 |
-
Sentence(
|
92 |
-
texts=new_texts,
|
93 |
-
semantics=[Semantics(values=s) for s in semantics],
|
94 |
-
)
|
95 |
-
)
|
96 |
-
|
97 |
-
# Pack the sentences
|
98 |
-
return pack_pb_stream(
|
99 |
-
TextData(
|
100 |
-
source=source,
|
101 |
-
name=name,
|
102 |
-
sentences=sentences,
|
103 |
-
)
|
104 |
-
)
|
105 |
-
|
106 |
-
|
107 |
-
@click.command()
|
108 |
-
@click.option(
|
109 |
-
"--input",
|
110 |
-
type=click.Path(path_type=Path),
|
111 |
-
required=True,
|
112 |
-
help="A folder containing the dataset or a filelist",
|
113 |
-
multiple=True,
|
114 |
-
)
|
115 |
-
@click.option(
|
116 |
-
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
117 |
-
)
|
118 |
-
@click.option("--num-workers", type=int, default=16)
|
119 |
-
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
120 |
-
@click.option(
|
121 |
-
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
122 |
-
)
|
123 |
-
def main(input, output, num_workers, text_extension, shard_size):
|
124 |
-
generator_fns = []
|
125 |
-
|
126 |
-
for f in input:
|
127 |
-
assert f.exists(), f"{f} not found"
|
128 |
-
|
129 |
-
if f.is_dir():
|
130 |
-
generator_fn = task_generator_folder(f, text_extension)
|
131 |
-
else:
|
132 |
-
generator_fn = task_generator_filelist(f)
|
133 |
-
|
134 |
-
generator_fns.append(generator_fn)
|
135 |
-
|
136 |
-
generator_fn = itertools.chain(*generator_fns)
|
137 |
-
output.mkdir(parents=True, exist_ok=True)
|
138 |
-
|
139 |
-
dataset_fp = None
|
140 |
-
tar_idx = 0
|
141 |
-
written_size = 0
|
142 |
-
|
143 |
-
with Pool(num_workers) as p:
|
144 |
-
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
145 |
-
if dataset_fp is None:
|
146 |
-
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
147 |
-
|
148 |
-
dataset_fp.write(result)
|
149 |
-
written_size += len(result)
|
150 |
-
|
151 |
-
if written_size > shard_size * 1024 * 1024:
|
152 |
-
logger.info(f"Finished writing {tar_idx} shards to {output}")
|
153 |
-
dataset_fp.close()
|
154 |
-
dataset_fp = None
|
155 |
-
written_size = 0
|
156 |
-
tar_idx += 1
|
157 |
-
|
158 |
-
if dataset_fp is not None:
|
159 |
-
dataset_fp.close()
|
160 |
-
|
161 |
-
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
162 |
-
|
163 |
-
|
164 |
-
if __name__ == "__main__":
|
165 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/generate.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import time
|
3 |
from pathlib import Path
|
4 |
from typing import Optional, Tuple, Union
|
5 |
|
6 |
import click
|
|
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torch._dynamo.config
|
@@ -361,6 +364,7 @@ def encode_tokens(
|
|
361 |
def load_model(
|
362 |
config_name, checkpoint_path, device, precision, max_length, compile=False
|
363 |
):
|
|
|
364 |
with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
|
365 |
cfg = compose(
|
366 |
config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
|
@@ -456,6 +460,7 @@ def generate_long(
|
|
456 |
speaker: Optional[str] = None,
|
457 |
prompt_text: Optional[str] = None,
|
458 |
prompt_tokens: Optional[torch.Tensor] = None,
|
|
|
459 |
):
|
460 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
461 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
@@ -496,6 +501,10 @@ def generate_long(
|
|
496 |
all_codes = []
|
497 |
seg_idx = 0
|
498 |
|
|
|
|
|
|
|
|
|
499 |
while seg_idx < len(encoded):
|
500 |
logger.info(
|
501 |
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
@@ -562,10 +571,7 @@ def generate_long(
|
|
562 |
codes = y[1:, prompt_length:-2].clone()
|
563 |
|
564 |
codes = codes - 2
|
565 |
-
|
566 |
-
global_encoded.pop()
|
567 |
-
logger.warning(f"Negative code found: {codes}, retrying ...")
|
568 |
-
continue
|
569 |
|
570 |
decoded = y[:, prompt_length:-1].clone()
|
571 |
if decoded[0, -1] != im_end_id: # <im_end>
|
@@ -576,13 +582,63 @@ def generate_long(
|
|
576 |
|
577 |
# But for global encoding, we should keep the <im_end> token
|
578 |
global_encoded.append(decoded)
|
579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
seg_idx += 1
|
581 |
|
582 |
-
|
583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
|
585 |
-
|
586 |
|
587 |
|
588 |
@click.command()
|
|
|
1 |
import os
|
2 |
+
import queue
|
3 |
+
import threading
|
4 |
import time
|
5 |
from pathlib import Path
|
6 |
from typing import Optional, Tuple, Union
|
7 |
|
8 |
import click
|
9 |
+
import hydra
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
import torch._dynamo.config
|
|
|
364 |
def load_model(
|
365 |
config_name, checkpoint_path, device, precision, max_length, compile=False
|
366 |
):
|
367 |
+
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
368 |
with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
|
369 |
cfg = compose(
|
370 |
config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
|
|
|
460 |
speaker: Optional[str] = None,
|
461 |
prompt_text: Optional[str] = None,
|
462 |
prompt_tokens: Optional[torch.Tensor] = None,
|
463 |
+
is_streaming: bool = False,
|
464 |
):
|
465 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
466 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
|
501 |
all_codes = []
|
502 |
seg_idx = 0
|
503 |
|
504 |
+
if use_prompt:
|
505 |
+
seg_idx = 1
|
506 |
+
global_encoded.append(encoded[0])
|
507 |
+
|
508 |
while seg_idx < len(encoded):
|
509 |
logger.info(
|
510 |
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
|
571 |
codes = y[1:, prompt_length:-2].clone()
|
572 |
|
573 |
codes = codes - 2
|
574 |
+
assert (codes >= 0).all(), f"Negative code found"
|
|
|
|
|
|
|
575 |
|
576 |
decoded = y[:, prompt_length:-1].clone()
|
577 |
if decoded[0, -1] != im_end_id: # <im_end>
|
|
|
582 |
|
583 |
# But for global encoding, we should keep the <im_end> token
|
584 |
global_encoded.append(decoded)
|
585 |
+
|
586 |
+
if is_streaming:
|
587 |
+
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
588 |
+
yield codes
|
589 |
+
else:
|
590 |
+
all_codes.append(codes)
|
591 |
+
|
592 |
seg_idx += 1
|
593 |
|
594 |
+
if is_streaming:
|
595 |
+
# This indicates the end of the current sample
|
596 |
+
yield None
|
597 |
+
else:
|
598 |
+
all_codes = torch.cat(all_codes, dim=1)
|
599 |
+
assert (all_codes >= 0).all(), f"Negative code found: {codes}"
|
600 |
+
yield all_codes
|
601 |
+
|
602 |
+
|
603 |
+
def launch_thread_safe_queue(
|
604 |
+
config_name,
|
605 |
+
checkpoint_path,
|
606 |
+
device,
|
607 |
+
precision,
|
608 |
+
max_length,
|
609 |
+
compile=False,
|
610 |
+
):
|
611 |
+
input_queue = queue.Queue()
|
612 |
+
|
613 |
+
def worker():
|
614 |
+
model, decode_one_token = load_model(
|
615 |
+
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
616 |
+
)
|
617 |
+
|
618 |
+
while True:
|
619 |
+
item = input_queue.get()
|
620 |
+
if item is None:
|
621 |
+
break
|
622 |
+
|
623 |
+
kwargs = item["request"]
|
624 |
+
event = item["event"]
|
625 |
+
|
626 |
+
try:
|
627 |
+
item["success"] = True
|
628 |
+
item["response"] = list(
|
629 |
+
generate_long(
|
630 |
+
model=model, decode_one_token=decode_one_token, **kwargs
|
631 |
+
)
|
632 |
+
)
|
633 |
+
except Exception as e:
|
634 |
+
item["success"] = False
|
635 |
+
item["response"] = e
|
636 |
+
|
637 |
+
event.set()
|
638 |
+
|
639 |
+
threading.Thread(target=worker, daemon=True).start()
|
640 |
|
641 |
+
return input_queue
|
642 |
|
643 |
|
644 |
@click.command()
|
tools/llama/rebuild_tokenizer.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
|
2 |
-
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
3 |
-
|
4 |
-
# Initialize a tokenizer
|
5 |
-
tokenizer = Tokenizer(models.BPE())
|
6 |
-
|
7 |
-
# Customize pre-tokenization and decoding
|
8 |
-
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
9 |
-
tokenizer.decoder = decoders.ByteLevel()
|
10 |
-
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
11 |
-
|
12 |
-
# Don't train the tokenizer
|
13 |
-
trainer = trainers.BpeTrainer(
|
14 |
-
vocab_size=0,
|
15 |
-
min_frequency=2,
|
16 |
-
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
17 |
-
special_tokens=[
|
18 |
-
"<|begin_of_sequence|>",
|
19 |
-
"<|end_of_sequence|>",
|
20 |
-
"<|im_start|>",
|
21 |
-
"<|im_sep|>", # system, user, assistant, etc.
|
22 |
-
"<|im_end|>",
|
23 |
-
"<|semantic|>", # audio features
|
24 |
-
"<|pad|>",
|
25 |
-
],
|
26 |
-
)
|
27 |
-
|
28 |
-
# <|im_start|>user<|im_sep|>...<|im_end|>
|
29 |
-
# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
|
30 |
-
tokenizer.train_from_iterator([], trainer=trainer)
|
31 |
-
|
32 |
-
print(len(tokenizer.get_vocab()))
|
33 |
-
x = tokenizer.encode(
|
34 |
-
"Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
|
35 |
-
).ids
|
36 |
-
print(x, len(x))
|
37 |
-
print(tokenizer.decode(x, skip_special_tokens=True))
|
38 |
-
|
39 |
-
|
40 |
-
tokenizer = PreTrainedTokenizerFast(
|
41 |
-
tokenizer_object=tokenizer,
|
42 |
-
pad_token="<|pad|>",
|
43 |
-
bos_token="<|begin_of_sequence|>",
|
44 |
-
eos_token="<|end_of_sequence|>",
|
45 |
-
)
|
46 |
-
|
47 |
-
# Try tokenizing a new sequence
|
48 |
-
sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
|
49 |
-
encoded = tokenizer(sequence).input_ids
|
50 |
-
|
51 |
-
print("Test encoding....")
|
52 |
-
print(f"\tSentence: {sequence}")
|
53 |
-
print(f"\tEncoded: {encoded}")
|
54 |
-
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
|
55 |
-
print(f"\tDecoded: {tokenizer.decode(encoded)}")
|
56 |
-
|
57 |
-
tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/merge_asr_files.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from pathlib import Path
|
3 |
-
|
4 |
-
from pydub import AudioSegment
|
5 |
-
from tqdm import tqdm
|
6 |
-
|
7 |
-
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
|
8 |
-
|
9 |
-
|
10 |
-
def merge_and_delete_files(save_dir, original_files):
|
11 |
-
save_path = Path(save_dir)
|
12 |
-
audio_slice_files = list_files(
|
13 |
-
path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
|
14 |
-
)
|
15 |
-
audio_files = {}
|
16 |
-
label_files = {}
|
17 |
-
for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
|
18 |
-
rel_path = Path(file_path).relative_to(save_path)
|
19 |
-
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
20 |
-
if file_path.suffix == ".wav":
|
21 |
-
prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
|
22 |
-
if prefix == rel_path.parent / file_path.stem:
|
23 |
-
continue
|
24 |
-
audio = AudioSegment.from_wav(file_path)
|
25 |
-
if prefix in audio_files.keys():
|
26 |
-
audio_files[prefix] = audio_files[prefix] + audio
|
27 |
-
else:
|
28 |
-
audio_files[prefix] = audio
|
29 |
-
|
30 |
-
elif file_path.suffix == ".lab":
|
31 |
-
prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
|
32 |
-
if prefix == rel_path.parent / file_path.stem:
|
33 |
-
continue
|
34 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
35 |
-
label = f.read()
|
36 |
-
if prefix in label_files.keys():
|
37 |
-
label_files[prefix] = label_files[prefix] + ", " + label
|
38 |
-
else:
|
39 |
-
label_files[prefix] = label
|
40 |
-
|
41 |
-
for prefix, audio in audio_files.items():
|
42 |
-
output_audio_path = save_path / f"{prefix}.wav"
|
43 |
-
audio.export(output_audio_path, format="wav")
|
44 |
-
|
45 |
-
for prefix, label in label_files.items():
|
46 |
-
output_label_path = save_path / f"{prefix}.lab"
|
47 |
-
with open(output_label_path, "w", encoding="utf-8") as f:
|
48 |
-
f.write(label)
|
49 |
-
|
50 |
-
for file_path in original_files:
|
51 |
-
os.remove(file_path)
|
52 |
-
|
53 |
-
|
54 |
-
if __name__ == "__main__":
|
55 |
-
merge_and_delete_files("/made/by/spicysama/laziman", [__file__])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/vqgan/create_train_split.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from pathlib import Path
|
3 |
-
from random import Random
|
4 |
-
|
5 |
-
import click
|
6 |
-
from loguru import logger
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
10 |
-
|
11 |
-
|
12 |
-
@click.command()
|
13 |
-
@click.argument("root", type=click.Path(exists=True, path_type=Path))
|
14 |
-
@click.option("--val-ratio", type=float, default=None)
|
15 |
-
@click.option("--val-count", type=int, default=None)
|
16 |
-
@click.option("--filelist", default=None, type=Path)
|
17 |
-
def main(root, val_ratio, val_count, filelist):
|
18 |
-
if filelist:
|
19 |
-
files = [i[0] for i in load_filelist(filelist)]
|
20 |
-
else:
|
21 |
-
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
|
22 |
-
|
23 |
-
logger.info(f"Found {len(files)} files")
|
24 |
-
files = [str(file.relative_to(root)) for file in tqdm(files)]
|
25 |
-
|
26 |
-
Random(42).shuffle(files)
|
27 |
-
|
28 |
-
if val_count is None and val_ratio is None:
|
29 |
-
logger.info("Validation ratio and count not specified, using min(20%, 100)")
|
30 |
-
val_size = min(100, math.ceil(len(files) * 0.2))
|
31 |
-
elif val_count is not None and val_ratio is not None:
|
32 |
-
logger.error("Cannot specify both val_count and val_ratio")
|
33 |
-
return
|
34 |
-
elif val_count is not None:
|
35 |
-
if val_count < 1 or val_count > len(files):
|
36 |
-
logger.error("val_count must be between 1 and number of files")
|
37 |
-
return
|
38 |
-
val_size = val_count
|
39 |
-
else:
|
40 |
-
val_size = math.ceil(len(files) * val_ratio)
|
41 |
-
|
42 |
-
logger.info(f"Using {val_size} files for validation")
|
43 |
-
|
44 |
-
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
|
45 |
-
f.write("\n".join(files[val_size:]))
|
46 |
-
|
47 |
-
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
|
48 |
-
f.write("\n".join(files[:val_size]))
|
49 |
-
|
50 |
-
logger.info("Done")
|
51 |
-
|
52 |
-
|
53 |
-
if __name__ == "__main__":
|
54 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/vqgan/extract_vq.py
DELETED
@@ -1,213 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import subprocess as sp
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
from datetime import timedelta
|
6 |
-
from functools import lru_cache
|
7 |
-
from pathlib import Path
|
8 |
-
from random import Random
|
9 |
-
|
10 |
-
import click
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
import torchaudio
|
14 |
-
from hydra import compose, initialize
|
15 |
-
from hydra.utils import instantiate
|
16 |
-
from lightning import LightningModule
|
17 |
-
from loguru import logger
|
18 |
-
from omegaconf import OmegaConf
|
19 |
-
|
20 |
-
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
21 |
-
|
22 |
-
# register eval resolver
|
23 |
-
OmegaConf.register_new_resolver("eval", eval)
|
24 |
-
# This file is used to convert the audio files to text files using the Whisper model.
|
25 |
-
# It's mainly used to generate the training data for the VQ model.
|
26 |
-
|
27 |
-
|
28 |
-
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
29 |
-
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
|
30 |
-
|
31 |
-
logger_format = (
|
32 |
-
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
33 |
-
"<level>{level: <8}</level> | "
|
34 |
-
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
35 |
-
"{extra[rank]} - <level>{message}</level>"
|
36 |
-
)
|
37 |
-
logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
|
38 |
-
logger.remove()
|
39 |
-
logger.add(sys.stderr, format=logger_format)
|
40 |
-
|
41 |
-
|
42 |
-
@lru_cache(maxsize=1)
|
43 |
-
def get_model(
|
44 |
-
config_name: str = "vqgan_pretrain",
|
45 |
-
checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
|
46 |
-
):
|
47 |
-
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
48 |
-
cfg = compose(config_name=config_name)
|
49 |
-
|
50 |
-
model: LightningModule = instantiate(cfg.model)
|
51 |
-
state_dict = torch.load(
|
52 |
-
checkpoint_path,
|
53 |
-
map_location=model.device,
|
54 |
-
)
|
55 |
-
if "state_dict" in state_dict:
|
56 |
-
state_dict = state_dict["state_dict"]
|
57 |
-
|
58 |
-
model.load_state_dict(state_dict, strict=False)
|
59 |
-
model.eval()
|
60 |
-
model.cuda()
|
61 |
-
|
62 |
-
logger.info(f"Loaded model")
|
63 |
-
return model
|
64 |
-
|
65 |
-
|
66 |
-
@torch.inference_mode()
|
67 |
-
def process_batch(files: list[Path], model) -> float:
|
68 |
-
wavs = []
|
69 |
-
audio_lengths = []
|
70 |
-
new_files = []
|
71 |
-
max_length = total_time = 0
|
72 |
-
|
73 |
-
for file in files:
|
74 |
-
try:
|
75 |
-
wav, sr = torchaudio.load(
|
76 |
-
str(file), backend="sox"
|
77 |
-
) # Need to install libsox-dev
|
78 |
-
except Exception as e:
|
79 |
-
logger.error(f"Error reading {file}: {e}")
|
80 |
-
continue
|
81 |
-
|
82 |
-
if wav.shape[0] > 1:
|
83 |
-
wav = wav.mean(dim=0, keepdim=True)
|
84 |
-
|
85 |
-
wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
|
86 |
-
total_time += len(wav) / model.sampling_rate
|
87 |
-
max_length = max(max_length, len(wav))
|
88 |
-
|
89 |
-
wavs.append(wav)
|
90 |
-
audio_lengths.append(len(wav))
|
91 |
-
new_files.append(file)
|
92 |
-
|
93 |
-
files = new_files
|
94 |
-
|
95 |
-
# Pad to max length
|
96 |
-
for i, wav in enumerate(wavs):
|
97 |
-
wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
|
98 |
-
|
99 |
-
audios = torch.stack(wavs, dim=0)[:, None]
|
100 |
-
audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
|
101 |
-
|
102 |
-
# Calculate lengths
|
103 |
-
indices, feature_lengths = model.encode(audios, audio_lengths)
|
104 |
-
|
105 |
-
# Save to disk
|
106 |
-
outputs = indices.cpu().numpy()
|
107 |
-
|
108 |
-
for file, length, feature, audio_length in zip(
|
109 |
-
files, feature_lengths, outputs, audio_lengths
|
110 |
-
):
|
111 |
-
feature = feature[:, :length]
|
112 |
-
|
113 |
-
# (T,)
|
114 |
-
with open(file.with_suffix(".npy"), "wb") as f:
|
115 |
-
np.save(f, feature)
|
116 |
-
|
117 |
-
return total_time
|
118 |
-
|
119 |
-
|
120 |
-
@click.command()
|
121 |
-
@click.argument("folder")
|
122 |
-
@click.option("--num-workers", default=1)
|
123 |
-
@click.option("--config-name", default="vqgan_pretrain")
|
124 |
-
@click.option(
|
125 |
-
"--checkpoint-path",
|
126 |
-
default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
|
127 |
-
)
|
128 |
-
@click.option("--batch-size", default=64)
|
129 |
-
@click.option("--filelist", default=None, type=Path)
|
130 |
-
def main(
|
131 |
-
folder: str,
|
132 |
-
num_workers: int,
|
133 |
-
config_name: str,
|
134 |
-
checkpoint_path: str,
|
135 |
-
batch_size: int,
|
136 |
-
filelist: Path,
|
137 |
-
):
|
138 |
-
if num_workers > 1 and WORLD_SIZE != num_workers:
|
139 |
-
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
|
140 |
-
|
141 |
-
logger.info(f"Spawning {num_workers} workers")
|
142 |
-
|
143 |
-
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
144 |
-
if visible_devices is None:
|
145 |
-
visible_devices = list(range(torch.cuda.device_count()))
|
146 |
-
else:
|
147 |
-
visible_devices = visible_devices.split(",")
|
148 |
-
|
149 |
-
processes = []
|
150 |
-
for i in range(num_workers):
|
151 |
-
env = os.environ.copy()
|
152 |
-
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
|
153 |
-
env["SLURM_PROCID"] = str(i)
|
154 |
-
env["SLURM_NTASKS"] = str(num_workers)
|
155 |
-
|
156 |
-
processes.append(
|
157 |
-
sp.Popen(
|
158 |
-
[sys.executable] + sys.argv.copy(),
|
159 |
-
env=env,
|
160 |
-
)
|
161 |
-
)
|
162 |
-
|
163 |
-
for p in processes:
|
164 |
-
p.wait()
|
165 |
-
|
166 |
-
logger.info(f"All workers finished")
|
167 |
-
return
|
168 |
-
|
169 |
-
# This is a worker
|
170 |
-
logger.info(f"Starting worker")
|
171 |
-
if filelist:
|
172 |
-
files = [i[0] for i in load_filelist(filelist)]
|
173 |
-
else:
|
174 |
-
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
|
175 |
-
|
176 |
-
print(f"Found {len(files)} files")
|
177 |
-
# files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
|
178 |
-
|
179 |
-
total_files = len(files)
|
180 |
-
files = files[RANK::WORLD_SIZE]
|
181 |
-
logger.info(f"Processing {len(files)}/{total_files} files")
|
182 |
-
|
183 |
-
# Batch processing
|
184 |
-
total_time = 0
|
185 |
-
begin_time = time.time()
|
186 |
-
processed_files = 0
|
187 |
-
model = get_model(config_name, checkpoint_path)
|
188 |
-
|
189 |
-
for n_batch, idx in enumerate(range(0, len(files), batch_size)):
|
190 |
-
batch = files[idx : idx + batch_size]
|
191 |
-
batch_time = process_batch(batch, model)
|
192 |
-
|
193 |
-
total_time += batch_time
|
194 |
-
processed_files += len(batch)
|
195 |
-
|
196 |
-
if (n_batch + 1) % 10 == 0:
|
197 |
-
eta = (
|
198 |
-
(time.time() - begin_time)
|
199 |
-
/ processed_files
|
200 |
-
* (len(files) - processed_files)
|
201 |
-
)
|
202 |
-
logger.info(
|
203 |
-
f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
|
204 |
-
+ f"ETA: {timedelta(seconds=round(eta))}s"
|
205 |
-
)
|
206 |
-
|
207 |
-
logger.info(
|
208 |
-
f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
|
209 |
-
)
|
210 |
-
|
211 |
-
|
212 |
-
if __name__ == "__main__":
|
213 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/whisper_asr.py
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Used to transcribe all audio files in one folder into another folder.
|
3 |
-
e.g.
|
4 |
-
Directory structure:
|
5 |
-
--pre_data_root
|
6 |
-
----SP_1
|
7 |
-
------01.wav
|
8 |
-
------02.wav
|
9 |
-
------......
|
10 |
-
----SP_2
|
11 |
-
------01.wav
|
12 |
-
------02.wav
|
13 |
-
------......
|
14 |
-
Use
|
15 |
-
python tools/whisper_asr.py --audio_dir pre_data_root/SP_1 --save_dir data/SP_1
|
16 |
-
to transcribe the first speaker.
|
17 |
-
|
18 |
-
Use
|
19 |
-
python tools/whisper_asr.py --audio_dir pre_data_root/SP_2 --save_dir data/SP_2
|
20 |
-
to transcribe the second speaker.
|
21 |
-
|
22 |
-
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
|
23 |
-
"""
|
24 |
-
from pathlib import Path
|
25 |
-
|
26 |
-
import click
|
27 |
-
import librosa
|
28 |
-
import soundfile as sf
|
29 |
-
import whisper
|
30 |
-
from loguru import logger
|
31 |
-
from merge_asr_files import merge_and_delete_files
|
32 |
-
from tqdm import tqdm
|
33 |
-
|
34 |
-
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
|
35 |
-
|
36 |
-
|
37 |
-
@click.command()
|
38 |
-
@click.option("--model-size", default="large", help="Size of the Whisper model")
|
39 |
-
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
40 |
-
@click.option(
|
41 |
-
"--save-dir", required=True, help="Directory to save processed audio files"
|
42 |
-
)
|
43 |
-
@click.option(
|
44 |
-
"--sample-rate",
|
45 |
-
default=None,
|
46 |
-
type=int,
|
47 |
-
help="Output sample rate, default to input sample rate",
|
48 |
-
)
|
49 |
-
@click.option("--device", default="cuda", help="Device to use")
|
50 |
-
@click.option("--language", default="ZH", help="Language of the transcription")
|
51 |
-
def main(model_size, audio_dir, save_dir, sample_rate, device, language):
|
52 |
-
logger.info("Loading / Downloading OpenAI Whisper model...")
|
53 |
-
model = whisper.load_model(
|
54 |
-
name=model_size,
|
55 |
-
device=device,
|
56 |
-
download_root=str(Path(".cache/whisper").resolve()),
|
57 |
-
)
|
58 |
-
logger.info("Model loaded.")
|
59 |
-
|
60 |
-
save_path = Path(save_dir)
|
61 |
-
save_path.mkdir(parents=True, exist_ok=True)
|
62 |
-
original_files = []
|
63 |
-
audio_files = list_files(
|
64 |
-
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
65 |
-
)
|
66 |
-
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
67 |
-
file_stem = file_path.stem
|
68 |
-
file_suffix = file_path.suffix
|
69 |
-
|
70 |
-
rel_path = Path(file_path).relative_to(audio_dir)
|
71 |
-
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
72 |
-
|
73 |
-
if (save_path / rel_path.parent / f"{rel_path.stem}.wav").exists() and (
|
74 |
-
save_path / rel_path.parent / f"{rel_path.stem}.lab"
|
75 |
-
).exists():
|
76 |
-
continue
|
77 |
-
|
78 |
-
audio, sr = librosa.load(file_path, sr=sample_rate, mono=False)
|
79 |
-
transcription = model.transcribe(str(file_path), language=language)
|
80 |
-
|
81 |
-
for segment in transcription.get("segments", []):
|
82 |
-
id, text, start, end = (
|
83 |
-
segment["id"],
|
84 |
-
segment["text"],
|
85 |
-
segment["start"],
|
86 |
-
segment["end"],
|
87 |
-
)
|
88 |
-
|
89 |
-
extract = audio[..., int(start * sr) : int(end * sr)]
|
90 |
-
audio_save_path = (
|
91 |
-
save_path / rel_path.parent / f"{file_stem}-{id}{file_suffix}"
|
92 |
-
)
|
93 |
-
sf.write(
|
94 |
-
audio_save_path,
|
95 |
-
extract,
|
96 |
-
samplerate=sr,
|
97 |
-
)
|
98 |
-
original_files.append(audio_save_path)
|
99 |
-
|
100 |
-
transcript_save_path = save_path / rel_path.parent / f"{file_stem}-{id}.lab"
|
101 |
-
with open(
|
102 |
-
transcript_save_path,
|
103 |
-
"w",
|
104 |
-
encoding="utf-8",
|
105 |
-
) as f:
|
106 |
-
f.write(text)
|
107 |
-
original_files.append(transcript_save_path)
|
108 |
-
|
109 |
-
merge_and_delete_files(save_dir, original_files)
|
110 |
-
|
111 |
-
|
112 |
-
if __name__ == "__main__":
|
113 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|