QuietImpostor commited on
Commit
cd51b0f
·
verified ·
1 Parent(s): cc0fe39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -40
app.py CHANGED
@@ -1,6 +1,8 @@
 
 
 
1
  import gradio as gr
2
  import torch
3
- import spaces
4
  import torchaudio
5
  from whisperspeech.vq_stoks import RQBottleneckTransformer
6
  from encodec.utils import convert_audio
@@ -8,13 +10,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
  from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
9
  from threading import Thread
10
  import logging
11
- import os
12
- from generate_audio import (
13
- TTSProcessor,
14
- )
15
  import uuid
16
 
17
- device = "cpu" # Change this to always use CPU
18
  vq_model = RQBottleneckTransformer.load_model(
19
  "whisper-vq-stoks-medium-en+pl-fixed.model"
20
  ).to(device)
@@ -30,12 +29,11 @@ if use_8bit:
30
  llm_int8_has_fp16_weight=False,
31
  )
32
  else:
33
- model_kwargs["torch_dtype"] = torch.float32 # Change this to use float32 on CPU
34
  model = AutoModelForCausalLM.from_pretrained(llm_path, **model_kwargs).to(device)
35
 
36
- @spaces.CPU # Change this to use CPU
37
  def audio_to_sound_tokens_whisperspeech(audio_path):
38
- vq_model.ensure_whisper(device) # Change this to use the defined device
39
  wav, sr = torchaudio.load(audio_path)
40
  if sr != 16000:
41
  wav = torchaudio.functional.resample(wav, sr, 16000)
@@ -46,9 +44,8 @@ def audio_to_sound_tokens_whisperspeech(audio_path):
46
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
47
  return f'<|sound_start|>{result}<|sound_end|>'
48
 
49
- @spaces.CPU # Change this to use CPU
50
  def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
51
- vq_model.ensure_whisper(device) # Change this to use the defined device
52
  wav, sr = torchaudio.load(audio_path)
53
  if sr != 16000:
54
  wav = torchaudio.functional.resample(wav, sr, 16000)
@@ -59,53 +56,50 @@ def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
59
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
60
  return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
61
 
62
- @spaces.CPU # Change this to use CPU
63
  def text_to_audio_file(text):
64
  id = str(uuid.uuid4())
65
  temp_file = f"./user_audio/{id}_temp_audio.wav"
66
- text = text
67
  text_split = "_".join(text.lower().split(" "))
68
  if text_split[-1] == ".":
69
  text_split = text_split[:-1]
70
- tts = TTSProcessor(device) # Change this to use the defined device
71
  tts.convert_text_to_audio_file(text, temp_file)
72
  print(f"Saved audio to {temp_file}")
73
  return temp_file
74
 
 
 
 
 
75
 
76
- @spaces.CPU
77
  def process_input(audio_file=None):
78
-
79
  for partial_message in process_audio(audio_file):
80
  yield partial_message
81
-
82
-
83
- @spaces.CPU
84
  def process_transcribe_input(audio_file=None):
85
-
86
  for partial_message in process_audio(audio_file, transcript=True):
87
  yield partial_message
88
-
89
  class StopOnTokens(StoppingCriteria):
90
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
91
- # encode </s> token
92
- stop_ids = [tokenizer.eos_token_id, 128009] # Adjust this based on your model's tokenizer
93
  for stop_id in stop_ids:
94
  if input_ids[0][-1] == stop_id:
95
  return True
96
  return False
97
-
98
- @spaces.CPU
99
  def process_audio(audio_file, transcript=False):
100
  if audio_file is None:
101
- raise ValueError("No audio file provided")
102
 
103
  logging.info(f"Audio file received: {audio_file}")
104
  logging.info(f"Audio file type: {type(audio_file)}")
105
 
106
- sound_tokens = audio_to_sound_tokens_whisperspeech_transcribe(audio_file) if transcript else audio_to_sound_tokens_whisperspeech(audio_file)
107
  logging.info("Sound tokens generated successfully")
108
- # logging.info(f"audio_file: {audio_file.name}")
109
  messages = [
110
  {"role": "user", "content": sound_tokens},
111
  ]
@@ -115,7 +109,7 @@ def process_audio(audio_file, transcript=False):
115
  input_ids = tokenizer.encode(input_str, return_tensors="pt")
116
  input_ids = input_ids.to(model.device)
117
 
118
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
119
  generation_kwargs = dict(
120
  input_ids=input_ids,
121
  streamer=streamer,
@@ -134,10 +128,7 @@ def process_audio(audio_file, transcript=False):
134
  break
135
  partial_message = partial_message.replace("assistant\n\n", "")
136
  yield partial_message
137
- # def stop_generation():
138
- # # This is a placeholder. Implement actual stopping logic here if needed.
139
- # return "Generation stopped.", gr.Button.update(interactive=False)
140
- # take all the examples from the examples folder
141
  good_examples = []
142
  for file in os.listdir("./examples"):
143
  if file.endswith(".wav"):
@@ -149,6 +140,7 @@ for file in os.listdir("./bad_examples"):
149
  examples = []
150
  examples.extend(good_examples)
151
  examples.extend(bad_examples)
 
152
  with gr.Blocks() as iface:
153
  gr.Markdown("# Llama3.1-S: checkpoint Aug 19, 2024")
154
  gr.Markdown("Enter text to convert to audio, then submit the audio to generate text or Upload Audio")
@@ -158,8 +150,7 @@ with gr.Blocks() as iface:
158
  input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
159
  text_input = gr.Textbox(label="Text Input", visible=False)
160
  audio_input = gr.Audio(label="Audio", type="filepath", visible=True)
161
- # audio_output = gr.Audio(label="Converted Audio", type="filepath", visible=False)
162
-
163
  convert_button = gr.Button("Make synthetic audio", visible=False)
164
  submit_button = gr.Button("Chat with AI using audio")
165
  transcrip_button = gr.Button("Make Model transcribe the audio")
@@ -169,11 +160,11 @@ with gr.Blocks() as iface:
169
  def update_visibility(input_type):
170
  return (gr.update(visible=input_type == "text"),
171
  gr.update(visible=input_type == "text"))
 
172
  def convert_and_display(text):
173
  audio_file = text_to_audio_file(text)
174
  return audio_file
175
- def process_example(file_path):
176
- return update_visibility("audio")
177
  input_type.change(
178
  update_visibility,
179
  inputs=[input_type],
@@ -198,7 +189,6 @@ with gr.Blocks() as iface:
198
  )
199
 
200
  gr.Examples(examples, inputs=[audio_input])
 
201
  iface.queue()
202
- iface.launch()
203
- # launch locally
204
- # iface.launch(server_name="0.0.0.0")
 
1
+ import os
2
+ os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION'] = '0'
3
+
4
  import gradio as gr
5
  import torch
 
6
  import torchaudio
7
  from whisperspeech.vq_stoks import RQBottleneckTransformer
8
  from encodec.utils import convert_audio
 
10
  from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
11
  from threading import Thread
12
  import logging
13
+ from generate_audio import TTSProcessor
 
 
 
14
  import uuid
15
 
16
+ device = "cpu"
17
  vq_model = RQBottleneckTransformer.load_model(
18
  "whisper-vq-stoks-medium-en+pl-fixed.model"
19
  ).to(device)
 
29
  llm_int8_has_fp16_weight=False,
30
  )
31
  else:
32
+ model_kwargs["torch_dtype"] = torch.float32
33
  model = AutoModelForCausalLM.from_pretrained(llm_path, **model_kwargs).to(device)
34
 
 
35
  def audio_to_sound_tokens_whisperspeech(audio_path):
36
+ vq_model.ensure_whisper(device)
37
  wav, sr = torchaudio.load(audio_path)
38
  if sr != 16000:
39
  wav = torchaudio.functional.resample(wav, sr, 16000)
 
44
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
45
  return f'<|sound_start|>{result}<|sound_end|>'
46
 
 
47
  def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
48
+ vq_model.ensure_whisper(device)
49
  wav, sr = torchaudio.load(audio_path)
50
  if sr != 16000:
51
  wav = torchaudio.functional.resample(wav, sr, 16000)
 
56
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
57
  return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
58
 
 
59
  def text_to_audio_file(text):
60
  id = str(uuid.uuid4())
61
  temp_file = f"./user_audio/{id}_temp_audio.wav"
 
62
  text_split = "_".join(text.lower().split(" "))
63
  if text_split[-1] == ".":
64
  text_split = text_split[:-1]
65
+ tts = TTSProcessor(device)
66
  tts.convert_text_to_audio_file(text, temp_file)
67
  print(f"Saved audio to {temp_file}")
68
  return temp_file
69
 
70
+ def run_on_cpu(func):
71
+ def wrapper(*args, **kwargs):
72
+ return func(*args, **kwargs)
73
+ return wrapper
74
 
75
+ @run_on_cpu
76
  def process_input(audio_file=None):
 
77
  for partial_message in process_audio(audio_file):
78
  yield partial_message
79
+
80
+ @run_on_cpu
 
81
  def process_transcribe_input(audio_file=None):
 
82
  for partial_message in process_audio(audio_file, transcript=True):
83
  yield partial_message
84
+
85
  class StopOnTokens(StoppingCriteria):
86
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
87
+ stop_ids = [tokenizer.eos_token_id, 128009]
 
88
  for stop_id in stop_ids:
89
  if input_ids[0][-1] == stop_id:
90
  return True
91
  return False
92
+
 
93
  def process_audio(audio_file, transcript=False):
94
  if audio_file is None:
95
+ raise ValueError("No audio file provided")
96
 
97
  logging.info(f"Audio file received: {audio_file}")
98
  logging.info(f"Audio file type: {type(audio_file)}")
99
 
100
+ sound_tokens = audio_to_sound_tokens_whisperspeech_transcribe(audio_file) if transcript else audio_to_sound_tokens_whisperspeech(audio_file)
101
  logging.info("Sound tokens generated successfully")
102
+
103
  messages = [
104
  {"role": "user", "content": sound_tokens},
105
  ]
 
109
  input_ids = tokenizer.encode(input_str, return_tensors="pt")
110
  input_ids = input_ids.to(model.device)
111
 
112
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
113
  generation_kwargs = dict(
114
  input_ids=input_ids,
115
  streamer=streamer,
 
128
  break
129
  partial_message = partial_message.replace("assistant\n\n", "")
130
  yield partial_message
131
+
 
 
 
132
  good_examples = []
133
  for file in os.listdir("./examples"):
134
  if file.endswith(".wav"):
 
140
  examples = []
141
  examples.extend(good_examples)
142
  examples.extend(bad_examples)
143
+
144
  with gr.Blocks() as iface:
145
  gr.Markdown("# Llama3.1-S: checkpoint Aug 19, 2024")
146
  gr.Markdown("Enter text to convert to audio, then submit the audio to generate text or Upload Audio")
 
150
  input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
151
  text_input = gr.Textbox(label="Text Input", visible=False)
152
  audio_input = gr.Audio(label="Audio", type="filepath", visible=True)
153
+
 
154
  convert_button = gr.Button("Make synthetic audio", visible=False)
155
  submit_button = gr.Button("Chat with AI using audio")
156
  transcrip_button = gr.Button("Make Model transcribe the audio")
 
160
  def update_visibility(input_type):
161
  return (gr.update(visible=input_type == "text"),
162
  gr.update(visible=input_type == "text"))
163
+
164
  def convert_and_display(text):
165
  audio_file = text_to_audio_file(text)
166
  return audio_file
167
+
 
168
  input_type.change(
169
  update_visibility,
170
  inputs=[input_type],
 
189
  )
190
 
191
  gr.Examples(examples, inputs=[audio_input])
192
+
193
  iface.queue()
194
+ iface.launch()