Remsky commited on
Commit
3c8cbc9
·
1 Parent(s): b2ae79f

Added Multi-Voice, GPU Timeout, etc

Browse files
Files changed (6) hide show
  1. README.md +3 -1
  2. app.py +151 -91
  3. lib/file_utils.py +48 -42
  4. lib/ui_content.py +1 -1
  5. the_time_machine_hgwells.txt +0 -19
  6. tts_model.py +130 -173
README.md CHANGED
@@ -42,4 +42,6 @@ Main dependencies:
42
  - Transformers 4.47.1
43
  - HuggingFace Hub ≥0.25.1
44
 
45
- For a complete list, see requirements.txt.
 
 
 
42
  - Transformers 4.47.1
43
  - HuggingFace Hub ≥0.25.1
44
 
45
+ For a complete list, see requirements.txt.
46
+
47
+
app.py CHANGED
@@ -4,6 +4,8 @@ import spaces
4
  import time
5
  import matplotlib.pyplot as plt
6
  import numpy as np
 
 
7
  from tts_model import TTSModel
8
  from lib import format_audio_output
9
  from lib.ui_content import header_html, demo_text_info
@@ -14,106 +16,78 @@ os.environ["HF_HOME"] = "/data/.huggingface"
14
  # Create TTS model instance
15
  model = TTSModel()
16
 
17
- @spaces.GPU(duration=10) # Quick initialization
18
  def initialize_model():
19
  """Initialize model and get voices"""
20
  if model.model is None:
21
  if not model.initialize():
22
  raise gr.Error("Failed to initialize model")
23
- return model.list_voices()
 
 
 
 
 
24
 
25
- # Get initial voice list
26
- voice_list = initialize_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- @spaces.GPU(duration=120) # Allow 5 minutes for processing
29
- def generate_speech_from_ui(text, voice_name, speed, progress=gr.Progress(track_tqdm=False)):
30
  """Handle text-to-speech generation from the Gradio UI"""
31
  try:
 
 
 
32
  start_time = time.time()
33
- gpu_timeout = 120 # seconds
34
 
35
- # Create progress state
36
  progress_state = {
37
  "progress": 0.0,
38
- "tokens_per_sec": [],
39
- "rtf": [],
40
- "chunk_times": [],
41
- "gpu_time_left": gpu_timeout,
42
  "total_chunks": 0
43
  }
44
 
45
- def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf):
46
- progress_state["progress"] = chunk_num / total_chunks
47
- progress_state["tokens_per_sec"].append(tokens_per_sec)
48
- progress_state["rtf"].append(rtf)
49
-
50
- # Update GPU time remaining
51
- elapsed = time.time() - start_time
52
- gpu_time_left = max(0, gpu_timeout - elapsed)
53
- progress_state["gpu_time_left"] = gpu_time_left
54
- progress_state["total_chunks"] = total_chunks
55
-
56
- # Track individual chunk processing time
57
- chunk_time = elapsed - (sum(progress_state["chunk_times"]) if progress_state["chunk_times"] else 0)
58
- progress_state["chunk_times"].append(chunk_time)
59
-
60
- # Only update progress display during processing
61
- progress(progress_state["progress"], desc=f"Processing chunk {chunk_num}/{total_chunks} | GPU Time Left: {int(gpu_time_left)}s")
62
 
63
- # Generate speech with progress tracking
64
- audio_array, duration = model.generate_speech(
65
- text,
66
- voice_name,
67
  speed,
68
- progress_callback=update_progress
 
 
 
69
  )
70
-
71
  # Format output for Gradio
72
  audio_output, duration_text = format_audio_output(audio_array)
73
 
74
- # Calculate final metrics
75
- total_time = time.time() - start_time
76
- total_duration = len(audio_array) / 24000 # audio duration in seconds
77
- rtf = total_time / total_duration if total_duration > 0 else 0
78
- mean_tokens_per_sec = np.mean(progress_state["tokens_per_sec"])
79
-
80
- # Create plot of tokens per second with median line
81
- fig, ax = plt.subplots(figsize=(10, 5))
82
- fig.patch.set_facecolor('black')
83
- ax.set_facecolor('black')
84
- chunk_nums = list(range(1, len(progress_state["tokens_per_sec"]) + 1))
85
-
86
- # Plot bars for tokens per second
87
- ax.bar(chunk_nums, progress_state["tokens_per_sec"], color='#ff2a6d', alpha=0.8)
88
-
89
- # Add median line
90
- median_tps = np.median(progress_state["tokens_per_sec"])
91
- ax.axhline(y=median_tps, color='#05d9e8', linestyle='--', label=f'Median: {median_tps:.1f} tokens/sec')
92
-
93
- # Style improvements
94
- ax.set_xlabel('Chunk Number', fontsize=24, labelpad=20)
95
- ax.set_ylabel('Tokens per Second', fontsize=24, labelpad=20)
96
- ax.set_title('Processing Speed by Chunk', fontsize=28, pad=30)
97
-
98
- # Increase tick label size
99
- ax.tick_params(axis='both', which='major', labelsize=20)
100
-
101
- # Remove gridlines
102
- ax.grid(False)
103
-
104
- # Style legend and position it in bottom left
105
- ax.legend(fontsize=20, facecolor='black', edgecolor='#05d9e8', loc='lower left')
106
-
107
- plt.tight_layout()
108
-
109
- # Prepare final metrics display including audio duration and real-time speed
110
- metrics_text = (
111
- f"Median Processing Speed: {np.median(progress_state['tokens_per_sec']):.1f} tokens/sec\n" +
112
- f"Real-time Factor: {rtf:.3f}\n" +
113
- f"Real Time Generation Speed: {int(1/rtf)}x \n" +
114
- f"Processing Time: {int(total_time)}s\n" +
115
- f"Output Audio Duration: {total_duration:.2f}s"
116
- )
117
 
118
  return (
119
  audio_output,
@@ -123,6 +97,70 @@ def generate_speech_from_ui(text, voice_name, speed, progress=gr.Progress(track_
123
  except Exception as e:
124
  raise gr.Error(f"Generation failed: {str(e)}")
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  # Create Gradio interface
127
  with gr.Blocks(title="Kokoro TTS Demo", css="""
128
  .equal-height {
@@ -135,12 +173,15 @@ with gr.Blocks(title="Kokoro TTS Demo", css="""
135
 
136
  with gr.Row():
137
  # Column 1: Text Input
 
 
 
138
  with gr.Column(elem_classes="equal-height"):
139
  text_input = gr.TextArea(
140
  label="Text to speak",
141
  placeholder="Enter text here or upload a .txt file",
142
  lines=10,
143
- value=open("the_time_machine_hgwells.txt").read()[:1000]
144
  )
145
 
146
  # Column 2: Controls
@@ -166,17 +207,17 @@ with gr.Blocks(title="Kokoro TTS Demo", css="""
166
  )
167
 
168
  with gr.Group():
169
- default_voice = 'af_sky' if 'af_sky' in voice_list \
170
- else voice_list[0] \
171
- if voice_list else \
172
- None
173
-
174
  voice_dropdown = gr.Dropdown(
175
- label="Voice",
176
- choices=voice_list,
177
- value=default_voice,
178
- allow_custom_value=True
 
179
  )
 
 
 
 
180
  speed_slider = gr.Slider(
181
  label="Speed",
182
  minimum=0.5,
@@ -184,6 +225,14 @@ with gr.Blocks(title="Kokoro TTS Demo", css="""
184
  value=1.0,
185
  step=0.1
186
  )
 
 
 
 
 
 
 
 
187
  submit_btn = gr.Button("Generate Speech", variant="primary")
188
 
189
  # Column 3: Output
@@ -198,7 +247,7 @@ with gr.Blocks(title="Kokoro TTS Demo", css="""
198
  metrics_text = gr.Textbox(
199
  label="Performance Summary",
200
  interactive=False,
201
- lines=4
202
  )
203
  metrics_plot = gr.Plot(
204
  label="Processing Metrics",
@@ -206,10 +255,15 @@ with gr.Blocks(title="Kokoro TTS Demo", css="""
206
  format="png" # Explicitly set format to PNG which is supported by matplotlib
207
  )
208
 
209
- # Set up event handler
 
 
 
 
 
210
  submit_btn.click(
211
  fn=generate_speech_from_ui,
212
- inputs=[text_input, voice_dropdown, speed_slider],
213
  outputs=[audio_output, metrics_plot, metrics_text],
214
  show_progress=True
215
  )
@@ -218,6 +272,12 @@ with gr.Blocks(title="Kokoro TTS Demo", css="""
218
  with gr.Row():
219
  with gr.Column():
220
  gr.Markdown(demo_text_info)
 
 
 
 
 
 
221
 
222
  # Launch the app
223
  if __name__ == "__main__":
 
4
  import time
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
+ import torch
8
+ import os
9
  from tts_model import TTSModel
10
  from lib import format_audio_output
11
  from lib.ui_content import header_html, demo_text_info
 
16
  # Create TTS model instance
17
  model = TTSModel()
18
 
 
19
  def initialize_model():
20
  """Initialize model and get voices"""
21
  if model.model is None:
22
  if not model.initialize():
23
  raise gr.Error("Failed to initialize model")
24
+
25
+ voices = model.list_voices()
26
+ if not voices:
27
+ raise gr.Error("No voices found. Please check the voices directory.")
28
+
29
+ return gr.update(choices=voices, value=[voices[0]] if voices else None)
30
 
31
+ def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress):
32
+ # Calculate time metrics
33
+ elapsed = time.time() - start_time
34
+ gpu_time_left = max(0, gpu_timeout - elapsed)
35
+
36
+ # Calculate chunk time more accurately
37
+ prev_total_time = sum(progress_state["chunk_times"]) if progress_state["chunk_times"] else 0
38
+ chunk_time = elapsed - prev_total_time
39
+
40
+ # Validate metrics before adding to state
41
+ if chunk_time > 0 and tokens_per_sec >= 0:
42
+ # Update progress state with validated metrics
43
+ progress_state["progress"] = chunk_num / total_chunks
44
+ progress_state["total_chunks"] = total_chunks
45
+ progress_state["gpu_time_left"] = gpu_time_left
46
+ progress_state["tokens_per_sec"].append(float(tokens_per_sec))
47
+ progress_state["rtf"].append(float(rtf))
48
+ progress_state["chunk_times"].append(chunk_time)
49
+
50
+ # Only update progress display during processing
51
+ progress(progress_state["progress"], desc=f"Processing chunk {chunk_num}/{total_chunks} | GPU Time Left: {int(gpu_time_left)}s")
52
 
53
+ def generate_speech_from_ui(text, voice_names, speed, gpu_timeout, progress=gr.Progress(track_tqdm=False)):
 
54
  """Handle text-to-speech generation from the Gradio UI"""
55
  try:
56
+ if not text or not voice_names:
57
+ raise gr.Error("Please enter text and select at least one voice")
58
+
59
  start_time = time.time()
 
60
 
61
+ # Create progress state with explicit type initialization
62
  progress_state = {
63
  "progress": 0.0,
64
+ "tokens_per_sec": [], # Initialize as empty list
65
+ "rtf": [], # Initialize as empty list
66
+ "chunk_times": [], # Initialize as empty list
67
+ "gpu_time_left": float(gpu_timeout), # Ensure float
68
  "total_chunks": 0
69
  }
70
 
71
+ # Handle single or multiple voices
72
+ if isinstance(voice_names, str):
73
+ voice_names = [voice_names]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # Generate speech with progress tracking using combined voice
76
+ audio_array, duration, metrics = model.generate_speech(
77
+ text,
78
+ voice_names,
79
  speed,
80
+ gpu_timeout=gpu_timeout,
81
+ progress_callback=update_progress,
82
+ progress_state=progress_state,
83
+ progress=progress
84
  )
85
+
86
  # Format output for Gradio
87
  audio_output, duration_text = format_audio_output(audio_array)
88
 
89
+ # Create plot and metrics text outside GPU context
90
+ fig, metrics_text = create_performance_plot(metrics, voice_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  return (
93
  audio_output,
 
97
  except Exception as e:
98
  raise gr.Error(f"Generation failed: {str(e)}")
99
 
100
+ def create_performance_plot(metrics, voice_names):
101
+ """Create performance plot and metrics text from generation metrics"""
102
+ # Clean and process the data
103
+ tokens_per_sec = np.array(metrics["tokens_per_sec"])
104
+ rtf_values = np.array(metrics["rtf"])
105
+
106
+ # Calculate statistics using cleaned data
107
+ median_tps = float(np.median(tokens_per_sec))
108
+ mean_tps = float(np.mean(tokens_per_sec))
109
+ std_tps = float(np.std(tokens_per_sec))
110
+
111
+ # Set y-axis limits based on data range
112
+ y_min = max(0, np.min(tokens_per_sec) * 0.9)
113
+ y_max = np.max(tokens_per_sec) * 1.1
114
+
115
+ # Create plot
116
+ fig, ax = plt.subplots(figsize=(10, 5))
117
+ fig.patch.set_facecolor('black')
118
+ ax.set_facecolor('black')
119
+
120
+ # Plot data points
121
+ chunk_nums = list(range(1, len(tokens_per_sec) + 1))
122
+
123
+ # Plot data points
124
+ ax.bar(chunk_nums, tokens_per_sec, color='#ff2a6d', alpha=0.6)
125
+
126
+ # Set y-axis limits with padding
127
+ padding = 0.1 * (y_max - y_min)
128
+ ax.set_ylim(max(0, y_min - padding), y_max + padding)
129
+
130
+ # Add median line
131
+ ax.axhline(y=median_tps, color='#05d9e8', linestyle='--',
132
+ label=f'Median: {median_tps:.1f} tokens/sec')
133
+
134
+ # Style improvements
135
+ ax.set_xlabel('Chunk Number', fontsize=24, labelpad=20, color='white')
136
+ ax.set_ylabel('Tokens per Second', fontsize=24, labelpad=20, color='white')
137
+ ax.set_title('Processing Speed by Chunk', fontsize=28, pad=30, color='white')
138
+ ax.tick_params(axis='both', which='major', labelsize=20, colors='white')
139
+ ax.spines['bottom'].set_color('white')
140
+ ax.spines['top'].set_color('white')
141
+ ax.spines['left'].set_color('white')
142
+ ax.spines['right'].set_color('white')
143
+ ax.grid(False)
144
+ ax.legend(fontsize=20, facecolor='black', edgecolor='#05d9e8', loc='lower left',
145
+ labelcolor='white')
146
+
147
+ plt.tight_layout()
148
+
149
+ # Calculate average RTF from individual chunk RTFs
150
+ rtf = np.mean(rtf_values)
151
+
152
+ # Prepare metrics text
153
+ metrics_text = (
154
+ f"Median Speed: {median_tps:.1f} tokens/sec (o200k_base)\n" +
155
+ f"Real-time Factor: {rtf:.3f}\n" +
156
+ f"Real Time Speed: {int(1/rtf)}x\n" +
157
+ f"Processing Time: {int(metrics['total_time'])}s\n" +
158
+ f"Total Tokens: {metrics['total_tokens']} (o200k_base)\n" +
159
+ f"Voices: {', '.join(voice_names)}"
160
+ )
161
+
162
+ return fig, metrics_text
163
+
164
  # Create Gradio interface
165
  with gr.Blocks(title="Kokoro TTS Demo", css="""
166
  .equal-height {
 
173
 
174
  with gr.Row():
175
  # Column 1: Text Input
176
+ with open("the_time_machine_hgwells.txt") as f:
177
+ text = f.readlines()[:200]
178
+ text = "".join(text)
179
  with gr.Column(elem_classes="equal-height"):
180
  text_input = gr.TextArea(
181
  label="Text to speak",
182
  placeholder="Enter text here or upload a .txt file",
183
  lines=10,
184
+ value=text
185
  )
186
 
187
  # Column 2: Controls
 
207
  )
208
 
209
  with gr.Group():
 
 
 
 
 
210
  voice_dropdown = gr.Dropdown(
211
+ label="Voice(s)",
212
+ choices=[], # Start empty, will be populated after initialization
213
+ value=None,
214
+ allow_custom_value=True,
215
+ multiselect=True
216
  )
217
+
218
+ # Add refresh button to manually update voice list
219
+ refresh_btn = gr.Button("🔄 Refresh Voices", size="sm")
220
+
221
  speed_slider = gr.Slider(
222
  label="Speed",
223
  minimum=0.5,
 
225
  value=1.0,
226
  step=0.1
227
  )
228
+ gpu_timeout_slider = gr.Slider(
229
+ label="GPU Timeout (seconds)",
230
+ minimum=15,
231
+ maximum=120,
232
+ value=60,
233
+ step=1,
234
+ info="Maximum time allowed for GPU processing"
235
+ )
236
  submit_btn = gr.Button("Generate Speech", variant="primary")
237
 
238
  # Column 3: Output
 
247
  metrics_text = gr.Textbox(
248
  label="Performance Summary",
249
  interactive=False,
250
+ lines=5
251
  )
252
  metrics_plot = gr.Plot(
253
  label="Processing Metrics",
 
255
  format="png" # Explicitly set format to PNG which is supported by matplotlib
256
  )
257
 
258
+ # Set up event handlers
259
+ refresh_btn.click(
260
+ fn=initialize_model,
261
+ outputs=[voice_dropdown]
262
+ )
263
+
264
  submit_btn.click(
265
  fn=generate_speech_from_ui,
266
+ inputs=[text_input, voice_dropdown, speed_slider, gpu_timeout_slider],
267
  outputs=[audio_output, metrics_plot, metrics_text],
268
  show_progress=True
269
  )
 
272
  with gr.Row():
273
  with gr.Column():
274
  gr.Markdown(demo_text_info)
275
+
276
+ # Initialize voices on load
277
+ demo.load(
278
+ fn=initialize_model,
279
+ outputs=[voice_dropdown]
280
+ )
281
 
282
  # Launch the app
283
  if __name__ == "__main__":
lib/file_utils.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import importlib.util
3
  import sys
4
- from huggingface_hub import hf_hub_download
5
  from typing import List, Optional
6
 
7
  def load_module_from_file(module_name: str, file_path: str):
@@ -35,19 +35,39 @@ def ensure_dir(path: str) -> None:
35
  """Ensure directory exists, create if it doesn't"""
36
  os.makedirs(path, exist_ok=True)
37
 
 
 
 
 
 
 
 
 
38
  def list_voice_files(voices_dir: str) -> List[str]:
39
  """List available voice files in directory"""
40
  voices = []
41
  try:
42
- if not os.path.exists(voices_dir):
43
- print(f"Voices directory does not exist: {voices_dir}")
44
- return voices
 
 
45
 
46
- files = os.listdir(voices_dir)
 
 
 
 
 
 
 
 
 
 
47
  print(f"Found {len(files)} files in voices directory")
48
 
49
  for file in files:
50
- if file.endswith(".pt"):
51
  voice_name = file[:-3] # Remove .pt extension
52
  print(f"Found voice: {voice_name}")
53
  voices.append(voice_name)
@@ -62,40 +82,26 @@ def list_voice_files(voices_dir: str) -> List[str]:
62
 
63
  return sorted(voices)
64
 
65
- def download_voice_files(repo_id: str, voices: List[str], voices_dir: str) -> None:
66
- """Download voice files from Hugging Face Hub"""
67
- ensure_dir(voices_dir)
68
 
69
- for voice in voices:
70
- try:
71
- voice_path = os.path.join(voices_dir, voice)
72
- print(f"Attempting to download voice {voice} to {voice_path}")
73
-
74
- try:
75
- downloaded_path = hf_hub_download(
76
- repo_id=repo_id,
77
- filename=f"voices/{voice}",
78
- local_dir=voices_dir,
79
- local_dir_use_symlinks=False,
80
- force_filename=voice
81
- )
82
- print(f"Download completed to: {downloaded_path}")
83
-
84
- if not os.path.exists(voice_path):
85
- print(f"Warning: File not found at expected path {voice_path}")
86
- print(f"Checking download location: {downloaded_path}")
87
- if os.path.exists(downloaded_path):
88
- print(f"Moving file from {downloaded_path} to {voice_path}")
89
- os.rename(downloaded_path, voice_path)
90
- else:
91
- print(f"Verified voice file exists: {voice_path}")
92
-
93
- except Exception as e:
94
- print(f"Error downloading voice {voice}: {str(e)}")
95
- import traceback
96
- traceback.print_exc()
97
-
98
- except Exception as e:
99
- print(f"Error downloading voice {voice}: {str(e)}")
100
- import traceback
101
- traceback.print_exc()
 
1
  import os
2
  import importlib.util
3
  import sys
4
+ from huggingface_hub import hf_hub_download, snapshot_download
5
  from typing import List, Optional
6
 
7
  def load_module_from_file(module_name: str, file_path: str):
 
35
  """Ensure directory exists, create if it doesn't"""
36
  os.makedirs(path, exist_ok=True)
37
 
38
+ def find_voice_directory(start_path: str) -> str:
39
+ """Recursively search for directory containing .pt files that don't have 'kokoro' in the name"""
40
+ for root, dirs, files in os.walk(start_path):
41
+ pt_files = [f for f in files if f.endswith('.pt') and 'kokoro' not in f.lower()]
42
+ if pt_files:
43
+ return root
44
+ return ""
45
+
46
  def list_voice_files(voices_dir: str) -> List[str]:
47
  """List available voice files in directory"""
48
  voices = []
49
  try:
50
+ # First try the standard locations
51
+ if os.path.exists(os.path.join(voices_dir, 'voices')):
52
+ voice_path = os.path.join(voices_dir, 'voices')
53
+ else:
54
+ voice_path = voices_dir
55
 
56
+ # If no voices found, try recursive search
57
+ if not os.path.exists(voice_path) or not any(f.endswith('.pt') for f in os.listdir(voice_path)):
58
+ found_dir = find_voice_directory(os.path.dirname(voices_dir))
59
+ if found_dir:
60
+ voice_path = found_dir
61
+ print(f"Found voices in: {voice_path}")
62
+ else:
63
+ print(f"No voice directory found")
64
+ return voices
65
+
66
+ files = os.listdir(voice_path)
67
  print(f"Found {len(files)} files in voices directory")
68
 
69
  for file in files:
70
+ if file.endswith(".pt") and 'kokoro' not in file.lower():
71
  voice_name = file[:-3] # Remove .pt extension
72
  print(f"Found voice: {voice_name}")
73
  voices.append(voice_name)
 
82
 
83
  return sorted(voices)
84
 
85
+ def download_voice_files(repo_id: str, directory: str, local_dir: str) -> None:
86
+ """Download voice files from Hugging Face Hub
 
87
 
88
+ Args:
89
+ repo_id: The Hugging Face repository ID
90
+ directory: The directory in the repo to download (e.g. "voices")
91
+ local_dir: Local directory to save files to
92
+ """
93
+ ensure_dir(local_dir)
94
+ try:
95
+ print(f"Downloading voice files from {repo_id}/{directory} to {local_dir}")
96
+ downloaded_path = snapshot_download(
97
+ repo_id=repo_id,
98
+ repo_type="model",
99
+ local_dir=local_dir,
100
+ allow_patterns=[f"{directory}/*"],
101
+ local_dir_use_symlinks=False
102
+ )
103
+ print(f"Download completed to: {downloaded_path}")
104
+ except Exception as e:
105
+ print(f"Error downloading voice files: {str(e)}")
106
+ import traceback
107
+ traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/ui_content.py CHANGED
@@ -13,7 +13,7 @@ header_html = """
13
 
14
  <div style="text-align: center; margin-bottom: 1rem;">
15
  <h1 style="font-size: 1.75rem; font-weight: bold; color: #ffffff; margin-bottom: 0.5rem;">Kokoro TTS Demo</h1>
16
- <p style="color: #d1d5db;">Convert text to natural-sounding speech using various voices.</p>
17
  </div>
18
 
19
  <div style="display: flex; gap: 1rem;">
 
13
 
14
  <div style="text-align: center; margin-bottom: 1rem;">
15
  <h1 style="font-size: 1.75rem; font-weight: bold; color: #ffffff; margin-bottom: 0.5rem;">Kokoro TTS Demo</h1>
16
+ <p style="color: #d1d5db;">Rapidly onvert text to natural speech using various and blended voices.</p>
17
  </div>
18
 
19
  <div style="display: flex; gap: 1rem;">
the_time_machine_hgwells.txt CHANGED
@@ -1,22 +1,3 @@
1
- The Time Traveller (for so it will be convenient to speak of him) was
2
- expounding a recondite matter to us. His pale grey eyes shone and
3
- twinkled, and his usually pale face was flushed and animated. The fire
4
- burnt brightly, and the soft radiance of the incandescent lights in the
5
- lilies of silver caught the bubbles that flashed and passed in our
6
- glasses. Our chairs, being his patents, embraced and caressed us rather
7
- than submitted to be sat upon, and there was that luxurious
8
- after-dinner atmosphere, when thought runs gracefully free of the
9
- trammels of precision. And he put it to us in this way—marking the
10
- points with a lean forefinger—as we sat and lazily admired his
11
- earnestness over this new paradox (as we thought it) and his fecundity.
12
-
13
- “You must follow me carefully. I shall have to controvert one or two
14
- ideas that are almost universally accepted. The geometry, for instance,
15
- they taught you at school is founded on a misconception.”
16
-
17
- “Is not that rather a large thing to expect us to begin upon?” said
18
- Filby, an argumentative person with red hair.
19
-
20
  “I do not mean to ask you to accept anything without reasonable ground
21
  for it. You will soon admit as much as I need from you. You know of
22
  course that a mathematical line, a line of thickness _nil_, has no real
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  “I do not mean to ask you to accept anything without reasonable ground
2
  for it. You will soon admit as much as I need from you. You know of
3
  course that a mathematical line, a line of thickness _nil_, has no real
tts_model.py CHANGED
@@ -16,6 +16,7 @@ from lib import (
16
  ensure_dir,
17
  concatenate_audio_chunks
18
  )
 
19
 
20
  class TTSModel:
21
  """GPU-accelerated TTS model manager"""
@@ -25,6 +26,7 @@ class TTSModel:
25
  self.voices_dir = "voices"
26
  self.model_repo = "hexgrad/Kokoro-82M"
27
  ensure_dir(self.voices_dir)
 
28
 
29
  # Load required modules
30
  py_modules = ["istftnet", "plbert", "models", "kokoro"]
@@ -48,14 +50,14 @@ class TTSModel:
48
  self.model_repo,
49
  ["kokoro-v0_19.pth", "config.json"]
50
  )
51
- model_path = model_files[0] # kokoro-v0_19.pth
52
-
53
- # Build model directly on GPU
54
- with torch.cuda.device(0):
55
- torch.cuda.set_device(0)
56
- self.model = self.build_model(model_path, 'cuda')
57
- self._model_on_gpu = True
58
 
 
 
 
 
 
 
59
  print("Model initialization complete")
60
  return True
61
 
@@ -66,7 +68,7 @@ class TTSModel:
66
  def ensure_voice_downloaded(self, voice_name: str) -> bool:
67
  """Ensure specific voice is downloaded"""
68
  try:
69
- voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
70
  if not os.path.exists(voice_path):
71
  print(f"Downloading voice {voice_name}.pt...")
72
  download_voice_files(self.model_repo, [f"{voice_name}.pt"], self.voices_dir)
@@ -77,43 +79,58 @@ class TTSModel:
77
 
78
  def list_voices(self) -> List[str]:
79
  """List available voices"""
80
- return [
81
- "af_bella", "af_nicole", "af_sarah", "af_sky", "af",
82
- "am_adam", "am_michael", "bf_emma", "bf_isabella",
83
- "bm_george", "bm_lewis"
84
- ]
 
 
 
85
 
86
- def _ensure_model_on_gpu(self) -> None:
87
- """Ensure model is on GPU and stays there"""
88
- if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
89
- print("Moving model to GPU...")
90
- with torch.cuda.device(0):
91
- torch.cuda.set_device(0)
92
- if hasattr(self.model, 'to'):
93
- self.model.to('cuda')
94
- else:
95
- for name in self.model:
96
- if isinstance(self.model[name], torch.Tensor):
97
- self.model[name] = self.model[name].cuda()
98
- self._model_on_gpu = True
99
 
100
  def _generate_audio(self, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> np.ndarray:
101
  """GPU-accelerated audio generation"""
102
  try:
103
  with torch.cuda.device(0):
104
  torch.cuda.set_device(0)
105
-
106
- # Move everything to GPU in a single context
107
- if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
108
- print("Moving model to GPU...")
109
- if hasattr(self.model, 'to'):
110
- self.model.to('cuda')
111
- else:
112
- for name in self.model:
113
- if isinstance(self.model[name], torch.Tensor):
114
- self.model[name] = self.model[name].cuda()
115
- self._model_on_gpu = True
116
-
 
 
 
 
 
 
 
 
 
 
 
 
117
  # Move voicepack to GPU
118
  voicepack = voicepack.cuda()
119
 
@@ -131,59 +148,73 @@ class TTSModel:
131
  except Exception as e:
132
  print(f"Error in audio generation: {str(e)}")
133
  raise e
134
-
135
- def generate_speech(self, text: str, voice_name: str, speed: float = 1.0, progress_callback=None) -> Tuple[np.ndarray, float]:
 
136
  """Generate speech from text. Returns (audio_array, duration)
137
 
138
  Args:
139
  text: Input text to convert to speech
140
  voice_name: Name of voice to use
141
  speed: Speech speed multiplier
142
- progress_callback: Optional callback function(chunk_num, total_chunks, tokens_per_sec, rtf)
 
 
143
  """
144
  try:
145
- if not text or not voice_name:
146
- raise ValueError("Text and voice name are required")
147
-
148
  start_time = time.time()
149
-
150
- # Count tokens and normalize text
151
- total_tokens = count_tokens(text)
152
- text = normalize_text(text)
153
- if not text:
154
- raise ValueError("Text is empty after normalization")
155
-
156
- # Load voice and process within GPU context
157
  with torch.cuda.device(0):
158
  torch.cuda.set_device(0)
 
 
 
159
 
160
- voice_path = os.path.join(self.voices_dir, f"{voice_name}.pt")
161
-
162
- # Ensure voice is downloaded and load directly to GPU
163
- if not self.ensure_voice_downloaded(voice_name):
164
- raise ValueError(f"Failed to download voice: {voice_name}")
165
- voicepack = torch.load(voice_path, map_location='cuda', weights_only=True)
166
-
167
- # Break text into chunks for better memory management
168
- chunks = chunk_text(text)
169
- print(f"Processing {len(chunks)} chunks...")
170
-
171
- # Ensure model is initialized and on GPU
172
  if self.model is None:
173
- print("Model not initialized, reinitializing...")
174
- if not self.initialize():
175
- raise ValueError("Failed to initialize model")
 
 
176
 
177
  # Move model to GPU if needed
178
- if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
179
  print("Moving model to GPU...")
180
  if hasattr(self.model, 'to'):
181
- self.model.to('cuda')
182
  else:
183
  for name in self.model:
184
  if isinstance(self.model[name], torch.Tensor):
185
  self.model[name] = self.model[name].cuda()
186
- self._model_on_gpu = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  # Process all chunks within same GPU context
189
  audio_chunks = []
@@ -202,11 +233,13 @@ class TTSModel:
202
  )
203
  chunk_time = time.time() - chunk_start
204
 
205
- # Update metrics
206
  chunk_tokens = count_tokens(chunk)
 
 
 
207
  total_processed_tokens += chunk_tokens
208
  total_processed_time += chunk_time
209
- current_tokens_per_sec = total_processed_tokens / total_processed_time
210
 
211
  # Calculate processing speed metrics
212
  chunk_duration = len(chunk_audio) / 24000 # audio duration in seconds
@@ -216,7 +249,7 @@ class TTSModel:
216
  chunk_times.append(chunk_time)
217
  chunk_sizes.append(len(chunk))
218
  print(f"Chunk {i+1}/{len(chunks)} processed in {chunk_time:.2f}s")
219
- print(f"Current tokens/sec: {current_tokens_per_sec:.2f}")
220
  print(f"Real-time factor: {rtf:.2f}x")
221
  print(f"{times_faster:.1f}x faster than real-time")
222
 
@@ -224,109 +257,33 @@ class TTSModel:
224
 
225
  # Call progress callback if provided
226
  if progress_callback:
227
- progress_callback(i + 1, len(chunks), current_tokens_per_sec, rtf)
 
 
 
 
 
 
 
 
 
228
 
229
  # Concatenate audio chunks
230
  audio = concatenate_audio_chunks(audio_chunks)
231
-
232
- def setup_plot(fig, ax, title):
233
- """Configure plot styling"""
234
- # Improve grid
235
- ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
236
-
237
- # Set title and labels with better fonts and more padding
238
- ax.set_title(title, pad=40, fontsize=16, fontweight="bold", color="#ffffff")
239
- ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
240
- ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
241
-
242
- # Improve tick labels
243
- ax.tick_params(labelsize=12, colors="#ffffff")
244
-
245
- # Style spines
246
- for spine in ax.spines.values():
247
- spine.set_color("#ffffff")
248
- spine.set_alpha(0.3)
249
- spine.set_linewidth(0.5)
250
-
251
- # Set background colors
252
- ax.set_facecolor("#1a1a2e")
253
- fig.patch.set_facecolor("#1a1a2e")
254
-
255
- return fig, ax
256
-
257
- # Set dark style
258
- plt.style.use("dark_background")
259
-
260
- # Create figure with subplots
261
- fig = plt.figure(figsize=(18, 16))
262
- fig.patch.set_facecolor("#1a1a2e")
263
-
264
- # Create subplot grid
265
- gs = plt.GridSpec(2, 1, left=0.15, right=0.85, top=0.9, bottom=0.15, hspace=0.4)
266
-
267
- # Processing times plot
268
- ax1 = plt.subplot(gs[0])
269
- chunks_x = list(range(1, len(chunks) + 1))
270
- bars = ax1.bar(chunks_x, chunk_times, color='#ff2a6d', alpha=0.8)
271
-
272
- # Add statistics lines
273
- mean_time = mean(chunk_times)
274
- median_time = median(chunk_times)
275
- std_time = stdev(chunk_times) if len(chunk_times) > 1 else 0
276
-
277
- ax1.axhline(y=mean_time, color='#05d9e8', linestyle='--',
278
- label=f'Mean: {mean_time:.2f}s')
279
- ax1.axhline(y=median_time, color='#d1f7ff', linestyle=':',
280
- label=f'Median: {median_time:.2f}s')
281
-
282
- # Add ±1 std dev range
283
- if len(chunk_times) > 1:
284
- ax1.axhspan(mean_time - std_time, mean_time + std_time,
285
- color='#8c1eff', alpha=0.2, label='±1 Std Dev')
286
-
287
- # Add value labels on top of bars
288
- for bar in bars:
289
- height = bar.get_height()
290
- ax1.text(bar.get_x() + bar.get_width() / 2.0,
291
- height,
292
- f'{height:.2f}s',
293
- ha='center',
294
- va='bottom',
295
- color='white',
296
- fontsize=10)
297
-
298
- ax1.set_xlabel('Chunk Number')
299
- ax1.set_ylabel('Processing Time (seconds)')
300
- setup_plot(fig, ax1, 'Chunk Processing Times')
301
- ax1.legend(facecolor="#1a1a2e", edgecolor="#ffffff")
302
-
303
- # Chunk sizes plot
304
- ax2 = plt.subplot(gs[1])
305
- ax2.plot(chunks_x, chunk_sizes, color='#ff9e00', marker='o', linewidth=2)
306
- ax2.set_xlabel('Chunk Number')
307
- ax2.set_ylabel('Chunk Size (chars)')
308
- setup_plot(fig, ax2, 'Chunk Sizes')
309
-
310
- # Save plot
311
- plt.savefig('chunk_times.png', format='png')
312
- plt.close()
313
-
314
- # Calculate metrics
315
- total_time = time.time() - start_time
316
- tokens_per_second = total_tokens / total_time
317
-
318
- print(f"\nProcessing Metrics:")
319
- print(f"Total tokens: {total_tokens}")
320
- print(f"Total time: {total_time:.2f}s")
321
- print(f"Tokens per second: {tokens_per_second:.2f}")
322
- print(f"Mean chunk time: {mean_time:.2f}s")
323
- print(f"Median chunk time: {median_time:.2f}s")
324
- if len(chunk_times) > 1:
325
- print(f"Std dev: {std_time:.2f}s")
326
- print(f"\nChunk time plot saved as 'chunk_times.png'")
327
-
328
- return audio, len(audio) / 24000 # Return audio array and duration
329
-
330
  except Exception as e:
331
  print(f"Error generating speech: {str(e)}")
332
  raise
 
16
  ensure_dir,
17
  concatenate_audio_chunks
18
  )
19
+ import spaces
20
 
21
  class TTSModel:
22
  """GPU-accelerated TTS model manager"""
 
26
  self.voices_dir = "voices"
27
  self.model_repo = "hexgrad/Kokoro-82M"
28
  ensure_dir(self.voices_dir)
29
+ self.model_path = None
30
 
31
  # Load required modules
32
  py_modules = ["istftnet", "plbert", "models", "kokoro"]
 
50
  self.model_repo,
51
  ["kokoro-v0_19.pth", "config.json"]
52
  )
53
+ self.model_path = model_files[0] # kokoro-v0_19.pth
 
 
 
 
 
 
54
 
55
+ # Download voice files
56
+ download_voice_files(self.model_repo, "voices", self.voices_dir)
57
+
58
+ # Get list of available voices
59
+ available_voices = self.list_voices()
60
+
61
  print("Model initialization complete")
62
  return True
63
 
 
68
  def ensure_voice_downloaded(self, voice_name: str) -> bool:
69
  """Ensure specific voice is downloaded"""
70
  try:
71
+ voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
72
  if not os.path.exists(voice_path):
73
  print(f"Downloading voice {voice_name}.pt...")
74
  download_voice_files(self.model_repo, [f"{voice_name}.pt"], self.voices_dir)
 
79
 
80
  def list_voices(self) -> List[str]:
81
  """List available voices"""
82
+ voices = []
83
+ voices_subdir = os.path.join(self.voices_dir, "voices")
84
+ if os.path.exists(voices_subdir):
85
+ for file in os.listdir(voices_subdir):
86
+ if file.endswith(".pt"):
87
+ voice_name = file[:-3]
88
+ voices.append(voice_name)
89
+ return voices
90
 
91
+ # def _ensure_model_on_gpu(self) -> None:
92
+ # """Ensure model is on GPU and stays there"""
93
+ # if not hasattr(self, '_model_on_gpu') or not self._model_on_gpu:
94
+ # print("Moving model to GPU...")
95
+ # with torch.cuda.device(0):
96
+ # torch.cuda.set_device(0)
97
+ # if hasattr(self.model, 'to'):
98
+ # self.model.to('cuda')
99
+ # else:
100
+ # for name in self.model:
101
+ # if isinstance(self.model[name], torch.Tensor):
102
+ # self.model[name] = self.model[name].cuda()
103
+ # self._model_on_gpu = True
104
 
105
  def _generate_audio(self, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> np.ndarray:
106
  """GPU-accelerated audio generation"""
107
  try:
108
  with torch.cuda.device(0):
109
  torch.cuda.set_device(0)
110
+ try:
111
+ # Build model if needed
112
+ if self.model is None:
113
+ print("Building model...")
114
+ device = torch.device('cuda')
115
+ self.model = self.build_model(self.model_path, device=device)
116
+ if self.model is None:
117
+ raise ValueError("Failed to build model")
118
+ print("Model built successfully")
119
+
120
+ # Move model to GPU if needed
121
+ if not hasattr(self.model, '_on_gpu'):
122
+ print("Moving model to GPU...")
123
+ if hasattr(self.model, 'to'):
124
+ self.model = self.model.to('cuda')
125
+ else:
126
+ for name in self.model:
127
+ if isinstance(self.model[name], torch.Tensor):
128
+ self.model[name] = self.model[name].cuda()
129
+ self.model._on_gpu = True
130
+ except Exception as e:
131
+ print(f"Error building model: {str(e)}")
132
+ print("Attempting to continue")
133
+ raise e
134
  # Move voicepack to GPU
135
  voicepack = voicepack.cuda()
136
 
 
148
  except Exception as e:
149
  print(f"Error in audio generation: {str(e)}")
150
  raise e
151
+
152
+ @spaces.GPU(duration=None) # Duration will be set by the UI
153
+ def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
154
  """Generate speech from text. Returns (audio_array, duration)
155
 
156
  Args:
157
  text: Input text to convert to speech
158
  voice_name: Name of voice to use
159
  speed: Speech speed multiplier
160
+ progress_callback: Optional callback function(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress)
161
+ progress_state: Dictionary tracking generation progress metrics
162
+ progress: Progress callback from Gradio
163
  """
164
  try:
 
 
 
165
  start_time = time.time()
 
 
 
 
 
 
 
 
166
  with torch.cuda.device(0):
167
  torch.cuda.set_device(0)
168
+ if not text or not voice_names:
169
+ raise ValueError("Text and voice name are required")
170
+ # Build model directly on GPU
171
 
172
+ # Build model if needed
 
 
 
 
 
 
 
 
 
 
 
173
  if self.model is None:
174
+ print("Building model...")
175
+ self.model = self.build_model(self.model_path, device='cuda')
176
+ if self.model is None:
177
+ raise ValueError("Failed to build model")
178
+ print("Model built successfully")
179
 
180
  # Move model to GPU if needed
181
+ if not hasattr(self.model, '_on_gpu'):
182
  print("Moving model to GPU...")
183
  if hasattr(self.model, 'to'):
184
+ self.model = self.model.to('cuda')
185
  else:
186
  for name in self.model:
187
  if isinstance(self.model[name], torch.Tensor):
188
  self.model[name] = self.model[name].cuda()
189
+ self.model._on_gpu = True
190
+
191
+ t_voices = []
192
+ if isinstance(voice_names, list) and len(voice_names) > 1:
193
+ for voice in voice_names:
194
+ try:
195
+ voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
196
+ voicepack = torch.load(voice_path, weights_only=True)
197
+ t_voices.append(voicepack)
198
+ except Exception as e:
199
+ print(f"Warning: Failed to load voice {voice}: {str(e)}")
200
+
201
+ # Combine voices by taking mean
202
+ voicepack = torch.mean(torch.stack(t_voices), dim=0)
203
+ voice_name = "_".join(voice_names)
204
+ else:
205
+ voice_name = voice_names[0]
206
+ voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
207
+ voicepack = torch.load(voice_path, weights_only=True)
208
+
209
+ # Count tokens and normalize text
210
+ total_tokens = count_tokens(text)
211
+ text = normalize_text(text)
212
+ if not text:
213
+ raise ValueError("Text is empty after normalization")
214
+
215
+ # Break text into chunks for better memory management
216
+ chunks = chunk_text(text)
217
+ print(f"Processing {len(chunks)} chunks...")
218
 
219
  # Process all chunks within same GPU context
220
  audio_chunks = []
 
233
  )
234
  chunk_time = time.time() - chunk_start
235
 
236
+ # Calculate per-chunk metrics
237
  chunk_tokens = count_tokens(chunk)
238
+ chunk_tokens_per_sec = chunk_tokens / chunk_time
239
+
240
+ # Update totals for overall stats
241
  total_processed_tokens += chunk_tokens
242
  total_processed_time += chunk_time
 
243
 
244
  # Calculate processing speed metrics
245
  chunk_duration = len(chunk_audio) / 24000 # audio duration in seconds
 
249
  chunk_times.append(chunk_time)
250
  chunk_sizes.append(len(chunk))
251
  print(f"Chunk {i+1}/{len(chunks)} processed in {chunk_time:.2f}s")
252
+ print(f"Current tokens/sec: {chunk_tokens_per_sec:.2f}")
253
  print(f"Real-time factor: {rtf:.2f}x")
254
  print(f"{times_faster:.1f}x faster than real-time")
255
 
 
257
 
258
  # Call progress callback if provided
259
  if progress_callback:
260
+ progress_callback(
261
+ i + 1, # chunk_num
262
+ len(chunks), # total_chunks
263
+ chunk_tokens_per_sec, # Pass per-chunk rate instead of cumulative
264
+ rtf,
265
+ progress_state, # Added
266
+ start_time, # Added
267
+ gpu_timeout, # Use the timeout value from UI
268
+ progress # Added
269
+ )
270
 
271
  # Concatenate audio chunks
272
  audio = concatenate_audio_chunks(audio_chunks)
273
+
274
+ # Return audio and metrics
275
+ return (
276
+ audio, # Audio array
277
+ len(audio) / 24000, # Duration
278
+ {
279
+ "chunk_times": chunk_times,
280
+ "chunk_sizes": chunk_sizes,
281
+ "tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]],
282
+ "rtf": [float(x) for x in progress_state["rtf"]],
283
+ "total_tokens": total_tokens,
284
+ "total_time": time.time() - start_time
285
+ }
286
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  except Exception as e:
288
  print(f"Error generating speech: {str(e)}")
289
  raise