File size: 3,521 Bytes
411678e
31b6e92
411678e
 
 
3b52176
 
d98d50d
 
 
 
3b52176
cc23803
3b52176
 
 
 
 
e9eb13c
d98d50d
3b52176
d98d50d
ee6d004
741aa8b
 
 
 
 
 
 
 
 
 
 
 
ee6d004
741aa8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d98d50d
741aa8b
d98d50d
 
 
 
 
 
 
 
 
 
 
741aa8b
d98d50d
 
741aa8b
 
d98d50d
 
 
741aa8b
 
d98d50d
 
741aa8b
d98d50d
 
741aa8b
 
 
 
d98d50d
8eb51fc
 
 
 
ee6d004
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import streamlit as st
from functions import *

st.set_page_config(page_title="Earnings Semantic Search", page_icon="πŸ”Ž")
st.sidebar.header("Semantic Search")
st.markdown("## Earnings Semantic Search with SBert")

def gen_sentiment(text):
    '''Generate sentiment of given text'''
    return sent_pipe(text)[0]['label']

search_input = st.text_input(
        label='Enter Your Search Query',value= "What challenges did the business face?", key='search')
        
top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2)

window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=5,value=3)

if search_input:
    

    if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
    
        ## Save to a dataframe for ease of visualization
        sen_df = st.session_state['sen_df']
            
        passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
        
        ##### Sematic Search #####
        # Encode the query using the bi-encoder and find potentially relevant passages
        corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
        question_embedding = sbert.encode(search_input, convert_to_tensor=True)
        question_embedding = question_embedding.cpu()
        hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k,score_function=util.dot_score)
        hits = hits[0]  # Get the hits for the first query
        
        ##### Re-Ranking #####
        # Now, score all retrieved passages with the cross_encoder
        cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
        cross_scores = cross_encoder.predict(cross_inp)
        
        # Sort results by the cross-encoder scores
        for idx in range(len(cross_scores)):
            hits[idx]['cross-score'] = cross_scores[idx]
        
        # Output of top-3 hits from re-ranker
        hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
        
        score='cross-score'
        df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
        df['Score'] = round(df['Score'],2)
        df['Sentiment'] = df.Text.apply(gen_sentiment)
        
        def gen_annotated_text(df):
            '''Generate annotated text'''
            
            tag_list=[]
            for row in df.itertuples():
                label = row[3]
                text = row[2]
                if label == 'Positive':
                    tag_list.append((text,label,'#8fce00'))
                elif label == 'Negative':
                    tag_list.append((text,label,'#f44336'))
                else:
                    tag_list.append((text,label,'#fff2cc'))
                
            return tag_list  
        
        text_annotations = gen_annotated_text(df)

        first, second = text_annotations[0], text_annotations[1]
        
        
        with st.expander(label='Best Search Query Result', expanded=True):
            annotated_text(first)
            
        with st.expander(label='Alternative Search Query Result'):
            annotated_text(second)
            
    else:
        
        st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
        
else:

    st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')