|
from sentence_transformers import SentenceTransformer |
|
import os |
|
from pymilvus import connections, utility, Collection, DataType, FieldSchema, CollectionSchema |
|
import streamlit as st |
|
|
|
|
|
class TextVectorizer: |
|
''' |
|
sentence transformers to extract sentence embeddings |
|
''' |
|
def vectorize(self, x): |
|
sen_embeddings = sent_model.encode(x) |
|
return sen_embeddings |
|
|
|
def get_milvus_collection(): |
|
uri = os.environ.get("URI") |
|
token = os.environ.get("TOKEN") |
|
connections.connect("default", uri=uri, token=token) |
|
print(f"Connected to DB") |
|
collection_name = os.environ.get("COLLECTION_NAME") |
|
collection = Collection(name=collection_name) |
|
collection.load() |
|
return collection |
|
|
|
def find_similar_news(text: str, top_n: int=5): |
|
search_params = {"metric_type": "L2"} |
|
search_vec = vectorizer.vectorize(text) |
|
result = collection.search([search_vec], |
|
anns_field='article_embed', |
|
param=search_params, |
|
limit=top_n, |
|
guarantee_timestamp=1, |
|
output_fields=['article_desc', 'article_category']) |
|
|
|
|
|
output_dict = {"input_text": text, "similar_texts": [hit.entity.get('article_desc') for hits in result for hit in hits], |
|
"text_category": [hit.entity.get('article_category') for hits in result for hit in hits]} |
|
txt_category = [f'<li><b>{txt}</b> (<i>{cat}</i>)</li>' for txt, cat in zip(output_dict.get('similar_texts'), output_dict.get('text_category'))] |
|
similar_txt = ''.join(txt_category) |
|
return f"<h4>Similar News Articles</h4><ol>{similar_txt}</ol>" |
|
|
|
|
|
|
|
vectorizer = TextVectorizer() |
|
collection = get_milvus_collection() |
|
sent_model = SentenceTransformer('all-mpnet-base-v2') |
|
|
|
def main(): |
|
|
|
|
|
st.markdown("<h3>Find Similar News With Sentence Transformers (all-mpnet-base-v2)</h3>", unsafe_allow_html=True) |
|
desc = '''<p style="font-size: 13px;"> |
|
Embeddings of 300,000 news headlines are stored in Milvus vector database, used as a feature store. |
|
Embeddings of the input headline are computed using sentence transformers (all-mpnet-base-v2). |
|
Similar news headlines are retrieved from the vector database using Euclidean distance as similarity metric. |
|
<span style="color: red;">This method (all-mpnet-base-v2) has the best performance compared to multi-qa-distilbert-cos-v1 fine-tuned using TSDAE |
|
and extracting embeddings from fine-tuned DistilBERT classifier.</span> |
|
</p> |
|
''' |
|
st.markdown(desc, unsafe_allow_html=True) |
|
news_txt = st.text_area("Paste the headline of a news article:", "", height=50) |
|
top_n = st.slider('Select number of similar articles to display', 1, 100, 10) |
|
|
|
if st.button("Submit"): |
|
result = find_similar_news(news_txt, top_n) |
|
|
|
st.markdown(result, unsafe_allow_html=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |