Spaces:
Running
Running
from time import sleep | |
import streamlit as st | |
import openai | |
import pinecone | |
from postgres_db import query_postgresql_realvest | |
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] | |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] | |
INDEX_NAME = 'realvest-data-v2' | |
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023 | |
def test_pinecone(sleep_time: int=1): | |
MAX_TRIALS = 5 | |
trial = 0 | |
stats = None | |
while (stats is None) and (trial < MAX_TRIALS): | |
try: | |
print(f"BEFORE: trial: {trial}; stats: {stats}") | |
stats = index.describe_index_stats() | |
print(f"AFTER: trial: {trial}; stats: {stats}") | |
return stats | |
except pinecone.core.exceptions.PineconeProtocolError as err: | |
print(f"Error, sleep! {err}") | |
sleep(sleep_time) | |
trial = trial + 1 | |
raise Exception(f'max trials {MAX_TRIALS} Exceeded!') | |
def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: int=1): | |
MAX_TRIALS = 5 | |
trial = 0 | |
out = None | |
while (out is None) and (trial < MAX_TRIALS): | |
try: | |
# print(f"BEFORE: trial: {trial}; stats: {out}") | |
out = index.query(xq, top_k=top_k, include_metadata=include_metadata) | |
# print(f"AFTER: trial: {trial}; stats: {out}") | |
return out | |
except pinecone.core.exceptions.PineconeProtocolError as err: | |
print(f"Error, sleep! {err}") | |
sleep(sleep_time) | |
trial = trial + 1 | |
raise Exception(f'max trials {MAX_TRIALS} Exceeded!') | |
# initialize connection to pinecone (get API key at app.pinecone.io) | |
pinecone.init( | |
api_key=PINECONE_API_KEY, | |
environment="us-central1-gcp" # may be different, check at app.pinecone.io | |
) | |
index = pinecone.Index(INDEX_NAME) | |
stats = test_pinecone() | |
print(f"Pinecone DB stats: {stats}") | |
### Main | |
# Create a text input field | |
query = st.text_input("What are you looking for?") | |
# Create a button | |
if st.button('Submit'): | |
# ### call OpenAI text-embedding | |
res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY) | |
xq = res['data'][0]['embedding'] | |
# out = index.query(xq, top_k=3, include_metadata=True) | |
out = query_pinecone(xq, top_k=3, include_metadata=True) | |
### display | |
# print(f"{'*'*30}results #3: {out}") | |
# st.write("Matched results") | |
# for match in out['matches']: | |
# st.write( match['id'] ) | |
### candidates | |
pids = [ | |
match['metadata']['product_id'] | |
for match in out['matches'] | |
] | |
### query pids | |
pids_str = [f"'{pid}'"for pid in pids] | |
query = f""" | |
SELECT productid, name, category, alternatename, url, logo, description | |
FROM main_products | |
WHERE productid in ({', '.join(pids_str)}); | |
""" | |
results = query_postgresql_realvest(query) | |
print(results) | |
for result in results: | |
st.write("---") | |
st.json(result) | |