avans06 aadnk commited on
Commit
b6ac700
·
0 Parent(s):

Duplicate from aadnk/whisper-webui

Browse files

Co-authored-by: Kristian Stangeland <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ .vscode/
4
+ flagged/
5
+ *.py[cod]
6
+ *$py.class
LICENSE.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ ==============
3
+
4
+ _Version 2.0, January 2004_
5
+ _&lt;<http://www.apache.org/licenses/>&gt;_
6
+
7
+ ### Terms and Conditions for use, reproduction, and distribution
8
+
9
+ #### 1. Definitions
10
+
11
+ “License” shall mean the terms and conditions for use, reproduction, and
12
+ distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright
15
+ owner that is granting the License.
16
+
17
+ “Legal Entity” shall mean the union of the acting entity and all other entities
18
+ that control, are controlled by, or are under common control with that entity.
19
+ For the purposes of this definition, “control” means **(i)** the power, direct or
20
+ indirect, to cause the direction or management of such entity, whether by
21
+ contract or otherwise, or **(ii)** ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or **(iii)** beneficial ownership of such entity.
23
+
24
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising
25
+ permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including
28
+ but not limited to software source code, documentation source, and configuration
29
+ files.
30
+
31
+ “Object” form shall mean any form resulting from mechanical transformation or
32
+ translation of a Source form, including but not limited to compiled object code,
33
+ generated documentation, and conversions to other media types.
34
+
35
+ “Work” shall mean the work of authorship, whether in Source or Object form, made
36
+ available under the License, as indicated by a copyright notice that is included
37
+ in or attached to the work (an example is provided in the Appendix below).
38
+
39
+ “Derivative Works” shall mean any work, whether in Source or Object form, that
40
+ is based on (or derived from) the Work and for which the editorial revisions,
41
+ annotations, elaborations, or other modifications represent, as a whole, an
42
+ original work of authorship. For the purposes of this License, Derivative Works
43
+ shall not include works that remain separable from, or merely link (or bind by
44
+ name) to the interfaces of, the Work and Derivative Works thereof.
45
+
46
+ “Contribution” shall mean any work of authorship, including the original version
47
+ of the Work and any modifications or additions to that Work or Derivative Works
48
+ thereof, that is intentionally submitted to Licensor for inclusion in the Work
49
+ by the copyright owner or by an individual or Legal Entity authorized to submit
50
+ on behalf of the copyright owner. For the purposes of this definition,
51
+ “submitted” means any form of electronic, verbal, or written communication sent
52
+ to the Licensor or its representatives, including but not limited to
53
+ communication on electronic mailing lists, source code control systems, and
54
+ issue tracking systems that are managed by, or on behalf of, the Licensor for
55
+ the purpose of discussing and improving the Work, but excluding communication
56
+ that is conspicuously marked or otherwise designated in writing by the copyright
57
+ owner as “Not a Contribution.”
58
+
59
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf
60
+ of whom a Contribution has been received by Licensor and subsequently
61
+ incorporated within the Work.
62
+
63
+ #### 2. Grant of Copyright License
64
+
65
+ Subject to the terms and conditions of this License, each Contributor hereby
66
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
67
+ irrevocable copyright license to reproduce, prepare Derivative Works of,
68
+ publicly display, publicly perform, sublicense, and distribute the Work and such
69
+ Derivative Works in Source or Object form.
70
+
71
+ #### 3. Grant of Patent License
72
+
73
+ Subject to the terms and conditions of this License, each Contributor hereby
74
+ grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
75
+ irrevocable (except as stated in this section) patent license to make, have
76
+ made, use, offer to sell, sell, import, and otherwise transfer the Work, where
77
+ such license applies only to those patent claims licensable by such Contributor
78
+ that are necessarily infringed by their Contribution(s) alone or by combination
79
+ of their Contribution(s) with the Work to which such Contribution(s) was
80
+ submitted. If You institute patent litigation against any entity (including a
81
+ cross-claim or counterclaim in a lawsuit) alleging that the Work or a
82
+ Contribution incorporated within the Work constitutes direct or contributory
83
+ patent infringement, then any patent licenses granted to You under this License
84
+ for that Work shall terminate as of the date such litigation is filed.
85
+
86
+ #### 4. Redistribution
87
+
88
+ You may reproduce and distribute copies of the Work or Derivative Works thereof
89
+ in any medium, with or without modifications, and in Source or Object form,
90
+ provided that You meet the following conditions:
91
+
92
+ * **(a)** You must give any other recipients of the Work or Derivative Works a copy of
93
+ this License; and
94
+ * **(b)** You must cause any modified files to carry prominent notices stating that You
95
+ changed the files; and
96
+ * **(c)** You must retain, in the Source form of any Derivative Works that You distribute,
97
+ all copyright, patent, trademark, and attribution notices from the Source form
98
+ of the Work, excluding those notices that do not pertain to any part of the
99
+ Derivative Works; and
100
+ * **(d)** If the Work includes a “NOTICE” text file as part of its distribution, then any
101
+ Derivative Works that You distribute must include a readable copy of the
102
+ attribution notices contained within such NOTICE file, excluding those notices
103
+ that do not pertain to any part of the Derivative Works, in at least one of the
104
+ following places: within a NOTICE text file distributed as part of the
105
+ Derivative Works; within the Source form or documentation, if provided along
106
+ with the Derivative Works; or, within a display generated by the Derivative
107
+ Works, if and wherever such third-party notices normally appear. The contents of
108
+ the NOTICE file are for informational purposes only and do not modify the
109
+ License. You may add Your own attribution notices within Derivative Works that
110
+ You distribute, alongside or as an addendum to the NOTICE text from the Work,
111
+ provided that such additional attribution notices cannot be construed as
112
+ modifying the License.
113
+
114
+ You may add Your own copyright statement to Your modifications and may provide
115
+ additional or different license terms and conditions for use, reproduction, or
116
+ distribution of Your modifications, or for any such Derivative Works as a whole,
117
+ provided Your use, reproduction, and distribution of the Work otherwise complies
118
+ with the conditions stated in this License.
119
+
120
+ #### 5. Submission of Contributions
121
+
122
+ Unless You explicitly state otherwise, any Contribution intentionally submitted
123
+ for inclusion in the Work by You to the Licensor shall be under the terms and
124
+ conditions of this License, without any additional terms or conditions.
125
+ Notwithstanding the above, nothing herein shall supersede or modify the terms of
126
+ any separate license agreement you may have executed with Licensor regarding
127
+ such Contributions.
128
+
129
+ #### 6. Trademarks
130
+
131
+ This License does not grant permission to use the trade names, trademarks,
132
+ service marks, or product names of the Licensor, except as required for
133
+ reasonable and customary use in describing the origin of the Work and
134
+ reproducing the content of the NOTICE file.
135
+
136
+ #### 7. Disclaimer of Warranty
137
+
138
+ Unless required by applicable law or agreed to in writing, Licensor provides the
139
+ Work (and each Contributor provides its Contributions) on an “AS IS” BASIS,
140
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
141
+ including, without limitation, any warranties or conditions of TITLE,
142
+ NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
143
+ solely responsible for determining the appropriateness of using or
144
+ redistributing the Work and assume any risks associated with Your exercise of
145
+ permissions under this License.
146
+
147
+ #### 8. Limitation of Liability
148
+
149
+ In no event and under no legal theory, whether in tort (including negligence),
150
+ contract, or otherwise, unless required by applicable law (such as deliberate
151
+ and grossly negligent acts) or agreed to in writing, shall any Contributor be
152
+ liable to You for damages, including any direct, indirect, special, incidental,
153
+ or consequential damages of any character arising as a result of this License or
154
+ out of the use or inability to use the Work (including but not limited to
155
+ damages for loss of goodwill, work stoppage, computer failure or malfunction, or
156
+ any and all other commercial damages or losses), even if such Contributor has
157
+ been advised of the possibility of such damages.
158
+
159
+ #### 9. Accepting Warranty or Additional Liability
160
+
161
+ While redistributing the Work or Derivative Works thereof, You may choose to
162
+ offer, and charge a fee for, acceptance of support, warranty, indemnity, or
163
+ other liability obligations and/or rights consistent with this License. However,
164
+ in accepting such obligations, You may act only on Your own behalf and on Your
165
+ sole responsibility, not on behalf of any other Contributor, and only if You
166
+ agree to indemnify, defend, and hold each Contributor harmless for any liability
167
+ incurred by, or claims asserted against, such Contributor by reason of your
168
+ accepting any such warranty or additional liability.
169
+
170
+ _END OF TERMS AND CONDITIONS_
171
+
172
+ ### APPENDIX: How to apply the Apache License to your work
173
+
174
+ To apply the Apache License to your work, attach the following boilerplate
175
+ notice, with the fields enclosed by brackets `[]` replaced with your own
176
+ identifying information. (Don't include the brackets!) The text should be
177
+ enclosed in the appropriate comment syntax for the file format. We also
178
+ recommend that a file or class name and description of purpose be included on
179
+ the same “printed page” as the copyright notice for easier identification within
180
+ third-party archives.
181
+
182
+ Copyright [yyyy] [name of copyright owner]
183
+
184
+ Licensed under the Apache License, Version 2.0 (the "License");
185
+ you may not use this file except in compliance with the License.
186
+ You may obtain a copy of the License at
187
+
188
+ http://www.apache.org/licenses/LICENSE-2.0
189
+
190
+ Unless required by applicable law or agreed to in writing, software
191
+ distributed under the License is distributed on an "AS IS" BASIS,
192
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
193
+ See the License for the specific language governing permissions and
194
+ limitations under the License.
195
+
README.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Whisper Webui
3
+ emoji: ⚡
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: aadnk/whisper-webui
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ # Running Locally
17
+
18
+ To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
19
+ ```
20
+ pip install -r requirements.txt
21
+ ```
22
+
23
+ You can find detailed instructions for how to install this on Windows 10/11 [here (PDF)](docs/windows/install_win10_win11.pdf).
24
+
25
+ Finally, run the full version (no audio length restrictions) of the app with parallel CPU/GPU enabled:
26
+ ```
27
+ python app.py --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
28
+ ```
29
+
30
+ You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
31
+ ```
32
+ python cli.py \
33
+ [--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
34
+ [--vad_merge_window VAD_MERGE_WINDOW] \
35
+ [--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
36
+ [--vad_padding VAD_PADDING] \
37
+ [--vad_prompt_window VAD_PROMPT_WINDOW]
38
+ [--vad_cpu_cores NUMBER_OF_CORES]
39
+ [--vad_parallel_devices COMMA_DELIMITED_DEVICES]
40
+ [--auto_parallel BOOLEAN]
41
+ ```
42
+ In addition, you may also use URL's in addition to file paths as input.
43
+ ```
44
+ python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
45
+ ```
46
+
47
+ Rather than supplying arguments to `app.py` or `cli.py`, you can also use the configuration file [config.json5](config.json5). See that file for more information.
48
+ If you want to use a different configuration file, you can use the `WHISPER_WEBUI_CONFIG` environment variable to specify the path to another file.
49
+
50
+ ### Multiple Files
51
+
52
+ You can upload multiple files either through the "Upload files" option, or as a playlist on YouTube.
53
+ Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section.
54
+ When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
55
+
56
+ ## Diarization
57
+
58
+ To detect different speakers in the audio, you can use the [whisper-diarization](https://gitlab.com/aadnk/whisper-diarization) application.
59
+
60
+ Download the JSON file after running Whisper on an audio file, and then run app.py in the
61
+ whisper-diarization repository with the audio file and the JSON file as arguments.
62
+
63
+ ## Whisper Implementation
64
+
65
+ You can choose between using `whisper` or `faster-whisper`. [Faster Whisper](https://github.com/guillaumekln/faster-whisper) as a drop-in replacement for the
66
+ default Whisper which achieves up to a 4x speedup and 2x reduction in memory usage.
67
+
68
+ You can install the requirements for a specific Whisper implementation in `requirements-fasterWhisper.txt`
69
+ or `requirements-whisper.txt`:
70
+ ```
71
+ pip install -r requirements-fasterWhisper.txt
72
+ ```
73
+ And then run the App or the CLI with the `--whisper_implementation faster-whisper` flag:
74
+ ```
75
+ python app.py --whisper_implementation faster-whisper --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
76
+ ```
77
+ You can also select the whisper implementation in `config.json5`:
78
+ ```json5
79
+ {
80
+ "whisper_implementation": "faster-whisper"
81
+ }
82
+ ```
83
+ ### GPU Acceleration
84
+
85
+ In order to use GPU acceleration with Faster Whisper, both CUDA 11.2 and cuDNN 8 must be installed. You may want to install it in a virtual environment like Anaconda.
86
+
87
+ ## Google Colab
88
+
89
+ You can also run this Web UI directly on [Google Colab](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing), if you haven't got a GPU powerful enough to run the larger models.
90
+
91
+ See the [colab documentation](docs/colab.md) for more information.
92
+
93
+ ## Parallel Execution
94
+
95
+ You can also run both the Web-UI or the CLI on multiple GPUs in parallel, using the `vad_parallel_devices` option. This takes a comma-delimited list of
96
+ device IDs (0, 1, etc.) that Whisper should be distributed to and run on concurrently:
97
+ ```
98
+ python cli.py --model large --vad silero-vad --language Japanese \
99
+ --vad_parallel_devices 0,1 "https://www.youtube.com/watch?v=4cICErqqRSM"
100
+ ```
101
+
102
+ Note that this requires a VAD to function properly, otherwise only the first GPU will be used. Though you could use `period-vad` to avoid taking the hit
103
+ of running Silero-Vad, at a slight cost to accuracy.
104
+
105
+ This is achieved by creating N child processes (where N is the number of selected devices), where Whisper is run concurrently. In `app.py`, you can also
106
+ set the `vad_process_timeout` option. This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory.
107
+ The default value is 30 minutes.
108
+
109
+ ```
110
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600
111
+ ```
112
+
113
+ To execute the Silero VAD itself in parallel, use the `vad_cpu_cores` option:
114
+ ```
115
+ python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600 --vad_cpu_cores 4
116
+ ```
117
+
118
+ You may also use `vad_process_timeout` with a single device (`--vad_parallel_devices 0`), if you prefer to always free video memory after a period of time.
119
+
120
+ ### Auto Parallel
121
+
122
+ You can also set `auto_parallel` to `True`. This will set `vad_parallel_devices` to use all the GPU devices on the system, and `vad_cpu_cores` to be equal to the number of
123
+ cores (up to 8):
124
+ ```
125
+ python app.py --input_audio_max_duration -1 --auto_parallel True
126
+ ```
127
+
128
+ # Docker
129
+
130
+ To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU.
131
+ Then either use the GitLab hosted container below, or check out this repository and build an image:
132
+ ```
133
+ sudo docker build -t whisper-webui:1 .
134
+ ```
135
+
136
+ You can then start the WebUI with GPU support like so:
137
+ ```
138
+ sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
139
+ ```
140
+
141
+ Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
142
+ ```
143
+ sudo docker run -d -p 7860:7860 whisper-webui:1
144
+ ```
145
+
146
+ # GitLab Docker Registry
147
+
148
+ This Docker container is also hosted on GitLab:
149
+
150
+ ```
151
+ sudo docker run -d --gpus=all -p 7860:7860 registry.gitlab.com/aadnk/whisper-webui:latest
152
+ ```
153
+
154
+ ## Custom Arguments
155
+
156
+ You can also pass custom arguments to `app.py` in the Docker container, for instance to be able to use all the GPUs in parallel (replace administrator with your user):
157
+ ```
158
+ sudo docker run -d --gpus all -p 7860:7860 \
159
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
160
+ --mount type=bind,source=/home/administrator/.cache/huggingface,target=/root/.cache/huggingface \
161
+ --restart=on-failure:15 registry.gitlab.com/aadnk/whisper-webui:latest \
162
+ app.py --input_audio_max_duration -1 --server_name 0.0.0.0 --auto_parallel True \
163
+ --default_vad silero-vad --default_model_name large
164
+ ```
165
+
166
+ You can also call `cli.py` the same way:
167
+ ```
168
+ sudo docker run --gpus all \
169
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
170
+ --mount type=bind,source=/home/administrator/.cache/huggingface,target=/root/.cache/huggingface \
171
+ --mount type=bind,source=${PWD},target=/app/data \
172
+ registry.gitlab.com/aadnk/whisper-webui:latest \
173
+ cli.py --model large --auto_parallel True --vad silero-vad \
174
+ --output_dir /app/data /app/data/YOUR-FILE-HERE.mp4
175
+ ```
176
+
177
+ ## Caching
178
+
179
+ Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
180
+ To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
181
+ prepopulate the directory with the different Whisper models.
182
+ ```
183
+ sudo docker run -d --gpus=all -p 7860:7860 \
184
+ --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
185
+ registry.gitlab.com/aadnk/whisper-webui:latest
186
+ ```
app-local.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1))
app-network.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions, and make it available on the network
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, server_name="0.0.0.0"))
app-shared.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, share=True))
app.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import json
3
+ import math
4
+ from typing import Iterator, Union
5
+ import argparse
6
+
7
+ from io import StringIO
8
+ import os
9
+ import pathlib
10
+ import tempfile
11
+ import zipfile
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
17
+ from src.hooks.progressListener import ProgressListener
18
+ from src.hooks.subTaskProgressListener import SubTaskProgressListener
19
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
20
+ from src.languages import get_language_names
21
+ from src.modelCache import ModelCache
22
+ from src.prompts.jsonPromptStrategy import JsonPromptStrategy
23
+ from src.prompts.prependPromptStrategy import PrependPromptStrategy
24
+ from src.source import get_audio_source_collection
25
+ from src.vadParallel import ParallelContext, ParallelTranscription
26
+
27
+ # External programs
28
+ import ffmpeg
29
+
30
+ # UI
31
+ import gradio as gr
32
+
33
+ from src.download import ExceededMaximumDuration, download_url
34
+ from src.utils import optional_int, slugify, write_srt, write_vtt
35
+ from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
36
+ from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
37
+ from src.whisper.whisperFactory import create_whisper_container
38
+
39
+ # Configure more application defaults in config.json5
40
+
41
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
42
+ MAX_FILE_PREFIX_LENGTH = 17
43
+
44
+ # Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
45
+ MAX_AUTO_CPU_CORES = 8
46
+
47
+ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
48
+
49
+ class VadOptions:
50
+ def __init__(self, vad: str = None, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
51
+ vadInitialPromptMode: Union[VadInitialPromptMode, str] = VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
52
+ self.vad = vad
53
+ self.vadMergeWindow = vadMergeWindow
54
+ self.vadMaxMergeSize = vadMaxMergeSize
55
+ self.vadPadding = vadPadding
56
+ self.vadPromptWindow = vadPromptWindow
57
+ self.vadInitialPromptMode = vadInitialPromptMode if isinstance(vadInitialPromptMode, VadInitialPromptMode) \
58
+ else VadInitialPromptMode.from_string(vadInitialPromptMode)
59
+
60
+ class WhisperTranscriber:
61
+ def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
62
+ vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
63
+ app_config: ApplicationConfig = None):
64
+ self.model_cache = ModelCache()
65
+ self.parallel_device_list = None
66
+ self.gpu_parallel_context = None
67
+ self.cpu_parallel_context = None
68
+ self.vad_process_timeout = vad_process_timeout
69
+ self.vad_cpu_cores = vad_cpu_cores
70
+
71
+ self.vad_model = None
72
+ self.inputAudioMaxDuration = input_audio_max_duration
73
+ self.deleteUploadedFiles = delete_uploaded_files
74
+ self.output_dir = output_dir
75
+
76
+ self.app_config = app_config
77
+
78
+ def set_parallel_devices(self, vad_parallel_devices: str):
79
+ self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
80
+
81
+ def set_auto_parallel(self, auto_parallel: bool):
82
+ if auto_parallel:
83
+ if torch.cuda.is_available():
84
+ self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
85
+
86
+ self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
87
+ print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
88
+
89
+ # Entry function for the simple tab
90
+ def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
91
+ vad, vadMergeWindow, vadMaxMergeSize,
92
+ word_timestamps: bool = False, highlight_words: bool = False):
93
+ return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
94
+ vad, vadMergeWindow, vadMaxMergeSize,
95
+ word_timestamps, highlight_words)
96
+
97
+ # Entry function for the simple tab progress
98
+ def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
99
+ vad, vadMergeWindow, vadMaxMergeSize,
100
+ word_timestamps: bool = False, highlight_words: bool = False,
101
+ progress=gr.Progress()):
102
+
103
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
104
+
105
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
106
+ word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
107
+
108
+ # Entry function for the full tab
109
+ def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
110
+ vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
111
+ # Word timestamps
112
+ word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
113
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
114
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
115
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
116
+
117
+ return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
118
+ vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
119
+ word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
120
+ initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
121
+ condition_on_previous_text, fp16, temperature_increment_on_fallback,
122
+ compression_ratio_threshold, logprob_threshold, no_speech_threshold)
123
+
124
+ # Entry function for the full tab with progress
125
+ def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
126
+ vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
127
+ # Word timestamps
128
+ word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
129
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
130
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
131
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
132
+ progress=gr.Progress()):
133
+
134
+ # Handle temperature_increment_on_fallback
135
+ if temperature_increment_on_fallback is not None:
136
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
137
+ else:
138
+ temperature = [temperature]
139
+
140
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
141
+
142
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
143
+ initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
144
+ condition_on_previous_text=condition_on_previous_text, fp16=fp16,
145
+ compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
146
+ word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
147
+ progress=progress)
148
+
149
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
150
+ vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
151
+ **decodeOptions: dict):
152
+ try:
153
+ sources = self.__get_source(urlData, multipleFiles, microphoneData)
154
+
155
+ try:
156
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
157
+ selectedModel = modelName if modelName is not None else "base"
158
+
159
+ model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
160
+ model_name=selectedModel, compute_type=self.app_config.compute_type,
161
+ cache=self.model_cache, models=self.app_config.models)
162
+
163
+ # Result
164
+ download = []
165
+ zip_file_lookup = {}
166
+ text = ""
167
+ vtt = ""
168
+
169
+ # Write result
170
+ downloadDirectory = tempfile.mkdtemp()
171
+ source_index = 0
172
+
173
+ outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
174
+
175
+ # Progress
176
+ total_duration = sum([source.get_audio_duration() for source in sources])
177
+ current_progress = 0
178
+
179
+ # A listener that will report progress to Gradio
180
+ root_progress_listener = self._create_progress_listener(progress)
181
+
182
+ # Execute whisper
183
+ for source in sources:
184
+ source_prefix = ""
185
+ source_audio_duration = source.get_audio_duration()
186
+
187
+ if (len(sources) > 1):
188
+ # Prefix (minimum 2 digits)
189
+ source_index += 1
190
+ source_prefix = str(source_index).zfill(2) + "_"
191
+ print("Transcribing ", source.source_path)
192
+
193
+ scaled_progress_listener = SubTaskProgressListener(root_progress_listener,
194
+ base_task_total=total_duration,
195
+ sub_task_start=current_progress,
196
+ sub_task_total=source_audio_duration)
197
+
198
+ # Transcribe
199
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
200
+ filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
201
+
202
+ # Update progress
203
+ current_progress += source_audio_duration
204
+
205
+ source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
206
+
207
+ if len(sources) > 1:
208
+ # Add new line separators
209
+ if (len(source_text) > 0):
210
+ source_text += os.linesep + os.linesep
211
+ if (len(source_vtt) > 0):
212
+ source_vtt += os.linesep + os.linesep
213
+
214
+ # Append file name to source text too
215
+ source_text = source.get_full_name() + ":" + os.linesep + source_text
216
+ source_vtt = source.get_full_name() + ":" + os.linesep + source_vtt
217
+
218
+ # Add to result
219
+ download.extend(source_download)
220
+ text += source_text
221
+ vtt += source_vtt
222
+
223
+ if (len(sources) > 1):
224
+ # Zip files support at least 260 characters, but we'll play it safe and use 200
225
+ zipFilePrefix = slugify(source_prefix + source.get_short_name(max_length=200), allow_unicode=True)
226
+
227
+ # File names in ZIP file can be longer
228
+ for source_download_file in source_download:
229
+ # Get file postfix (after last -)
230
+ filePostfix = os.path.basename(source_download_file).split("-")[-1]
231
+ zip_file_name = zipFilePrefix + "-" + filePostfix
232
+ zip_file_lookup[source_download_file] = zip_file_name
233
+
234
+ # Create zip file from all sources
235
+ if len(sources) > 1:
236
+ downloadAllPath = os.path.join(downloadDirectory, "All_Output-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
237
+
238
+ with zipfile.ZipFile(downloadAllPath, 'w', zipfile.ZIP_DEFLATED) as zip:
239
+ for download_file in download:
240
+ # Get file name from lookup
241
+ zip_file_name = zip_file_lookup.get(download_file, os.path.basename(download_file))
242
+ zip.write(download_file, arcname=zip_file_name)
243
+
244
+ download.insert(0, downloadAllPath)
245
+
246
+ return download, text, vtt
247
+
248
+ finally:
249
+ # Cleanup source
250
+ if self.deleteUploadedFiles:
251
+ for source in sources:
252
+ print("Deleting source file " + source.source_path)
253
+
254
+ try:
255
+ os.remove(source.source_path)
256
+ except Exception as e:
257
+ # Ignore error - it's just a cleanup
258
+ print("Error deleting source file " + source.source_path + ": " + str(e))
259
+
260
+ except ExceededMaximumDuration as e:
261
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
262
+
263
+ def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
264
+ vadOptions: VadOptions = VadOptions(),
265
+ progressListener: ProgressListener = None, **decodeOptions: dict):
266
+
267
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
268
+
269
+ if progressListener is None:
270
+ # Default progress listener
271
+ progressListener = ProgressListener()
272
+
273
+ if ('task' in decodeOptions):
274
+ task = decodeOptions.pop('task')
275
+
276
+ initial_prompt_mode = vadOptions.vadInitialPromptMode
277
+
278
+ # Set default initial prompt mode
279
+ if (initial_prompt_mode is None):
280
+ initial_prompt_mode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT
281
+
282
+ if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
283
+ initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
284
+ # Prepend initial prompt
285
+ prompt_strategy = PrependPromptStrategy(initial_prompt, initial_prompt_mode)
286
+ elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
287
+ # Use a JSON format to specify the prompt for each segment
288
+ prompt_strategy = JsonPromptStrategy(initial_prompt)
289
+ else:
290
+ raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
291
+
292
+ # Callable for processing an audio file
293
+ whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
294
+
295
+ # The results
296
+ if (vadOptions.vad == 'silero-vad'):
297
+ # Silero VAD where non-speech gaps are transcribed
298
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadOptions)
299
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
300
+ elif (vadOptions.vad == 'silero-vad-skip-gaps'):
301
+ # Silero VAD where non-speech gaps are simply ignored
302
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadOptions)
303
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
304
+ elif (vadOptions.vad == 'silero-vad-expand-into-gaps'):
305
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
306
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadOptions)
307
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
308
+ elif (vadOptions.vad == 'periodic-vad'):
309
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
310
+ # it may create a break in the middle of a sentence, causing some artifacts.
311
+ periodic_vad = VadPeriodicTranscription()
312
+ period_config = PeriodicTranscriptionConfig(periodic_duration=vadOptions.vadMaxMergeSize, max_prompt_window=vadOptions.vadPromptWindow)
313
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
314
+
315
+ else:
316
+ if (self._has_parallel_devices()):
317
+ # Use a simple period transcription instead, as we need to use the parallel context
318
+ periodic_vad = VadPeriodicTranscription()
319
+ period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
320
+
321
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
322
+ else:
323
+ # Default VAD
324
+ result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
325
+
326
+ return result
327
+
328
+ def _create_progress_listener(self, progress: gr.Progress):
329
+ if (progress is None):
330
+ # Dummy progress listener
331
+ return ProgressListener()
332
+
333
+ class ForwardingProgressListener(ProgressListener):
334
+ def __init__(self, progress: gr.Progress):
335
+ self.progress = progress
336
+
337
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
338
+ # From 0 to 1
339
+ self.progress(current / total)
340
+
341
+ def on_finished(self):
342
+ self.progress(1)
343
+
344
+ return ForwardingProgressListener(progress)
345
+
346
+ def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig,
347
+ progressListener: ProgressListener = None):
348
+ if (not self._has_parallel_devices()):
349
+ # No parallel devices, so just run the VAD and Whisper in sequence
350
+ return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
351
+
352
+ gpu_devices = self.parallel_device_list
353
+
354
+ if (gpu_devices is None or len(gpu_devices) == 0):
355
+ # No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
356
+ gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
357
+
358
+ # Create parallel context if needed
359
+ if (self.gpu_parallel_context is None):
360
+ # Create a context wih processes and automatically clear the pool after 1 hour of inactivity
361
+ self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
362
+ # We also need a CPU context for the VAD
363
+ if (self.cpu_parallel_context is None):
364
+ self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
365
+
366
+ parallel_vad = ParallelTranscription()
367
+ return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
368
+ config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
369
+ cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
370
+ progress_listener=progressListener)
371
+
372
+ def _has_parallel_devices(self):
373
+ return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
374
+
375
+ def _concat_prompt(self, prompt1, prompt2):
376
+ if (prompt1 is None):
377
+ return prompt2
378
+ elif (prompt2 is None):
379
+ return prompt1
380
+ else:
381
+ return prompt1 + " " + prompt2
382
+
383
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
384
+ # Use Silero VAD
385
+ if (self.vad_model is None):
386
+ self.vad_model = VadSileroTranscription()
387
+
388
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
389
+ max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
390
+ segment_padding_left=vadOptions.vadPadding, segment_padding_right=vadOptions.vadPadding,
391
+ max_prompt_window=vadOptions.vadPromptWindow)
392
+
393
+ return config
394
+
395
+ def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
396
+ if not os.path.exists(output_dir):
397
+ os.makedirs(output_dir)
398
+
399
+ text = result["text"]
400
+ language = result["language"]
401
+ languageMaxLineWidth = self.__get_max_line_width(language)
402
+
403
+ print("Max line width " + str(languageMaxLineWidth))
404
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
405
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
406
+ json_result = json.dumps(result, indent=4, ensure_ascii=False)
407
+
408
+ output_files = []
409
+ output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
410
+ output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
411
+ output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
412
+ output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
413
+
414
+ return output_files, text, vtt
415
+
416
+ def clear_cache(self):
417
+ self.model_cache.clear()
418
+ self.vad_model = None
419
+
420
+ def __get_source(self, urlData, multipleFiles, microphoneData):
421
+ return get_audio_source_collection(urlData, multipleFiles, microphoneData, self.inputAudioMaxDuration)
422
+
423
+ def __get_max_line_width(self, language: str) -> int:
424
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
425
+ # Chinese characters and kana are wider, so limit line length to 40 characters
426
+ return 40
427
+ else:
428
+ # TODO: Add more languages
429
+ # 80 latin characters should fit on a 1080p/720p screen
430
+ return 80
431
+
432
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
433
+ segmentStream = StringIO()
434
+
435
+ if format == 'vtt':
436
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
437
+ elif format == 'srt':
438
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
439
+ else:
440
+ raise Exception("Unknown format " + format)
441
+
442
+ segmentStream.seek(0)
443
+ return segmentStream.read()
444
+
445
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
446
+ # Write the text to a file
447
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
448
+ file.write(text)
449
+
450
+ return file.name
451
+
452
+ def close(self):
453
+ print("Closing parallel contexts")
454
+ self.clear_cache()
455
+
456
+ if (self.gpu_parallel_context is not None):
457
+ self.gpu_parallel_context.close()
458
+ if (self.cpu_parallel_context is not None):
459
+ self.cpu_parallel_context.close()
460
+
461
+
462
+ def create_ui(app_config: ApplicationConfig):
463
+ ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
464
+ app_config.delete_uploaded_files, app_config.output_dir, app_config)
465
+
466
+ # Specify a list of devices to use for parallel processing
467
+ ui.set_parallel_devices(app_config.vad_parallel_devices)
468
+ ui.set_auto_parallel(app_config.auto_parallel)
469
+
470
+ is_whisper = False
471
+
472
+ if app_config.whisper_implementation == "whisper":
473
+ implementation_name = "Whisper"
474
+ is_whisper = True
475
+ elif app_config.whisper_implementation in ["faster-whisper", "faster_whisper"]:
476
+ implementation_name = "Faster Whisper"
477
+ else:
478
+ # Try to convert from camel-case to title-case
479
+ implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
480
+
481
+ ui_description = implementation_name + " is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
482
+ ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
483
+ ui_description += " as well as speech translation and language identification. "
484
+
485
+ ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
486
+
487
+ # Recommend faster-whisper
488
+ if is_whisper:
489
+ ui_description += "\n\n\n\nFor faster inference on GPU, try [faster-whisper](https://huggingface.co/spaces/aadnk/faster-whisper-webui)."
490
+
491
+ if app_config.input_audio_max_duration > 0:
492
+ ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
493
+
494
+ ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
495
+
496
+ whisper_models = app_config.get_model_names()
497
+
498
+ common_inputs = lambda : [
499
+ gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
500
+ gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
501
+ gr.Text(label="URL (YouTube, etc.)"),
502
+ gr.File(label="Upload Files", file_count="multiple"),
503
+ gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
504
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
505
+ ]
506
+
507
+ common_vad_inputs = lambda : [
508
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
509
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
510
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
511
+ ]
512
+
513
+ common_word_timestamps_inputs = lambda : [
514
+ gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
515
+ gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
516
+ ]
517
+
518
+ is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
519
+
520
+ simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
521
+ description=ui_description, article=ui_article, inputs=[
522
+ *common_inputs(),
523
+ *common_vad_inputs(),
524
+ *common_word_timestamps_inputs(),
525
+ ], outputs=[
526
+ gr.File(label="Download"),
527
+ gr.Text(label="Transcription"),
528
+ gr.Text(label="Segments")
529
+ ])
530
+
531
+ full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
532
+
533
+ full_transcribe = gr.Interface(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
534
+ description=full_description, article=ui_article, inputs=[
535
+ *common_inputs(),
536
+
537
+ *common_vad_inputs(),
538
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
539
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
540
+ gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode"),
541
+
542
+ *common_word_timestamps_inputs(),
543
+ gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
544
+ gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
545
+
546
+ gr.TextArea(label="Initial Prompt"),
547
+ gr.Number(label="Temperature", value=app_config.temperature),
548
+ gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
549
+ gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
550
+ gr.Number(label="Patience - Zero temperature", value=app_config.patience),
551
+ gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
552
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
553
+ gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
554
+ gr.Checkbox(label="FP16", value=app_config.fp16),
555
+ gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
556
+ gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
557
+ gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
558
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
559
+ ], outputs=[
560
+ gr.File(label="Download"),
561
+ gr.Text(label="Transcription"),
562
+ gr.Text(label="Segments")
563
+ ])
564
+
565
+ demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
566
+
567
+ # Queue up the demo
568
+ if is_queue_mode:
569
+ demo.queue(concurrency_count=app_config.queue_concurrency_count)
570
+ print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
571
+ else:
572
+ print("Queue mode disabled - progress bars will not be shown.")
573
+
574
+ demo.launch(share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
575
+
576
+ # Clean up
577
+ ui.close()
578
+
579
+ if __name__ == '__main__':
580
+ default_app_config = ApplicationConfig.create_default()
581
+ whisper_models = default_app_config.get_model_names()
582
+
583
+ # Environment variable overrides
584
+ default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
585
+
586
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
587
+ parser.add_argument("--input_audio_max_duration", type=int, default=default_app_config.input_audio_max_duration, \
588
+ help="Maximum audio file length in seconds, or -1 for no limit.") # 600
589
+ parser.add_argument("--share", type=bool, default=default_app_config.share, \
590
+ help="True to share the app on HuggingFace.") # False
591
+ parser.add_argument("--server_name", type=str, default=default_app_config.server_name, \
592
+ help="The host or IP to bind to. If None, bind to localhost.") # None
593
+ parser.add_argument("--server_port", type=int, default=default_app_config.server_port, \
594
+ help="The port to bind to.") # 7860
595
+ parser.add_argument("--queue_concurrency_count", type=int, default=default_app_config.queue_concurrency_count, \
596
+ help="The number of concurrent requests to process.") # 1
597
+ parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=default_app_config.default_model_name, \
598
+ help="The default model name.") # medium
599
+ parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
600
+ help="The default VAD.") # silero-vad
601
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
602
+ help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
603
+ parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
604
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
605
+ parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
606
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
607
+ parser.add_argument("--vad_process_timeout", type=float, default=default_app_config.vad_process_timeout, \
608
+ help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
609
+ parser.add_argument("--auto_parallel", type=bool, default=default_app_config.auto_parallel, \
610
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
611
+ parser.add_argument("--output_dir", "-o", type=str, default=default_app_config.output_dir, \
612
+ help="directory to save the outputs")
613
+ parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
614
+ help="the Whisper implementation to use")
615
+ parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
616
+ help="the compute type to use for inference")
617
+ parser.add_argument("--threads", type=optional_int, default=0,
618
+ help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
619
+
620
+ args = parser.parse_args().__dict__
621
+
622
+ updated_config = default_app_config.update(**args)
623
+
624
+ if (threads := args.pop("threads")) > 0:
625
+ torch.set_num_threads(threads)
626
+
627
+ create_ui(app_config=updated_config)
cli.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ from urllib.parse import urlparse
5
+ import warnings
6
+ import numpy as np
7
+
8
+ import torch
9
+ from app import VadOptions, WhisperTranscriber
10
+ from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
11
+ from src.download import download_url
12
+ from src.languages import get_language_names
13
+
14
+ from src.utils import optional_float, optional_int, str2bool
15
+ from src.whisper.whisperFactory import create_whisper_container
16
+
17
+ def cli():
18
+ app_config = ApplicationConfig.create_default()
19
+ whisper_models = app_config.get_model_names()
20
+
21
+ # For the CLI, we fallback to saving the output to the current directory
22
+ output_dir = app_config.output_dir if app_config.output_dir is not None else "."
23
+
24
+ # Environment variable overrides
25
+ default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", app_config.whisper_implementation)
26
+
27
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
+ parser.add_argument("audio", nargs="+", type=str, \
29
+ help="audio file(s) to transcribe")
30
+ parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
31
+ help="name of the Whisper model to use") # medium
32
+ parser.add_argument("--model_dir", type=str, default=app_config.model_dir, \
33
+ help="the path to save model files; uses ~/.cache/whisper by default")
34
+ parser.add_argument("--device", default=app_config.device, \
35
+ help="device to use for PyTorch inference")
36
+ parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
37
+ help="directory to save the outputs")
38
+ parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
39
+ help="whether to print out the progress and debug messages")
40
+ parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
41
+ help="the Whisper implementation to use")
42
+
43
+ parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
44
+ help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
45
+ parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(get_language_names()), \
46
+ help="language spoken in the audio, specify None to perform language detection")
47
+
48
+ parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
49
+ help="The voice activity detection algorithm to use") # silero-vad
50
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
51
+ help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
52
+ parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
53
+ help="The window size (in seconds) to merge voice segments")
54
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
55
+ help="The maximum size (in seconds) of a voice segment")
56
+ parser.add_argument("--vad_padding", type=optional_float, default=app_config.vad_padding, \
57
+ help="The padding (in seconds) to add to each voice segment")
58
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=app_config.vad_prompt_window, \
59
+ help="The window size of the prompt to pass to Whisper")
60
+ parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
61
+ help="The number of CPU cores to use for VAD pre-processing.") # 1
62
+ parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
63
+ help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
64
+ parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
65
+ help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
66
+
67
+ parser.add_argument("--temperature", type=float, default=app_config.temperature, \
68
+ help="temperature to use for sampling")
69
+ parser.add_argument("--best_of", type=optional_int, default=app_config.best_of, \
70
+ help="number of candidates when sampling with non-zero temperature")
71
+ parser.add_argument("--beam_size", type=optional_int, default=app_config.beam_size, \
72
+ help="number of beams in beam search, only applicable when temperature is zero")
73
+ parser.add_argument("--patience", type=float, default=app_config.patience, \
74
+ help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
75
+ parser.add_argument("--length_penalty", type=float, default=app_config.length_penalty, \
76
+ help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
77
+
78
+ parser.add_argument("--suppress_tokens", type=str, default=app_config.suppress_tokens, \
79
+ help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
80
+ parser.add_argument("--initial_prompt", type=str, default=app_config.initial_prompt, \
81
+ help="optional text to provide as a prompt for the first window.")
82
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=app_config.condition_on_previous_text, \
83
+ help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
84
+ parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
85
+ help="whether to perform inference in fp16; True by default")
86
+ parser.add_argument("--compute_type", type=str, default=app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
87
+ help="the compute type to use for inference")
88
+
89
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
90
+ help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
91
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=app_config.compression_ratio_threshold, \
92
+ help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
93
+ parser.add_argument("--logprob_threshold", type=optional_float, default=app_config.logprob_threshold, \
94
+ help="if the average log probability is lower than this value, treat the decoding as failed")
95
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
96
+ help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
97
+
98
+ parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
99
+ help="(experimental) extract word-level timestamps and refine the results based on them")
100
+ parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
101
+ help="if word_timestamps is True, merge these punctuation symbols with the next word")
102
+ parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
103
+ help="if word_timestamps is True, merge these punctuation symbols with the previous word")
104
+ parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
105
+ help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
106
+ parser.add_argument("--threads", type=optional_int, default=0,
107
+ help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
108
+
109
+ args = parser.parse_args().__dict__
110
+ model_name: str = args.pop("model")
111
+ model_dir: str = args.pop("model_dir")
112
+ output_dir: str = args.pop("output_dir")
113
+ device: str = args.pop("device")
114
+ os.makedirs(output_dir, exist_ok=True)
115
+
116
+ if (threads := args.pop("threads")) > 0:
117
+ torch.set_num_threads(threads)
118
+
119
+ whisper_implementation = args.pop("whisper_implementation")
120
+ print(f"Using {whisper_implementation} for Whisper")
121
+
122
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
123
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
124
+ args["language"] = "en"
125
+
126
+ temperature = args.pop("temperature")
127
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
128
+ if temperature_increment_on_fallback is not None:
129
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
130
+ else:
131
+ temperature = [temperature]
132
+
133
+ vad = args.pop("vad")
134
+ vad_initial_prompt_mode = args.pop("vad_initial_prompt_mode")
135
+ vad_merge_window = args.pop("vad_merge_window")
136
+ vad_max_merge_size = args.pop("vad_max_merge_size")
137
+ vad_padding = args.pop("vad_padding")
138
+ vad_prompt_window = args.pop("vad_prompt_window")
139
+ vad_cpu_cores = args.pop("vad_cpu_cores")
140
+ auto_parallel = args.pop("auto_parallel")
141
+
142
+ compute_type = args.pop("compute_type")
143
+ highlight_words = args.pop("highlight_words")
144
+
145
+ transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
146
+ transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
147
+ transcriber.set_auto_parallel(auto_parallel)
148
+
149
+ model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
150
+ device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
151
+
152
+ if (transcriber._has_parallel_devices()):
153
+ print("Using parallel devices:", transcriber.parallel_device_list)
154
+
155
+ for audio_path in args.pop("audio"):
156
+ sources = []
157
+
158
+ # Detect URL and download the audio
159
+ if (uri_validator(audio_path)):
160
+ # Download from YouTube/URL directly
161
+ for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
162
+ source_name = os.path.basename(source_path)
163
+ sources.append({ "path": source_path, "name": source_name })
164
+ else:
165
+ sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
166
+
167
+ for source in sources:
168
+ source_path = source["path"]
169
+ source_name = source["name"]
170
+
171
+ vadOptions = VadOptions(vad, vad_merge_window, vad_max_merge_size, vad_padding, vad_prompt_window,
172
+ VadInitialPromptMode.from_string(vad_initial_prompt_mode))
173
+
174
+ result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
175
+
176
+ transcriber.write_result(result, source_name, output_dir, highlight_words)
177
+
178
+ transcriber.close()
179
+
180
+ def uri_validator(x):
181
+ try:
182
+ result = urlparse(x)
183
+ return all([result.scheme, result.netloc])
184
+ except:
185
+ return False
186
+
187
+ if __name__ == '__main__':
188
+ cli()
config.json5 ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": [
3
+ // Configuration for the built-in models. You can remove any of these
4
+ // if you don't want to use the default models.
5
+ {
6
+ "name": "tiny",
7
+ "url": "tiny"
8
+ },
9
+ {
10
+ "name": "base",
11
+ "url": "base"
12
+ },
13
+ {
14
+ "name": "small",
15
+ "url": "small"
16
+ },
17
+ {
18
+ "name": "medium",
19
+ "url": "medium"
20
+ },
21
+ {
22
+ "name": "large",
23
+ "url": "large"
24
+ },
25
+ {
26
+ "name": "large-v2",
27
+ "url": "large-v2"
28
+ },
29
+ // Uncomment to add custom Japanese models
30
+ //{
31
+ // "name": "whisper-large-v2-mix-jp",
32
+ // "url": "vumichien/whisper-large-v2-mix-jp",
33
+ // // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
34
+ // // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
35
+ // "type": "huggingface",
36
+ //},
37
+ //{
38
+ // "name": "local-model",
39
+ // "url": "path/to/local/model",
40
+ //},
41
+ //{
42
+ // "name": "remote-model",
43
+ // "url": "https://example.com/path/to/model",
44
+ //}
45
+ ],
46
+ // Configuration options that will be used if they are not specified in the command line arguments.
47
+
48
+ // * WEBUI options *
49
+
50
+ // Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
51
+ "input_audio_max_duration": 600,
52
+ // True to share the app on HuggingFace.
53
+ "share": false,
54
+ // The host or IP to bind to. If None, bind to localhost.
55
+ "server_name": null,
56
+ // The port to bind to.
57
+ "server_port": 7860,
58
+ // The number of workers to use for the web server. Use -1 to disable queueing.
59
+ "queue_concurrency_count": 1,
60
+ // Whether or not to automatically delete all uploaded files, to save disk space
61
+ "delete_uploaded_files": true,
62
+
63
+ // * General options *
64
+
65
+ // The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
66
+ // Note that you must either install the requirements for faster-whisper (requirements-fasterWhisper.txt)
67
+ // or whisper (requirements.txt)
68
+ "whisper_implementation": "whisper",
69
+
70
+ // The default model name.
71
+ "default_model_name": "medium",
72
+ // The default VAD.
73
+ "default_vad": "silero-vad",
74
+ // A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
75
+ "vad_parallel_devices": "",
76
+ // The number of CPU cores to use for VAD pre-processing.
77
+ "vad_cpu_cores": 1,
78
+ // The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
79
+ "vad_process_timeout": 1800,
80
+ // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
81
+ "auto_parallel": false,
82
+ // Directory to save the outputs (CLI will use the current directory if not specified)
83
+ "output_dir": null,
84
+ // The path to save model files; uses ~/.cache/whisper by default
85
+ "model_dir": null,
86
+ // Device to use for PyTorch inference, or Null to use the default device
87
+ "device": null,
88
+ // Whether to print out the progress and debug messages
89
+ "verbose": true,
90
+ // Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
91
+ "task": "transcribe",
92
+ // Language spoken in the audio, specify None to perform language detection
93
+ "language": null,
94
+ // The window size (in seconds) to merge voice segments
95
+ "vad_merge_window": 5,
96
+ // The maximum size (in seconds) of a voice segment
97
+ "vad_max_merge_size": 30,
98
+ // The padding (in seconds) to add to each voice segment
99
+ "vad_padding": 1,
100
+ // Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
101
+ "vad_initial_prompt_mode": "prepend_first_segment",
102
+ // The window size of the prompt to pass to Whisper
103
+ "vad_prompt_window": 3,
104
+ // Temperature to use for sampling
105
+ "temperature": 0,
106
+ // Number of candidates when sampling with non-zero temperature
107
+ "best_of": 5,
108
+ // Number of beams in beam search, only applicable when temperature is zero
109
+ "beam_size": 5,
110
+ // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
111
+ "patience": 1,
112
+ // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
113
+ "length_penalty": null,
114
+ // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
115
+ "suppress_tokens": "-1",
116
+ // Optional text to provide as a prompt for the first window
117
+ "initial_prompt": null,
118
+ // If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
119
+ "condition_on_previous_text": true,
120
+ // Whether to perform inference in fp16; True by default
121
+ "fp16": true,
122
+ // The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
123
+ "compute_type": "auto",
124
+ // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
125
+ "temperature_increment_on_fallback": 0.2,
126
+ // If the gzip compression ratio is higher than this value, treat the decoding as failed
127
+ "compression_ratio_threshold": 2.4,
128
+ // If the average log probability is lower than this value, treat the decoding as failed
129
+ "logprob_threshold": -1.0,
130
+ // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
131
+ "no_speech_threshold": 0.6,
132
+
133
+ // (experimental) extract word-level timestamps and refine the results based on them
134
+ "word_timestamps": false,
135
+ // if word_timestamps is True, merge these punctuation symbols with the next word
136
+ "prepend_punctuations": "\"\'“¿([{-",
137
+ // if word_timestamps is True, merge these punctuation symbols with the previous word
138
+ "append_punctuations": "\"\'.。,,!!??::”)]}、",
139
+ // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
140
+ "highlight_words": false,
141
+ }
dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # docker build -t whisper-webui --build-arg WHISPER_IMPLEMENTATION=whisper .
2
+
3
+ FROM huggingface/transformers-pytorch-gpu
4
+ EXPOSE 7860
5
+
6
+ ARG WHISPER_IMPLEMENTATION=whisper
7
+ ENV WHISPER_IMPLEMENTATION=${WHISPER_IMPLEMENTATION}
8
+
9
+ ADD . /opt/whisper-webui/
10
+
11
+ # Latest version of transformers-pytorch-gpu seems to lack tk.
12
+ # Further, pip install fails, so we must upgrade pip first.
13
+ RUN apt-get -y install python3-tk
14
+ RUN python3 -m pip install --upgrade pip
15
+
16
+ RUN if [ "${WHISPER_IMPLEMENTATION}" = "whisper" ]; then \
17
+ python3 -m pip install -r /opt/whisper-webui/requirements-whisper.txt; \
18
+ else \
19
+ python3 -m pip install -r /opt/whisper-webui/requirements-fasterWhisper.txt; \
20
+ fi
21
+
22
+ # Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
23
+ # You can also bind this directory in the container to somewhere on the host.
24
+
25
+ # To be able to see logs in real time
26
+ ENV PYTHONUNBUFFERED=1
27
+
28
+ WORKDIR /opt/whisper-webui/
29
+ ENTRYPOINT ["python3"]
30
+ CMD ["app.py", "--input_audio_max_duration", "-1", "--server_name", "0.0.0.0", "--auto_parallel", "True"]
docs/colab.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running Whisper on Google Colab
2
+
3
+ If you don't have a decent GPU or any experience in running command-line applications, you might want to try this Google Colab instead:
4
+
5
+ * [Google Colab - Whisper WebUI GPU](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing)
6
+ * [Screenshots](https://imgur.com/a/ZfY6uBO)
7
+
8
+ The runtime (Runtime -> Change runtime type -> Hardware accelerator) should already be set top GPU. But if not, change it to GPU.
9
+
10
+ Then, sign in to Google if you haven't already. Next, click on "Connect" at the top right.
11
+
12
+ Under "Checking out WebUI from Git", click on the [play icon](https://imgur.com/a/81gOLyD) that appears in "[ ]" at the left. If you get a warning, click "Run anyway".
13
+
14
+ After this step has completed, it should be get a green check mark. Then move on to the next section under "Installing dependencies", and click in "[ ]" again. This might take approximately 30 seconds.
15
+
16
+ Once this has completed, scroll down to the "Run WebUI" section, and click on "[ ]". This will launch the WebUI in a shared link (expires in 72 hours). To open the UI, click on the link next to "Running on public URL", which will be something like https://12xxx.gradio.app/
17
+
18
+ The audio length in this version is not restricted, and it will run much faster as it is backed by a GPU. You can also run it using the "Large" model. Also note that it might take some time to start the model the first time, as it may need to download a 2.8 GB file on Google's servers.
19
+
20
+ Once you're done, you can close the WebUI session by clicking the animated close button under "Run WebUI". You can also do this if you encounter any errors and need to restart the UI. You should also go to "Manage Sessions" and terminate the session, otherwise you may end up using all your free compute credits.
docs/options.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Options
2
+ To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
+ supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
+ in the file selector to select any file type, including video files) or use the microphone.
5
+
6
+ For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option, especially if you are using the `large-v1` model. Note that `large-v2` is a lot more forgiving, but you may still want to use a VAD with a slightly higher "VAD - Max Merge Size (s)" (60 seconds or more).
7
+
8
+ ## Model
9
+ Select the model that Whisper will use to transcribe the audio:
10
+
11
+ | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
12
+ |-----------|------------|--------------------|--------------------|---------------|----------------|
13
+ | tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
14
+ | base | 74 M | base.en | base | ~1 GB | ~16x |
15
+ | small | 244 M | small.en | small | ~2 GB | ~6x |
16
+ | medium | 769 M | medium.en | medium | ~5 GB | ~2x |
17
+ | large | 1550 M | N/A | large | ~10 GB | 1x |
18
+ | large-v2 | 1550 M | N/A | large | ~10 GB | 1x |
19
+
20
+ ## Language
21
+
22
+ Select the language, or leave it empty for Whisper to automatically detect it.
23
+
24
+ Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
25
+ language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
26
+
27
+ ## Inputs
28
+ The options "URL (YouTube, etc.)", "Upload Files" or "Micriphone Input" allows you to send an audio input to the model.
29
+
30
+ ### Multiple Files
31
+ Note that the UI will only process either the given URL or the upload files (including microphone) - not both.
32
+
33
+ But you can upload multiple files either through the "Upload files" option, or as a playlist on YouTube. Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section. When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
34
+
35
+ ## Task
36
+ Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
37
+
38
+ ## Vad
39
+ Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
40
+ loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
41
+ with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
42
+
43
+ Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
44
+ So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
45
+
46
+ * none
47
+ * Run whisper on the entire audio input
48
+ * silero-vad
49
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
50
+ on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
51
+ on the non-speech section.
52
+ * silero-vad-expand-into-gaps
53
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
54
+ such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
55
+ 00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
56
+ * silero-vad-skip-gaps
57
+ * As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
58
+ may cause dialogue to be skipped.
59
+ * periodic-vad
60
+ * Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
61
+ a sentence or word in two.
62
+
63
+ ## VAD - Merge Window
64
+ If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
65
+
66
+ ## VAD - Max Merge Size (s)
67
+ Disables merging of adjacent speech sections if they are this number of seconds long.
68
+
69
+ ## VAD - Padding (s)
70
+ The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
71
+ larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
72
+ a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
73
+ to each transcribed line. The default value is 1 second.
74
+
75
+ ## VAD - Prompt Window (s)
76
+ The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
77
+ number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
78
+ 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
79
+
80
+ Note that detected lines in gaps between speech sections will not be included in the prompt
81
+ (if silero-vad or silero-vad-expand-into-gaps) is used.
82
+
83
+ # Command Line Options
84
+
85
+ Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
86
+ CPU/GPU cores, the default model name/VAD and so on. Consult the README in the root folder for more information.
87
+
88
+ # Additional Options
89
+
90
+ In addition to the above, there's also a "Full" options interface that allows you to set all the options available in the Whisper
91
+ model. The options are as follows:
92
+
93
+ ## Initial Prompt
94
+ Optional text to provide as a prompt for the first 30 seconds window. Whisper will attempt to use this as a starting point for the transcription, but you can
95
+ also get creative and specify a style or format for the output of the transcription.
96
+
97
+ For instance, if you use the prompt "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they", Whisper will
98
+ be biased to output lower capital letters and no punctuation, and may also be biased to output the words in the prompt more often.
99
+
100
+ ## Temperature
101
+ The temperature to use when sampling. Default is 0 (zero). A higher temperature will result in more random output, while a lower temperature will be more deterministic.
102
+
103
+ ## Best Of - Non-zero temperature
104
+ The number of candidates to sample from when sampling with non-zero temperature. Default is 5.
105
+
106
+ ## Beam Size - Zero temperature
107
+ The number of beams to use in beam search when sampling with zero temperature. Default is 5.
108
+
109
+ ## Patience - Zero temperature
110
+ The patience value to use in beam search when sampling with zero temperature. As in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search.
111
+
112
+ ## Length Penalty - Any temperature
113
+ The token length penalty coefficient (alpha) to use when sampling with any temperature. As in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.
114
+
115
+ ## Suppress Tokens - Comma-separated list of token IDs
116
+ A comma-separated list of token IDs to suppress during sampling. The default value of "-1" will suppress most special characters except common punctuations.
117
+
118
+ ## Condition on previous text
119
+ If True, provide the previous output of the model as a prompt for the next window. Disabling this may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop.
120
+
121
+ ## FP16
122
+ Whether to perform inference in fp16. True by default.
123
+
124
+ ## Temperature increment on fallback
125
+ The temperature to increase when falling back when the decoding fails to meet either of the thresholds below. Default is 0.2.
126
+
127
+ ## Compression ratio threshold
128
+ If the gzip compression ratio is higher than this value, treat the decoding as failed. Default is 2.4.
129
+
130
+ ## Logprob threshold
131
+ If the average log probability is lower than this value, treat the decoding as failed. Default is -1.0.
132
+
133
+ ## No speech threshold
134
+ If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
docs/windows/install_win10_win11.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b9f4ed547d6534411c17da1ea56707d2ec6e812611b1cbd3098756d5cbb8084
3
+ size 3378789
requirements-fasterWhisper.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ctranslate2
2
+ faster-whisper
3
+ ffmpeg-python==0.2.0
4
+ gradio==3.23.0
5
+ yt-dlp
6
+ json5
7
+ torch
8
+ torchaudio
9
+ more_itertools
requirements-whisper.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ git+https://github.com/openai/whisper.git
3
+ transformers
4
+ ffmpeg-python==0.2.0
5
+ gradio==3.23.0
6
+ yt-dlp
7
+ torchaudio
8
+ altair
9
+ json5
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ git+https://github.com/openai/whisper.git
3
+ transformers
4
+ ffmpeg-python==0.2.0
5
+ gradio==3.23.0
6
+ yt-dlp
7
+ torchaudio
8
+ altair
9
+ json5
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import urllib
3
+
4
+ import os
5
+ from typing import List
6
+ from urllib.parse import urlparse
7
+ import json5
8
+ import torch
9
+
10
+ from tqdm import tqdm
11
+
12
+ class ModelConfig:
13
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
14
+ """
15
+ Initialize a model configuration.
16
+
17
+ name: Name of the model
18
+ url: URL to download the model from
19
+ path: Path to the model file. If not set, the model will be downloaded from the URL.
20
+ type: Type of model. Can be whisper or huggingface.
21
+ """
22
+ self.name = name
23
+ self.url = url
24
+ self.path = path
25
+ self.type = type
26
+
27
+ VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
28
+
29
+ class VadInitialPromptMode(Enum):
30
+ PREPEND_ALL_SEGMENTS = 1
31
+ PREPREND_FIRST_SEGMENT = 2
32
+ JSON_PROMPT_MODE = 3
33
+
34
+ @staticmethod
35
+ def from_string(s: str):
36
+ normalized = s.lower() if s is not None else None
37
+
38
+ if normalized == "prepend_all_segments":
39
+ return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
40
+ elif normalized == "prepend_first_segment":
41
+ return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
42
+ elif normalized == "json_prompt_mode":
43
+ return VadInitialPromptMode.JSON_PROMPT_MODE
44
+ elif normalized is not None and normalized != "":
45
+ raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
46
+ else:
47
+ return None
48
+
49
+ class ApplicationConfig:
50
+ def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
51
+ share: bool = False, server_name: str = None, server_port: int = 7860,
52
+ queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
53
+ whisper_implementation: str = "whisper",
54
+ default_model_name: str = "medium", default_vad: str = "silero-vad",
55
+ vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
56
+ auto_parallel: bool = False, output_dir: str = None,
57
+ model_dir: str = None, device: str = None,
58
+ verbose: bool = True, task: str = "transcribe", language: str = None,
59
+ vad_initial_prompt_mode: str = "prepend_first_segment ",
60
+ vad_merge_window: float = 5, vad_max_merge_size: float = 30,
61
+ vad_padding: float = 1, vad_prompt_window: float = 3,
62
+ temperature: float = 0, best_of: int = 5, beam_size: int = 5,
63
+ patience: float = None, length_penalty: float = None,
64
+ suppress_tokens: str = "-1", initial_prompt: str = None,
65
+ condition_on_previous_text: bool = True, fp16: bool = True,
66
+ compute_type: str = "float16",
67
+ temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
68
+ logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
69
+ # Word timestamp settings
70
+ word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
71
+ append_punctuations: str = "\"\'.。,,!!??::”)]}、",
72
+ highlight_words: bool = False):
73
+
74
+ self.models = models
75
+
76
+ # WebUI settings
77
+ self.input_audio_max_duration = input_audio_max_duration
78
+ self.share = share
79
+ self.server_name = server_name
80
+ self.server_port = server_port
81
+ self.queue_concurrency_count = queue_concurrency_count
82
+ self.delete_uploaded_files = delete_uploaded_files
83
+
84
+ self.whisper_implementation = whisper_implementation
85
+ self.default_model_name = default_model_name
86
+ self.default_vad = default_vad
87
+ self.vad_parallel_devices = vad_parallel_devices
88
+ self.vad_cpu_cores = vad_cpu_cores
89
+ self.vad_process_timeout = vad_process_timeout
90
+ self.auto_parallel = auto_parallel
91
+ self.output_dir = output_dir
92
+
93
+ self.model_dir = model_dir
94
+ self.device = device
95
+ self.verbose = verbose
96
+ self.task = task
97
+ self.language = language
98
+ self.vad_initial_prompt_mode = vad_initial_prompt_mode
99
+ self.vad_merge_window = vad_merge_window
100
+ self.vad_max_merge_size = vad_max_merge_size
101
+ self.vad_padding = vad_padding
102
+ self.vad_prompt_window = vad_prompt_window
103
+ self.temperature = temperature
104
+ self.best_of = best_of
105
+ self.beam_size = beam_size
106
+ self.patience = patience
107
+ self.length_penalty = length_penalty
108
+ self.suppress_tokens = suppress_tokens
109
+ self.initial_prompt = initial_prompt
110
+ self.condition_on_previous_text = condition_on_previous_text
111
+ self.fp16 = fp16
112
+ self.compute_type = compute_type
113
+ self.temperature_increment_on_fallback = temperature_increment_on_fallback
114
+ self.compression_ratio_threshold = compression_ratio_threshold
115
+ self.logprob_threshold = logprob_threshold
116
+ self.no_speech_threshold = no_speech_threshold
117
+
118
+ # Word timestamp settings
119
+ self.word_timestamps = word_timestamps
120
+ self.prepend_punctuations = prepend_punctuations
121
+ self.append_punctuations = append_punctuations
122
+ self.highlight_words = highlight_words
123
+
124
+ def get_model_names(self):
125
+ return [ x.name for x in self.models ]
126
+
127
+ def update(self, **new_values):
128
+ result = ApplicationConfig(**self.__dict__)
129
+
130
+ for key, value in new_values.items():
131
+ setattr(result, key, value)
132
+ return result
133
+
134
+ @staticmethod
135
+ def create_default(**kwargs):
136
+ app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
137
+
138
+ # Update with kwargs
139
+ if len(kwargs) > 0:
140
+ app_config = app_config.update(**kwargs)
141
+ return app_config
142
+
143
+ @staticmethod
144
+ def parse_file(config_path: str):
145
+ import json5
146
+
147
+ with open(config_path, "r", encoding="utf-8") as f:
148
+ # Load using json5
149
+ data = json5.load(f)
150
+ data_models = data.pop("models", [])
151
+
152
+ models = [ ModelConfig(**x) for x in data_models ]
153
+
154
+ return ApplicationConfig(models, **data)
src/conversion/hf_converter.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
2
+
3
+ from copy import deepcopy
4
+ import torch
5
+
6
+ WHISPER_MAPPING = {
7
+ "layers": "blocks",
8
+ "fc1": "mlp.0",
9
+ "fc2": "mlp.2",
10
+ "final_layer_norm": "mlp_ln",
11
+ "layers": "blocks",
12
+ ".self_attn.q_proj": ".attn.query",
13
+ ".self_attn.k_proj": ".attn.key",
14
+ ".self_attn.v_proj": ".attn.value",
15
+ ".self_attn_layer_norm": ".attn_ln",
16
+ ".self_attn.out_proj": ".attn.out",
17
+ ".encoder_attn.q_proj": ".cross_attn.query",
18
+ ".encoder_attn.k_proj": ".cross_attn.key",
19
+ ".encoder_attn.v_proj": ".cross_attn.value",
20
+ ".encoder_attn_layer_norm": ".cross_attn_ln",
21
+ ".encoder_attn.out_proj": ".cross_attn.out",
22
+ "decoder.layer_norm.": "decoder.ln.",
23
+ "encoder.layer_norm.": "encoder.ln_post.",
24
+ "embed_tokens": "token_embedding",
25
+ "encoder.embed_positions.weight": "encoder.positional_embedding",
26
+ "decoder.embed_positions.weight": "decoder.positional_embedding",
27
+ "layer_norm": "ln_post",
28
+ }
29
+
30
+
31
+ def rename_keys(s_dict):
32
+ keys = list(s_dict.keys())
33
+ for key in keys:
34
+ new_key = key
35
+ for k, v in WHISPER_MAPPING.items():
36
+ if k in key:
37
+ new_key = new_key.replace(k, v)
38
+
39
+ print(f"{key} -> {new_key}")
40
+
41
+ s_dict[new_key] = s_dict.pop(key)
42
+ return s_dict
43
+
44
+
45
+ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
46
+ from transformers import WhisperForConditionalGeneration
47
+ transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
48
+ config = transformer_model.config
49
+
50
+ # first build dims
51
+ dims = {
52
+ 'n_mels': config.num_mel_bins,
53
+ 'n_vocab': config.vocab_size,
54
+ 'n_audio_ctx': config.max_source_positions,
55
+ 'n_audio_state': config.d_model,
56
+ 'n_audio_head': config.encoder_attention_heads,
57
+ 'n_audio_layer': config.encoder_layers,
58
+ 'n_text_ctx': config.max_target_positions,
59
+ 'n_text_state': config.d_model,
60
+ 'n_text_head': config.decoder_attention_heads,
61
+ 'n_text_layer': config.decoder_layers
62
+ }
63
+
64
+ state_dict = deepcopy(transformer_model.model.state_dict())
65
+ state_dict = rename_keys(state_dict)
66
+
67
+ torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
src/download.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'paths': {
34
+ 'home': destinationDirectory
35
+ }
36
+ }
37
+ if (playlistItems):
38
+ ydl_opts['playlist_items'] = playlistItems
39
+
40
+ # Add output template if specified
41
+ if outputTemplate:
42
+ ydl_opts['outtmpl'] = outputTemplate
43
+
44
+ filename_collector = FilenameCollectorPP()
45
+
46
+ with YoutubeDL(ydl_opts) as ydl:
47
+ if maxDuration and maxDuration > 0:
48
+ info = ydl.extract_info(url, download=False)
49
+ entries = "entries" in info and info["entries"] or [info]
50
+
51
+ total_duration = 0
52
+
53
+ # Compute total duration
54
+ for entry in entries:
55
+ total_duration += float(entry["duration"])
56
+
57
+ if total_duration >= maxDuration:
58
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=maxDuration, message="Video is too long")
59
+
60
+ ydl.add_post_processor(filename_collector)
61
+ ydl.download([url])
62
+
63
+ if len(filename_collector.filenames) <= 0:
64
+ raise Exception("Cannot download " + url)
65
+
66
+ result = []
67
+
68
+ for filename in filename_collector.filenames:
69
+ result.append(filename)
70
+ print("Downloaded " + filename)
71
+
72
+ return result
73
+
74
+ class ExceededMaximumDuration(Exception):
75
+ def __init__(self, videoDuration, maxDuration, message):
76
+ self.videoDuration = videoDuration
77
+ self.maxDuration = maxDuration
78
+ super().__init__(message)
src/hooks/progressListener.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ class ProgressListener:
4
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
5
+ self.total = total
6
+
7
+ def on_finished(self):
8
+ pass
src/hooks/subTaskProgressListener.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.hooks.progressListener import ProgressListener
2
+
3
+ from typing import Union
4
+
5
+ class SubTaskProgressListener(ProgressListener):
6
+ """
7
+ A sub task listener that reports the progress of a sub task to a base task listener
8
+ Parameters
9
+ ----------
10
+ base_task_listener : ProgressListener
11
+ The base progress listener to accumulate overall progress in.
12
+ base_task_total : float
13
+ The maximum total progress that will be reported to the base progress listener.
14
+ sub_task_start : float
15
+ The starting progress of a sub task, in respect to the base progress listener.
16
+ sub_task_total : float
17
+ The total amount of progress a sub task will report to the base progress listener.
18
+ """
19
+ def __init__(
20
+ self,
21
+ base_task_listener: ProgressListener,
22
+ base_task_total: float,
23
+ sub_task_start: float,
24
+ sub_task_total: float,
25
+ ):
26
+ self.base_task_listener = base_task_listener
27
+ self.base_task_total = base_task_total
28
+ self.sub_task_start = sub_task_start
29
+ self.sub_task_total = sub_task_total
30
+
31
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
32
+ sub_task_progress_frac = current / total
33
+ sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
34
+ self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
35
+
36
+ def on_finished(self):
37
+ self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
src/hooks/whisperProgressHook.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import threading
3
+ from typing import List, Union
4
+ import tqdm
5
+
6
+ from src.hooks.progressListener import ProgressListener
7
+
8
+ class ProgressListenerHandle:
9
+ def __init__(self, listener: ProgressListener):
10
+ self.listener = listener
11
+
12
+ def __enter__(self):
13
+ register_thread_local_progress_listener(self.listener)
14
+
15
+ def __exit__(self, exc_type, exc_val, exc_tb):
16
+ unregister_thread_local_progress_listener(self.listener)
17
+
18
+ if exc_type is None:
19
+ self.listener.on_finished()
20
+
21
+ class _CustomProgressBar(tqdm.tqdm):
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._current = self.n # Set the initial value
25
+
26
+ def update(self, n):
27
+ super().update(n)
28
+ # Because the progress bar might be disabled, we need to manually update the progress
29
+ self._current += n
30
+
31
+ # Inform listeners
32
+ listeners = _get_thread_local_listeners()
33
+
34
+ for listener in listeners:
35
+ listener.on_progress(self._current, self.total)
36
+
37
+ _thread_local = threading.local()
38
+
39
+ def _get_thread_local_listeners():
40
+ if not hasattr(_thread_local, 'listeners'):
41
+ _thread_local.listeners = []
42
+ return _thread_local.listeners
43
+
44
+ _hooked = False
45
+
46
+ def init_progress_hook():
47
+ global _hooked
48
+
49
+ if _hooked:
50
+ return
51
+
52
+ # Inject into tqdm.tqdm of Whisper, so we can see progress
53
+ import whisper.transcribe
54
+ transcribe_module = sys.modules['whisper.transcribe']
55
+ transcribe_module.tqdm.tqdm = _CustomProgressBar
56
+ _hooked = True
57
+
58
+ def register_thread_local_progress_listener(progress_listener: ProgressListener):
59
+ # This is a workaround for the fact that the progress bar is not exposed in the API
60
+ init_progress_hook()
61
+
62
+ listeners = _get_thread_local_listeners()
63
+ listeners.append(progress_listener)
64
+
65
+ def unregister_thread_local_progress_listener(progress_listener: ProgressListener):
66
+ listeners = _get_thread_local_listeners()
67
+
68
+ if progress_listener in listeners:
69
+ listeners.remove(progress_listener)
70
+
71
+ def create_progress_listener_handle(progress_listener: ProgressListener):
72
+ return ProgressListenerHandle(progress_listener)
73
+
74
+ # Example usage
75
+ if __name__ == '__main__':
76
+ class PrintingProgressListener:
77
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
78
+ print(f"Progress: {current}/{total}")
79
+
80
+ def on_finished(self):
81
+ print("Finished")
82
+
83
+ import whisper
84
+ model = whisper.load_model("medium")
85
+
86
+ with create_progress_listener_handle(PrintingProgressListener()) as listener:
87
+ # Set verbose to None to disable the progress bar, as we are using our own
88
+ result = model.transcribe("J:\\Dev\\OpenAI\\whisper\\tests\\Noriko\\out.mka", language="Japanese", fp16=False, verbose=None)
89
+ print(result)
90
+
91
+ print("Done")
src/languages.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Language():
2
+ def __init__(self, code, name):
3
+ self.code = code
4
+ self.name = name
5
+
6
+ def __str__(self):
7
+ return "Language(code={}, name={})".format(self.code, self.name)
8
+
9
+ LANGUAGES = [
10
+ Language('en', 'English'),
11
+ Language('zh', 'Chinese'),
12
+ Language('de', 'German'),
13
+ Language('es', 'Spanish'),
14
+ Language('ru', 'Russian'),
15
+ Language('ko', 'Korean'),
16
+ Language('fr', 'French'),
17
+ Language('ja', 'Japanese'),
18
+ Language('pt', 'Portuguese'),
19
+ Language('tr', 'Turkish'),
20
+ Language('pl', 'Polish'),
21
+ Language('ca', 'Catalan'),
22
+ Language('nl', 'Dutch'),
23
+ Language('ar', 'Arabic'),
24
+ Language('sv', 'Swedish'),
25
+ Language('it', 'Italian'),
26
+ Language('id', 'Indonesian'),
27
+ Language('hi', 'Hindi'),
28
+ Language('fi', 'Finnish'),
29
+ Language('vi', 'Vietnamese'),
30
+ Language('he', 'Hebrew'),
31
+ Language('uk', 'Ukrainian'),
32
+ Language('el', 'Greek'),
33
+ Language('ms', 'Malay'),
34
+ Language('cs', 'Czech'),
35
+ Language('ro', 'Romanian'),
36
+ Language('da', 'Danish'),
37
+ Language('hu', 'Hungarian'),
38
+ Language('ta', 'Tamil'),
39
+ Language('no', 'Norwegian'),
40
+ Language('th', 'Thai'),
41
+ Language('ur', 'Urdu'),
42
+ Language('hr', 'Croatian'),
43
+ Language('bg', 'Bulgarian'),
44
+ Language('lt', 'Lithuanian'),
45
+ Language('la', 'Latin'),
46
+ Language('mi', 'Maori'),
47
+ Language('ml', 'Malayalam'),
48
+ Language('cy', 'Welsh'),
49
+ Language('sk', 'Slovak'),
50
+ Language('te', 'Telugu'),
51
+ Language('fa', 'Persian'),
52
+ Language('lv', 'Latvian'),
53
+ Language('bn', 'Bengali'),
54
+ Language('sr', 'Serbian'),
55
+ Language('az', 'Azerbaijani'),
56
+ Language('sl', 'Slovenian'),
57
+ Language('kn', 'Kannada'),
58
+ Language('et', 'Estonian'),
59
+ Language('mk', 'Macedonian'),
60
+ Language('br', 'Breton'),
61
+ Language('eu', 'Basque'),
62
+ Language('is', 'Icelandic'),
63
+ Language('hy', 'Armenian'),
64
+ Language('ne', 'Nepali'),
65
+ Language('mn', 'Mongolian'),
66
+ Language('bs', 'Bosnian'),
67
+ Language('kk', 'Kazakh'),
68
+ Language('sq', 'Albanian'),
69
+ Language('sw', 'Swahili'),
70
+ Language('gl', 'Galician'),
71
+ Language('mr', 'Marathi'),
72
+ Language('pa', 'Punjabi'),
73
+ Language('si', 'Sinhala'),
74
+ Language('km', 'Khmer'),
75
+ Language('sn', 'Shona'),
76
+ Language('yo', 'Yoruba'),
77
+ Language('so', 'Somali'),
78
+ Language('af', 'Afrikaans'),
79
+ Language('oc', 'Occitan'),
80
+ Language('ka', 'Georgian'),
81
+ Language('be', 'Belarusian'),
82
+ Language('tg', 'Tajik'),
83
+ Language('sd', 'Sindhi'),
84
+ Language('gu', 'Gujarati'),
85
+ Language('am', 'Amharic'),
86
+ Language('yi', 'Yiddish'),
87
+ Language('lo', 'Lao'),
88
+ Language('uz', 'Uzbek'),
89
+ Language('fo', 'Faroese'),
90
+ Language('ht', 'Haitian creole'),
91
+ Language('ps', 'Pashto'),
92
+ Language('tk', 'Turkmen'),
93
+ Language('nn', 'Nynorsk'),
94
+ Language('mt', 'Maltese'),
95
+ Language('sa', 'Sanskrit'),
96
+ Language('lb', 'Luxembourgish'),
97
+ Language('my', 'Myanmar'),
98
+ Language('bo', 'Tibetan'),
99
+ Language('tl', 'Tagalog'),
100
+ Language('mg', 'Malagasy'),
101
+ Language('as', 'Assamese'),
102
+ Language('tt', 'Tatar'),
103
+ Language('haw', 'Hawaiian'),
104
+ Language('ln', 'Lingala'),
105
+ Language('ha', 'Hausa'),
106
+ Language('ba', 'Bashkir'),
107
+ Language('jw', 'Javanese'),
108
+ Language('su', 'Sundanese')
109
+ ]
110
+
111
+ _TO_LANGUAGE_CODE = {
112
+ **{language.code: language for language in LANGUAGES},
113
+ "burmese": "my",
114
+ "valencian": "ca",
115
+ "flemish": "nl",
116
+ "haitian": "ht",
117
+ "letzeburgesch": "lb",
118
+ "pushto": "ps",
119
+ "panjabi": "pa",
120
+ "moldavian": "ro",
121
+ "moldovan": "ro",
122
+ "sinhalese": "si",
123
+ "castilian": "es",
124
+ }
125
+
126
+ _FROM_LANGUAGE_NAME = {
127
+ **{language.name.lower(): language for language in LANGUAGES}
128
+ }
129
+
130
+ def get_language_from_code(language_code, default=None) -> Language:
131
+ """Return the language name from the language code."""
132
+ return _TO_LANGUAGE_CODE.get(language_code, default)
133
+
134
+ def get_language_from_name(language, default=None) -> Language:
135
+ """Return the language code from the language name."""
136
+ return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
137
+
138
+ def get_language_names():
139
+ """Return a list of language names."""
140
+ return [language.name for language in LANGUAGES]
141
+
142
+ if __name__ == "__main__":
143
+ # Test lookup
144
+ print(get_language_from_code('en'))
145
+ print(get_language_from_name('English'))
146
+
147
+ print(get_language_names())
src/modelCache.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelCache:
2
+ def __init__(self):
3
+ self._cache = dict()
4
+
5
+ def get(self, model_key: str, model_factory):
6
+ result = self._cache.get(model_key)
7
+
8
+ if result is None:
9
+ result = model_factory()
10
+ self._cache[model_key] = result
11
+ return result
12
+
13
+ def clear(self):
14
+ self._cache.clear()
15
+
16
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
17
+ GLOBAL_MODEL_CACHE = ModelCache()
src/prompts/abstractPromptStrategy.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+
4
+ class AbstractPromptStrategy:
5
+ """
6
+ Represents a strategy for generating prompts for a given audio segment.
7
+
8
+ Note that the strategy must be picklable, as it will be serialized and sent to the workers.
9
+ """
10
+
11
+ @abc.abstractmethod
12
+ def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
13
+ """
14
+ Retrieves the prompt for a given segment.
15
+
16
+ Parameters
17
+ ----------
18
+ segment_index: int
19
+ The index of the segment.
20
+ whisper_prompt: str
21
+ The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
22
+ detected_language: str
23
+ The language detected for the segment.
24
+ """
25
+ pass
26
+
27
+ @abc.abstractmethod
28
+ def on_segment_finished(self, segment_index: int, whisper_prompt: str, detected_language: str, result: dict):
29
+ """
30
+ Called when a segment has finished processing.
31
+
32
+ Parameters
33
+ ----------
34
+ segment_index: int
35
+ The index of the segment.
36
+ whisper_prompt: str
37
+ The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
38
+ detected_language: str
39
+ The language detected for the segment.
40
+ result: dict
41
+ The result of the segment. It has the following format:
42
+ {
43
+ "text": str,
44
+ "segments": [
45
+ {
46
+ "text": str,
47
+ "start": float,
48
+ "end": float,
49
+ "words": [words],
50
+ }
51
+ ],
52
+ "language": str,
53
+ }
54
+ """
55
+ pass
56
+
57
+ def _concat_prompt(self, prompt1, prompt2):
58
+ """
59
+ Concatenates two prompts.
60
+
61
+ Parameters
62
+ ----------
63
+ prompt1: str
64
+ The first prompt.
65
+ prompt2: str
66
+ The second prompt.
67
+ """
68
+ if (prompt1 is None):
69
+ return prompt2
70
+ elif (prompt2 is None):
71
+ return prompt1
72
+ else:
73
+ return prompt1 + " " + prompt2
src/prompts/jsonPromptStrategy.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict
3
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
4
+
5
+
6
+ class JsonPromptSegment():
7
+ def __init__(self, segment_index: int, prompt: str, format_prompt: bool = False):
8
+ self.prompt = prompt
9
+ self.segment_index = segment_index
10
+ self.format_prompt = format_prompt
11
+
12
+ class JsonPromptStrategy(AbstractPromptStrategy):
13
+ def __init__(self, initial_json_prompt: str):
14
+ """
15
+ Parameters
16
+ ----------
17
+ initial_json_prompt: str
18
+ The initial prompts for each segment in JSON form.
19
+
20
+ Format:
21
+ [
22
+ {"segment_index": 0, "prompt": "Hello, how are you?"},
23
+ {"segment_index": 1, "prompt": "I'm doing well, how are you?"},
24
+ {"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
25
+ ]
26
+
27
+ """
28
+ parsed_json = json.loads(initial_json_prompt)
29
+ self.segment_lookup: Dict[str, JsonPromptSegment] = dict()
30
+
31
+ for prompt_entry in parsed_json:
32
+ segment_index = prompt_entry["segment_index"]
33
+ prompt = prompt_entry["prompt"]
34
+ format_prompt = prompt_entry.get("format_prompt", False)
35
+ self.segment_lookup[str(segment_index)] = JsonPromptSegment(segment_index, prompt, format_prompt)
36
+
37
+ def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
38
+ # Lookup prompt
39
+ prompt = self.segment_lookup.get(str(segment_index), None)
40
+
41
+ if (prompt is None):
42
+ # No prompt found, return whisper prompt
43
+ print(f"Could not find prompt for segment {segment_index}, returning whisper prompt")
44
+ return whisper_prompt
45
+
46
+ if (prompt.format_prompt):
47
+ return prompt.prompt.format(whisper_prompt)
48
+ else:
49
+ return self._concat_prompt(prompt.prompt, whisper_prompt)
src/prompts/prependPromptStrategy.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.config import VadInitialPromptMode
2
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
3
+
4
+ class PrependPromptStrategy(AbstractPromptStrategy):
5
+ """
6
+ A simple prompt strategy that prepends a single prompt to all segments of audio, or prepends the prompt to the first segment of audio.
7
+ """
8
+ def __init__(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode):
9
+ """
10
+ Parameters
11
+ ----------
12
+ initial_prompt: str
13
+ The initial prompt to use for the transcription.
14
+ initial_prompt_mode: VadInitialPromptMode
15
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
16
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
17
+ """
18
+ self.initial_prompt = initial_prompt
19
+ self.initial_prompt_mode = initial_prompt_mode
20
+
21
+ # This is a simple prompt strategy, so we only support these two modes
22
+ if initial_prompt_mode not in [VadInitialPromptMode.PREPEND_ALL_SEGMENTS, VadInitialPromptMode.PREPREND_FIRST_SEGMENT]:
23
+ raise ValueError(f"Unsupported initial prompt mode {initial_prompt_mode}")
24
+
25
+ def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
26
+ if (self.initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
27
+ return self._concat_prompt(self.initial_prompt, whisper_prompt)
28
+ elif (self.initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
29
+ return self._concat_prompt(self.initial_prompt, whisper_prompt) if segment_index == 0 else whisper_prompt
30
+ else:
31
+ raise ValueError(f"Unknown initial prompt mode {self.initial_prompt_mode}")
src/segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
src/source.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
2
+ import os
3
+ import pathlib
4
+ from typing import List
5
+ import zipfile
6
+
7
+ import ffmpeg
8
+ from more_itertools import unzip
9
+
10
+ from src.download import ExceededMaximumDuration, download_url
11
+
12
+ MAX_FILE_PREFIX_LENGTH = 17
13
+
14
+ class AudioSource:
15
+ def __init__(self, source_path, source_name = None, audio_duration = None):
16
+ self.source_path = source_path
17
+ self.source_name = source_name
18
+ self._audio_duration = audio_duration
19
+
20
+ # Load source name if not provided
21
+ if (self.source_name is None):
22
+ file_path = pathlib.Path(self.source_path)
23
+ self.source_name = file_path.name
24
+
25
+ def get_audio_duration(self):
26
+ if self._audio_duration is None:
27
+ self._audio_duration = float(ffmpeg.probe(self.source_path)["format"]["duration"])
28
+
29
+ return self._audio_duration
30
+
31
+ def get_full_name(self):
32
+ return self.source_name
33
+
34
+ def get_short_name(self, max_length: int = MAX_FILE_PREFIX_LENGTH):
35
+ file_path = pathlib.Path(self.source_name)
36
+ short_name = file_path.stem[:max_length] + file_path.suffix
37
+
38
+ return short_name
39
+
40
+ def __str__(self) -> str:
41
+ return self.source_path
42
+
43
+ class AudioSourceCollection:
44
+ def __init__(self, sources: List[AudioSource]):
45
+ self.sources = sources
46
+
47
+ def __iter__(self):
48
+ return iter(self.sources)
49
+
50
+ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneData: str, input_audio_max_duration: float = -1) -> List[AudioSource]:
51
+ output: List[AudioSource] = []
52
+
53
+ if urlData:
54
+ # Download from YouTube. This could also be a playlist or a channel.
55
+ output.extend([ AudioSource(x) for x in download_url(urlData, input_audio_max_duration, playlistItems=None) ])
56
+ else:
57
+ # Add input files
58
+ if (multipleFiles is not None):
59
+ output.extend([ AudioSource(x.name) for x in multipleFiles ])
60
+ if (microphoneData is not None):
61
+ output.append(AudioSource(microphoneData))
62
+
63
+ total_duration = 0
64
+
65
+ # Calculate total audio length. We do this even if input_audio_max_duration
66
+ # is disabled to ensure that all the audio files are valid.
67
+ for source in output:
68
+ audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
69
+ total_duration += float(audioDuration)
70
+
71
+ # Save audio duration
72
+ source._audio_duration = float(audioDuration)
73
+
74
+ # Ensure the total duration of the audio is not too long
75
+ if input_audio_max_duration > 0:
76
+ if float(total_duration) > input_audio_max_duration:
77
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
78
+
79
+ # Return a list of audio sources
80
+ return output
src/utils.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO, Union
7
+ import tqdm
8
+
9
+ import urllib3
10
+
11
+
12
+ def exact_div(x, y):
13
+ assert x % y == 0
14
+ return x // y
15
+
16
+
17
+ def str2bool(string):
18
+ str2val = {"True": True, "False": False}
19
+ if string in str2val:
20
+ return str2val[string]
21
+ else:
22
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
23
+
24
+
25
+ def optional_int(string):
26
+ return None if string == "None" else int(string)
27
+
28
+
29
+ def optional_float(string):
30
+ return None if string == "None" else float(string)
31
+
32
+
33
+ def compression_ratio(text) -> float:
34
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
35
+
36
+
37
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
38
+ assert seconds >= 0, "non-negative timestamp expected"
39
+ milliseconds = round(seconds * 1000.0)
40
+
41
+ hours = milliseconds // 3_600_000
42
+ milliseconds -= hours * 3_600_000
43
+
44
+ minutes = milliseconds // 60_000
45
+ milliseconds -= minutes * 60_000
46
+
47
+ seconds = milliseconds // 1_000
48
+ milliseconds -= seconds * 1_000
49
+
50
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
51
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
52
+
53
+
54
+ def write_txt(transcript: Iterator[dict], file: TextIO):
55
+ for segment in transcript:
56
+ print(segment['text'].strip(), file=file, flush=True)
57
+
58
+
59
+ def write_vtt(transcript: Iterator[dict], file: TextIO,
60
+ maxLineWidth=None, highlight_words: bool = False):
61
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
62
+
63
+ print("WEBVTT\n", file=file)
64
+
65
+ for segment in iterator:
66
+ text = segment['text'].replace('-->', '->')
67
+
68
+ print(
69
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
70
+ f"{text}\n",
71
+ file=file,
72
+ flush=True,
73
+ )
74
+
75
+ def write_srt(transcript: Iterator[dict], file: TextIO,
76
+ maxLineWidth=None, highlight_words: bool = False):
77
+ """
78
+ Write a transcript to a file in SRT format.
79
+ Example usage:
80
+ from pathlib import Path
81
+ from whisper.utils import write_srt
82
+ result = transcribe(model, audio_path, temperature=temperature, **args)
83
+ # save SRT
84
+ audio_basename = Path(audio_path).stem
85
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
86
+ write_srt(result["segments"], file=srt)
87
+ """
88
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
89
+
90
+ for i, segment in enumerate(iterator, start=1):
91
+ text = segment['text'].replace('-->', '->')
92
+
93
+ # write srt lines
94
+ print(
95
+ f"{i}\n"
96
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
97
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
98
+ f"{text}\n",
99
+ file=file,
100
+ flush=True,
101
+ )
102
+
103
+ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
+ for segment in transcript:
105
+ words = segment.get('words', [])
106
+
107
+ if len(words) == 0:
108
+ # Yield the segment as-is or processed
109
+ if maxLineWidth is None or maxLineWidth < 0:
110
+ yield segment
111
+ else:
112
+ yield {
113
+ 'start': segment['start'],
114
+ 'end': segment['end'],
115
+ 'text': process_text(segment['text'].strip(), maxLineWidth)
116
+ }
117
+ # We are done
118
+ continue
119
+
120
+ subtitle_start = segment['start']
121
+ subtitle_end = segment['end']
122
+
123
+ text_words = [ this_word["word"] for this_word in words ]
124
+ subtitle_text = __join_words(text_words, maxLineWidth)
125
+
126
+ # Iterate over the words in the segment
127
+ if highlight_words:
128
+ last = subtitle_start
129
+
130
+ for i, this_word in enumerate(words):
131
+ start = this_word['start']
132
+ end = this_word['end']
133
+
134
+ if last != start:
135
+ # Display the text up to this point
136
+ yield {
137
+ 'start': last,
138
+ 'end': start,
139
+ 'text': subtitle_text
140
+ }
141
+
142
+ # Display the text with the current word highlighted
143
+ yield {
144
+ 'start': start,
145
+ 'end': end,
146
+ 'text': __join_words(
147
+ [
148
+ {
149
+ "word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
150
+ if j == i
151
+ else word,
152
+ # The HTML tags <u> and </u> are not displayed,
153
+ # # so they should not be counted in the word length
154
+ "length": len(word)
155
+ } for j, word in enumerate(text_words)
156
+ ], maxLineWidth)
157
+ }
158
+ last = end
159
+
160
+ if last != subtitle_end:
161
+ # Display the last part of the text
162
+ yield {
163
+ 'start': last,
164
+ 'end': subtitle_end,
165
+ 'text': subtitle_text
166
+ }
167
+
168
+ # Just return the subtitle text
169
+ else:
170
+ yield {
171
+ 'start': subtitle_start,
172
+ 'end': subtitle_end,
173
+ 'text': subtitle_text
174
+ }
175
+
176
+ def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
177
+ if maxLineWidth is None or maxLineWidth < 0:
178
+ return " ".join(words)
179
+
180
+ lines = []
181
+ current_line = ""
182
+ current_length = 0
183
+
184
+ for entry in words:
185
+ # Either accept a string or a dict with a 'word' and 'length' field
186
+ if isinstance(entry, dict):
187
+ word = entry['word']
188
+ word_length = entry['length']
189
+ else:
190
+ word = entry
191
+ word_length = len(word)
192
+
193
+ if current_length > 0 and current_length + word_length > maxLineWidth:
194
+ lines.append(current_line)
195
+ current_line = ""
196
+ current_length = 0
197
+
198
+ current_length += word_length
199
+ # The word will be prefixed with a space by Whisper, so we don't need to add one here
200
+ current_line += word
201
+
202
+ if len(current_line) > 0:
203
+ lines.append(current_line)
204
+
205
+ return "\n".join(lines)
206
+
207
+ def process_text(text: str, maxLineWidth=None):
208
+ if (maxLineWidth is None or maxLineWidth < 0):
209
+ return text
210
+
211
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
212
+ return '\n'.join(lines)
213
+
214
+ def slugify(value, allow_unicode=False):
215
+ """
216
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
217
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
218
+ dashes to single dashes. Remove characters that aren't alphanumerics,
219
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
220
+ trailing whitespace, dashes, and underscores.
221
+ """
222
+ value = str(value)
223
+ if allow_unicode:
224
+ value = unicodedata.normalize('NFKC', value)
225
+ else:
226
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
227
+ value = re.sub(r'[^\w\s-]', '', value.lower())
228
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
229
+
230
+ def download_file(url: str, destination: str):
231
+ with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
232
+ with tqdm(
233
+ total=int(source.info().get("Content-Length")),
234
+ ncols=80,
235
+ unit="iB",
236
+ unit_scale=True,
237
+ unit_divisor=1024,
238
+ ) as loop:
239
+ while True:
240
+ buffer = source.read(8192)
241
+ if not buffer:
242
+ break
243
+
244
+ output.write(buffer)
245
+ loop.update(len(buffer))
src/vad.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+ import time
4
+
5
+ from typing import Any, Deque, Iterator, List, Dict
6
+
7
+ from pprint import pprint
8
+ from src.hooks.progressListener import ProgressListener
9
+ from src.hooks.subTaskProgressListener import SubTaskProgressListener
10
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
11
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
12
+
13
+ from src.segments import merge_timestamps
14
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
15
+
16
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
17
+ try:
18
+ import tensorflow as tf
19
+ except ModuleNotFoundError:
20
+ # Error handling
21
+ pass
22
+
23
+ import torch
24
+
25
+ import ffmpeg
26
+ import numpy as np
27
+
28
+ from src.utils import format_timestamp
29
+ from enum import Enum
30
+
31
+ class NonSpeechStrategy(Enum):
32
+ """
33
+ Ignore non-speech frames segments.
34
+ """
35
+ SKIP = 1
36
+ """
37
+ Just treat non-speech segments as speech.
38
+ """
39
+ CREATE_SEGMENT = 2
40
+ """
41
+ Expand speech segments into subsequent non-speech segments.
42
+ """
43
+ EXPAND_SEGMENT = 3
44
+
45
+ # Defaults for Silero
46
+ SPEECH_TRESHOLD = 0.3
47
+
48
+ # Minimum size of segments to process
49
+ MIN_SEGMENT_DURATION = 1
50
+
51
+ # The maximum time for texts from old segments to be used in the next segment
52
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
53
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
54
+
55
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
56
+
57
+ class TranscriptionConfig(ABC):
58
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
59
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
60
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
61
+ self.non_speech_strategy = non_speech_strategy
62
+ self.segment_padding_left = segment_padding_left
63
+ self.segment_padding_right = segment_padding_right
64
+ self.max_silent_period = max_silent_period
65
+ self.max_merge_size = max_merge_size
66
+ self.max_prompt_window = max_prompt_window
67
+ self.initial_segment_index = initial_segment_index
68
+
69
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
70
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
71
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
72
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
73
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
74
+ self.periodic_duration = periodic_duration
75
+
76
+ class AbstractTranscription(ABC):
77
+ def __init__(self, sampling_rate: int = 16000):
78
+ self.sampling_rate = sampling_rate
79
+
80
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
81
+ return load_audio(str, self.sampling_rate, start_time, duration)
82
+
83
+ def is_transcribe_timestamps_fast(self):
84
+ """
85
+ Determine if get_transcribe_timestamps is fast enough to not need parallelization.
86
+ """
87
+ return False
88
+
89
+ @abstractmethod
90
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
91
+ """
92
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
93
+
94
+ Parameters
95
+ ----------
96
+ audio: str
97
+ The audio file.
98
+ config: TranscriptionConfig
99
+ The transcription configuration.
100
+
101
+ Returns
102
+ -------
103
+ A list of start and end timestamps, in fractional seconds.
104
+ """
105
+ return
106
+
107
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
108
+ """
109
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method,
110
+ after merging the given segments using the specified configuration.
111
+
112
+ Parameters
113
+ ----------
114
+ audio: str
115
+ The audio file.
116
+ config: TranscriptionConfig
117
+ The transcription configuration.
118
+
119
+ Returns
120
+ -------
121
+ A list of start and end timestamps, in fractional seconds.
122
+ """
123
+ merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
124
+ config.segment_padding_left, config.segment_padding_right)
125
+
126
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
127
+ # Expand segments to include the gaps between them
128
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
129
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
130
+ merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
131
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
132
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
133
+ merged = self.expand_gaps(merged, total_duration=total_duration)
134
+ else:
135
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
136
+
137
+ print("Transcribing non-speech:")
138
+ pprint(merged)
139
+ return merged
140
+
141
+ def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
142
+ progressListener: ProgressListener = None):
143
+ """
144
+ Transcribe the given audo file.
145
+
146
+ Parameters
147
+ ----------
148
+ audio: str
149
+ The audio file.
150
+ whisperCallable: WhisperCallback
151
+ A callback object to call to transcribe each segment.
152
+
153
+ Returns
154
+ -------
155
+ A list of start and end timestamps, in fractional seconds.
156
+ """
157
+
158
+ try:
159
+ max_audio_duration = self.get_audio_duration(audio, config)
160
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
161
+
162
+ # Get speech timestamps from full audio file
163
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
164
+
165
+ # A deque of transcribed segments that is passed to the next segment as a prompt
166
+ prompt_window = deque()
167
+
168
+ print("Processing timestamps:")
169
+ pprint(merged)
170
+
171
+ result = {
172
+ 'text': "",
173
+ 'segments': [],
174
+ 'language': ""
175
+ }
176
+ languageCounter = Counter()
177
+ detected_language = None
178
+
179
+ segment_index = config.initial_segment_index
180
+
181
+ # Calculate progress
182
+ progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0
183
+ progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged])
184
+
185
+ # For each time segment, run whisper
186
+ for segment in merged:
187
+ segment_index += 1
188
+ segment_start = segment['start']
189
+ segment_end = segment['end']
190
+ segment_expand_amount = segment.get('expand_amount', 0)
191
+ segment_gap = segment.get('gap', False)
192
+
193
+ segment_duration = segment_end - segment_start
194
+
195
+ if segment_duration < MIN_SEGMENT_DURATION:
196
+ continue
197
+
198
+ # Audio to run on Whisper
199
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
200
+ # Previous segments to use as a prompt
201
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
202
+
203
+ # Detected language
204
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
205
+
206
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
207
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
208
+
209
+ perf_start_time = time.perf_counter()
210
+
211
+ scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=progress_total_duration,
212
+ sub_task_start=segment_start - progress_start_offset, sub_task_total=segment_duration)
213
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
214
+
215
+ perf_end_time = time.perf_counter()
216
+ print("Whisper took {} seconds".format(perf_end_time - perf_start_time))
217
+
218
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
219
+
220
+ # Propagate expand amount to the segments
221
+ if (segment_expand_amount > 0):
222
+ segment_without_expansion = segment_duration - segment_expand_amount
223
+
224
+ for adjusted_segment in adjusted_segments:
225
+ adjusted_segment_end = adjusted_segment['end']
226
+
227
+ # Add expand amount if the segment got expanded
228
+ if (adjusted_segment_end > segment_without_expansion):
229
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
230
+
231
+ # Append to output
232
+ result['text'] += segment_result['text']
233
+ result['segments'].extend(adjusted_segments)
234
+
235
+ # Increment detected language
236
+ if not segment_gap:
237
+ languageCounter[segment_result['language']] += 1
238
+
239
+ # Update prompt window
240
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
241
+
242
+ if detected_language is not None:
243
+ result['language'] = detected_language
244
+ finally:
245
+ # Notify progress listener that we are done
246
+ if progressListener is not None:
247
+ progressListener.on_finished()
248
+ return result
249
+
250
+ def get_audio_duration(self, audio: str, config: TranscriptionConfig):
251
+ return get_audio_duration(audio)
252
+
253
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
254
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
255
+ # Add segments to the current prompt window (unless it is a speech gap)
256
+ if not segment_gap:
257
+ for segment in adjusted_segments:
258
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
259
+ prompt_window.append(segment)
260
+
261
+ while (len(prompt_window) > 0):
262
+ first_end_time = prompt_window[0].get('end', 0)
263
+ # Time expanded in the segments should be discounted from the prompt window
264
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
265
+
266
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
267
+ prompt_window.popleft()
268
+ else:
269
+ break
270
+
271
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
272
+ result = []
273
+ last_end_time = 0
274
+
275
+ for segment in segments:
276
+ segment_start = float(segment['start'])
277
+ segment_end = float(segment['end'])
278
+
279
+ if (last_end_time != segment_start):
280
+ delta = segment_start - last_end_time
281
+
282
+ if (min_gap_length is None or delta >= min_gap_length):
283
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
284
+
285
+ last_end_time = segment_end
286
+ result.append(segment)
287
+
288
+ # Also include total duration if specified
289
+ if (total_duration is not None and last_end_time < total_duration):
290
+ delta = total_duration - segment_start
291
+
292
+ if (min_gap_length is None or delta >= min_gap_length):
293
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
294
+
295
+ return result
296
+
297
+ # Expand the end time of each segment to the start of the next segment
298
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
299
+ result = []
300
+
301
+ if len(segments) == 0:
302
+ return result
303
+
304
+ # Add gap at the beginning if needed
305
+ if (segments[0]['start'] > 0):
306
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
307
+
308
+ for i in range(len(segments) - 1):
309
+ current_segment = segments[i]
310
+ next_segment = segments[i + 1]
311
+
312
+ delta = next_segment['start'] - current_segment['end']
313
+
314
+ # Expand if the gap actually exists
315
+ if (delta >= 0):
316
+ current_segment = current_segment.copy()
317
+ current_segment['expand_amount'] = delta
318
+ current_segment['end'] = next_segment['start']
319
+
320
+ result.append(current_segment)
321
+
322
+ # Add last segment
323
+ last_segment = segments[-1]
324
+ result.append(last_segment)
325
+
326
+ # Also include total duration if specified
327
+ if (total_duration is not None):
328
+ last_segment = result[-1]
329
+
330
+ if (last_segment['end'] < total_duration):
331
+ last_segment = last_segment.copy()
332
+ last_segment['end'] = total_duration
333
+ result[-1] = last_segment
334
+
335
+ return result
336
+
337
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
338
+ result = []
339
+
340
+ if len(segments) == 0:
341
+ return result
342
+
343
+ # Add gap at the beginning if needed
344
+ if (segments[0]['start'] > 0):
345
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
346
+
347
+ for i in range(len(segments) - 1):
348
+ expanded = False
349
+ current_segment = segments[i]
350
+ next_segment = segments[i + 1]
351
+
352
+ delta = next_segment['start'] - current_segment['end']
353
+
354
+ if (max_expand_size is not None and delta <= max_expand_size):
355
+ # Just expand the current segment
356
+ current_segment = current_segment.copy()
357
+ current_segment['expand_amount'] = delta
358
+ current_segment['end'] = next_segment['start']
359
+ expanded = True
360
+
361
+ result.append(current_segment)
362
+
363
+ # Add a gap to the next segment if needed
364
+ if (delta >= 0 and not expanded):
365
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
366
+
367
+ # Add last segment
368
+ last_segment = segments[-1]
369
+ result.append(last_segment)
370
+
371
+ # Also include total duration if specified
372
+ if (total_duration is not None):
373
+ last_segment = result[-1]
374
+
375
+ delta = total_duration - last_segment['end']
376
+
377
+ if (delta > 0):
378
+ if (max_expand_size is not None and delta <= max_expand_size):
379
+ # Expand the last segment
380
+ last_segment = last_segment.copy()
381
+ last_segment['expand_amount'] = delta
382
+ last_segment['end'] = total_duration
383
+ result[-1] = last_segment
384
+ else:
385
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
386
+
387
+ return result
388
+
389
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
390
+ result = []
391
+
392
+ for segment in segments:
393
+ segment_start = float(segment['start'])
394
+ segment_end = float(segment['end'])
395
+
396
+ # Filter segments?
397
+ if (max_source_time is not None):
398
+ if (segment_start > max_source_time):
399
+ continue
400
+ segment_end = min(max_source_time, segment_end)
401
+
402
+ new_segment = segment.copy()
403
+
404
+ # Add to start and end
405
+ new_segment['start'] = segment_start + adjust_seconds
406
+ new_segment['end'] = segment_end + adjust_seconds
407
+
408
+ # Handle words
409
+ if ('words' in new_segment):
410
+ for word in new_segment['words']:
411
+ # Adjust start and end
412
+ word['start'] = word['start'] + adjust_seconds
413
+ word['end'] = word['end'] + adjust_seconds
414
+
415
+ result.append(new_segment)
416
+ return result
417
+
418
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
419
+ result = []
420
+
421
+ for entry in timestamps:
422
+ start = entry['start']
423
+ end = entry['end']
424
+
425
+ result.append({
426
+ 'start': start * factor,
427
+ 'end': end * factor
428
+ })
429
+ return result
430
+
431
+
432
+ class VadSileroTranscription(AbstractTranscription):
433
+ def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
434
+ super().__init__(sampling_rate=sampling_rate)
435
+ self.model = None
436
+ self.cache = cache
437
+ self._initialize_model()
438
+
439
+ def _initialize_model(self):
440
+ if (self.cache is not None):
441
+ model_key = "VadSileroTranscription"
442
+ self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
443
+ print("Loaded Silerio model from cache.")
444
+ else:
445
+ self.model, self.get_speech_timestamps = self._create_model()
446
+ print("Created Silerio model")
447
+
448
+ def _create_model(self):
449
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
450
+
451
+ # Silero does not benefit from multi-threading
452
+ torch.set_num_threads(1) # JIT
453
+ (get_speech_timestamps, _, _, _, _) = utils
454
+
455
+ return model, get_speech_timestamps
456
+
457
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
458
+ result = []
459
+
460
+ print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
461
+ perf_start_time = time.perf_counter()
462
+
463
+ # Divide procesisng of audio into chunks
464
+ chunk_start = start_time
465
+
466
+ while (chunk_start < end_time):
467
+ chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
468
+
469
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
470
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
471
+
472
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
473
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
474
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
475
+
476
+ #pprint(adjusted)
477
+
478
+ result.extend(adjusted)
479
+ chunk_start += chunk_duration
480
+
481
+ perf_end_time = time.perf_counter()
482
+ print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
483
+
484
+ return result
485
+
486
+ def __getstate__(self):
487
+ # We only need the sampling rate
488
+ return { 'sampling_rate': self.sampling_rate }
489
+
490
+ def __setstate__(self, state):
491
+ self.sampling_rate = state['sampling_rate']
492
+ self.model = None
493
+ # Use the global cache
494
+ self.cache = GLOBAL_MODEL_CACHE
495
+ self._initialize_model()
496
+
497
+ # A very simple VAD that just marks every N seconds as speech
498
+ class VadPeriodicTranscription(AbstractTranscription):
499
+ def __init__(self, sampling_rate: int = 16000):
500
+ super().__init__(sampling_rate=sampling_rate)
501
+
502
+ def is_transcribe_timestamps_fast(self):
503
+ # This is a very fast VAD - no need to parallelize it
504
+ return True
505
+
506
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
507
+ result = []
508
+
509
+ # Generate a timestamp every N seconds
510
+ start_timestamp = start_time
511
+
512
+ while (start_timestamp < end_time):
513
+ end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
514
+ segment_duration = end_timestamp - start_timestamp
515
+
516
+ # Minimum duration is 1 second
517
+ if (segment_duration >= 1):
518
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
519
+
520
+ start_timestamp = end_timestamp
521
+
522
+ return result
523
+
524
+ def get_audio_duration(file: str):
525
+ return float(ffmpeg.probe(file)["format"]["duration"])
526
+
527
+ def load_audio(file: str, sample_rate: int = 16000,
528
+ start_time: str = None, duration: str = None):
529
+ """
530
+ Open an audio file and read as mono waveform, resampling as necessary
531
+
532
+ Parameters
533
+ ----------
534
+ file: str
535
+ The audio file to open
536
+
537
+ sr: int
538
+ The sample rate to resample the audio if necessary
539
+
540
+ start_time: str
541
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
542
+
543
+ duration: str
544
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
545
+
546
+ Returns
547
+ -------
548
+ A NumPy array containing the audio waveform, in float32 dtype.
549
+ """
550
+ try:
551
+ inputArgs = {'threads': 0}
552
+
553
+ if (start_time is not None):
554
+ inputArgs['ss'] = start_time
555
+ if (duration is not None):
556
+ inputArgs['t'] = duration
557
+
558
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
559
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
560
+ out, _ = (
561
+ ffmpeg.input(file, **inputArgs)
562
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
563
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
564
+ )
565
+ except ffmpeg.Error as e:
566
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
567
+
568
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
src/vadParallel.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from queue import Empty
3
+ import threading
4
+ import time
5
+ from src.hooks.progressListener import ProgressListener
6
+ from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
7
+
8
+ from multiprocessing import Pool, Queue
9
+
10
+ from typing import Any, Dict, List, Union
11
+ import os
12
+
13
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
14
+
15
+ class _ProgressListenerToQueue(ProgressListener):
16
+ def __init__(self, progress_queue: Queue):
17
+ self.progress_queue = progress_queue
18
+ self.progress_total = 0
19
+ self.prev_progress = 0
20
+
21
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
22
+ delta = current - self.prev_progress
23
+ self.prev_progress = current
24
+ self.progress_total = total
25
+ self.progress_queue.put(delta)
26
+
27
+ def on_finished(self):
28
+ if self.progress_total > self.prev_progress:
29
+ delta = self.progress_total - self.prev_progress
30
+ self.progress_queue.put(delta)
31
+ self.prev_progress = self.progress_total
32
+
33
+ class ParallelContext:
34
+ def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
35
+ self.num_processes = num_processes
36
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
37
+ self.lock = threading.Lock()
38
+
39
+ self.ref_count = 0
40
+ self.pool = None
41
+ self.cleanup_timer = None
42
+
43
+ def get_pool(self):
44
+ # Initialize pool lazily
45
+ if (self.pool is None):
46
+ context = multiprocessing.get_context('spawn')
47
+ self.pool = context.Pool(self.num_processes)
48
+
49
+ self.ref_count = self.ref_count + 1
50
+
51
+ if (self.auto_cleanup_timeout_seconds is not None):
52
+ self._stop_auto_cleanup()
53
+
54
+ return self.pool
55
+
56
+ def return_pool(self, pool):
57
+ if (self.pool == pool and self.ref_count > 0):
58
+ self.ref_count = self.ref_count - 1
59
+
60
+ if (self.ref_count == 0):
61
+ if (self.auto_cleanup_timeout_seconds is not None):
62
+ self._start_auto_cleanup()
63
+
64
+ def _start_auto_cleanup(self):
65
+ if (self.cleanup_timer is not None):
66
+ self.cleanup_timer.cancel()
67
+ self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
68
+ self.cleanup_timer.start()
69
+
70
+ print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
71
+
72
+ def _stop_auto_cleanup(self):
73
+ if (self.cleanup_timer is not None):
74
+ self.cleanup_timer.cancel()
75
+ self.cleanup_timer = None
76
+
77
+ print("Stopped auto cleanup of pool")
78
+
79
+ def _execute_cleanup(self):
80
+ print("Executing cleanup of pool")
81
+
82
+ if (self.ref_count == 0):
83
+ self.close()
84
+
85
+ def close(self):
86
+ self._stop_auto_cleanup()
87
+
88
+ if (self.pool is not None):
89
+ print("Closing pool of " + str(self.num_processes) + " processes")
90
+ self.pool.close()
91
+ self.pool.join()
92
+ self.pool = None
93
+
94
+ class ParallelTranscriptionConfig(TranscriptionConfig):
95
+ def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
96
+ super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
97
+ self.device_id = device_id
98
+ self.override_timestamps = override_timestamps
99
+
100
+ class ParallelTranscription(AbstractTranscription):
101
+ # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
102
+ # into smaller segments than 2 minute (min 6 seconds per CPU core)
103
+ MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
104
+
105
+ def __init__(self, sampling_rate: int = 16000):
106
+ super().__init__(sampling_rate=sampling_rate)
107
+
108
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
109
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
110
+ progress_listener: ProgressListener = None):
111
+ total_duration = get_audio_duration(audio)
112
+
113
+ # First, get the timestamps for the original audio
114
+ if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
115
+ merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
116
+ else:
117
+ timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
118
+ merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
119
+
120
+ # We must make sure the whisper model is downloaded
121
+ if (len(gpu_devices) > 1):
122
+ whisperCallable.model_container.ensure_downloaded()
123
+
124
+ # Split into a list for each device
125
+ # TODO: Split by time instead of by number of chunks
126
+ merged_split = list(self._split(merged, len(gpu_devices)))
127
+
128
+ # Parameters that will be passed to the transcribe function
129
+ parameters = []
130
+ segment_index = config.initial_segment_index
131
+
132
+ processing_manager = multiprocessing.Manager()
133
+ progress_queue = processing_manager.Queue()
134
+
135
+ for i in range(len(gpu_devices)):
136
+ # Note that device_segment_list can be empty. But we will still create a process for it,
137
+ # as otherwise we run the risk of assigning the same device to multiple processes.
138
+ device_segment_list = list(merged_split[i]) if i < len(merged_split) else []
139
+ device_id = gpu_devices[i]
140
+
141
+ print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
142
+
143
+ # Create a new config with the given device ID
144
+ device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
145
+ segment_index += len(device_segment_list)
146
+
147
+ progress_listener_to_queue = _ProgressListenerToQueue(progress_queue)
148
+ parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]);
149
+
150
+ merged = {
151
+ 'text': '',
152
+ 'segments': [],
153
+ 'language': None
154
+ }
155
+
156
+ created_context = False
157
+
158
+ perf_start_gpu = time.perf_counter()
159
+
160
+ # Spawn a separate process for each device
161
+ try:
162
+ if (gpu_parallel_context is None):
163
+ gpu_parallel_context = ParallelContext(len(gpu_devices))
164
+ created_context = True
165
+
166
+ # Get a pool of processes
167
+ pool = gpu_parallel_context.get_pool()
168
+
169
+ # Run the transcription in parallel
170
+ results_async = pool.starmap_async(self.transcribe, parameters)
171
+ total_progress = 0
172
+
173
+ while not results_async.ready():
174
+ try:
175
+ delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds
176
+ except Empty:
177
+ continue
178
+
179
+ total_progress += delta
180
+ if progress_listener is not None:
181
+ progress_listener.on_progress(total_progress, total_duration)
182
+
183
+ results = results_async.get()
184
+
185
+ # Call the finished callback
186
+ if progress_listener is not None:
187
+ progress_listener.on_finished()
188
+
189
+ for result in results:
190
+ # Merge the results
191
+ if (result['text'] is not None):
192
+ merged['text'] += result['text']
193
+ if (result['segments'] is not None):
194
+ merged['segments'].extend(result['segments'])
195
+ if (result['language'] is not None):
196
+ merged['language'] = result['language']
197
+
198
+ finally:
199
+ # Return the pool to the context
200
+ if (gpu_parallel_context is not None):
201
+ gpu_parallel_context.return_pool(pool)
202
+ # Always close the context if we created it
203
+ if (created_context):
204
+ gpu_parallel_context.close()
205
+
206
+ perf_end_gpu = time.perf_counter()
207
+ print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
208
+
209
+ return merged
210
+
211
+ def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
212
+ cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
213
+ parameters = []
214
+
215
+ chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
216
+ chunk_start = 0
217
+ cpu_device_id = 0
218
+
219
+ perf_start_time = time.perf_counter()
220
+
221
+ # Create chunks that will be processed on the CPU
222
+ while (chunk_start < total_duration):
223
+ chunk_end = min(chunk_start + chunk_size, total_duration)
224
+
225
+ if (chunk_end - chunk_start < 1):
226
+ # No need to process chunks that are less than 1 second
227
+ break
228
+
229
+ print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
230
+ str(chunk_end) + " on CPU device " + str(cpu_device_id))
231
+ parameters.append([audio, config, chunk_start, chunk_end]);
232
+
233
+ cpu_device_id += 1
234
+ chunk_start = chunk_end
235
+
236
+ created_context = False
237
+
238
+ # Spawn a separate process for each device
239
+ try:
240
+ if (cpu_parallel_context is None):
241
+ cpu_parallel_context = ParallelContext(cpu_device_count)
242
+ created_context = True
243
+
244
+ # Get a pool of processes
245
+ pool = cpu_parallel_context.get_pool()
246
+
247
+ # Run the transcription in parallel. Note that transcription must be picklable.
248
+ results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
249
+
250
+ timestamps = []
251
+
252
+ # Flatten the results
253
+ for result in results:
254
+ timestamps.extend(result)
255
+
256
+ merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
257
+
258
+ perf_end_time = time.perf_counter()
259
+ print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
260
+ return merged
261
+
262
+ finally:
263
+ # Return the pool to the context
264
+ if (cpu_parallel_context is not None):
265
+ cpu_parallel_context.return_pool(pool)
266
+ # Always close the context if we created it
267
+ if (created_context):
268
+ cpu_parallel_context.close()
269
+
270
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
271
+ return []
272
+
273
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
274
+ # Override timestamps that will be processed
275
+ if (config.override_timestamps is not None):
276
+ print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps)))
277
+ return config.override_timestamps
278
+ return super().get_merged_timestamps(timestamps, config, total_duration)
279
+
280
+ def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: ParallelTranscriptionConfig,
281
+ progressListener: ProgressListener = None):
282
+ # Override device ID the first time
283
+ if (os.environ.get("INITIALIZED", None) is None):
284
+ os.environ["INITIALIZED"] = "1"
285
+
286
+ # Note that this may be None if the user didn't specify a device. In that case, Whisper will
287
+ # just use the default GPU device.
288
+ if (config.device_id is not None):
289
+ print("Using device " + config.device_id)
290
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
291
+
292
+ return super().transcribe(audio, whisperCallable, config, progressListener)
293
+
294
+ def _split(self, a, n):
295
+ """Split a list into n approximately equal parts."""
296
+ k, m = divmod(len(a), n)
297
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
298
+
src/whisper/abstractWhisperContainer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List
3
+
4
+ from src.config import ModelConfig, VadInitialPromptMode
5
+
6
+ from src.hooks.progressListener import ProgressListener
7
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
8
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
9
+
10
+ class AbstractWhisperCallback:
11
+ def __init__(self):
12
+ self.__prompt_mode_gpt = None
13
+
14
+ @abc.abstractmethod
15
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
16
+ """
17
+ Peform the transcription of the given audio file or data.
18
+
19
+ Parameters
20
+ ----------
21
+ audio: Union[str, np.ndarray, torch.Tensor]
22
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
23
+ segment_index: int
24
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
25
+ task: str
26
+ The task - either translate or transcribe.
27
+ progress_listener: ProgressListener
28
+ A callback to receive progress updates.
29
+ """
30
+ raise NotImplementedError()
31
+
32
+ class AbstractWhisperContainer:
33
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
34
+ download_root: str = None,
35
+ cache: ModelCache = None, models: List[ModelConfig] = []):
36
+ self.model_name = model_name
37
+ self.device = device
38
+ self.compute_type = compute_type
39
+ self.download_root = download_root
40
+ self.cache = cache
41
+
42
+ # Will be created on demand
43
+ self.model = None
44
+
45
+ # List of known models
46
+ self.models = models
47
+
48
+ def get_model(self):
49
+ if self.model is None:
50
+
51
+ if (self.cache is None):
52
+ self.model = self._create_model()
53
+ else:
54
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
55
+ self.model = self.cache.get(model_key, self._create_model)
56
+ return self.model
57
+
58
+ @abc.abstractmethod
59
+ def _create_model(self):
60
+ raise NotImplementedError()
61
+
62
+ def ensure_downloaded(self):
63
+ pass
64
+
65
+ @abc.abstractmethod
66
+ def create_callback(self, language: str = None, task: str = None,
67
+ prompt_strategy: AbstractPromptStrategy = None,
68
+ **decodeOptions: dict) -> AbstractWhisperCallback:
69
+ """
70
+ Create a WhisperCallback object that can be used to transcript audio files.
71
+
72
+ Parameters
73
+ ----------
74
+ language: str
75
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
76
+ task: str
77
+ The task - either translate or transcribe.
78
+ prompt_strategy: AbstractPromptStrategy
79
+ The prompt strategy to use for the transcription.
80
+ decodeOptions: dict
81
+ Additional options to pass to the decoder. Must be pickleable.
82
+
83
+ Returns
84
+ -------
85
+ A WhisperCallback object.
86
+ """
87
+ raise NotImplementedError()
88
+
89
+ # This is required for multiprocessing
90
+ def __getstate__(self):
91
+ return {
92
+ "model_name": self.model_name,
93
+ "device": self.device,
94
+ "download_root": self.download_root,
95
+ "models": self.models,
96
+ "compute_type": self.compute_type
97
+ }
98
+
99
+ def __setstate__(self, state):
100
+ self.model_name = state["model_name"]
101
+ self.device = state["device"]
102
+ self.download_root = state["download_root"]
103
+ self.models = state["models"]
104
+ self.compute_type = state["compute_type"]
105
+ self.model = None
106
+ # Depickled objects must use the global cache
107
+ self.cache = GLOBAL_MODEL_CACHE
src/whisper/fasterWhisperContainer.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union
3
+
4
+ from faster_whisper import WhisperModel, download_model
5
+ from src.config import ModelConfig, VadInitialPromptMode
6
+ from src.hooks.progressListener import ProgressListener
7
+ from src.languages import get_language_from_name
8
+ from src.modelCache import ModelCache
9
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
10
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
11
+ from src.utils import format_timestamp
12
+
13
+ class FasterWhisperContainer(AbstractWhisperContainer):
14
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
15
+ download_root: str = None,
16
+ cache: ModelCache = None, models: List[ModelConfig] = []):
17
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
18
+
19
+ def ensure_downloaded(self):
20
+ """
21
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
22
+ passing the container to a subprocess.
23
+ """
24
+ model_config = self._get_model_config()
25
+
26
+ if os.path.isdir(model_config.url):
27
+ model_config.path = model_config.url
28
+ else:
29
+ model_config.path = download_model(model_config.url, output_dir=self.download_root)
30
+
31
+ def _get_model_config(self) -> ModelConfig:
32
+ """
33
+ Get the model configuration for the model.
34
+ """
35
+ for model in self.models:
36
+ if model.name == self.model_name:
37
+ return model
38
+ return None
39
+
40
+ def _create_model(self):
41
+ print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
42
+ model_config = self._get_model_config()
43
+ model_url = model_config.url
44
+
45
+ if model_config.type == "whisper":
46
+ if model_url not in ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]:
47
+ raise Exception("FasterWhisperContainer does not yet support Whisper models. Use ct2-transformers-converter to convert the model to a faster-whisper model.")
48
+ if model_url == "large":
49
+ # large is an alias for large-v1
50
+ model_url = "large-v1"
51
+
52
+ device = self.device
53
+
54
+ if (device is None):
55
+ device = "auto"
56
+
57
+ model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
58
+ return model
59
+
60
+ def create_callback(self, language: str = None, task: str = None,
61
+ prompt_strategy: AbstractPromptStrategy = None,
62
+ **decodeOptions: dict) -> AbstractWhisperCallback:
63
+ """
64
+ Create a WhisperCallback object that can be used to transcript audio files.
65
+
66
+ Parameters
67
+ ----------
68
+ language: str
69
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
70
+ task: str
71
+ The task - either translate or transcribe.
72
+ prompt_strategy: AbstractPromptStrategy
73
+ The prompt strategy to use. If not specified, the prompt from Whisper will be used.
74
+ decodeOptions: dict
75
+ Additional options to pass to the decoder. Must be pickleable.
76
+
77
+ Returns
78
+ -------
79
+ A WhisperCallback object.
80
+ """
81
+ return FasterWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
82
+
83
+ class FasterWhisperCallback(AbstractWhisperCallback):
84
+ def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
85
+ prompt_strategy: AbstractPromptStrategy = None,
86
+ **decodeOptions: dict):
87
+ self.model_container = model_container
88
+ self.language = language
89
+ self.task = task
90
+ self.prompt_strategy = prompt_strategy
91
+ self.decodeOptions = decodeOptions
92
+
93
+ self._printed_warning = False
94
+
95
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
96
+ """
97
+ Peform the transcription of the given audio file or data.
98
+
99
+ Parameters
100
+ ----------
101
+ audio: Union[str, np.ndarray, torch.Tensor]
102
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
103
+ segment_index: int
104
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
105
+ task: str
106
+ The task - either translate or transcribe.
107
+ progress_listener: ProgressListener
108
+ A callback to receive progress updates.
109
+ """
110
+ model: WhisperModel = self.model_container.get_model()
111
+ language_code = self._lookup_language_code(self.language) if self.language else None
112
+
113
+ # Copy decode options and remove options that are not supported by faster-whisper
114
+ decodeOptions = self.decodeOptions.copy()
115
+ verbose = decodeOptions.pop("verbose", None)
116
+
117
+ logprob_threshold = decodeOptions.pop("logprob_threshold", None)
118
+
119
+ patience = decodeOptions.pop("patience", None)
120
+ length_penalty = decodeOptions.pop("length_penalty", None)
121
+ suppress_tokens = decodeOptions.pop("suppress_tokens", None)
122
+
123
+ if (decodeOptions.pop("fp16", None) is not None):
124
+ if not self._printed_warning:
125
+ print("WARNING: fp16 option is ignored by faster-whisper - use compute_type instead.")
126
+ self._printed_warning = True
127
+
128
+ # Fix up decode options
129
+ if (logprob_threshold is not None):
130
+ decodeOptions["log_prob_threshold"] = logprob_threshold
131
+
132
+ decodeOptions["patience"] = float(patience) if patience is not None else 1.0
133
+ decodeOptions["length_penalty"] = float(length_penalty) if length_penalty is not None else 1.0
134
+
135
+ # See if supress_tokens is a string - if so, convert it to a list of ints
136
+ decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
137
+
138
+ initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
139
+ if self.prompt_strategy else prompt
140
+
141
+ segments_generator, info = model.transcribe(audio, \
142
+ language=language_code if language_code else detected_language, task=self.task, \
143
+ initial_prompt=initial_prompt, \
144
+ **decodeOptions
145
+ )
146
+
147
+ segments = []
148
+
149
+ for segment in segments_generator:
150
+ segments.append(segment)
151
+
152
+ if progress_listener is not None:
153
+ progress_listener.on_progress(segment.end, info.duration)
154
+ if verbose:
155
+ print("[{}->{}] {}".format(format_timestamp(segment.start, True), format_timestamp(segment.end, True),
156
+ segment.text))
157
+
158
+ text = " ".join([segment.text for segment in segments])
159
+
160
+ # Convert the segments to a format that is easier to serialize
161
+ whisper_segments = [{
162
+ "text": segment.text,
163
+ "start": segment.start,
164
+ "end": segment.end,
165
+
166
+ # Extra fields added by faster-whisper
167
+ "words": [{
168
+ "start": word.start,
169
+ "end": word.end,
170
+ "word": word.word,
171
+ "probability": word.probability
172
+ } for word in (segment.words if segment.words is not None else []) ]
173
+ } for segment in segments]
174
+
175
+ result = {
176
+ "segments": whisper_segments,
177
+ "text": text,
178
+ "language": info.language if info else None,
179
+
180
+ # Extra fields added by faster-whisper
181
+ "language_probability": info.language_probability if info else None,
182
+ "duration": info.duration if info else None
183
+ }
184
+
185
+ # If we have a prompt strategy, we need to increment the current prompt
186
+ if self.prompt_strategy:
187
+ self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
188
+
189
+ if progress_listener is not None:
190
+ progress_listener.on_finished()
191
+ return result
192
+
193
+ def _split_suppress_tokens(self, suppress_tokens: Union[str, List[int]]):
194
+ if (suppress_tokens is None):
195
+ return None
196
+ if (isinstance(suppress_tokens, list)):
197
+ return suppress_tokens
198
+
199
+ return [int(token) for token in suppress_tokens.split(",")]
200
+
201
+ def _lookup_language_code(self, language: str):
202
+ language = get_language_from_name(language)
203
+
204
+ if language is None:
205
+ raise ValueError("Invalid language: " + language)
206
+
207
+ return language.code
src/whisper/whisperContainer.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import abc
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from urllib.parse import urlparse
7
+ import torch
8
+ import urllib3
9
+ from src.hooks.progressListener import ProgressListener
10
+
11
+ import whisper
12
+ from whisper import Whisper
13
+
14
+ from src.config import ModelConfig, VadInitialPromptMode
15
+ from src.hooks.whisperProgressHook import create_progress_listener_handle
16
+
17
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
18
+ from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
19
+ from src.utils import download_file
20
+ from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
21
+
22
+ class WhisperContainer(AbstractWhisperContainer):
23
+ def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
24
+ download_root: str = None,
25
+ cache: ModelCache = None, models: List[ModelConfig] = []):
26
+ if device is None:
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ super().__init__(model_name, device, compute_type, download_root, cache, models)
29
+
30
+ def ensure_downloaded(self):
31
+ """
32
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
33
+ passing the container to a subprocess.
34
+ """
35
+ # Warning: Using private API here
36
+ try:
37
+ root_dir = self.download_root
38
+ model_config = self._get_model_config()
39
+
40
+ if root_dir is None:
41
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
42
+
43
+ if self.model_name in whisper._MODELS:
44
+ whisper._download(whisper._MODELS[self.model_name], root_dir, False)
45
+ else:
46
+ # If the model is not in the official list, see if it needs to be downloaded
47
+ model_config.download_url(root_dir)
48
+ return True
49
+
50
+ except Exception as e:
51
+ # Given that the API is private, it could change at any time. We don't want to crash the program
52
+ print("Error pre-downloading model: " + str(e))
53
+ return False
54
+
55
+ def _get_model_config(self) -> ModelConfig:
56
+ """
57
+ Get the model configuration for the model.
58
+ """
59
+ for model in self.models:
60
+ if model.name == self.model_name:
61
+ return model
62
+ return None
63
+
64
+ def _create_model(self):
65
+ print("Loading whisper model " + self.model_name)
66
+ model_config = self._get_model_config()
67
+
68
+ # Note that the model will not be downloaded in the case of an official Whisper model
69
+ model_path = self._get_model_path(model_config, self.download_root)
70
+
71
+ return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
72
+
73
+ def create_callback(self, language: str = None, task: str = None,
74
+ prompt_strategy: AbstractPromptStrategy = None,
75
+ **decodeOptions: dict) -> AbstractWhisperCallback:
76
+ """
77
+ Create a WhisperCallback object that can be used to transcript audio files.
78
+
79
+ Parameters
80
+ ----------
81
+ language: str
82
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
83
+ task: str
84
+ The task - either translate or transcribe.
85
+ prompt_strategy: AbstractPromptStrategy
86
+ The prompt strategy to use. If not specified, the prompt from Whisper will be used.
87
+ decodeOptions: dict
88
+ Additional options to pass to the decoder. Must be pickleable.
89
+
90
+ Returns
91
+ -------
92
+ A WhisperCallback object.
93
+ """
94
+ return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
95
+
96
+ def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
97
+ from src.conversion.hf_converter import convert_hf_whisper
98
+ """
99
+ Download the model.
100
+
101
+ Parameters
102
+ ----------
103
+ model_config: ModelConfig
104
+ The model configuration.
105
+ """
106
+ # See if path is already set
107
+ if model_config.path is not None:
108
+ return model_config.path
109
+
110
+ if root_dir is None:
111
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
112
+
113
+ model_type = model_config.type.lower() if model_config.type is not None else "whisper"
114
+
115
+ if model_type in ["huggingface", "hf"]:
116
+ model_config.path = model_config.url
117
+ destination_target = os.path.join(root_dir, model_config.name + ".pt")
118
+
119
+ # Convert from HuggingFace format to Whisper format
120
+ if os.path.exists(destination_target):
121
+ print(f"File {destination_target} already exists, skipping conversion")
122
+ else:
123
+ print("Saving HuggingFace model in Whisper format to " + destination_target)
124
+ convert_hf_whisper(model_config.url, destination_target)
125
+
126
+ model_config.path = destination_target
127
+
128
+ elif model_type in ["whisper", "w"]:
129
+ model_config.path = model_config.url
130
+
131
+ # See if URL is just a file
132
+ if model_config.url in whisper._MODELS:
133
+ # No need to download anything - Whisper will handle it
134
+ model_config.path = model_config.url
135
+ elif model_config.url.startswith("file://"):
136
+ # Get file path
137
+ model_config.path = urlparse(model_config.url).path
138
+ # See if it is an URL
139
+ elif model_config.url.startswith("http://") or model_config.url.startswith("https://"):
140
+ # Extension (or file name)
141
+ extension = os.path.splitext(model_config.url)[-1]
142
+ download_target = os.path.join(root_dir, model_config.name + extension)
143
+
144
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
145
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
146
+
147
+ if not os.path.isfile(download_target):
148
+ download_file(model_config.url, download_target)
149
+ else:
150
+ print(f"File {download_target} already exists, skipping download")
151
+
152
+ model_config.path = download_target
153
+ # Must be a local file
154
+ else:
155
+ model_config.path = model_config.url
156
+
157
+ else:
158
+ raise ValueError(f"Unknown model type {model_type}")
159
+
160
+ return model_config.path
161
+
162
+ class WhisperCallback(AbstractWhisperCallback):
163
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
164
+ prompt_strategy: AbstractPromptStrategy = None,
165
+ **decodeOptions: dict):
166
+ self.model_container = model_container
167
+ self.language = language
168
+ self.task = task
169
+ self.prompt_strategy = prompt_strategy
170
+
171
+ self.decodeOptions = decodeOptions
172
+
173
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
174
+ """
175
+ Peform the transcription of the given audio file or data.
176
+
177
+ Parameters
178
+ ----------
179
+ audio: Union[str, np.ndarray, torch.Tensor]
180
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
181
+ segment_index: int
182
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
183
+ task: str
184
+ The task - either translate or transcribe.
185
+ progress_listener: ProgressListener
186
+ A callback to receive progress updates.
187
+ """
188
+ model = self.model_container.get_model()
189
+
190
+ if progress_listener is not None:
191
+ with create_progress_listener_handle(progress_listener):
192
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
193
+ else:
194
+ return self._transcribe(model, audio, segment_index, prompt, detected_language)
195
+
196
+ def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
197
+ decodeOptions = self.decodeOptions.copy()
198
+
199
+ # Add fp16
200
+ if self.model_container.compute_type in ["fp16", "float16"]:
201
+ decodeOptions["fp16"] = True
202
+
203
+ initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
204
+ if self.prompt_strategy else prompt
205
+
206
+ result = model.transcribe(audio, \
207
+ language=self.language if self.language else detected_language, task=self.task, \
208
+ initial_prompt=initial_prompt, \
209
+ **decodeOptions
210
+ )
211
+
212
+ # If we have a prompt strategy, we need to increment the current prompt
213
+ if self.prompt_strategy:
214
+ self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
215
+
216
+ return result
src/whisper/whisperFactory.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from src import modelCache
3
+ from src.config import ModelConfig
4
+ from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
5
+
6
+ def create_whisper_container(whisper_implementation: str,
7
+ model_name: str, device: str = None, compute_type: str = "float16",
8
+ download_root: str = None,
9
+ cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
10
+ print("Creating whisper container for " + whisper_implementation)
11
+
12
+ if (whisper_implementation == "whisper"):
13
+ from src.whisper.whisperContainer import WhisperContainer
14
+ return WhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
15
+ elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
16
+ from src.whisper.fasterWhisperContainer import FasterWhisperContainer
17
+ return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
18
+ else:
19
+ raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
tests/segments_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+
4
+ sys.path.append('../whisper-webui')
5
+
6
+ from src.segments import merge_timestamps
7
+
8
+ class TestSegments(unittest.TestCase):
9
+ def __init__(self, *args, **kwargs):
10
+ super(TestSegments, self).__init__(*args, **kwargs)
11
+
12
+ def test_merge_segments(self):
13
+ segments = [
14
+ {'start': 10.0, 'end': 20.0},
15
+ {'start': 22.0, 'end': 27.0},
16
+ {'start': 31.0, 'end': 35.0},
17
+ {'start': 45.0, 'end': 60.0},
18
+ {'start': 61.0, 'end': 65.0},
19
+ {'start': 68.0, 'end': 98.0},
20
+ {'start': 100.0, 'end': 102.0},
21
+ {'start': 110.0, 'end': 112.0}
22
+ ]
23
+
24
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
25
+
26
+ self.assertListEqual(result, [
27
+ {'start': 9.0, 'end': 36.0},
28
+ {'start': 44.0, 'end': 66.0},
29
+ {'start': 67.0, 'end': 99.0},
30
+ {'start': 99.0, 'end': 103.0},
31
+ {'start': 109.0, 'end': 113.0}
32
+ ])
33
+
34
+ def test_overlap_next(self):
35
+ segments = [
36
+ {'start': 5.0, 'end': 39.182},
37
+ {'start': 39.986, 'end': 40.814}
38
+ ]
39
+
40
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
41
+
42
+ self.assertListEqual(result, [
43
+ {'start': 4.0, 'end': 39.584},
44
+ {'start': 39.584, 'end': 41.814}
45
+ ])
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../whisper-webui')
7
+
8
+ from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()