sagar007 commited on
Commit
1f7ba92
·
verified ·
1 Parent(s): b77f66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -201
app.py CHANGED
@@ -1,209 +1,44 @@
1
  import gradio as gr
 
2
  import torch
3
- import spaces
4
- import torchaudio
5
- from whisperspeech.vq_stoks import RQBottleneckTransformer
6
- from encodec.utils import convert_audio
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
8
- from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
9
- from threading import Thread
10
- import logging
11
- import os
12
- from generate_audio import (
13
- TTSProcessor,
14
- )
15
- import uuid
16
 
 
 
 
 
 
 
 
 
17
 
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- vq_model = RQBottleneckTransformer.load_model(
20
- "whisper-vq-stoks-medium-en+pl-fixed.model"
21
- ).to(device)
22
- # tts = TTSProcessor('cpu')
23
- use_8bit = True
24
- llm_path = "akjindal53244/Llama-3.1-Storm-8B"
25
- tokenizer = AutoTokenizer.from_pretrained(llm_path)
26
- model_kwargs = {}
27
- if use_8bit:
28
- model_kwargs["quantization_config"] = BitsAndBytesConfig(
29
- load_in_8bit=True,
30
- llm_int8_enable_fp32_cpu_offload=False,
31
- llm_int8_has_fp16_weight=False,
32
- )
33
- else:
34
- model_kwargs["torch_dtype"] = torch.bfloat16
35
- model = AutoModelForCausalLM.from_pretrained(llm_path, **model_kwargs).to(device)
36
-
37
- @spaces.GPU
38
- def audio_to_sound_tokens_whisperspeech(audio_path):
39
- vq_model.ensure_whisper('cuda')
40
- wav, sr = torchaudio.load(audio_path)
41
- if sr != 16000:
42
- wav = torchaudio.functional.resample(wav, sr, 16000)
43
- with torch.no_grad():
44
- codes = vq_model.encode_audio(wav.to(device))
45
- codes = codes[0].cpu().tolist()
46
-
47
- result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
48
- return f'<|sound_start|>{result}<|sound_end|>'
49
-
50
- @spaces.GPU
51
- def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
52
- vq_model.ensure_whisper('cuda')
53
- wav, sr = torchaudio.load(audio_path)
54
- if sr != 16000:
55
- wav = torchaudio.functional.resample(wav, sr, 16000)
56
- with torch.no_grad():
57
- codes = vq_model.encode_audio(wav.to(device))
58
- codes = codes[0].cpu().tolist()
59
-
60
- result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
61
- return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
62
- # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
63
- # print(tokenizer.eos_token)
64
-
65
- @spaces.GPU
66
- def text_to_audio_file(text):
67
- # gen a random id for the audio file
68
- id = str(uuid.uuid4())
69
- temp_file = f"./user_audio/{id}_temp_audio.wav"
70
- text = text
71
- text_split = "_".join(text.lower().split(" "))
72
- # remove the last character if it is a period
73
- if text_split[-1] == ".":
74
- text_split = text_split[:-1]
75
- tts = TTSProcessor("cuda")
76
- tts.convert_text_to_audio_file(text, temp_file)
77
- # logging.info(f"Saving audio to {temp_file}")
78
- # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
79
- print(f"Saved audio to {temp_file}")
80
- return temp_file
81
-
82
-
83
- @spaces.GPU
84
- def process_input(audio_file=None):
85
-
86
- for partial_message in process_audio(audio_file):
87
- yield partial_message
88
-
89
-
90
- @spaces.GPU
91
- def process_transcribe_input(audio_file=None):
92
-
93
- for partial_message in process_audio(audio_file, transcript=True):
94
- yield partial_message
95
-
96
- class StopOnTokens(StoppingCriteria):
97
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
98
- # encode </s> token
99
- stop_ids = [tokenizer.eos_token_id, 128009] # Adjust this based on your model's tokenizer
100
- for stop_id in stop_ids:
101
- if input_ids[0][-1] == stop_id:
102
- return True
103
- return False
104
-
105
- @spaces.GPU
106
- def process_audio(audio_file, transcript=False):
107
- if audio_file is None:
108
- raise ValueError("No audio file provided")
109
-
110
- logging.info(f"Audio file received: {audio_file}")
111
- logging.info(f"Audio file type: {type(audio_file)}")
112
-
113
- sound_tokens = audio_to_sound_tokens_whisperspeech_transcribe(audio_file) if transcript else audio_to_sound_tokens_whisperspeech(audio_file)
114
- logging.info("Sound tokens generated successfully")
115
- # logging.info(f"audio_file: {audio_file.name}")
116
  messages = [
117
- {"role": "user", "content": sound_tokens},
 
118
  ]
119
-
120
- stop = StopOnTokens()
121
- input_str = tokenizer.apply_chat_template(messages, tokenize=False)
122
- input_ids = tokenizer.encode(input_str, return_tensors="pt")
123
- input_ids = input_ids.to(model.device)
124
-
125
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
126
- generation_kwargs = dict(
127
- input_ids=input_ids,
128
- streamer=streamer,
129
- max_new_tokens=1024,
130
- do_sample=False,
131
- stopping_criteria=StoppingCriteriaList([stop])
132
- )
133
-
134
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
135
- thread.start()
136
-
137
- partial_message = ""
138
- for new_token in streamer:
139
- partial_message += new_token
140
- if tokenizer.eos_token in partial_message:
141
- break
142
- partial_message = partial_message.replace("assistant\n\n", "")
143
- yield partial_message
144
- # def stop_generation():
145
- # # This is a placeholder. Implement actual stopping logic here if needed.
146
- # return "Generation stopped.", gr.Button.update(interactive=False)
147
- # take all the examples from the examples folder
148
- good_examples = []
149
- for file in os.listdir("./examples"):
150
- if file.endswith(".wav"):
151
- good_examples.append([f"./examples/{file}"])
152
- bad_examples = []
153
- for file in os.listdir("./bad_examples"):
154
- if file.endswith(".wav"):
155
- bad_examples.append([f"./bad_examples/{file}"])
156
- examples = []
157
- examples.extend(good_examples)
158
- examples.extend(bad_examples)
159
- with gr.Blocks() as iface:
160
- gr.Markdown("# Llama3.1-S: checkpoint Aug 19, 2024")
161
- gr.Markdown("Enter text to convert to audio, then submit the audio to generate text or Upload Audio")
162
- gr.Markdown("Powered by [Homebrew Ltd](https://homebrew.ltd/) | [Read our blog post](https://homebrew.ltd/blog/llama3-just-got-ears)")
163
-
164
- with gr.Row():
165
- input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
166
- text_input = gr.Textbox(label="Text Input", visible=False)
167
- audio_input = gr.Audio(label="Audio", type="filepath", visible=True)
168
- # audio_output = gr.Audio(label="Converted Audio", type="filepath", visible=False)
169
-
170
- convert_button = gr.Button("Make synthetic audio", visible=False)
171
- submit_button = gr.Button("Chat with AI using audio")
172
- transcrip_button = gr.Button("Make Model transcribe the audio")
173
-
174
- text_output = gr.Textbox(label="Generated Text")
175
-
176
- def update_visibility(input_type):
177
- return (gr.update(visible=input_type == "text"),
178
- gr.update(visible=input_type == "text"))
179
- def convert_and_display(text):
180
- audio_file = text_to_audio_file(text)
181
- return audio_file
182
- def process_example(file_path):
183
- return update_visibility("audio")
184
- input_type.change(
185
- update_visibility,
186
- inputs=[input_type],
187
- outputs=[text_input, convert_button]
188
- )
189
-
190
- convert_button.click(
191
- convert_and_display,
192
- inputs=[text_input],
193
- outputs=[audio_input]
194
- )
195
-
196
- submit_button.click(
197
- process_input,
198
- inputs=[audio_input],
199
- outputs=[text_output]
200
- )
201
- transcrip_button.click(
202
- process_transcribe_input,
203
- inputs=[audio_input],
204
- outputs=[text_output]
205
  )
206
 
207
- gr.Examples(examples, inputs=[audio_input])
208
- iface.queue()
209
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, pipeline
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ model_name = "akjindal53244/Llama-3.1-Storm-8B"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ pipeline = pipeline(
8
+ "text-generation",
9
+ model=model_name,
10
+ torch_dtype=torch.bfloat16,
11
+ device_map="auto",
12
+ )
13
 
14
+ def generate_text(prompt, max_length, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  messages = [
16
+ {"role": "system", "content": "You are a helpful assistant."},
17
+ {"role": "user", "content": prompt}
18
  ]
19
+ formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
20
+
21
+ outputs = pipeline(
22
+ formatted_prompt,
23
+ max_new_tokens=max_length,
24
+ do_sample=True,
25
+ temperature=temperature,
26
+ top_k=100,
27
+ top_p=0.95,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
 
30
+ return outputs[0]["generated_text"]
31
+
32
+ iface = gr.Interface(
33
+ fn=generate_text,
34
+ inputs=[
35
+ gr.Textbox(lines=5, label="Prompt"),
36
+ gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length"),
37
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
38
+ ],
39
+ outputs=gr.Textbox(lines=10, label="Generated Text"),
40
+ title="Llama-3.1-Storm-8B Text Generation",
41
+ description="Enter a prompt to generate text using the Llama-3.1-Storm-8B model.",
42
+ )
43
+
44
+ iface.launch()