shaocongma commited on
Commit
d709aaf
·
2 Parent(s): 780901a 4810aa4

Merge remote-tracking branch 'origin/main' into main

Browse files
app.py CHANGED
@@ -6,18 +6,21 @@ from utils.file_operations import hash_name
6
 
7
  # note: App白屏bug:允许第三方cookie
8
  # todo:
9
- # 5. Use some simple method for simple tasks
10
- # (including: writing abstract, conclusion, generate keywords, generate figures...)
11
- # 5.1 Use GPT 3.5 for abstract, conclusion, ... (or may not)
12
- # 5.2 Use local LLM to generate keywords, figures, ...
13
- # 5.3 Use embedding to find most related papers (find a paper dataset)
14
- # 6. get logs when the procedure is not completed.
15
  # 7. 自己的文件库; 更多的prompts
16
- # 11. distinguish citep and citet
 
 
 
 
 
 
 
17
  # future:
18
- # 8. Change prompts to langchain
19
  # 4. add auto_polishing function
20
  # 12. Change link to more appealing color # after the website is built;
 
 
21
 
22
  openai_key = os.getenv("OPENAI_API_KEY")
23
  access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
@@ -109,19 +112,77 @@ with gr.Blocks(theme=theme) as demo:
109
 
110
  输入想要生成的论文名称(比如Playing Atari with Deep Reinforcement Learning), 点击Submit, 等待大概十分钟, 下载.zip格式的输出,在Overleaf上编译浏览.
111
  ''')
 
112
  with gr.Row():
113
  with gr.Column(scale=2):
114
  key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
115
  visible=not IS_OPENAI_API_KEY_AVAILABLE)
 
116
  # generator = gr.Dropdown(choices=["学术论文", "文献总结"], value="文献总结",
117
  # label="Selection", info="目前支持生成'学术论文'和'文献总结'.", interactive=True)
118
- title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
119
- label="Title", info="论文标题")
120
- description = gr.Textbox(lines=5, label="Description (Optional)", visible=False)
121
 
122
- with gr.Row():
123
- clear_button = gr.Button("Clear")
124
- submit_button = gr.Button("Submit", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Column(scale=1):
126
  style_mapping = {True: "color:white;background-color:green",
127
  False: "color:white;background-color:red"} # todo: to match website's style
@@ -133,8 +194,8 @@ with gr.Blocks(theme=theme) as demo:
133
  `OpenAI API`: <span style="{style_mapping[IS_OPENAI_API_KEY_AVAILABLE]}">{availability_mapping[IS_OPENAI_API_KEY_AVAILABLE]}</span>. `Cache`: <span style="{style_mapping[IS_CACHE_AVAILABLE]}">{availability_mapping[IS_CACHE_AVAILABLE]}</span>.''')
134
  file_output = gr.File(label="Output")
135
 
136
- clear_button.click(fn=clear_inputs, inputs=[title, description], outputs=[title, description])
137
- submit_button.click(fn=wrapped_generator, inputs=[title, description, key], outputs=file_output)
138
 
139
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
140
  demo.launch()
 
6
 
7
  # note: App白屏bug:允许第三方cookie
8
  # todo:
9
+ # 6. get logs when the procedure is not completed. *
 
 
 
 
 
10
  # 7. 自己的文件库; 更多的prompts
11
+ # 8. Decide on how to generate the main part of a paper * (Langchain/AutoGPT
12
+ # 9. Load .bibtex file to generate a pre-defined references list. *
13
+ # 1. 把paper改成纯JSON?
14
+ # 2. 实现别的功能
15
+ # 3. Check API Key GPT-4 Support.
16
+ # 8. Re-build some components using `langchain`
17
+ # - in `references.py`, use PromptTemplates.format -> str
18
+ # - in `gpt_interation`, use LLM
19
  # future:
 
20
  # 4. add auto_polishing function
21
  # 12. Change link to more appealing color # after the website is built;
22
+ # 1. Check if there are any duplicated citations
23
+ # 2. Remove potential thebibliography and bibitem in .tex file
24
 
25
  openai_key = os.getenv("OPENAI_API_KEY")
26
  access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
 
112
 
113
  输入想要生成的论文名称(比如Playing Atari with Deep Reinforcement Learning), 点击Submit, 等待大概十分钟, 下载.zip格式的输出,在Overleaf上编译浏览.
114
  ''')
115
+
116
  with gr.Row():
117
  with gr.Column(scale=2):
118
  key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
119
  visible=not IS_OPENAI_API_KEY_AVAILABLE)
120
+
121
  # generator = gr.Dropdown(choices=["学术论文", "文献总结"], value="文献总结",
122
  # label="Selection", info="目前支持生成'学术论文'和'文献总结'.", interactive=True)
 
 
 
123
 
124
+ # 每个功能做一个tab
125
+ with gr.Tab("学术论文"):
126
+ title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
127
+ label="Title", info="论文标题")
128
+
129
+ with gr.Accordion("高级设置", open=False):
130
+ description_pp = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
131
+ info="对希望生成的论文的一些描述. 包括这篇论文的创新点, 主要贡献, 等.")
132
+
133
+ interactive = False
134
+ gr.Markdown('''
135
+ ## 下面的功能我只做了UI, 还没来得及实现功能.
136
+ ''')
137
+ with gr.Row():
138
+ with gr.Column():
139
+ gr.Markdown('''
140
+ Upload .bib file (Optional)
141
+
142
+ 通过上传.bib文件来控制GPT-4模型必须参考哪些文献.
143
+ ''')
144
+ bibtex_file = gr.File(label="Upload .bib file", file_types=["text"],
145
+ interactive=interactive)
146
+ with gr.Column():
147
+ search_engine = gr.Dropdown(label="Search Engine",
148
+ choices=["ArXiv", "Semantic Scholar", "Google Scholar", "None"],
149
+ value= "Semantic Scholar",
150
+ interactive=interactive,
151
+ info="用于决定GPT-4用什么搜索引擎来搜索文献. 选择None的时候仅参考给定文献.")
152
+ tldr = gr.Checkbox(value=True, label="TLDR;",
153
+ info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
154
+ interactive = interactive),
155
+ use_cache = gr.Checkbox(label="总是重新生成",
156
+ info="选择此筐表示将不会读取已经生成好的文章.",
157
+ interactive = interactive)
158
+ slider = gr.Slider(minimum=1, maximum=30, value=20, label="最大参考文献数目",
159
+ info="过多参考文献会超出Token数限制导致报错,这里限制最大参考文献数目.")
160
+
161
+ with gr.Row():
162
+ clear_button_pp = gr.Button("Clear")
163
+ submit_button_pp = gr.Button("Submit", variant="primary")
164
+
165
+ with gr.Tab("文献综述"):
166
+ gr.Markdown('''
167
+ <h1 style="text-align: center;">Coming soon!</h1>
168
+ ''')
169
+ # topic = gr.Textbox(value="Deep Reinforcement Learning", lines=1, max_lines=1,
170
+ # label="Topic", info="文献主题")
171
+ # with gr.Accordion("Advanced Setting"):
172
+ # description_lr = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
173
+ # info="对希望生成的综述的一些描述. 包括这篇论文的创新点, 主要贡献, 等.")
174
+ # with gr.Row():
175
+ # clear_button_lr = gr.Button("Clear")
176
+ # submit_button_lr = gr.Button("Submit", variant="primary")
177
+ with gr.Tab("论文润色"):
178
+ gr.Markdown('''
179
+ <h1 style="text-align: center;">Coming soon!</h1>
180
+ ''')
181
+ with gr.Tab("帮我想想该写什么论文!"):
182
+ gr.Markdown('''
183
+ <h1 style="text-align: center;">Coming soon!</h1>
184
+ ''')
185
+
186
  with gr.Column(scale=1):
187
  style_mapping = {True: "color:white;background-color:green",
188
  False: "color:white;background-color:red"} # todo: to match website's style
 
194
  `OpenAI API`: <span style="{style_mapping[IS_OPENAI_API_KEY_AVAILABLE]}">{availability_mapping[IS_OPENAI_API_KEY_AVAILABLE]}</span>. `Cache`: <span style="{style_mapping[IS_CACHE_AVAILABLE]}">{availability_mapping[IS_CACHE_AVAILABLE]}</span>.''')
195
  file_output = gr.File(label="Output")
196
 
197
+ clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
198
+ submit_button_pp.click(fn=wrapped_generator, inputs=[title, description_pp, key], outputs=file_output)
199
 
200
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
201
  demo.launch()
auto_backgrounds.py CHANGED
@@ -91,7 +91,7 @@ def fake_generator(title, description="", template="ICLR2022", model="gpt-4"):
91
  return make_archive("sample-output.pdf", filename)
92
 
93
 
94
- def generate_draft(title, description="", template="ICLR2022", model="gpt-4", search_engine="ss", tldr=True, max_kw_refs=14):
95
  paper, destination_folder, _ = _generation_setup(title, description, template, model, search_engine, tldr, max_kw_refs)
96
 
97
  # todo: `list_of_methods` failed to be generated; find a solution ...
 
91
  return make_archive("sample-output.pdf", filename)
92
 
93
 
94
+ def generate_draft(title, description="", template="ICLR2022", model="gpt-4", search_engine="ss", tldr=True, max_kw_refs=10):
95
  paper, destination_folder, _ = _generation_setup(title, description, template, model, search_engine, tldr, max_kw_refs)
96
 
97
  # todo: `list_of_methods` failed to be generated; find a solution ...
latex_templates/pre_refs.bib ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ @article{1512.07669,
3
+ title = {Reinforcement Learning: Stochastic Approximation Algorithms for Markov
4
+ Decision Processes},
5
+ author = {Vikram Krishnamurthy},
6
+ journal={arXiv preprint arXiv:1512.07669},
7
+ year = {2015},
8
+ url = {http://arxiv.org/abs/1512.07669v1}
9
+ }
10
+
11
+ @article{1511.02377,
12
+ title = {The Value Functions of Markov Decision Processes},
13
+ author = {Ehud Lehrer , Eilon Solan , Omri N. Solan},
14
+ journal={arXiv preprint arXiv:1511.02377},
15
+ year = {2015},
16
+ url = {http://arxiv.org/abs/1511.02377v1}
17
+ }
utils/references.py CHANGED
@@ -1,18 +1,26 @@
1
- # Generate references
2
- # 1. select most correlated references from "references" dataset or Arxiv search engine.
3
- # 2. Generate bibtex from the selected papers. --> to_bibtex()
4
- # 3. Generate prompts from the selected papers: --> to_prompts()
5
- # {"paper_id": "paper summary"}
6
-
 
 
 
 
7
 
8
  import requests
9
  import re
 
 
 
10
 
11
 
12
  ######################################################################################################################
13
  # Some basic tools
14
  ######################################################################################################################
15
  def remove_newlines(serie):
 
16
  serie = serie.replace('\n', ' ')
17
  serie = serie.replace('\\n', ' ')
18
  serie = serie.replace(' ', ' ')
@@ -20,6 +28,47 @@ def remove_newlines(serie):
20
  return serie
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ######################################################################################################################
24
  # Semantic Scholar (SS) API
25
  ######################################################################################################################
@@ -63,7 +112,11 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
63
  pattern = r'^\w+'
64
  words = re.findall(pattern, title)
65
  # return last_name + year_str + title.split(' ', 1)[0]
66
- return last_name + year_str + words[0]
 
 
 
 
67
 
68
  def extract_author_info(raw_authors):
69
  authors = [author['name'] for author in raw_authors]
@@ -71,7 +124,7 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
71
  authors_str = " and ".join(authors)
72
  try:
73
  last_name = authors[0].split()[-1]
74
- except:
75
  last_name = "ma"
76
  # pattern = r'^\w+'
77
  # last_name = re.findall(pattern, authors[0])
@@ -79,7 +132,7 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
79
 
80
  def parse_search_results(search_results_ss):
81
  # turn the search result to a list of paper dictionary.
82
- papers = []
83
  for raw_paper in search_results_ss:
84
  if raw_paper["abstract"] is None:
85
  continue
@@ -100,14 +153,14 @@ def _collect_papers_ss(keyword, counts=3, tldr=False):
100
  result = {
101
  "paper_id": paper_id,
102
  "title": title,
103
- "abstract": abstract, # todo: compare results with tldr
104
  "link": link,
105
  "authors": authors_str,
106
  "year": year_str,
107
  "journal": journal
108
  }
109
- papers.append(result)
110
- return papers
111
 
112
  raw_results = ss_search(keyword, limit=counts)
113
  if raw_results is not None:
@@ -192,13 +245,13 @@ def _collect_papers_arxiv(keyword, counts=3, tldr=False):
192
  # References Class
193
  ######################################################################################################################
194
 
195
- # Each `paper` is a dictionary containing (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal
196
  class References:
197
  def __init__(self, load_papers=""):
198
  if load_papers:
199
- # todo: read a json file from the given path
200
- # this could be used to support pre-defined references
201
- pass
 
202
  else:
203
  self.papers = []
204
 
@@ -266,15 +319,20 @@ class References:
266
 
267
 
268
  if __name__ == "__main__":
269
- refs = References()
270
- keywords_dict = {
271
- "Deep Q-Networks": 15,
272
- "Policy Gradient Methods": 24,
273
- "Actor-Critic Algorithms": 4,
274
- "Model-Based Reinforcement Learning": 13,
275
- "Exploration-Exploitation Trade-off": 7
276
- }
277
- refs.collect_papers(keywords_dict, method="ss", tldr=True)
278
- for p in refs.papers:
279
- print(p["paper_id"])
280
- print(len(refs.papers))
 
 
 
 
 
 
1
+ # Each `paper` is a dictionary containing:
2
+ # (1) paper_id (2) title (3) authors (4) year (5) link (6) abstract (7) journal
3
+ #
4
+ # Generate references:
5
+ # `Reference` class:
6
+ # 1. Read a given .bib file to collect papers; use `search_paper_abstract` method to fill missing abstract.
7
+ # 2. Given some keywords; use ArXiv or Semantic Scholar API to find papers.
8
+ # 3. Generate bibtex from the selected papers. --> to_bibtex()
9
+ # 4. Generate prompts from the selected papers: --> to_prompts()
10
+ # A sample prompt: {"paper_id": "paper summary"}
11
 
12
  import requests
13
  import re
14
+ import bibtexparser
15
+ from scholarly import scholarly
16
+ from scholarly import ProxyGenerator
17
 
18
 
19
  ######################################################################################################################
20
  # Some basic tools
21
  ######################################################################################################################
22
  def remove_newlines(serie):
23
+ # This function is applied to the abstract of each paper to reduce the length of prompts.
24
  serie = serie.replace('\n', ' ')
25
  serie = serie.replace('\\n', ' ')
26
  serie = serie.replace(' ', ' ')
 
28
  return serie
29
 
30
 
31
+ def search_paper_abstract(title):
32
+ pg = ProxyGenerator()
33
+ success = pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
34
+ scholarly.use_proxy(pg)
35
+ # input the title of a paper, return its abstract
36
+ search_query = scholarly.search_pubs(title)
37
+ paper = next(search_query)
38
+ return remove_newlines(paper['bib']['abstract'])
39
+
40
+
41
+ def load_papers_from_bibtex(bib_file_path):
42
+ with open(bib_file_path) as bibtex_file:
43
+ bib_database = bibtexparser.load(bibtex_file)
44
+ if len(bib_database.entries) == 0:
45
+ return []
46
+ else:
47
+ bib_papers = []
48
+ for bibitem in bib_database.entries:
49
+ paper_id = bibitem.get("ID")
50
+ title = bibitem.get("title")
51
+ if title is None:
52
+ continue
53
+ journal = bibitem.get("journal")
54
+ year = bibitem.get("year")
55
+ author = bibitem.get("author")
56
+ abstract = bibitem.get("abstract")
57
+ if abstract is None:
58
+ abstract = search_paper_abstract(title)
59
+ result = {
60
+ "paper_id": paper_id,
61
+ "title": title,
62
+ "link": "",
63
+ "abstract": abstract,
64
+ "authors": author,
65
+ "year": year,
66
+ "journal": journal
67
+ }
68
+ bib_papers.append(result)
69
+ return bib_papers
70
+
71
+
72
  ######################################################################################################################
73
  # Semantic Scholar (SS) API
74
  ######################################################################################################################
 
112
  pattern = r'^\w+'
113
  words = re.findall(pattern, title)
114
  # return last_name + year_str + title.split(' ', 1)[0]
115
+ try:
116
+ output = last_name + year_str + words[0]
117
+ except IndexError:
118
+ output = last_name + year_str + title[:4]
119
+ return output
120
 
121
  def extract_author_info(raw_authors):
122
  authors = [author['name'] for author in raw_authors]
 
124
  authors_str = " and ".join(authors)
125
  try:
126
  last_name = authors[0].split()[-1]
127
+ except IndexError:
128
  last_name = "ma"
129
  # pattern = r'^\w+'
130
  # last_name = re.findall(pattern, authors[0])
 
132
 
133
  def parse_search_results(search_results_ss):
134
  # turn the search result to a list of paper dictionary.
135
+ papers_ss = []
136
  for raw_paper in search_results_ss:
137
  if raw_paper["abstract"] is None:
138
  continue
 
153
  result = {
154
  "paper_id": paper_id,
155
  "title": title,
156
+ "abstract": abstract,
157
  "link": link,
158
  "authors": authors_str,
159
  "year": year_str,
160
  "journal": journal
161
  }
162
+ papers_ss.append(result)
163
+ return papers_ss
164
 
165
  raw_results = ss_search(keyword, limit=counts)
166
  if raw_results is not None:
 
245
  # References Class
246
  ######################################################################################################################
247
 
 
248
  class References:
249
  def __init__(self, load_papers=""):
250
  if load_papers:
251
+ # todo: (1) too large bibtex may make have issues on token limitations; may truncate to 5 or 10
252
+ # (2) google scholar didn't give a full abstract for some papers ...
253
+ # (3) may use langchain to support long input
254
+ self.papers = load_papers_from_bibtex(load_papers)
255
  else:
256
  self.papers = []
257
 
 
319
 
320
 
321
  if __name__ == "__main__":
322
+ # refs = References()
323
+ # keywords_dict = {
324
+ # "Deep Q-Networks": 15,
325
+ # "Policy Gradient Methods": 24,
326
+ # "Actor-Critic Algorithms": 4,
327
+ # "Model-Based Reinforcement Learning": 13,
328
+ # "Exploration-Exploitation Trade-off": 7
329
+ # }
330
+ # refs.collect_papers(keywords_dict, method="ss", tldr=True)
331
+ # for p in refs.papers:
332
+ # print(p["paper_id"])
333
+ # print(len(refs.papers))
334
+
335
+ bib = "D:\\Projects\\auto-draft\\latex_templates\\pre_refs.bib"
336
+ papers = load_papers_from_bibtex(bib)
337
+ for paper in papers:
338
+ print(paper)
utils/tex_processing.py CHANGED
@@ -2,16 +2,12 @@ import os
2
 
3
  def replace_title(save_to_path, title):
4
  # Define input and output file names
5
- # input_file_name = save_to_path + "/template.tex"
6
- # output_file_name = save_to_path + "/main.tex"
7
  input_file_name = os.path.join(save_to_path, "template.tex")
8
  output_file_name = os.path.join(save_to_path , "main.tex")
9
 
10
  # Open the input file and read its content
11
  with open(input_file_name, 'r') as infile:
12
  content = infile.read()
13
-
14
- # Replace all occurrences of "asdfgh" with "hahaha"
15
  content = content.replace(r"\title{TITLE} ", f"\\title{{{title}}} ")
16
 
17
  # Open the output file and write the modified content
@@ -19,11 +15,14 @@ def replace_title(save_to_path, title):
19
  outfile.write(content)
20
 
21
 
22
- # return all string in \cite{...}.
23
 
24
  # check if citations are in bibtex.
25
 
26
 
27
  # replace citations
28
 
29
- # sometimes the output may include thebibliography and bibitem . remove all of it.
 
 
 
 
2
 
3
  def replace_title(save_to_path, title):
4
  # Define input and output file names
 
 
5
  input_file_name = os.path.join(save_to_path, "template.tex")
6
  output_file_name = os.path.join(save_to_path , "main.tex")
7
 
8
  # Open the input file and read its content
9
  with open(input_file_name, 'r') as infile:
10
  content = infile.read()
 
 
11
  content = content.replace(r"\title{TITLE} ", f"\\title{{{title}}} ")
12
 
13
  # Open the output file and write the modified content
 
15
  outfile.write(content)
16
 
17
 
18
+ # return all string in \cite{...} \citet{...} or \citep{...}.
19
 
20
  # check if citations are in bibtex.
21
 
22
 
23
  # replace citations
24
 
25
+ # sometimes the output may include thebibliography and bibitem . remove all of it.
26
+
27
+
28
+