from time import sleep import streamlit as st import openai import pinecone from postgres_db import query_postgresql_realvest PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] INDEX_NAME = 'realvest-data-v2' EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023 MAX_LENGTH_DESC = 200 MATCH_SCORE_THR = 0.0 def test_pinecone(sleep_time: int=1): MAX_TRIALS = 5 trial = 0 stats = None while (stats is None) and (trial < MAX_TRIALS): try: print(f"BEFORE: trial: {trial}; stats: {stats}") stats = index.describe_index_stats() print(f"AFTER: trial: {trial}; stats: {stats}") return stats except pinecone.core.exceptions.PineconeProtocolError as err: print(f"Error, sleep! {err}") sleep(sleep_time) trial = trial + 1 raise Exception(f'max trials {MAX_TRIALS} Exceeded!') def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: int=1): MAX_TRIALS = 5 trial = 0 out = None while (out is None) and (trial < MAX_TRIALS): try: # print(f"BEFORE: trial: {trial}; stats: {out}") out = index.query(xq, top_k=top_k, include_metadata=include_metadata) # print(f"AFTER: trial: {trial}; stats: {out}") return out except pinecone.core.exceptions.PineconeProtocolError as err: print(f"Error, sleep! {err}") sleep(sleep_time) trial = trial + 1 raise Exception(f'max trials {MAX_TRIALS} Exceeded!') def sort_dict_by_value(d: dict, ascending: bool=True): """ Sort dictionary {k1: v1, k2: v2} by its value. The output is a sorted list of tuples [(k1, v1), (k2, v2)] """ return sorted(d.items(), key=lambda x: x[1], reverse=not ascending) # initialize connection to pinecone (get API key at app.pinecone.io) pinecone.init( api_key=PINECONE_API_KEY, environment="us-central1-gcp" # may be different, check at app.pinecone.io ) index = pinecone.Index(INDEX_NAME) stats = test_pinecone() print(f"Pinecone DB stats: {stats}") ### Main # st.set_page_config(layout="centered") css=''' ''' st.markdown(css, unsafe_allow_html=True) # Create a text input field query = st.text_input("What are you looking for?") # Create a button if st.button('Submit'): # ### call OpenAI text-embedding res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY) xq = res['data'][0]['embedding'] out = query_pinecone(xq, top_k=20, include_metadata=True) ### candidates metadata = {match['metadata']['product_id']: match['metadata'] for match in out['matches']} match_score = {match['metadata']['product_id']: match['score'] for match in out['matches']} above_thr_sorted = [ item for item in sort_dict_by_value(match_score, ascending=False) if item[1] > MATCH_SCORE_THR ] pids = metadata.keys() ### query pids pids_str = [f"'{pid}'" for pid, _ in above_thr_sorted] query = f""" SELECT productid, name, category, alternatename, url, logo, description FROM main_products WHERE productid in ({', '.join(pids_str)}); """ results = query_postgresql_realvest(query) results = { result['productid']: result for result in results } st.header("Results") st.divider() # display matched results for pid, match_score in above_thr_sorted: if pid not in results: continue result = results[pid] col_icon, col_info = st.columns([1, 3]) with col_icon: st.image(result["logo"]) with col_info: st.markdown(f"""match score: { round(100 * match_score, 2) }
**{result['name']}**
_Asking Price:_ {metadata[pid].get('asking_price', 'N/A')}
_Category:_ {metadata[pid].get('category', 'N/A')}
_Location:_ {metadata[pid].get('location', 'N/A')} """, unsafe_allow_html=True) st.markdown(f"""**_Description:_** {result['description'][:MAX_LENGTH_DESC]}...[more]({result['url']}) """)