from time import sleep import streamlit as st import openai import pinecone from postgres_db import query_postgresql_realvest PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] INDEX_NAME = 'realvest-data-v2' EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023 def test_pinecone(sleep_time: int=1): MAX_TRIALS = 5 trial = 0 stats = None while (stats is None) and (trial < MAX_TRIALS): try: print(f"BEFORE: trial: {trial}; stats: {stats}") stats = index.describe_index_stats() print(f"AFTER: trial: {trial}; stats: {stats}") return stats except pinecone.core.exceptions.PineconeProtocolError as err: print(f"Error, sleep! {err}") sleep(sleep_time) trial = trial + 1 raise Exception(f'max trials {MAX_TRIALS} Exceeded!') def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: int=1): MAX_TRIALS = 5 trial = 0 out = None while (out is None) and (trial < MAX_TRIALS): try: # print(f"BEFORE: trial: {trial}; stats: {out}") out = index.query(xq, top_k=top_k, include_metadata=include_metadata) # print(f"AFTER: trial: {trial}; stats: {out}") return out except pinecone.core.exceptions.PineconeProtocolError as err: print(f"Error, sleep! {err}") sleep(sleep_time) trial = trial + 1 raise Exception(f'max trials {MAX_TRIALS} Exceeded!') # initialize connection to pinecone (get API key at app.pinecone.io) pinecone.init( api_key=PINECONE_API_KEY, environment="us-central1-gcp" # may be different, check at app.pinecone.io ) index = pinecone.Index(INDEX_NAME) stats = test_pinecone() print(f"Pinecone DB stats: {stats}") ### Main # Create a text input field query = st.text_input("What are you looking for?") # Create a button if st.button('Submit'): # ### call OpenAI text-embedding res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY) xq = res['data'][0]['embedding'] # out = index.query(xq, top_k=3, include_metadata=True) out = query_pinecone(xq, top_k=3, include_metadata=True) ### display # print(f"{'*'*30}results #3: {out}") # st.write("Matched results") # for match in out['matches']: # st.write( match['id'] ) ### candidates pids = [ match['metadata']['product_id'] for match in out['matches'] ] ### query pids pids_str = [f"'{pid}'"for pid in pids] query = f""" SELECT productid, name, category, alternatename, url, logo, description FROM main_products WHERE productid in ({', '.join(pids_str)}); """ results = query_postgresql_realvest(query) print(results) for result in results: st.write("---") st.json(result)