Spaces:
Sleeping
Sleeping
from time import sleep | |
import streamlit as st | |
import openai | |
import pinecone | |
from postgres_db import query_postgresql_realvest | |
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] | |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] | |
INDEX_NAME = 'realvest-data-v2' | |
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023 | |
MAX_LENGTH_DESC = 200 | |
MATCH_SCORE_THR = 0.0 | |
def test_pinecone(sleep_time: int=1): | |
MAX_TRIALS = 5 | |
trial = 0 | |
stats = None | |
while (stats is None) and (trial < MAX_TRIALS): | |
try: | |
print(f"BEFORE: trial: {trial}; stats: {stats}") | |
stats = index.describe_index_stats() | |
print(f"AFTER: trial: {trial}; stats: {stats}") | |
return stats | |
except pinecone.core.exceptions.PineconeProtocolError as err: | |
print(f"Error, sleep! {err}") | |
sleep(sleep_time) | |
trial = trial + 1 | |
raise Exception(f'max trials {MAX_TRIALS} Exceeded!') | |
def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: int=1): | |
MAX_TRIALS = 5 | |
trial = 0 | |
out = None | |
while (out is None) and (trial < MAX_TRIALS): | |
try: | |
# print(f"BEFORE: trial: {trial}; stats: {out}") | |
out = index.query(xq, top_k=top_k, include_metadata=include_metadata) | |
# print(f"AFTER: trial: {trial}; stats: {out}") | |
return out | |
except pinecone.core.exceptions.PineconeProtocolError as err: | |
print(f"Error, sleep! {err}") | |
sleep(sleep_time) | |
trial = trial + 1 | |
raise Exception(f'max trials {MAX_TRIALS} Exceeded!') | |
def sort_dict_by_value(d: dict, ascending: bool=True): | |
""" | |
Sort dictionary {k1: v1, k2: v2} by its value. The output is | |
a sorted list of tuples [(k1, v1), (k2, v2)] | |
""" | |
return sorted(d.items(), key=lambda x: x[1], reverse=not ascending) | |
# initialize connection to pinecone (get API key at app.pinecone.io) | |
pinecone.init( | |
api_key=PINECONE_API_KEY, | |
environment="us-central1-gcp" # may be different, check at app.pinecone.io | |
) | |
index = pinecone.Index(INDEX_NAME) | |
stats = test_pinecone() | |
print(f"Pinecone DB stats: {stats}") | |
### Main | |
# st.set_page_config(layout="centered") | |
css=''' | |
<style> | |
section.main > div {max-width:70rem} | |
</style> | |
''' | |
st.markdown(css, unsafe_allow_html=True) | |
# Create a text input field | |
query = st.text_input("What are you looking for?") | |
# Create a button | |
if st.button('Submit'): | |
# ### call OpenAI text-embedding | |
res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY) | |
xq = res['data'][0]['embedding'] | |
out = query_pinecone(xq, top_k=20, include_metadata=True) | |
### candidates | |
metadata = {match['metadata']['product_id']: match['metadata'] for match in out['matches']} | |
match_score = {match['metadata']['product_id']: match['score'] for match in out['matches']} | |
above_thr_sorted = [ | |
item | |
for item in sort_dict_by_value(match_score, ascending=False) | |
if item[1] > MATCH_SCORE_THR | |
] | |
pids = metadata.keys() | |
### query pids | |
pids_str = [f"'{pid}'" for pid, _ in above_thr_sorted] | |
query = f""" | |
SELECT productid, name, category, alternatename, url, logo, description | |
FROM main_products | |
WHERE productid in ({', '.join(pids_str)}); | |
""" | |
results = query_postgresql_realvest(query) | |
results = { | |
result['productid']: result | |
for result in results | |
} | |
st.header("Results") | |
st.divider() | |
# display matched results | |
for pid, match_score in above_thr_sorted: | |
if pid not in results: | |
continue | |
result = results[pid] | |
col_icon, col_info, col_compare = st.columns([2, 6, 1]) | |
with col_icon: | |
st.image(result["logo"]) | |
with col_info: | |
st.markdown(f"""match score: { round(100 * match_score, 2) } | |
<br> | |
**{result['name']}** | |
<br> | |
_Asking Price:_ {metadata[pid].get('asking_price', 'N/A')} | |
<br> | |
_Category:_ {metadata[pid].get('category', 'N/A')} | |
<br> | |
_Location:_ {metadata[pid].get('location', 'N/A')} | |
""", unsafe_allow_html=True) | |
st.markdown(f"""**_Description:_** {result['description'][:MAX_LENGTH_DESC]}...[more]({result['url']}) | |
""") | |
with col_compare: | |
st.checkbox('compare', key=pid) | |