YingxuHe commited on
Commit
5a227f0
·
1 Parent(s): aba3925

change retriever

Browse files
Files changed (3) hide show
  1. src/content/agent.py +34 -19
  2. src/content/common.py +1 -41
  3. src/retrieval.py +103 -10
src/content/agent.py CHANGED
@@ -4,12 +4,11 @@ import base64
4
  import streamlit as st
5
 
6
  from src.generation import MAX_AUDIO_LENGTH
7
- from src.retrieval import retrieve_relevant_docs
8
  from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
10
  MODEL_NAMES,
11
  AUDIO_SAMPLES_W_INSTRUCT,
12
- STANDARD_QUERIES,
13
  DEFAULT_DIALOGUE_STATES,
14
  init_state_section,
15
  header_section,
@@ -20,21 +19,25 @@ from src.content.common import (
20
 
21
  LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip.
22
 
23
- ## User question
24
  {user_question}
25
 
26
- {audio_information_prompt}Please reply this user question with an friendly, accurate, and helpful answer."""
 
27
 
28
  AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip.
29
 
30
  ## Audio Information
31
  {audio_information}
32
 
33
- This may or may not contain relevant information to the user question, please use with caution.
34
 
35
  """
36
 
37
 
 
 
 
38
  def _update_audio(audio_bytes):
39
  origin_audio_array = bytes_to_array(audio_bytes)
40
  truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
@@ -46,7 +49,7 @@ def _update_audio(audio_bytes):
46
 
47
  @st.fragment
48
  def successful_example_section():
49
- audio_sample_names = [audio_sample_name for audio_sample_name in AUDIO_SAMPLES_W_INSTRUCT.keys()]
50
 
51
  st.markdown(":fire: **Successful Tasks and Examples**")
52
 
@@ -61,6 +64,7 @@ def successful_example_section():
61
  on_select=True,
62
  ag_messages=[],
63
  ag_model_messages=[],
 
64
  disprompt=True
65
  ),
66
  key='select')
@@ -83,7 +87,12 @@ def audio_attach_dialogue():
83
  label="**Upload Audio:**",
84
  label_visibility="collapsed",
85
  type=['wav', 'mp3'],
86
- on_change=lambda: st.session_state.update(on_upload=True, ag_messages=[], ag_model_messages=[]),
 
 
 
 
 
87
  key='upload'
88
  )
89
 
@@ -98,7 +107,12 @@ def audio_attach_dialogue():
98
  uploaded_file = st.audio_input(
99
  label="**Record Audio:**",
100
  label_visibility="collapsed",
101
- on_change=lambda: st.session_state.update(on_record=True, ag_messages=[], ag_model_messages=[]),
 
 
 
 
 
102
  key='record'
103
  )
104
 
@@ -132,15 +146,16 @@ def bottom_input_section():
132
 
133
 
134
  def _prepare_final_prompt_with_ui(one_time_prompt):
135
- relevant_query_indices = retrieve_relevant_docs(one_time_prompt, STANDARD_QUERIES)
136
- if len(st.session_state.ag_messages) <= 2:
137
- relevant_query_indices.append(0)
 
138
 
139
- relevant_query_indices = list(
140
- set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices)
141
- )
142
-
143
- st.session_state.ag_visited_query_indices.extend(relevant_query_indices)
144
 
145
  if not relevant_query_indices:
146
  return LLM_PROMPT_TEMPLATE.format(
@@ -149,7 +164,7 @@ def _prepare_final_prompt_with_ui(one_time_prompt):
149
  )
150
 
151
  audio_info = []
152
- with st.status("Thought process...", expanded=True) as status:
153
  for idx in relevant_query_indices:
154
  error_msg, warnings, response = retrive_response_with_ui(
155
  model_name=MODEL_NAMES["with_lora"]["vllm_name"],
@@ -194,7 +209,7 @@ def conversation_section():
194
  for warning_msg in message.get("warnings", []):
195
  st.warning(warning_msg)
196
  if process := message.get("process", []):
197
- with st.status("Thought process...", expanded=True, state="complete"):
198
  for proc in process:
199
  if proc.get("error"):
200
  st.error(proc["error"])
@@ -242,7 +257,7 @@ def conversation_section():
242
 
243
  def agent_page():
244
  init_state_section()
245
- header_section(component_name="Agent System", icon="👥")
246
 
247
  with st.sidebar:
248
  sidebar_fragment()
 
4
  import streamlit as st
5
 
6
  from src.generation import MAX_AUDIO_LENGTH
7
+ from src.retrieval import STANDARD_QUERIES, retrieve_relevant_docs
8
  from src.utils import bytes_to_array, array_to_bytes
9
  from src.content.common import (
10
  MODEL_NAMES,
11
  AUDIO_SAMPLES_W_INSTRUCT,
 
12
  DEFAULT_DIALOGUE_STATES,
13
  init_state_section,
14
  header_section,
 
19
 
20
  LLM_PROMPT_TEMPLATE = """User asked a question about the audio clip.
21
 
22
+ ## User Question
23
  {user_question}
24
 
25
+ {audio_information_prompt}Please reply to user's question with a friendly, accurate, and helpful answer."""
26
+
27
 
28
  AUDIO_INFO_TEMPLATE = """Here are some information about this audio clip.
29
 
30
  ## Audio Information
31
  {audio_information}
32
 
33
+ However, the audio analysis may or may not contain relevant information to the user question, please only reply the user with the relevant information.
34
 
35
  """
36
 
37
 
38
+ AUDIO_ANALYSIS_STATUS = "Analyzing audio..."
39
+
40
+
41
  def _update_audio(audio_bytes):
42
  origin_audio_array = bytes_to_array(audio_bytes)
43
  truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000]
 
49
 
50
  @st.fragment
51
  def successful_example_section():
52
+ audio_sample_names = [name for name in AUDIO_SAMPLES_W_INSTRUCT.keys() if "Paralinguistic" in name]
53
 
54
  st.markdown(":fire: **Successful Tasks and Examples**")
55
 
 
64
  on_select=True,
65
  ag_messages=[],
66
  ag_model_messages=[],
67
+ ag_visited_query_indices=[],
68
  disprompt=True
69
  ),
70
  key='select')
 
87
  label="**Upload Audio:**",
88
  label_visibility="collapsed",
89
  type=['wav', 'mp3'],
90
+ on_change=lambda: st.session_state.update(
91
+ on_upload=True,
92
+ ag_messages=[],
93
+ ag_model_messages=[],
94
+ ag_visited_query_indices=[]
95
+ ),
96
  key='upload'
97
  )
98
 
 
107
  uploaded_file = st.audio_input(
108
  label="**Record Audio:**",
109
  label_visibility="collapsed",
110
+ on_change=lambda: st.session_state.update(
111
+ on_record=True,
112
+ ag_messages=[],
113
+ ag_model_messages=[],
114
+ ag_visited_query_indices=[]
115
+ ),
116
  key='record'
117
  )
118
 
 
146
 
147
 
148
  def _prepare_final_prompt_with_ui(one_time_prompt):
149
+ with st.spinner("Searching appropriate querys..."):
150
+ relevant_query_indices = retrieve_relevant_docs(one_time_prompt)
151
+ if len(st.session_state.ag_messages) <= 2:
152
+ relevant_query_indices.append(0)
153
 
154
+ relevant_query_indices = list(
155
+ set(relevant_query_indices).difference(st.session_state.ag_visited_query_indices)
156
+ )
157
+
158
+ st.session_state.ag_visited_query_indices.extend(relevant_query_indices)
159
 
160
  if not relevant_query_indices:
161
  return LLM_PROMPT_TEMPLATE.format(
 
164
  )
165
 
166
  audio_info = []
167
+ with st.status(AUDIO_ANALYSIS_STATUS, expanded=True) as status:
168
  for idx in relevant_query_indices:
169
  error_msg, warnings, response = retrive_response_with_ui(
170
  model_name=MODEL_NAMES["with_lora"]["vllm_name"],
 
209
  for warning_msg in message.get("warnings", []):
210
  st.warning(warning_msg)
211
  if process := message.get("process", []):
212
+ with st.status(AUDIO_ANALYSIS_STATUS, expanded=True, state="complete"):
213
  for proc in process:
214
  if proc.get("error"):
215
  st.error(proc["error"])
 
257
 
258
  def agent_page():
259
  init_state_section()
260
+ header_section(component_name="Chatbot", icon="👥")
261
 
262
  with st.sidebar:
263
  sidebar_fragment()
src/content/common.py CHANGED
@@ -302,46 +302,6 @@ AUDIO_SAMPLES_W_INSTRUCT = {
302
  }
303
 
304
 
305
- STANDARD_QUERIES = [
306
- {
307
- "query_text": "Please transcribe this speech.",
308
- "doc_text": "Listen to a speech and write down exactly what is being said in text form. It's essentially converting spoken words into written words. Provide the exact transcription of the given audio. Record whatever the speaker has said into written text.",
309
- "response_prefix_text": "The transcription of the speech is: ",
310
- "ui_text": "speech trancription"
311
- },
312
- {
313
- "query_text": "Please describe what happended in this audio",
314
- "doc_text": "Text captions describing the sound events and environments in the audio clips, describing the events and actions happened in the audio.",
315
- "response_prefix_text": "Events in this audio clip: ",
316
- "ui_text": "audio caption"
317
- },
318
- {
319
- "query_text": "May I know the gender of the speakers",
320
- "doc_text": "Please identify the gender of the speaker. For instance, whether is the speaker male or female.",
321
- "response_prefix_text": "By analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders: ",
322
- "ui_text": "gender recognition"
323
- },
324
- {
325
- "query_text": "May I know the nationality of the speakers",
326
- "doc_text": "Discover speakers' nationality, country, or the place he is coming from, from his/her accent, pronunciation patterns, and other language-specific speech features influenced by cultural and linguistic backgrounds.",
327
- "response_prefix_text": "By analyzing accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds: ",
328
- "ui_text": "accent recognition"
329
- },
330
- {
331
- "query_text": "Can you guess which ethnic group this person is from based on their accent.",
332
- "doc_text": "Discover speakers' ethnic group, home country, or the place he is coming from, from his/her accent, tone, and other vocal characteristics influenced by cultural, regional, and linguistic factors.",
333
- "response_prefix_text": "By analyzing speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors: ",
334
- "ui_text": "accent recognition"
335
- },
336
- {
337
- "query_text": "What do you think the speakers are feeling.",
338
- "doc_text": "What do you think the speakers are feeling. Please identify speakers' emotions by analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy, which reflect emotional states such as happiness, anger, sadness, or fear.",
339
- "response_prefix_text": "By analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy: ",
340
- "ui_text": "emotion recognition"
341
- },
342
- ]
343
-
344
-
345
  def init_state_section():
346
  st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
347
 
@@ -414,7 +374,7 @@ def header_section(component_name="Playground", icon="🤖"):
414
  def sidebar_fragment():
415
  with st.container(height=256, border=False):
416
  st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
417
- st.page_link("pages/agent.py", disabled=st.session_state.disprompt, label="👥 Multi-Agent System")
418
  st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
419
 
420
  st.divider()
 
302
  }
303
 
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def init_state_section():
306
  st.set_page_config(page_title='MERaLiON-AudioLLM', page_icon = "🔥", layout='wide')
307
 
 
374
  def sidebar_fragment():
375
  with st.container(height=256, border=False):
376
  st.page_link("pages/playground.py", disabled=st.session_state.disprompt, label="🚀 Playground")
377
+ st.page_link("pages/agent.py", disabled=st.session_state.disprompt, label="👥 Chatbot")
378
  st.page_link("pages/voice_chat.py", disabled=st.session_state.disprompt, label="🗣️ Voice Chat (experimental)")
379
 
380
  st.divider()
src/retrieval.py CHANGED
@@ -1,20 +1,113 @@
1
- from typing import Dict, List
2
 
3
  import numpy as np
4
  import streamlit as st
5
- from FlagEmbedding import FlagReranker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  @st.cache_resource()
9
  def load_retriever():
10
- reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True)
11
- reranker.compute_score([["test", "test"]], normalize=True)
12
- return reranker
13
 
14
 
15
- def retrieve_relevant_docs(user_question, docs: List[Dict]) -> List[int]:
16
- scores = st.session_state.retriever.compute_score([[user_question, d["doc_text"]] for d in docs], normalize=True)
17
- normalized_scores = np.array(scores) / np.sum(scores)
18
-
19
- selected_indices = np.where((np.array(scores) > 0.2) & (normalized_scores > 0.3))[0]
20
  return selected_indices.tolist()
 
1
+ from typing import List
2
 
3
  import numpy as np
4
  import streamlit as st
5
+ from FlagEmbedding import BGEM3FlagModel
6
+
7
+
8
+ STANDARD_QUERIES = [
9
+ {
10
+ "query_text": "Please transcribe this speech.",
11
+ "doc_text": "Listen to a speech and write down exactly what is being said in text form. It's essentially converting spoken words into written words. Provide the exact transcription of the given audio. Record whatever the speaker has said into written text.",
12
+ "response_prefix_text": "The transcription of the speech is: ",
13
+ "ui_text": "speech trancription"
14
+ },
15
+ {
16
+ "query_text": "Please describe what happended in this audio",
17
+ "doc_text": "Text captions describing the sound events and environments in the audio clips, describing the events and actions happened in the audio.",
18
+ "response_prefix_text": "Events in this audio clip: ",
19
+ "ui_text": "audio caption"
20
+ },
21
+ {
22
+ "query_text": "May I know the gender of the speakers",
23
+ "doc_text": "Identify the gender, male or female, based on pitch, formants, harmonics, and prosody features, and other speech pattern differences between genders.",
24
+ "response_prefix_text": "By analyzing pitch, formants, harmonics, and prosody features, which reflect physiological and speech pattern differences between genders: ",
25
+ "ui_text": "gender recognition"
26
+ },
27
+ {
28
+ "query_text": "May I know the nationality of the speakers",
29
+ "doc_text": "Discover speakers' nationality, country, or the place he is coming from, from his/her accent, pronunciation patterns, and other language-specific speech features influenced by cultural and linguistic backgrounds.",
30
+ "response_prefix_text": "By analyzing accent, pronunciation patterns, intonation, rhythm, phoneme usage, and language-specific speech features influenced by cultural and linguistic backgrounds: ",
31
+ "ui_text": "natinoality recognition"
32
+ },
33
+ {
34
+ "query_text": "Can you guess which ethnic group this person is from based on their accent.",
35
+ "doc_text": "Discover speakers' ethnic group, home country, or the place he is coming from, from his/her accent, tone, and other vocal characteristics influenced by cultural, regional, and linguistic factors.",
36
+ "response_prefix_text": "By analyzing speech features like accent, tone, intonation, phoneme variations, and vocal characteristics influenced by cultural, regional, and linguistic factors: ",
37
+ "ui_text": "ethnic group recognition"
38
+ },
39
+ {
40
+ "query_text": "What do you think the speakers are feeling.",
41
+ "doc_text": "What do you think the speakers are feeling. Please identify speakers' emotions by analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy, which reflect emotional states such as happiness, anger, sadness, or fear.",
42
+ "response_prefix_text": "By analyzing vocal features like pitch, tone, volume, speech rate, rhythm, and spectral energy: ",
43
+ "ui_text": "emotion recognition"
44
+ },
45
+ ]
46
+
47
+
48
+ def _colbert_score(q_reps, p_reps):
49
+ """Compute colbert scores of input queries and passages.
50
+
51
+ Args:
52
+ q_reps (np.ndarray): Multi-vector embeddings for queries.
53
+ p_reps (np.ndarray): Multi-vector embeddings for passages/corpus.
54
+
55
+ Returns:
56
+ torch.Tensor: Computed colbert scores.
57
+ """
58
+ # q_reps, p_reps = torch.from_numpy(q_reps), torch.from_numpy(p_reps)
59
+ token_scores = np.einsum('in,jn->ij', q_reps, p_reps)
60
+ scores = token_scores.max(-1)
61
+ scores = np.sum(scores) / q_reps.shape[0]
62
+ return scores
63
+
64
+ class QueryRetriever:
65
+ def __init__(self, docs):
66
+ self.model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
67
+ self.docs = docs
68
+ self.doc_vectors = self.model.encode(
69
+ [d["doc_text"] for d in self.docs],
70
+ return_sparse=True,
71
+ return_colbert_vecs=True
72
+ )
73
+ self.scorer_attrs = {
74
+ "lexical_weights": {
75
+ "method": self.model.compute_lexical_matching_score,
76
+ "weight": 0.2
77
+ },
78
+ "colbert_vecs": {
79
+ "method": _colbert_score,
80
+ "weight": 0.8
81
+ },
82
+ }
83
+
84
+ def get_relevant_doc_indices(self, prompt, normalize=False) -> np.ndarray:
85
+ scores = np.zeros(len(self.docs))
86
+
87
+ if not prompt:
88
+ return scores
89
+
90
+ prompt_vector = self.model.encode(
91
+ prompt,
92
+ return_sparse=True,
93
+ return_colbert_vecs=True
94
+ )
95
+
96
+ for scorer_name, scorer_attrs in self.scorer_attrs.items():
97
+ for i, doc_vec in enumerate(self.doc_vectors[scorer_name]):
98
+ scores[i] += scorer_attrs["method"](prompt_vector[scorer_name], doc_vec)
99
+
100
+ if normalize:
101
+ scores = scores / np.sum(scores)
102
+ return scores
103
 
104
 
105
  @st.cache_resource()
106
  def load_retriever():
107
+ return QueryRetriever(docs=STANDARD_QUERIES)
 
 
108
 
109
 
110
+ def retrieve_relevant_docs(user_question: str) -> List[int]:
111
+ scores = st.session_state.retriever.get_relevant_doc_indices(user_question, normalize=True)
112
+ selected_indices = np.where(scores > 0.2)[0]
 
 
113
  return selected_indices.tolist()