chansung commited on
Commit
7e4123a
Β·
1 Parent(s): 9ea7354

major update

Browse files
Files changed (12) hide show
  1. app.py +92 -413
  2. background.py +93 -0
  3. constants/context.py +13 -0
  4. constants/js.py +109 -10
  5. constants/styles.py +97 -2
  6. gen/gemini.py +1 -1
  7. gen/gemini_chat.py +129 -0
  8. gen/openllm.py +178 -0
  9. init.py +100 -0
  10. requirements.txt +5 -1
  11. ui.py +264 -0
  12. utils.py +5 -12
app.py CHANGED
@@ -1,80 +1,26 @@
1
- import os
2
- import re
3
- import copy
4
- import datasets
5
- import pandas as pd
6
  import gradio as gr
7
 
8
- from collections import defaultdict
9
- from datetime import datetime, timedelta
10
- from datasets import Dataset
11
- from huggingface_hub import HfApi
12
- from huggingface_hub import create_repo
13
- from huggingface_hub.utils import HfHubHTTPError
14
-
15
- import utils
16
- from paper.download import (
17
- download_pdf_from_arxiv,
18
- get_papers_from_hf_daily_papers,
19
- get_papers_from_arxiv_ids
20
- )
21
- from paper.parser import extract_text_and_figures
22
- from gen.gemini import get_basic_qa, get_deep_qa
23
-
24
  from constants.styles import STYLE
25
- from constants.js import UPDATE_SEARCH_RESULTS, UPDATE_IF_TYPE
26
- from constants.utils import get_secrets
 
 
27
 
 
 
28
  from apscheduler.schedulers.background import BackgroundScheduler
29
 
30
- def count_nans(row):
31
- count = 0
32
-
33
- for _, (k, v) in enumerate(data.items()):
34
- if v is None:
35
- count = count + 1
36
 
37
- return count
38
-
39
- gemini_api_key, hf_token, dataset_repo_id, request_arxiv_repo_id = get_secrets()
40
-
41
- ds = datasets.load_dataset(dataset_repo_id)
42
- request_ds = datasets.load_dataset(request_arxiv_repo_id)
43
- requested_arxiv_ids = []
44
- for request_d in request_ds['train']:
45
- arxiv_ids = request_d['Requested arXiv IDs']
46
- requested_arxiv_ids = requested_arxiv_ids + arxiv_ids
47
- requested_arxiv_ids_df = pd.DataFrame({'Requested arXiv IDs': requested_arxiv_ids})
48
-
49
- title2qna = {}
50
- date2qna = {}
51
- date_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
52
-
53
- for data in ds["train"]:
54
- date = data["target_date"].strftime("%Y-%m-%d")
55
-
56
- if date in date2qna:
57
- papers = copy.deepcopy(date2qna[date])
58
- for paper in papers:
59
- if paper["title"] == data["title"]:
60
- if count_nans(paper) > count_nans(data):
61
- date2qna[date].remove(paper)
62
-
63
- date2qna[date].append(data)
64
- del papers
65
- else:
66
- date2qna[date] = [data]
67
-
68
- for date in date2qna:
69
- year, month, day = date.split("-")
70
- papers = date2qna[date]
71
- for paper in papers:
72
- title2qna[paper["title"]] = paper
73
- date_dict[year][month][day].append(paper)
74
-
75
- titles = title2qna.keys()
76
-
77
- sorted_dates = sorted(date2qna.keys())
78
 
79
  sorted_year = sorted(date_dict.keys())
80
  last_year = sorted_year[-1]
@@ -85,301 +31,28 @@ last_day = sorted_day[-1]
85
  last_papers = date_dict[last_year][last_month][last_day]
86
  selected_paper = last_papers[0]
87
 
88
- def filter_function(example, ids):
89
- ids_e = example['Requested arXiv IDs']
90
- for iid in ids:
91
- if iid in ids_e:
92
- ids_e.remove(iid)
93
- example['Requested arXiv IDs'] = ids_e
94
-
95
- print(example)
96
- return example
97
-
98
- def process_arxiv_ids(gemini_api, hf_repo_id, req_hf_repo_id, hf_token, how_many=10):
99
- arxiv_ids = []
100
-
101
- ds1 = datasets.load_dataset(req_hf_repo_id)
102
- for d in ds1['train']:
103
- req_arxiv_ids = d['Requested arXiv IDs']
104
- if len(req_arxiv_ids) > 0 and req_arxiv_ids[0] != "top":
105
- arxiv_ids = arxiv_ids + req_arxiv_ids
106
-
107
- arxiv_ids = arxiv_ids[:how_many]
108
-
109
- if arxiv_ids is not None and len(arxiv_ids) > 0:
110
- print(f"1. Get metadata for the papers [{arxiv_ids}]")
111
- papers = get_papers_from_arxiv_ids(arxiv_ids)
112
- print("...DONE")
113
-
114
- print("2. Generating QAs for the paper")
115
- for paper in papers:
116
- try:
117
- title = paper['title']
118
- target_date = paper['target_date']
119
- abstract = paper['paper']['summary']
120
- arxiv_id = paper['paper']['id']
121
- authors = paper['paper']['authors']
122
-
123
- print(f"...PROCESSING ON[{arxiv_id}, {title}]")
124
- print(f"......Downloading the paper PDF")
125
- filename = download_pdf_from_arxiv(arxiv_id)
126
- print(f"......DONE")
127
-
128
- print(f"......Extracting text and figures")
129
- texts, figures = extract_text_and_figures(filename)
130
- text =' '.join(texts)
131
- print(f"......DONE")
132
-
133
- print(f"......Generating the seed(basic) QAs")
134
- qnas = get_basic_qa(text, gemini_api_key=gemini_api, trucate=30000)
135
- qnas['title'] = title
136
- qnas['abstract'] = abstract
137
- qnas['authors'] = ','.join(authors)
138
- qnas['arxiv_id'] = arxiv_id
139
- qnas['target_date'] = target_date
140
- qnas['full_text'] = text
141
- print(f"......DONE")
142
-
143
- print(f"......Generating the follow-up QAs")
144
- qnas = get_deep_qa(text, qnas, gemini_api_key=gemini_api, trucate=30000)
145
- del qnas["qna"]
146
- print(f"......DONE")
147
-
148
- print(f"......Exporting to HF Dataset repo at [{hf_repo_id}]")
149
- utils.push_to_hf_hub(qnas, hf_repo_id, hf_token)
150
- print(f"......DONE")
151
-
152
- print(f"......Updating request arXiv HF Dataset repo at [{req_hf_repo_id}]")
153
- ds1 = ds1['train'].map(
154
- lambda example: filter_function(example, [arxiv_id])
155
- ).filter(
156
- lambda example: len(example['Requested arXiv IDs']) > 0
157
- )
158
- ds1.push_to_hub(req_hf_repo_id, token=hf_token)
159
-
160
- print(f"......DONE")
161
- except Exception as e:
162
- print(f".......failed due to exception {e}")
163
- continue
164
-
165
- HfApi(token=hf_token).restart_space(
166
- repo_id="chansung/paper_qa", token=hf_token
167
- )
168
-
169
- def push_to_hf_hub(
170
- df, repo_id, token, append=True
171
- ):
172
- exist = False
173
- ds = Dataset.from_pandas(df)
174
-
175
- try:
176
- create_repo(request_arxiv_repo_id, repo_type="dataset", token=hf_token)
177
- except HfHubHTTPError as e:
178
- exist = True
179
-
180
- if exist and append:
181
- existing_ds = datasets.load_dataset(repo_id)
182
- ds = datasets.concatenate_datasets([existing_ds['train'], ds])
183
-
184
- ds.push_to_hub(repo_id, token=token)
185
-
186
- def _filter_duplicate_arxiv_ids(arxiv_ids_to_be_added):
187
- ds1 = datasets.load_dataset("chansung/requested-arxiv-ids-3")
188
- ds2 = datasets.load_dataset("chansung/auto-paper-qa2")
189
-
190
- unique_arxiv_ids = set()
191
-
192
- for d in ds1['train']:
193
- arxiv_ids = d['Requested arXiv IDs']
194
- unique_arxiv_ids = set(list(unique_arxiv_ids) + arxiv_ids)
195
-
196
- for d in ds2['train']:
197
- arxiv_id = d['arxiv_id']
198
- unique_arxiv_ids.add(arxiv_id)
199
-
200
- return list(set(arxiv_ids_to_be_added) - unique_arxiv_ids)
201
-
202
- def _is_arxiv_id_valid(arxiv_id):
203
- pattern = r"^\d{4}\.\d{5}$"
204
- return bool(re.match(pattern, arxiv_id))
205
-
206
- def _get_valid_arxiv_ids(arxiv_ids_str):
207
- valid_arxiv_ids = []
208
- invalid_arxiv_ids = []
209
-
210
- for arxiv_id in arxiv_ids_str.split(","):
211
- arxiv_id = arxiv_id.strip()
212
- if _is_arxiv_id_valid(arxiv_id):
213
- valid_arxiv_ids.append(arxiv_id)
214
- else:
215
- invalid_arxiv_ids.append(arxiv_id)
216
-
217
- return valid_arxiv_ids, invalid_arxiv_ids
218
-
219
- def add_arxiv_ids_to_queue(queue, arxiv_ids_str):
220
- print(0)
221
- valid_arxiv_ids, invalid_arxiv_ids = _get_valid_arxiv_ids(arxiv_ids_str)
222
- print("01")
223
-
224
- if len(invalid_arxiv_ids) > 0:
225
- gr.Warning(f"found invalid arXiv ids as in {invalid_arxiv_ids}")
226
-
227
- if len(valid_arxiv_ids) > 0:
228
- valid_arxiv_ids = _filter_duplicate_arxiv_ids(valid_arxiv_ids)
229
-
230
- if len(valid_arxiv_ids) > 0:
231
- valid_arxiv_ids = [[arxiv_id] for arxiv_id in valid_arxiv_ids]
232
- gr.Warning(f"Processing on [{valid_arxiv_ids}]. Other requested arXiv IDs not found on this list should be already processed or being processed...")
233
- valid_arxiv_ids = pd.DataFrame({'Requested arXiv IDs': valid_arxiv_ids})
234
- queue = pd.concat([queue, valid_arxiv_ids])
235
- queue.reset_index(drop=True)
236
-
237
- push_to_hf_hub(valid_arxiv_ids, request_arxiv_repo_id, hf_token)
238
- else:
239
- gr.Warning(f"All requested arXiv IDs are already processed or being processed...")
240
- else:
241
- gr.Warning(f"No valid arXiv IDs found...")
242
-
243
- return (
244
- queue, gr.Textbox("")
245
- )
246
-
247
- def get_paper_by_year(y):
248
- m = sorted(date_dict[y].keys())
249
- last_m = m[-1]
250
- d = sorted(date_dict[y][last_m].keys())
251
- last_d = d[-1]
252
- papers = [paper["title"] for paper in date_dict[y][last_m][last_d]]
253
- papers = list(set(papers))
254
- return (
255
- gr.Dropdown(choices=m, value=last_m),
256
- gr.Dropdown(choices=d, value=last_d),
257
- gr.Dropdown(choices=papers, value=papers[0])
258
- )
259
-
260
- def get_paper_by_month(y, m):
261
- d = sorted(date_dict[y][m].keys())
262
- last_d = d[-1]
263
- papers = [paper["title"] for paper in date_dict[y][m][last_d]]
264
- papers = list(set(papers))
265
- return (
266
- gr.Dropdown(choices=d, value=last_d),
267
- gr.Dropdown(choices=papers, value=papers[0])
268
- )
269
 
270
- def get_paper_by_day(y, m, d):
271
- papers = [paper["title"] for paper in date_dict[y][m][d]]
272
- papers = list(set(papers))
273
- return gr.Dropdown(choices=papers, value=papers[0])
274
-
275
- def set_paper(y, m, d, paper_title):
276
- selected_paper = None
277
- for paper in date_dict[y][m][d]:
278
- if paper["title"] == paper_title:
279
- selected_paper = paper
280
- break
281
-
282
- return (
283
- gr.Markdown(f"# {selected_paper['title']}"),
284
- gr.Markdown(
285
- "[![arXiv](https://img.shields.io/badge/arXiv-%s-b31b1b.svg)](https://arxiv.org/abs/%s)" % (selected_paper['arxiv_id'], selected_paper['arxiv_id'])
286
- ),
287
- gr.Markdown(
288
- "[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/%s)" % selected_paper['arxiv_id']
289
- ),
290
- gr.Markdown(selected_paper["summary"]),
291
-
292
- gr.Markdown(f"### πŸ™‹ {selected_paper['0_question']}"),
293
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['0_answers:eli5']}"),
294
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['0_answers:expert']}"),
295
- gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['0_additional_depth_q:follow up question']}"),
296
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['0_additional_depth_q:answers:eli5']}"),
297
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['0_additional_depth_q:answers:expert']}"),
298
- gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['0_additional_breath_q:follow up question']}"),
299
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['0_additional_breath_q:answers:eli5']}"),
300
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['0_additional_breath_q:answers:expert']}"),
301
-
302
- gr.Markdown(f"### πŸ™‹ {selected_paper['1_question']}"),
303
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['1_answers:eli5']}"),
304
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['1_answers:expert']}"),
305
- gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['1_additional_depth_q:follow up question']}"),
306
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['1_additional_depth_q:answers:eli5']}"),
307
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['1_additional_depth_q:answers:expert']}"),
308
- gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['1_additional_breath_q:follow up question']}"),
309
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['1_additional_breath_q:answers:eli5']}"),
310
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['1_additional_breath_q:answers:expert']}"),
311
-
312
- gr.Markdown(f"### πŸ™‹ {selected_paper['2_question']}"),
313
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['2_answers:eli5']}"),
314
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['2_answers:expert']}"),
315
- gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['2_additional_depth_q:follow up question']}"),
316
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['2_additional_depth_q:answers:eli5']}"),
317
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['2_additional_depth_q:answers:expert']}"),
318
- gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['2_additional_breath_q:follow up question']}"),
319
- gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['2_additional_breath_q:answers:eli5']}"),
320
- gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['2_additional_breath_q:answers:expert']}"),
321
- )
322
 
323
- def change_exp_type(exp_type):
324
- if exp_type == "ELI5":
325
- return (
326
- gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False),
327
- gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False),
328
- gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False),
329
- )
330
- else:
331
- return (
332
- gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True),
333
- gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True),
334
- gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True),
335
- )
336
-
337
- def search(search_in, max_results=3):
338
- results = []
339
-
340
- for title in titles:
341
- if len(results) > 3:
342
- break
343
- else:
344
- if search_in in title:
345
- results.append(title)
346
-
347
- return (
348
- gr.Textbox(
349
- visible=True if len(results) > 0 else False,
350
- value=results[0] if len(results) > 0 else ""
351
- ),
352
- gr.Textbox(
353
- visible=True if len(results) > 1 else False,
354
- value=results[1] if len(results) > 1 else ""
355
- ),
356
- gr.Textbox(
357
- visible=True if len(results) > 2 else False,
358
- value=results[2] if len(results) > 2 else ""
359
- )
360
- )
361
 
362
- def set_date(title):
363
- for _, (year, months) in enumerate(date_dict.items()):
364
- for _, (month, days) in enumerate(months.items()):
365
- for _, (day, papers) in enumerate(days.items()):
366
- for paper in papers:
367
- if paper['title'] == title:
368
- return (
369
- gr.Dropdown(value=year),
370
- gr.Dropdown(choices=sorted(months), value=month),
371
- gr.Dropdown(choices=sorted(days), value=day),
372
- )
373
-
374
- def set_papers(y, m, d, title):
375
- papers = [paper["title"] for paper in date_dict[y][m][d]]
376
- papers = list(set(papers))
377
- return (
378
- gr.Dropdown(choices=papers, value=title),
379
- gr.Textbox("")
380
- )
381
-
382
- with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
383
  gr.Markdown("# Let's explore papers with auto generated Q&As")
384
 
385
  with gr.Column(elem_id="control-panel", elem_classes=["group"]):
@@ -410,25 +83,22 @@ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
410
  search_r9 = gr.Button(visible=False, elem_id="search_r9", elem_classes=["no-radius"])
411
  search_r10 = gr.Button(visible=False, elem_id="search_r10", elem_classes=["no-radius"])
412
 
413
- conv_type = gr.Radio(choices=["Q&As", "Chat"], value="Q&As", interactive=True, visible=False, elem_classes=["conv-type"])
414
-
415
  with gr.Column(scale=7):
416
- title = gr.Markdown(f"# {selected_paper['title']}")
417
  # with gr.Row():
418
  with gr.Row():
419
  arxiv_link = gr.Markdown(
420
- "[![arXiv](https://img.shields.io/badge/arXiv-%s-b31b1b.svg)](https://arxiv.org/abs/%s)" % (selected_paper['arxiv_id'], selected_paper['arxiv_id'])
 
421
  )
422
  hf_paper_link = gr.Markdown(
423
- "[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/%s)" % selected_paper['arxiv_id']
 
424
  )
425
- gr.Button("Chat about the paper", interactive=False)
426
 
427
  summary = gr.Markdown(f"{selected_paper['summary']}", elem_classes=["small-font"])
428
 
429
- with gr.Column(elem_id="chat_block", visible=False):
430
- gr.Chatbot([("hello", "world"), ("how", "are you?")])
431
-
432
  with gr.Column(elem_id="qna_block", visible=True):
433
  with gr.Row():
434
  with gr.Column(scale=7):
@@ -489,7 +159,7 @@ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
489
  headers=["Requested arXiv IDs"], col_count=(1, "fixed"),
490
  value=requested_arxiv_ids_df,
491
  datatype=["str"],
492
- interactive=False
493
  )
494
 
495
  arxiv_id_enter = gr.Textbox(placeholder="Enter comma separated arXiv IDs...", elem_classes=["textbox-no-label"])
@@ -508,72 +178,68 @@ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
508
  search_r1.click(set_date, search_r1, [year_dd, month_dd, day_dd]).then(
509
  set_papers,
510
  inputs=[year_dd, month_dd, day_dd, search_r1],
511
- outputs=[papers_dd, search_in]
512
  )
513
 
514
  search_r2.click(set_date, search_r2, [year_dd, month_dd, day_dd]).then(
515
  set_papers,
516
  inputs=[year_dd, month_dd, day_dd, search_r2],
517
- outputs=[papers_dd, search_in]
518
  )
519
 
520
  search_r3.click(set_date, search_r3, [year_dd, month_dd, day_dd]).then(
521
  set_papers,
522
  inputs=[year_dd, month_dd, day_dd, search_r3],
523
- outputs=[papers_dd, search_in]
524
  )
525
 
526
  search_r4.click(set_date, search_r4, [year_dd, month_dd, day_dd]).then(
527
  set_papers,
528
  inputs=[year_dd, month_dd, day_dd, search_r4],
529
- outputs=[papers_dd, search_in]
530
  )
531
 
532
  search_r5.click(set_date, search_r5, [year_dd, month_dd, day_dd]).then(
533
  set_papers,
534
  inputs=[year_dd, month_dd, day_dd, search_r5],
535
- outputs=[papers_dd, search_in]
536
  )
537
 
538
  search_r6.click(set_date, search_r6, [year_dd, month_dd, day_dd]).then(
539
  set_papers,
540
  inputs=[year_dd, month_dd, day_dd, search_r6],
541
- outputs=[papers_dd, search_in]
542
  )
543
 
544
  search_r7.click(set_date, search_r7, [year_dd, month_dd, day_dd]).then(
545
  set_papers,
546
  inputs=[year_dd, month_dd, day_dd, search_r7],
547
- outputs=[papers_dd, search_in]
548
  )
549
 
550
  search_r8.click(set_date, search_r8, [year_dd, month_dd, day_dd]).then(
551
  set_papers,
552
  inputs=[year_dd, month_dd, day_dd, search_r8],
553
- outputs=[papers_dd, search_in]
554
  )
555
 
556
  search_r9.click(set_date, search_r9, [year_dd, month_dd, day_dd]).then(
557
  set_papers,
558
  inputs=[year_dd, month_dd, day_dd, search_r9],
559
- outputs=[papers_dd, search_in]
560
  )
561
 
562
  search_r10.click(set_date, search_r10, [year_dd, month_dd, day_dd]).then(
563
  set_papers,
564
  inputs=[year_dd, month_dd, day_dd, search_r10],
565
- outputs=[papers_dd, search_in]
566
  )
567
 
568
- year_dd.input(
569
- get_paper_by_year,
570
- inputs=[year_dd],
571
- outputs=[month_dd, day_dd, papers_dd]
572
- ).then(
573
- set_paper,
574
- [year_dd, month_dd, day_dd, papers_dd],
575
  [
576
- title, summary,
 
577
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
578
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
579
  breath_q_0, breath_q_eli5_0, breath_q_expert_0,
@@ -588,14 +254,10 @@ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
588
  ]
589
  )
590
 
591
- month_dd.input(
592
- get_paper_by_month,
593
- inputs=[year_dd, month_dd],
594
- outputs=[day_dd, papers_dd]
595
- ).then(
596
- set_paper,
597
- [year_dd, month_dd, day_dd, papers_dd],
598
  [
 
599
  title, arxiv_link, hf_paper_link, summary,
600
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
601
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
@@ -611,14 +273,10 @@ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
611
  ]
612
  )
613
 
614
- day_dd.input(
615
- get_paper_by_day,
616
- inputs=[year_dd, month_dd, day_dd],
617
- outputs=[papers_dd]
618
- ).then(
619
- set_paper,
620
- [year_dd, month_dd, day_dd, papers_dd],
621
  [
 
622
  title, arxiv_link, hf_paper_link, summary,
623
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
624
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
@@ -634,10 +292,9 @@ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
634
  ]
635
  )
636
 
637
- papers_dd.change(
638
- set_paper,
639
- [year_dd, month_dd, day_dd, papers_dd],
640
  [
 
641
  title, arxiv_link, hf_paper_link, summary,
642
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
643
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
@@ -672,14 +329,35 @@ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
672
  basic_q_eli5_2, basic_q_expert_2, depth_q_eli5_2, depth_q_expert_2, breath_q_eli5_2, breath_q_expert_2
673
  ]
674
  )
 
 
 
675
 
676
- conv_type.select(
677
- inputs=[conv_type],
678
- js=UPDATE_IF_TYPE,
679
- outputs=None,
680
- fn=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
  )
682
 
 
 
 
683
  start_date = datetime.now() + timedelta(minutes=1)
684
  scheduler = BackgroundScheduler()
685
  scheduler.add_job(
@@ -690,7 +368,8 @@ scheduler.add_job(
690
  gemini_api_key,
691
  dataset_repo_id,
692
  request_arxiv_repo_id,
693
- hf_token
 
694
  ],
695
  start_date=start_date
696
  )
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ from init import get_secrets, initialize_data, update_dataframe
4
+ from gen.openllm import GradioLLaMA2ChatPPManager, GradioMistralChatPPManager
5
+ from gen.gemini_chat import GradioGeminiChatPPManager
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from constants.styles import STYLE
7
+ from constants.js import (
8
+ UPDATE_SEARCH_RESULTS, OPEN_CHAT_IF,
9
+ CLOSE_CHAT_IF, UPDATE_CHAT_HISTORY
10
+ )
11
 
12
+ from datetime import datetime, timedelta
13
+ from background import process_arxiv_ids
14
  from apscheduler.schedulers.background import BackgroundScheduler
15
 
16
+ gemini_api_key, hf_token, dataset_repo_id, request_arxiv_repo_id, restart_repo_id = get_secrets()
17
+ titles, date_dict, requested_arxiv_ids_df, arxivid2data = initialize_data(dataset_repo_id, request_arxiv_repo_id)
 
 
 
 
18
 
19
+ from ui import (
20
+ get_paper_by_year, get_paper_by_month, get_paper_by_day,
21
+ set_papers, set_paper, set_date, change_exp_type, add_arxiv_ids_to_queue,
22
+ before_chat_begin, chat_stream, chat_reset
23
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  sorted_year = sorted(date_dict.keys())
26
  last_year = sorted_year[-1]
 
31
  last_papers = date_dict[last_year][last_month][last_day]
32
  selected_paper = last_papers[0]
33
 
34
+ with gr.Blocks(css=STYLE, theme=gr.themes.Soft()) as demo:
35
+ cur_arxiv_id = gr.Textbox(selected_paper['arxiv_id'], visible=False)
36
+ local_data = gr.JSON({}, visible=False)
37
+ chat_state = gr.State({
38
+ "ppmanager_type": GradioGeminiChatPPManager # GradioMistralChatPPManager # GradioLLaMA2ChatPPManager
39
+ })
40
+
41
+ with gr.Column(elem_id="chatbot-back"):
42
+ with gr.Column(elem_id="chatbot", elem_classes=["hover-opacity"]):
43
+ close = gr.Button("𝕏", elem_id="chatbot-right-button") #elem_id="chatbot-right-button")
44
+ chatbot = gr.Chatbot(
45
+ label="Gemini 1.0 Pro", show_label=True,
46
+ show_copy_button=True, show_share_button=True,
47
+ visible=True, elem_id="chatbot-inside"
48
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ with gr.Row(elem_id="chatbot-bottm"):
51
+ reset = gr.Button("πŸ—‘οΈ Reset")
52
+ regen = gr.Button("πŸ”„ Regenerate", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ prompt_txtbox = gr.Textbox(placeholder="Ask anything.....", elem_id="chatbot-txtbox", elem_classes=["textbox-no-label"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  gr.Markdown("# Let's explore papers with auto generated Q&As")
57
 
58
  with gr.Column(elem_id="control-panel", elem_classes=["group"]):
 
83
  search_r9 = gr.Button(visible=False, elem_id="search_r9", elem_classes=["no-radius"])
84
  search_r10 = gr.Button(visible=False, elem_id="search_r10", elem_classes=["no-radius"])
85
 
 
 
86
  with gr.Column(scale=7):
87
+ title = gr.Markdown(f"# {selected_paper['title']}", elem_classes=["markdown-center"])
88
  # with gr.Row():
89
  with gr.Row():
90
  arxiv_link = gr.Markdown(
91
+ "[![arXiv](https://img.shields.io/badge/arXiv-%s-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/%s)" % (selected_paper['arxiv_id'], selected_paper['arxiv_id']),
92
+ elem_classes=["markdown-center"]
93
  )
94
  hf_paper_link = gr.Markdown(
95
+ "[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-lg.svg)](https://huggingface.co/papers/%s)" % selected_paper['arxiv_id'],
96
+ elem_classes=["markdown-center"]
97
  )
98
+ chat_button = gr.Button("πŸ’¬ about paper", interactive=True, elem_id="chat-button")
99
 
100
  summary = gr.Markdown(f"{selected_paper['summary']}", elem_classes=["small-font"])
101
 
 
 
 
102
  with gr.Column(elem_id="qna_block", visible=True):
103
  with gr.Row():
104
  with gr.Column(scale=7):
 
159
  headers=["Requested arXiv IDs"], col_count=(1, "fixed"),
160
  value=requested_arxiv_ids_df,
161
  datatype=["str"],
162
+ interactive=False,
163
  )
164
 
165
  arxiv_id_enter = gr.Textbox(placeholder="Enter comma separated arXiv IDs...", elem_classes=["textbox-no-label"])
 
178
  search_r1.click(set_date, search_r1, [year_dd, month_dd, day_dd]).then(
179
  set_papers,
180
  inputs=[year_dd, month_dd, day_dd, search_r1],
181
+ outputs=[cur_arxiv_id, papers_dd, search_in]
182
  )
183
 
184
  search_r2.click(set_date, search_r2, [year_dd, month_dd, day_dd]).then(
185
  set_papers,
186
  inputs=[year_dd, month_dd, day_dd, search_r2],
187
+ outputs=[cur_arxiv_id, papers_dd, search_in]
188
  )
189
 
190
  search_r3.click(set_date, search_r3, [year_dd, month_dd, day_dd]).then(
191
  set_papers,
192
  inputs=[year_dd, month_dd, day_dd, search_r3],
193
+ outputs=[cur_arxiv_id, papers_dd, search_in]
194
  )
195
 
196
  search_r4.click(set_date, search_r4, [year_dd, month_dd, day_dd]).then(
197
  set_papers,
198
  inputs=[year_dd, month_dd, day_dd, search_r4],
199
+ outputs=[cur_arxiv_id, papers_dd, search_in]
200
  )
201
 
202
  search_r5.click(set_date, search_r5, [year_dd, month_dd, day_dd]).then(
203
  set_papers,
204
  inputs=[year_dd, month_dd, day_dd, search_r5],
205
+ outputs=[cur_arxiv_id, papers_dd, search_in]
206
  )
207
 
208
  search_r6.click(set_date, search_r6, [year_dd, month_dd, day_dd]).then(
209
  set_papers,
210
  inputs=[year_dd, month_dd, day_dd, search_r6],
211
+ outputs=[cur_arxiv_id, papers_dd, search_in]
212
  )
213
 
214
  search_r7.click(set_date, search_r7, [year_dd, month_dd, day_dd]).then(
215
  set_papers,
216
  inputs=[year_dd, month_dd, day_dd, search_r7],
217
+ outputs=[cur_arxiv_id, papers_dd, search_in]
218
  )
219
 
220
  search_r8.click(set_date, search_r8, [year_dd, month_dd, day_dd]).then(
221
  set_papers,
222
  inputs=[year_dd, month_dd, day_dd, search_r8],
223
+ outputs=[cur_arxiv_id, papers_dd, search_in]
224
  )
225
 
226
  search_r9.click(set_date, search_r9, [year_dd, month_dd, day_dd]).then(
227
  set_papers,
228
  inputs=[year_dd, month_dd, day_dd, search_r9],
229
+ outputs=[cur_arxiv_id, papers_dd, search_in]
230
  )
231
 
232
  search_r10.click(set_date, search_r10, [year_dd, month_dd, day_dd]).then(
233
  set_papers,
234
  inputs=[year_dd, month_dd, day_dd, search_r10],
235
+ outputs=[cur_arxiv_id, papers_dd, search_in]
236
  )
237
 
238
+ year_dd.input(get_paper_by_year, inputs=[year_dd], outputs=[month_dd, day_dd, papers_dd]).then(
239
+ set_paper, [year_dd, month_dd, day_dd, papers_dd],
 
 
 
 
 
240
  [
241
+ cur_arxiv_id,
242
+ title, arxiv_link, hf_paper_link, summary,
243
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
244
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
245
  breath_q_0, breath_q_eli5_0, breath_q_expert_0,
 
254
  ]
255
  )
256
 
257
+ month_dd.input(get_paper_by_month, inputs=[year_dd, month_dd], outputs=[day_dd, papers_dd]).then(
258
+ set_paper, [year_dd, month_dd, day_dd, papers_dd],
 
 
 
 
 
259
  [
260
+ cur_arxiv_id,
261
  title, arxiv_link, hf_paper_link, summary,
262
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
263
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
 
273
  ]
274
  )
275
 
276
+ day_dd.input(get_paper_by_day, inputs=[year_dd, month_dd, day_dd], outputs=[papers_dd]).then(
277
+ set_paper, [year_dd, month_dd, day_dd, papers_dd],
 
 
 
 
 
278
  [
279
+ cur_arxiv_id,
280
  title, arxiv_link, hf_paper_link, summary,
281
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
282
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
 
292
  ]
293
  )
294
 
295
+ papers_dd.change(set_paper, [year_dd, month_dd, day_dd, papers_dd],
 
 
296
  [
297
+ cur_arxiv_id,
298
  title, arxiv_link, hf_paper_link, summary,
299
  basic_q_0, basic_q_eli5_0, basic_q_expert_0,
300
  depth_q_0, depth_q_eli5_0, depth_q_expert_0,
 
329
  basic_q_eli5_2, basic_q_expert_2, depth_q_eli5_2, depth_q_expert_2, breath_q_eli5_2, breath_q_expert_2
330
  ]
331
  )
332
+
333
+ chat_button.click(None, [cur_arxiv_id], [local_data, chatbot], js=OPEN_CHAT_IF)
334
+ close.click(None, None, None,js=CLOSE_CHAT_IF)
335
 
336
+ prompt_txtbox.submit(
337
+ before_chat_begin, None, [close, reset, regen]
338
+ ).then(
339
+ chat_stream,
340
+ [cur_arxiv_id, local_data, prompt_txtbox, chat_state],
341
+ [prompt_txtbox, chatbot, local_data, close, reset, regen]
342
+ ).then(
343
+ None, [cur_arxiv_id, local_data], None,
344
+ js=UPDATE_CHAT_HISTORY
345
+ )
346
+
347
+ reset.click(
348
+ before_chat_begin, None, [close, reset, regen]
349
+ ).then(
350
+ chat_reset,
351
+ [local_data, chat_state],
352
+ [prompt_txtbox, chatbot, local_data, close, reset, regen]
353
+ ).then(
354
+ None, [cur_arxiv_id, local_data], None,
355
+ js=UPDATE_CHAT_HISTORY
356
  )
357
 
358
+ demo.load(lambda: update_dataframe(request_arxiv_repo_id), None, arxiv_queue, every=180)
359
+ # demo.load(None, None, [chatbot, local_data], js=GET_LOCAL_STORAGE % idx.value)
360
+
361
  start_date = datetime.now() + timedelta(minutes=1)
362
  scheduler = BackgroundScheduler()
363
  scheduler.add_job(
 
368
  gemini_api_key,
369
  dataset_repo_id,
370
  request_arxiv_repo_id,
371
+ hf_token,
372
+ restart_repo_id
373
  ],
374
  start_date=start_date
375
  )
background.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import pandas as pd
3
+ from huggingface_hub import HfApi
4
+
5
+ from utils import push_to_hf_hub
6
+ from paper.download import download_pdf_from_arxiv
7
+ from paper.download import get_papers_from_arxiv_ids
8
+ from paper.parser import extract_text_and_figures
9
+ from gen.gemini import get_basic_qa, get_deep_qa
10
+
11
+ def _filter_function(example, ids):
12
+ ids_e = example['Requested arXiv IDs']
13
+ for iid in ids:
14
+ if iid in ids_e:
15
+ ids_e.remove(iid)
16
+ example['Requested arXiv IDs'] = ids_e
17
+
18
+ print(example)
19
+ return example
20
+
21
+ def process_arxiv_ids(gemini_api, hf_repo_id, req_hf_repo_id, hf_token, restart_repo_id, how_many=10):
22
+ arxiv_ids = []
23
+
24
+ ds1 = datasets.load_dataset(req_hf_repo_id)
25
+ for d in ds1['train']:
26
+ req_arxiv_ids = d['Requested arXiv IDs']
27
+ if len(req_arxiv_ids) > 0 and req_arxiv_ids[0] != "top":
28
+ arxiv_ids = arxiv_ids + req_arxiv_ids
29
+
30
+ arxiv_ids = arxiv_ids[:how_many]
31
+
32
+ if arxiv_ids is not None and len(arxiv_ids) > 0:
33
+ print(f"1. Get metadata for the papers [{arxiv_ids}]")
34
+ papers = get_papers_from_arxiv_ids(arxiv_ids)
35
+ print("...DONE")
36
+
37
+ print("2. Generating QAs for the paper")
38
+ for paper in papers:
39
+ try:
40
+ title = paper['title']
41
+ target_date = paper['target_date']
42
+ abstract = paper['paper']['summary']
43
+ arxiv_id = paper['paper']['id']
44
+ authors = paper['paper']['authors']
45
+
46
+ print(f"...PROCESSING ON[{arxiv_id}, {title}]")
47
+ print(f"......Downloading the paper PDF")
48
+ filename = download_pdf_from_arxiv(arxiv_id)
49
+ print(f"......DONE")
50
+
51
+ print(f"......Extracting text and figures")
52
+ texts, figures = extract_text_and_figures(filename)
53
+ text =' '.join(texts)
54
+ print(f"......DONE")
55
+
56
+ print(f"......Generating the seed(basic) QAs")
57
+ qnas = get_basic_qa(text, gemini_api_key=gemini_api, trucate=30000)
58
+ qnas['title'] = title
59
+ qnas['abstract'] = abstract
60
+ qnas['authors'] = ','.join(authors)
61
+ qnas['arxiv_id'] = arxiv_id
62
+ qnas['target_date'] = target_date
63
+ qnas['full_text'] = text
64
+ print(f"......DONE")
65
+
66
+ print(f"......Generating the follow-up QAs")
67
+ qnas = get_deep_qa(text, qnas, gemini_api_key=gemini_api, trucate=30000)
68
+ del qnas["qna"]
69
+ print(f"......DONE")
70
+
71
+ print(f"......Exporting to HF Dataset repo at [{hf_repo_id}]")
72
+ df = pd.DataFrame([qnas])
73
+ ds = datasets.Dataset.from_pandas(df)
74
+ ds = ds.cast_column("target_date", datasets.features.Value("timestamp[s]"))
75
+ push_to_hf_hub(ds, hf_repo_id, hf_token)
76
+ print(f"......DONE")
77
+
78
+ print(f"......Updating request arXiv HF Dataset repo at [{req_hf_repo_id}]")
79
+ ds1 = ds1['train'].map(
80
+ lambda example: _filter_function(example, [arxiv_id])
81
+ ).filter(
82
+ lambda example: len(example['Requested arXiv IDs']) > 0
83
+ )
84
+ ds1.push_to_hub(req_hf_repo_id, token=hf_token)
85
+
86
+ print(f"......DONE")
87
+ except Exception as e:
88
+ print(f".......failed due to exception {e}")
89
+ continue
90
+
91
+ HfApi(token=hf_token).restart_space(
92
+ repo_id=restart_repo_id, token=hf_token
93
+ )
constants/context.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_GLOBAL_CTX = """
2
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.
3
+ Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
4
+ Please ensure that your responses are socially unbiased and positive in nature.
5
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
6
+ If you don't know the answer to a question, please don't share false information.
7
+
8
+ Based on the above statement, answer questions based on the text below.
9
+ ------------------------------------------------------------------------
10
+ %s
11
+ """
12
+
13
+ placeholder = "In each conversation, question is placed after [INST] while your answer should be placed after [/INST]. By looking [INST] and [/INST], you must consider multi-turn conversations."
constants/js.py CHANGED
@@ -83,14 +83,113 @@ function search(searchIn, maxResults = 3) {{
83
  }}
84
  """
85
 
86
- UPDATE_IF_TYPE = f"""
87
- function chage_if_type(if_type) {{
88
- if (if_type == 'Q&As') {{
89
- document.getElementById('chat_block').style.display = 'none';
90
- document.getElementById('qna_block').style.display = 'block';
91
- }} else {{
92
- document.getElementById('chat_block').style.display = 'block';
93
- document.getElementById('qna_block').style.display = 'none';
94
- }}
95
- }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  """
 
83
  }}
84
  """
85
 
86
+ UPDATE_IF_TYPE = """
87
+ function chage_if_type() {
88
+ document.querySelector("#chatbot-back").style.display = 'block';
89
+ document.getElementById('qna_block').style.display = 'none';
90
+ }
91
+ """
92
+
93
+
94
+ # globalThis.setStorage = (key, value)=>{
95
+ # localStorage.setItem(key, JSON.stringify(value));
96
+ # }
97
+ # globalThis.getStorage = (key, value)=>{
98
+ # return JSON.parse(localStorage.getItem(key));
99
+ # }
100
+
101
+ OPEN_CHAT_IF = """
102
+ function (arXivId) {
103
+ var localData = localStorage.getItem('localData');
104
+ if (!localData) {
105
+ localData = {}; // Initialize if it doesn't exist
106
+ }
107
+ else {
108
+ localData = JSON.parse(localData);
109
+ }
110
+
111
+ if (!localData[arXivId]) {
112
+ localData[arXivId] = { ctx: '', pingpongs: [] };
113
+ }
114
+
115
+ localStorage.setItem('localData', JSON.stringify(localData));
116
+
117
+ document.querySelector("#chatbot-back").classList.add("visible");
118
+
119
+ pingpongs = [];
120
+ localData[arXivId]['pingpongs'].forEach(element =>{
121
+ pingpongs.push([element.ping, element.pong]);
122
+ });
123
+
124
+ return [localData[arXivId], pingpongs];
125
+ }
126
+ """
127
+
128
+ CLOSE_CHAT_IF = """
129
+ function close() {
130
+ setTimeout(function() {
131
+ document.querySelector("#chatbot-back").classList.remove("visible"); // Remove after a slight delay
132
+ }, 100); // 100-millisecond delay
133
+ }
134
+ """
135
+
136
+ UPDATE_CHAT_HISTORY = """
137
+ function (arXivId, data) {
138
+ console.log(arXivId)
139
+ console.log(data);
140
+ if (localStorage.getItem('localData') === null) {
141
+ localStorage['localData'] = {};
142
+ }
143
+
144
+ var localData = localStorage.getItem('localData');
145
+ localData = JSON.parse(localData);
146
+ localData[arXivId] = data;
147
+ console.log(localData[arXivId]);
148
+ localStorage.setItem('localData', JSON.stringify(localData));
149
+ }
150
+ """
151
+
152
+
153
+ GET_LOCAL_STORAGE = """
154
+ function() {
155
+ globalThis.setStorage = (arXivId, value) => {
156
+ console.log(value);
157
+ if (localStorage.getItem('localData') === null) {
158
+ localStorage['localData'] = {};
159
+ }
160
+
161
+ var localData = localStorage.getItem('localData');
162
+ localData = JSON.parse(localData);
163
+ localData[arXivId] = value;
164
+ console.log(localData[arXivId]);
165
+ localStorage.setItem('localData', JSON.stringify(localData));
166
+ }
167
+
168
+ globalThis.getStorage = (arXivId)=>{
169
+ var localData = localStorage.getItem('localData');
170
+ console.log(localData);
171
+ if (!localData) {
172
+ localData = {}; // Initialize if it doesn't exist
173
+ }
174
+ else {
175
+ localData = JSON.parse(localData);
176
+ }
177
+
178
+ if (!localData[arXivId]) {
179
+ localData[arXivId] = { ctx: '', pingpongs: [] };
180
+ }
181
+
182
+ localStorage.setItem('localData', JSON.stringify(localData));
183
+ console.log(localData[arXivId]['pingpongs']);
184
+ return [localData[arXivId], localData[arXivId]['pingpongs']];
185
+ }
186
+
187
+ var localData = localStorage.getItem('localData');
188
+
189
+ if(!localData) {
190
+ localData = {}
191
+ localStorage.setItem('localData', JSON.stringify(localData));
192
+ }
193
+ return [localData['%s']['pingpongs'], localData];
194
+ }
195
  """
constants/styles.py CHANGED
@@ -1,7 +1,7 @@
1
  STYLE = """
2
 
3
- @media only screen and (min-width: 700px) {
4
- .main {
5
  width: 70% !important;
6
  margin: 0 auto; /* Center the container */
7
  }
@@ -76,4 +76,99 @@ h3 {
76
  #control-panel {
77
  margin-bottom: 30px;
78
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  """
 
1
  STYLE = """
2
 
3
+ .main {
4
+ @media only screen and (min-width: 1000px) {
5
  width: 70% !important;
6
  margin: 0 auto; /* Center the container */
7
  }
 
76
  #control-panel {
77
  margin-bottom: 30px;
78
  }
79
+
80
+ #chatbot {
81
+ background-color: white;
82
+ border: 1px solid #ccc;
83
+ padding: 20px;
84
+ box-shadow: 0px 5px 5px rgba(0, 0, 0, 0.3);
85
+ border-radius: 30px;
86
+ height: 80%;
87
+ width: 80%;
88
+
89
+ position: fixed;
90
+ top: 50%;
91
+ left: 50%;
92
+ transform: translate(-50%, -50%);
93
+ z-index: 1000; /* Or a high enough value to stay on top */
94
+
95
+ @media (max-width: 768px) { /* Adjust this breakpoint as needed */
96
+ width: 95%;
97
+ }
98
+
99
+ @media (prefers-color-scheme: dark) {
100
+ background-color: dimgrey;
101
+ }
102
+ }
103
+
104
+ #chat-button {
105
+ border-radius: 40px;
106
+ padding: 0px;
107
+ margin: 0px;
108
+ margin-left: 30px;
109
+ margin-right: 30px;
110
+ font-size: 13pt !important;
111
+
112
+ @media only screen and (min-width: 500px) {
113
+ font-size: 10pt;
114
+ margin: 0 auto; /* Center the container */
115
+ }
116
+ }
117
+
118
+ #chatbot-inside {
119
+ height: 100% !important;
120
+ border-width: 1px !important;
121
+ border-color: lightgray !important;
122
+ }
123
+
124
+ #chatbot-txtbox {
125
+ padding-bottom: 25px;
126
+ }
127
+
128
+ #chatbot-bottm {
129
+ padding-left: 10px;
130
+ padding-right: 10px;
131
+ }
132
+
133
+ #chatbot-right-button {
134
+ float: right;
135
+ width: 20px;
136
+ font-size: 17pt;
137
+ }
138
+
139
+ #chatbot-info {
140
+ word-break: break-word;
141
+ }
142
+
143
+ #chatbot-back {
144
+ position: absolute; /* Stay in place even when scrolling */
145
+ z-index: 1000; /* Ensure it's on top of everything else */
146
+ width: 100%;
147
+ height: 100%;
148
+ left: 0px;
149
+ top: 0px;
150
+
151
+ opacity: 0;
152
+ visibility: hidden; /* Ensures the element is not interactive */
153
+ transition: opacity 0.5s ease, visibility 0s 0.5s; /* Transition for opacity and delay visibility */
154
+ }
155
+
156
+ #chatbot-back.visible {
157
+ opacity: 1;
158
+ visibility: visible; /* Now visible and interactive */
159
+ transition: opacity 0.5s ease; /* Smooth transition for opacity */
160
+ }
161
+
162
+ .hover-opacity {
163
+ opacity: 0.8; /* Normal opacity of the element */
164
+ transition: opacity 0.3s ease-in-out; /* Smooth opacity change */
165
+ }
166
+
167
+ .hover-opacity:hover {
168
+ opacity: 1; /* Full opacity on hover */
169
+ }
170
+
171
+ .markdown-center {
172
+ text-align: -webkit-center;
173
+ }
174
  """
gen/gemini.py CHANGED
@@ -69,7 +69,7 @@ def call_gemini(prompt="", API_KEY=None, given_text=None, given_image=None, gene
69
  response = model.generate_content(prompt_parts)
70
  return response.text
71
 
72
- def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=5):
73
  qna_json = None
74
  cur_retry = 0
75
 
 
69
  response = model.generate_content(prompt_parts)
70
  return response.text
71
 
72
+ def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=10):
73
  qna_json = None
74
  cur_retry = 0
75
 
gen/gemini_chat.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import asyncio
3
+ import google.generativeai as genai
4
+
5
+ from pingpong import PingPong
6
+ from pingpong.pingpong import PPManager
7
+ from pingpong.pingpong import PromptFmt
8
+ from pingpong.pingpong import UIFmt
9
+ from pingpong.gradio import GradioChatUIFmt
10
+
11
+ class GeminiChatPromptFmt(PromptFmt):
12
+ @classmethod
13
+ def ctx(cls, context):
14
+ if context is None or context == "":
15
+ return None
16
+ else:
17
+ return {
18
+ "role": "system",
19
+ "parts": [context]
20
+ }
21
+
22
+ @classmethod
23
+ def prompt(cls, pingpong, truncate_size):
24
+ ping = pingpong.ping[:truncate_size]
25
+ pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
26
+ result = [
27
+ {
28
+ "role": "user",
29
+ "parts": [ping]
30
+ }
31
+ ]
32
+ if pong != "":
33
+ result = result + [
34
+ {
35
+ "role": "model",
36
+ "parts": [pong]
37
+ }
38
+ ]
39
+
40
+ return result
41
+
42
+ class GeminiChatPPManager(PPManager):
43
+ def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=GeminiChatPromptFmt, truncate_size: int=None):
44
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
45
+ to_idx = len(self.pingpongs)
46
+
47
+ pingpongs = copy.deepcopy(self.pingpongs)
48
+ ctx = fmt.ctx(self.ctx)
49
+ ctx = ctx['parts'][0] if ctx is not None else ""
50
+ results = []
51
+
52
+ for idx, pingpong in enumerate(pingpongs[from_idx:to_idx]):
53
+ if idx == 0:
54
+ pingpong.ping = f"SYSTEM: {ctx} ----------- \n" + pingpong.ping
55
+ results += fmt.prompt(pingpong, truncate_size=truncate_size)
56
+
57
+ return results
58
+
59
+ class GradioGeminiChatPPManager(GeminiChatPPManager):
60
+ def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
61
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
62
+ to_idx = len(self.pingpongs)
63
+
64
+ results = []
65
+
66
+ for pingpong in self.pingpongs[from_idx:to_idx]:
67
+ results.append(fmt.ui(pingpong))
68
+
69
+ return results
70
+
71
+ def init(api_key):
72
+ genai.configure(api_key=api_key)
73
+
74
+ def _default_gen_text():
75
+ return {
76
+ "temperature": 0.9,
77
+ "top_p": 1,
78
+ "top_k": 1,
79
+ "max_output_tokens": 2048,
80
+ }
81
+
82
+ def _default_safety_settings():
83
+ return [
84
+ {
85
+ "category": "HARM_CATEGORY_HARASSMENT",
86
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
87
+ },
88
+ {
89
+ "category": "HARM_CATEGORY_HATE_SPEECH",
90
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
91
+ },
92
+ {
93
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
94
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
95
+ },
96
+ {
97
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
98
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
99
+ },
100
+ ]
101
+
102
+ async def _word_generator(sentence):
103
+ for word in sentence.split():
104
+ yield word
105
+ delay = 0.03 + (len(word) * 0.005)
106
+ await asyncio.sleep(delay) # Simulate a short delay
107
+
108
+ async def gen_text(
109
+ prompts,
110
+ gen_config=_default_gen_text(),
111
+ safety_settings=_default_safety_settings(),
112
+ stream=True
113
+ ):
114
+ model = genai.GenerativeModel(model_name="gemini-1.0-pro",
115
+ generation_config=gen_config,
116
+ safety_settings=safety_settings)
117
+
118
+ user_prompt = prompts[-1]
119
+ prompts = prompts[:-1]
120
+ convo = model.start_chat(history=prompts)
121
+
122
+ resps = await convo.send_message_async(
123
+ user_prompt["parts"][0], stream=stream
124
+ )
125
+
126
+ async for resp in resps:
127
+ async for word in _word_generator(resp.text):
128
+ yield word + " "
129
+
gen/openllm.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import sseclient
5
+
6
+ from pingpong import PingPong
7
+ from pingpong.pingpong import PPManager
8
+ from pingpong.pingpong import PromptFmt
9
+ from pingpong.pingpong import UIFmt
10
+ from pingpong.gradio import GradioChatUIFmt
11
+
12
+ class MistralChatPromptFmt(PromptFmt):
13
+ @classmethod
14
+ def ctx(cls, context):
15
+ if context is None or context == "":
16
+ return ""
17
+ else:
18
+ return f"""{context}
19
+
20
+ """
21
+
22
+ @classmethod
23
+ def prompt(cls, pingpong, truncate_size):
24
+ ping = pingpong.ping[:truncate_size]
25
+ pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size] + "</s>"
26
+ return f"""<s>[INST] {ping} [/INST] {pong}
27
+ """
28
+
29
+ class MistralChatPPManager(PPManager):
30
+ def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=MistralChatPromptFmt, truncate_size: int=None):
31
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
32
+ to_idx = len(self.pingpongs)
33
+
34
+ results = fmt.ctx(self.ctx)
35
+
36
+ for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
37
+ results += fmt.prompt(pingpong, truncate_size=truncate_size)
38
+
39
+ return results
40
+
41
+ class GradioMistralChatPPManager(MistralChatPPManager):
42
+ def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
43
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
44
+ to_idx = len(self.pingpongs)
45
+
46
+ results = []
47
+
48
+ for pingpong in self.pingpongs[from_idx:to_idx]:
49
+ results.append(fmt.ui(pingpong))
50
+
51
+ return results
52
+
53
+
54
+ class LLaMA2ChatPromptFmt(PromptFmt):
55
+ @classmethod
56
+ def ctx(cls, context):
57
+ if context is None or context == "":
58
+ return ""
59
+ else:
60
+ return f"""<<SYS>>
61
+ {context}
62
+ <</SYS>>
63
+ """
64
+
65
+ @classmethod
66
+ def prompt(cls, pingpong, truncate_size):
67
+ ping = pingpong.ping[:truncate_size]
68
+ pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size]
69
+ return f"""[INST] {ping} [/INST] {pong}"""
70
+
71
+ class LLaMA2ChatPPManager(PPManager):
72
+ def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None):
73
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
74
+ to_idx = len(self.pingpongs)
75
+
76
+ results = fmt.ctx(self.ctx)
77
+
78
+ for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
79
+ results += fmt.prompt(pingpong, truncate_size=truncate_size)
80
+
81
+ return results
82
+
83
+ class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager):
84
+ def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
85
+ if to_idx == -1 or to_idx >= len(self.pingpongs):
86
+ to_idx = len(self.pingpongs)
87
+
88
+ results = []
89
+
90
+ for pingpong in self.pingpongs[from_idx:to_idx]:
91
+ results.append(fmt.ui(pingpong))
92
+
93
+ return results
94
+
95
+ async def gen_text(
96
+ prompt,
97
+ hf_model='mistralai/Mistral-7B-Instruct-v0.2', # 'mistralai/Mixtral-8x7B-Instruct-v0.1', # 'mistralai/Mistral-7B-Instruct-v0.1', # 'meta-llama/Llama-2-70b-chat-hf',
98
+ hf_token=None,
99
+ parameters=None
100
+ ):
101
+ if hf_token is None:
102
+ raise ValueError("Hugging Face Token is not set")
103
+
104
+ if parameters is None:
105
+ parameters = {
106
+ 'max_new_tokens': 512,
107
+ 'do_sample': True,
108
+ 'return_full_text': False,
109
+ 'temperature': 1.0,
110
+ 'top_k': 50,
111
+ # 'top_p': 1.0,
112
+ 'repetition_penalty': 1.2
113
+ }
114
+
115
+ url = f'https://api-inference.huggingface.co/models/{hf_model}'
116
+ headers={
117
+ 'Authorization': f'Bearer {hf_token}',
118
+ 'Content-type': 'application/json'
119
+ }
120
+ data = {
121
+ 'inputs': prompt,
122
+ 'stream': True,
123
+ 'options': {
124
+ 'use_cache': False,
125
+ },
126
+ 'parameters': parameters
127
+ }
128
+
129
+ r = requests.post(
130
+ url,
131
+ headers=headers,
132
+ data=json.dumps(data),
133
+ stream=True
134
+ )
135
+
136
+ try:
137
+ client = sseclient.SSEClient(r)
138
+ for event in client.events():
139
+ yield json.loads(event.data)['token']['text']
140
+ except Exception as e:
141
+ print(e)
142
+
143
+ def gen_text_none_stream(
144
+ prompt,
145
+ hf_model='meta-llama/Llama-2-70b-chat-hf',
146
+ hf_token=None,
147
+ ):
148
+ parameters = {
149
+ 'max_new_tokens': 64,
150
+ 'do_sample': True,
151
+ 'return_full_text': False,
152
+ 'temperature': 0.7,
153
+ 'top_k': 10,
154
+ # 'top_p': 1.0,
155
+ 'repetition_penalty': 1.2
156
+ }
157
+
158
+ url = f'https://api-inference.huggingface.co/models/{hf_model}'
159
+ headers={
160
+ 'Authorization': f'Bearer {hf_token}',
161
+ 'Content-type': 'application/json'
162
+ }
163
+ data = {
164
+ 'inputs': prompt,
165
+ 'stream': False,
166
+ 'options': {
167
+ 'use_cache': False,
168
+ },
169
+ 'parameters': parameters
170
+ }
171
+
172
+ r = requests.post(
173
+ url,
174
+ headers=headers,
175
+ data=json.dumps(data),
176
+ )
177
+
178
+ return json.loads(r.text)[0]["generated_text"]
init.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import datasets
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+
7
+ from datetime import datetime, timedelta
8
+ from background import process_arxiv_ids
9
+ from apscheduler.schedulers.background import BackgroundScheduler
10
+
11
+ def _count_nans(row):
12
+ count = 0
13
+
14
+ for _, (k, v) in enumerate(row.items()):
15
+ if v is None:
16
+ count = count + 1
17
+
18
+ return count
19
+
20
+ def _initialize_requested_arxiv_ids(request_ds):
21
+ requested_arxiv_ids = []
22
+
23
+ for request_d in request_ds['train']:
24
+ arxiv_ids = request_d['Requested arXiv IDs']
25
+ requested_arxiv_ids = requested_arxiv_ids + arxiv_ids
26
+
27
+ requested_arxiv_ids_df = pd.DataFrame({'Requested arXiv IDs': requested_arxiv_ids})
28
+ return requested_arxiv_ids_df
29
+
30
+ def _initialize_paper_info(source_ds):
31
+ title2qna, date2qna = {}, {}
32
+ date_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
33
+ arxivid2data = {}
34
+ count = 0
35
+
36
+ for data in source_ds["train"]:
37
+ date = data["target_date"].strftime("%Y-%m-%d")
38
+ arxiv_id = data["arxiv_id"]
39
+
40
+ if date in date2qna:
41
+ papers = copy.deepcopy(date2qna[date])
42
+ for paper in papers:
43
+ if paper["title"] == data["title"]:
44
+ if _count_nans(paper) > _count_nans(data):
45
+ date2qna[date].remove(paper)
46
+
47
+ date2qna[date].append(data)
48
+ del papers
49
+ else:
50
+ date2qna[date] = [data]
51
+
52
+ for date in date2qna:
53
+ year, month, day = date.split("-")
54
+ papers = date2qna[date]
55
+ for paper in papers:
56
+ title2qna[paper["title"]] = paper
57
+ arxivid2data[paper['arxiv_id']] = {"idx": count, "paper": paper}
58
+ date_dict[year][month][day].append(paper)
59
+
60
+ titles = title2qna.keys()
61
+
62
+ return titles, date_dict, arxivid2data
63
+
64
+ def initialize_data(source_data_repo_id, request_data_repo_id):
65
+ global date_dict, arxivid2data
66
+ global requested_arxiv_ids_df
67
+
68
+ source_ds = datasets.load_dataset(source_data_repo_id)
69
+ request_ds = datasets.load_dataset(request_data_repo_id)
70
+
71
+ titles, date_dict, arxivid2data = _initialize_paper_info(source_ds)
72
+ requested_arxiv_ids_df = _initialize_requested_arxiv_ids(request_ds)
73
+
74
+ return (
75
+ titles, date_dict, requested_arxiv_ids_df, arxivid2data
76
+ )
77
+
78
+ def update_dataframe(request_data_repo_id):
79
+ request_ds = datasets.load_dataset(request_data_repo_id)
80
+ return _initialize_requested_arxiv_ids(request_ds)
81
+
82
+ def get_secrets():
83
+ global gemini_api_key
84
+ global hf_token
85
+ global request_arxiv_repo_id
86
+ global dataset_repo_id
87
+
88
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
89
+ hf_token = os.getenv("HF_TOKEN")
90
+ dataset_repo_id = os.getenv("SOURCE_DATA_REPO_ID")
91
+ request_arxiv_repo_id = os.getenv("REQUEST_DATA_REPO_ID")
92
+ restart_repo_id = os.getenv("RESTART_TARGET_SPACE_REPO_ID", "chansung/paper_qa")
93
+
94
+ return (
95
+ gemini_api_key,
96
+ hf_token,
97
+ dataset_repo_id,
98
+ request_arxiv_repo_id,
99
+ restart_repo_id
100
+ )
requirements.txt CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  google-generativeai
2
  pypdf2
3
  PyMuPDF
@@ -6,4 +10,4 @@ requests
6
  toml
7
  datasets
8
  flatdict
9
- APScheduler
 
1
+ bingbong
2
+ sseclient-py
3
+ chromadb
4
+ pydantic-settings
5
  google-generativeai
6
  pypdf2
7
  PyMuPDF
 
10
  toml
11
  datasets
12
  flatdict
13
+ APScheduler
ui.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import copy
3
+ import json
4
+ import datasets
5
+ import gradio as gr
6
+ import pandas as pd
7
+
8
+ from pingpong import PingPong
9
+ from pingpong.context import CtxLastWindowStrategy
10
+
11
+ from gen.openllm import gen_text as open_llm_gen_text
12
+ from gen.gemini_chat import gen_text as gemini_gen_text
13
+ from gen.gemini_chat import init as gemini_init
14
+ from constants.context import DEFAULT_GLOBAL_CTX
15
+
16
+ from init import (
17
+ requested_arxiv_ids_df,
18
+ date_dict,
19
+ arxivid2data,
20
+ request_arxiv_repo_id,
21
+ hf_token,
22
+ gemini_api_key
23
+ )
24
+ from utils import push_to_hf_hub
25
+
26
+ def get_paper_by_year(year):
27
+ months = sorted(date_dict[year].keys())
28
+ last_month = months[-1]
29
+
30
+ days = sorted(date_dict[year][last_month].keys())
31
+ last_day = days[-1]
32
+
33
+ papers = list(set(
34
+ [paper["title"] for paper in date_dict[year][last_month][last_day]]
35
+ ))
36
+
37
+ return (
38
+ gr.Dropdown(choices=months, value=last_month),
39
+ gr.Dropdown(choices=days, value=last_day),
40
+ gr.Dropdown(choices=papers, value=papers[0])
41
+ )
42
+
43
+ def get_paper_by_month(year, month):
44
+ days = sorted(date_dict[year][month].keys())
45
+ last_day = days[-1]
46
+
47
+ papers = list(set(
48
+ [paper["title"] for paper in date_dict[year][month][last_day]]
49
+ ))
50
+
51
+ return (
52
+ gr.Dropdown(choices=days, value=last_day),
53
+ gr.Dropdown(choices=papers, value=papers[0])
54
+ )
55
+
56
+ def get_paper_by_day(year, month, day):
57
+ papers = list(set(
58
+ [paper["title"] for paper in date_dict[year][month][day]]
59
+ ))
60
+ return gr.Dropdown(choices=papers, value=papers[0])
61
+
62
+ def set_papers(year, month, day, title):
63
+ papers = []
64
+ for paper in date_dict[year][month][day]:
65
+ papers.append(paper["title"])
66
+ if paper["title"] == title:
67
+ arxiv_id = paper["arxiv_id"]
68
+
69
+ papers = list(set(papers))
70
+
71
+ return (
72
+ arxiv_id,
73
+ gr.Dropdown(choices=papers, value=title),
74
+ gr.Textbox("")
75
+ )
76
+
77
+ def set_paper(year, month, day, paper_title):
78
+ selected_paper = None
79
+ for paper in date_dict[year][month][day]:
80
+ if paper["title"] == paper_title:
81
+ selected_paper = paper
82
+ break
83
+
84
+ print(type(selected_paper['arxiv_id']))
85
+
86
+ return (
87
+ selected_paper['arxiv_id'],
88
+ gr.Markdown(f"# {selected_paper['title']}"),
89
+ gr.Markdown(
90
+ "[![arXiv](https://img.shields.io/badge/arXiv-%s-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/%s)" % (selected_paper['arxiv_id'], selected_paper['arxiv_id'])
91
+ ),
92
+ gr.Markdown(
93
+ "[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-lg.svg)](https://huggingface.co/papers/%s)" % selected_paper['arxiv_id']
94
+ ),
95
+ gr.Markdown(selected_paper["summary"]),
96
+
97
+ gr.Markdown(f"### πŸ™‹ {selected_paper['0_question']}"),
98
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['0_answers:eli5']}"),
99
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['0_answers:expert']}"),
100
+ gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['0_additional_depth_q:follow up question']}"),
101
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['0_additional_depth_q:answers:eli5']}"),
102
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['0_additional_depth_q:answers:expert']}"),
103
+ gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['0_additional_breath_q:follow up question']}"),
104
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['0_additional_breath_q:answers:eli5']}"),
105
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['0_additional_breath_q:answers:expert']}"),
106
+
107
+ gr.Markdown(f"### πŸ™‹ {selected_paper['1_question']}"),
108
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['1_answers:eli5']}"),
109
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['1_answers:expert']}"),
110
+ gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['1_additional_depth_q:follow up question']}"),
111
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['1_additional_depth_q:answers:eli5']}"),
112
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['1_additional_depth_q:answers:expert']}"),
113
+ gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['1_additional_breath_q:follow up question']}"),
114
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['1_additional_breath_q:answers:eli5']}"),
115
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['1_additional_breath_q:answers:expert']}"),
116
+
117
+ gr.Markdown(f"### πŸ™‹ {selected_paper['2_question']}"),
118
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['2_answers:eli5']}"),
119
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['2_answers:expert']}"),
120
+ gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['2_additional_depth_q:follow up question']}"),
121
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['2_additional_depth_q:answers:eli5']}"),
122
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['2_additional_depth_q:answers:expert']}"),
123
+ gr.Markdown(f"### πŸ™‹πŸ™‹ {selected_paper['2_additional_breath_q:follow up question']}"),
124
+ gr.Markdown(f"β†ͺ **(ELI5)** {selected_paper['2_additional_breath_q:answers:eli5']}"),
125
+ gr.Markdown(f"β†ͺ **(Technical)** {selected_paper['2_additional_breath_q:answers:expert']}"),
126
+ )
127
+
128
+ def set_date(title):
129
+ for _, (year, months) in enumerate(date_dict.items()):
130
+ for _, (month, days) in enumerate(months.items()):
131
+ for _, (day, papers) in enumerate(days.items()):
132
+ for paper in papers:
133
+ if paper['title'] == title:
134
+ return (
135
+ gr.Dropdown(value=year),
136
+ gr.Dropdown(choices=sorted(months), value=month),
137
+ gr.Dropdown(choices=sorted(days), value=day),
138
+ )
139
+
140
+ def change_exp_type(exp_type):
141
+ if exp_type == "ELI5":
142
+ return (
143
+ gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False),
144
+ gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False),
145
+ gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False),
146
+ )
147
+ else:
148
+ return (
149
+ gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True),
150
+ gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True),
151
+ gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True), gr.Markdown(visible=False), gr.Markdown(visible=True),
152
+ )
153
+
154
+ def _filter_duplicate_arxiv_ids(arxiv_ids_to_be_added):
155
+ ds1 = datasets.load_dataset("chansung/requested-arxiv-ids-3")
156
+ ds2 = datasets.load_dataset("chansung/auto-paper-qa2")
157
+
158
+ unique_arxiv_ids = set()
159
+
160
+ for d in ds1['train']:
161
+ arxiv_ids = d['Requested arXiv IDs']
162
+ unique_arxiv_ids = set(list(unique_arxiv_ids) + arxiv_ids)
163
+
164
+ for d in ds2['train']:
165
+ arxiv_id = d['arxiv_id']
166
+ unique_arxiv_ids.add(arxiv_id)
167
+
168
+ return list(set(arxiv_ids_to_be_added) - unique_arxiv_ids)
169
+
170
+ def _is_arxiv_id_valid(arxiv_id):
171
+ pattern = r"^\d{4}\.\d{5}$"
172
+ return bool(re.match(pattern, arxiv_id))
173
+
174
+ def _get_valid_arxiv_ids(arxiv_ids_str):
175
+ valid_arxiv_ids = []
176
+ invalid_arxiv_ids = []
177
+
178
+ for arxiv_id in arxiv_ids_str.split(","):
179
+ arxiv_id = arxiv_id.strip()
180
+ if _is_arxiv_id_valid(arxiv_id):
181
+ valid_arxiv_ids.append(arxiv_id)
182
+ else:
183
+ invalid_arxiv_ids.append(arxiv_id)
184
+
185
+ return valid_arxiv_ids, invalid_arxiv_ids
186
+
187
+ def add_arxiv_ids_to_queue(queue, arxiv_ids_str):
188
+ valid_arxiv_ids, invalid_arxiv_ids = _get_valid_arxiv_ids(arxiv_ids_str)
189
+
190
+ if len(invalid_arxiv_ids) > 0:
191
+ gr.Warning(f"found invalid arXiv ids as in {invalid_arxiv_ids}")
192
+
193
+ if len(valid_arxiv_ids) > 0:
194
+ valid_arxiv_ids = _filter_duplicate_arxiv_ids(valid_arxiv_ids)
195
+
196
+ if len(valid_arxiv_ids) > 0:
197
+ valid_arxiv_ids = [[arxiv_id] for arxiv_id in valid_arxiv_ids]
198
+ gr.Warning(f"Processing on [{valid_arxiv_ids}]. Other requested arXiv IDs not found on this list should be already processed or being processed...")
199
+ valid_arxiv_ids = pd.DataFrame({'Requested arXiv IDs': valid_arxiv_ids})
200
+ queue = pd.concat([queue, valid_arxiv_ids])
201
+ queue.reset_index(drop=True)
202
+
203
+ ds = datasets.Dataset.from_pandas(valid_arxiv_ids)
204
+ push_to_hf_hub(ds, request_arxiv_repo_id, hf_token)
205
+ else:
206
+ gr.Warning(f"All requested arXiv IDs are already processed or being processed...")
207
+ else:
208
+ gr.Warning(f"No valid arXiv IDs found...")
209
+
210
+ return (
211
+ queue, gr.Textbox("")
212
+ )
213
+
214
+ # Chat
215
+
216
+ def before_chat_begin():
217
+ return (
218
+ gr.Button(interactive=False),
219
+ gr.Button(interactive=False),
220
+ gr.Button(interactive=False)
221
+ )
222
+
223
+ def _build_prompts(ppmanager, global_context, win_size=3):
224
+ dummy_ppm = copy.deepcopy(ppmanager)
225
+ dummy_ppm.ctx = global_context
226
+ lws = CtxLastWindowStrategy(win_size)
227
+ return lws(dummy_ppm)
228
+
229
+ async def chat_stream(idx, local_data, user_prompt, chat_state, ctx_num_lconv=3):
230
+ paper = arxivid2data[idx]['paper']
231
+ ppm = chat_state["ppmanager_type"].from_json(json.dumps(local_data))
232
+ ppm.add_pingpong(
233
+ PingPong(
234
+ user_prompt,
235
+ ""
236
+ )
237
+ )
238
+ prompt = _build_prompts(ppm, DEFAULT_GLOBAL_CTX % paper["full_text"].replace("\n", " ")[:30000], ctx_num_lconv)
239
+ print(prompt)
240
+
241
+ # async for result in open_llm_gen_text(
242
+ # prompt,
243
+ # hf_model='meta-llama/Llama-2-70b-chat-hf', hf_token=hf_token,
244
+ # parameters={
245
+ # 'max_new_tokens': 4906,
246
+ # 'do_sample': True,
247
+ # 'return_full_text': False,
248
+ # 'temperature': 0.7,
249
+ # 'top_k': 10,
250
+ # 'repetition_penalty': 1.2
251
+ # }
252
+ # ):
253
+ gemini_init(gemini_api_key)
254
+ async for result in gemini_gen_text(prompt):
255
+ ppm.append_pong(result)
256
+ yield "", ppm.build_uis(), str(ppm), gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
257
+
258
+ yield "", ppm.build_uis(), str(ppm), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
259
+
260
+ def chat_reset(local_data, chat_state):
261
+ ppm = chat_state["ppmanager_type"].from_json(json.dumps(local_data))
262
+ ppm.pingpongs = []
263
+
264
+ return "", ppm.build_uis(), str(ppm), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
utils.py CHANGED
@@ -1,28 +1,21 @@
1
- import pandas as pd
2
  import datasets
3
- from datasets import Dataset
 
4
  from huggingface_hub import create_repo
5
  from huggingface_hub.utils import HfHubHTTPError
6
 
7
  def push_to_hf_hub(
8
- qnas, repo_id, token, append=True
9
  ):
10
- print(1)
11
  exist = False
12
- df = pd.DataFrame([qnas])
13
- ds = Dataset.from_pandas(df)
14
- ds = ds.cast_column("target_date", datasets.features.Value("timestamp[s]"))
15
 
16
- print(2)
17
  try:
18
- create_repo(repo_id, repo_type="dataset", token=token)
19
  except HfHubHTTPError as e:
20
  exist = True
21
 
22
  if exist and append:
23
- print(3)
24
  existing_ds = datasets.load_dataset(repo_id)
25
  ds = datasets.concatenate_datasets([existing_ds['train'], ds])
26
 
27
- print(4)
28
- ds.push_to_hub(repo_id, token=token)
 
 
1
  import datasets
2
+ import pandas as pd
3
+
4
  from huggingface_hub import create_repo
5
  from huggingface_hub.utils import HfHubHTTPError
6
 
7
  def push_to_hf_hub(
8
+ ds, repo_id, hf_token, append=True
9
  ):
 
10
  exist = False
 
 
 
11
 
 
12
  try:
13
+ create_repo(repo_id, repo_type="dataset", token=hf_token)
14
  except HfHubHTTPError as e:
15
  exist = True
16
 
17
  if exist and append:
 
18
  existing_ds = datasets.load_dataset(repo_id)
19
  ds = datasets.concatenate_datasets([existing_ds['train'], ds])
20
 
21
+ ds.push_to_hub(repo_id, token=hf_token)