nickmuchi commited on
Commit
e694dea
Β·
1 Parent(s): 1ad3fab

Update pages/3_Earnings_Semantic_Search_πŸ”Ž_.py

Browse files
pages/3_Earnings_Semantic_Search_πŸ”Ž_.py CHANGED
@@ -8,24 +8,36 @@ st.markdown("## Earnings Semantic Search with SBert")
8
  def gen_sentiment(text):
9
  '''Generate sentiment of given text'''
10
  return sent_pipe(text)[0]['label']
 
 
11
 
12
  search_input = st.text_input(
13
  label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
14
 
 
 
15
  top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2)
16
 
17
  window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)
18
 
19
- if search_input:
20
 
21
 
22
  if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
23
 
 
 
24
  ## Save to a dataframe for ease of visualization
25
  sen_df = st.session_state['sen_df']
26
 
27
  passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
28
 
 
 
 
 
 
 
29
  ##### Sematic Search #####
30
  # Encode the query using the bi-encoder and find potentially relevant passages
31
  corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
 
8
  def gen_sentiment(text):
9
  '''Generate sentiment of given text'''
10
  return sent_pipe(text)[0]['label']
11
+
12
+ bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1","neeva/query2query"]
13
 
14
  search_input = st.text_input(
15
  label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
16
 
17
+ sbert_model_name = st.sidebar.selectbox("Encoder Model", options=bi_enc_options, key='sbox')
18
+
19
  top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2)
20
 
21
  window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)
22
 
23
+ if search_input and sbert_model_name:
24
 
25
 
26
  if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
27
 
28
+
29
+
30
  ## Save to a dataframe for ease of visualization
31
  sen_df = st.session_state['sen_df']
32
 
33
  passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
34
 
35
+ with st.spinner(
36
+ text=f"Loading {sbert_model_name} encoder and embedding text into vector space. This might take a few seconds depending on the length of text..."
37
+ ):
38
+ sbert = load_sbert(sbert_model_name)
39
+
40
+
41
  ##### Sematic Search #####
42
  # Encode the query using the bi-encoder and find potentially relevant passages
43
  corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)