Spaces:
Runtime error
Runtime error
Commit
·
9e8682f
0
Parent(s):
Duplicate from aadnk/whisper-webui
Browse filesCo-authored-by: Kristian Stangeland <[email protected]>
- .gitattributes +31 -0
- .gitignore +5 -0
- README.md +136 -0
- app-local.py +3 -0
- app-network.py +3 -0
- app-shared.py +3 -0
- app.py +339 -0
- cli.py +121 -0
- dockerfile +20 -0
- docs/options.md +78 -0
- requirements.txt +6 -0
- src/__init__.py +0 -0
- src/download.py +72 -0
- src/modelCache.py +17 -0
- src/segments.py +55 -0
- src/utils.py +115 -0
- src/vad.py +527 -0
- src/vadParallel.py +251 -0
- src/whisperContainer.py +106 -0
- tests/segments_test.py +48 -0
- tests/vad_test.py +66 -0
.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
flagged/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
README.md
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Whisper Webui
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.3.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: aadnk/whisper-webui
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
15 |
+
|
16 |
+
# Running Locally
|
17 |
+
|
18 |
+
To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
|
19 |
+
```
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
Finally, run the full version (no audio length restrictions) of the app with parallel CPU/GPU enabled:
|
24 |
+
```
|
25 |
+
python app.py --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
|
26 |
+
```
|
27 |
+
|
28 |
+
You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
|
29 |
+
```
|
30 |
+
python cli.py \
|
31 |
+
[--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
|
32 |
+
[--vad_merge_window VAD_MERGE_WINDOW] \
|
33 |
+
[--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
|
34 |
+
[--vad_padding VAD_PADDING] \
|
35 |
+
[--vad_prompt_window VAD_PROMPT_WINDOW]
|
36 |
+
[--vad_cpu_cores NUMBER_OF_CORES]
|
37 |
+
[--vad_parallel_devices COMMA_DELIMITED_DEVICES]
|
38 |
+
[--auto_parallel BOOLEAN]
|
39 |
+
```
|
40 |
+
In addition, you may also use URL's in addition to file paths as input.
|
41 |
+
```
|
42 |
+
python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
|
43 |
+
```
|
44 |
+
|
45 |
+
## Parallel Execution
|
46 |
+
|
47 |
+
You can also run both the Web-UI or the CLI on multiple GPUs in parallel, using the `vad_parallel_devices` option. This takes a comma-delimited list of
|
48 |
+
device IDs (0, 1, etc.) that Whisper should be distributed to and run on concurrently:
|
49 |
+
```
|
50 |
+
python cli.py --model large --vad silero-vad --language Japanese \
|
51 |
+
--vad_parallel_devices 0,1 "https://www.youtube.com/watch?v=4cICErqqRSM"
|
52 |
+
```
|
53 |
+
|
54 |
+
Note that this requires a VAD to function properly, otherwise only the first GPU will be used. Though you could use `period-vad` to avoid taking the hit
|
55 |
+
of running Silero-Vad, at a slight cost to accuracy.
|
56 |
+
|
57 |
+
This is achieved by creating N child processes (where N is the number of selected devices), where Whisper is run concurrently. In `app.py`, you can also
|
58 |
+
set the `vad_process_timeout` option. This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory.
|
59 |
+
The default value is 30 minutes.
|
60 |
+
|
61 |
+
```
|
62 |
+
python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600
|
63 |
+
```
|
64 |
+
|
65 |
+
To execute the Silero VAD itself in parallel, use the `vad_cpu_cores` option:
|
66 |
+
```
|
67 |
+
python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600 --vad_cpu_cores 4
|
68 |
+
```
|
69 |
+
|
70 |
+
You may also use `vad_process_timeout` with a single device (`--vad_parallel_devices 0`), if you prefer to always free video memory after a period of time.
|
71 |
+
|
72 |
+
### Auto Parallel
|
73 |
+
|
74 |
+
You can also set `auto_parallel` to `True`. This will set `vad_parallel_devices` to use all the GPU devices on the system, and `vad_cpu_cores` to be equal to the number of
|
75 |
+
cores (up to 8):
|
76 |
+
```
|
77 |
+
python app.py --input_audio_max_duration -1 --auto_parallel True
|
78 |
+
```
|
79 |
+
|
80 |
+
# Docker
|
81 |
+
|
82 |
+
To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU.
|
83 |
+
Then either use the GitLab hosted container below, or check out this repository and build an image:
|
84 |
+
```
|
85 |
+
sudo docker build -t whisper-webui:1 .
|
86 |
+
```
|
87 |
+
|
88 |
+
You can then start the WebUI with GPU support like so:
|
89 |
+
```
|
90 |
+
sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
|
91 |
+
```
|
92 |
+
|
93 |
+
Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
|
94 |
+
```
|
95 |
+
sudo docker run -d -p 7860:7860 whisper-webui:1
|
96 |
+
```
|
97 |
+
|
98 |
+
# GitLab Docker Registry
|
99 |
+
|
100 |
+
This Docker container is also hosted on GitLab:
|
101 |
+
|
102 |
+
```
|
103 |
+
sudo docker run -d --gpus=all -p 7860:7860 registry.gitlab.com/aadnk/whisper-webui:latest
|
104 |
+
```
|
105 |
+
|
106 |
+
## Custom Arguments
|
107 |
+
|
108 |
+
You can also pass custom arguments to `app.py` in the Docker container, for instance to be able to use all the GPUs in parallel:
|
109 |
+
```
|
110 |
+
sudo docker run -d --gpus all -p 7860:7860 \
|
111 |
+
--mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
|
112 |
+
--restart=on-failure:15 registry.gitlab.com/aadnk/whisper-webui:latest \
|
113 |
+
app.py --input_audio_max_duration -1 --server_name 0.0.0.0 --vad_parallel_devices 0,1 \
|
114 |
+
--default_vad silero-vad --default_model_name large
|
115 |
+
```
|
116 |
+
|
117 |
+
You can also call `cli.py` the same way:
|
118 |
+
```
|
119 |
+
sudo docker run --gpus all \
|
120 |
+
--mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
|
121 |
+
--mount type=bind,source=${PWD},target=/app/data \
|
122 |
+
registry.gitlab.com/aadnk/whisper-webui:latest \
|
123 |
+
cli.py --model large --vad_parallel_devices 0,1 --vad silero-vad \
|
124 |
+
--output_dir /app/data /app/data/YOUR-FILE-HERE.mp4
|
125 |
+
```
|
126 |
+
|
127 |
+
## Caching
|
128 |
+
|
129 |
+
Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
|
130 |
+
To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
|
131 |
+
prepopulate the directory with the different Whisper models.
|
132 |
+
```
|
133 |
+
sudo docker run -d --gpus=all -p 7860:7860 \
|
134 |
+
--mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
|
135 |
+
registry.gitlab.com/aadnk/whisper-webui:latest
|
136 |
+
```
|
app-local.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Run the app with no audio file restrictions
|
2 |
+
from app import create_ui
|
3 |
+
create_ui(-1)
|
app-network.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Run the app with no audio file restrictions, and make it available on the network
|
2 |
+
from app import create_ui
|
3 |
+
create_ui(-1, server_name="0.0.0.0")
|
app-shared.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Run the app with no audio file restrictions
|
2 |
+
from app import create_ui
|
3 |
+
create_ui(-1, share=True)
|
app.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Iterator
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
from io import StringIO
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import tempfile
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from src.modelCache import ModelCache
|
12 |
+
from src.vadParallel import ParallelContext, ParallelTranscription
|
13 |
+
|
14 |
+
# External programs
|
15 |
+
import ffmpeg
|
16 |
+
|
17 |
+
# UI
|
18 |
+
import gradio as gr
|
19 |
+
|
20 |
+
from src.download import ExceededMaximumDuration, download_url
|
21 |
+
from src.utils import slugify, write_srt, write_vtt
|
22 |
+
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
23 |
+
from src.whisperContainer import WhisperContainer
|
24 |
+
|
25 |
+
# Limitations (set to -1 to disable)
|
26 |
+
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
|
27 |
+
|
28 |
+
# Whether or not to automatically delete all uploaded files, to save disk space
|
29 |
+
DELETE_UPLOADED_FILES = True
|
30 |
+
|
31 |
+
# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
|
32 |
+
MAX_FILE_PREFIX_LENGTH = 17
|
33 |
+
|
34 |
+
# Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
|
35 |
+
MAX_AUTO_CPU_CORES = 8
|
36 |
+
|
37 |
+
LANGUAGES = [
|
38 |
+
"English", "Chinese", "German", "Spanish", "Russian", "Korean",
|
39 |
+
"French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
|
40 |
+
"Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
|
41 |
+
"Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
|
42 |
+
"Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
|
43 |
+
"Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
|
44 |
+
"Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
|
45 |
+
"Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
|
46 |
+
"Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
|
47 |
+
"Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
|
48 |
+
"Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
|
49 |
+
"Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
|
50 |
+
"Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
|
51 |
+
"Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
|
52 |
+
"Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
|
53 |
+
"Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
|
54 |
+
"Hausa", "Bashkir", "Javanese", "Sundanese"
|
55 |
+
]
|
56 |
+
|
57 |
+
class WhisperTranscriber:
|
58 |
+
def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None, vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES):
|
59 |
+
self.model_cache = ModelCache()
|
60 |
+
self.parallel_device_list = None
|
61 |
+
self.gpu_parallel_context = None
|
62 |
+
self.cpu_parallel_context = None
|
63 |
+
self.vad_process_timeout = vad_process_timeout
|
64 |
+
self.vad_cpu_cores = vad_cpu_cores
|
65 |
+
|
66 |
+
self.vad_model = None
|
67 |
+
self.inputAudioMaxDuration = input_audio_max_duration
|
68 |
+
self.deleteUploadedFiles = delete_uploaded_files
|
69 |
+
|
70 |
+
def set_parallel_devices(self, vad_parallel_devices: str):
|
71 |
+
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
72 |
+
|
73 |
+
def set_auto_parallel(self, auto_parallel: bool):
|
74 |
+
if auto_parallel:
|
75 |
+
if torch.cuda.is_available():
|
76 |
+
self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
|
77 |
+
|
78 |
+
self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
|
79 |
+
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
80 |
+
|
81 |
+
def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
|
82 |
+
try:
|
83 |
+
source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
|
84 |
+
|
85 |
+
try:
|
86 |
+
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
87 |
+
selectedModel = modelName if modelName is not None else "base"
|
88 |
+
|
89 |
+
model = WhisperContainer(model_name=selectedModel, cache=self.model_cache)
|
90 |
+
|
91 |
+
# Execute whisper
|
92 |
+
result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
93 |
+
|
94 |
+
# Write result
|
95 |
+
downloadDirectory = tempfile.mkdtemp()
|
96 |
+
|
97 |
+
filePrefix = slugify(sourceName, allow_unicode=True)
|
98 |
+
download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)
|
99 |
+
|
100 |
+
return download, text, vtt
|
101 |
+
|
102 |
+
finally:
|
103 |
+
# Cleanup source
|
104 |
+
if self.deleteUploadedFiles:
|
105 |
+
print("Deleting source file " + source)
|
106 |
+
os.remove(source)
|
107 |
+
|
108 |
+
except ExceededMaximumDuration as e:
|
109 |
+
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
110 |
+
|
111 |
+
def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
|
112 |
+
vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
|
113 |
+
|
114 |
+
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
115 |
+
|
116 |
+
if ('task' in decodeOptions):
|
117 |
+
task = decodeOptions.pop('task')
|
118 |
+
|
119 |
+
# Callable for processing an audio file
|
120 |
+
whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
|
121 |
+
|
122 |
+
# The results
|
123 |
+
if (vad == 'silero-vad'):
|
124 |
+
# Silero VAD where non-speech gaps are transcribed
|
125 |
+
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
126 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
|
127 |
+
elif (vad == 'silero-vad-skip-gaps'):
|
128 |
+
# Silero VAD where non-speech gaps are simply ignored
|
129 |
+
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
130 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
|
131 |
+
elif (vad == 'silero-vad-expand-into-gaps'):
|
132 |
+
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
133 |
+
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
|
134 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
|
135 |
+
elif (vad == 'periodic-vad'):
|
136 |
+
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
137 |
+
# it may create a break in the middle of a sentence, causing some artifacts.
|
138 |
+
periodic_vad = VadPeriodicTranscription()
|
139 |
+
period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
|
140 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
|
141 |
+
|
142 |
+
else:
|
143 |
+
if (self._has_parallel_devices()):
|
144 |
+
# Use a simple period transcription instead, as we need to use the parallel context
|
145 |
+
periodic_vad = VadPeriodicTranscription()
|
146 |
+
period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
|
147 |
+
|
148 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
|
149 |
+
else:
|
150 |
+
# Default VAD
|
151 |
+
result = whisperCallable(audio_path, 0, None, None)
|
152 |
+
|
153 |
+
return result
|
154 |
+
|
155 |
+
def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
|
156 |
+
if (not self._has_parallel_devices()):
|
157 |
+
# No parallel devices, so just run the VAD and Whisper in sequence
|
158 |
+
return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
|
159 |
+
|
160 |
+
gpu_devices = self.parallel_device_list
|
161 |
+
|
162 |
+
if (gpu_devices is None or len(gpu_devices) == 0):
|
163 |
+
# No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
|
164 |
+
gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
|
165 |
+
|
166 |
+
# Create parallel context if needed
|
167 |
+
if (self.gpu_parallel_context is None):
|
168 |
+
# Create a context wih processes and automatically clear the pool after 1 hour of inactivity
|
169 |
+
self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
|
170 |
+
# We also need a CPU context for the VAD
|
171 |
+
if (self.cpu_parallel_context is None):
|
172 |
+
self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
|
173 |
+
|
174 |
+
parallel_vad = ParallelTranscription()
|
175 |
+
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
176 |
+
config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
|
177 |
+
cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context)
|
178 |
+
|
179 |
+
def _has_parallel_devices(self):
|
180 |
+
return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
|
181 |
+
|
182 |
+
def _concat_prompt(self, prompt1, prompt2):
|
183 |
+
if (prompt1 is None):
|
184 |
+
return prompt2
|
185 |
+
elif (prompt2 is None):
|
186 |
+
return prompt1
|
187 |
+
else:
|
188 |
+
return prompt1 + " " + prompt2
|
189 |
+
|
190 |
+
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
|
191 |
+
# Use Silero VAD
|
192 |
+
if (self.vad_model is None):
|
193 |
+
self.vad_model = VadSileroTranscription()
|
194 |
+
|
195 |
+
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
196 |
+
max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
|
197 |
+
segment_padding_left=vadPadding, segment_padding_right=vadPadding,
|
198 |
+
max_prompt_window=vadPromptWindow)
|
199 |
+
|
200 |
+
return config
|
201 |
+
|
202 |
+
def write_result(self, result: dict, source_name: str, output_dir: str):
|
203 |
+
if not os.path.exists(output_dir):
|
204 |
+
os.makedirs(output_dir)
|
205 |
+
|
206 |
+
text = result["text"]
|
207 |
+
language = result["language"]
|
208 |
+
languageMaxLineWidth = self.__get_max_line_width(language)
|
209 |
+
|
210 |
+
print("Max line width " + str(languageMaxLineWidth))
|
211 |
+
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
|
212 |
+
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
|
213 |
+
|
214 |
+
output_files = []
|
215 |
+
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
216 |
+
output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
|
217 |
+
output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
|
218 |
+
|
219 |
+
return output_files, text, vtt
|
220 |
+
|
221 |
+
def clear_cache(self):
|
222 |
+
self.model_cache.clear()
|
223 |
+
self.vad_model = None
|
224 |
+
|
225 |
+
def __get_source(self, urlData, uploadFile, microphoneData):
|
226 |
+
if urlData:
|
227 |
+
# Download from YouTube
|
228 |
+
source = download_url(urlData, self.inputAudioMaxDuration)[0]
|
229 |
+
else:
|
230 |
+
# File input
|
231 |
+
source = uploadFile if uploadFile is not None else microphoneData
|
232 |
+
|
233 |
+
if self.inputAudioMaxDuration > 0:
|
234 |
+
# Calculate audio length
|
235 |
+
audioDuration = ffmpeg.probe(source)["format"]["duration"]
|
236 |
+
|
237 |
+
if float(audioDuration) > self.inputAudioMaxDuration:
|
238 |
+
raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")
|
239 |
+
|
240 |
+
file_path = pathlib.Path(source)
|
241 |
+
sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
|
242 |
+
|
243 |
+
return source, sourceName
|
244 |
+
|
245 |
+
def __get_max_line_width(self, language: str) -> int:
|
246 |
+
if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
|
247 |
+
# Chinese characters and kana are wider, so limit line length to 40 characters
|
248 |
+
return 40
|
249 |
+
else:
|
250 |
+
# TODO: Add more languages
|
251 |
+
# 80 latin characters should fit on a 1080p/720p screen
|
252 |
+
return 80
|
253 |
+
|
254 |
+
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
|
255 |
+
segmentStream = StringIO()
|
256 |
+
|
257 |
+
if format == 'vtt':
|
258 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
259 |
+
elif format == 'srt':
|
260 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
261 |
+
else:
|
262 |
+
raise Exception("Unknown format " + format)
|
263 |
+
|
264 |
+
segmentStream.seek(0)
|
265 |
+
return segmentStream.read()
|
266 |
+
|
267 |
+
def __create_file(self, text: str, directory: str, fileName: str) -> str:
|
268 |
+
# Write the text to a file
|
269 |
+
with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
|
270 |
+
file.write(text)
|
271 |
+
|
272 |
+
return file.name
|
273 |
+
|
274 |
+
def close(self):
|
275 |
+
self.clear_cache()
|
276 |
+
|
277 |
+
if (self.gpu_parallel_context is not None):
|
278 |
+
self.gpu_parallel_context.close()
|
279 |
+
if (self.cpu_parallel_context is not None):
|
280 |
+
self.cpu_parallel_context.close()
|
281 |
+
|
282 |
+
|
283 |
+
def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
|
284 |
+
default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None, vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False):
|
285 |
+
ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores)
|
286 |
+
|
287 |
+
# Specify a list of devices to use for parallel processing
|
288 |
+
ui.set_parallel_devices(vad_parallel_devices)
|
289 |
+
ui.set_auto_parallel(auto_parallel)
|
290 |
+
|
291 |
+
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
292 |
+
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
293 |
+
ui_description += " as well as speech translation and language identification. "
|
294 |
+
|
295 |
+
ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
|
296 |
+
|
297 |
+
if input_audio_max_duration > 0:
|
298 |
+
ui_description += "\n\n" + "Max audio file length: " + str(input_audio_max_duration) + " s"
|
299 |
+
|
300 |
+
ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
|
301 |
+
|
302 |
+
demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
|
303 |
+
gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value=default_model_name, label="Model"),
|
304 |
+
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
|
305 |
+
gr.Text(label="URL (YouTube, etc.)"),
|
306 |
+
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
|
307 |
+
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
308 |
+
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
|
309 |
+
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
|
310 |
+
gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
|
311 |
+
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
|
312 |
+
gr.Number(label="VAD - Padding (s)", precision=None, value=1),
|
313 |
+
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
|
314 |
+
], outputs=[
|
315 |
+
gr.File(label="Download"),
|
316 |
+
gr.Text(label="Transcription"),
|
317 |
+
gr.Text(label="Segments")
|
318 |
+
])
|
319 |
+
|
320 |
+
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
321 |
+
|
322 |
+
# Clean up
|
323 |
+
ui.close()
|
324 |
+
|
325 |
+
if __name__ == '__main__':
|
326 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
327 |
+
parser.add_argument("--input_audio_max_duration", type=int, default=DEFAULT_INPUT_AUDIO_MAX_DURATION, help="Maximum audio file length in seconds, or -1 for no limit.")
|
328 |
+
parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
|
329 |
+
parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
|
330 |
+
parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
|
331 |
+
parser.add_argument("--default_model_name", type=str, default="medium", help="The default model name.")
|
332 |
+
parser.add_argument("--default_vad", type=str, default="silero-vad", help="The default VAD.")
|
333 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
334 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
|
335 |
+
parser.add_argument("--vad_process_timeout", type=float, default="1800", help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.")
|
336 |
+
parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
|
337 |
+
|
338 |
+
args = parser.parse_args().__dict__
|
339 |
+
create_ui(**args)
|
cli.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
import warnings
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import whisper
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from app import LANGUAGES, WhisperTranscriber
|
12 |
+
from src.download import download_url
|
13 |
+
|
14 |
+
from src.utils import optional_float, optional_int, str2bool
|
15 |
+
from src.whisperContainer import WhisperContainer
|
16 |
+
|
17 |
+
|
18 |
+
def cli():
|
19 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
20 |
+
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
21 |
+
parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large"], help="name of the Whisper model to use")
|
22 |
+
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
23 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
24 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
25 |
+
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
26 |
+
|
27 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
28 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
|
29 |
+
|
30 |
+
parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
|
31 |
+
parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
|
32 |
+
parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
|
33 |
+
parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
|
34 |
+
parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
|
35 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=1, help="The number of CPU cores to use for VAD pre-processing.")
|
36 |
+
parser.add_argument("--vad_parallel_devices", type=str, default="", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
|
37 |
+
parser.add_argument("--auto_parallel", type=bool, default=False, help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.")
|
38 |
+
|
39 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
40 |
+
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
41 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
42 |
+
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
43 |
+
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
|
44 |
+
|
45 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
46 |
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
47 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
48 |
+
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
49 |
+
|
50 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
51 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
52 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
53 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
54 |
+
|
55 |
+
args = parser.parse_args().__dict__
|
56 |
+
model_name: str = args.pop("model")
|
57 |
+
model_dir: str = args.pop("model_dir")
|
58 |
+
output_dir: str = args.pop("output_dir")
|
59 |
+
device: str = args.pop("device")
|
60 |
+
os.makedirs(output_dir, exist_ok=True)
|
61 |
+
|
62 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
63 |
+
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
64 |
+
args["language"] = "en"
|
65 |
+
|
66 |
+
temperature = args.pop("temperature")
|
67 |
+
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
68 |
+
if temperature_increment_on_fallback is not None:
|
69 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
70 |
+
else:
|
71 |
+
temperature = [temperature]
|
72 |
+
|
73 |
+
vad = args.pop("vad")
|
74 |
+
vad_merge_window = args.pop("vad_merge_window")
|
75 |
+
vad_max_merge_size = args.pop("vad_max_merge_size")
|
76 |
+
vad_padding = args.pop("vad_padding")
|
77 |
+
vad_prompt_window = args.pop("vad_prompt_window")
|
78 |
+
vad_cpu_cores = args.pop("vad_cpu_cores")
|
79 |
+
auto_parallel = args.pop("auto_parallel")
|
80 |
+
|
81 |
+
model = WhisperContainer(model_name, device=device, download_root=model_dir)
|
82 |
+
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores)
|
83 |
+
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
84 |
+
transcriber.set_auto_parallel(auto_parallel)
|
85 |
+
|
86 |
+
if (transcriber._has_parallel_devices()):
|
87 |
+
print("Using parallel devices:", transcriber.parallel_device_list)
|
88 |
+
|
89 |
+
for audio_path in args.pop("audio"):
|
90 |
+
sources = []
|
91 |
+
|
92 |
+
# Detect URL and download the audio
|
93 |
+
if (uri_validator(audio_path)):
|
94 |
+
# Download from YouTube/URL directly
|
95 |
+
for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
|
96 |
+
source_name = os.path.basename(source_path)
|
97 |
+
sources.append({ "path": source_path, "name": source_name })
|
98 |
+
else:
|
99 |
+
sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
|
100 |
+
|
101 |
+
for source in sources:
|
102 |
+
source_path = source["path"]
|
103 |
+
source_name = source["name"]
|
104 |
+
|
105 |
+
result = transcriber.transcribe_file(model, source_path, temperature=temperature,
|
106 |
+
vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
|
107 |
+
vadPadding=vad_padding, vadPromptWindow=vad_prompt_window, **args)
|
108 |
+
|
109 |
+
transcriber.write_result(result, source_name, output_dir)
|
110 |
+
|
111 |
+
transcriber.close()
|
112 |
+
|
113 |
+
def uri_validator(x):
|
114 |
+
try:
|
115 |
+
result = urlparse(x)
|
116 |
+
return all([result.scheme, result.netloc])
|
117 |
+
except:
|
118 |
+
return False
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
cli()
|
dockerfile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM huggingface/transformers-pytorch-gpu
|
2 |
+
EXPOSE 7860
|
3 |
+
|
4 |
+
ADD . /opt/whisper-webui/
|
5 |
+
|
6 |
+
# Latest version of transformers-pytorch-gpu seems to lack tk.
|
7 |
+
# Further, pip install fails, so we must upgrade pip first.
|
8 |
+
RUN apt-get -y install python3-tk
|
9 |
+
RUN python3 -m pip install --upgrade pip &&\
|
10 |
+
python3 -m pip install -r /opt/whisper-webui/requirements.txt
|
11 |
+
|
12 |
+
# Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
|
13 |
+
# You can also bind this directory in the container to somewhere on the host.
|
14 |
+
|
15 |
+
# To be able to see logs in real time
|
16 |
+
ENV PYTHONUNBUFFERED=1
|
17 |
+
|
18 |
+
WORKDIR /opt/whisper-webui/
|
19 |
+
ENTRYPOINT ["python3"]
|
20 |
+
CMD ["app.py", "--input_audio_max_duration", "-1", "--server_name", "0.0.0.0", "--auto_parallel", "True"]
|
docs/options.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Options
|
2 |
+
To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
|
3 |
+
supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
|
4 |
+
in the file selector to select any file type, including video files) or use the microphone.
|
5 |
+
|
6 |
+
For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option.
|
7 |
+
|
8 |
+
## Model
|
9 |
+
Select the model that Whisper will use to transcribe the audio:
|
10 |
+
|
11 |
+
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
12 |
+
|--------|------------|--------------------|--------------------|---------------|----------------|
|
13 |
+
| tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
|
14 |
+
| base | 74 M | base.en | base | ~1 GB | ~16x |
|
15 |
+
| small | 244 M | small.en | small | ~2 GB | ~6x |
|
16 |
+
| medium | 769 M | medium.en | medium | ~5 GB | ~2x |
|
17 |
+
| large | 1550 M | N/A | large | ~10 GB | 1x |
|
18 |
+
|
19 |
+
## Language
|
20 |
+
|
21 |
+
Select the language, or leave it empty for Whisper to automatically detect it.
|
22 |
+
|
23 |
+
Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
|
24 |
+
language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
|
25 |
+
|
26 |
+
## Inputs
|
27 |
+
The options "URL (YouTube, etc.)", "Upload Audio" or "Micriphone Input" allows you to send an audio input to the model.
|
28 |
+
|
29 |
+
Note that the UI will only process the first valid input - i.e. if you enter both an URL and upload an audio, it will only process
|
30 |
+
the URL.
|
31 |
+
|
32 |
+
## Task
|
33 |
+
Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
|
34 |
+
|
35 |
+
## Vad
|
36 |
+
Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
|
37 |
+
loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
|
38 |
+
with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
|
39 |
+
|
40 |
+
Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
|
41 |
+
So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
|
42 |
+
|
43 |
+
* none
|
44 |
+
* Run whisper on the entire audio input
|
45 |
+
* silero-vad
|
46 |
+
* Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
|
47 |
+
on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
|
48 |
+
on the non-speech section.
|
49 |
+
* silero-vad-expand-into-gaps
|
50 |
+
* Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
|
51 |
+
such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
|
52 |
+
00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
|
53 |
+
* silero-vad-skip-gaps
|
54 |
+
* As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
|
55 |
+
may cause dialogue to be skipped.
|
56 |
+
* periodic-vad
|
57 |
+
* Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
|
58 |
+
a sentence or word in two.
|
59 |
+
|
60 |
+
## VAD - Merge Window
|
61 |
+
If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
|
62 |
+
|
63 |
+
## VAD - Max Merge Size (s)
|
64 |
+
Disables merging of adjacent speech sections if they are this number of seconds long.
|
65 |
+
|
66 |
+
## VAD - Padding (s)
|
67 |
+
The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
|
68 |
+
larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
|
69 |
+
a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
|
70 |
+
to each transcribed line. The default value is 1 second.
|
71 |
+
|
72 |
+
## VAD - Prompt Window (s)
|
73 |
+
The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
|
74 |
+
number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
|
75 |
+
10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
|
76 |
+
|
77 |
+
Note that detected lines in gaps between speech sections will not be included in the prompt
|
78 |
+
(if silero-vad or silero-vad-expand-into-gaps) is used.
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/openai/whisper.git
|
2 |
+
transformers
|
3 |
+
ffmpeg-python==0.2.0
|
4 |
+
gradio
|
5 |
+
yt-dlp
|
6 |
+
torchaudio
|
src/__init__.py
ADDED
File without changes
|
src/download.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tempfile import mkdtemp
|
2 |
+
from typing import List
|
3 |
+
from yt_dlp import YoutubeDL
|
4 |
+
|
5 |
+
import yt_dlp
|
6 |
+
from yt_dlp.postprocessor import PostProcessor
|
7 |
+
|
8 |
+
class FilenameCollectorPP(PostProcessor):
|
9 |
+
def __init__(self):
|
10 |
+
super(FilenameCollectorPP, self).__init__(None)
|
11 |
+
self.filenames = []
|
12 |
+
|
13 |
+
def run(self, information):
|
14 |
+
self.filenames.append(information["filepath"])
|
15 |
+
return [], information
|
16 |
+
|
17 |
+
def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
|
18 |
+
try:
|
19 |
+
return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
|
20 |
+
except yt_dlp.utils.DownloadError as e:
|
21 |
+
# In case of an OS error, try again with a different output template
|
22 |
+
if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
|
23 |
+
return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
|
24 |
+
pass
|
25 |
+
|
26 |
+
def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
|
27 |
+
# Create a temporary directory to store the downloaded files
|
28 |
+
if destinationDirectory is None:
|
29 |
+
destinationDirectory = mkdtemp()
|
30 |
+
|
31 |
+
ydl_opts = {
|
32 |
+
"format": "bestaudio/best",
|
33 |
+
'paths': {
|
34 |
+
'home': destinationDirectory
|
35 |
+
}
|
36 |
+
}
|
37 |
+
if (playlistItems):
|
38 |
+
ydl_opts['playlist_items'] = playlistItems
|
39 |
+
|
40 |
+
# Add output template if specified
|
41 |
+
if outputTemplate:
|
42 |
+
ydl_opts['outtmpl'] = outputTemplate
|
43 |
+
|
44 |
+
filename_collector = FilenameCollectorPP()
|
45 |
+
|
46 |
+
with YoutubeDL(ydl_opts) as ydl:
|
47 |
+
if maxDuration and maxDuration > 0:
|
48 |
+
info = ydl.extract_info(url, download=False)
|
49 |
+
duration = info['duration']
|
50 |
+
|
51 |
+
if duration >= maxDuration:
|
52 |
+
raise ExceededMaximumDuration(videoDuration=duration, maxDuration=maxDuration, message="Video is too long")
|
53 |
+
|
54 |
+
ydl.add_post_processor(filename_collector)
|
55 |
+
ydl.download([url])
|
56 |
+
|
57 |
+
if len(filename_collector.filenames) <= 0:
|
58 |
+
raise Exception("Cannot download " + url)
|
59 |
+
|
60 |
+
result = []
|
61 |
+
|
62 |
+
for filename in filename_collector.filenames:
|
63 |
+
result.append(filename)
|
64 |
+
print("Downloaded " + filename)
|
65 |
+
|
66 |
+
return result
|
67 |
+
|
68 |
+
class ExceededMaximumDuration(Exception):
|
69 |
+
def __init__(self, videoDuration, maxDuration, message):
|
70 |
+
self.videoDuration = videoDuration
|
71 |
+
self.maxDuration = maxDuration
|
72 |
+
super().__init__(message)
|
src/modelCache.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ModelCache:
|
2 |
+
def __init__(self):
|
3 |
+
self._cache = dict()
|
4 |
+
|
5 |
+
def get(self, model_key: str, model_factory):
|
6 |
+
result = self._cache.get(model_key)
|
7 |
+
|
8 |
+
if result is None:
|
9 |
+
result = model_factory()
|
10 |
+
self._cache[model_key] = result
|
11 |
+
return result
|
12 |
+
|
13 |
+
def clear(self):
|
14 |
+
self._cache.clear()
|
15 |
+
|
16 |
+
# A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
|
17 |
+
GLOBAL_MODEL_CACHE = ModelCache()
|
src/segments.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
|
3 |
+
import copy
|
4 |
+
|
5 |
+
def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
|
6 |
+
result = []
|
7 |
+
|
8 |
+
if len(timestamps) == 0:
|
9 |
+
return result
|
10 |
+
if max_merge_size is None:
|
11 |
+
return timestamps
|
12 |
+
|
13 |
+
if padding_left is None:
|
14 |
+
padding_left = 0
|
15 |
+
if padding_right is None:
|
16 |
+
padding_right = 0
|
17 |
+
|
18 |
+
processed_time = 0
|
19 |
+
current_segment = None
|
20 |
+
|
21 |
+
for i in range(len(timestamps)):
|
22 |
+
next_segment = timestamps[i]
|
23 |
+
|
24 |
+
delta = next_segment['start'] - processed_time
|
25 |
+
|
26 |
+
# Note that segments can still be longer than the max merge size, they just won't be merged in that case
|
27 |
+
if current_segment is None or (merge_window is not None and delta > merge_window) \
|
28 |
+
or next_segment['end'] - current_segment['start'] > max_merge_size:
|
29 |
+
# Finish the current segment
|
30 |
+
if current_segment is not None:
|
31 |
+
# Add right padding
|
32 |
+
finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
|
33 |
+
current_segment['end'] += finish_padding
|
34 |
+
delta -= finish_padding
|
35 |
+
|
36 |
+
result.append(current_segment)
|
37 |
+
|
38 |
+
# Start a new segment
|
39 |
+
current_segment = copy.deepcopy(next_segment)
|
40 |
+
|
41 |
+
# Pad the segment
|
42 |
+
current_segment['start'] = current_segment['start'] - min(padding_left, delta)
|
43 |
+
processed_time = current_segment['end']
|
44 |
+
|
45 |
+
else:
|
46 |
+
# Merge the segment
|
47 |
+
current_segment['end'] = next_segment['end']
|
48 |
+
processed_time = current_segment['end']
|
49 |
+
|
50 |
+
# Add the last segment
|
51 |
+
if current_segment is not None:
|
52 |
+
current_segment['end'] += padding_right
|
53 |
+
result.append(current_segment)
|
54 |
+
|
55 |
+
return result
|
src/utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textwrap
|
2 |
+
import unicodedata
|
3 |
+
import re
|
4 |
+
|
5 |
+
import zlib
|
6 |
+
from typing import Iterator, TextIO
|
7 |
+
|
8 |
+
|
9 |
+
def exact_div(x, y):
|
10 |
+
assert x % y == 0
|
11 |
+
return x // y
|
12 |
+
|
13 |
+
|
14 |
+
def str2bool(string):
|
15 |
+
str2val = {"True": True, "False": False}
|
16 |
+
if string in str2val:
|
17 |
+
return str2val[string]
|
18 |
+
else:
|
19 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
20 |
+
|
21 |
+
|
22 |
+
def optional_int(string):
|
23 |
+
return None if string == "None" else int(string)
|
24 |
+
|
25 |
+
|
26 |
+
def optional_float(string):
|
27 |
+
return None if string == "None" else float(string)
|
28 |
+
|
29 |
+
|
30 |
+
def compression_ratio(text) -> float:
|
31 |
+
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
32 |
+
|
33 |
+
|
34 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
|
35 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
36 |
+
milliseconds = round(seconds * 1000.0)
|
37 |
+
|
38 |
+
hours = milliseconds // 3_600_000
|
39 |
+
milliseconds -= hours * 3_600_000
|
40 |
+
|
41 |
+
minutes = milliseconds // 60_000
|
42 |
+
milliseconds -= minutes * 60_000
|
43 |
+
|
44 |
+
seconds = milliseconds // 1_000
|
45 |
+
milliseconds -= seconds * 1_000
|
46 |
+
|
47 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
48 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
|
49 |
+
|
50 |
+
|
51 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
52 |
+
for segment in transcript:
|
53 |
+
print(segment['text'].strip(), file=file, flush=True)
|
54 |
+
|
55 |
+
|
56 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
57 |
+
print("WEBVTT\n", file=file)
|
58 |
+
for segment in transcript:
|
59 |
+
text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
|
60 |
+
|
61 |
+
print(
|
62 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
63 |
+
f"{text}\n",
|
64 |
+
file=file,
|
65 |
+
flush=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
70 |
+
"""
|
71 |
+
Write a transcript to a file in SRT format.
|
72 |
+
Example usage:
|
73 |
+
from pathlib import Path
|
74 |
+
from whisper.utils import write_srt
|
75 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
76 |
+
# save SRT
|
77 |
+
audio_basename = Path(audio_path).stem
|
78 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
79 |
+
write_srt(result["segments"], file=srt)
|
80 |
+
"""
|
81 |
+
for i, segment in enumerate(transcript, start=1):
|
82 |
+
text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
|
83 |
+
|
84 |
+
# write srt lines
|
85 |
+
print(
|
86 |
+
f"{i}\n"
|
87 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
88 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
|
89 |
+
f"{text}\n",
|
90 |
+
file=file,
|
91 |
+
flush=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
def process_text(text: str, maxLineWidth=None):
|
95 |
+
if (maxLineWidth is None or maxLineWidth < 0):
|
96 |
+
return text
|
97 |
+
|
98 |
+
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
|
99 |
+
return '\n'.join(lines)
|
100 |
+
|
101 |
+
def slugify(value, allow_unicode=False):
|
102 |
+
"""
|
103 |
+
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
104 |
+
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
105 |
+
dashes to single dashes. Remove characters that aren't alphanumerics,
|
106 |
+
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
107 |
+
trailing whitespace, dashes, and underscores.
|
108 |
+
"""
|
109 |
+
value = str(value)
|
110 |
+
if allow_unicode:
|
111 |
+
value = unicodedata.normalize('NFKC', value)
|
112 |
+
else:
|
113 |
+
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
114 |
+
value = re.sub(r'[^\w\s-]', '', value.lower())
|
115 |
+
return re.sub(r'[-\s]+', '-', value).strip('-_')
|
src/vad.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from collections import Counter, deque
|
3 |
+
import time
|
4 |
+
|
5 |
+
from typing import Any, Deque, Iterator, List, Dict
|
6 |
+
|
7 |
+
from pprint import pprint
|
8 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
9 |
+
|
10 |
+
from src.segments import merge_timestamps
|
11 |
+
from src.whisperContainer import WhisperCallback
|
12 |
+
|
13 |
+
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
14 |
+
try:
|
15 |
+
import tensorflow as tf
|
16 |
+
except ModuleNotFoundError:
|
17 |
+
# Error handling
|
18 |
+
pass
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
import ffmpeg
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from src.utils import format_timestamp
|
26 |
+
from enum import Enum
|
27 |
+
|
28 |
+
class NonSpeechStrategy(Enum):
|
29 |
+
"""
|
30 |
+
Ignore non-speech frames segments.
|
31 |
+
"""
|
32 |
+
SKIP = 1
|
33 |
+
"""
|
34 |
+
Just treat non-speech segments as speech.
|
35 |
+
"""
|
36 |
+
CREATE_SEGMENT = 2
|
37 |
+
"""
|
38 |
+
Expand speech segments into subsequent non-speech segments.
|
39 |
+
"""
|
40 |
+
EXPAND_SEGMENT = 3
|
41 |
+
|
42 |
+
# Defaults for Silero
|
43 |
+
SPEECH_TRESHOLD = 0.3
|
44 |
+
|
45 |
+
# Minimum size of segments to process
|
46 |
+
MIN_SEGMENT_DURATION = 1
|
47 |
+
|
48 |
+
# The maximum time for texts from old segments to be used in the next segment
|
49 |
+
MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
|
50 |
+
PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
|
51 |
+
|
52 |
+
VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
|
53 |
+
|
54 |
+
class TranscriptionConfig(ABC):
|
55 |
+
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
56 |
+
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
57 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
58 |
+
self.non_speech_strategy = non_speech_strategy
|
59 |
+
self.segment_padding_left = segment_padding_left
|
60 |
+
self.segment_padding_right = segment_padding_right
|
61 |
+
self.max_silent_period = max_silent_period
|
62 |
+
self.max_merge_size = max_merge_size
|
63 |
+
self.max_prompt_window = max_prompt_window
|
64 |
+
self.initial_segment_index = initial_segment_index
|
65 |
+
|
66 |
+
class PeriodicTranscriptionConfig(TranscriptionConfig):
|
67 |
+
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
68 |
+
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
69 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
70 |
+
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
|
71 |
+
self.periodic_duration = periodic_duration
|
72 |
+
|
73 |
+
class AbstractTranscription(ABC):
|
74 |
+
def __init__(self, sampling_rate: int = 16000):
|
75 |
+
self.sampling_rate = sampling_rate
|
76 |
+
|
77 |
+
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
78 |
+
return load_audio(str, self.sampling_rate, start_time, duration)
|
79 |
+
|
80 |
+
@abstractmethod
|
81 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
82 |
+
"""
|
83 |
+
Get the start and end timestamps of the sections that should be transcribed by this VAD method.
|
84 |
+
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
audio: str
|
88 |
+
The audio file.
|
89 |
+
config: TranscriptionConfig
|
90 |
+
The transcription configuration.
|
91 |
+
|
92 |
+
Returns
|
93 |
+
-------
|
94 |
+
A list of start and end timestamps, in fractional seconds.
|
95 |
+
"""
|
96 |
+
return
|
97 |
+
|
98 |
+
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
|
99 |
+
"""
|
100 |
+
Get the start and end timestamps of the sections that should be transcribed by this VAD method,
|
101 |
+
after merging the given segments using the specified configuration.
|
102 |
+
|
103 |
+
Parameters
|
104 |
+
----------
|
105 |
+
audio: str
|
106 |
+
The audio file.
|
107 |
+
config: TranscriptionConfig
|
108 |
+
The transcription configuration.
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
A list of start and end timestamps, in fractional seconds.
|
113 |
+
"""
|
114 |
+
merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
|
115 |
+
config.segment_padding_left, config.segment_padding_right)
|
116 |
+
|
117 |
+
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
118 |
+
# Expand segments to include the gaps between them
|
119 |
+
if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
|
120 |
+
# When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
|
121 |
+
merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
|
122 |
+
elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
|
123 |
+
# With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
|
124 |
+
merged = self.expand_gaps(merged, total_duration=total_duration)
|
125 |
+
else:
|
126 |
+
raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
|
127 |
+
|
128 |
+
print("Transcribing non-speech:")
|
129 |
+
pprint(merged)
|
130 |
+
return merged
|
131 |
+
|
132 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
|
133 |
+
"""
|
134 |
+
Transcribe the given audo file.
|
135 |
+
|
136 |
+
Parameters
|
137 |
+
----------
|
138 |
+
audio: str
|
139 |
+
The audio file.
|
140 |
+
whisperCallable: WhisperCallback
|
141 |
+
A callback object to call to transcribe each segment.
|
142 |
+
|
143 |
+
Returns
|
144 |
+
-------
|
145 |
+
A list of start and end timestamps, in fractional seconds.
|
146 |
+
"""
|
147 |
+
|
148 |
+
max_audio_duration = get_audio_duration(audio)
|
149 |
+
timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
|
150 |
+
|
151 |
+
# Get speech timestamps from full audio file
|
152 |
+
merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
|
153 |
+
|
154 |
+
# A deque of transcribed segments that is passed to the next segment as a prompt
|
155 |
+
prompt_window = deque()
|
156 |
+
|
157 |
+
print("Processing timestamps:")
|
158 |
+
pprint(merged)
|
159 |
+
|
160 |
+
result = {
|
161 |
+
'text': "",
|
162 |
+
'segments': [],
|
163 |
+
'language': ""
|
164 |
+
}
|
165 |
+
languageCounter = Counter()
|
166 |
+
detected_language = None
|
167 |
+
|
168 |
+
segment_index = config.initial_segment_index
|
169 |
+
|
170 |
+
# For each time segment, run whisper
|
171 |
+
for segment in merged:
|
172 |
+
segment_index += 1
|
173 |
+
segment_start = segment['start']
|
174 |
+
segment_end = segment['end']
|
175 |
+
segment_expand_amount = segment.get('expand_amount', 0)
|
176 |
+
segment_gap = segment.get('gap', False)
|
177 |
+
|
178 |
+
segment_duration = segment_end - segment_start
|
179 |
+
|
180 |
+
if segment_duration < MIN_SEGMENT_DURATION:
|
181 |
+
continue;
|
182 |
+
|
183 |
+
# Audio to run on Whisper
|
184 |
+
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
185 |
+
# Previous segments to use as a prompt
|
186 |
+
segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
|
187 |
+
|
188 |
+
# Detected language
|
189 |
+
detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
|
190 |
+
|
191 |
+
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
192 |
+
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
193 |
+
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
|
194 |
+
|
195 |
+
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
196 |
+
|
197 |
+
# Propagate expand amount to the segments
|
198 |
+
if (segment_expand_amount > 0):
|
199 |
+
segment_without_expansion = segment_duration - segment_expand_amount
|
200 |
+
|
201 |
+
for adjusted_segment in adjusted_segments:
|
202 |
+
adjusted_segment_end = adjusted_segment['end']
|
203 |
+
|
204 |
+
# Add expand amount if the segment got expanded
|
205 |
+
if (adjusted_segment_end > segment_without_expansion):
|
206 |
+
adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
|
207 |
+
|
208 |
+
# Append to output
|
209 |
+
result['text'] += segment_result['text']
|
210 |
+
result['segments'].extend(adjusted_segments)
|
211 |
+
|
212 |
+
# Increment detected language
|
213 |
+
if not segment_gap:
|
214 |
+
languageCounter[segment_result['language']] += 1
|
215 |
+
|
216 |
+
# Update prompt window
|
217 |
+
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
|
218 |
+
|
219 |
+
if detected_language is not None:
|
220 |
+
result['language'] = detected_language
|
221 |
+
|
222 |
+
return result
|
223 |
+
|
224 |
+
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
225 |
+
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
226 |
+
# Add segments to the current prompt window (unless it is a speech gap)
|
227 |
+
if not segment_gap:
|
228 |
+
for segment in adjusted_segments:
|
229 |
+
if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
|
230 |
+
prompt_window.append(segment)
|
231 |
+
|
232 |
+
while (len(prompt_window) > 0):
|
233 |
+
first_end_time = prompt_window[0].get('end', 0)
|
234 |
+
# Time expanded in the segments should be discounted from the prompt window
|
235 |
+
first_expand_time = prompt_window[0].get('expand_amount', 0)
|
236 |
+
|
237 |
+
if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
|
238 |
+
prompt_window.popleft()
|
239 |
+
else:
|
240 |
+
break
|
241 |
+
|
242 |
+
def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
|
243 |
+
result = []
|
244 |
+
last_end_time = 0
|
245 |
+
|
246 |
+
for segment in segments:
|
247 |
+
segment_start = float(segment['start'])
|
248 |
+
segment_end = float(segment['end'])
|
249 |
+
|
250 |
+
if (last_end_time != segment_start):
|
251 |
+
delta = segment_start - last_end_time
|
252 |
+
|
253 |
+
if (min_gap_length is None or delta >= min_gap_length):
|
254 |
+
result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
|
255 |
+
|
256 |
+
last_end_time = segment_end
|
257 |
+
result.append(segment)
|
258 |
+
|
259 |
+
# Also include total duration if specified
|
260 |
+
if (total_duration is not None and last_end_time < total_duration):
|
261 |
+
delta = total_duration - segment_start
|
262 |
+
|
263 |
+
if (min_gap_length is None or delta >= min_gap_length):
|
264 |
+
result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
|
265 |
+
|
266 |
+
return result
|
267 |
+
|
268 |
+
# Expand the end time of each segment to the start of the next segment
|
269 |
+
def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
|
270 |
+
result = []
|
271 |
+
|
272 |
+
if len(segments) == 0:
|
273 |
+
return result
|
274 |
+
|
275 |
+
# Add gap at the beginning if needed
|
276 |
+
if (segments[0]['start'] > 0):
|
277 |
+
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
|
278 |
+
|
279 |
+
for i in range(len(segments) - 1):
|
280 |
+
current_segment = segments[i]
|
281 |
+
next_segment = segments[i + 1]
|
282 |
+
|
283 |
+
delta = next_segment['start'] - current_segment['end']
|
284 |
+
|
285 |
+
# Expand if the gap actually exists
|
286 |
+
if (delta >= 0):
|
287 |
+
current_segment = current_segment.copy()
|
288 |
+
current_segment['expand_amount'] = delta
|
289 |
+
current_segment['end'] = next_segment['start']
|
290 |
+
|
291 |
+
result.append(current_segment)
|
292 |
+
|
293 |
+
# Add last segment
|
294 |
+
last_segment = segments[-1]
|
295 |
+
result.append(last_segment)
|
296 |
+
|
297 |
+
# Also include total duration if specified
|
298 |
+
if (total_duration is not None):
|
299 |
+
last_segment = result[-1]
|
300 |
+
|
301 |
+
if (last_segment['end'] < total_duration):
|
302 |
+
last_segment = last_segment.copy()
|
303 |
+
last_segment['end'] = total_duration
|
304 |
+
result[-1] = last_segment
|
305 |
+
|
306 |
+
return result
|
307 |
+
|
308 |
+
def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
|
309 |
+
result = []
|
310 |
+
|
311 |
+
if len(segments) == 0:
|
312 |
+
return result
|
313 |
+
|
314 |
+
# Add gap at the beginning if needed
|
315 |
+
if (segments[0]['start'] > 0):
|
316 |
+
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
|
317 |
+
|
318 |
+
for i in range(len(segments) - 1):
|
319 |
+
expanded = False
|
320 |
+
current_segment = segments[i]
|
321 |
+
next_segment = segments[i + 1]
|
322 |
+
|
323 |
+
delta = next_segment['start'] - current_segment['end']
|
324 |
+
|
325 |
+
if (max_expand_size is not None and delta <= max_expand_size):
|
326 |
+
# Just expand the current segment
|
327 |
+
current_segment = current_segment.copy()
|
328 |
+
current_segment['expand_amount'] = delta
|
329 |
+
current_segment['end'] = next_segment['start']
|
330 |
+
expanded = True
|
331 |
+
|
332 |
+
result.append(current_segment)
|
333 |
+
|
334 |
+
# Add a gap to the next segment if needed
|
335 |
+
if (delta >= 0 and not expanded):
|
336 |
+
result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
|
337 |
+
|
338 |
+
# Add last segment
|
339 |
+
last_segment = segments[-1]
|
340 |
+
result.append(last_segment)
|
341 |
+
|
342 |
+
# Also include total duration if specified
|
343 |
+
if (total_duration is not None):
|
344 |
+
last_segment = result[-1]
|
345 |
+
|
346 |
+
delta = total_duration - last_segment['end']
|
347 |
+
|
348 |
+
if (delta > 0):
|
349 |
+
if (max_expand_size is not None and delta <= max_expand_size):
|
350 |
+
# Expand the last segment
|
351 |
+
last_segment = last_segment.copy()
|
352 |
+
last_segment['expand_amount'] = delta
|
353 |
+
last_segment['end'] = total_duration
|
354 |
+
result[-1] = last_segment
|
355 |
+
else:
|
356 |
+
result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
|
357 |
+
|
358 |
+
return result
|
359 |
+
|
360 |
+
def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
|
361 |
+
result = []
|
362 |
+
|
363 |
+
for segment in segments:
|
364 |
+
segment_start = float(segment['start'])
|
365 |
+
segment_end = float(segment['end'])
|
366 |
+
|
367 |
+
# Filter segments?
|
368 |
+
if (max_source_time is not None):
|
369 |
+
if (segment_start > max_source_time):
|
370 |
+
continue
|
371 |
+
segment_end = min(max_source_time, segment_end)
|
372 |
+
|
373 |
+
new_segment = segment.copy()
|
374 |
+
|
375 |
+
# Add to start and end
|
376 |
+
new_segment['start'] = segment_start + adjust_seconds
|
377 |
+
new_segment['end'] = segment_end + adjust_seconds
|
378 |
+
result.append(new_segment)
|
379 |
+
return result
|
380 |
+
|
381 |
+
def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
|
382 |
+
result = []
|
383 |
+
|
384 |
+
for entry in timestamps:
|
385 |
+
start = entry['start']
|
386 |
+
end = entry['end']
|
387 |
+
|
388 |
+
result.append({
|
389 |
+
'start': start * factor,
|
390 |
+
'end': end * factor
|
391 |
+
})
|
392 |
+
return result
|
393 |
+
|
394 |
+
|
395 |
+
class VadSileroTranscription(AbstractTranscription):
|
396 |
+
def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
|
397 |
+
super().__init__(sampling_rate=sampling_rate)
|
398 |
+
self.model = None
|
399 |
+
self.cache = cache
|
400 |
+
self._initialize_model()
|
401 |
+
|
402 |
+
def _initialize_model(self):
|
403 |
+
if (self.cache is not None):
|
404 |
+
model_key = "VadSileroTranscription"
|
405 |
+
self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
|
406 |
+
print("Loaded Silerio model from cache.")
|
407 |
+
else:
|
408 |
+
self.model, self.get_speech_timestamps = self._create_model()
|
409 |
+
print("Created Silerio model")
|
410 |
+
|
411 |
+
def _create_model(self):
|
412 |
+
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
413 |
+
|
414 |
+
# Silero does not benefit from multi-threading
|
415 |
+
torch.set_num_threads(1) # JIT
|
416 |
+
(get_speech_timestamps, _, _, _, _) = utils
|
417 |
+
|
418 |
+
return model, get_speech_timestamps
|
419 |
+
|
420 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
421 |
+
result = []
|
422 |
+
|
423 |
+
print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
|
424 |
+
perf_start_time = time.perf_counter()
|
425 |
+
|
426 |
+
# Divide procesisng of audio into chunks
|
427 |
+
chunk_start = start_time
|
428 |
+
|
429 |
+
while (chunk_start < end_time):
|
430 |
+
chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
|
431 |
+
|
432 |
+
print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
|
433 |
+
wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
|
434 |
+
|
435 |
+
sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
|
436 |
+
seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
|
437 |
+
adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
|
438 |
+
|
439 |
+
#pprint(adjusted)
|
440 |
+
|
441 |
+
result.extend(adjusted)
|
442 |
+
chunk_start += chunk_duration
|
443 |
+
|
444 |
+
perf_end_time = time.perf_counter()
|
445 |
+
print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
|
446 |
+
|
447 |
+
return result
|
448 |
+
|
449 |
+
def __getstate__(self):
|
450 |
+
# We only need the sampling rate
|
451 |
+
return { 'sampling_rate': self.sampling_rate }
|
452 |
+
|
453 |
+
def __setstate__(self, state):
|
454 |
+
self.sampling_rate = state['sampling_rate']
|
455 |
+
self.model = None
|
456 |
+
# Use the global cache
|
457 |
+
self.cache = GLOBAL_MODEL_CACHE
|
458 |
+
self._initialize_model()
|
459 |
+
|
460 |
+
# A very simple VAD that just marks every N seconds as speech
|
461 |
+
class VadPeriodicTranscription(AbstractTranscription):
|
462 |
+
def __init__(self, sampling_rate: int = 16000):
|
463 |
+
super().__init__(sampling_rate=sampling_rate)
|
464 |
+
|
465 |
+
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
|
466 |
+
result = []
|
467 |
+
|
468 |
+
# Generate a timestamp every N seconds
|
469 |
+
start_timestamp = start_time
|
470 |
+
|
471 |
+
while (start_timestamp < end_time):
|
472 |
+
end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
|
473 |
+
segment_duration = end_timestamp - start_timestamp
|
474 |
+
|
475 |
+
# Minimum duration is 1 second
|
476 |
+
if (segment_duration >= 1):
|
477 |
+
result.append( { 'start': start_timestamp, 'end': end_timestamp } )
|
478 |
+
|
479 |
+
start_timestamp = end_timestamp
|
480 |
+
|
481 |
+
return result
|
482 |
+
|
483 |
+
def get_audio_duration(file: str):
|
484 |
+
return float(ffmpeg.probe(file)["format"]["duration"])
|
485 |
+
|
486 |
+
def load_audio(file: str, sample_rate: int = 16000,
|
487 |
+
start_time: str = None, duration: str = None):
|
488 |
+
"""
|
489 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
490 |
+
|
491 |
+
Parameters
|
492 |
+
----------
|
493 |
+
file: str
|
494 |
+
The audio file to open
|
495 |
+
|
496 |
+
sr: int
|
497 |
+
The sample rate to resample the audio if necessary
|
498 |
+
|
499 |
+
start_time: str
|
500 |
+
The start time, using the standard FFMPEG time duration syntax, or None to disable.
|
501 |
+
|
502 |
+
duration: str
|
503 |
+
The duration, using the standard FFMPEG time duration syntax, or None to disable.
|
504 |
+
|
505 |
+
Returns
|
506 |
+
-------
|
507 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
508 |
+
"""
|
509 |
+
try:
|
510 |
+
inputArgs = {'threads': 0}
|
511 |
+
|
512 |
+
if (start_time is not None):
|
513 |
+
inputArgs['ss'] = start_time
|
514 |
+
if (duration is not None):
|
515 |
+
inputArgs['t'] = duration
|
516 |
+
|
517 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
518 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
519 |
+
out, _ = (
|
520 |
+
ffmpeg.input(file, **inputArgs)
|
521 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
|
522 |
+
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
|
523 |
+
)
|
524 |
+
except ffmpeg.Error as e:
|
525 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
|
526 |
+
|
527 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
src/vadParallel.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
5 |
+
from src.whisperContainer import WhisperCallback
|
6 |
+
|
7 |
+
from multiprocessing import Pool
|
8 |
+
|
9 |
+
from typing import Any, Dict, List
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
class ParallelContext:
|
14 |
+
def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
|
15 |
+
self.num_processes = num_processes
|
16 |
+
self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
|
17 |
+
self.lock = threading.Lock()
|
18 |
+
|
19 |
+
self.ref_count = 0
|
20 |
+
self.pool = None
|
21 |
+
self.cleanup_timer = None
|
22 |
+
|
23 |
+
def get_pool(self):
|
24 |
+
# Initialize pool lazily
|
25 |
+
if (self.pool is None):
|
26 |
+
context = multiprocessing.get_context('spawn')
|
27 |
+
self.pool = context.Pool(self.num_processes)
|
28 |
+
|
29 |
+
self.ref_count = self.ref_count + 1
|
30 |
+
|
31 |
+
if (self.auto_cleanup_timeout_seconds is not None):
|
32 |
+
self._stop_auto_cleanup()
|
33 |
+
|
34 |
+
return self.pool
|
35 |
+
|
36 |
+
def return_pool(self, pool):
|
37 |
+
if (self.pool == pool and self.ref_count > 0):
|
38 |
+
self.ref_count = self.ref_count - 1
|
39 |
+
|
40 |
+
if (self.ref_count == 0):
|
41 |
+
if (self.auto_cleanup_timeout_seconds is not None):
|
42 |
+
self._start_auto_cleanup()
|
43 |
+
|
44 |
+
def _start_auto_cleanup(self):
|
45 |
+
if (self.cleanup_timer is not None):
|
46 |
+
self.cleanup_timer.cancel()
|
47 |
+
self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
|
48 |
+
self.cleanup_timer.start()
|
49 |
+
|
50 |
+
print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
|
51 |
+
|
52 |
+
def _stop_auto_cleanup(self):
|
53 |
+
if (self.cleanup_timer is not None):
|
54 |
+
self.cleanup_timer.cancel()
|
55 |
+
self.cleanup_timer = None
|
56 |
+
|
57 |
+
print("Stopped auto cleanup of pool")
|
58 |
+
|
59 |
+
def _execute_cleanup(self):
|
60 |
+
print("Executing cleanup of pool")
|
61 |
+
|
62 |
+
if (self.ref_count == 0):
|
63 |
+
self.close()
|
64 |
+
|
65 |
+
def close(self):
|
66 |
+
self._stop_auto_cleanup()
|
67 |
+
|
68 |
+
if (self.pool is not None):
|
69 |
+
print("Closing pool of " + str(self.num_processes) + " processes")
|
70 |
+
self.pool.close()
|
71 |
+
self.pool.join()
|
72 |
+
self.pool = None
|
73 |
+
|
74 |
+
class ParallelTranscriptionConfig(TranscriptionConfig):
|
75 |
+
def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
|
76 |
+
super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
|
77 |
+
self.device_id = device_id
|
78 |
+
self.override_timestamps = override_timestamps
|
79 |
+
|
80 |
+
class ParallelTranscription(AbstractTranscription):
|
81 |
+
# Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
|
82 |
+
# into smaller segments than 2 minute (min 6 seconds per CPU core)
|
83 |
+
MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
|
84 |
+
|
85 |
+
def __init__(self, sampling_rate: int = 16000):
|
86 |
+
super().__init__(sampling_rate=sampling_rate)
|
87 |
+
|
88 |
+
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
|
89 |
+
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None):
|
90 |
+
total_duration = get_audio_duration(audio)
|
91 |
+
|
92 |
+
# First, get the timestamps for the original audio
|
93 |
+
if (cpu_device_count > 1):
|
94 |
+
merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
|
95 |
+
else:
|
96 |
+
merged = transcription.get_merged_timestamps(audio, config, total_duration)
|
97 |
+
|
98 |
+
# Split into a list for each device
|
99 |
+
# TODO: Split by time instead of by number of chunks
|
100 |
+
merged_split = list(self._split(merged, len(gpu_devices)))
|
101 |
+
|
102 |
+
# Parameters that will be passed to the transcribe function
|
103 |
+
parameters = []
|
104 |
+
segment_index = config.initial_segment_index
|
105 |
+
|
106 |
+
for i in range(len(merged_split)):
|
107 |
+
device_segment_list = list(merged_split[i])
|
108 |
+
device_id = gpu_devices[i]
|
109 |
+
|
110 |
+
if (len(device_segment_list) <= 0):
|
111 |
+
continue
|
112 |
+
|
113 |
+
print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
|
114 |
+
|
115 |
+
# Create a new config with the given device ID
|
116 |
+
device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
|
117 |
+
segment_index += len(device_segment_list)
|
118 |
+
|
119 |
+
parameters.append([audio, whisperCallable, device_config]);
|
120 |
+
|
121 |
+
merged = {
|
122 |
+
'text': '',
|
123 |
+
'segments': [],
|
124 |
+
'language': None
|
125 |
+
}
|
126 |
+
|
127 |
+
created_context = False
|
128 |
+
|
129 |
+
perf_start_gpu = time.perf_counter()
|
130 |
+
|
131 |
+
# Spawn a separate process for each device
|
132 |
+
try:
|
133 |
+
if (gpu_parallel_context is None):
|
134 |
+
gpu_parallel_context = ParallelContext(len(gpu_devices))
|
135 |
+
created_context = True
|
136 |
+
|
137 |
+
# Get a pool of processes
|
138 |
+
pool = gpu_parallel_context.get_pool()
|
139 |
+
|
140 |
+
# Run the transcription in parallel
|
141 |
+
results = pool.starmap(self.transcribe, parameters)
|
142 |
+
|
143 |
+
for result in results:
|
144 |
+
# Merge the results
|
145 |
+
if (result['text'] is not None):
|
146 |
+
merged['text'] += result['text']
|
147 |
+
if (result['segments'] is not None):
|
148 |
+
merged['segments'].extend(result['segments'])
|
149 |
+
if (result['language'] is not None):
|
150 |
+
merged['language'] = result['language']
|
151 |
+
|
152 |
+
finally:
|
153 |
+
# Return the pool to the context
|
154 |
+
if (gpu_parallel_context is not None):
|
155 |
+
gpu_parallel_context.return_pool(pool)
|
156 |
+
# Always close the context if we created it
|
157 |
+
if (created_context):
|
158 |
+
gpu_parallel_context.close()
|
159 |
+
|
160 |
+
perf_end_gpu = time.perf_counter()
|
161 |
+
print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
|
162 |
+
|
163 |
+
return merged
|
164 |
+
|
165 |
+
def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
|
166 |
+
cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
|
167 |
+
parameters = []
|
168 |
+
|
169 |
+
chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
|
170 |
+
chunk_start = 0
|
171 |
+
cpu_device_id = 0
|
172 |
+
|
173 |
+
perf_start_time = time.perf_counter()
|
174 |
+
|
175 |
+
# Create chunks that will be processed on the CPU
|
176 |
+
while (chunk_start < total_duration):
|
177 |
+
chunk_end = min(chunk_start + chunk_size, total_duration)
|
178 |
+
|
179 |
+
if (chunk_end - chunk_start < 1):
|
180 |
+
# No need to process chunks that are less than 1 second
|
181 |
+
break
|
182 |
+
|
183 |
+
print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
|
184 |
+
str(chunk_end) + " on CPU device " + str(cpu_device_id))
|
185 |
+
parameters.append([audio, config, chunk_start, chunk_end]);
|
186 |
+
|
187 |
+
cpu_device_id += 1
|
188 |
+
chunk_start = chunk_end
|
189 |
+
|
190 |
+
created_context = False
|
191 |
+
|
192 |
+
# Spawn a separate process for each device
|
193 |
+
try:
|
194 |
+
if (cpu_parallel_context is None):
|
195 |
+
cpu_parallel_context = ParallelContext(cpu_device_count)
|
196 |
+
created_context = True
|
197 |
+
|
198 |
+
# Get a pool of processes
|
199 |
+
pool = cpu_parallel_context.get_pool()
|
200 |
+
|
201 |
+
# Run the transcription in parallel. Note that transcription must be picklable.
|
202 |
+
results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
|
203 |
+
|
204 |
+
timestamps = []
|
205 |
+
|
206 |
+
# Flatten the results
|
207 |
+
for result in results:
|
208 |
+
timestamps.extend(result)
|
209 |
+
|
210 |
+
merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
|
211 |
+
|
212 |
+
perf_end_time = time.perf_counter()
|
213 |
+
print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
|
214 |
+
return merged
|
215 |
+
|
216 |
+
finally:
|
217 |
+
# Return the pool to the context
|
218 |
+
if (cpu_parallel_context is not None):
|
219 |
+
cpu_parallel_context.return_pool(pool)
|
220 |
+
# Always close the context if we created it
|
221 |
+
if (created_context):
|
222 |
+
cpu_parallel_context.close()
|
223 |
+
|
224 |
+
def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
|
225 |
+
return []
|
226 |
+
|
227 |
+
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
|
228 |
+
# Override timestamps that will be processed
|
229 |
+
if (config.override_timestamps is not None):
|
230 |
+
print("Using override timestamps of size " + str(len(config.override_timestamps)))
|
231 |
+
return config.override_timestamps
|
232 |
+
return super().get_merged_timestamps(timestamps, config, total_duration)
|
233 |
+
|
234 |
+
def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
|
235 |
+
# Override device ID the first time
|
236 |
+
if (os.environ.get("INITIALIZED", None) is None):
|
237 |
+
os.environ["INITIALIZED"] = "1"
|
238 |
+
|
239 |
+
# Note that this may be None if the user didn't specify a device. In that case, Whisper will
|
240 |
+
# just use the default GPU device.
|
241 |
+
if (config.device_id is not None):
|
242 |
+
print("Using device " + config.device_id)
|
243 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
244 |
+
|
245 |
+
return super().transcribe(audio, whisperCallable, config)
|
246 |
+
|
247 |
+
def _split(self, a, n):
|
248 |
+
"""Split a list into n approximately equal parts."""
|
249 |
+
k, m = divmod(len(a), n)
|
250 |
+
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|
251 |
+
|
src/whisperContainer.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# External programs
|
2 |
+
import whisper
|
3 |
+
|
4 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
5 |
+
|
6 |
+
class WhisperContainer:
|
7 |
+
def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
|
8 |
+
self.model_name = model_name
|
9 |
+
self.device = device
|
10 |
+
self.download_root = download_root
|
11 |
+
self.cache = cache
|
12 |
+
|
13 |
+
# Will be created on demand
|
14 |
+
self.model = None
|
15 |
+
|
16 |
+
def get_model(self):
|
17 |
+
if self.model is None:
|
18 |
+
|
19 |
+
if (self.cache is None):
|
20 |
+
self.model = self._create_model()
|
21 |
+
else:
|
22 |
+
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
|
23 |
+
self.model = self.cache.get(model_key, self._create_model)
|
24 |
+
return self.model
|
25 |
+
|
26 |
+
def _create_model(self):
|
27 |
+
print("Loading whisper model " + self.model_name)
|
28 |
+
return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
29 |
+
|
30 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
31 |
+
"""
|
32 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
language: str
|
37 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
38 |
+
task: str
|
39 |
+
The task - either translate or transcribe.
|
40 |
+
initial_prompt: str
|
41 |
+
The initial prompt to use for the transcription.
|
42 |
+
decodeOptions: dict
|
43 |
+
Additional options to pass to the decoder. Must be pickleable.
|
44 |
+
|
45 |
+
Returns
|
46 |
+
-------
|
47 |
+
A WhisperCallback object.
|
48 |
+
"""
|
49 |
+
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
50 |
+
|
51 |
+
# This is required for multiprocessing
|
52 |
+
def __getstate__(self):
|
53 |
+
return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
|
54 |
+
|
55 |
+
def __setstate__(self, state):
|
56 |
+
self.model_name = state["model_name"]
|
57 |
+
self.device = state["device"]
|
58 |
+
self.download_root = state["download_root"]
|
59 |
+
self.model = None
|
60 |
+
# Depickled objects must use the global cache
|
61 |
+
self.cache = GLOBAL_MODEL_CACHE
|
62 |
+
|
63 |
+
|
64 |
+
class WhisperCallback:
|
65 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
|
66 |
+
self.model_container = model_container
|
67 |
+
self.language = language
|
68 |
+
self.task = task
|
69 |
+
self.initial_prompt = initial_prompt
|
70 |
+
self.decodeOptions = decodeOptions
|
71 |
+
|
72 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
|
73 |
+
"""
|
74 |
+
Peform the transcription of the given audio file or data.
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
79 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
80 |
+
segment_index: int
|
81 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
82 |
+
task: str
|
83 |
+
The task - either translate or transcribe.
|
84 |
+
prompt: str
|
85 |
+
The prompt to use for the transcription.
|
86 |
+
detected_language: str
|
87 |
+
The detected language of the audio file.
|
88 |
+
|
89 |
+
Returns
|
90 |
+
-------
|
91 |
+
The result of the Whisper call.
|
92 |
+
"""
|
93 |
+
model = self.model_container.get_model()
|
94 |
+
|
95 |
+
return model.transcribe(audio, \
|
96 |
+
language=self.language if self.language else detected_language, task=self.task, \
|
97 |
+
initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
|
98 |
+
**self.decodeOptions)
|
99 |
+
|
100 |
+
def _concat_prompt(self, prompt1, prompt2):
|
101 |
+
if (prompt1 is None):
|
102 |
+
return prompt2
|
103 |
+
elif (prompt2 is None):
|
104 |
+
return prompt1
|
105 |
+
else:
|
106 |
+
return prompt1 + " " + prompt2
|
tests/segments_test.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import unittest
|
3 |
+
|
4 |
+
sys.path.append('../whisper-webui')
|
5 |
+
|
6 |
+
from src.segments import merge_timestamps
|
7 |
+
|
8 |
+
class TestSegments(unittest.TestCase):
|
9 |
+
def __init__(self, *args, **kwargs):
|
10 |
+
super(TestSegments, self).__init__(*args, **kwargs)
|
11 |
+
|
12 |
+
def test_merge_segments(self):
|
13 |
+
segments = [
|
14 |
+
{'start': 10.0, 'end': 20.0},
|
15 |
+
{'start': 22.0, 'end': 27.0},
|
16 |
+
{'start': 31.0, 'end': 35.0},
|
17 |
+
{'start': 45.0, 'end': 60.0},
|
18 |
+
{'start': 61.0, 'end': 65.0},
|
19 |
+
{'start': 68.0, 'end': 98.0},
|
20 |
+
{'start': 100.0, 'end': 102.0},
|
21 |
+
{'start': 110.0, 'end': 112.0}
|
22 |
+
]
|
23 |
+
|
24 |
+
result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
|
25 |
+
|
26 |
+
self.assertListEqual(result, [
|
27 |
+
{'start': 9.0, 'end': 36.0},
|
28 |
+
{'start': 44.0, 'end': 66.0},
|
29 |
+
{'start': 67.0, 'end': 99.0},
|
30 |
+
{'start': 99.0, 'end': 103.0},
|
31 |
+
{'start': 109.0, 'end': 113.0}
|
32 |
+
])
|
33 |
+
|
34 |
+
def test_overlap_next(self):
|
35 |
+
segments = [
|
36 |
+
{'start': 5.0, 'end': 39.182},
|
37 |
+
{'start': 39.986, 'end': 40.814}
|
38 |
+
]
|
39 |
+
|
40 |
+
result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
|
41 |
+
|
42 |
+
self.assertListEqual(result, [
|
43 |
+
{'start': 4.0, 'end': 39.584},
|
44 |
+
{'start': 39.584, 'end': 41.814}
|
45 |
+
])
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
unittest.main()
|
tests/vad_test.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
import unittest
|
3 |
+
import numpy as np
|
4 |
+
import sys
|
5 |
+
|
6 |
+
sys.path.append('../whisper-webui')
|
7 |
+
|
8 |
+
from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
|
9 |
+
|
10 |
+
class TestVad(unittest.TestCase):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super(TestVad, self).__init__(*args, **kwargs)
|
13 |
+
self.transcribe_calls = []
|
14 |
+
|
15 |
+
def test_transcript(self):
|
16 |
+
mock = MockVadTranscription()
|
17 |
+
|
18 |
+
self.transcribe_calls.clear()
|
19 |
+
result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
|
20 |
+
|
21 |
+
self.assertListEqual(self.transcribe_calls, [
|
22 |
+
[30, 30],
|
23 |
+
[100, 100]
|
24 |
+
])
|
25 |
+
|
26 |
+
self.assertListEqual(result['segments'],
|
27 |
+
[{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
|
28 |
+
{'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
|
29 |
+
)
|
30 |
+
|
31 |
+
def transcribe_segments(self, segment):
|
32 |
+
self.transcribe_calls.append(segment.tolist())
|
33 |
+
|
34 |
+
# Dummy text
|
35 |
+
return {
|
36 |
+
'text': "Hello world ",
|
37 |
+
'segments': [
|
38 |
+
{
|
39 |
+
"start": 10.0,
|
40 |
+
"end": 20.0,
|
41 |
+
"text": "Hello world "
|
42 |
+
}
|
43 |
+
],
|
44 |
+
'language': ""
|
45 |
+
}
|
46 |
+
|
47 |
+
class MockVadTranscription(AbstractTranscription):
|
48 |
+
def __init__(self):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
52 |
+
start_time_seconds = float(start_time.removesuffix("s"))
|
53 |
+
duration_seconds = float(duration.removesuffix("s"))
|
54 |
+
|
55 |
+
# For mocking, this just returns a simple numppy array
|
56 |
+
return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
|
57 |
+
|
58 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
|
59 |
+
result = []
|
60 |
+
|
61 |
+
result.append( { 'start': 30, 'end': 60 } )
|
62 |
+
result.append( { 'start': 100, 'end': 200 } )
|
63 |
+
return result
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
unittest.main()
|