Spaces:
Running
Running
sort search results by match score
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
|
|
8 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
9 |
INDEX_NAME = 'realvest-data-v2'
|
10 |
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
|
|
|
|
|
11 |
|
12 |
def test_pinecone(sleep_time: int=1):
|
13 |
MAX_TRIALS = 5
|
@@ -43,6 +45,13 @@ def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: in
|
|
43 |
|
44 |
raise Exception(f'max trials {MAX_TRIALS} Exceeded!')
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
# initialize connection to pinecone (get API key at app.pinecone.io)
|
47 |
pinecone.init(
|
48 |
api_key=PINECONE_API_KEY,
|
@@ -53,6 +62,14 @@ stats = test_pinecone()
|
|
53 |
print(f"Pinecone DB stats: {stats}")
|
54 |
|
55 |
### Main
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# Create a text input field
|
57 |
query = st.text_input("What are you looking for?")
|
58 |
|
@@ -62,23 +79,20 @@ if st.button('Submit'):
|
|
62 |
# ### call OpenAI text-embedding
|
63 |
res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
|
64 |
xq = res['data'][0]['embedding']
|
65 |
-
|
66 |
-
out = query_pinecone(xq, top_k=3, include_metadata=True)
|
67 |
-
|
68 |
-
### display
|
69 |
-
# print(f"{'*'*30}results #3: {out}")
|
70 |
-
# st.write("Matched results")
|
71 |
-
# for match in out['matches']:
|
72 |
-
# st.write( match['id'] )
|
73 |
|
74 |
### candidates
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
78 |
]
|
|
|
79 |
|
80 |
### query pids
|
81 |
-
pids_str = [f"'{pid}'"for pid in
|
82 |
query = f"""
|
83 |
SELECT productid, name, category, alternatename, url, logo, description
|
84 |
FROM main_products
|
@@ -86,17 +100,37 @@ if st.button('Submit'):
|
|
86 |
"""
|
87 |
|
88 |
results = query_postgresql_realvest(query)
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
col_icon, col_info = st.columns([1,
|
93 |
|
94 |
with col_icon:
|
95 |
st.image(result["logo"])
|
96 |
|
97 |
with col_info:
|
98 |
-
st.
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
|
|
8 |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
9 |
INDEX_NAME = 'realvest-data-v2'
|
10 |
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
|
11 |
+
MAX_LENGTH_DESC = 200
|
12 |
+
MATCH_SCORE_THR = 0.0
|
13 |
|
14 |
def test_pinecone(sleep_time: int=1):
|
15 |
MAX_TRIALS = 5
|
|
|
45 |
|
46 |
raise Exception(f'max trials {MAX_TRIALS} Exceeded!')
|
47 |
|
48 |
+
def sort_dict_by_value(d: dict, ascending: bool=True):
|
49 |
+
"""
|
50 |
+
Sort dictionary {k1: v1, k2: v2} by its value. The output is
|
51 |
+
a sorted list of tuples [(k1, v1), (k2, v2)]
|
52 |
+
"""
|
53 |
+
return sorted(d.items(), key=lambda x: x[1], reverse=not ascending)
|
54 |
+
|
55 |
# initialize connection to pinecone (get API key at app.pinecone.io)
|
56 |
pinecone.init(
|
57 |
api_key=PINECONE_API_KEY,
|
|
|
62 |
print(f"Pinecone DB stats: {stats}")
|
63 |
|
64 |
### Main
|
65 |
+
# st.set_page_config(layout="centered")
|
66 |
+
css='''
|
67 |
+
<style>
|
68 |
+
section.main > div {max-width:70rem}
|
69 |
+
</style>
|
70 |
+
'''
|
71 |
+
st.markdown(css, unsafe_allow_html=True)
|
72 |
+
|
73 |
# Create a text input field
|
74 |
query = st.text_input("What are you looking for?")
|
75 |
|
|
|
79 |
# ### call OpenAI text-embedding
|
80 |
res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
|
81 |
xq = res['data'][0]['embedding']
|
82 |
+
out = query_pinecone(xq, top_k=20, include_metadata=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
### candidates
|
85 |
+
metadata = {match['metadata']['product_id']: match['metadata'] for match in out['matches']}
|
86 |
+
match_score = {match['metadata']['product_id']: match['score'] for match in out['matches']}
|
87 |
+
above_thr_sorted = [
|
88 |
+
item
|
89 |
+
for item in sort_dict_by_value(match_score, ascending=False)
|
90 |
+
if item[1] > MATCH_SCORE_THR
|
91 |
]
|
92 |
+
pids = metadata.keys()
|
93 |
|
94 |
### query pids
|
95 |
+
pids_str = [f"'{pid}'" for pid, _ in above_thr_sorted]
|
96 |
query = f"""
|
97 |
SELECT productid, name, category, alternatename, url, logo, description
|
98 |
FROM main_products
|
|
|
100 |
"""
|
101 |
|
102 |
results = query_postgresql_realvest(query)
|
103 |
+
results = {
|
104 |
+
result['productid']: result
|
105 |
+
for result in results
|
106 |
+
}
|
107 |
+
|
108 |
+
st.header("Results")
|
109 |
+
st.divider()
|
110 |
+
|
111 |
+
# display matched results
|
112 |
+
for pid, match_score in above_thr_sorted:
|
113 |
+
|
114 |
+
if pid not in results:
|
115 |
+
continue
|
116 |
|
117 |
+
result = results[pid]
|
118 |
+
col_icon, col_info = st.columns([1, 3])
|
119 |
|
120 |
with col_icon:
|
121 |
st.image(result["logo"])
|
122 |
|
123 |
with col_info:
|
124 |
+
st.markdown(f"""match score: { round(100 * match_score, 2) }
|
125 |
+
<br>
|
126 |
+
**{result['name']}**
|
127 |
+
<br>
|
128 |
+
_Asking Price:_ {metadata[pid].get('asking_price', 'N/A')}
|
129 |
+
<br>
|
130 |
+
_Category:_ {metadata[pid].get('category', 'N/A')}
|
131 |
+
<br>
|
132 |
+
_Location:_ {metadata[pid].get('location', 'N/A')}
|
133 |
+
""", unsafe_allow_html=True)
|
134 |
|
135 |
+
st.markdown(f"""**_Description:_** {result['description'][:MAX_LENGTH_DESC]}...[more]({result['url']})
|
136 |
+
""")
|