Spaces:
Running
Running
change retriever
Browse files- src/content/agent.py +34 -19
- src/content/common.py +1 -41
- src/retrieval.py +103 -10
src/content/agent.py
CHANGED
@@ -4,12 +4,11 @@ import base64
|
|
4 |
import streamlit as st
|
5 |
|
6 |
from src.generation import MAX_AUDIO_LENGTH
|
7 |
-
from src.retrieval import retrieve_relevant_docs
|
8 |
from src.utils import bytes_to_array, array_to_bytes
|
9 |
from src.content.common import (
|
10 |
MODEL_NAMES,
|
11 |
AUDIO_SAMPLES_W_INSTRUCT,
|
12 |
-
STANDARD_QUERIES,
|
13 |
DEFAULT_DIALOGUE_STATES,
|
14 |
init_state_section,
|
15 |
header_section,
|
@@ -20,21 +19,25 @@ from src.content.common import (
|
|
20 |
|
21 |
LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip.
|
22 |
|
23 |
-
## User
|
24 |
{user_question}
|
25 |
|
26 |
-
{audio_information_prompt}Please reply
|
|
|
27 |
|
28 |
AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip.
|
29 |
|
30 |
## Audio Information
|
31 |
{audio_information}
|
32 |
|
33 |
-
|
34 |
|
35 |
"""
|
36 |
|
37 |
|
|
|
|
|
|
|
38 |
def _update_audio(audio_bytes):
|
39 |
origin_audio_array = bytes_to_array(audio_bytes)
|
40 |
truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
|
@@ -46,7 +49,7 @@ def _update_audio(audio_bytes):
|
|
46 |
|
47 |
@st.fragment
|
48 |
def successful_example_section():
|
49 |
-
audio_sample_names = [
|
50 |
|
51 |
st.markdown(":fire: **Successful Tasks and Examples**")
|
52 |
|
@@ -61,6 +64,7 @@ def successful_example_section():
|
|
61 |
on_select=True,
|
62 |
ag_messages=[],
|
63 |
ag_model_messages=[],
|
|
|
64 |
disprompt=True
|
65 |
),
|
66 |
key='select')
|
@@ -83,7 +87,12 @@ def audio_attach_dialogue():
|
|
83 |
label="**Upload Audio:**",
|
84 |
label_visibility="collapsed",
|
85 |
type=['wav', 'mp3'],
|
86 |
-
on_change=lambda: st.session_state.update(
|
|
|
|
|
|
|
|
|
|
|
87 |
key='upload'
|
88 |
)
|
89 |
|
@@ -98,7 +107,12 @@ def audio_attach_dialogue():
|
|
98 |
uploaded_file = st.audio_input(
|
99 |
label="**Record Audio:**",
|
100 |
label_visibility="collapsed",
|
101 |
-
on_change=lambda: st.session_state.update(
|
|
|
|
|
|
|
|
|
|
|
102 |
key='record'
|
103 |
)
|
104 |
|
@@ -132,15 +146,16 @@ def bottom_input_section():
|
|
132 |
|
133 |
|
134 |
def _prepare_final_prompt_with_ui(one_time_prompt):
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
|
145 |
if not relevant_query_indices:
|
146 |
return LLM_PROMPT_TEMPLATE.format(
|
@@ -149,7 +164,7 @@ def _prepare_final_prompt_with_ui(one_time_prompt):
|
|
149 |
)
|
150 |
|
151 |
audio_info = []
|
152 |
-
with st.status(
|
153 |
for idx in relevant_query_indices:
|
154 |
error_msg, warnings, response = retrive_response_with_ui(
|
155 |
model_name=MODEL_NAMES["with_lora"]["vllm_name"],
|
@@ -194,7 +209,7 @@ def conversation_section():
|
|
194 |
for warning_msg in message.get("warnings", []):
|
195 |
st.warning(warning_msg)
|
196 |
if process := message.get("process", []):
|
197 |
-
with st.status(
|
198 |
for proc in process:
|
199 |
if proc.get("error"):
|
200 |
st.error(proc["error"])
|
@@ -242,7 +257,7 @@ def conversation_section():
|
|
242 |
|
243 |
def agent_page():
|
244 |
init_state_section()
|
245 |
-
header_section(component_name="
|
246 |
|
247 |
with st.sidebar:
|
248 |
sidebar_fragment()
|
|
|
4 |
import streamlit as st
|
5 |
|
6 |
from src.generation import MAX_AUDIO_LENGTH
|
7 |
+
from src.retrieval import STANDARD_QUERIES, retrieve_relevant_docs
|
8 |
from src.utils import bytes_to_array, array_to_bytes
|
9 |
from src.content.common import (
|
10 |
MODEL_NAMES,
|
11 |
AUDIO_SAMPLES_W_INSTRUCT,
|
|
|
12 |
DEFAULT_DIALOGUE_STATES,
|
13 |
init_state_section,
|
14 |
header_section,
|
|
|
19 |
|
20 |
LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip.
|
21 |
|
22 |
+
## User Question
|
23 |
{user_question}
|
24 |
|
25 |
+
{audio_information_prompt}Please reply to user's question with a friendly, accurate, and helpful answer."""
|
26 |
+
|
27 |
|
28 |
AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip.
|
29 |
|
30 |
## Audio Information
|
31 |
{audio_information}
|
32 |
|
33 |
+
However, the audio analysis may or may not contain relevant information to the user question, please only reply the user with the relevant information.
|
34 |
|
35 |
"""
|
36 |
|
37 |
|
38 |
+
AUDIO_ANALYSIS_STATUS = "Analyzing audio..."
|
39 |
+
|
40 |
+
|
41 |
def _update_audio(audio_bytes):
|
42 |
origin_audio_array = bytes_to_array(audio_bytes)
|
43 |
truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
|
|
|
49 |
|
50 |
@st.fragment
|
51 |
def successful_example_section():
|
52 |
+
audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys() if "Paralinguistic" in name]
|
53 |
|
54 |
st.markdown(":fire: **Successful Tasks and Examples**")
|
55 |
|
|
|
64 |
on_select=True,
|
65 |
ag_messages=[],
|
66 |
ag_model_messages=[],
|
67 |
+
ag_visited_query_indices=[],
|
68 |
disprompt=True
|
69 |
),
|
70 |
key='select')
|
|
|
87 |
label="**Upload Audio:**",
|
88 |
label_visibility="collapsed",
|
89 |
type=['wav', 'mp3'],
|
90 |
+
on_change=lambda: st.session_state.update(
|
91 |
+
on_upload=True,
|
92 |
+
ag_messages=[],
|
93 |
+
ag_model_messages=[],
|
94 |
+
ag_visited_query_indices=[]
|
95 |
+
),
|
96 |
key='upload'
|
97 |
)
|
98 |
|
|
|
107 |
uploaded_file = st.audio_input(
|
108 |
label="**Record Audio:**",
|
109 |
label_visibility="collapsed",
|
110 |
+
on_change=lambda: st.session_state.update(
|
111 |
+
on_record=True,
|
112 |
+
ag_messages=[],
|
113 |
+
ag_model_messages=[],
|
114 |
+
ag_visited_query_indices=[]
|
115 |
+
),
|
116 |
key='record'
|
117 |
)
|
118 |
|
|
|
146 |
|
147 |
|
148 |
def _prepare_final_prompt_with_ui(one_time_prompt):
|
149 |
+
with st.spinner("Searching appropriate querys..."):
|
150 |
+
relevant_query_indices = retrieve_relevant_docs(one_time_prompt)
|
151 |
+
if len(st.session_state.ag_messages) <= 2:
|
152 |
+
relevant_query_indices.append(0)
|
153 |
|
154 |
+
relevant_query_indices = list(
|
155 |
+
set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices)
|
156 |
+
)
|
157 |
+
|
158 |
+
st.session_state.ag_visited_query_indices.extend(relevant_query_indices)
|
159 |
|
160 |
if not relevant_query_indices:
|
161 |
return LLM_PROMPT_TEMPLATE.format(
|
|
|
164 |
)
|
165 |
|
166 |
audio_info = []
|
167 |
+
with st.status(AUDIO_ANALYSIS_STATUS, expanded=True) as status:
|
168 |
for idx in relevant_query_indices:
|
169 |
error_msg, warnings, response = retrive_response_with_ui(
|
170 |
model_name=MODEL_NAMES["with_lora"]["vllm_name"],
|
|
|
209 |
for warning_msg in message.get("warnings", []):
|
210 |
st.warning(warning_msg)
|
211 |
if process := message.get("process", []):
|
212 |
+
with st.status(AUDIO_ANALYSIS_STATUS, expanded=True, state="complete"):
|
213 |
for proc in process:
|
214 |
if proc.get("error"):
|
215 |
st.error(proc["error"])
|
|
|
257 |
|
258 |
def agent_page():
|
259 |
init_state_section()
|
260 |
+
header_section(component_name="Chatbot", icon="👥")
|
261 |
|
262 |
with st.sidebar:
|
263 |
sidebar_fragment()
|
src/content/common.py
CHANGED
@@ -302,46 +302,6 @@ AUDIO_SAMPLES_W_INSTRUCT = {
|
|
302 |
}
|
303 |
|
304 |
|
305 |
-
STANDARD_QUERIES = [
|
306 |
-
{
|
307 |
-
"query_text": "Please transcribe this speech.",
|
308 |
-
"doc_text": "Listen to a speech and write down exactly what is being said in text form. It's essentially converting spoken words into written words. Provide the exact transcription of the given audio. Record whatever the speaker has said into written text.",
|
309 |
-
"response_prefix_text": "The transcription of the speech is: ",
|
310 |
-
"ui_text": "speech trancription"
|
311 |
-
},
|
312 |
-
{
|
313 |
-
"query_text": "Please describe what happended in this audio",
|
314 |
-
"doc_text": "Text captions describing the sound events and environments in the audio clips, describing the events and actions happened in the audio.",
|
315 |
-
"response_prefix_text": "Events in this audio clip: ",
|
316 |
-
"ui_text": "audio caption"
|
317 |
-
},
|
318 |
-
{
|
319 |
-
"query_text": "May I know the gender of the speakers",
|
320 |
-
"doc_text": "Please identify the gender of the speaker. For instance, whether is the speaker male or female.",
|
321 |
-
"response_prefix_text": "By analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders: ",
|
322 |
-
"ui_text": "gender recognition"
|
323 |
-
},
|
324 |
-
{
|
325 |
-
"query_text": "May I know the nationality of the speakers",
|
326 |
-
"doc_text": "Discover speakers' nationality, country, or the place he is coming from, from his/her accent, pronunciation patterns, and other language-specific speech features influenced by cultural and linguistic backgrounds.",
|
327 |
-
"response_prefix_text": "By analyzing accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds: ",
|
328 |
-
"ui_text": "accent recognition"
|
329 |
-
},
|
330 |
-
{
|
331 |
-
"query_text": "Can you guess which ethnic group this person is from based on their accent.",
|
332 |
-
"doc_text": "Discover speakers' ethnic group, home country, or the place he is coming from, from his/her accent, tone, and other vocal characteristics influenced by cultural, regional, and linguistic factors.",
|
333 |
-
"response_prefix_text": "By analyzing speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors: ",
|
334 |
-
"ui_text": "accent recognition"
|
335 |
-
},
|
336 |
-
{
|
337 |
-
"query_text": "What do you think the speakers are feeling.",
|
338 |
-
"doc_text": "What do you think the speakers are feeling. Please identify speakers' emotions by analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy, which reflect emotional states such as happiness, anger, sadness, or fear.",
|
339 |
-
"response_prefix_text": "By analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy: ",
|
340 |
-
"ui_text": "emotion recognition"
|
341 |
-
},
|
342 |
-
]
|
343 |
-
|
344 |
-
|
345 |
def init_state_section():
|
346 |
st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
|
347 |
|
@@ -414,7 +374,7 @@ def header_section(component_name="Playground", icon="🤖"):
|
|
414 |
def sidebar_fragment():
|
415 |
with st.container(height=256, border=False):
|
416 |
st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
|
417 |
-
st.page_link("pages/agent.py", disabled=st.session_state.disprompt, label="👥
|
418 |
st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
|
419 |
|
420 |
st.divider()
|
|
|
302 |
}
|
303 |
|
304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
def init_state_section():
|
306 |
st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
|
307 |
|
|
|
374 |
def sidebar_fragment():
|
375 |
with st.container(height=256, border=False):
|
376 |
st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
|
377 |
+
st.page_link("pages/agent.py", disabled=st.session_state.disprompt, label="👥 Chatbot")
|
378 |
st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
|
379 |
|
380 |
st.divider()
|
src/retrieval.py
CHANGED
@@ -1,20 +1,113 @@
|
|
1 |
-
from typing import
|
2 |
|
3 |
import numpy as np
|
4 |
import streamlit as st
|
5 |
-
from FlagEmbedding import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
@st.cache_resource()
|
9 |
def load_retriever():
|
10 |
-
|
11 |
-
reranker.compute_score([["test", "test"]], normalize=True)
|
12 |
-
return reranker
|
13 |
|
14 |
|
15 |
-
def retrieve_relevant_docs(user_question
|
16 |
-
scores = st.session_state.retriever.
|
17 |
-
|
18 |
-
|
19 |
-
selected_indices = np.where((np.array(scores) > 0.2) & (normalized_scores > 0.3))[0]
|
20 |
return selected_indices.tolist()
|
|
|
1 |
+
from typing import List
|
2 |
|
3 |
import numpy as np
|
4 |
import streamlit as st
|
5 |
+
from FlagEmbedding import BGEM3FlagModel
|
6 |
+
|
7 |
+
|
8 |
+
STANDARD_QUERIES = [
|
9 |
+
{
|
10 |
+
"query_text": "Please transcribe this speech.",
|
11 |
+
"doc_text": "Listen to a speech and write down exactly what is being said in text form. It's essentially converting spoken words into written words. Provide the exact transcription of the given audio. Record whatever the speaker has said into written text.",
|
12 |
+
"response_prefix_text": "The transcription of the speech is: ",
|
13 |
+
"ui_text": "speech trancription"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"query_text": "Please describe what happended in this audio",
|
17 |
+
"doc_text": "Text captions describing the sound events and environments in the audio clips, describing the events and actions happened in the audio.",
|
18 |
+
"response_prefix_text": "Events in this audio clip: ",
|
19 |
+
"ui_text": "audio caption"
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"query_text": "May I know the gender of the speakers",
|
23 |
+
"doc_text": "Identify the gender, male or female, based on pitch, formants, harmonics, and prosody features, and other speech pattern differences between genders.",
|
24 |
+
"response_prefix_text": "By analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders: ",
|
25 |
+
"ui_text": "gender recognition"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"query_text": "May I know the nationality of the speakers",
|
29 |
+
"doc_text": "Discover speakers' nationality, country, or the place he is coming from, from his/her accent, pronunciation patterns, and other language-specific speech features influenced by cultural and linguistic backgrounds.",
|
30 |
+
"response_prefix_text": "By analyzing accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds: ",
|
31 |
+
"ui_text": "natinoality recognition"
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"query_text": "Can you guess which ethnic group this person is from based on their accent.",
|
35 |
+
"doc_text": "Discover speakers' ethnic group, home country, or the place he is coming from, from his/her accent, tone, and other vocal characteristics influenced by cultural, regional, and linguistic factors.",
|
36 |
+
"response_prefix_text": "By analyzing speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors: ",
|
37 |
+
"ui_text": "ethnic group recognition"
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"query_text": "What do you think the speakers are feeling.",
|
41 |
+
"doc_text": "What do you think the speakers are feeling. Please identify speakers' emotions by analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy, which reflect emotional states such as happiness, anger, sadness, or fear.",
|
42 |
+
"response_prefix_text": "By analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy: ",
|
43 |
+
"ui_text": "emotion recognition"
|
44 |
+
},
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def _colbert_score(q_reps, p_reps):
|
49 |
+
"""Compute colbert scores of input queries and passages.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
q_reps (np.ndarray): Multi-vector embeddings for queries.
|
53 |
+
p_reps (np.ndarray): Multi-vector embeddings for passages/corpus.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
torch.Tensor: Computed colbert scores.
|
57 |
+
"""
|
58 |
+
# q_reps, p_reps = torch.from_numpy(q_reps), torch.from_numpy(p_reps)
|
59 |
+
token_scores = np.einsum('in,jn->ij', q_reps, p_reps)
|
60 |
+
scores = token_scores.max(-1)
|
61 |
+
scores = np.sum(scores) / q_reps.shape[0]
|
62 |
+
return scores
|
63 |
+
|
64 |
+
class QueryRetriever:
|
65 |
+
def __init__(self, docs):
|
66 |
+
self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
|
67 |
+
self.docs = docs
|
68 |
+
self.doc_vectors = self.model.encode(
|
69 |
+
[d["doc_text"] for d in self.docs],
|
70 |
+
return_sparse=True,
|
71 |
+
return_colbert_vecs=True
|
72 |
+
)
|
73 |
+
self.scorer_attrs = {
|
74 |
+
"lexical_weights": {
|
75 |
+
"method": self.model.compute_lexical_matching_score,
|
76 |
+
"weight": 0.2
|
77 |
+
},
|
78 |
+
"colbert_vecs": {
|
79 |
+
"method": _colbert_score,
|
80 |
+
"weight": 0.8
|
81 |
+
},
|
82 |
+
}
|
83 |
+
|
84 |
+
def get_relevant_doc_indices(self, prompt, normalize=False) -> np.ndarray:
|
85 |
+
scores = np.zeros(len(self.docs))
|
86 |
+
|
87 |
+
if not prompt:
|
88 |
+
return scores
|
89 |
+
|
90 |
+
prompt_vector = self.model.encode(
|
91 |
+
prompt,
|
92 |
+
return_sparse=True,
|
93 |
+
return_colbert_vecs=True
|
94 |
+
)
|
95 |
+
|
96 |
+
for scorer_name, scorer_attrs in self.scorer_attrs.items():
|
97 |
+
for i, doc_vec in enumerate(self.doc_vectors[scorer_name]):
|
98 |
+
scores[i] += scorer_attrs["method"](prompt_vector[scorer_name], doc_vec)
|
99 |
+
|
100 |
+
if normalize:
|
101 |
+
scores = scores / np.sum(scores)
|
102 |
+
return scores
|
103 |
|
104 |
|
105 |
@st.cache_resource()
|
106 |
def load_retriever():
|
107 |
+
return QueryRetriever(docs=STANDARD_QUERIES)
|
|
|
|
|
108 |
|
109 |
|
110 |
+
def retrieve_relevant_docs(user_question: str) -> List[int]:
|
111 |
+
scores = st.session_state.retriever.get_relevant_doc_indices(user_question, normalize=True)
|
112 |
+
selected_indices = np.where(scores > 0.2)[0]
|
|
|
|
|
113 |
return selected_indices.tolist()
|