sagar007 commited on
Commit
02a0e92
·
verified ·
1 Parent(s): 20d8ecd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -47
app.py CHANGED
@@ -1,50 +1,209 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, LlamaForCausalLM
4
  import spaces
5
- import subprocess
6
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
7
-
8
-
9
- # Initialize model and tokenizer
10
- model_id = 'akjindal53244/Llama-3.1-Storm-8B'
11
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
12
- model = LlamaForCausalLM.from_pretrained(
13
- model_id,
14
- torch_dtype=torch.float32,
15
- device_map="auto",
16
- low_cpu_mem_usage=True
17
- )
18
-
19
- # Function to format the prompt
20
- def format_prompt(messages):
21
- prompt = "<|begin_of_text|>"
22
- for message in messages:
23
- prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}<|eot_id|>"
24
- prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
25
- return prompt
26
-
27
- # Function to generate response
28
- @spaces.GPU(duration=300) # Increased duration due to potential slower processing
29
- def generate_response(message, history):
30
- messages = [{"role": "system", "content": "You are a helpful assistant."}]
31
- for human, assistant in history:
32
- messages.append({"role": "user", "content": human})
33
- messages.append({"role": "assistant", "content": assistant})
34
- messages.append({"role": "user", "content": message})
35
-
36
- prompt = format_prompt(messages)
37
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
38
- generated_ids = model.generate(input_ids, max_new_tokens=256, temperature=0.7, do_sample=True, eos_token_id=tokenizer.eos_token_id)
39
- response = tokenizer.decode(generated_ids[0][input_ids.shape[-1]:], skip_special_tokens=True)
40
- return response.strip()
41
-
42
- # Create Gradio interface
43
- iface = gr.ChatInterface(
44
- generate_response,
45
- title="Llama-3.1-Storm-8B Chatbot",
46
- description="Chat with the Llama-3.1-Storm-8B model. Type your message and press Enter to send.",
47
- )
48
-
49
- # Launch the app
50
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
 
3
  import spaces
4
+ import torchaudio
5
+ from whisperspeech.vq_stoks import RQBottleneckTransformer
6
+ from encodec.utils import convert_audio
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
8
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
9
+ from threading import Thread
10
+ import logging
11
+ import os
12
+ from generate_audio import (
13
+ TTSProcessor,
14
+ )
15
+ import uuid
16
+
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ vq_model = RQBottleneckTransformer.load_model(
20
+ "whisper-vq-stoks-medium-en+pl-fixed.model"
21
+ ).to(device)
22
+ # tts = TTSProcessor('cpu')
23
+ use_8bit = False
24
+ llm_path = "akjindal53244/Llama-3.1-Storm-8B"
25
+ tokenizer = AutoTokenizer.from_pretrained(llm_path)
26
+ model_kwargs = {}
27
+ if use_8bit:
28
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
29
+ load_in_8bit=True,
30
+ llm_int8_enable_fp32_cpu_offload=False,
31
+ llm_int8_has_fp16_weight=False,
32
+ )
33
+ else:
34
+ model_kwargs["torch_dtype"] = torch.bfloat16
35
+ model = AutoModelForCausalLM.from_pretrained(llm_path, **model_kwargs).to(device)
36
+
37
+ @spaces.GPU
38
+ def audio_to_sound_tokens_whisperspeech(audio_path):
39
+ vq_model.ensure_whisper('cuda')
40
+ wav, sr = torchaudio.load(audio_path)
41
+ if sr != 16000:
42
+ wav = torchaudio.functional.resample(wav, sr, 16000)
43
+ with torch.no_grad():
44
+ codes = vq_model.encode_audio(wav.to(device))
45
+ codes = codes[0].cpu().tolist()
46
+
47
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
48
+ return f'<|sound_start|>{result}<|sound_end|>'
49
+
50
+ @spaces.GPU
51
+ def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
52
+ vq_model.ensure_whisper('cuda')
53
+ wav, sr = torchaudio.load(audio_path)
54
+ if sr != 16000:
55
+ wav = torchaudio.functional.resample(wav, sr, 16000)
56
+ with torch.no_grad():
57
+ codes = vq_model.encode_audio(wav.to(device))
58
+ codes = codes[0].cpu().tolist()
59
+
60
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
61
+ return f'<|reserved_special_token_69|><|sound_start|>{result}<|sound_end|>'
62
+ # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
63
+ # print(tokenizer.eos_token)
64
+
65
+ @spaces.GPU
66
+ def text_to_audio_file(text):
67
+ # gen a random id for the audio file
68
+ id = str(uuid.uuid4())
69
+ temp_file = f"./user_audio/{id}_temp_audio.wav"
70
+ text = text
71
+ text_split = "_".join(text.lower().split(" "))
72
+ # remove the last character if it is a period
73
+ if text_split[-1] == ".":
74
+ text_split = text_split[:-1]
75
+ tts = TTSProcessor("cuda")
76
+ tts.convert_text_to_audio_file(text, temp_file)
77
+ # logging.info(f"Saving audio to {temp_file}")
78
+ # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
79
+ print(f"Saved audio to {temp_file}")
80
+ return temp_file
81
+
82
+
83
+ @spaces.GPU
84
+ def process_input(audio_file=None):
85
+
86
+ for partial_message in process_audio(audio_file):
87
+ yield partial_message
88
+
89
+
90
+ @spaces.GPU
91
+ def process_transcribe_input(audio_file=None):
92
+
93
+ for partial_message in process_audio(audio_file, transcript=True):
94
+ yield partial_message
95
+
96
+ class StopOnTokens(StoppingCriteria):
97
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
98
+ # encode </s> token
99
+ stop_ids = [tokenizer.eos_token_id, 128009] # Adjust this based on your model's tokenizer
100
+ for stop_id in stop_ids:
101
+ if input_ids[0][-1] == stop_id:
102
+ return True
103
+ return False
104
+
105
+ @spaces.GPU
106
+ def process_audio(audio_file, transcript=False):
107
+ if audio_file is None:
108
+ raise ValueError("No audio file provided")
109
+
110
+ logging.info(f"Audio file received: {audio_file}")
111
+ logging.info(f"Audio file type: {type(audio_file)}")
112
+
113
+ sound_tokens = audio_to_sound_tokens_whisperspeech_transcribe(audio_file) if transcript else audio_to_sound_tokens_whisperspeech(audio_file)
114
+ logging.info("Sound tokens generated successfully")
115
+ # logging.info(f"audio_file: {audio_file.name}")
116
+ messages = [
117
+ {"role": "user", "content": sound_tokens},
118
+ ]
119
+
120
+ stop = StopOnTokens()
121
+ input_str = tokenizer.apply_chat_template(messages, tokenize=False)
122
+ input_ids = tokenizer.encode(input_str, return_tensors="pt")
123
+ input_ids = input_ids.to(model.device)
124
+
125
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
126
+ generation_kwargs = dict(
127
+ input_ids=input_ids,
128
+ streamer=streamer,
129
+ max_new_tokens=1024,
130
+ do_sample=False,
131
+ stopping_criteria=StoppingCriteriaList([stop])
132
+ )
133
+
134
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
135
+ thread.start()
136
+
137
+ partial_message = ""
138
+ for new_token in streamer:
139
+ partial_message += new_token
140
+ if tokenizer.eos_token in partial_message:
141
+ break
142
+ partial_message = partial_message.replace("assistant\n\n", "")
143
+ yield partial_message
144
+ # def stop_generation():
145
+ # # This is a placeholder. Implement actual stopping logic here if needed.
146
+ # return "Generation stopped.", gr.Button.update(interactive=False)
147
+ # take all the examples from the examples folder
148
+ good_examples = []
149
+ for file in os.listdir("./examples"):
150
+ if file.endswith(".wav"):
151
+ good_examples.append([f"./examples/{file}"])
152
+ bad_examples = []
153
+ for file in os.listdir("./bad_examples"):
154
+ if file.endswith(".wav"):
155
+ bad_examples.append([f"./bad_examples/{file}"])
156
+ examples = []
157
+ examples.extend(good_examples)
158
+ examples.extend(bad_examples)
159
+ with gr.Blocks() as iface:
160
+ gr.Markdown("# Llama3.1-S: checkpoint Aug 19, 2024")
161
+ gr.Markdown("Enter text to convert to audio, then submit the audio to generate text or Upload Audio")
162
+ gr.Markdown("Powered by [Homebrew Ltd](https://homebrew.ltd/) | [Read our blog post](https://homebrew.ltd/blog/llama3-just-got-ears)")
163
+
164
+ with gr.Row():
165
+ input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
166
+ text_input = gr.Textbox(label="Text Input", visible=False)
167
+ audio_input = gr.Audio(label="Audio", type="filepath", visible=True)
168
+ # audio_output = gr.Audio(label="Converted Audio", type="filepath", visible=False)
169
+
170
+ convert_button = gr.Button("Make synthetic audio", visible=False)
171
+ submit_button = gr.Button("Chat with AI using audio")
172
+ transcrip_button = gr.Button("Make Model transcribe the audio")
173
+
174
+ text_output = gr.Textbox(label="Generated Text")
175
+
176
+ def update_visibility(input_type):
177
+ return (gr.update(visible=input_type == "text"),
178
+ gr.update(visible=input_type == "text"))
179
+ def convert_and_display(text):
180
+ audio_file = text_to_audio_file(text)
181
+ return audio_file
182
+ def process_example(file_path):
183
+ return update_visibility("audio")
184
+ input_type.change(
185
+ update_visibility,
186
+ inputs=[input_type],
187
+ outputs=[text_input, convert_button]
188
+ )
189
+
190
+ convert_button.click(
191
+ convert_and_display,
192
+ inputs=[text_input],
193
+ outputs=[audio_input]
194
+ )
195
+
196
+ submit_button.click(
197
+ process_input,
198
+ inputs=[audio_input],
199
+ outputs=[text_output]
200
+ )
201
+ transcrip_button.click(
202
+ process_transcribe_input,
203
+ inputs=[audio_input],
204
+ outputs=[text_output]
205
+ )
206
+
207
+ gr.Examples(examples, inputs=[audio_input])
208
+ iface.queue()
209
+ iface.launch(