nickmuchi commited on
Commit
a957eeb
Β·
1 Parent(s): 648179d

Update pages/3_Earnings_Semantic_Search_πŸ”Ž_.py

Browse files
pages/3_Earnings_Semantic_Search_πŸ”Ž_.py CHANGED
@@ -20,79 +20,85 @@ top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5
20
 
21
  window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)
22
 
23
- if search_input:
24
-
25
- if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
26
-
27
- ## Save to a dataframe for ease of visualization
28
- sen_df = st.session_state['sen_df']
29
-
30
- passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
31
-
32
- with st.spinner(
33
- text=f"Loading {sbert_model_name} encoder..."
34
- ):
35
- sbert = load_sbert(sbert_model_name)
36
-
37
-
38
- ##### Sematic Search #####
39
- # Encode the query using the bi-encoder and find potentially relevant passages
40
- corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
41
- question_embedding = sbert.encode(search_input, convert_to_tensor=True)
42
- question_embedding = question_embedding.cpu()
43
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
44
- hits = hits[0] # Get the hits for the first query
45
-
46
- ##### Re-Ranking #####
47
- # Now, score all retrieved passages with the cross_encoder
48
- cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
49
- cross_scores = cross_encoder.predict(cross_inp)
50
-
51
- # Sort results by the cross-encoder scores
52
- for idx in range(len(cross_scores)):
53
- hits[idx]['cross-score'] = cross_scores[idx]
54
-
55
- # Output of top-3 hits from re-ranker
56
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
57
 
58
- score='cross-score'
59
- df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
60
- df['Score'] = round(df['Score'],2)
61
- df['Sentiment'] = df.Text.apply(gen_sentiment)
62
 
63
- def gen_annotated_text(df):
64
- '''Generate annotated text'''
 
 
65
 
66
- tag_list=[]
67
- for row in df.itertuples():
68
- label = row[3]
69
- text = row[2]
70
- if label == 'Positive':
71
- tag_list.append((text,label,'#8fce00'))
72
- elif label == 'Negative':
73
- tag_list.append((text,label,'#f44336'))
74
- else:
75
- tag_list.append((text,label,'#000000'))
76
 
77
- return tag_list
78
-
79
- text_annotations = gen_annotated_text(df)
80
-
81
- first, second = text_annotations[0], text_annotations[1]
82
-
83
-
84
- with st.expander(label='Best Search Query Result', expanded=True):
85
- annotated_text(first)
86
 
87
- with st.expander(label='Alternative Search Query Result'):
88
- annotated_text(second)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  else:
 
 
91
 
92
- st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
93
-
94
- else:
95
-
96
- st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
97
 
98
 
 
20
 
21
  window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)
22
 
23
+ try:
24
+
25
+ if search_input:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
 
 
 
28
 
29
+ ## Save to a dataframe for ease of visualization
30
+ sen_df = st.session_state['sen_df']
31
+
32
+ passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
33
 
34
+ with st.spinner(
35
+ text=f"Loading {sbert_model_name} encoder..."
36
+ ):
37
+ sbert = load_sbert(sbert_model_name)
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
39
 
40
+ ##### Sematic Search #####
41
+ # Encode the query using the bi-encoder and find potentially relevant passages
42
+ corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
43
+ question_embedding = sbert.encode(search_input, convert_to_tensor=True)
44
+ question_embedding = question_embedding.cpu()
45
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
46
+ hits = hits[0] # Get the hits for the first query
47
+
48
+ ##### Re-Ranking #####
49
+ # Now, score all retrieved passages with the cross_encoder
50
+ cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
51
+ cross_scores = cross_encoder.predict(cross_inp)
52
+
53
+ # Sort results by the cross-encoder scores
54
+ for idx in range(len(cross_scores)):
55
+ hits[idx]['cross-score'] = cross_scores[idx]
56
+
57
+ # Output of top-3 hits from re-ranker
58
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
59
+
60
+ score='cross-score'
61
+ df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
62
+ df['Score'] = round(df['Score'],2)
63
+ df['Sentiment'] = df.Text.apply(gen_sentiment)
64
+
65
+ def gen_annotated_text(df):
66
+ '''Generate annotated text'''
67
+
68
+ tag_list=[]
69
+ for row in df.itertuples():
70
+ label = row[3]
71
+ text = row[2]
72
+ if label == 'Positive':
73
+ tag_list.append((text,label,'#8fce00'))
74
+ elif label == 'Negative':
75
+ tag_list.append((text,label,'#f44336'))
76
+ else:
77
+ tag_list.append((text,label,'#000000'))
78
+
79
+ return tag_list
80
+
81
+ text_annotations = gen_annotated_text(df)
82
+
83
+ first, second = text_annotations[0], text_annotations[1]
84
+
85
+
86
+ with st.expander(label='Best Search Query Result', expanded=True):
87
+ annotated_text(first)
88
+
89
+ with st.expander(label='Alternative Search Query Result'):
90
+ annotated_text(second)
91
+
92
+ else:
93
+
94
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
95
 
96
  else:
97
+
98
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
99
 
100
+ except RuntimeError:
101
+
102
+ st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file'
 
 
103
 
104