neobot commited on
Commit
49df564
·
1 Parent(s): 6e7c49a

improve summary

Browse files
Files changed (1) hide show
  1. app.py +60 -20
app.py CHANGED
@@ -3,6 +3,7 @@ import streamlit as st
3
  import openai
4
  import pinecone
5
  from postgres_db import query_postgresql_realvest
 
6
 
7
  PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
8
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
@@ -11,18 +12,20 @@ EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr
11
  MAX_LENGTH_DESC = 200
12
  MATCH_SCORE_THR = 0.0
13
  TOP_K = 20
 
 
14
 
15
 
16
 
17
 
18
 
19
- def query_pinecone(xq, top_k: int=3, include_metadata: bool=True, sleep_time: int=10):
20
  MAX_TRIALS = 5
21
  trial = 0
22
  out = None
23
  while (out is None) and (trial < MAX_TRIALS):
24
  try:
25
- out = st.session_state['index'].query(xq, top_k=top_k, include_metadata=include_metadata)
26
  return out
27
  except pinecone.core.exceptions.PineconeProtocolError as err:
28
  print(f"Error, sleep! {err}")
@@ -96,24 +99,35 @@ def summarize_products(products: list) -> str:
96
  summary = "{summary of all products}"
97
  """
98
  NEW_LINE = '\n'
99
- prompt = f"""
100
- Based on the product information below, please read and try to understand it.
101
- { f"{NEW_LINE*2}---{NEW_LINE*2}".join(products) }
102
- Please write a concise and insightful summary table (display as HTML) to compare the products for investors, which should inlcude but not limited to:
103
- - description
104
- - category
105
- - asking price
106
- - location
107
- - potential profit margin
 
 
 
 
 
 
 
 
 
 
 
108
  """
109
- print(f"prompt: {prompt}")
110
 
111
  openai.api_key = OPENAI_API_KEY
112
  completion = openai.ChatCompletion.create(
113
  model="gpt-4",
114
  messages=[
115
  {"role": "system", "content": "You are a helpful assistant."},
116
- {"role": "user", "content": prompt}
117
  ]
118
  )
119
 
@@ -166,8 +180,8 @@ if st.button('Search'):
166
  ### call OpenAI text-embedding
167
  res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
168
  xq = res['data'][0]['embedding']
169
- out = query_pinecone(xq, top_k=TOP_K, include_metadata=True)
170
-
171
  if (out is not None) and ('matches' in out):
172
  metadata = {match['metadata']['product_id']: match['metadata'] for match in out['matches'] if 'metadata' in match and match['metadata'] is not None}
173
 
@@ -259,14 +273,40 @@ if st.session_state['count_checked'] > 0:
259
  with summary_container.container():
260
  st.header('Summary')
261
  if st.button('Compare Products'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  products = []
263
- for key in st.session_state['checked_boxes']:
264
- # TODO: Need to pull all the document
265
- # TODO: Need to dedup the pid too
266
- pid = key.split('__')[-1]
267
  products.append(
268
- st.session_state['metadata'][pid].get('document')
269
  )
 
 
270
  with st.spinner('Summarizing...'):
271
  summary = summarize_products(products)
272
  st.markdown(summary.get("content"), unsafe_allow_html=True)
 
3
  import openai
4
  import pinecone
5
  from postgres_db import query_postgresql_realvest
6
+ import numpy as np
7
 
8
  PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
9
  OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
 
12
  MAX_LENGTH_DESC = 200
13
  MATCH_SCORE_THR = 0.0
14
  TOP_K = 20
15
+ EMBEDDING_VECTOR_DIM = 1536
16
+ ZERO_EMBEDDING_VECTOR = list(np.zeros(EMBEDDING_VECTOR_DIM))
17
 
18
 
19
 
20
 
21
 
22
+ def query_pinecone(vector=None, top_k: int=3, include_metadata: bool=True, metadata_filter: dict=None, sleep_time: int=10):
23
  MAX_TRIALS = 5
24
  trial = 0
25
  out = None
26
  while (out is None) and (trial < MAX_TRIALS):
27
  try:
28
+ out = st.session_state['index'].query(vector=vector, top_k=top_k, filter=metadata_filter, include_metadata=include_metadata)
29
  return out
30
  except pinecone.core.exceptions.PineconeProtocolError as err:
31
  print(f"Error, sleep! {err}")
 
99
  summary = "{summary of all products}"
100
  """
101
  NEW_LINE = '\n'
102
+
103
+ PROMPT_PRODUCTS_SUMMARY = f"""
104
+ You are a very sharp and helpful assistant to a group of commercial real estate investors.
105
+ You are about to write a summary comparison of a few products whose information are given below:
106
+
107
+ ----- DESCRIPTION of PRODUCTS -----
108
+
109
+ { f"{NEW_LINE*2}---{NEW_LINE*2}".join(products) }
110
+
111
+ -----------------------------------
112
+
113
+ Please write a concise and insightful summary table to compare the products for investors, which should include but not limited to:
114
+ - title
115
+ - product summary
116
+ - category
117
+ - asking price
118
+ - location
119
+ - potential profit margin
120
+
121
+ and display the resulting table in HTML.
122
  """
123
+ print(f"prompt: {PROMPT_PRODUCTS_SUMMARY}")
124
 
125
  openai.api_key = OPENAI_API_KEY
126
  completion = openai.ChatCompletion.create(
127
  model="gpt-4",
128
  messages=[
129
  {"role": "system", "content": "You are a helpful assistant."},
130
+ {"role": "user", "content": PROMPT_PRODUCTS_SUMMARY}
131
  ]
132
  )
133
 
 
180
  ### call OpenAI text-embedding
181
  res = openai.Embedding.create(model=EMBEDDING_MODEL, input=[query], api_key=OPENAI_API_KEY)
182
  xq = res['data'][0]['embedding']
183
+ out = query_pinecone(vector=xq, top_k=TOP_K, include_metadata=True)
184
+
185
  if (out is not None) and ('matches' in out):
186
  metadata = {match['metadata']['product_id']: match['metadata'] for match in out['matches'] if 'metadata' in match and match['metadata'] is not None}
187
 
 
273
  with summary_container.container():
274
  st.header('Summary')
275
  if st.button('Compare Products'):
276
+
277
+ # populate pids that are checked
278
+ relevant_pids = [key.split('__')[-1] for key in st.session_state['checked_boxes']]
279
+ relevant_pids = list(set(relevant_pids))
280
+
281
+ # get metadata from pinecone
282
+ metadata_filter = {
283
+ 'product_id': {"$in": relevant_pids}
284
+ }
285
+ results = query_pinecone(
286
+ vector=ZERO_EMBEDDING_VECTOR,
287
+ top_k=100,
288
+ include_metadata=True,
289
+ metadata_filter=metadata_filter
290
+ )
291
+
292
+ # organize document by product_id
293
+ documents = {}
294
+ for res in results['matches']:
295
+ pid, chunk_id = res['id'].split('-')
296
+ if pid not in documents:
297
+ documents[pid] = {}
298
+ if "chunk" not in documents[pid]:
299
+ documents[pid]['chunk'] = {}
300
+ documents[pid]['chunk'][chunk_id] = res['metadata']['document']
301
+
302
+ # concatenate documents
303
  products = []
304
+ for pid, doc in documents.items():
 
 
 
305
  products.append(
306
+ doc['chunk']['1'] + '\n\n' + doc['chunk']['2']
307
  )
308
+
309
+ # summarize
310
  with st.spinner('Summarizing...'):
311
  summary = summarize_products(products)
312
  st.markdown(summary.get("content"), unsafe_allow_html=True)