VanYsa commited on
Commit
3cad491
·
1 Parent(s): e56ba3c

added meta llm

Browse files
Files changed (1) hide show
  1. app.py +59 -36
app.py CHANGED
@@ -11,6 +11,14 @@ import time
11
 
12
  from nemo.collections.asr.models import ASRModel
13
 
 
 
 
 
 
 
 
 
14
  SAMPLE_RATE = 16000 # Hz
15
  MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
16
  DESCRIPTION = '''
@@ -42,12 +50,13 @@ decoding_cfg.beam.beam_size = 1
42
  canary_model.change_decoding_strategy(decoding_cfg)
43
 
44
  ### LLM model
45
- pipeline = transformers.pipeline(
46
- "text-generation",
47
- model="meta-llama/Meta-Llama-3-8B-Instruct",
48
- model_kwargs={"torch_dtype": torch.bfloat16},
49
- device=device
50
- )
 
51
 
52
  def convert_audio(audio_filepath, tmpdir, utt_id):
53
  """
@@ -133,36 +142,50 @@ def bot(history,message):
133
  time.sleep(0.05)
134
  yield history
135
 
136
- def bot_response(message):
137
- """
138
- Generates a response from the LLM model.
139
- max_new_tokens, temperature and top_p are set to 256, 0.6 and 0.9 respectively.
140
- """
141
- messages = [
142
- {"role": "system", "content": "You are a helpful AI assistant."},
143
- {"role": "user", "content": "What is an apple"},
144
- ]
145
-
146
- prompt = pipeline.tokenizer.apply_chat_template(
147
- messages,
148
- tokenize=False,
149
- add_generation_prompt=True
150
- )
151
-
152
- terminators = [
153
- pipeline.tokenizer.eos_token_id,
154
- pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
155
- ]
156
-
157
- outputs = pipeline(
158
- prompt,
159
- max_new_tokens=256,
160
- eos_token_id=terminators,
161
- do_sample=True,
162
- temperature=0.6,
163
- top_p=0.9,
164
- )
165
- return outputs[0]["generated_text"][len(prompt):]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  with gr.Blocks(
168
  title="MyAlexa",
 
11
 
12
  from nemo.collections.asr.models import ASRModel
13
 
14
+ from transformers import GemmaTokenizer, AutoModelForCausalLM
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
16
+ from threading import Thread
17
+
18
+ # Set an environment variable
19
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
+
21
+
22
  SAMPLE_RATE = 16000 # Hz
23
  MAX_AUDIO_SECONDS = 40 # wont try to transcribe if longer than this
24
  DESCRIPTION = '''
 
50
  canary_model.change_decoding_strategy(decoding_cfg)
51
 
52
  ### LLM model
53
+ # Load the tokenizer and model
54
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
55
+ llama3_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto") # to("cuda:0")
56
+ terminators = [
57
+ tokenizer.eos_token_id,
58
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
59
+ ]
60
 
61
  def convert_audio(audio_filepath, tmpdir, utt_id):
62
  """
 
142
  time.sleep(0.05)
143
  yield history
144
 
145
+ def bot_response(message: str,
146
+ history: list,
147
+ temperature: float,
148
+ max_new_tokens: int
149
+ ) -> str: # type: ignore
150
+ """
151
+ Generate a streaming response using the llama3-8b model.
152
+ Args:
153
+ message (str): The input message.
154
+ history (list): The conversation history used by ChatInterface.
155
+ temperature (float): The temperature for generating the response.
156
+ max_new_tokens (int): The maximum number of new tokens to generate.
157
+ Returns:
158
+ str: The generated response.
159
+ """
160
+ conversation = []
161
+ for user, assistant in history:
162
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
163
+ conversation.append({"role": "user", "content": message})
164
+
165
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(llama3_model.device)
166
+
167
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
168
+
169
+ generate_kwargs = dict(
170
+ input_ids= input_ids,
171
+ streamer=streamer,
172
+ max_new_tokens=max_new_tokens,
173
+ do_sample=True,
174
+ temperature=temperature,
175
+ eos_token_id=terminators,
176
+ )
177
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
178
+ if temperature == 0:
179
+ generate_kwargs['do_sample'] = False
180
+
181
+ t = Thread(target=llama3_model.generate, kwargs=generate_kwargs)
182
+ t.start()
183
+
184
+ outputs = []
185
+ for text in streamer:
186
+ outputs.append(text)
187
+ #print(outputs)
188
+ yield "".join(outputs)
189
 
190
  with gr.Blocks(
191
  title="MyAlexa",