mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
@@ -54,8 +54,7 @@ def prepare_csv_wavs_dir(input_dir):
|
|
54 |
|
55 |
def get_audio_duration(audio_path):
|
56 |
audio, sample_rate = torchaudio.load(audio_path)
|
57 |
-
|
58 |
-
return audio.shape[1] / (sample_rate * num_channels)
|
59 |
|
60 |
|
61 |
def read_audio_text_pairs(csv_file_path):
|
|
|
54 |
|
55 |
def get_audio_duration(audio_path):
|
56 |
audio, sample_rate = torchaudio.load(audio_path)
|
57 |
+
return audio.shape[1] / sample_rate
|
|
|
58 |
|
59 |
|
60 |
def read_audio_text_pairs(csv_file_path):
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -172,10 +172,9 @@ def load_settings(project_name):
|
|
172 |
|
173 |
# Load metadata
|
174 |
def get_audio_duration(audio_path):
|
175 |
-
"""Calculate the duration of an audio file."""
|
176 |
audio, sample_rate = torchaudio.load(audio_path)
|
177 |
-
|
178 |
-
return audio.shape[1] / (sample_rate * num_channels)
|
179 |
|
180 |
|
181 |
def clear_text(text):
|
@@ -383,13 +382,17 @@ def start_training(
|
|
383 |
stream=False,
|
384 |
logger="wandb",
|
385 |
):
|
386 |
-
global training_process, tts_api, stop_signal
|
387 |
|
388 |
-
if tts_api is not None:
|
389 |
-
|
|
|
|
|
|
|
390 |
gc.collect()
|
391 |
torch.cuda.empty_cache()
|
392 |
tts_api = None
|
|
|
393 |
|
394 |
path_project = os.path.join(path_data, dataset_name)
|
395 |
|
@@ -1557,7 +1560,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
|
|
1557 |
last_per_steps = gr.Number(label="Last per Steps", value=100)
|
1558 |
|
1559 |
with gr.Row():
|
1560 |
-
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "
|
1561 |
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
1562 |
start_button = gr.Button("Start Training")
|
1563 |
stop_button = gr.Button("Stop Training", interactive=False)
|
|
|
172 |
|
173 |
# Load metadata
|
174 |
def get_audio_duration(audio_path):
|
175 |
+
"""Calculate the duration mono of an audio file."""
|
176 |
audio, sample_rate = torchaudio.load(audio_path)
|
177 |
+
return audio.shape[1] / sample_rate
|
|
|
178 |
|
179 |
|
180 |
def clear_text(text):
|
|
|
382 |
stream=False,
|
383 |
logger="wandb",
|
384 |
):
|
385 |
+
global training_process, tts_api, stop_signal, pipe
|
386 |
|
387 |
+
if tts_api is not None or pipe is not None:
|
388 |
+
if tts_api is not None:
|
389 |
+
del tts_api
|
390 |
+
if pipe is not None:
|
391 |
+
del pipe
|
392 |
gc.collect()
|
393 |
torch.cuda.empty_cache()
|
394 |
tts_api = None
|
395 |
+
pipe = None
|
396 |
|
397 |
path_project = os.path.join(path_data, dataset_name)
|
398 |
|
|
|
1560 |
last_per_steps = gr.Number(label="Last per Steps", value=100)
|
1561 |
|
1562 |
with gr.Row():
|
1563 |
+
mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
|
1564 |
cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
|
1565 |
start_button = gr.Button("Start Training")
|
1566 |
stop_button = gr.Button("Stop Training", interactive=False)
|