NeuraFusionAI commited on
Commit
d0c5b3a
·
1 Parent(s): 2412b00

Initial commit of WhisperFast project

Browse files
Files changed (5) hide show
  1. README.md +4 -7
  2. app.py +173 -0
  3. languages.py +147 -0
  4. requirements.txt +5 -0
  5. subtitle_manager.py +52 -0
README.md CHANGED
@@ -1,13 +1,10 @@
1
- ---
2
  title: WhisperFast
3
- emoji: 🌍
4
- colorFrom: green
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  title: WhisperFast
2
+ emoji: 🗣️
3
+ colorFrom: blue
4
  colorTo: blue
5
  sdk: gradio
6
+ sdk_version: 4.36.0
7
  app_file: app.py
8
  pinned: false
9
  license: mit
10
+ short_description: Transcribe audio to subtitles instantly
 
 
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ import logging
4
+ import torch
5
+ from sys import platform
6
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
7
+ from transformers.utils import is_flash_attn_2_available
8
+ from languages import get_language_names
9
+ from subtitle_manager import Subtitle
10
+
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ last_model = None
14
+ pipe = None
15
+
16
+ def write_file(output_file,subtitle):
17
+ with open(output_file, 'w', encoding='utf-8') as f:
18
+ f.write(subtitle)
19
+
20
+ def create_pipe(model, flash):
21
+ if torch.cuda.is_available():
22
+ device = "cuda:0"
23
+ elif platform == "darwin":
24
+ device = "mps"
25
+ else:
26
+ device = "cpu"
27
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
28
+ model_id = model
29
+
30
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
31
+ model_id,
32
+ torch_dtype=torch_dtype,
33
+ low_cpu_mem_usage=True,
34
+ use_safetensors=True,
35
+ attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
36
+ # eager (manual attention implementation)
37
+ # flash_attention_2 (implementation using flash attention 2)
38
+ # sdpa (implementation using torch.nn.functional.scaled_dot_product_attention)
39
+ # PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
40
+ )
41
+ model.to(device)
42
+
43
+ processor = AutoProcessor.from_pretrained(model_id)
44
+
45
+ pipe = pipeline(
46
+ "automatic-speech-recognition",
47
+ model=model,
48
+ tokenizer=processor.tokenizer,
49
+ feature_extractor=processor.feature_extractor,
50
+ # max_new_tokens=128,
51
+ # chunk_length_s=15,
52
+ # batch_size=16,
53
+ torch_dtype=torch_dtype,
54
+ device=device,
55
+ )
56
+ return pipe
57
+
58
+ def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
59
+ chunk_length_s, batch_size, progress=gr.Progress()):
60
+ global last_model
61
+ global pipe
62
+
63
+ progress(0, desc="Loading Audio..")
64
+ logging.info(f"urlData:{urlData}")
65
+ logging.info(f"multipleFiles:{multipleFiles}")
66
+ logging.info(f"microphoneData:{microphoneData}")
67
+ logging.info(f"task: {task}")
68
+ logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}")
69
+ logging.info(f"chunk_length_s: {chunk_length_s}")
70
+ logging.info(f"batch_size: {batch_size}")
71
+
72
+ if last_model == None:
73
+ logging.info("first model")
74
+ progress(0.1, desc="Loading Model..")
75
+ pipe = create_pipe(modelName, flash)
76
+ elif modelName != last_model:
77
+ logging.info("new model")
78
+ torch.cuda.empty_cache()
79
+ progress(0.1, desc="Loading Model..")
80
+ pipe = create_pipe(modelName, flash)
81
+ else:
82
+ logging.info("Model not changed")
83
+ last_model = modelName
84
+
85
+ srt_sub = Subtitle("srt")
86
+ vtt_sub = Subtitle("vtt")
87
+ txt_sub = Subtitle("txt")
88
+
89
+ files = []
90
+ if multipleFiles:
91
+ files+=multipleFiles
92
+ if urlData:
93
+ files.append(urlData)
94
+ if microphoneData:
95
+ files.append(microphoneData)
96
+ logging.info(files)
97
+
98
+ generate_kwargs = {}
99
+ if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
100
+ generate_kwargs["language"] = languageName
101
+ if modelName.endswith(".en") == False:
102
+ generate_kwargs["task"] = task
103
+
104
+ files_out = []
105
+ for file in progress.tqdm(files, desc="Working..."):
106
+ start_time = time.time()
107
+ logging.info(file)
108
+ outputs = pipe(
109
+ file,
110
+ chunk_length_s=chunk_length_s,#30
111
+ batch_size=batch_size,#24
112
+ generate_kwargs=generate_kwargs,
113
+ return_timestamps=True,
114
+ )
115
+ logging.debug(outputs)
116
+ logging.info(print(f"transcribe: {time.time() - start_time} sec."))
117
+
118
+ file_out = file.split('/')[-1]
119
+ srt = srt_sub.get_subtitle(outputs["chunks"])
120
+ vtt = vtt_sub.get_subtitle(outputs["chunks"])
121
+ txt = txt_sub.get_subtitle(outputs["chunks"])
122
+ write_file(file_out+".srt",srt)
123
+ write_file(file_out+".vtt",vtt)
124
+ write_file(file_out+".txt",txt)
125
+ files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]
126
+
127
+ progress(1, desc="Completed!")
128
+
129
+ return files_out, vtt, txt
130
+
131
+
132
+ with gr.Blocks(title="Insanely Fast Whisper") as demo:
133
+ description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn"
134
+
135
+ whisper_models = [
136
+ "openai/whisper-tiny", "openai/whisper-tiny.en",
137
+ "openai/whisper-base", "openai/whisper-base.en",
138
+ "openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en",
139
+ "openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en",
140
+ "openai/whisper-large",
141
+ "openai/whisper-large-v1",
142
+ "openai/whisper-large-v2", "distil-whisper/distil-large-v2",
143
+ "openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2",
144
+ ]
145
+ waveform_options=gr.WaveformOptions(
146
+ waveform_color="#01C6FF",
147
+ waveform_progress_color="#0066B4",
148
+ skip_length=2,
149
+ show_controls=False,
150
+ )
151
+
152
+ simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
153
+ description=description,
154
+ article=article,
155
+ inputs=[
156
+ gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,),
157
+ gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,),
158
+ gr.Text(label="URL", info="(YouTube, etc.)", interactive = True),
159
+ gr.File(label="Upload Files", file_count="multiple"),
160
+ gr.Audio(sources=["upload", "microphone",], type="filepath", label="Input", waveform_options = waveform_options),
161
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True),
162
+ gr.Checkbox(label='Flash',info='Use Flash Attention 2'),
163
+ gr.Number(label='chunk_length_s',value=30, interactive = True),
164
+ gr.Number(label='batch_size',value=24, interactive = True)
165
+ ], outputs=[
166
+ gr.File(label="Download"),
167
+ gr.Text(label="Transcription"),
168
+ gr.Text(label="Segments")
169
+ ]
170
+ )
171
+
172
+ if __name__ == "__main__":
173
+ demo.launch()
languages.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Language():
2
+ def __init__(self, code, name):
3
+ self.code = code
4
+ self.name = name
5
+
6
+ def __str__(self):
7
+ return "Language(code={}, name={})".format(self.code, self.name)
8
+
9
+ LANGUAGES = [
10
+ Language('en', 'English'),
11
+ Language('zh', 'Chinese'),
12
+ Language('de', 'German'),
13
+ Language('es', 'Spanish'),
14
+ Language('ru', 'Russian'),
15
+ Language('ko', 'Korean'),
16
+ Language('fr', 'French'),
17
+ Language('ja', 'Japanese'),
18
+ Language('pt', 'Portuguese'),
19
+ Language('tr', 'Turkish'),
20
+ Language('pl', 'Polish'),
21
+ Language('ca', 'Catalan'),
22
+ Language('nl', 'Dutch'),
23
+ Language('ar', 'Arabic'),
24
+ Language('sv', 'Swedish'),
25
+ Language('it', 'Italian'),
26
+ Language('id', 'Indonesian'),
27
+ Language('hi', 'Hindi'),
28
+ Language('fi', 'Finnish'),
29
+ Language('vi', 'Vietnamese'),
30
+ Language('he', 'Hebrew'),
31
+ Language('uk', 'Ukrainian'),
32
+ Language('el', 'Greek'),
33
+ Language('ms', 'Malay'),
34
+ Language('cs', 'Czech'),
35
+ Language('ro', 'Romanian'),
36
+ Language('da', 'Danish'),
37
+ Language('hu', 'Hungarian'),
38
+ Language('ta', 'Tamil'),
39
+ Language('no', 'Norwegian'),
40
+ Language('th', 'Thai'),
41
+ Language('ur', 'Urdu'),
42
+ Language('hr', 'Croatian'),
43
+ Language('bg', 'Bulgarian'),
44
+ Language('lt', 'Lithuanian'),
45
+ Language('la', 'Latin'),
46
+ Language('mi', 'Maori'),
47
+ Language('ml', 'Malayalam'),
48
+ Language('cy', 'Welsh'),
49
+ Language('sk', 'Slovak'),
50
+ Language('te', 'Telugu'),
51
+ Language('fa', 'Persian'),
52
+ Language('lv', 'Latvian'),
53
+ Language('bn', 'Bengali'),
54
+ Language('sr', 'Serbian'),
55
+ Language('az', 'Azerbaijani'),
56
+ Language('sl', 'Slovenian'),
57
+ Language('kn', 'Kannada'),
58
+ Language('et', 'Estonian'),
59
+ Language('mk', 'Macedonian'),
60
+ Language('br', 'Breton'),
61
+ Language('eu', 'Basque'),
62
+ Language('is', 'Icelandic'),
63
+ Language('hy', 'Armenian'),
64
+ Language('ne', 'Nepali'),
65
+ Language('mn', 'Mongolian'),
66
+ Language('bs', 'Bosnian'),
67
+ Language('kk', 'Kazakh'),
68
+ Language('sq', 'Albanian'),
69
+ Language('sw', 'Swahili'),
70
+ Language('gl', 'Galician'),
71
+ Language('mr', 'Marathi'),
72
+ Language('pa', 'Punjabi'),
73
+ Language('si', 'Sinhala'),
74
+ Language('km', 'Khmer'),
75
+ Language('sn', 'Shona'),
76
+ Language('yo', 'Yoruba'),
77
+ Language('so', 'Somali'),
78
+ Language('af', 'Afrikaans'),
79
+ Language('oc', 'Occitan'),
80
+ Language('ka', 'Georgian'),
81
+ Language('be', 'Belarusian'),
82
+ Language('tg', 'Tajik'),
83
+ Language('sd', 'Sindhi'),
84
+ Language('gu', 'Gujarati'),
85
+ Language('am', 'Amharic'),
86
+ Language('yi', 'Yiddish'),
87
+ Language('lo', 'Lao'),
88
+ Language('uz', 'Uzbek'),
89
+ Language('fo', 'Faroese'),
90
+ Language('ht', 'Haitian creole'),
91
+ Language('ps', 'Pashto'),
92
+ Language('tk', 'Turkmen'),
93
+ Language('nn', 'Nynorsk'),
94
+ Language('mt', 'Maltese'),
95
+ Language('sa', 'Sanskrit'),
96
+ Language('lb', 'Luxembourgish'),
97
+ Language('my', 'Myanmar'),
98
+ Language('bo', 'Tibetan'),
99
+ Language('tl', 'Tagalog'),
100
+ Language('mg', 'Malagasy'),
101
+ Language('as', 'Assamese'),
102
+ Language('tt', 'Tatar'),
103
+ Language('haw', 'Hawaiian'),
104
+ Language('ln', 'Lingala'),
105
+ Language('ha', 'Hausa'),
106
+ Language('ba', 'Bashkir'),
107
+ Language('jw', 'Javanese'),
108
+ Language('su', 'Sundanese')
109
+ ]
110
+
111
+ _TO_LANGUAGE_CODE = {
112
+ **{language.code: language for language in LANGUAGES},
113
+ "burmese": "my",
114
+ "valencian": "ca",
115
+ "flemish": "nl",
116
+ "haitian": "ht",
117
+ "letzeburgesch": "lb",
118
+ "pushto": "ps",
119
+ "panjabi": "pa",
120
+ "moldavian": "ro",
121
+ "moldovan": "ro",
122
+ "sinhalese": "si",
123
+ "castilian": "es",
124
+ }
125
+
126
+ _FROM_LANGUAGE_NAME = {
127
+ **{language.name.lower(): language for language in LANGUAGES}
128
+ }
129
+
130
+ def get_language_from_code(language_code, default=None) -> Language:
131
+ """Return the language name from the language code."""
132
+ return _TO_LANGUAGE_CODE.get(language_code, default)
133
+
134
+ def get_language_from_name(language, default=None) -> Language:
135
+ """Return the language code from the language name."""
136
+ return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
137
+
138
+ def get_language_names():
139
+ """Return a list of language names."""
140
+ return [language.name for language in LANGUAGES]
141
+
142
+ if __name__ == "__main__":
143
+ # Test lookup
144
+ print(get_language_from_code('en'))
145
+ print(get_language_from_name('English'))
146
+
147
+ print(get_language_names())
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch>=2.1.1
3
+ gradio==4.16.0
4
+ transformers
5
+ accelerate
subtitle_manager.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ class Subtitle():
4
+ def __init__(self,ext="srt"):
5
+ sub_dict = {
6
+ "srt":{
7
+ "coma": ",",
8
+ "header": "",
9
+ "format": lambda i,segment : f"{i + 1}\n{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n",
10
+ },
11
+ "vtt":{
12
+ "coma": ".",
13
+ "header": "WebVTT\n\n",
14
+ "format": lambda i,segment : f"{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n",
15
+ },
16
+ "txt":{
17
+ "coma": "",
18
+ "header": "",
19
+ "format": lambda i,segment : f"{segment['text']}\n",
20
+ },
21
+ }
22
+
23
+ self.ext = ext
24
+ self.coma = sub_dict[ext]["coma"]
25
+ self.header = sub_dict[ext]["header"]
26
+ self.format = sub_dict[ext]["format"]
27
+
28
+ def timeformat(self,time):
29
+ hours = time // 3600
30
+ minutes = (time - hours * 3600) // 60
31
+ seconds = time - hours * 3600 - minutes * 60
32
+ milliseconds = (time - int(time)) * 1000
33
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}{self.coma}{int(milliseconds):03d}"
34
+
35
+ def get_subtitle(self,segments):
36
+ output = self.header
37
+ for i, segment in enumerate(segments):
38
+ if segment['text'].startswith(' '):
39
+ segment['text'] = segment['text'][1:]
40
+ try:
41
+ output += self.format(i,segment)
42
+ except Exception as e:
43
+ print(e,segment)
44
+
45
+ return output
46
+
47
+ def write_subtitle(self, segments, output_file):
48
+ output_file += "."+self.ext
49
+ subtitle = self.get_subtitle(segments)
50
+
51
+ with open(output_file, 'w', encoding='utf-8') as f:
52
+ f.write(subtitle)