Spaces:
Running
Running
add agent
Browse files- pages/agent.py +3 -0
- requirements.txt +5 -1
- src/content/agent.py +240 -0
- src/content/common.py +78 -21
- src/content/playground.py +9 -11
- src/content/voice_chat.py +5 -4
- src/generation.py +39 -37
- src/logger.py +12 -15
- src/retrieval.py +20 -0
- style/small_window.css +6 -0
pages/agent.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from src.content.agent import agent_page
|
2 |
+
|
3 |
+
agent_page()
|
requirements.txt
CHANGED
@@ -2,4 +2,8 @@ librosa==0.10.2.post1
|
|
2 |
streamlit==1.40.2
|
3 |
openai==1.57.1
|
4 |
streamlit_mic_recorder==0.0.8
|
5 |
-
sshtunnel
|
|
|
|
|
|
|
|
|
|
2 |
streamlit==1.40.2
|
3 |
openai==1.57.1
|
4 |
streamlit_mic_recorder==0.0.8
|
5 |
+
sshtunnel
|
6 |
+
accelerate==1.3.0
|
7 |
+
FlagEmbedding==1.3.3
|
8 |
+
sentence-transformers==3.4.0
|
9 |
+
sentencepiece==0.1.99
|
src/content/agent.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import base64
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
from src.generation import MAX_AUDIO_LENGTH
|
7 |
+
from src.retrieval import retrieve_relevant_docs
|
8 |
+
from src.utils import bytes_to_array, array_to_bytes
|
9 |
+
from src.content.common import (
|
10 |
+
MODEL_NAMES,
|
11 |
+
AUDIO_SAMPLES_W_INSTRUCT,
|
12 |
+
STANDARD_QUERIES,
|
13 |
+
DEFAULT_DIALOGUE_STATES,
|
14 |
+
init_state_section,
|
15 |
+
header_section,
|
16 |
+
sidebar_fragment,
|
17 |
+
retrive_response_with_ui
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip.
|
22 |
+
|
23 |
+
## User question
|
24 |
+
{user_question}
|
25 |
+
|
26 |
+
{audio_information_prompt}Please reply this user question with an friendly, accurate, and helpful answer."""
|
27 |
+
|
28 |
+
AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip.
|
29 |
+
|
30 |
+
## Audio Information
|
31 |
+
{audio_information}
|
32 |
+
|
33 |
+
This may or may not contain relevant information to the user question, please use with caution.
|
34 |
+
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
def _update_audio(audio_bytes):
|
39 |
+
origin_audio_array = bytes_to_array(audio_bytes)
|
40 |
+
truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
|
41 |
+
truncated_audio_bytes = array_to_bytes(truncated_audio_array)
|
42 |
+
|
43 |
+
st.session_state.ag_audio_array = origin_audio_array
|
44 |
+
st.session_state.ag_audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8')
|
45 |
+
|
46 |
+
|
47 |
+
@st.fragment
|
48 |
+
def successful_example_section():
|
49 |
+
audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
|
50 |
+
|
51 |
+
st.markdown(":fire: **Successful Tasks and Examples**")
|
52 |
+
|
53 |
+
sample_name = st.selectbox(
|
54 |
+
label="**Select Audio:**",
|
55 |
+
label_visibility="collapsed",
|
56 |
+
options=audio_sample_names,
|
57 |
+
format_func=lambda o: AUDIO_SAMPLES_W_INSTRUCT[o]["apperance"],
|
58 |
+
index=None,
|
59 |
+
placeholder="Select an audio sample:",
|
60 |
+
on_change=lambda: st.session_state.update(
|
61 |
+
on_select=True,
|
62 |
+
ag_messages=[],
|
63 |
+
ag_model_messages=[],
|
64 |
+
disprompt=True
|
65 |
+
),
|
66 |
+
key='select')
|
67 |
+
|
68 |
+
if sample_name and st.session_state.on_select:
|
69 |
+
audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
|
70 |
+
st.session_state.update(
|
71 |
+
on_select=False,
|
72 |
+
new_prompt=AUDIO_SAMPLES_W_INSTRUCT[sample_name]["instructions"][0]
|
73 |
+
)
|
74 |
+
_update_audio(audio_bytes)
|
75 |
+
st.rerun(scope="app")
|
76 |
+
|
77 |
+
|
78 |
+
@st.dialog("Specify Audio")
|
79 |
+
def audio_attach_dialogue():
|
80 |
+
st.markdown("**Upload**")
|
81 |
+
|
82 |
+
uploaded_file = st.file_uploader(
|
83 |
+
label="**Upload Audio:**",
|
84 |
+
label_visibility="collapsed",
|
85 |
+
type=['wav', 'mp3'],
|
86 |
+
on_change=lambda: st.session_state.update(on_upload=True, ag_messages=[], ag_model_messages=[]),
|
87 |
+
key='upload'
|
88 |
+
)
|
89 |
+
|
90 |
+
if uploaded_file and st.session_state.on_upload:
|
91 |
+
audio_bytes = uploaded_file.read()
|
92 |
+
_update_audio(audio_bytes)
|
93 |
+
st.session_state.on_upload = False
|
94 |
+
st.rerun()
|
95 |
+
|
96 |
+
st.markdown("**Record**")
|
97 |
+
|
98 |
+
uploaded_file = st.audio_input(
|
99 |
+
label="**Record Audio:**",
|
100 |
+
label_visibility="collapsed",
|
101 |
+
on_change=lambda: st.session_state.update(on_record=True, ag_messages=[], ag_model_messages=[]),
|
102 |
+
key='record'
|
103 |
+
)
|
104 |
+
|
105 |
+
if uploaded_file and st.session_state.on_record:
|
106 |
+
audio_bytes = uploaded_file.read()
|
107 |
+
_update_audio(audio_bytes)
|
108 |
+
st.session_state.on_record = False
|
109 |
+
st.rerun()
|
110 |
+
|
111 |
+
|
112 |
+
def bottom_input_section():
|
113 |
+
bottom_cols = st.columns([0.03, 0.03, 0.94])
|
114 |
+
with bottom_cols[0]:
|
115 |
+
st.button(
|
116 |
+
'Clear',
|
117 |
+
disabled=st.session_state.disprompt,
|
118 |
+
on_click=lambda: st.session_state.update(copy.deepcopy(DEFAULT_DIALOGUE_STATES))
|
119 |
+
)
|
120 |
+
|
121 |
+
with bottom_cols[1]:
|
122 |
+
if st.button("\+ Audio", disabled=st.session_state.disprompt):
|
123 |
+
audio_attach_dialogue()
|
124 |
+
|
125 |
+
with bottom_cols[2]:
|
126 |
+
if chat_input := st.chat_input(
|
127 |
+
placeholder="Instruction...",
|
128 |
+
disabled=st.session_state.disprompt,
|
129 |
+
on_submit=lambda: st.session_state.update(disprompt=True)
|
130 |
+
):
|
131 |
+
st.session_state.new_prompt = chat_input
|
132 |
+
|
133 |
+
|
134 |
+
def conversation_section():
|
135 |
+
chat_message_container = st.container(height=480)
|
136 |
+
if st.session_state.ag_audio_array.size:
|
137 |
+
with chat_message_container.chat_message("user"):
|
138 |
+
st.audio(st.session_state.ag_audio_array, format="audio/wav", sample_rate=16000)
|
139 |
+
|
140 |
+
for message in st.session_state.ag_messages:
|
141 |
+
message_name = "assistant" if "assistant" in message["role"] else message["role"]
|
142 |
+
|
143 |
+
with chat_message_container.chat_message(name=message_name):
|
144 |
+
if message.get("error"):
|
145 |
+
st.error(message["error"])
|
146 |
+
for warning_msg in message.get("warnings", []):
|
147 |
+
st.warning(warning_msg)
|
148 |
+
if process := message.get("process", []):
|
149 |
+
with st.status("Thought process...", expanded=True, state="complete"):
|
150 |
+
for proc in process:
|
151 |
+
if proc.get("error"):
|
152 |
+
st.error(proc["error"])
|
153 |
+
for proc_warning_msg in proc.get("warnings", []):
|
154 |
+
st.warning(proc_warning_msg)
|
155 |
+
if proc.get("content"):
|
156 |
+
st.write(proc["content"])
|
157 |
+
if message.get("content"):
|
158 |
+
st.write(message["content"])
|
159 |
+
|
160 |
+
with st._bottom:
|
161 |
+
bottom_input_section()
|
162 |
+
|
163 |
+
if one_time_prompt := st.session_state.new_prompt:
|
164 |
+
st.session_state.update(new_prompt="")
|
165 |
+
|
166 |
+
with chat_message_container.chat_message("user"):
|
167 |
+
st.write(one_time_prompt)
|
168 |
+
st.session_state.ag_messages.append({"role": "user", "content": one_time_prompt})
|
169 |
+
|
170 |
+
with chat_message_container.chat_message("assistant"):
|
171 |
+
assistant_message = {"role": "assistant", "process": []}
|
172 |
+
st.session_state.ag_messages.append(assistant_message)
|
173 |
+
|
174 |
+
relevant_query_indices = retrieve_relevant_docs(one_time_prompt, STANDARD_QUERIES)
|
175 |
+
if len(st.session_state.ag_messages) <= 2:
|
176 |
+
relevant_query_indices.append(0)
|
177 |
+
|
178 |
+
relevant_query_indices = list(set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices))
|
179 |
+
|
180 |
+
audio_info = []
|
181 |
+
if relevant_query_indices:
|
182 |
+
with st.status("Thought process...", expanded=True) as status:
|
183 |
+
for idx in relevant_query_indices:
|
184 |
+
error_msg, warnings, response = retrive_response_with_ui(
|
185 |
+
model_name=MODEL_NAMES["with_lora"]["vllm_name"],
|
186 |
+
prompt=STANDARD_QUERIES[idx]["query_text"],
|
187 |
+
array_audio=st.session_state.ag_audio_array,
|
188 |
+
base64_audio=st.session_state.ag_audio_base64,
|
189 |
+
prefix=f"**{STANDARD_QUERIES[idx]['ui_text']}** :speech_balloon: : ",
|
190 |
+
stream=True
|
191 |
+
)
|
192 |
+
audio_info.append(STANDARD_QUERIES[idx]["response_prefix_text"] + response)
|
193 |
+
|
194 |
+
assistant_message["process"].append({
|
195 |
+
"error": error_msg,
|
196 |
+
"warnings": warnings,
|
197 |
+
"content": response
|
198 |
+
})
|
199 |
+
|
200 |
+
status.update(state="complete")
|
201 |
+
|
202 |
+
audio_information_prompt = ""
|
203 |
+
if audio_info:
|
204 |
+
audio_information_prompt = AUDIO_INFO_TEMPLATE.format(
|
205 |
+
audio_information="\n".join(audio_info)
|
206 |
+
)
|
207 |
+
|
208 |
+
prompt = LLM_PROMPT_TEMPLATE.format(
|
209 |
+
user_question=one_time_prompt,
|
210 |
+
audio_information_prompt=audio_information_prompt
|
211 |
+
)
|
212 |
+
|
213 |
+
error_msg, warnings, response = retrive_response_with_ui(
|
214 |
+
model_name=MODEL_NAMES["wo_lora"]["vllm_name"],
|
215 |
+
prompt=prompt,
|
216 |
+
array_audio=st.session_state.ag_audio_array,
|
217 |
+
base64_audio="",
|
218 |
+
stream=True,
|
219 |
+
history=st.session_state.ag_model_messages
|
220 |
+
)
|
221 |
+
|
222 |
+
assistant_message.update({"error": error_msg, "warnings": warnings, "content": response})
|
223 |
+
st.session_state.ag_model_messages.extend([
|
224 |
+
{"role": "user", "content": prompt},
|
225 |
+
{"role": "assistant", "content": response}
|
226 |
+
])
|
227 |
+
|
228 |
+
st.session_state.disprompt=False
|
229 |
+
st.rerun(scope="app")
|
230 |
+
|
231 |
+
|
232 |
+
def agent_page():
|
233 |
+
init_state_section()
|
234 |
+
header_section(component_name="Agent System", icon="👥")
|
235 |
+
|
236 |
+
with st.sidebar:
|
237 |
+
sidebar_fragment()
|
238 |
+
|
239 |
+
successful_example_section()
|
240 |
+
conversation_section()
|
src/content/common.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
import copy
|
|
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import streamlit as st
|
5 |
|
6 |
from src.tunnel import start_server
|
7 |
from src.generation import FIXED_GENERATION_CONFIG, load_model, retrive_response
|
|
|
8 |
from src.logger import load_logger
|
9 |
|
10 |
|
@@ -15,6 +18,11 @@ DEFAULT_DIALOGUE_STATES = dict(
|
|
15 |
vc_audio_base64='',
|
16 |
vc_audio_array=np.array([]),
|
17 |
vc_messages=[],
|
|
|
|
|
|
|
|
|
|
|
18 |
disprompt = False,
|
19 |
new_prompt = "",
|
20 |
on_select=False,
|
@@ -24,17 +32,16 @@ DEFAULT_DIALOGUE_STATES = dict(
|
|
24 |
)
|
25 |
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
)
|
38 |
|
39 |
|
40 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
@@ -295,6 +302,46 @@ AUDIO_SAMPLES_W_INSTRUCT = {
|
|
295 |
}
|
296 |
|
297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
def init_state_section():
|
299 |
st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
|
300 |
|
@@ -317,7 +364,10 @@ def init_state_section():
|
|
317 |
st.session_state.server = start_server()
|
318 |
|
319 |
if "client" not in st.session_state or 'model_name' not in st.session_state:
|
320 |
-
st.session_state.client
|
|
|
|
|
|
|
321 |
|
322 |
for key, value in FIXED_GENERATION_CONFIG.items():
|
323 |
if key not in st.session_state:
|
@@ -364,8 +414,8 @@ def header_section(component_name="Playground", icon="🤖"):
|
|
364 |
def sidebar_fragment():
|
365 |
with st.container(height=256, border=False):
|
366 |
st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
|
|
|
367 |
st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
|
368 |
-
|
369 |
|
370 |
st.divider()
|
371 |
|
@@ -376,9 +426,9 @@ def sidebar_fragment():
|
|
376 |
st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
|
377 |
|
378 |
|
379 |
-
def retrive_response_with_ui(prompt, array_audio, base64_audio,
|
380 |
generation_params = dict(
|
381 |
-
model=
|
382 |
max_completion_tokens=st.session_state.max_completion_tokens,
|
383 |
temperature=st.session_state.temperature,
|
384 |
top_p=st.session_state.top_p,
|
@@ -390,26 +440,33 @@ def retrive_response_with_ui(prompt, array_audio, base64_audio, stream):
|
|
390 |
seed=st.session_state.seed
|
391 |
)
|
392 |
|
393 |
-
error_msg, warnings,
|
394 |
prompt,
|
395 |
array_audio,
|
396 |
base64_audio,
|
397 |
-
|
398 |
-
|
399 |
)
|
400 |
-
response = ""
|
401 |
|
402 |
if error_msg:
|
403 |
st.error(error_msg)
|
404 |
for warning_msg in warnings:
|
405 |
st.warning(warning_msg)
|
406 |
-
|
407 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
st.session_state.logger.register_query(
|
410 |
session_id=st.session_state.session_id,
|
411 |
base64_audio=base64_audio,
|
412 |
text_input=prompt,
|
|
|
413 |
params=generation_params,
|
414 |
response=response,
|
415 |
warnings=warnings,
|
|
|
1 |
import copy
|
2 |
+
import itertools
|
3 |
+
from collections import OrderedDict
|
4 |
|
5 |
import numpy as np
|
6 |
import streamlit as st
|
7 |
|
8 |
from src.tunnel import start_server
|
9 |
from src.generation import FIXED_GENERATION_CONFIG, load_model, retrive_response
|
10 |
+
from src.retrieval import load_retriever
|
11 |
from src.logger import load_logger
|
12 |
|
13 |
|
|
|
18 |
vc_audio_base64='',
|
19 |
vc_audio_array=np.array([]),
|
20 |
vc_messages=[],
|
21 |
+
ag_audio_base64='',
|
22 |
+
ag_audio_array=np.array([]),
|
23 |
+
ag_visited_query_indices=[],
|
24 |
+
ag_messages=[],
|
25 |
+
ag_model_messages=[],
|
26 |
disprompt = False,
|
27 |
new_prompt = "",
|
28 |
on_select=False,
|
|
|
32 |
)
|
33 |
|
34 |
|
35 |
+
MODEL_NAMES = OrderedDict({
|
36 |
+
"with_lora": {
|
37 |
+
"vllm_name": "MERaLiON-lora",
|
38 |
+
"ui_name": "MERaLiON-AudioLLM (more accurate)"
|
39 |
+
},
|
40 |
+
"wo_lora": {
|
41 |
+
"vllm_name": "MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora",
|
42 |
+
"ui_name": "MERaLiON-AudioLLM-instruction-following (more flexible)"
|
43 |
+
}
|
44 |
+
})
|
|
|
45 |
|
46 |
|
47 |
AUDIO_SAMPLES_W_INSTRUCT = {
|
|
|
302 |
}
|
303 |
|
304 |
|
305 |
+
STANDARD_QUERIES = [
|
306 |
+
{
|
307 |
+
"query_text": "Please transcribe this speech.",
|
308 |
+
"doc_text": "Listen to a speech and write down exactly what is being said in text form. It's essentially converting spoken words into written words. Provide the exact transcription of the given audio. Record whatever the speaker has said into written text.",
|
309 |
+
"response_prefix_text": "The transcription of the speech is: ",
|
310 |
+
"ui_text": "speech trancription"
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"query_text": "Please describe what happended in this audio",
|
314 |
+
"doc_text": "Text captions describing the sound events and environments in the audio clips, describing the events and actions happened in the audio.",
|
315 |
+
"response_prefix_text": "Events in this audio clip: ",
|
316 |
+
"ui_text": "audio caption"
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"query_text": "May I know the gender of the speakers",
|
320 |
+
"doc_text": "Please identify speaker gender by analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders.",
|
321 |
+
"response_prefix_text": "By analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders: ",
|
322 |
+
"ui_text": "gender recognition"
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"query_text": "May I know the nationality of the speakers",
|
326 |
+
"doc_text": "Discover speakers' nationality, country, or the place he is coming from. Analyze speakers' accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds.",
|
327 |
+
"response_prefix_text": "By analyzing accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds: ",
|
328 |
+
"ui_text": "accent recognition"
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"query_text": "Can you guess which ethnic group this person is from based on their accent.",
|
332 |
+
"doc_text": "Discover speakers' ethnic group, home country, or the place he is coming from, from speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors.",
|
333 |
+
"response_prefix_text": "By analyzing speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors: ",
|
334 |
+
"ui_text": "accent recognition"
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"query_text": "What do you think the speakers are feeling.",
|
338 |
+
"doc_text": "What do you think the speakers are feeling. Please identify speakers' emotions by analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy, which reflect emotional states such as happiness, anger, sadness, or fear.",
|
339 |
+
"response_prefix_text": "By analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy: ",
|
340 |
+
"ui_text": "emotion recognition"
|
341 |
+
},
|
342 |
+
]
|
343 |
+
|
344 |
+
|
345 |
def init_state_section():
|
346 |
st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
|
347 |
|
|
|
364 |
st.session_state.server = start_server()
|
365 |
|
366 |
if "client" not in st.session_state or 'model_name' not in st.session_state:
|
367 |
+
st.session_state.client = load_model()
|
368 |
+
|
369 |
+
if "retriever" not in st.session_state:
|
370 |
+
st.session_state.retriever = load_retriever()
|
371 |
|
372 |
for key, value in FIXED_GENERATION_CONFIG.items():
|
373 |
if key not in st.session_state:
|
|
|
414 |
def sidebar_fragment():
|
415 |
with st.container(height=256, border=False):
|
416 |
st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
|
417 |
+
st.page_link("pages/agent.py", disabled=st.session_state.disprompt, label="👥 Multi-Agent System")
|
418 |
st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
|
|
|
419 |
|
420 |
st.divider()
|
421 |
|
|
|
426 |
st.slider(label="Repetition Penalty", min_value=1.0, max_value=1.2, value=1.1, key="repetition_penalty")
|
427 |
|
428 |
|
429 |
+
def retrive_response_with_ui(model_name, prompt, array_audio, base64_audio, prefix="", **kwargs):
|
430 |
generation_params = dict(
|
431 |
+
model=model_name,
|
432 |
max_completion_tokens=st.session_state.max_completion_tokens,
|
433 |
temperature=st.session_state.temperature,
|
434 |
top_p=st.session_state.top_p,
|
|
|
440 |
seed=st.session_state.seed
|
441 |
)
|
442 |
|
443 |
+
error_msg, warnings, response_obj = retrive_response(
|
444 |
prompt,
|
445 |
array_audio,
|
446 |
base64_audio,
|
447 |
+
**generation_params,
|
448 |
+
**kwargs
|
449 |
)
|
|
|
450 |
|
451 |
if error_msg:
|
452 |
st.error(error_msg)
|
453 |
for warning_msg in warnings:
|
454 |
st.warning(warning_msg)
|
455 |
+
|
456 |
+
response = ""
|
457 |
+
if response_obj is not None:
|
458 |
+
if kwargs.get("stream", ""):
|
459 |
+
response_obj = itertools.chain([prefix], response_obj)
|
460 |
+
response = st.write_stream(response_obj)
|
461 |
+
else:
|
462 |
+
response = response_obj.choices[0].message.content
|
463 |
+
st.write(prefix+response)
|
464 |
|
465 |
st.session_state.logger.register_query(
|
466 |
session_id=st.session_state.session_id,
|
467 |
base64_audio=base64_audio,
|
468 |
text_input=prompt,
|
469 |
+
history=kwargs.get("history", []),
|
470 |
params=generation_params,
|
471 |
response=response,
|
472 |
warnings=warnings,
|
src/content/playground.py
CHANGED
@@ -6,6 +6,7 @@ import streamlit as st
|
|
6 |
from src.generation import MAX_AUDIO_LENGTH
|
7 |
from src.utils import bytes_to_array, array_to_bytes
|
8 |
from src.content.common import (
|
|
|
9 |
AUDIO_SAMPLES_W_INSTRUCT,
|
10 |
DEFAULT_DIALOGUE_STATES,
|
11 |
init_state_section,
|
@@ -104,19 +105,15 @@ def audio_attach_dialogue():
|
|
104 |
|
105 |
@st.fragment
|
106 |
def select_model_variants_fradment():
|
107 |
-
display_mapper = {
|
108 |
-
'MERaLiON-lora': "MERaLiON-AudioLLM (better transcription)",
|
109 |
-
'MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora': "MERaLiON-AudioLLM-instruction-following (more flexible)"
|
110 |
-
}
|
111 |
|
112 |
st.selectbox(
|
113 |
label=":fire: Explore more MERaLiON-AudioLLM variants!",
|
114 |
-
|
115 |
-
options=['MERaLiON-lora', 'MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora'],
|
116 |
index=0,
|
117 |
format_func=lambda o: display_mapper[o],
|
118 |
-
key="
|
119 |
-
placeholder=":fire: Explore more
|
120 |
disabled=st.session_state.disprompt,
|
121 |
)
|
122 |
|
@@ -196,9 +193,10 @@ def conversation_section():
|
|
196 |
with st.chat_message("assistant"):
|
197 |
with st.spinner("Thinking..."):
|
198 |
error_msg, warnings, response = retrive_response_with_ui(
|
199 |
-
|
200 |
-
|
201 |
-
st.session_state.
|
|
|
202 |
stream=True
|
203 |
)
|
204 |
|
|
|
6 |
from src.generation import MAX_AUDIO_LENGTH
|
7 |
from src.utils import bytes_to_array, array_to_bytes
|
8 |
from src.content.common import (
|
9 |
+
MODEL_NAMES,
|
10 |
AUDIO_SAMPLES_W_INSTRUCT,
|
11 |
DEFAULT_DIALOGUE_STATES,
|
12 |
init_state_section,
|
|
|
105 |
|
106 |
@st.fragment
|
107 |
def select_model_variants_fradment():
|
108 |
+
display_mapper = {value["vllm_name"]: value["ui_name"] for value in MODEL_NAMES.values()}
|
|
|
|
|
|
|
109 |
|
110 |
st.selectbox(
|
111 |
label=":fire: Explore more MERaLiON-AudioLLM variants!",
|
112 |
+
options=[value["vllm_name"] for value in MODEL_NAMES.values()],
|
|
|
113 |
index=0,
|
114 |
format_func=lambda o: display_mapper[o],
|
115 |
+
key="pg_model_name",
|
116 |
+
placeholder=":fire: Explore more MERaLiON-AudioLLM variants!",
|
117 |
disabled=st.session_state.disprompt,
|
118 |
)
|
119 |
|
|
|
193 |
with st.chat_message("assistant"):
|
194 |
with st.spinner("Thinking..."):
|
195 |
error_msg, warnings, response = retrive_response_with_ui(
|
196 |
+
model_name=st.session_state.pg_model_name,
|
197 |
+
prompt=one_time_prompt,
|
198 |
+
array_audio=st.session_state.pg_audio_array,
|
199 |
+
base64_audio=st.session_state.pg_audio_base64,
|
200 |
stream=True
|
201 |
)
|
202 |
|
src/content/voice_chat.py
CHANGED
@@ -7,6 +7,7 @@ import streamlit as st
|
|
7 |
from src.generation import MAX_AUDIO_LENGTH
|
8 |
from src.utils import bytes_to_array, array_to_bytes
|
9 |
from src.content.common import (
|
|
|
10 |
DEFAULT_DIALOGUE_STATES,
|
11 |
init_state_section,
|
12 |
header_section,
|
@@ -122,9 +123,10 @@ def conversation_section():
|
|
122 |
with st.chat_message("assistant"):
|
123 |
with st.spinner("Thinking..."):
|
124 |
error_msg, warnings, response = retrive_response_with_ui(
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
128 |
stream=True
|
129 |
)
|
130 |
|
@@ -141,7 +143,6 @@ def conversation_section():
|
|
141 |
|
142 |
def voice_chat_page():
|
143 |
init_state_section()
|
144 |
-
st.session_state.model_name = 'MERaLiON_local/MERaLiON-AudioLLM-Whisper-SEA-LION-wo-lora'
|
145 |
header_section(component_name="Voice Chat")
|
146 |
|
147 |
with st.sidebar:
|
|
|
7 |
from src.generation import MAX_AUDIO_LENGTH
|
8 |
from src.utils import bytes_to_array, array_to_bytes
|
9 |
from src.content.common import (
|
10 |
+
MODEL_NAMES,
|
11 |
DEFAULT_DIALOGUE_STATES,
|
12 |
init_state_section,
|
13 |
header_section,
|
|
|
123 |
with st.chat_message("assistant"):
|
124 |
with st.spinner("Thinking..."):
|
125 |
error_msg, warnings, response = retrive_response_with_ui(
|
126 |
+
model_name=MODEL_NAMES["wo_lora"]["vllm_name"],
|
127 |
+
prompt=one_time_prompt,
|
128 |
+
array_audio=one_time_array,
|
129 |
+
base64_audio=one_time_base64,
|
130 |
stream=True
|
131 |
)
|
132 |
|
|
|
143 |
|
144 |
def voice_chat_page():
|
145 |
init_state_section()
|
|
|
146 |
header_section(component_name="Voice Chat")
|
147 |
|
148 |
with st.sidebar:
|
src/generation.py
CHANGED
@@ -6,7 +6,7 @@ from typing import List
|
|
6 |
import streamlit as st
|
7 |
from openai import OpenAI, APIConnectionError
|
8 |
|
9 |
-
from src.exceptions import
|
10 |
|
11 |
|
12 |
local_port = int(os.getenv('LOCAL_PORT'))
|
@@ -33,48 +33,38 @@ def load_model():
|
|
33 |
api_key=openai_api_key,
|
34 |
base_url=openai_api_base,
|
35 |
)
|
36 |
-
|
37 |
-
models = client.models.list()
|
38 |
-
model_name = models.data[0].id
|
39 |
|
40 |
-
return client
|
41 |
|
42 |
def _retrive_response(text_input: str, base64_audio_input: str, **kwargs):
|
43 |
"""
|
44 |
Send request through OpenAI client.
|
45 |
"""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
{
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
54 |
},
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
],
|
62 |
-
}],
|
63 |
**kwargs
|
64 |
)
|
65 |
|
66 |
|
67 |
-
def _retry_retrive_response_throws_exception(
|
68 |
-
if not base64_audio_input:
|
69 |
-
raise NoAudioException("audio is empty.")
|
70 |
-
|
71 |
try:
|
72 |
-
response_object = _retrive_response(
|
73 |
-
text_input=text_input,
|
74 |
-
base64_audio_input=base64_audio_input,
|
75 |
-
stream=stream,
|
76 |
-
**params
|
77 |
-
)
|
78 |
except APIConnectionError as e:
|
79 |
if not st.session_state.server.is_running():
|
80 |
if retry == 0:
|
@@ -87,7 +77,7 @@ def _retry_retrive_response_throws_exception(text_input, base64_audio_input, par
|
|
87 |
elif st.session_state.server.is_starting():
|
88 |
time.sleep(2)
|
89 |
|
90 |
-
return _retry_retrive_response_throws_exception(
|
91 |
raise e
|
92 |
|
93 |
return response_object
|
@@ -104,25 +94,37 @@ def _validate_input(text_input, array_audio_input) -> List[str]:
|
|
104 |
if re.search(r'[\u4e00-\u9fff]+', text_input):
|
105 |
warnings.append("NOTE: Please try to prompt in English for the best performance.")
|
106 |
|
|
|
|
|
|
|
107 |
if array_audio_input.shape[0] / 16000 > 30.0:
|
108 |
warnings.append((
|
109 |
-
"MERaLiON-AudioLLM is trained to process audio up to **30 seconds**."
|
110 |
f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated."
|
111 |
))
|
112 |
|
113 |
return warnings
|
114 |
|
115 |
|
116 |
-
def retrive_response(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
warnings = _validate_input(text_input, array_audio_input)
|
118 |
|
119 |
response_object, error_msg = None, ""
|
120 |
try:
|
121 |
response_object = _retry_retrive_response_throws_exception(
|
122 |
-
text_input,
|
|
|
|
|
|
|
|
|
123 |
)
|
124 |
-
except NoAudioException:
|
125 |
-
error_msg = "Please specify audio first!"
|
126 |
except TunnelNotRunningException:
|
127 |
error_msg = "Internet connection cannot be established. Please contact the administrator."
|
128 |
except Exception as e:
|
|
|
6 |
import streamlit as st
|
7 |
from openai import OpenAI, APIConnectionError
|
8 |
|
9 |
+
from src.exceptions import TunnelNotRunningException
|
10 |
|
11 |
|
12 |
local_port = int(os.getenv('LOCAL_PORT'))
|
|
|
33 |
api_key=openai_api_key,
|
34 |
base_url=openai_api_base,
|
35 |
)
|
|
|
|
|
|
|
36 |
|
37 |
+
return client
|
38 |
|
39 |
def _retrive_response(text_input: str, base64_audio_input: str, **kwargs):
|
40 |
"""
|
41 |
Send request through OpenAI client.
|
42 |
"""
|
43 |
+
history = kwargs.pop("history", [])
|
44 |
+
if base64_audio_input:
|
45 |
+
content = [
|
46 |
+
{
|
47 |
+
"type": "text",
|
48 |
+
"text": f"Text instruction: {text_input}"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"type": "audio_url",
|
52 |
+
"audio_url": {
|
53 |
+
"url": f"data:audio/ogg;base64,{base64_audio_input}"
|
54 |
},
|
55 |
+
},
|
56 |
+
]
|
57 |
+
else:
|
58 |
+
content = text_input
|
59 |
+
return st.session_state.client.chat.completions.create(
|
60 |
+
messages=history + [{"role": "user", "content": content}],
|
|
|
|
|
61 |
**kwargs
|
62 |
)
|
63 |
|
64 |
|
65 |
+
def _retry_retrive_response_throws_exception(retry=3, **kwargs):
|
|
|
|
|
|
|
66 |
try:
|
67 |
+
response_object = _retrive_response(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
68 |
except APIConnectionError as e:
|
69 |
if not st.session_state.server.is_running():
|
70 |
if retry == 0:
|
|
|
77 |
elif st.session_state.server.is_starting():
|
78 |
time.sleep(2)
|
79 |
|
80 |
+
return _retry_retrive_response_throws_exception(retry-1, **kwargs)
|
81 |
raise e
|
82 |
|
83 |
return response_object
|
|
|
94 |
if re.search(r'[\u4e00-\u9fff]+', text_input):
|
95 |
warnings.append("NOTE: Please try to prompt in English for the best performance.")
|
96 |
|
97 |
+
if array_audio_input.shape[0] == 0:
|
98 |
+
warnings.append("NOTE: Please specify audio from examples or local files.")
|
99 |
+
|
100 |
if array_audio_input.shape[0] / 16000 > 30.0:
|
101 |
warnings.append((
|
102 |
+
"WARNING: MERaLiON-AudioLLM is trained to process audio up to **30 seconds**."
|
103 |
f" Audio longer than **{MAX_AUDIO_LENGTH} seconds** will be truncated."
|
104 |
))
|
105 |
|
106 |
return warnings
|
107 |
|
108 |
|
109 |
+
def retrive_response(
|
110 |
+
text_input,
|
111 |
+
array_audio_input,
|
112 |
+
base64_audio_input,
|
113 |
+
stream=True,
|
114 |
+
history=[],
|
115 |
+
**kwargs
|
116 |
+
):
|
117 |
warnings = _validate_input(text_input, array_audio_input)
|
118 |
|
119 |
response_object, error_msg = None, ""
|
120 |
try:
|
121 |
response_object = _retry_retrive_response_throws_exception(
|
122 |
+
text_input=text_input,
|
123 |
+
base64_audio_input=base64_audio_input,
|
124 |
+
stream=stream,
|
125 |
+
history=history,
|
126 |
+
**kwargs
|
127 |
)
|
|
|
|
|
128 |
except TunnelNotRunningException:
|
129 |
error_msg = "Internet connection cannot be established. Please contact the administrator."
|
130 |
except Exception as e:
|
src/logger.py
CHANGED
@@ -49,25 +49,22 @@ class Logger:
|
|
49 |
session_id,
|
50 |
base64_audio,
|
51 |
text_input,
|
52 |
-
params,
|
53 |
response,
|
54 |
-
|
55 |
-
error_msg
|
56 |
):
|
57 |
new_query_id = self.query_increment
|
58 |
current_time = get_current_strftime()
|
59 |
|
60 |
with logger_lock:
|
61 |
-
|
62 |
"session_id": session_id,
|
63 |
"query_id": new_query_id,
|
64 |
"creation_time": current_time,
|
65 |
"text": text_input,
|
66 |
-
"params": params,
|
67 |
"response": response,
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
|
72 |
self.audio_data.append({
|
73 |
"session_id": session_id,
|
@@ -98,13 +95,13 @@ class Logger:
|
|
98 |
row_str = json.dumps(row, ensure_ascii=False)+"\n"
|
99 |
buffer.write(row_str.encode("utf-8"))
|
100 |
|
101 |
-
api.upload_file(
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
)
|
108 |
|
109 |
buffer.close()
|
110 |
|
|
|
49 |
session_id,
|
50 |
base64_audio,
|
51 |
text_input,
|
|
|
52 |
response,
|
53 |
+
**kwargs
|
|
|
54 |
):
|
55 |
new_query_id = self.query_increment
|
56 |
current_time = get_current_strftime()
|
57 |
|
58 |
with logger_lock:
|
59 |
+
current_query_data = {
|
60 |
"session_id": session_id,
|
61 |
"query_id": new_query_id,
|
62 |
"creation_time": current_time,
|
63 |
"text": text_input,
|
|
|
64 |
"response": response,
|
65 |
+
}
|
66 |
+
current_query_data.update(kwargs)
|
67 |
+
self.query_data.append(current_query_data)
|
68 |
|
69 |
self.audio_data.append({
|
70 |
"session_id": session_id,
|
|
|
95 |
row_str = json.dumps(row, ensure_ascii=False)+"\n"
|
96 |
buffer.write(row_str.encode("utf-8"))
|
97 |
|
98 |
+
# api.upload_file(
|
99 |
+
# path_or_fileobj=buffer,
|
100 |
+
# path_in_repo=f"{data_name}/{get_current_strftime()}.json",
|
101 |
+
# repo_id=os.getenv("LOGGING_REPO_NAME"),
|
102 |
+
# repo_type="dataset",
|
103 |
+
# token=os.getenv('HF_TOKEN')
|
104 |
+
# )
|
105 |
|
106 |
buffer.close()
|
107 |
|
src/retrieval.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
+
from FlagEmbedding import FlagReranker
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache_resource()
|
9 |
+
def load_retriever():
|
10 |
+
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
|
11 |
+
reranker.compute_score([["test", "test"]], normalize=True)
|
12 |
+
return reranker
|
13 |
+
|
14 |
+
|
15 |
+
def retrieve_relevant_docs(user_question, docs: List[Dict]) -> List[int]:
|
16 |
+
scores = st.session_state.retriever.compute_score([[user_question, d["doc_text"]] for d in docs], normalize=True)
|
17 |
+
normalized_scores = np.array(scores) / np.sum(scores)
|
18 |
+
|
19 |
+
selected_indices = np.where((np.array(scores) > 0.02) & (normalized_scores > 0.3))[0]
|
20 |
+
return selected_indices.tolist()
|
style/small_window.css
CHANGED
@@ -15,4 +15,10 @@
|
|
15 |
div[data-testid="stSidebarCollapsedControl"] button[data-testid="stBaseButton-headerNoPadding"]::after {
|
16 |
content: "More Use Cases"
|
17 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
}
|
|
|
15 |
div[data-testid="stSidebarCollapsedControl"] button[data-testid="stBaseButton-headerNoPadding"]::after {
|
16 |
content: "More Use Cases"
|
17 |
}
|
18 |
+
}
|
19 |
+
|
20 |
+
@media (max-width: 916px) and (max-height: 958px) {
|
21 |
+
div[height="480"][data-testid="stVerticalBlockBorderWrapper"] {
|
22 |
+
height: 380px;
|
23 |
+
}
|
24 |
}
|