nickmuchi commited on
Commit
8e77d9f
Β·
1 Parent(s): 805e19a

Update pages/3_Earnings_Semantic_Search_πŸ”Ž_.py

Browse files
pages/3_Earnings_Semantic_Search_πŸ”Ž_.py CHANGED
@@ -10,6 +10,8 @@ st.sidebar.header("Semantic Search")
10
 
11
  st.markdown("Earnings Semantic Search with LangChain, OpenAI & SBert")
12
 
 
 
13
  st.markdown(
14
  """
15
  <style>
@@ -59,7 +61,8 @@ st.markdown(
59
  )
60
 
61
  bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
62
- 'instructor-base': 'hkunlp/instructor-base'}
 
63
 
64
  search_input = st.text_input(
65
  label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
@@ -73,69 +76,89 @@ overlap_size = 50
73
 
74
  try:
75
 
76
- if search_input:
77
-
78
- if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
79
 
80
- ## Save to a dataframe for ease of visualization
81
- sen_df = st.session_state['sen_df']
82
-
83
- title = st.session_state['title']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- earnings_text = st.session_state['earnings_passages']
86
 
87
- print(f'earnings_to_be_embedded:{earnings_text}')
 
88
 
89
- st.session_state.eval_set = generate_eval(
90
- earnings_text, 10, 3000)
 
 
 
 
91
 
92
- # Display the question-answer pairs in the sidebar with smaller text
93
- for i, qa_pair in enumerate(st.session_state.eval_set):
94
- st.sidebar.markdown(
95
- f"""
96
- <div class="css-card">
97
- <span class="card-tag">Question {i + 1}</span>
98
- <p style="font-size: 12px;">{qa_pair['question']}</p>
99
- <p style="font-size: 12px;">{qa_pair['answer']}</p>
100
- </div>
101
- """,
102
- unsafe_allow_html=True,
103
  )
104
 
105
- embedding_model = bi_enc_dict[sbert_model_name]
106
-
107
- with st.spinner(
108
- text=f"Loading {embedding_model} embedding model and Generating Response..."
109
- ):
110
 
111
- docsearch = process_corpus(earnings_text,title, embedding_model)
112
-
113
- result = embed_text(search_input,docsearch)
114
-
115
-
116
- references = [doc.page_content for doc in result['source_documents']]
117
-
118
- answer = result['answer']
119
-
120
- sentiment_label = gen_sentiment(answer)
121
 
122
- ##### Sematic Search #####
123
-
124
- df = pd.DataFrame.from_dict({'Text':[answer],'Sentiment':[sentiment_label]})
125
-
126
-
127
- text_annotations = gen_annotated_text(df)[0]
128
-
129
- with st.expander(label='Query Result', expanded=True):
130
- annotated_text(text_annotations)
131
 
132
- with st.expander(label='References from Corpus used to Generate Result'):
133
- for ref in references:
134
- st.write(ref)
135
 
136
- else:
 
 
 
137
 
138
- st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
 
 
 
 
 
 
139
 
140
  else:
141
 
 
10
 
11
  st.markdown("Earnings Semantic Search with LangChain, OpenAI & SBert")
12
 
13
+ starter_message = "Ask me anything about the Earnings Call!"
14
+
15
  st.markdown(
16
  """
17
  <style>
 
61
  )
62
 
63
  bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
64
+ 'instructor-base': 'hkunlp/instructor-base',
65
+ 'FlagEmbedding': 'BAAI/bge-base-en'}
66
 
67
  search_input = st.text_input(
68
  label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
 
76
 
77
  try:
78
 
 
 
 
79
 
80
+ if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
81
+
82
+ ## Save to a dataframe for ease of visualization
83
+ sen_df = st.session_state['sen_df']
84
+
85
+ title = st.session_state['title']
86
+
87
+ earnings_text = st.session_state['earnings_passages']
88
+
89
+ print(f'earnings_to_be_embedded:{earnings_text}')
90
+
91
+ st.session_state.eval_set = generate_eval(
92
+ earnings_text, 10, 3000)
93
+
94
+ # Display the question-answer pairs in the sidebar with smaller text
95
+ for i, qa_pair in enumerate(st.session_state.eval_set):
96
+ st.sidebar.markdown(
97
+ f"""
98
+ <div class="css-card">
99
+ <span class="card-tag">Question {i + 1}</span>
100
+ <p style="font-size: 12px;">{qa_pair['question']}</p>
101
+ <p style="font-size: 12px;">{qa_pair['answer']}</p>
102
+ </div>
103
+ """,
104
+ unsafe_allow_html=True,
105
+ )
106
+
107
+ embedding_model = bi_enc_dict[sbert_model_name]
108
+
109
+ with st.spinner(
110
+ text=f"Loading {embedding_model} embedding model and Generating Response..."
111
+ ):
112
+
113
+ docsearch = create_vectorstore(earnings_text,title, embedding_model)
114
 
115
+ memory, agent_executor = create_memory_and_agent(search_input,docsearch)
116
 
117
+ if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
118
+ st.session_state["messages"] = [AIMessage(content=starter_message)]
119
 
120
+ for msg in st.session_state.messages:
121
+ if isinstance(msg, AIMessage):
122
+ st.chat_message("assistant").write(msg.content)
123
+ elif isinstance(msg, HumanMessage):
124
+ st.chat_message("user").write(msg.content)
125
+ memory.chat_memory.add_message(msg)
126
 
127
+ if user_question := st.chat_input(placeholder=starter_message):
128
+ st.chat_message("user").write(user_question)
129
+
130
+ with st.chat_message("assistant"):
131
+
132
+ st_callback = StreamlitCallbackHandler(st.container())
133
+
134
+ response = agent_executor(
135
+ {"input": user_question, "history": st.session_state.messages},
136
+ callbacks=[st_callback],
137
+ include_run_info=True,
138
  )
139
 
140
+ answer = response["output"]
 
 
 
 
141
 
142
+ st.session_state.messages.append(AIMessage(content=answer))
 
 
 
 
 
 
 
 
 
143
 
144
+ st.write(answer)
145
+
146
+ memory.save_context({"input": user_question}, response)
 
 
 
 
 
 
147
 
148
+ st.session_state["messages"] = memory.buffer
 
 
149
 
150
+ run_id = response["__run"].run_id
151
+
152
+ col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
153
+
154
 
155
+ with st.expander(label='Query Result with Sentiment Tag', expanded=True):
156
+
157
+ sentiment_label = gen_sentiment(answer)
158
+ df = pd.DataFrame.from_dict({'Text':[answer],'Sentiment':[sentiment_label]})
159
+ text_annotations = gen_annotated_text(df)[0]
160
+ annotated_text(text_annotations)
161
+
162
 
163
  else:
164