Spaces:
Running
Running
connect to postgresql
Browse files- app.py +64 -6
- postgres_db.py +54 -0
app.py
CHANGED
@@ -1,19 +1,55 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
import openai
|
3 |
import pinecone
|
|
|
4 |
|
5 |
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
|
6 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
7 |
INDEX_NAME = 'realvest-data-v2'
|
8 |
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# initialize connection to pinecone (get API key at app.pinecone.io)
|
11 |
pinecone.init(
|
12 |
api_key=PINECONE_API_KEY,
|
13 |
environment="us-central1-gcp" # may be different, check at app.pinecone.io
|
14 |
)
|
15 |
index = pinecone.Index(INDEX_NAME)
|
16 |
-
stats =
|
17 |
print(f"Pinecone DB stats: {stats}")
|
18 |
|
19 |
### Main
|
@@ -26,10 +62,32 @@ if st.button('Submit'):
|
|
26 |
# ### call OpenAI text-embedding
|
27 |
res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
|
28 |
xq = res['data'][0]['embedding']
|
29 |
-
out = index.query(xq, top_k=3, include_metadata=True)
|
|
|
30 |
|
31 |
### display
|
32 |
-
print(f"{'*'*30}results #3: {out}")
|
33 |
-
st.write("Matched results")
|
34 |
-
for match in out['matches']:
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import sleep
|
2 |
import streamlit as st
|
3 |
import openai
|
4 |
import pinecone
|
5 |
+
from postgres_db import query_postgresql_realvest
|
6 |
|
7 |
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
|
8 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
9 |
INDEX_NAME = 'realvest-data-v2'
|
10 |
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
|
11 |
|
12 |
+
def test_pinecone(sleep_time: int=1):
|
13 |
+
MAX_TRIALS = 5
|
14 |
+
trial = 0
|
15 |
+
stats = None
|
16 |
+
while (stats is None) and (trial < MAX_TRIALS):
|
17 |
+
try:
|
18 |
+
print(f"BEFORE: trial: {trial}; stats: {stats}")
|
19 |
+
stats = index.describe_index_stats()
|
20 |
+
print(f"AFTER: trial: {trial}; stats: {stats}")
|
21 |
+
return stats
|
22 |
+
except pinecone.core.exceptions.PineconeProtocolError as err:
|
23 |
+
print(f"Error, sleep! {err}")
|
24 |
+
sleep(sleep_time)
|
25 |
+
trial = trial + 1
|
26 |
+
|
27 |
+
raise Exception(f'max trials {MAX_TRIALS} Exceeded!')
|
28 |
+
|
29 |
+
def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: int=1):
|
30 |
+
MAX_TRIALS = 5
|
31 |
+
trial = 0
|
32 |
+
out = None
|
33 |
+
while (out is None) and (trial < MAX_TRIALS):
|
34 |
+
try:
|
35 |
+
# print(f"BEFORE: trial: {trial}; stats: {out}")
|
36 |
+
out = index.query(xq, top_k=top_k, include_metadata=include_metadata)
|
37 |
+
# print(f"AFTER: trial: {trial}; stats: {out}")
|
38 |
+
return out
|
39 |
+
except pinecone.core.exceptions.PineconeProtocolError as err:
|
40 |
+
print(f"Error, sleep! {err}")
|
41 |
+
sleep(sleep_time)
|
42 |
+
trial = trial + 1
|
43 |
+
|
44 |
+
raise Exception(f'max trials {MAX_TRIALS} Exceeded!')
|
45 |
+
|
46 |
# initialize connection to pinecone (get API key at app.pinecone.io)
|
47 |
pinecone.init(
|
48 |
api_key=PINECONE_API_KEY,
|
49 |
environment="us-central1-gcp" # may be different, check at app.pinecone.io
|
50 |
)
|
51 |
index = pinecone.Index(INDEX_NAME)
|
52 |
+
stats = test_pinecone()
|
53 |
print(f"Pinecone DB stats: {stats}")
|
54 |
|
55 |
### Main
|
|
|
62 |
# ### call OpenAI text-embedding
|
63 |
res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
|
64 |
xq = res['data'][0]['embedding']
|
65 |
+
# out = index.query(xq, top_k=3, include_metadata=True)
|
66 |
+
out = query_pinecone(xq, top_k=3, include_metadata=True)
|
67 |
|
68 |
### display
|
69 |
+
# print(f"{'*'*30}results #3: {out}")
|
70 |
+
# st.write("Matched results")
|
71 |
+
# for match in out['matches']:
|
72 |
+
# st.write( match['id'] )
|
73 |
+
|
74 |
+
### candidates
|
75 |
+
pids = [
|
76 |
+
match['metadata']['product_id']
|
77 |
+
for match in out['matches']
|
78 |
+
]
|
79 |
+
|
80 |
+
### query pids
|
81 |
+
pids_str = [f"'{pid}'"for pid in pids]
|
82 |
+
query = f"""
|
83 |
+
SELECT productid, name, category, alternatename, url, logo, description
|
84 |
+
FROM main_products
|
85 |
+
WHERE productid in ({', '.join(pids_str)});
|
86 |
+
"""
|
87 |
+
|
88 |
+
results = query_postgresql_realvest(query)
|
89 |
+
print(results)
|
90 |
+
|
91 |
+
for result in results:
|
92 |
+
st.write("---")
|
93 |
+
st.json(result)
|
postgres_db.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import psycopg2
|
2 |
+
|
3 |
+
def query_postgresql(
|
4 |
+
query: str,
|
5 |
+
database: str,
|
6 |
+
user: str,
|
7 |
+
password: str,
|
8 |
+
host: str,
|
9 |
+
port: str,
|
10 |
+
named_columns: bool=True
|
11 |
+
):
|
12 |
+
|
13 |
+
conn = psycopg2.connect(
|
14 |
+
database=database,
|
15 |
+
user=user,
|
16 |
+
password=password,
|
17 |
+
host=host,
|
18 |
+
port=port
|
19 |
+
)
|
20 |
+
|
21 |
+
cur = conn.cursor()
|
22 |
+
cur.execute(query)
|
23 |
+
rows = cur.fetchall()
|
24 |
+
|
25 |
+
if named_columns:
|
26 |
+
column_names = [desc[0] for desc in cur.description]
|
27 |
+
return [ dict(zip(column_names, r)) for r in rows ]
|
28 |
+
|
29 |
+
return rows
|
30 |
+
|
31 |
+
def query_postgresql_realvest(query: str, named_columns: bool=True):
|
32 |
+
import streamlit as st
|
33 |
+
POSTGRESQL_REALVEST_USER = st.secrets["POSTGRESQL_REALVEST_USER"]
|
34 |
+
POSTGRESQL_REALVEST_PSWD = st.secrets["POSTGRESQL_REALVEST_PSWD"]
|
35 |
+
|
36 |
+
return query_postgresql(
|
37 |
+
query,
|
38 |
+
database="realvest",
|
39 |
+
user=POSTGRESQL_REALVEST_USER,
|
40 |
+
password=POSTGRESQL_REALVEST_PSWD,
|
41 |
+
host="realvest.cdb5lmqrlgu5.us-east-2.rds.amazonaws.com",
|
42 |
+
port="5432",
|
43 |
+
named_columns=named_columns
|
44 |
+
)
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
query = """
|
48 |
+
SELECT *
|
49 |
+
FROM main_products
|
50 |
+
WHERE productid in ('2093075');
|
51 |
+
"""
|
52 |
+
|
53 |
+
results = query_postgresql_realvest(query)
|
54 |
+
print(results)
|