YingxuHe commited on
Commit
aea1886
·
1 Parent(s): a8aa9d2
pages/agent.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.content.agent import agent_page
2
+
3
+ agent_page()
requirements.txt CHANGED
@@ -2,4 +2,8 @@ librosa==0.10.2.post1
2
  streamlit==1.40.2
3
  openai==1.57.1
4
  streamlit_mic_recorder==0.0.8
5
- sshtunnel
 
 
 
 
 
2
  streamlit==1.40.2
3
  openai==1.57.1
4
  streamlit_mic_recorder==0.0.8
5
+ sshtunnel
6
+ accelerate==1.3.0
7
+ FlagEmbedding==1.3.3
8
+ sentence-transformers==3.4.0
9
+ sentencepiece==0.1.99
src/content/agent.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import base64
3
+
4
+ import streamlit as st
5
+
6
+ from src.generation import MAX_AUDIO_LENGTH
7
+ from src.retrieval import retrieve_relevant_docs
8
+ from src.utils import bytes_to_array, array_to_bytes
9
+ from src.content.common import (
10
+ MODEL_NAMES,
11
+ AUDIO_SAMPLES_W_INSTRUCT,
12
+ STANDARD_QUERIES,
13
+ DEFAULT_DIALOGUE_STATES,
14
+ init_state_section,
15
+ header_section,
16
+ sidebar_fragment,
17
+ retrive_response_with_ui
18
+ )
19
+
20
+
21
+ LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip.
22
+
23
+ ## User question
24
+ {user_question}
25
+
26
+ {audio_information_prompt}Please reply this user question with an friendly, accurate, and helpful answer."""
27
+
28
+ AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip.
29
+
30
+ ## Audio Information
31
+ {audio_information}
32
+
33
+ This may or may not contain relevant information to the user question, please use with caution.
34
+
35
+ """
36
+
37
+
38
+ def _update_audio(audio_bytes):
39
+ origin_audio_array = bytes_to_array(audio_bytes)
40
+ truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
41
+ truncated_audio_bytes = array_to_bytes(truncated_audio_array)
42
+
43
+ st.session_state.ag_audio_array = origin_audio_array
44
+ st.session_state.ag_audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8')
45
+
46
+
47
+ @st.fragment
48
+ def successful_example_section():
49
+ audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
50
+
51
+ st.markdown(":fire: **Successful Tasks and Examples**")
52
+
53
+ sample_name = st.selectbox(
54
+ label="**Select Audio:**",
55
+ label_visibility="collapsed",
56
+ options=audio_sample_names,
57
+ format_func=lambda o: AUDIO_SAMPLES_W_INSTRUCT[o]["apperance"],
58
+ index=None,
59
+ placeholder="Select an audio sample:",
60
+ on_change=lambda: st.session_state.update(
61
+ on_select=True,
62
+ ag_messages=[],
63
+ ag_model_messages=[],
64
+ disprompt=True
65
+ ),
66
+ key='select')
67
+
68
+ if sample_name and st.session_state.on_select:
69
+ audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
70
+ st.session_state.update(
71
+ on_select=False,
72
+ new_prompt=AUDIO_SAMPLES_W_INSTRUCT[sample_name]["instructions"][0]
73
+ )
74
+ _update_audio(audio_bytes)
75
+ st.rerun(scope="app")
76
+
77
+
78
+ @st.dialog("Specify Audio")
79
+ def audio_attach_dialogue():
80
+ st.markdown("**Upload**")
81
+
82
+ uploaded_file = st.file_uploader(
83
+ label="**Upload Audio:**",
84
+ label_visibility="collapsed",
85
+ type=['wav', 'mp3'],
86
+ on_change=lambda: st.session_state.update(on_upload=True, ag_messages=[], ag_model_messages=[]),
87
+ key='upload'
88
+ )
89
+
90
+ if uploaded_file and st.session_state.on_upload:
91
+ audio_bytes = uploaded_file.read()
92
+ _update_audio(audio_bytes)
93
+ st.session_state.on_upload = False
94
+ st.rerun()
95
+
96
+ st.markdown("**Record**")
97
+
98
+ uploaded_file = st.audio_input(
99
+ label="**Record Audio:**",
100
+ label_visibility="collapsed",
101
+ on_change=lambda: st.session_state.update(on_record=True, ag_messages=[], ag_model_messages=[]),
102
+ key='record'
103
+ )
104
+
105
+ if uploaded_file and st.session_state.on_record:
106
+ audio_bytes = uploaded_file.read()
107
+ _update_audio(audio_bytes)
108
+ st.session_state.on_record = False
109
+ st.rerun()
110
+
111
+
112
+ def bottom_input_section():
113
+ bottom_cols = st.columns([0.03, 0.03, 0.94])
114
+ with bottom_cols[0]:
115
+ st.button(
116
+ 'Clear',
117
+ disabled=st.session_state.disprompt,
118
+ on_click=lambda: st.session_state.update(copy.deepcopy(DEFAULT_DIALOGUE_STATES))
119
+ )
120
+
121
+ with bottom_cols[1]:
122
+ if st.button("\+ Audio", disabled=st.session_state.disprompt):
123
+ audio_attach_dialogue()
124
+
125
+ with bottom_cols[2]:
126
+ if chat_input := st.chat_input(
127
+ placeholder="Instruction...",
128
+ disabled=st.session_state.disprompt,
129
+ on_submit=lambda: st.session_state.update(disprompt=True)
130
+ ):
131
+ st.session_state.new_prompt = chat_input
132
+
133
+
134
+ def conversation_section():
135
+ chat_message_container = st.container(height=480)
136
+ if st.session_state.ag_audio_array.size:
137
+ with chat_message_container.chat_message("user"):
138
+ st.audio(st.session_state.ag_audio_array, format="audio/wav", sample_rate=16000)
139
+
140
+ for message in st.session_state.ag_messages:
141
+ message_name = "assistant" if "assistant" in message["role"] else message["role"]
142
+
143
+ with chat_message_container.chat_message(name=message_name):
144
+ if message.get("error"):
145
+ st.error(message["error"])
146
+ for warning_msg in message.get("warnings", []):
147
+ st.warning(warning_msg)
148
+ if process := message.get("process", []):
149
+ with st.status("Thought process...", expanded=True, state="complete"):
150
+ for proc in process:
151
+ if proc.get("error"):
152
+ st.error(proc["error"])
153
+ for proc_warning_msg in proc.get("warnings", []):
154
+ st.warning(proc_warning_msg)
155
+ if proc.get("content"):
156
+ st.write(proc["content"])
157
+ if message.get("content"):
158
+ st.write(message["content"])
159
+
160
+ with st._bottom:
161
+ bottom_input_section()
162
+
163
+ if one_time_prompt := st.session_state.new_prompt:
164
+ st.session_state.update(new_prompt="")
165
+
166
+ with chat_message_container.chat_message("user"):
167
+ st.write(one_time_prompt)
168
+ st.session_state.ag_messages.append({"role": "user", "content": one_time_prompt})
169
+
170
+ with chat_message_container.chat_message("assistant"):
171
+ assistant_message = {"role": "assistant", "process": []}
172
+ st.session_state.ag_messages.append(assistant_message)
173
+
174
+ relevant_query_indices = retrieve_relevant_docs(one_time_prompt, STANDARD_QUERIES)
175
+ if len(st.session_state.ag_messages) <= 2:
176
+ relevant_query_indices.append(0)
177
+
178
+ relevant_query_indices = list(set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices))
179
+
180
+ audio_info = []
181
+ if relevant_query_indices:
182
+ with st.status("Thought process...", expanded=True) as status:
183
+ for idx in relevant_query_indices:
184
+ error_msg, warnings, response = retrive_response_with_ui(
185
+ model_name=MODEL_NAMES["with_lora"]["vllm_name"],
186
+ prompt=STANDARD_QUERIES[idx]["query_text"],
187
+ array_audio=st.session_state.ag_audio_array,
188
+ base64_audio=st.session_state.ag_audio_base64,
189
+ prefix=f"**{STANDARD_QUERIES[idx]['ui_text']}** :speech_balloon: : ",
190
+ stream=True
191
+ )
192
+ audio_info.append(STANDARD_QUERIES[idx]["response_prefix_text"] + response)
193
+
194
+ assistant_message["process"].append({
195
+ "error": error_msg,
196
+ "warnings": warnings,
197
+ "content": response
198
+ })
199
+
200
+ status.update(state="complete")
201
+
202
+ audio_information_prompt = ""
203
+ if audio_info:
204
+ audio_information_prompt = AUDIO_INFO_TEMPLATE.format(
205
+ audio_information="\n".join(audio_info)
206
+ )
207
+
208
+ prompt = LLM_PROMPT_TEMPLATE.format(
209
+ user_question=one_time_prompt,
210
+ audio_information_prompt=audio_information_prompt
211
+ )
212
+
213
+ error_msg, warnings, response = retrive_response_with_ui(
214
+ model_name=MODEL_NAMES["wo_lora"]["vllm_name"],
215
+ prompt=prompt,
216
+ array_audio=st.session_state.ag_audio_array,
217
+ base64_audio="",
218
+ stream=True,
219
+ history=st.session_state.ag_model_messages
220
+ )
221
+
222
+ assistant_message.update({"error": error_msg, "warnings": warnings, "content": response})
223
+ st.session_state.ag_model_messages.extend([
224
+ {"role": "user", "content": prompt},
225
+ {"role": "assistant", "content": response}
226
+ ])
227
+
228
+ st.session_state.disprompt=False
229
+ st.rerun(scope="app")
230
+
231
+
232
+ def agent_page():
233
+ init_state_section()
234
+ header_section(component_name="Agent System", icon="👥")
235
+
236
+ with st.sidebar:
237
+ sidebar_fragment()
238
+
239
+ successful_example_section()
240
+ conversation_section()
src/content/common.py CHANGED
@@ -1,10 +1,13 @@
1
  import copy
 
 
2
 
3
  import numpy as np
4
  import streamlit as st
5
 
6
  from src.tunnel import start_server
7
  from src.generation import FIXED_GENERATION_CONFIG, load_model, retrive_response
 
8
  from src.logger import load_logger
9
 
10
 
@@ -15,6 +18,11 @@ DEFAULT_DIALOGUE_STATES = dict(
15
  vc_audio_base64='',
16
  vc_audio_array=np.array([]),
17
  vc_messages=[],
 
 
 
 
 
18
  disprompt = False,
19
  new_prompt = "",
20
  on_select=False,
@@ -24,17 +32,16 @@ DEFAULT_DIALOGUE_STATES = dict(
24
  )
25
 
26
 
27
- DEFAULT_VOICE_CHAT_STATES = dict(
28
- audio_base64='',
29
- audio_array=np.array([]),
30
- disprompt = False,
31
- new_prompt = "",
32
- messages=[],
33
- on_select=False,
34
- on_upload=False,
35
- on_record=False,
36
- on_select_quick_action=False
37
- )
38
 
39
 
40
  AUDIO_SAMPLES_W_INSTRUCT = {
@@ -295,6 +302,46 @@ AUDIO_SAMPLES_W_INSTRUCT = {
295
  }
296
 
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  def init_state_section():
299
  st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
300
 
@@ -317,7 +364,10 @@ def init_state_section():
317
  st.session_state.server = start_server()
318
 
319
  if "client" not in st.session_state or 'model_name' not in st.session_state:
320
- st.session_state.client, _ = load_model()
 
 
 
321
 
322
  for key, value in FIXED_GENERATION_CONFIG.items():
323
  if key not in st.session_state:
@@ -364,8 +414,8 @@ def header_section(component_name="Playground", icon="🤖"):
364
  def sidebar_fragment():
365
  with st.container(height=256, border=False):
366
  st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
 
367
  st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
368
-
369
 
370
  st.divider()
371
 
@@ -376,9 +426,9 @@ def sidebar_fragment():
376
  st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
377
 
378
 
379
- def retrive_response_with_ui(prompt, array_audio, base64_audio, stream):
380
  generation_params = dict(
381
- model=st.session_state.model_name,
382
  max_completion_tokens=st.session_state.max_completion_tokens,
383
  temperature=st.session_state.temperature,
384
  top_p=st.session_state.top_p,
@@ -390,26 +440,33 @@ def retrive_response_with_ui(prompt, array_audio, base64_audio, stream):
390
  seed=st.session_state.seed
391
  )
392
 
393
- error_msg, warnings, stream = retrive_response(
394
  prompt,
395
  array_audio,
396
  base64_audio,
397
- params=generation_params,
398
- stream=True
399
  )
400
- response = ""
401
 
402
  if error_msg:
403
  st.error(error_msg)
404
  for warning_msg in warnings:
405
  st.warning(warning_msg)
406
- if stream:
407
- response = st.write_stream(stream)
 
 
 
 
 
 
 
408
 
409
  st.session_state.logger.register_query(
410
  session_id=st.session_state.session_id,
411
  base64_audio=base64_audio,
412
  text_input=prompt,
 
413
  params=generation_params,
414
  response=response,
415
  warnings=warnings,
 
1
  import copy
2
+ import itertools
3
+ from collections import OrderedDict
4
 
5
  import numpy as np
6
  import streamlit as st
7
 
8
  from src.tunnel import start_server
9
  from src.generation import FIXED_GENERATION_CONFIG, load_model, retrive_response
10
+ from src.retrieval import load_retriever
11
  from src.logger import load_logger
12
 
13
 
 
18
  vc_audio_base64='',
19
  vc_audio_array=np.array([]),
20
  vc_messages=[],
21
+ ag_audio_base64='',
22
+ ag_audio_array=np.array([]),
23
+ ag_visited_query_indices=[],
24
+ ag_messages=[],
25
+ ag_model_messages=[],
26
  disprompt = False,
27
  new_prompt = "",
28
  on_select=False,
 
32
  )
33
 
34
 
35
+ MODEL_NAMES = OrderedDict({
36
+ "with_lora": {
37
+ "vllm_name": "MERaLiON-lora",
38
+ "ui_name": "MERaLiON-AudioLLM (more accurate)"
39
+ },
40
+ "wo_lora": {
41
+ "vllm_name": "MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora",
42
+ "ui_name": "MERaLiON-AudioLLM-instruction-following (more flexible)"
43
+ }
44
+ })
 
45
 
46
 
47
  AUDIO_SAMPLES_W_INSTRUCT = {
 
302
  }
303
 
304
 
305
+ STANDARD_QUERIES = [
306
+ {
307
+ "query_text": "Please transcribe this speech.",
308
+ "doc_text": "Listen to a speech and write down exactly what is being said in text form. It's essentially converting spoken words into written words. Provide the exact transcription of the given audio. Record whatever the speaker has said into written text.",
309
+ "response_prefix_text": "The transcription of the speech is: ",
310
+ "ui_text": "speech trancription"
311
+ },
312
+ {
313
+ "query_text": "Please describe what happended in this audio",
314
+ "doc_text": "Text captions describing the sound events and environments in the audio clips, describing the events and actions happened in the audio.",
315
+ "response_prefix_text": "Events in this audio clip: ",
316
+ "ui_text": "audio caption"
317
+ },
318
+ {
319
+ "query_text": "May I know the gender of the speakers",
320
+ "doc_text": "Please identify speaker gender by analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders.",
321
+ "response_prefix_text": "By analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders: ",
322
+ "ui_text": "gender recognition"
323
+ },
324
+ {
325
+ "query_text": "May I know the nationality of the speakers",
326
+ "doc_text": "Discover speakers' nationality, country, or the place he is coming from. Analyze speakers' accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds.",
327
+ "response_prefix_text": "By analyzing accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds: ",
328
+ "ui_text": "accent recognition"
329
+ },
330
+ {
331
+ "query_text": "Can you guess which ethnic group this person is from based on their accent.",
332
+ "doc_text": "Discover speakers' ethnic group, home country, or the place he is coming from, from speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors.",
333
+ "response_prefix_text": "By analyzing speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors: ",
334
+ "ui_text": "accent recognition"
335
+ },
336
+ {
337
+ "query_text": "What do you think the speakers are feeling.",
338
+ "doc_text": "What do you think the speakers are feeling. Please identify speakers' emotions by analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy, which reflect emotional states such as happiness, anger, sadness, or fear.",
339
+ "response_prefix_text": "By analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy: ",
340
+ "ui_text": "emotion recognition"
341
+ },
342
+ ]
343
+
344
+
345
  def init_state_section():
346
  st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
347
 
 
364
  st.session_state.server = start_server()
365
 
366
  if "client" not in st.session_state or 'model_name' not in st.session_state:
367
+ st.session_state.client = load_model()
368
+
369
+ if "retriever" not in st.session_state:
370
+ st.session_state.retriever = load_retriever()
371
 
372
  for key, value in FIXED_GENERATION_CONFIG.items():
373
  if key not in st.session_state:
 
414
  def sidebar_fragment():
415
  with st.container(height=256, border=False):
416
  st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
417
+ st.page_link("pages/agent.py", disabled=st.session_state.disprompt, label="👥 Multi-Agent System")
418
  st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
 
419
 
420
  st.divider()
421
 
 
426
  st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
427
 
428
 
429
+ def retrive_response_with_ui(model_name, prompt, array_audio, base64_audio, prefix="", **kwargs):
430
  generation_params = dict(
431
+ model=model_name,
432
  max_completion_tokens=st.session_state.max_completion_tokens,
433
  temperature=st.session_state.temperature,
434
  top_p=st.session_state.top_p,
 
440
  seed=st.session_state.seed
441
  )
442
 
443
+ error_msg, warnings, response_obj = retrive_response(
444
  prompt,
445
  array_audio,
446
  base64_audio,
447
+ **generation_params,
448
+ **kwargs
449
  )
 
450
 
451
  if error_msg:
452
  st.error(error_msg)
453
  for warning_msg in warnings:
454
  st.warning(warning_msg)
455
+
456
+ response = ""
457
+ if response_obj is not None:
458
+ if kwargs.get("stream", ""):
459
+ response_obj = itertools.chain([prefix], response_obj)
460
+ response = st.write_stream(response_obj)
461
+ else:
462
+ response = response_obj.choices[0].message.content
463
+ st.write(prefix+response)
464
 
465
  st.session_state.logger.register_query(
466
  session_id=st.session_state.session_id,
467
  base64_audio=base64_audio,
468
  text_input=prompt,
469
+ history=kwargs.get("history", []),
470
  params=generation_params,
471
  response=response,
472
  warnings=warnings,
src/content/playground.py CHANGED
@@ -6,6 +6,7 @@ import streamlit as st
6
  from src.generation import MAX_AUDIO_LENGTH
7
  from src.utils import bytes_to_array, array_to_bytes
8
  from src.content.common import (
 
9
  AUDIO_SAMPLES_W_INSTRUCT,
10
  DEFAULT_DIALOGUE_STATES,
11
  init_state_section,
@@ -104,19 +105,15 @@ def audio_attach_dialogue():
104
 
105
  @st.fragment
106
  def select_model_variants_fradment():
107
- display_mapper = {
108
- 'MERaLiON-lora': "MERaLiON-AudioLLM (better transcription)",
109
- 'MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora': "MERaLiON-AudioLLM-instruction-following (more flexible)"
110
- }
111
 
112
  st.selectbox(
113
  label=":fire: Explore more MERaLiON-AudioLLM variants!",
114
- # label_visibility="collapsed",
115
- options=['MERaLiON-lora', 'MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora'],
116
  index=0,
117
  format_func=lambda o: display_mapper[o],
118
- key="model_name",
119
- placeholder=":fire: Explore more model variants!",
120
  disabled=st.session_state.disprompt,
121
  )
122
 
@@ -196,9 +193,10 @@ def conversation_section():
196
  with st.chat_message("assistant"):
197
  with st.spinner("Thinking..."):
198
  error_msg, warnings, response = retrive_response_with_ui(
199
- one_time_prompt,
200
- st.session_state.pg_audio_array,
201
- st.session_state.pg_audio_base64,
 
202
  stream=True
203
  )
204
 
 
6
  from src.generation import MAX_AUDIO_LENGTH
7
  from src.utils import bytes_to_array, array_to_bytes
8
  from src.content.common import (
9
+ MODEL_NAMES,
10
  AUDIO_SAMPLES_W_INSTRUCT,
11
  DEFAULT_DIALOGUE_STATES,
12
  init_state_section,
 
105
 
106
  @st.fragment
107
  def select_model_variants_fradment():
108
+ display_mapper = {value["vllm_name"]: value["ui_name"] for value in MODEL_NAMES.values()}
 
 
 
109
 
110
  st.selectbox(
111
  label=":fire: Explore more MERaLiON-AudioLLM variants!",
112
+ options=[value["vllm_name"] for value in MODEL_NAMES.values()],
 
113
  index=0,
114
  format_func=lambda o: display_mapper[o],
115
+ key="pg_model_name",
116
+ placeholder=":fire: Explore more MERaLiON-AudioLLM variants!",
117
  disabled=st.session_state.disprompt,
118
  )
119
 
 
193
  with st.chat_message("assistant"):
194
  with st.spinner("Thinking..."):
195
  error_msg, warnings, response = retrive_response_with_ui(
196
+ model_name=st.session_state.pg_model_name,
197
+ prompt=one_time_prompt,
198
+ array_audio=st.session_state.pg_audio_array,
199
+ base64_audio=st.session_state.pg_audio_base64,
200
  stream=True
201
  )
202
 
src/content/voice_chat.py CHANGED
@@ -7,6 +7,7 @@ import streamlit as st
7
  from src.generation import MAX_AUDIO_LENGTH
8
  from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
 
10
  DEFAULT_DIALOGUE_STATES,
11
  init_state_section,
12
  header_section,
@@ -122,9 +123,10 @@ def conversation_section():
122
  with st.chat_message("assistant"):
123
  with st.spinner("Thinking..."):
124
  error_msg, warnings, response = retrive_response_with_ui(
125
- one_time_prompt,
126
- one_time_array,
127
- one_time_base64,
 
128
  stream=True
129
  )
130
 
@@ -141,7 +143,6 @@ def conversation_section():
141
 
142
  def voice_chat_page():
143
  init_state_section()
144
- st.session_state.model_name = 'MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora'
145
  header_section(component_name="Voice Chat")
146
 
147
  with st.sidebar:
 
7
  from src.generation import MAX_AUDIO_LENGTH
8
  from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
10
+ MODEL_NAMES,
11
  DEFAULT_DIALOGUE_STATES,
12
  init_state_section,
13
  header_section,
 
123
  with st.chat_message("assistant"):
124
  with st.spinner("Thinking..."):
125
  error_msg, warnings, response = retrive_response_with_ui(
126
+ model_name=MODEL_NAMES["wo_lora"]["vllm_name"],
127
+ prompt=one_time_prompt,
128
+ array_audio=one_time_array,
129
+ base64_audio=one_time_base64,
130
  stream=True
131
  )
132
 
 
143
 
144
  def voice_chat_page():
145
  init_state_section()
 
146
  header_section(component_name="Voice Chat")
147
 
148
  with st.sidebar:
src/generation.py CHANGED
@@ -6,7 +6,7 @@ from typing import List
6
  import streamlit as st
7
  from openai import OpenAI, APIConnectionError
8
 
9
- from src.exceptions import NoAudioException, TunnelNotRunningException
10
 
11
 
12
  local_port = int(os.getenv('LOCAL_PORT'))
@@ -33,48 +33,38 @@ def load_model():
33
  api_key=openai_api_key,
34
  base_url=openai_api_base,
35
  )
36
-
37
- models = client.models.list()
38
- model_name = models.data[0].id
39
 
40
- return client, model_name
41
 
42
  def _retrive_response(text_input: str, base64_audio_input: str, **kwargs):
43
  """
44
  Send request through OpenAI client.
45
  """
46
- return st.session_state.client.chat.completions.create(
47
- messages=[{
48
- "role":
49
- "user",
50
- "content": [
51
- {
52
- "type": "text",
53
- "text": f"Text instruction: {text_input}"
 
 
 
54
  },
55
- {
56
- "type": "audio_url",
57
- "audio_url": {
58
- "url": f"data:audio/ogg;base64,{base64_audio_input}"
59
- },
60
- },
61
- ],
62
- }],
63
  **kwargs
64
  )
65
 
66
 
67
- def _retry_retrive_response_throws_exception(text_input, base64_audio_input, params, stream=False, retry=3):
68
- if not base64_audio_input:
69
- raise NoAudioException("audio is empty.")
70
-
71
  try:
72
- response_object = _retrive_response(
73
- text_input=text_input,
74
- base64_audio_input=base64_audio_input,
75
- stream=stream,
76
- **params
77
- )
78
  except APIConnectionError as e:
79
  if not st.session_state.server.is_running():
80
  if retry == 0:
@@ -87,7 +77,7 @@ def _retry_retrive_response_throws_exception(text_input, base64_audio_input, par
87
  elif st.session_state.server.is_starting():
88
  time.sleep(2)
89
 
90
- return _retry_retrive_response_throws_exception(text_input, retry-1)
91
  raise e
92
 
93
  return response_object
@@ -104,25 +94,37 @@ def _validate_input(text_input, array_audio_input) -> List[str]:
104
  if re.search(r'[\u4e00-\u9fff]+', text_input):
105
  warnings.append("NOTE: Please try to prompt in English for the best performance.")
106
 
 
 
 
107
  if array_audio_input.shape[0] / 16000 > 30.0:
108
  warnings.append((
109
- "MERaLiON-AudioLLM is trained to process audio up to **30 seconds**."
110
  f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated."
111
  ))
112
 
113
  return warnings
114
 
115
 
116
- def retrive_response(text_input, array_audio_input, base64_audio_input, params, stream=False):
 
 
 
 
 
 
 
117
  warnings = _validate_input(text_input, array_audio_input)
118
 
119
  response_object, error_msg = None, ""
120
  try:
121
  response_object = _retry_retrive_response_throws_exception(
122
- text_input, base64_audio_input, params, stream
 
 
 
 
123
  )
124
- except NoAudioException:
125
- error_msg = "Please specify audio first!"
126
  except TunnelNotRunningException:
127
  error_msg = "Internet connection cannot be established. Please contact the administrator."
128
  except Exception as e:
 
6
  import streamlit as st
7
  from openai import OpenAI, APIConnectionError
8
 
9
+ from src.exceptions import TunnelNotRunningException
10
 
11
 
12
  local_port = int(os.getenv('LOCAL_PORT'))
 
33
  api_key=openai_api_key,
34
  base_url=openai_api_base,
35
  )
 
 
 
36
 
37
+ return client
38
 
39
  def _retrive_response(text_input: str, base64_audio_input: str, **kwargs):
40
  """
41
  Send request through OpenAI client.
42
  """
43
+ history = kwargs.pop("history", [])
44
+ if base64_audio_input:
45
+ content = [
46
+ {
47
+ "type": "text",
48
+ "text": f"Text instruction: {text_input}"
49
+ },
50
+ {
51
+ "type": "audio_url",
52
+ "audio_url": {
53
+ "url": f"data:audio/ogg;base64,{base64_audio_input}"
54
  },
55
+ },
56
+ ]
57
+ else:
58
+ content = text_input
59
+ return st.session_state.client.chat.completions.create(
60
+ messages=history + [{"role": "user", "content": content}],
 
 
61
  **kwargs
62
  )
63
 
64
 
65
+ def _retry_retrive_response_throws_exception(retry=3, **kwargs):
 
 
 
66
  try:
67
+ response_object = _retrive_response(**kwargs)
 
 
 
 
 
68
  except APIConnectionError as e:
69
  if not st.session_state.server.is_running():
70
  if retry == 0:
 
77
  elif st.session_state.server.is_starting():
78
  time.sleep(2)
79
 
80
+ return _retry_retrive_response_throws_exception(retry-1, **kwargs)
81
  raise e
82
 
83
  return response_object
 
94
  if re.search(r'[\u4e00-\u9fff]+', text_input):
95
  warnings.append("NOTE: Please try to prompt in English for the best performance.")
96
 
97
+ if array_audio_input.shape[0] == 0:
98
+ warnings.append("NOTE: Please specify audio from examples or local files.")
99
+
100
  if array_audio_input.shape[0] / 16000 > 30.0:
101
  warnings.append((
102
+ "WARNING: MERaLiON-AudioLLM is trained to process audio up to **30 seconds**."
103
  f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated."
104
  ))
105
 
106
  return warnings
107
 
108
 
109
+ def retrive_response(
110
+ text_input,
111
+ array_audio_input,
112
+ base64_audio_input,
113
+ stream=True,
114
+ history=[],
115
+ **kwargs
116
+ ):
117
  warnings = _validate_input(text_input, array_audio_input)
118
 
119
  response_object, error_msg = None, ""
120
  try:
121
  response_object = _retry_retrive_response_throws_exception(
122
+ text_input=text_input,
123
+ base64_audio_input=base64_audio_input,
124
+ stream=stream,
125
+ history=history,
126
+ **kwargs
127
  )
 
 
128
  except TunnelNotRunningException:
129
  error_msg = "Internet connection cannot be established. Please contact the administrator."
130
  except Exception as e:
src/logger.py CHANGED
@@ -49,25 +49,22 @@ class Logger:
49
  session_id,
50
  base64_audio,
51
  text_input,
52
- params,
53
  response,
54
- warnings,
55
- error_msg
56
  ):
57
  new_query_id = self.query_increment
58
  current_time = get_current_strftime()
59
 
60
  with logger_lock:
61
- self.query_data.append({
62
  "session_id": session_id,
63
  "query_id": new_query_id,
64
  "creation_time": current_time,
65
  "text": text_input,
66
- "params": params,
67
  "response": response,
68
- "warnings": warnings,
69
- "error": error_msg,
70
- })
71
 
72
  self.audio_data.append({
73
  "session_id": session_id,
@@ -98,13 +95,13 @@ class Logger:
98
  row_str = json.dumps(row, ensure_ascii=False)+"\n"
99
  buffer.write(row_str.encode("utf-8"))
100
 
101
- api.upload_file(
102
- path_or_fileobj=buffer,
103
- path_in_repo=f"{data_name}/{get_current_strftime()}.json",
104
- repo_id=os.getenv("LOGGING_REPO_NAME"),
105
- repo_type="dataset",
106
- token=os.getenv('HF_TOKEN')
107
- )
108
 
109
  buffer.close()
110
 
 
49
  session_id,
50
  base64_audio,
51
  text_input,
 
52
  response,
53
+ **kwargs
 
54
  ):
55
  new_query_id = self.query_increment
56
  current_time = get_current_strftime()
57
 
58
  with logger_lock:
59
+ current_query_data = {
60
  "session_id": session_id,
61
  "query_id": new_query_id,
62
  "creation_time": current_time,
63
  "text": text_input,
 
64
  "response": response,
65
+ }
66
+ current_query_data.update(kwargs)
67
+ self.query_data.append(current_query_data)
68
 
69
  self.audio_data.append({
70
  "session_id": session_id,
 
95
  row_str = json.dumps(row, ensure_ascii=False)+"\n"
96
  buffer.write(row_str.encode("utf-8"))
97
 
98
+ # api.upload_file(
99
+ # path_or_fileobj=buffer,
100
+ # path_in_repo=f"{data_name}/{get_current_strftime()}.json",
101
+ # repo_id=os.getenv("LOGGING_REPO_NAME"),
102
+ # repo_type="dataset",
103
+ # token=os.getenv('HF_TOKEN')
104
+ # )
105
 
106
  buffer.close()
107
 
src/retrieval.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import numpy as np
4
+ import streamlit as st
5
+ from FlagEmbedding import FlagReranker
6
+
7
+
8
+ @st.cache_resource()
9
+ def load_retriever():
10
+ reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
11
+ reranker.compute_score([["test", "test"]], normalize=True)
12
+ return reranker
13
+
14
+
15
+ def retrieve_relevant_docs(user_question, docs: List[Dict]) -> List[int]:
16
+ scores = st.session_state.retriever.compute_score([[user_question, d["doc_text"]] for d in docs], normalize=True)
17
+ normalized_scores = np.array(scores) / np.sum(scores)
18
+
19
+ selected_indices = np.where((np.array(scores) > 0.02) & (normalized_scores > 0.3))[0]
20
+ return selected_indices.tolist()
style/small_window.css CHANGED
@@ -15,4 +15,10 @@
15
  div[data-testid="stSidebarCollapsedControl"] button[data-testid="stBaseButton-headerNoPadding"]::after {
16
  content: "More Use Cases"
17
  }
 
 
 
 
 
 
18
  }
 
15
  div[data-testid="stSidebarCollapsedControl"] button[data-testid="stBaseButton-headerNoPadding"]::after {
16
  content: "More Use Cases"
17
  }
18
+ }
19
+
20
+ @media (max-width: 916px) and (max-height: 958px) {
21
+ div[height="480"][data-testid="stVerticalBlockBorderWrapper"] {
22
+ height: 380px;
23
+ }
24
  }