sanchit-gandhi commited on
Commit
bf2a8ff
·
1 Parent(s): f8dd558

revert short-form changes

Browse files
Files changed (4) hide show
  1. app.py +21 -55
  2. assets/example_1.wav +2 -2
  3. assets/example_2.wav +2 -2
  4. assets/example_3.wav +0 -3
app.py CHANGED
@@ -1,7 +1,6 @@
1
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, TextIteratorStreamer
2
  from transformers.utils import is_flash_attn_2_available
3
  from transformers.pipelines.audio_utils import ffmpeg_read
4
- from threading import Thread
5
  import torch
6
  import gradio as gr
7
  import time
@@ -26,7 +25,6 @@ if not use_flash_attention_2:
26
  distilled_model = distilled_model.to_bettertransformer()
27
 
28
  processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
29
- streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
30
 
31
  model.to(device)
32
  distilled_model.to(device)
@@ -58,6 +56,7 @@ distil_pipe = pipeline(
58
  )
59
  distil_pipe_forward = distil_pipe._forward
60
 
 
61
  def transcribe(inputs):
62
  if inputs is None:
63
  raise gr.Error("No audio file submitted! Please record or upload an audio file before submitting your request.")
@@ -74,65 +73,32 @@ def transcribe(inputs):
74
  f"Got an audio of length {round(audio_length_mins, 3)} minutes."
75
  )
76
 
77
- if audio_length_mins >= 0.5:
78
- inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
79
-
80
- def _forward_distil_time(*args, **kwargs):
81
- global distil_runtime_pipeline
82
- start_time = time.time()
83
- result = distil_pipe_forward(*args, **kwargs)
84
- distil_runtime_pipeline = time.time() - start_time
85
- distil_runtime_pipeline = round(distil_runtime_pipeline, 2)
86
- return result
87
-
88
- distil_pipe._forward = _forward_distil_time
89
- distil_text = distil_pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"]
90
- yield distil_text, distil_runtime_pipeline, None, None
91
-
92
- def _forward_time(*args, **kwargs):
93
- global runtime_pipeline
94
- start_time = time.time()
95
- result = pipe_forward(*args, **kwargs)
96
- runtime_pipeline = time.time() - start_time
97
- runtime_pipeline = round(runtime_pipeline, 2)
98
- return result
99
-
100
- pipe._forward = _forward_time
101
- text = pipe(inputs, batch_size=BATCH_SIZE)["text"]
102
-
103
- yield distil_text, distil_runtime_pipeline, text, runtime_pipeline
104
-
105
- else:
106
- input_features = processor(inputs, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt").input_features
107
- input_features = input_features.to(device, dtype=torch_dtype)
108
 
109
- # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
110
- generation_kwargs = dict(input_features=input_features, streamer=streamer, max_new_tokens=128, language="en", task="transcribe")
111
- thread = Thread(target=distilled_model.generate, kwargs=generation_kwargs)
112
-
113
- thread.start()
114
  start_time = time.time()
115
- distil_text = ""
116
- for generated_text in streamer:
117
- distil_text += generated_text
118
- yield distil_text, None, None, None
119
-
120
  distil_runtime = time.time() - start_time
121
  distil_runtime = round(distil_runtime, 2)
122
- yield distil_text, distil_runtime, None, None
123
 
124
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
125
 
126
- thread.start()
 
127
  start_time = time.time()
128
- text = ""
129
- for generated_text in streamer:
130
- text += generated_text
131
- yield distil_text, distil_runtime, text, None
132
-
133
  runtime = time.time() - start_time
134
  runtime = round(runtime, 2)
135
- yield distil_text, distil_runtime, text, runtime
 
 
 
 
 
136
 
137
 
138
  if __name__ == "__main__":
@@ -158,7 +124,7 @@ if __name__ == "__main__":
158
  of the <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper</a> model by OpenAI. Compared to Whisper,
159
  Distil-Whisper runs 6x faster with 50% fewer parameters, while performing to within 1% word error rate (WER) on
160
  out-of-distribution evaluation data.</p>
161
-
162
  <p>In this demo, we perform a speed comparison between Whisper and Distil-Whisper in order to test this claim.
163
  Both models use the <a href="https://huggingface.co/distil-whisper/distil-large-v2#long-form-transcription"> chunked long-form transcription algorithm</a>
164
  in 🤗 Transformers, as well as Flash Attention. To use Distil-Whisper yourself, check the code examples on the
@@ -181,7 +147,7 @@ if __name__ == "__main__":
181
  )
182
  gr.Markdown("## Examples")
183
  gr.Examples(
184
- [["./assets/example_1.wav"], ["./assets/example_2.wav"], ["./assets/example_3.wav"]],
185
  audio,
186
  outputs=[distil_transcription, distil_runtime, transcription, runtime],
187
  fn=transcribe,
 
1
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
2
  from transformers.utils import is_flash_attn_2_available
3
  from transformers.pipelines.audio_utils import ffmpeg_read
 
4
  import torch
5
  import gradio as gr
6
  import time
 
25
  distilled_model = distilled_model.to_bettertransformer()
26
 
27
  processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
 
28
 
29
  model.to(device)
30
  distilled_model.to(device)
 
56
  )
57
  distil_pipe_forward = distil_pipe._forward
58
 
59
+
60
  def transcribe(inputs):
61
  if inputs is None:
62
  raise gr.Error("No audio file submitted! Please record or upload an audio file before submitting your request.")
 
73
  f"Got an audio of length {round(audio_length_mins, 3)} minutes."
74
  )
75
 
76
+ inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ def _forward_distil_time(*args, **kwargs):
79
+ global distil_runtime
 
 
 
80
  start_time = time.time()
81
+ result = distil_pipe_forward(*args, **kwargs)
 
 
 
 
82
  distil_runtime = time.time() - start_time
83
  distil_runtime = round(distil_runtime, 2)
84
+ return result
85
 
86
+ distil_pipe._forward = _forward_distil_time
87
+ distil_text = distil_pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"]
88
+ yield distil_text, distil_runtime, None, None, None
89
 
90
+ def _forward_time(*args, **kwargs):
91
+ global runtime
92
  start_time = time.time()
93
+ result = pipe_forward(*args, **kwargs)
 
 
 
 
94
  runtime = time.time() - start_time
95
  runtime = round(runtime, 2)
96
+ return result
97
+
98
+ pipe._forward = _forward_time
99
+ text = pipe(inputs, batch_size=BATCH_SIZE)["text"]
100
+
101
+ yield distil_text, distil_runtime, text, runtime
102
 
103
 
104
  if __name__ == "__main__":
 
124
  of the <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper</a> model by OpenAI. Compared to Whisper,
125
  Distil-Whisper runs 6x faster with 50% fewer parameters, while performing to within 1% word error rate (WER) on
126
  out-of-distribution evaluation data.</p>
127
+
128
  <p>In this demo, we perform a speed comparison between Whisper and Distil-Whisper in order to test this claim.
129
  Both models use the <a href="https://huggingface.co/distil-whisper/distil-large-v2#long-form-transcription"> chunked long-form transcription algorithm</a>
130
  in 🤗 Transformers, as well as Flash Attention. To use Distil-Whisper yourself, check the code examples on the
 
147
  )
148
  gr.Markdown("## Examples")
149
  gr.Examples(
150
+ [["./assets/example_1.wav"], ["./assets/example_2.wav"]],
151
  audio,
152
  outputs=[distil_transcription, distil_runtime, transcription, runtime],
153
  fn=transcribe,
assets/example_1.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d96fece5c0c24d039801e9e39e9985982ad63becdab6c1a141992aa6dd37a615
3
- size 802110
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e938b9f81dea096ec7d3752e90afca8d370f7a461d3a08e1a559f4440ed055d
3
+ size 1963810
assets/example_2.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e938b9f81dea096ec7d3752e90afca8d370f7a461d3a08e1a559f4440ed055d
3
- size 1963810
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81fc0857f7fe11416ede431db713a02fdb787bbc049802fe74c791f3b44e5bf4
3
+ size 1920044
assets/example_3.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:81fc0857f7fe11416ede431db713a02fdb787bbc049802fe74c791f3b44e5bf4
3
- size 1920044