neobot commited on
Commit
7a1a853
·
1 Parent(s): 76b9148

sort search results by match score

Browse files
Files changed (1) hide show
  1. app.py +53 -19
app.py CHANGED
@@ -8,6 +8,8 @@ PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
8
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
9
  INDEX_NAME = 'realvest-data-v2'
10
  EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
 
 
11
 
12
  def test_pinecone(sleep_time: int=1):
13
  MAX_TRIALS = 5
@@ -43,6 +45,13 @@ def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: in
43
 
44
  raise Exception(f'max trials {MAX_TRIALS} Exceeded!')
45
 
 
 
 
 
 
 
 
46
  # initialize connection to pinecone (get API key at app.pinecone.io)
47
  pinecone.init(
48
  api_key=PINECONE_API_KEY,
@@ -53,6 +62,14 @@ stats = test_pinecone()
53
  print(f"Pinecone DB stats: {stats}")
54
 
55
  ### Main
 
 
 
 
 
 
 
 
56
  # Create a text input field
57
  query = st.text_input("What are you looking for?")
58
 
@@ -62,23 +79,20 @@ if st.button('Submit'):
62
  # ### call OpenAI text-embedding
63
  res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
64
  xq = res['data'][0]['embedding']
65
- # out = index.query(xq, top_k=3, include_metadata=True)
66
- out = query_pinecone(xq, top_k=3, include_metadata=True)
67
-
68
- ### display
69
- # print(f"{'*'*30}results #3: {out}")
70
- # st.write("Matched results")
71
- # for match in out['matches']:
72
- # st.write( match['id'] )
73
 
74
  ### candidates
75
- pids = [
76
- match['metadata']['product_id']
77
- for match in out['matches']
 
 
 
78
  ]
 
79
 
80
  ### query pids
81
- pids_str = [f"'{pid}'"for pid in pids]
82
  query = f"""
83
  SELECT productid, name, category, alternatename, url, logo, description
84
  FROM main_products
@@ -86,17 +100,37 @@ if st.button('Submit'):
86
  """
87
 
88
  results = query_postgresql_realvest(query)
89
- print(results)
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- for result in results:
92
- col_icon, col_info = st.columns([1, 5])
93
 
94
  with col_icon:
95
  st.image(result["logo"])
96
 
97
  with col_info:
98
- st.subheader(result['name'])
99
- st.write(result["description"])
 
 
 
 
 
 
 
 
100
 
101
- # st.write("---")
102
- # st.json(result)
 
8
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
9
  INDEX_NAME = 'realvest-data-v2'
10
  EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
11
+ MAX_LENGTH_DESC = 200
12
+ MATCH_SCORE_THR = 0.0
13
 
14
  def test_pinecone(sleep_time: int=1):
15
  MAX_TRIALS = 5
 
45
 
46
  raise Exception(f'max trials {MAX_TRIALS} Exceeded!')
47
 
48
+ def sort_dict_by_value(d: dict, ascending: bool=True):
49
+ """
50
+ Sort dictionary {k1: v1, k2: v2} by its value. The output is
51
+ a sorted list of tuples [(k1, v1), (k2, v2)]
52
+ """
53
+ return sorted(d.items(), key=lambda x: x[1], reverse=not ascending)
54
+
55
  # initialize connection to pinecone (get API key at app.pinecone.io)
56
  pinecone.init(
57
  api_key=PINECONE_API_KEY,
 
62
  print(f"Pinecone DB stats: {stats}")
63
 
64
  ### Main
65
+ # st.set_page_config(layout="centered")
66
+ css='''
67
+ <style>
68
+ section.main > div {max-width:70rem}
69
+ </style>
70
+ '''
71
+ st.markdown(css, unsafe_allow_html=True)
72
+
73
  # Create a text input field
74
  query = st.text_input("What are you looking for?")
75
 
 
79
  # ### call OpenAI text-embedding
80
  res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
81
  xq = res['data'][0]['embedding']
82
+ out = query_pinecone(xq, top_k=20, include_metadata=True)
 
 
 
 
 
 
 
83
 
84
  ### candidates
85
+ metadata = {match['metadata']['product_id']: match['metadata'] for match in out['matches']}
86
+ match_score = {match['metadata']['product_id']: match['score'] for match in out['matches']}
87
+ above_thr_sorted = [
88
+ item
89
+ for item in sort_dict_by_value(match_score, ascending=False)
90
+ if item[1] > MATCH_SCORE_THR
91
  ]
92
+ pids = metadata.keys()
93
 
94
  ### query pids
95
+ pids_str = [f"'{pid}'" for pid, _ in above_thr_sorted]
96
  query = f"""
97
  SELECT productid, name, category, alternatename, url, logo, description
98
  FROM main_products
 
100
  """
101
 
102
  results = query_postgresql_realvest(query)
103
+ results = {
104
+ result['productid']: result
105
+ for result in results
106
+ }
107
+
108
+ st.header("Results")
109
+ st.divider()
110
+
111
+ # display matched results
112
+ for pid, match_score in above_thr_sorted:
113
+
114
+ if pid not in results:
115
+ continue
116
 
117
+ result = results[pid]
118
+ col_icon, col_info = st.columns([1, 3])
119
 
120
  with col_icon:
121
  st.image(result["logo"])
122
 
123
  with col_info:
124
+ st.markdown(f"""match score: { round(100 * match_score, 2) }
125
+ <br>
126
+ **{result['name']}**
127
+ <br>
128
+ _Asking Price:_ {metadata[pid].get('asking_price', 'N/A')}
129
+ <br>
130
+ _Category:_ {metadata[pid].get('category', 'N/A')}
131
+ <br>
132
+ _Location:_ {metadata[pid].get('location', 'N/A')}
133
+ """, unsafe_allow_html=True)
134
 
135
+ st.markdown(f"""**_Description:_** {result['description'][:MAX_LENGTH_DESC]}...[more]({result['url']})
136
+ """)