domenicrosati commited on
Commit
f5555cd
Β·
1 Parent(s): 2d39184

experiment with summarization

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -149,9 +149,10 @@ def init_models():
149
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
150
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
152
- return question_answerer, reranker, stop, device # uqeryexp_model, queryexp_tokenizer
 
153
 
154
- qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
155
 
156
 
157
  def clean_query(query, strict=True, clean=True):
@@ -212,6 +213,9 @@ st.markdown("""
212
  """, unsafe_allow_html=True)
213
 
214
  with st.expander("Settings (strictness, context limit, top hits)"):
 
 
 
215
  support_all = st.radio(
216
  "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
217
  ('yes', 'no'))
@@ -267,6 +271,21 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
267
  return None
268
 
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  def run_query(query):
271
  # if use_query_exp == 'yes':
272
  # query_exp = paraphrase(f"question2question: {query}")
@@ -275,10 +294,6 @@ def run_query(query):
275
  # * {query_exp}
276
  # """)
277
 
278
- # address period in highlitht avoidability. Risk factors
279
- # address poor tokenization Deletions involving chromosome region 4p16.3 cause WolfHirschhorn syndrome (WHS, OMIM 194190) [Battaglia et al, 2001].
280
- # address highlight html
281
-
282
  # could also try fallback if there are no good answers by score...
283
  limit = top_hits_limit or 100
284
  context_limit = context_lim or 10
@@ -346,10 +361,13 @@ def run_query(query):
346
  else:
347
  threshold = (confidence_threshold or 10) / 100
348
 
349
- sorted_result = filter(
350
  lambda x: x['score'] > threshold,
351
  sorted_result
352
- )
 
 
 
353
 
354
  for r in sorted_result:
355
  ctx = remove_html(r["context"])
 
149
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
150
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
152
+ summarizer = pipeline("summarization")
153
+ return question_answerer, reranker, stop, device, summarizer
154
 
155
+ qa_model, reranker, stop, device, summarizer = init_models() # queryexp_model, queryexp_tokenizer
156
 
157
 
158
  def clean_query(query, strict=True, clean=True):
 
213
  """, unsafe_allow_html=True)
214
 
215
  with st.expander("Settings (strictness, context limit, top hits)"):
216
+ use_mds = st.radio(
217
+ "Use multi-document summarization to summarize answer?",
218
+ ('yes', 'no'))
219
  support_all = st.radio(
220
  "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
221
  ('yes', 'no'))
 
271
  return None
272
 
273
 
274
+ def gen_summary(query, sorted_result):
275
+ doc_sep = '\n'
276
+ summary = summarizer(f'{query} '.join([f'{doc_sep}'.join(r['texts']) + r['context'] for r in sorted_result]))[0]['summary_text']
277
+ st.markdown(f"""
278
+ <div class="container-fluid">
279
+ <div class="row align-items-start">
280
+ <div class="col-md-12 col-sm-12">
281
+ <strong>Answer:</strong> {summary}
282
+ </div>
283
+ </div>
284
+ </div>
285
+ """, unsafe_allow_html=True)
286
+ st.markdown("<br /><br /><h5>Sources:</h5>", unsafe_allow_html=True)
287
+
288
+
289
  def run_query(query):
290
  # if use_query_exp == 'yes':
291
  # query_exp = paraphrase(f"question2question: {query}")
 
294
  # * {query_exp}
295
  # """)
296
 
 
 
 
 
297
  # could also try fallback if there are no good answers by score...
298
  limit = top_hits_limit or 100
299
  context_limit = context_lim or 10
 
361
  else:
362
  threshold = (confidence_threshold or 10) / 100
363
 
364
+ sorted_result = list(filter(
365
  lambda x: x['score'] > threshold,
366
  sorted_result
367
+ ))
368
+
369
+ if use_mds == 'yes':
370
+ gen_summary(query, sorted_result)
371
 
372
  for r in sorted_result:
373
  ctx = remove_html(r["context"])