File size: 1,657 Bytes
e9b0055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2534036
 
 
8c5c01a
 
 
2534036
e9b0055
 
2534036
e9b0055
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from classes import classes
import numpy as np
from sentence_transformers import SentenceTransformer, util
import streamlit as st


# Simple sentence transformer
model_checkpoint = 'sentence-transformers/paraphrase-distilroberta-base-v1'
model = SentenceTransformer(model_checkpoint)

# Predefined messages and their embeddings
classes_text = np.array(classes)
classes_embeddings = model.encode(classes_text, convert_to_numpy=True)
assert classes_embeddings.shape[0] == len(classes)

# Function to compare the embedding of the human chat/text message with the embeddings of the 
# predefined messages
def convert(sentence_embedding: np.array, class_embeddings: np.array, top_n=5) -> np.array:
    similarities = np.array(util.cos_sim(sentence_embedding, class_embeddings)).reshape(-1,)
    top_n_indices = np.argsort(similarities)[::-1][0:top_n]

    return top_n_indices

# Simple title and description for the app
st.title('JHG Chat Message Converter')
st.write('Converts human chat/text messages into predefined chat messages via a sentence transformer')

# Number of predictions to display
n_preds = st.slider("Number of predictions to display:", min_value=1, max_value=10, step=1)

# Text box to enter a chat/text message
text = st.text_area('Enter chat message')

if text and n_preds:
    # Use the sentence transformer and "convert" function to display predicted, predefined messages
    text_embedding = model.encode(text, convert_to_numpy=True)
    indices = convert(text_embedding, classes_embeddings, top_n=n_preds)
    predicted_classes = classes_text[indices]

    for converted_message in predicted_classes:
        st.write(converted_message)