betheredge commited on
Commit
275a08a
·
1 Parent(s): bce600b

Upgrading with description and summarization.

Browse files
Files changed (4) hide show
  1. spaces_info.py +0 -10
  2. src/app.py +69 -19
  3. src/requirements.in +3 -1
  4. src/requirements.txt +107 -6
spaces_info.py CHANGED
@@ -1,10 +0,0 @@
1
- description = """Gradio Demo for exploring Speech Transcription.
2
-
3
- Upload an audio file or record yourself to see a transcription.
4
- The transcription passes through 3 models: transcription, punctuation, and capitalization.
5
- All output is given
6
-
7
- Tips:
8
- - Large files will take a while to process.
9
- - Live recordings is on the second tab.
10
- """
 
 
 
 
 
 
 
 
 
 
 
src/app.py CHANGED
@@ -8,11 +8,27 @@ puntuation_model = PunctuationModel()
8
  # capitalization_model = ("KES/caribe-capitalise")
9
  # text = "My name is Clara and I live in Berkeley California Ist das eine Frage Frau Müller"
10
  # print(result)
11
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  capitalise_tokenizer = AutoTokenizer.from_pretrained("KES/caribe-capitalise")
14
  capitalise_model = AutoModelForSeq2SeqLM.from_pretrained("KES/caribe-capitalise")
 
 
15
 
 
16
 
17
  pipe = pipeline(
18
  model="facebook/wav2vec2-large-960h",
@@ -25,26 +41,53 @@ def translate(audio_file):
25
  text = x['text']
26
  return text
27
 
 
28
  def punctuation(text):
29
  punctuation = puntuation_model.restore_punctuation(text)
30
  return punctuation
31
 
 
32
  def capitalise(text):
33
  text = text.lower()
34
  inputs = capitalise_tokenizer("text:"+text, truncation=True, return_tensors='pt')
35
  # print(capitalization)
36
- output = capitalise_model.generate(inputs['input_ids'], num_beams=4, max_length=512, early_stopping=True)
 
37
  capitalised_text = capitalise_tokenizer.batch_decode(output, skip_special_tokens=True)
38
 
39
  result = ("".join(capitalised_text))
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  return result
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def all(file):
44
  trans_text = translate(file).lower()
45
  punct_text = punctuation(trans_text)
46
  cap_text = capitalise(punct_text)
47
- return trans_text, punct_text, cap_text
 
48
 
49
  input = gr.Audio(type="filepath")
50
  live_in = gr.Audio(type="filepath", source="microphone")
@@ -54,21 +97,22 @@ live_in = gr.Audio(type="filepath", source="microphone")
54
  raw_output = gr.Text(label="Raw Output")
55
  puncuation_output = gr.Text(label="Punctuation Output")
56
  capitalization_output = gr.Text(label="Capitalization Output")
 
57
 
58
- translater = gr.Interface(
59
- fn=translate,
60
- inputs=input,
61
- outputs=[raw_output])
62
 
63
- punctuation = gr.Interface(
64
- fn=punctuation,
65
- inputs=raw_output,
66
- outputs=[puncuation_output])
67
 
68
- capitalization = gr.Interface(
69
- fn=capitalise,
70
- inputs=puncuation_output,
71
- outputs=[capitalization_output])
72
 
73
 
74
 
@@ -76,11 +120,17 @@ capitalization = gr.Interface(
76
  live_demo = gr.Interface(
77
  fn=all,
78
  inputs=live_in,
79
- outputs=[raw_output, puncuation_output, capitalization_output])
 
80
  demo = gr.Interface(
81
  fn=all,
82
  inputs=input,
83
- outputs=[raw_output, puncuation_output, capitalization_output])
 
84
 
85
- # demo.launch(share=True)
86
- gr.TabbedInterface([demo, live_demo], tab_names=["Upload File", "Record Self"]).launch()
 
 
 
 
 
8
  # capitalization_model = ("KES/caribe-capitalise")
9
  # text = "My name is Clara and I live in Berkeley California Ist das eine Frage Frau Müller"
10
  # print(result)
11
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
12
+
13
+ description = """
14
+ # Gradio Demo for exploring Speech Transcription.
15
+
16
+ Upload an audio file or record yourself to see a transcription.
17
+ The transcription passes through 4 models: transcription, punctuation, capitalization, and summarization.
18
+ All output is given
19
+
20
+ Tips:
21
+ - Large files will take a while to process.
22
+ - Live recording is on the second tab.
23
+ """
24
+
25
 
26
  capitalise_tokenizer = AutoTokenizer.from_pretrained("KES/caribe-capitalise")
27
  capitalise_model = AutoModelForSeq2SeqLM.from_pretrained("KES/caribe-capitalise")
28
+ spell_tokenizer = AutoTokenizer.from_pretrained("murali1996/bert-base-cased-spell-correction")
29
+ spell_model = AutoModel.from_pretrained("murali1996/bert-base-cased-spell-correction")
30
 
31
+ summarizer = pipeline("summarization")
32
 
33
  pipe = pipeline(
34
  model="facebook/wav2vec2-large-960h",
 
41
  text = x['text']
42
  return text
43
 
44
+
45
  def punctuation(text):
46
  punctuation = puntuation_model.restore_punctuation(text)
47
  return punctuation
48
 
49
+
50
  def capitalise(text):
51
  text = text.lower()
52
  inputs = capitalise_tokenizer("text:"+text, truncation=True, return_tensors='pt')
53
  # print(capitalization)
54
+ output = capitalise_model.generate(inputs['input_ids'], num_beams=4, max_length=4096, early_stopping=True)
55
+ # output = capitalise_model.generate(inputs['input_ids'], num_beams=4, max_length=1024, early_stopping=True)
56
  capitalised_text = capitalise_tokenizer.batch_decode(output, skip_special_tokens=True)
57
 
58
  result = ("".join(capitalised_text))
59
+ return result
60
+
61
+
62
+ def spell_check(text):
63
+ text = text.lower()
64
+ inputs = spell_tokenizer(text, return_tensors='pt')
65
+ # print(capitalization)
66
+ output = spell_model.generate(inputs)
67
+ spell_text = spell_tokenizer.batch_decode(output, skip_special_tokens=True)
68
+
69
+ result = ("".join(spell_text))
70
 
71
  return result
72
 
73
+
74
+ def summarize(text):
75
+ results = None
76
+ length = len(text)
77
+ while not results:
78
+ try:
79
+ results = summarizer(text[:length], min_length=10, max_length=128)
80
+ except IndexError:
81
+ print(f"shortening text: {length} -> {length//2}")
82
+ length = length // 2
83
+ return results[0]['summary_text']
84
+
85
  def all(file):
86
  trans_text = translate(file).lower()
87
  punct_text = punctuation(trans_text)
88
  cap_text = capitalise(punct_text)
89
+ sum_text = summarize(punct_text)
90
+ return trans_text, punct_text, cap_text, sum_text
91
 
92
  input = gr.Audio(type="filepath")
93
  live_in = gr.Audio(type="filepath", source="microphone")
 
97
  raw_output = gr.Text(label="Raw Output")
98
  puncuation_output = gr.Text(label="Punctuation Output")
99
  capitalization_output = gr.Text(label="Capitalization Output")
100
+ sum_output = gr.Text(label="Summarized Output")
101
 
102
+ # translater = gr.Interface(
103
+ # fn=translate,
104
+ # inputs=input,
105
+ # outputs=raw_output)
106
 
107
+ # punctuation = gr.Interface(
108
+ # fn=punctuation,
109
+ # inputs=raw_output,
110
+ # outputs=puncuation_output)
111
 
112
+ # capitalization = gr.Interface(
113
+ # fn=capitalise,
114
+ # inputs=puncuation_output,
115
+ # outputs=capitalization_output]
116
 
117
 
118
 
 
120
  live_demo = gr.Interface(
121
  fn=all,
122
  inputs=live_in,
123
+ outputs=[raw_output, puncuation_output, capitalization_output, sum_output],
124
+ description=description)
125
  demo = gr.Interface(
126
  fn=all,
127
  inputs=input,
128
+ outputs=[raw_output, puncuation_output, capitalization_output, sum_output],
129
+ description=description)
130
 
131
+ # interface = gr.Series(
132
+ # gr.Textbox(value=description, show_label=False, interactive=False),
133
+ # gr.TabbedInterface([demo, live_demo], tab_names=["Upload File", "Record Self"])
134
+ # )
135
+ interface = gr.TabbedInterface([demo, live_demo], tab_names=["Upload File", "Record Self"])
136
+ interface.launch()
src/requirements.in CHANGED
@@ -13,4 +13,6 @@ numpy
13
  pandas
14
  icecream
15
  jupyter
16
- gradio
 
 
 
13
  pandas
14
  icecream
15
  jupyter
16
+ gradio
17
+ deepmultilingualpunctuation
18
+ sentencepiece
src/requirements.txt CHANGED
@@ -9,9 +9,17 @@
9
  absl-py==1.2.0
10
  # via tensorboard
11
  aiohttp==3.8.1
12
- # via fsspec
 
 
13
  aiosignal==1.2.0
14
  # via aiohttp
 
 
 
 
 
 
15
  argon2-cffi==21.3.0
16
  # via notebook
17
  argon2-cffi-bindings==21.2.0
@@ -28,6 +36,10 @@ attrs==21.4.0
28
  # jsonschema
29
  backcall==0.2.0
30
  # via ipython
 
 
 
 
31
  beautifulsoup4==4.11.1
32
  # via nbconvert
33
  bleach==5.0.1
@@ -37,23 +49,36 @@ build==0.8.0
37
  cachetools==5.2.0
38
  # via google-auth
39
  certifi==2022.6.15
40
- # via requests
 
 
 
41
  cffi==1.15.1
42
- # via argon2-cffi-bindings
 
 
 
 
43
  charset-normalizer==2.1.0
44
  # via
45
  # aiohttp
46
  # requests
47
  click==8.1.3
48
- # via pip-tools
 
 
49
  colorama==0.4.5
50
  # via icecream
 
 
51
  cycler==0.11.0
52
  # via matplotlib
53
  debugpy==1.6.2
54
  # via ipykernel
55
  decorator==5.1.1
56
  # via ipython
 
 
57
  defusedxml==0.7.1
58
  # via nbconvert
59
  entrypoints==0.4
@@ -64,8 +89,12 @@ executing==0.8.3
64
  # via
65
  # icecream
66
  # stack-data
 
 
67
  fastjsonschema==2.16.1
68
  # via nbformat
 
 
69
  filelock==3.7.1
70
  # via
71
  # huggingface-hub
@@ -77,22 +106,37 @@ frozenlist==1.3.0
77
  # aiohttp
78
  # aiosignal
79
  fsspec[http]==2022.5.0
80
- # via pytorch-lightning
 
 
81
  google-auth==2.9.1
82
  # via
83
  # google-auth-oauthlib
84
  # tensorboard
85
  google-auth-oauthlib==0.4.6
86
  # via tensorboard
 
 
87
  grpcio==1.47.0
88
  # via tensorboard
 
 
 
 
 
 
 
 
 
89
  huggingface-hub==0.8.1
90
  # via transformers
91
  icecream==2.1.3
92
  # via -r requirements.in
93
  idna==3.3
94
  # via
 
95
  # requests
 
96
  # yarl
97
  importlib-metadata==4.12.0
98
  # via markdown
@@ -121,6 +165,7 @@ jedi==0.18.1
121
  # via ipython
122
  jinja2==3.1.2
123
  # via
 
124
  # nbconvert
125
  # notebook
126
  jsonschema==4.7.2
@@ -149,8 +194,14 @@ jupyterlab-widgets==1.1.1
149
  # via ipywidgets
150
  kiwisolver==1.4.4
151
  # via matplotlib
 
 
152
  markdown==3.4.1
153
  # via tensorboard
 
 
 
 
154
  markupsafe==2.1.1
155
  # via
156
  # jinja2
@@ -158,13 +209,20 @@ markupsafe==2.1.1
158
  matplotlib==3.5.2
159
  # via
160
  # -r requirements.in
 
161
  # seaborn
162
  matplotlib-inline==0.1.3
163
  # via
164
  # ipykernel
165
  # ipython
 
 
 
 
166
  mistune==0.8.4
167
  # via nbconvert
 
 
168
  multidict==6.0.2
169
  # via
170
  # aiohttp
@@ -193,6 +251,7 @@ notebook==6.4.12
193
  numpy==1.23.1
194
  # via
195
  # -r requirements.in
 
196
  # matplotlib
197
  # pandas
198
  # pytorch-lightning
@@ -204,6 +263,8 @@ numpy==1.23.1
204
  # transformers
205
  oauthlib==3.2.0
206
  # via requests-oauthlib
 
 
207
  packaging==21.3
208
  # via
209
  # build
@@ -218,9 +279,12 @@ packaging==21.3
218
  pandas==1.4.3
219
  # via
220
  # -r requirements.in
 
221
  # seaborn
222
  pandocfilters==1.5.0
223
  # via nbconvert
 
 
224
  parso==0.8.3
225
  # via jedi
226
  pep517==0.12.0
@@ -231,6 +295,7 @@ pickleshare==0.7.5
231
  # via ipython
232
  pillow==9.2.0
233
  # via
 
234
  # matplotlib
235
  # torchvision
236
  pip-tools==6.8.0
@@ -261,8 +326,16 @@ pyasn1-modules==0.2.8
261
  # via google-auth
262
  pycparser==2.21
263
  # via cffi
 
 
 
 
 
 
264
  pydeprecate==0.3.2
265
  # via pytorch-lightning
 
 
266
  pygments==2.12.0
267
  # via
268
  # icecream
@@ -270,6 +343,8 @@ pygments==2.12.0
270
  # jupyter-console
271
  # nbconvert
272
  # qtconsole
 
 
273
  pyparsing==3.0.9
274
  # via
275
  # matplotlib
@@ -278,9 +353,12 @@ pyrsistent==0.18.1
278
  # via jsonschema
279
  python-dateutil==2.8.2
280
  # via
 
281
  # jupyter-client
282
  # matplotlib
283
  # pandas
 
 
284
  pytorch-lightning==1.6.5
285
  # via -r requirements.in
286
  pytz==2022.1
@@ -304,7 +382,9 @@ regex==2022.7.9
304
  # via transformers
305
  requests==2.28.1
306
  # via
 
307
  # fsspec
 
308
  # huggingface-hub
309
  # requests-oauthlib
310
  # tensorboard
@@ -312,6 +392,8 @@ requests==2.28.1
312
  # transformers
313
  requests-oauthlib==1.3.1
314
  # via google-auth-oauthlib
 
 
315
  rsa==4.9
316
  # via google-auth
317
  scipy==1.8.1
@@ -322,15 +404,25 @@ send2trash==1.8.0
322
  # via notebook
323
  six==1.16.0
324
  # via
 
325
  # asttokens
326
  # bleach
327
  # google-auth
328
  # grpcio
 
329
  # python-dateutil
 
 
 
 
 
 
330
  soupsieve==2.3.2.post1
331
  # via beautifulsoup4
332
  stack-data==0.3.0
333
  # via ipython
 
 
334
  tensorboard==2.9.1
335
  # via pytorch-lightning
336
  tensorboard-data-server==0.6.1
@@ -350,6 +442,7 @@ tomli==2.0.1
350
  torch==1.12.0+cu116
351
  # via
352
  # -r requirements.in
 
353
  # pytorch-lightning
354
  # torchaudio
355
  # torchmetrics
@@ -385,15 +478,23 @@ traitlets==5.3.0
385
  # notebook
386
  # qtconsole
387
  transformers==4.20.1
388
- # via -r requirements.in
 
 
389
  typing-extensions==4.3.0
390
  # via
391
  # huggingface-hub
 
392
  # pytorch-lightning
 
393
  # torch
394
  # torchvision
 
 
395
  urllib3==1.26.10
396
  # via requests
 
 
397
  wcwidth==0.2.5
398
  # via prompt-toolkit
399
  webencodings==0.5.1
 
9
  absl-py==1.2.0
10
  # via tensorboard
11
  aiohttp==3.8.1
12
+ # via
13
+ # fsspec
14
+ # gradio
15
  aiosignal==1.2.0
16
  # via aiohttp
17
+ analytics-python==1.4.0
18
+ # via gradio
19
+ anyio==3.6.1
20
+ # via
21
+ # httpcore
22
+ # starlette
23
  argon2-cffi==21.3.0
24
  # via notebook
25
  argon2-cffi-bindings==21.2.0
 
36
  # jsonschema
37
  backcall==0.2.0
38
  # via ipython
39
+ backoff==1.10.0
40
+ # via analytics-python
41
+ bcrypt==3.2.2
42
+ # via paramiko
43
  beautifulsoup4==4.11.1
44
  # via nbconvert
45
  bleach==5.0.1
 
49
  cachetools==5.2.0
50
  # via google-auth
51
  certifi==2022.6.15
52
+ # via
53
+ # httpcore
54
+ # httpx
55
+ # requests
56
  cffi==1.15.1
57
+ # via
58
+ # argon2-cffi-bindings
59
+ # bcrypt
60
+ # cryptography
61
+ # pynacl
62
  charset-normalizer==2.1.0
63
  # via
64
  # aiohttp
65
  # requests
66
  click==8.1.3
67
+ # via
68
+ # pip-tools
69
+ # uvicorn
70
  colorama==0.4.5
71
  # via icecream
72
+ cryptography==37.0.4
73
+ # via paramiko
74
  cycler==0.11.0
75
  # via matplotlib
76
  debugpy==1.6.2
77
  # via ipykernel
78
  decorator==5.1.1
79
  # via ipython
80
+ deepmultilingualpunctuation==1.0.1
81
+ # via -r requirements.in
82
  defusedxml==0.7.1
83
  # via nbconvert
84
  entrypoints==0.4
 
89
  # via
90
  # icecream
91
  # stack-data
92
+ fastapi==0.79.0
93
+ # via gradio
94
  fastjsonschema==2.16.1
95
  # via nbformat
96
+ ffmpy==0.3.0
97
+ # via gradio
98
  filelock==3.7.1
99
  # via
100
  # huggingface-hub
 
106
  # aiohttp
107
  # aiosignal
108
  fsspec[http]==2022.5.0
109
+ # via
110
+ # gradio
111
+ # pytorch-lightning
112
  google-auth==2.9.1
113
  # via
114
  # google-auth-oauthlib
115
  # tensorboard
116
  google-auth-oauthlib==0.4.6
117
  # via tensorboard
118
+ gradio==3.1.1
119
+ # via -r requirements.in
120
  grpcio==1.47.0
121
  # via tensorboard
122
+ h11==0.12.0
123
+ # via
124
+ # gradio
125
+ # httpcore
126
+ # uvicorn
127
+ httpcore==0.15.0
128
+ # via httpx
129
+ httpx==0.23.0
130
+ # via gradio
131
  huggingface-hub==0.8.1
132
  # via transformers
133
  icecream==2.1.3
134
  # via -r requirements.in
135
  idna==3.3
136
  # via
137
+ # anyio
138
  # requests
139
+ # rfc3986
140
  # yarl
141
  importlib-metadata==4.12.0
142
  # via markdown
 
165
  # via ipython
166
  jinja2==3.1.2
167
  # via
168
+ # gradio
169
  # nbconvert
170
  # notebook
171
  jsonschema==4.7.2
 
194
  # via ipywidgets
195
  kiwisolver==1.4.4
196
  # via matplotlib
197
+ linkify-it-py==1.0.3
198
+ # via markdown-it-py
199
  markdown==3.4.1
200
  # via tensorboard
201
+ markdown-it-py[linkify,plugins]==2.1.0
202
+ # via
203
+ # gradio
204
+ # mdit-py-plugins
205
  markupsafe==2.1.1
206
  # via
207
  # jinja2
 
209
  matplotlib==3.5.2
210
  # via
211
  # -r requirements.in
212
+ # gradio
213
  # seaborn
214
  matplotlib-inline==0.1.3
215
  # via
216
  # ipykernel
217
  # ipython
218
+ mdit-py-plugins==0.3.0
219
+ # via markdown-it-py
220
+ mdurl==0.1.1
221
+ # via markdown-it-py
222
  mistune==0.8.4
223
  # via nbconvert
224
+ monotonic==1.6
225
+ # via analytics-python
226
  multidict==6.0.2
227
  # via
228
  # aiohttp
 
251
  numpy==1.23.1
252
  # via
253
  # -r requirements.in
254
+ # gradio
255
  # matplotlib
256
  # pandas
257
  # pytorch-lightning
 
263
  # transformers
264
  oauthlib==3.2.0
265
  # via requests-oauthlib
266
+ orjson==3.7.8
267
+ # via gradio
268
  packaging==21.3
269
  # via
270
  # build
 
279
  pandas==1.4.3
280
  # via
281
  # -r requirements.in
282
+ # gradio
283
  # seaborn
284
  pandocfilters==1.5.0
285
  # via nbconvert
286
+ paramiko==2.11.0
287
+ # via gradio
288
  parso==0.8.3
289
  # via jedi
290
  pep517==0.12.0
 
295
  # via ipython
296
  pillow==9.2.0
297
  # via
298
+ # gradio
299
  # matplotlib
300
  # torchvision
301
  pip-tools==6.8.0
 
326
  # via google-auth
327
  pycparser==2.21
328
  # via cffi
329
+ pycryptodome==3.15.0
330
+ # via gradio
331
+ pydantic==1.9.1
332
+ # via
333
+ # fastapi
334
+ # gradio
335
  pydeprecate==0.3.2
336
  # via pytorch-lightning
337
+ pydub==0.25.1
338
+ # via gradio
339
  pygments==2.12.0
340
  # via
341
  # icecream
 
343
  # jupyter-console
344
  # nbconvert
345
  # qtconsole
346
+ pynacl==1.5.0
347
+ # via paramiko
348
  pyparsing==3.0.9
349
  # via
350
  # matplotlib
 
353
  # via jsonschema
354
  python-dateutil==2.8.2
355
  # via
356
+ # analytics-python
357
  # jupyter-client
358
  # matplotlib
359
  # pandas
360
+ python-multipart==0.0.5
361
+ # via gradio
362
  pytorch-lightning==1.6.5
363
  # via -r requirements.in
364
  pytz==2022.1
 
382
  # via transformers
383
  requests==2.28.1
384
  # via
385
+ # analytics-python
386
  # fsspec
387
+ # gradio
388
  # huggingface-hub
389
  # requests-oauthlib
390
  # tensorboard
 
392
  # transformers
393
  requests-oauthlib==1.3.1
394
  # via google-auth-oauthlib
395
+ rfc3986[idna2008]==1.5.0
396
+ # via httpx
397
  rsa==4.9
398
  # via google-auth
399
  scipy==1.8.1
 
404
  # via notebook
405
  six==1.16.0
406
  # via
407
+ # analytics-python
408
  # asttokens
409
  # bleach
410
  # google-auth
411
  # grpcio
412
+ # paramiko
413
  # python-dateutil
414
+ # python-multipart
415
+ sniffio==1.2.0
416
+ # via
417
+ # anyio
418
+ # httpcore
419
+ # httpx
420
  soupsieve==2.3.2.post1
421
  # via beautifulsoup4
422
  stack-data==0.3.0
423
  # via ipython
424
+ starlette==0.19.1
425
+ # via fastapi
426
  tensorboard==2.9.1
427
  # via pytorch-lightning
428
  tensorboard-data-server==0.6.1
 
442
  torch==1.12.0+cu116
443
  # via
444
  # -r requirements.in
445
+ # deepmultilingualpunctuation
446
  # pytorch-lightning
447
  # torchaudio
448
  # torchmetrics
 
478
  # notebook
479
  # qtconsole
480
  transformers==4.20.1
481
+ # via
482
+ # -r requirements.in
483
+ # deepmultilingualpunctuation
484
  typing-extensions==4.3.0
485
  # via
486
  # huggingface-hub
487
+ # pydantic
488
  # pytorch-lightning
489
+ # starlette
490
  # torch
491
  # torchvision
492
+ uc-micro-py==1.0.1
493
+ # via linkify-it-py
494
  urllib3==1.26.10
495
  # via requests
496
+ uvicorn==0.18.2
497
+ # via gradio
498
  wcwidth==0.2.5
499
  # via prompt-toolkit
500
  webencodings==0.5.1