Mikerhinos aadnk commited on
Commit
9e8682f
·
0 Parent(s):

Duplicate from aadnk/whisper-webui

Browse files

Co-authored-by: Kristian Stangeland <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ flagged/
4
+ *.py[cod]
5
+ *$py.class
README.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Whisper Webui
3
+ emoji: ⚡
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.3.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: aadnk/whisper-webui
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ # Running Locally
17
+
18
+ To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
19
+ ```
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ Finally, run the full version (no audio length restrictions) of the app with parallel CPU/GPU enabled:
24
+ ```
25
+ python app.py --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
26
+ ```
27
+
28
+ You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
29
+ ```
30
+ python cli.py \
31
+ [--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
32
+ [--vad_merge_window VAD_MERGE_WINDOW] \
33
+ [--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
34
+ [--vad_padding VAD_PADDING] \
35
+ [--vad_prompt_window VAD_PROMPT_WINDOW]
36
+ [--vad_cpu_cores NUMBER_OF_CORES]
37
+ [--vad_parallel_devices COMMA_DELIMITED_DEVICES]
38
+ [--auto_parallel BOOLEAN]
39
+ ```
40
+ In addition, you may also use URL's in addition to file paths as input.
41
+ ```
42
+ python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
43
+ ```
44
+
45
+ ## Parallel Execution
46
+
47
+ You can also run both the Web-UI or the CLI on multiple GPUs in parallel, using the `vad_parallel_devices` option. This takes a comma-delimited list of
48
+ device IDs (0, 1, etc.) that Whisper should be distributed to and run on concurrently:
49
+ ```
50
+ python cli.py --model large --vad silero-vad --language Japanese \
51
+ --vad_parallel_devices 0,1 "https://www.youtube.com/watch?v=4cICErqqRSM"
52
+ ```
53
+
54
+ Note that this requires a VAD to function properly, otherwise only the first GPU will be used. Though you could use `period-vad` to avoid taking the hit
55
+ of running Silero-Vad, at a slight cost to accuracy.
56
+
57
+ This is achieved by creating N child processes (where N is the number of selected devices), where Whisper is run concurrently. In `app.py`, you can also
58
+ set the `vad_process_timeout` option. This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory.
59
+ The default value is 30 minutes.
60
+
61
+ ```
62
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600
63
+ ```
64
+
65
+ To execute the Silero VAD itself in parallel, use the `vad_cpu_cores` option:
66
+ ```
67
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600 --vad_cpu_cores 4
68
+ ```
69
+
70
+ You may also use `vad_process_timeout` with a single device (`--vad_parallel_devices 0`), if you prefer to always free video memory after a period of time.
71
+
72
+ ### Auto Parallel
73
+
74
+ You can also set `auto_parallel` to `True`. This will set `vad_parallel_devices` to use all the GPU devices on the system, and `vad_cpu_cores` to be equal to the number of
75
+ cores (up to 8):
76
+ ```
77
+ python app.py --input_audio_max_duration -1 --auto_parallel True
78
+ ```
79
+
80
+ # Docker
81
+
82
+ To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU.
83
+ Then either use the GitLab hosted container below, or check out this repository and build an image:
84
+ ```
85
+ sudo docker build -t whisper-webui:1 .
86
+ ```
87
+
88
+ You can then start the WebUI with GPU support like so:
89
+ ```
90
+ sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
91
+ ```
92
+
93
+ Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
94
+ ```
95
+ sudo docker run -d -p 7860:7860 whisper-webui:1
96
+ ```
97
+
98
+ # GitLab Docker Registry
99
+
100
+ This Docker container is also hosted on GitLab:
101
+
102
+ ```
103
+ sudo docker run -d --gpus=all -p 7860:7860 registry.gitlab.com/aadnk/whisper-webui:latest
104
+ ```
105
+
106
+ ## Custom Arguments
107
+
108
+ You can also pass custom arguments to `app.py` in the Docker container, for instance to be able to use all the GPUs in parallel:
109
+ ```
110
+ sudo docker run -d --gpus all -p 7860:7860 \
111
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
112
+ --restart=on-failure:15 registry.gitlab.com/aadnk/whisper-webui:latest \
113
+ app.py --input_audio_max_duration -1 --server_name 0.0.0.0 --vad_parallel_devices 0,1 \
114
+ --default_vad silero-vad --default_model_name large
115
+ ```
116
+
117
+ You can also call `cli.py` the same way:
118
+ ```
119
+ sudo docker run --gpus all \
120
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
121
+ --mount type=bind,source=${PWD},target=/app/data \
122
+ registry.gitlab.com/aadnk/whisper-webui:latest \
123
+ cli.py --model large --vad_parallel_devices 0,1 --vad silero-vad \
124
+ --output_dir /app/data /app/data/YOUR-FILE-HERE.mp4
125
+ ```
126
+
127
+ ## Caching
128
+
129
+ Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
130
+ To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
131
+ prepopulate the directory with the different Whisper models.
132
+ ```
133
+ sudo docker run -d --gpus=all -p 7860:7860 \
134
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
135
+ registry.gitlab.com/aadnk/whisper-webui:latest
136
+ ```
app-local.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1)
app-network.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions, and make it available on the network
2
+ from app import create_ui
3
+ create_ui(-1, server_name="0.0.0.0")
app-shared.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1, share=True)
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Iterator
3
+ import argparse
4
+
5
+ from io import StringIO
6
+ import os
7
+ import pathlib
8
+ import tempfile
9
+
10
+ import torch
11
+ from src.modelCache import ModelCache
12
+ from src.vadParallel import ParallelContext, ParallelTranscription
13
+
14
+ # External programs
15
+ import ffmpeg
16
+
17
+ # UI
18
+ import gradio as gr
19
+
20
+ from src.download import ExceededMaximumDuration, download_url
21
+ from src.utils import slugify, write_srt, write_vtt
22
+ from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
23
+ from src.whisperContainer import WhisperContainer
24
+
25
+ # Limitations (set to -1 to disable)
26
+ DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
27
+
28
+ # Whether or not to automatically delete all uploaded files, to save disk space
29
+ DELETE_UPLOADED_FILES = True
30
+
31
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
32
+ MAX_FILE_PREFIX_LENGTH = 17
33
+
34
+ # Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
35
+ MAX_AUTO_CPU_CORES = 8
36
+
37
+ LANGUAGES = [
38
+ "English", "Chinese", "German", "Spanish", "Russian", "Korean",
39
+ "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
40
+ "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
41
+ "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
42
+ "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
43
+ "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
44
+ "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
45
+ "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
46
+ "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
47
+ "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
48
+ "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
49
+ "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
50
+ "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
51
+ "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
52
+ "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
53
+ "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
54
+ "Hausa", "Bashkir", "Javanese", "Sundanese"
55
+ ]
56
+
57
+ class WhisperTranscriber:
58
+ def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
59
+ self.model_cache = ModelCache()
60
+ self.parallel_device_list = None
61
+ self.gpu_parallel_context = None
62
+ self.cpu_parallel_context = None
63
+ self.vad_process_timeout = vad_process_timeout
64
+ self.vad_cpu_cores = vad_cpu_cores
65
+
66
+ self.vad_model = None
67
+ self.inputAudioMaxDuration = input_audio_max_duration
68
+ self.deleteUploadedFiles = delete_uploaded_files
69
+
70
+ def set_parallel_devices(self, vad_parallel_devices: str):
71
+ self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
72
+
73
+ def set_auto_parallel(self, auto_parallel: bool):
74
+ if auto_parallel:
75
+ if torch.cuda.is_available():
76
+ self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
77
+
78
+ self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
79
+ print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
80
+
81
+ def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
82
+ try:
83
+ source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
84
+
85
+ try:
86
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
87
+ selectedModel = modelName if modelName is not None else "base"
88
+
89
+ model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
90
+
91
+ # Execute whisper
92
+ result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
93
+
94
+ # Write result
95
+ downloadDirectory = tempfile.mkdtemp()
96
+
97
+ filePrefix = slugify(sourceName, allow_unicode=True)
98
+ download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)
99
+
100
+ return download, text, vtt
101
+
102
+ finally:
103
+ # Cleanup source
104
+ if self.deleteUploadedFiles:
105
+ print("Deleting source file " + source)
106
+ os.remove(source)
107
+
108
+ except ExceededMaximumDuration as e:
109
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
110
+
111
+ def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
112
+ vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
113
+
114
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
115
+
116
+ if ('task' in decodeOptions):
117
+ task = decodeOptions.pop('task')
118
+
119
+ # Callable for processing an audio file
120
+ whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
121
+
122
+ # The results
123
+ if (vad == 'silero-vad'):
124
+ # Silero VAD where non-speech gaps are transcribed
125
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
126
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
127
+ elif (vad == 'silero-vad-skip-gaps'):
128
+ # Silero VAD where non-speech gaps are simply ignored
129
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
130
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
131
+ elif (vad == 'silero-vad-expand-into-gaps'):
132
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
133
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
134
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
135
+ elif (vad == 'periodic-vad'):
136
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
137
+ # it may create a break in the middle of a sentence, causing some artifacts.
138
+ periodic_vad = VadPeriodicTranscription()
139
+ period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
140
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
141
+
142
+ else:
143
+ if (self._has_parallel_devices()):
144
+ # Use a simple period transcription instead, as we need to use the parallel context
145
+ periodic_vad = VadPeriodicTranscription()
146
+ period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
147
+
148
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
149
+ else:
150
+ # Default VAD
151
+ result = whisperCallable(audio_path, 0, None, None)
152
+
153
+ return result
154
+
155
+ def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
156
+ if (not self._has_parallel_devices()):
157
+ # No parallel devices, so just run the VAD and Whisper in sequence
158
+ return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
159
+
160
+ gpu_devices = self.parallel_device_list
161
+
162
+ if (gpu_devices is None or len(gpu_devices) == 0):
163
+ # No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
164
+ gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
165
+
166
+ # Create parallel context if needed
167
+ if (self.gpu_parallel_context is None):
168
+ # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
169
+ self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
170
+ # We also need a CPU context for the VAD
171
+ if (self.cpu_parallel_context is None):
172
+ self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
173
+
174
+ parallel_vad = ParallelTranscription()
175
+ return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
176
+ config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
177
+ cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context)
178
+
179
+ def _has_parallel_devices(self):
180
+ return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
181
+
182
+ def _concat_prompt(self, prompt1, prompt2):
183
+ if (prompt1 is None):
184
+ return prompt2
185
+ elif (prompt2 is None):
186
+ return prompt1
187
+ else:
188
+ return prompt1 + " " + prompt2
189
+
190
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
191
+ # Use Silero VAD
192
+ if (self.vad_model is None):
193
+ self.vad_model = VadSileroTranscription()
194
+
195
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
196
+ max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
197
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
198
+ max_prompt_window=vadPromptWindow)
199
+
200
+ return config
201
+
202
+ def write_result(self, result: dict, source_name: str, output_dir: str):
203
+ if not os.path.exists(output_dir):
204
+ os.makedirs(output_dir)
205
+
206
+ text = result["text"]
207
+ language = result["language"]
208
+ languageMaxLineWidth = self.__get_max_line_width(language)
209
+
210
+ print("Max line width " + str(languageMaxLineWidth))
211
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
212
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
213
+
214
+ output_files = []
215
+ output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
216
+ output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
217
+ output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
218
+
219
+ return output_files, text, vtt
220
+
221
+ def clear_cache(self):
222
+ self.model_cache.clear()
223
+ self.vad_model = None
224
+
225
+ def __get_source(self, urlData, uploadFile, microphoneData):
226
+ if urlData:
227
+ # Download from YouTube
228
+ source = download_url(urlData, self.inputAudioMaxDuration)[0]
229
+ else:
230
+ # File input
231
+ source = uploadFile if uploadFile is not None else microphoneData
232
+
233
+ if self.inputAudioMaxDuration > 0:
234
+ # Calculate audio length
235
+ audioDuration = ffmpeg.probe(source)["format"]["duration"]
236
+
237
+ if float(audioDuration) > self.inputAudioMaxDuration:
238
+ raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")
239
+
240
+ file_path = pathlib.Path(source)
241
+ sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
242
+
243
+ return source, sourceName
244
+
245
+ def __get_max_line_width(self, language: str) -> int:
246
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
247
+ # Chinese characters and kana are wider, so limit line length to 40 characters
248
+ return 40
249
+ else:
250
+ # TODO: Add more languages
251
+ # 80 latin characters should fit on a 1080p/720p screen
252
+ return 80
253
+
254
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
255
+ segmentStream = StringIO()
256
+
257
+ if format == 'vtt':
258
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
259
+ elif format == 'srt':
260
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
261
+ else:
262
+ raise Exception("Unknown format " + format)
263
+
264
+ segmentStream.seek(0)
265
+ return segmentStream.read()
266
+
267
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
268
+ # Write the text to a file
269
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
270
+ file.write(text)
271
+
272
+ return file.name
273
+
274
+ def close(self):
275
+ self.clear_cache()
276
+
277
+ if (self.gpu_parallel_context is not None):
278
+ self.gpu_parallel_context.close()
279
+ if (self.cpu_parallel_context is not None):
280
+ self.cpu_parallel_context.close()
281
+
282
+
283
+ def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
284
+ default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False):
285
+ ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores)
286
+
287
+ # Specify a list of devices to use for parallel processing
288
+ ui.set_parallel_devices(vad_parallel_devices)
289
+ ui.set_auto_parallel(auto_parallel)
290
+
291
+ ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
292
+ ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
293
+ ui_description += " as well as speech translation and language identification. "
294
+
295
+ ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
296
+
297
+ if input_audio_max_duration > 0:
298
+ ui_description += "\n\n" + "Max audio file length: " + str(input_audio_max_duration) + " s"
299
+
300
+ ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
301
+
302
+ demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
303
+ gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value=default_model_name, label="Model"),
304
+ gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
305
+ gr.Text(label="URL (YouTube, etc.)"),
306
+ gr.Audio(source="upload", type="filepath", label="Upload Audio"),
307
+ gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
308
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
309
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
310
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
311
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
312
+ gr.Number(label="VAD - Padding (s)", precision=None, value=1),
313
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
314
+ ], outputs=[
315
+ gr.File(label="Download"),
316
+ gr.Text(label="Transcription"),
317
+ gr.Text(label="Segments")
318
+ ])
319
+
320
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
321
+
322
+ # Clean up
323
+ ui.close()
324
+
325
+ if __name__ == '__main__':
326
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
327
+ parser.add_argument("--input_audio_max_duration", type=int, default=DEFAULT_INPUT_AUDIO_MAX_DURATION, help="Maximum audio file length in seconds, or -1 for no limit.")
328
+ parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
329
+ parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
330
+ parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
331
+ parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
332
+ parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
333
+ parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
334
+ parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
335
+ parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
336
+ parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
337
+
338
+ args = parser.parse_args().__dict__
339
+ create_ui(**args)
cli.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ from urllib.parse import urlparse
5
+ import warnings
6
+ import numpy as np
7
+
8
+ import whisper
9
+
10
+ import torch
11
+ from app import LANGUAGES, WhisperTranscriber
12
+ from src.download import download_url
13
+
14
+ from src.utils import optional_float, optional_int, str2bool
15
+ from src.whisperContainer import WhisperContainer
16
+
17
+
18
+ def cli():
19
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
21
+ parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large"], help="name of the Whisper model to use")
22
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
23
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
24
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
25
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
26
+
27
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
28
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
29
+
30
+ parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
31
+ parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
32
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
33
+ parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
34
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
35
+ parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
36
+ parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
37
+ parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
38
+
39
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
40
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
41
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
42
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
43
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
44
+
45
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
46
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
47
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
48
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
49
+
50
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
51
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
52
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
53
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
54
+
55
+ args = parser.parse_args().__dict__
56
+ model_name: str = args.pop("model")
57
+ model_dir: str = args.pop("model_dir")
58
+ output_dir: str = args.pop("output_dir")
59
+ device: str = args.pop("device")
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
63
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
64
+ args["language"] = "en"
65
+
66
+ temperature = args.pop("temperature")
67
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
68
+ if temperature_increment_on_fallback is not None:
69
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
70
+ else:
71
+ temperature = [temperature]
72
+
73
+ vad = args.pop("vad")
74
+ vad_merge_window = args.pop("vad_merge_window")
75
+ vad_max_merge_size = args.pop("vad_max_merge_size")
76
+ vad_padding = args.pop("vad_padding")
77
+ vad_prompt_window = args.pop("vad_prompt_window")
78
+ vad_cpu_cores = args.pop("vad_cpu_cores")
79
+ auto_parallel = args.pop("auto_parallel")
80
+
81
+ model = WhisperContainer(model_name, device=device, download_root=model_dir)
82
+ transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores)
83
+ transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
84
+ transcriber.set_auto_parallel(auto_parallel)
85
+
86
+ if (transcriber._has_parallel_devices()):
87
+ print("Using parallel devices:", transcriber.parallel_device_list)
88
+
89
+ for audio_path in args.pop("audio"):
90
+ sources = []
91
+
92
+ # Detect URL and download the audio
93
+ if (uri_validator(audio_path)):
94
+ # Download from YouTube/URL directly
95
+ for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
96
+ source_name = os.path.basename(source_path)
97
+ sources.append({ "path": source_path, "name": source_name })
98
+ else:
99
+ sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
100
+
101
+ for source in sources:
102
+ source_path = source["path"]
103
+ source_name = source["name"]
104
+
105
+ result = transcriber.transcribe_file(model, source_path, temperature=temperature,
106
+ vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
107
+ vadPadding=vad_padding, vadPromptWindow=vad_prompt_window, **args)
108
+
109
+ transcriber.write_result(result, source_name, output_dir)
110
+
111
+ transcriber.close()
112
+
113
+ def uri_validator(x):
114
+ try:
115
+ result = urlparse(x)
116
+ return all([result.scheme, result.netloc])
117
+ except:
118
+ return False
119
+
120
+ if __name__ == '__main__':
121
+ cli()
dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM huggingface/transformers-pytorch-gpu
2
+ EXPOSE 7860
3
+
4
+ ADD . /opt/whisper-webui/
5
+
6
+ # Latest version of transformers-pytorch-gpu seems to lack tk.
7
+ # Further, pip install fails, so we must upgrade pip first.
8
+ RUN apt-get -y install python3-tk
9
+ RUN python3 -m pip install --upgrade pip &&\
10
+ python3 -m pip install -r /opt/whisper-webui/requirements.txt
11
+
12
+ # Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
13
+ # You can also bind this directory in the container to somewhere on the host.
14
+
15
+ # To be able to see logs in real time
16
+ ENV PYTHONUNBUFFERED=1
17
+
18
+ WORKDIR /opt/whisper-webui/
19
+ ENTRYPOINT ["python3"]
20
+ CMD ["app.py", "--input_audio_max_duration", "-1", "--server_name", "0.0.0.0", "--auto_parallel", "True"]
docs/options.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Options
2
+ To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
+ supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
+ in the file selector to select any file type, including video files) or use the microphone.
5
+
6
+ For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option.
7
+
8
+ ## Model
9
+ Select the model that Whisper will use to transcribe the audio:
10
+
11
+ | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
12
+ |--------|------------|--------------------|--------------------|---------------|----------------|
13
+ | tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
14
+ | base | 74 M | base.en | base | ~1 GB | ~16x |
15
+ | small | 244 M | small.en | small | ~2 GB | ~6x |
16
+ | medium | 769 M | medium.en | medium | ~5 GB | ~2x |
17
+ | large | 1550 M | N/A | large | ~10 GB | 1x |
18
+
19
+ ## Language
20
+
21
+ Select the language, or leave it empty for Whisper to automatically detect it.
22
+
23
+ Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
24
+ language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
25
+
26
+ ## Inputs
27
+ The options "URL (YouTube, etc.)", "Upload Audio" or "Micriphone Input" allows you to send an audio input to the model.
28
+
29
+ Note that the UI will only process the first valid input - i.e. if you enter both an URL and upload an audio, it will only process
30
+ the URL.
31
+
32
+ ## Task
33
+ Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
34
+
35
+ ## Vad
36
+ Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
37
+ loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
38
+ with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
39
+
40
+ Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
41
+ So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
42
+
43
+ * none
44
+ * Run whisper on the entire audio input
45
+ * silero-vad
46
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
47
+ on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
48
+ on the non-speech section.
49
+ * silero-vad-expand-into-gaps
50
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
51
+ such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
52
+ 00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
53
+ * silero-vad-skip-gaps
54
+ * As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
55
+ may cause dialogue to be skipped.
56
+ * periodic-vad
57
+ * Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
58
+ a sentence or word in two.
59
+
60
+ ## VAD - Merge Window
61
+ If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
62
+
63
+ ## VAD - Max Merge Size (s)
64
+ Disables merging of adjacent speech sections if they are this number of seconds long.
65
+
66
+ ## VAD - Padding (s)
67
+ The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
68
+ larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
69
+ a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
70
+ to each transcribed line. The default value is 1 second.
71
+
72
+ ## VAD - Prompt Window (s)
73
+ The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
74
+ number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
75
+ 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
76
+
77
+ Note that detected lines in gaps between speech sections will not be included in the prompt
78
+ (if silero-vad or silero-vad-expand-into-gaps) is used.
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/openai/whisper.git
2
+ transformers
3
+ ffmpeg-python==0.2.0
4
+ gradio
5
+ yt-dlp
6
+ torchaudio
src/__init__.py ADDED
File without changes
src/download.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'paths': {
34
+ 'home': destinationDirectory
35
+ }
36
+ }
37
+ if (playlistItems):
38
+ ydl_opts['playlist_items'] = playlistItems
39
+
40
+ # Add output template if specified
41
+ if outputTemplate:
42
+ ydl_opts['outtmpl'] = outputTemplate
43
+
44
+ filename_collector = FilenameCollectorPP()
45
+
46
+ with YoutubeDL(ydl_opts) as ydl:
47
+ if maxDuration and maxDuration > 0:
48
+ info = ydl.extract_info(url, download=False)
49
+ duration = info['duration']
50
+
51
+ if duration >= maxDuration:
52
+ raise ExceededMaximumDuration(videoDuration=duration, maxDuration=maxDuration, message="Video is too long")
53
+
54
+ ydl.add_post_processor(filename_collector)
55
+ ydl.download([url])
56
+
57
+ if len(filename_collector.filenames) <= 0:
58
+ raise Exception("Cannot download " + url)
59
+
60
+ result = []
61
+
62
+ for filename in filename_collector.filenames:
63
+ result.append(filename)
64
+ print("Downloaded " + filename)
65
+
66
+ return result
67
+
68
+ class ExceededMaximumDuration(Exception):
69
+ def __init__(self, videoDuration, maxDuration, message):
70
+ self.videoDuration = videoDuration
71
+ self.maxDuration = maxDuration
72
+ super().__init__(message)
src/modelCache.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelCache:
2
+ def __init__(self):
3
+ self._cache = dict()
4
+
5
+ def get(self, model_key: str, model_factory):
6
+ result = self._cache.get(model_key)
7
+
8
+ if result is None:
9
+ result = model_factory()
10
+ self._cache[model_key] = result
11
+ return result
12
+
13
+ def clear(self):
14
+ self._cache.clear()
15
+
16
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
17
+ GLOBAL_MODEL_CACHE = ModelCache()
src/segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
src/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+
8
+
9
+ def exact_div(x, y):
10
+ assert x % y == 0
11
+ return x // y
12
+
13
+
14
+ def str2bool(string):
15
+ str2val = {"True": True, "False": False}
16
+ if string in str2val:
17
+ return str2val[string]
18
+ else:
19
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
20
+
21
+
22
+ def optional_int(string):
23
+ return None if string == "None" else int(string)
24
+
25
+
26
+ def optional_float(string):
27
+ return None if string == "None" else float(string)
28
+
29
+
30
+ def compression_ratio(text) -> float:
31
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
32
+
33
+
34
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
35
+ assert seconds >= 0, "non-negative timestamp expected"
36
+ milliseconds = round(seconds * 1000.0)
37
+
38
+ hours = milliseconds // 3_600_000
39
+ milliseconds -= hours * 3_600_000
40
+
41
+ minutes = milliseconds // 60_000
42
+ milliseconds -= minutes * 60_000
43
+
44
+ seconds = milliseconds // 1_000
45
+ milliseconds -= seconds * 1_000
46
+
47
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
48
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
49
+
50
+
51
+ def write_txt(transcript: Iterator[dict], file: TextIO):
52
+ for segment in transcript:
53
+ print(segment['text'].strip(), file=file, flush=True)
54
+
55
+
56
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
+ print("WEBVTT\n", file=file)
58
+ for segment in transcript:
59
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
60
+
61
+ print(
62
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
63
+ f"{text}\n",
64
+ file=file,
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
70
+ """
71
+ Write a transcript to a file in SRT format.
72
+ Example usage:
73
+ from pathlib import Path
74
+ from whisper.utils import write_srt
75
+ result = transcribe(model, audio_path, temperature=temperature, **args)
76
+ # save SRT
77
+ audio_basename = Path(audio_path).stem
78
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
79
+ write_srt(result["segments"], file=srt)
80
+ """
81
+ for i, segment in enumerate(transcript, start=1):
82
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
+
84
+ # write srt lines
85
+ print(
86
+ f"{i}\n"
87
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
88
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
89
+ f"{text}\n",
90
+ file=file,
91
+ flush=True,
92
+ )
93
+
94
+ def process_text(text: str, maxLineWidth=None):
95
+ if (maxLineWidth is None or maxLineWidth < 0):
96
+ return text
97
+
98
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
99
+ return '\n'.join(lines)
100
+
101
+ def slugify(value, allow_unicode=False):
102
+ """
103
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
104
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
105
+ dashes to single dashes. Remove characters that aren't alphanumerics,
106
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
107
+ trailing whitespace, dashes, and underscores.
108
+ """
109
+ value = str(value)
110
+ if allow_unicode:
111
+ value = unicodedata.normalize('NFKC', value)
112
+ else:
113
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
114
+ value = re.sub(r'[^\w\s-]', '', value.lower())
115
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
src/vad.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+ import time
4
+
5
+ from typing import Any, Deque, Iterator, List, Dict
6
+
7
+ from pprint import pprint
8
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
9
+
10
+ from src.segments import merge_timestamps
11
+ from src.whisperContainer import WhisperCallback
12
+
13
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
14
+ try:
15
+ import tensorflow as tf
16
+ except ModuleNotFoundError:
17
+ # Error handling
18
+ pass
19
+
20
+ import torch
21
+
22
+ import ffmpeg
23
+ import numpy as np
24
+
25
+ from src.utils import format_timestamp
26
+ from enum import Enum
27
+
28
+ class NonSpeechStrategy(Enum):
29
+ """
30
+ Ignore non-speech frames segments.
31
+ """
32
+ SKIP = 1
33
+ """
34
+ Just treat non-speech segments as speech.
35
+ """
36
+ CREATE_SEGMENT = 2
37
+ """
38
+ Expand speech segments into subsequent non-speech segments.
39
+ """
40
+ EXPAND_SEGMENT = 3
41
+
42
+ # Defaults for Silero
43
+ SPEECH_TRESHOLD = 0.3
44
+
45
+ # Minimum size of segments to process
46
+ MIN_SEGMENT_DURATION = 1
47
+
48
+ # The maximum time for texts from old segments to be used in the next segment
49
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
50
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
51
+
52
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
53
+
54
+ class TranscriptionConfig(ABC):
55
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
56
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
57
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
58
+ self.non_speech_strategy = non_speech_strategy
59
+ self.segment_padding_left = segment_padding_left
60
+ self.segment_padding_right = segment_padding_right
61
+ self.max_silent_period = max_silent_period
62
+ self.max_merge_size = max_merge_size
63
+ self.max_prompt_window = max_prompt_window
64
+ self.initial_segment_index = initial_segment_index
65
+
66
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
67
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
68
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
69
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
70
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
71
+ self.periodic_duration = periodic_duration
72
+
73
+ class AbstractTranscription(ABC):
74
+ def __init__(self, sampling_rate: int = 16000):
75
+ self.sampling_rate = sampling_rate
76
+
77
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
78
+ return load_audio(str, self.sampling_rate, start_time, duration)
79
+
80
+ @abstractmethod
81
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
82
+ """
83
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
84
+
85
+ Parameters
86
+ ----------
87
+ audio: str
88
+ The audio file.
89
+ config: TranscriptionConfig
90
+ The transcription configuration.
91
+
92
+ Returns
93
+ -------
94
+ A list of start and end timestamps, in fractional seconds.
95
+ """
96
+ return
97
+
98
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
99
+ """
100
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method,
101
+ after merging the given segments using the specified configuration.
102
+
103
+ Parameters
104
+ ----------
105
+ audio: str
106
+ The audio file.
107
+ config: TranscriptionConfig
108
+ The transcription configuration.
109
+
110
+ Returns
111
+ -------
112
+ A list of start and end timestamps, in fractional seconds.
113
+ """
114
+ merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
115
+ config.segment_padding_left, config.segment_padding_right)
116
+
117
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
118
+ # Expand segments to include the gaps between them
119
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
120
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
121
+ merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
122
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
123
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
124
+ merged = self.expand_gaps(merged, total_duration=total_duration)
125
+ else:
126
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
127
+
128
+ print("Transcribing non-speech:")
129
+ pprint(merged)
130
+ return merged
131
+
132
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
133
+ """
134
+ Transcribe the given audo file.
135
+
136
+ Parameters
137
+ ----------
138
+ audio: str
139
+ The audio file.
140
+ whisperCallable: WhisperCallback
141
+ A callback object to call to transcribe each segment.
142
+
143
+ Returns
144
+ -------
145
+ A list of start and end timestamps, in fractional seconds.
146
+ """
147
+
148
+ max_audio_duration = get_audio_duration(audio)
149
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
150
+
151
+ # Get speech timestamps from full audio file
152
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
153
+
154
+ # A deque of transcribed segments that is passed to the next segment as a prompt
155
+ prompt_window = deque()
156
+
157
+ print("Processing timestamps:")
158
+ pprint(merged)
159
+
160
+ result = {
161
+ 'text': "",
162
+ 'segments': [],
163
+ 'language': ""
164
+ }
165
+ languageCounter = Counter()
166
+ detected_language = None
167
+
168
+ segment_index = config.initial_segment_index
169
+
170
+ # For each time segment, run whisper
171
+ for segment in merged:
172
+ segment_index += 1
173
+ segment_start = segment['start']
174
+ segment_end = segment['end']
175
+ segment_expand_amount = segment.get('expand_amount', 0)
176
+ segment_gap = segment.get('gap', False)
177
+
178
+ segment_duration = segment_end - segment_start
179
+
180
+ if segment_duration < MIN_SEGMENT_DURATION:
181
+ continue;
182
+
183
+ # Audio to run on Whisper
184
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
185
+ # Previous segments to use as a prompt
186
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
187
+
188
+ # Detected language
189
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
190
+
191
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
192
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
193
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
194
+
195
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
196
+
197
+ # Propagate expand amount to the segments
198
+ if (segment_expand_amount > 0):
199
+ segment_without_expansion = segment_duration - segment_expand_amount
200
+
201
+ for adjusted_segment in adjusted_segments:
202
+ adjusted_segment_end = adjusted_segment['end']
203
+
204
+ # Add expand amount if the segment got expanded
205
+ if (adjusted_segment_end > segment_without_expansion):
206
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
207
+
208
+ # Append to output
209
+ result['text'] += segment_result['text']
210
+ result['segments'].extend(adjusted_segments)
211
+
212
+ # Increment detected language
213
+ if not segment_gap:
214
+ languageCounter[segment_result['language']] += 1
215
+
216
+ # Update prompt window
217
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
218
+
219
+ if detected_language is not None:
220
+ result['language'] = detected_language
221
+
222
+ return result
223
+
224
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
225
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
226
+ # Add segments to the current prompt window (unless it is a speech gap)
227
+ if not segment_gap:
228
+ for segment in adjusted_segments:
229
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
230
+ prompt_window.append(segment)
231
+
232
+ while (len(prompt_window) > 0):
233
+ first_end_time = prompt_window[0].get('end', 0)
234
+ # Time expanded in the segments should be discounted from the prompt window
235
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
236
+
237
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
238
+ prompt_window.popleft()
239
+ else:
240
+ break
241
+
242
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
243
+ result = []
244
+ last_end_time = 0
245
+
246
+ for segment in segments:
247
+ segment_start = float(segment['start'])
248
+ segment_end = float(segment['end'])
249
+
250
+ if (last_end_time != segment_start):
251
+ delta = segment_start - last_end_time
252
+
253
+ if (min_gap_length is None or delta >= min_gap_length):
254
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
255
+
256
+ last_end_time = segment_end
257
+ result.append(segment)
258
+
259
+ # Also include total duration if specified
260
+ if (total_duration is not None and last_end_time < total_duration):
261
+ delta = total_duration - segment_start
262
+
263
+ if (min_gap_length is None or delta >= min_gap_length):
264
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
265
+
266
+ return result
267
+
268
+ # Expand the end time of each segment to the start of the next segment
269
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
270
+ result = []
271
+
272
+ if len(segments) == 0:
273
+ return result
274
+
275
+ # Add gap at the beginning if needed
276
+ if (segments[0]['start'] > 0):
277
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
278
+
279
+ for i in range(len(segments) - 1):
280
+ current_segment = segments[i]
281
+ next_segment = segments[i + 1]
282
+
283
+ delta = next_segment['start'] - current_segment['end']
284
+
285
+ # Expand if the gap actually exists
286
+ if (delta >= 0):
287
+ current_segment = current_segment.copy()
288
+ current_segment['expand_amount'] = delta
289
+ current_segment['end'] = next_segment['start']
290
+
291
+ result.append(current_segment)
292
+
293
+ # Add last segment
294
+ last_segment = segments[-1]
295
+ result.append(last_segment)
296
+
297
+ # Also include total duration if specified
298
+ if (total_duration is not None):
299
+ last_segment = result[-1]
300
+
301
+ if (last_segment['end'] < total_duration):
302
+ last_segment = last_segment.copy()
303
+ last_segment['end'] = total_duration
304
+ result[-1] = last_segment
305
+
306
+ return result
307
+
308
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
309
+ result = []
310
+
311
+ if len(segments) == 0:
312
+ return result
313
+
314
+ # Add gap at the beginning if needed
315
+ if (segments[0]['start'] > 0):
316
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
317
+
318
+ for i in range(len(segments) - 1):
319
+ expanded = False
320
+ current_segment = segments[i]
321
+ next_segment = segments[i + 1]
322
+
323
+ delta = next_segment['start'] - current_segment['end']
324
+
325
+ if (max_expand_size is not None and delta <= max_expand_size):
326
+ # Just expand the current segment
327
+ current_segment = current_segment.copy()
328
+ current_segment['expand_amount'] = delta
329
+ current_segment['end'] = next_segment['start']
330
+ expanded = True
331
+
332
+ result.append(current_segment)
333
+
334
+ # Add a gap to the next segment if needed
335
+ if (delta >= 0 and not expanded):
336
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
337
+
338
+ # Add last segment
339
+ last_segment = segments[-1]
340
+ result.append(last_segment)
341
+
342
+ # Also include total duration if specified
343
+ if (total_duration is not None):
344
+ last_segment = result[-1]
345
+
346
+ delta = total_duration - last_segment['end']
347
+
348
+ if (delta > 0):
349
+ if (max_expand_size is not None and delta <= max_expand_size):
350
+ # Expand the last segment
351
+ last_segment = last_segment.copy()
352
+ last_segment['expand_amount'] = delta
353
+ last_segment['end'] = total_duration
354
+ result[-1] = last_segment
355
+ else:
356
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
357
+
358
+ return result
359
+
360
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
361
+ result = []
362
+
363
+ for segment in segments:
364
+ segment_start = float(segment['start'])
365
+ segment_end = float(segment['end'])
366
+
367
+ # Filter segments?
368
+ if (max_source_time is not None):
369
+ if (segment_start > max_source_time):
370
+ continue
371
+ segment_end = min(max_source_time, segment_end)
372
+
373
+ new_segment = segment.copy()
374
+
375
+ # Add to start and end
376
+ new_segment['start'] = segment_start + adjust_seconds
377
+ new_segment['end'] = segment_end + adjust_seconds
378
+ result.append(new_segment)
379
+ return result
380
+
381
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
382
+ result = []
383
+
384
+ for entry in timestamps:
385
+ start = entry['start']
386
+ end = entry['end']
387
+
388
+ result.append({
389
+ 'start': start * factor,
390
+ 'end': end * factor
391
+ })
392
+ return result
393
+
394
+
395
+ class VadSileroTranscription(AbstractTranscription):
396
+ def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
397
+ super().__init__(sampling_rate=sampling_rate)
398
+ self.model = None
399
+ self.cache = cache
400
+ self._initialize_model()
401
+
402
+ def _initialize_model(self):
403
+ if (self.cache is not None):
404
+ model_key = "VadSileroTranscription"
405
+ self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
406
+ print("Loaded Silerio model from cache.")
407
+ else:
408
+ self.model, self.get_speech_timestamps = self._create_model()
409
+ print("Created Silerio model")
410
+
411
+ def _create_model(self):
412
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
413
+
414
+ # Silero does not benefit from multi-threading
415
+ torch.set_num_threads(1) # JIT
416
+ (get_speech_timestamps, _, _, _, _) = utils
417
+
418
+ return model, get_speech_timestamps
419
+
420
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
421
+ result = []
422
+
423
+ print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
424
+ perf_start_time = time.perf_counter()
425
+
426
+ # Divide procesisng of audio into chunks
427
+ chunk_start = start_time
428
+
429
+ while (chunk_start < end_time):
430
+ chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
431
+
432
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
433
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
434
+
435
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
436
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
437
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
438
+
439
+ #pprint(adjusted)
440
+
441
+ result.extend(adjusted)
442
+ chunk_start += chunk_duration
443
+
444
+ perf_end_time = time.perf_counter()
445
+ print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
446
+
447
+ return result
448
+
449
+ def __getstate__(self):
450
+ # We only need the sampling rate
451
+ return { 'sampling_rate': self.sampling_rate }
452
+
453
+ def __setstate__(self, state):
454
+ self.sampling_rate = state['sampling_rate']
455
+ self.model = None
456
+ # Use the global cache
457
+ self.cache = GLOBAL_MODEL_CACHE
458
+ self._initialize_model()
459
+
460
+ # A very simple VAD that just marks every N seconds as speech
461
+ class VadPeriodicTranscription(AbstractTranscription):
462
+ def __init__(self, sampling_rate: int = 16000):
463
+ super().__init__(sampling_rate=sampling_rate)
464
+
465
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
466
+ result = []
467
+
468
+ # Generate a timestamp every N seconds
469
+ start_timestamp = start_time
470
+
471
+ while (start_timestamp < end_time):
472
+ end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
473
+ segment_duration = end_timestamp - start_timestamp
474
+
475
+ # Minimum duration is 1 second
476
+ if (segment_duration >= 1):
477
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
478
+
479
+ start_timestamp = end_timestamp
480
+
481
+ return result
482
+
483
+ def get_audio_duration(file: str):
484
+ return float(ffmpeg.probe(file)["format"]["duration"])
485
+
486
+ def load_audio(file: str, sample_rate: int = 16000,
487
+ start_time: str = None, duration: str = None):
488
+ """
489
+ Open an audio file and read as mono waveform, resampling as necessary
490
+
491
+ Parameters
492
+ ----------
493
+ file: str
494
+ The audio file to open
495
+
496
+ sr: int
497
+ The sample rate to resample the audio if necessary
498
+
499
+ start_time: str
500
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
501
+
502
+ duration: str
503
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
504
+
505
+ Returns
506
+ -------
507
+ A NumPy array containing the audio waveform, in float32 dtype.
508
+ """
509
+ try:
510
+ inputArgs = {'threads': 0}
511
+
512
+ if (start_time is not None):
513
+ inputArgs['ss'] = start_time
514
+ if (duration is not None):
515
+ inputArgs['t'] = duration
516
+
517
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
518
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
519
+ out, _ = (
520
+ ffmpeg.input(file, **inputArgs)
521
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
522
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
523
+ )
524
+ except ffmpeg.Error as e:
525
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
526
+
527
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
src/vadParallel.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import threading
3
+ import time
4
+ from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
5
+ from src.whisperContainer import WhisperCallback
6
+
7
+ from multiprocessing import Pool
8
+
9
+ from typing import Any, Dict, List
10
+ import os
11
+
12
+
13
+ class ParallelContext:
14
+ def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
15
+ self.num_processes = num_processes
16
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
17
+ self.lock = threading.Lock()
18
+
19
+ self.ref_count = 0
20
+ self.pool = None
21
+ self.cleanup_timer = None
22
+
23
+ def get_pool(self):
24
+ # Initialize pool lazily
25
+ if (self.pool is None):
26
+ context = multiprocessing.get_context('spawn')
27
+ self.pool = context.Pool(self.num_processes)
28
+
29
+ self.ref_count = self.ref_count + 1
30
+
31
+ if (self.auto_cleanup_timeout_seconds is not None):
32
+ self._stop_auto_cleanup()
33
+
34
+ return self.pool
35
+
36
+ def return_pool(self, pool):
37
+ if (self.pool == pool and self.ref_count > 0):
38
+ self.ref_count = self.ref_count - 1
39
+
40
+ if (self.ref_count == 0):
41
+ if (self.auto_cleanup_timeout_seconds is not None):
42
+ self._start_auto_cleanup()
43
+
44
+ def _start_auto_cleanup(self):
45
+ if (self.cleanup_timer is not None):
46
+ self.cleanup_timer.cancel()
47
+ self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
48
+ self.cleanup_timer.start()
49
+
50
+ print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
51
+
52
+ def _stop_auto_cleanup(self):
53
+ if (self.cleanup_timer is not None):
54
+ self.cleanup_timer.cancel()
55
+ self.cleanup_timer = None
56
+
57
+ print("Stopped auto cleanup of pool")
58
+
59
+ def _execute_cleanup(self):
60
+ print("Executing cleanup of pool")
61
+
62
+ if (self.ref_count == 0):
63
+ self.close()
64
+
65
+ def close(self):
66
+ self._stop_auto_cleanup()
67
+
68
+ if (self.pool is not None):
69
+ print("Closing pool of " + str(self.num_processes) + " processes")
70
+ self.pool.close()
71
+ self.pool.join()
72
+ self.pool = None
73
+
74
+ class ParallelTranscriptionConfig(TranscriptionConfig):
75
+ def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
76
+ super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
77
+ self.device_id = device_id
78
+ self.override_timestamps = override_timestamps
79
+
80
+ class ParallelTranscription(AbstractTranscription):
81
+ # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
82
+ # into smaller segments than 2 minute (min 6 seconds per CPU core)
83
+ MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
84
+
85
+ def __init__(self, sampling_rate: int = 16000):
86
+ super().__init__(sampling_rate=sampling_rate)
87
+
88
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
89
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None):
90
+ total_duration = get_audio_duration(audio)
91
+
92
+ # First, get the timestamps for the original audio
93
+ if (cpu_device_count > 1):
94
+ merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
95
+ else:
96
+ merged = transcription.get_merged_timestamps(audio, config, total_duration)
97
+
98
+ # Split into a list for each device
99
+ # TODO: Split by time instead of by number of chunks
100
+ merged_split = list(self._split(merged, len(gpu_devices)))
101
+
102
+ # Parameters that will be passed to the transcribe function
103
+ parameters = []
104
+ segment_index = config.initial_segment_index
105
+
106
+ for i in range(len(merged_split)):
107
+ device_segment_list = list(merged_split[i])
108
+ device_id = gpu_devices[i]
109
+
110
+ if (len(device_segment_list) <= 0):
111
+ continue
112
+
113
+ print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
114
+
115
+ # Create a new config with the given device ID
116
+ device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
117
+ segment_index += len(device_segment_list)
118
+
119
+ parameters.append([audio, whisperCallable, device_config]);
120
+
121
+ merged = {
122
+ 'text': '',
123
+ 'segments': [],
124
+ 'language': None
125
+ }
126
+
127
+ created_context = False
128
+
129
+ perf_start_gpu = time.perf_counter()
130
+
131
+ # Spawn a separate process for each device
132
+ try:
133
+ if (gpu_parallel_context is None):
134
+ gpu_parallel_context = ParallelContext(len(gpu_devices))
135
+ created_context = True
136
+
137
+ # Get a pool of processes
138
+ pool = gpu_parallel_context.get_pool()
139
+
140
+ # Run the transcription in parallel
141
+ results = pool.starmap(self.transcribe, parameters)
142
+
143
+ for result in results:
144
+ # Merge the results
145
+ if (result['text'] is not None):
146
+ merged['text'] += result['text']
147
+ if (result['segments'] is not None):
148
+ merged['segments'].extend(result['segments'])
149
+ if (result['language'] is not None):
150
+ merged['language'] = result['language']
151
+
152
+ finally:
153
+ # Return the pool to the context
154
+ if (gpu_parallel_context is not None):
155
+ gpu_parallel_context.return_pool(pool)
156
+ # Always close the context if we created it
157
+ if (created_context):
158
+ gpu_parallel_context.close()
159
+
160
+ perf_end_gpu = time.perf_counter()
161
+ print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
162
+
163
+ return merged
164
+
165
+ def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
166
+ cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
167
+ parameters = []
168
+
169
+ chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
170
+ chunk_start = 0
171
+ cpu_device_id = 0
172
+
173
+ perf_start_time = time.perf_counter()
174
+
175
+ # Create chunks that will be processed on the CPU
176
+ while (chunk_start < total_duration):
177
+ chunk_end = min(chunk_start + chunk_size, total_duration)
178
+
179
+ if (chunk_end - chunk_start < 1):
180
+ # No need to process chunks that are less than 1 second
181
+ break
182
+
183
+ print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
184
+ str(chunk_end) + " on CPU device " + str(cpu_device_id))
185
+ parameters.append([audio, config, chunk_start, chunk_end]);
186
+
187
+ cpu_device_id += 1
188
+ chunk_start = chunk_end
189
+
190
+ created_context = False
191
+
192
+ # Spawn a separate process for each device
193
+ try:
194
+ if (cpu_parallel_context is None):
195
+ cpu_parallel_context = ParallelContext(cpu_device_count)
196
+ created_context = True
197
+
198
+ # Get a pool of processes
199
+ pool = cpu_parallel_context.get_pool()
200
+
201
+ # Run the transcription in parallel. Note that transcription must be picklable.
202
+ results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
203
+
204
+ timestamps = []
205
+
206
+ # Flatten the results
207
+ for result in results:
208
+ timestamps.extend(result)
209
+
210
+ merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
211
+
212
+ perf_end_time = time.perf_counter()
213
+ print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
214
+ return merged
215
+
216
+ finally:
217
+ # Return the pool to the context
218
+ if (cpu_parallel_context is not None):
219
+ cpu_parallel_context.return_pool(pool)
220
+ # Always close the context if we created it
221
+ if (created_context):
222
+ cpu_parallel_context.close()
223
+
224
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
225
+ return []
226
+
227
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
228
+ # Override timestamps that will be processed
229
+ if (config.override_timestamps is not None):
230
+ print("Using override timestamps of size " + str(len(config.override_timestamps)))
231
+ return config.override_timestamps
232
+ return super().get_merged_timestamps(timestamps, config, total_duration)
233
+
234
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
235
+ # Override device ID the first time
236
+ if (os.environ.get("INITIALIZED", None) is None):
237
+ os.environ["INITIALIZED"] = "1"
238
+
239
+ # Note that this may be None if the user didn't specify a device. In that case, Whisper will
240
+ # just use the default GPU device.
241
+ if (config.device_id is not None):
242
+ print("Using device " + config.device_id)
243
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
244
+
245
+ return super().transcribe(audio, whisperCallable, config)
246
+
247
+ def _split(self, a, n):
248
+ """Split a list into n approximately equal parts."""
249
+ k, m = divmod(len(a), n)
250
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
251
+
src/whisperContainer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import whisper
3
+
4
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
5
+
6
+ class WhisperContainer:
7
+ def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
8
+ self.model_name = model_name
9
+ self.device = device
10
+ self.download_root = download_root
11
+ self.cache = cache
12
+
13
+ # Will be created on demand
14
+ self.model = None
15
+
16
+ def get_model(self):
17
+ if self.model is None:
18
+
19
+ if (self.cache is None):
20
+ self.model = self._create_model()
21
+ else:
22
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
23
+ self.model = self.cache.get(model_key, self._create_model)
24
+ return self.model
25
+
26
+ def _create_model(self):
27
+ print("Loading whisper model " + self.model_name)
28
+ return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
29
+
30
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
31
+ """
32
+ Create a WhisperCallback object that can be used to transcript audio files.
33
+
34
+ Parameters
35
+ ----------
36
+ language: str
37
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
38
+ task: str
39
+ The task - either translate or transcribe.
40
+ initial_prompt: str
41
+ The initial prompt to use for the transcription.
42
+ decodeOptions: dict
43
+ Additional options to pass to the decoder. Must be pickleable.
44
+
45
+ Returns
46
+ -------
47
+ A WhisperCallback object.
48
+ """
49
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
50
+
51
+ # This is required for multiprocessing
52
+ def __getstate__(self):
53
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
54
+
55
+ def __setstate__(self, state):
56
+ self.model_name = state["model_name"]
57
+ self.device = state["device"]
58
+ self.download_root = state["download_root"]
59
+ self.model = None
60
+ # Depickled objects must use the global cache
61
+ self.cache = GLOBAL_MODEL_CACHE
62
+
63
+
64
+ class WhisperCallback:
65
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
66
+ self.model_container = model_container
67
+ self.language = language
68
+ self.task = task
69
+ self.initial_prompt = initial_prompt
70
+ self.decodeOptions = decodeOptions
71
+
72
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
73
+ """
74
+ Peform the transcription of the given audio file or data.
75
+
76
+ Parameters
77
+ ----------
78
+ audio: Union[str, np.ndarray, torch.Tensor]
79
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
80
+ segment_index: int
81
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
82
+ task: str
83
+ The task - either translate or transcribe.
84
+ prompt: str
85
+ The prompt to use for the transcription.
86
+ detected_language: str
87
+ The detected language of the audio file.
88
+
89
+ Returns
90
+ -------
91
+ The result of the Whisper call.
92
+ """
93
+ model = self.model_container.get_model()
94
+
95
+ return model.transcribe(audio, \
96
+ language=self.language if self.language else detected_language, task=self.task, \
97
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
98
+ **self.decodeOptions)
99
+
100
+ def _concat_prompt(self, prompt1, prompt2):
101
+ if (prompt1 is None):
102
+ return prompt2
103
+ elif (prompt2 is None):
104
+ return prompt1
105
+ else:
106
+ return prompt1 + " " + prompt2
tests/segments_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+
4
+ sys.path.append('../whisper-webui')
5
+
6
+ from src.segments import merge_timestamps
7
+
8
+ class TestSegments(unittest.TestCase):
9
+ def __init__(self, *args, **kwargs):
10
+ super(TestSegments, self).__init__(*args, **kwargs)
11
+
12
+ def test_merge_segments(self):
13
+ segments = [
14
+ {'start': 10.0, 'end': 20.0},
15
+ {'start': 22.0, 'end': 27.0},
16
+ {'start': 31.0, 'end': 35.0},
17
+ {'start': 45.0, 'end': 60.0},
18
+ {'start': 61.0, 'end': 65.0},
19
+ {'start': 68.0, 'end': 98.0},
20
+ {'start': 100.0, 'end': 102.0},
21
+ {'start': 110.0, 'end': 112.0}
22
+ ]
23
+
24
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
25
+
26
+ self.assertListEqual(result, [
27
+ {'start': 9.0, 'end': 36.0},
28
+ {'start': 44.0, 'end': 66.0},
29
+ {'start': 67.0, 'end': 99.0},
30
+ {'start': 99.0, 'end': 103.0},
31
+ {'start': 109.0, 'end': 113.0}
32
+ ])
33
+
34
+ def test_overlap_next(self):
35
+ segments = [
36
+ {'start': 5.0, 'end': 39.182},
37
+ {'start': 39.986, 'end': 40.814}
38
+ ]
39
+
40
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
41
+
42
+ self.assertListEqual(result, [
43
+ {'start': 4.0, 'end': 39.584},
44
+ {'start': 39.584, 'end': 41.814}
45
+ ])
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../whisper-webui')
7
+
8
+ from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()