shaocongma commited on
Commit
c160ff7
·
1 Parent(s): 72c76c9

Add references generation.

Browse files
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import openai
4
  from auto_backgrounds import generate_backgrounds, generate_draft
5
  from utils.file_operations import hash_name
 
6
 
7
  # note: App白屏bug:允许第三方cookie
8
  # todo:
@@ -48,6 +49,9 @@ else:
48
  def clear_inputs(*args):
49
  return "", ""
50
 
 
 
 
51
 
52
  def wrapped_generator(paper_title, paper_description, openai_api_key=None,
53
  paper_template="ICLR2022", tldr=True, max_num_refs=50, selected_sections=None, bib_refs=None, model="gpt-4",
@@ -91,6 +95,11 @@ def wrapped_generator(paper_title, paper_description, openai_api_key=None,
91
  return output
92
 
93
 
 
 
 
 
 
94
  theme = gr.themes.Default(font=gr.themes.GoogleFont("Questrial"))
95
  # .set(
96
  # background_fill_primary='#E5E4E2',
@@ -105,6 +114,14 @@ ACADEMIC_PAPER = """## 一键生成论文初稿
105
  3. 在右侧下载.zip格式的输出,在Overleaf上编译浏览.
106
  """
107
 
 
 
 
 
 
 
 
 
108
  with gr.Blocks(theme=theme) as demo:
109
  gr.Markdown('''
110
  # Auto-Draft: 文献整理辅助工具
@@ -176,23 +193,22 @@ with gr.Blocks(theme=theme) as demo:
176
  clear_button_pp = gr.Button("Clear")
177
  submit_button_pp = gr.Button("Submit", variant="primary")
178
 
179
- with gr.Tab("文献综述"):
180
- gr.Markdown('''
181
- <h1 style="text-align: center;">Coming soon!</h1>
182
- ''')
183
- # topic = gr.Textbox(value="Deep Reinforcement Learning", lines=1, max_lines=1,
184
- # label="Topic", info="文献主题")
185
- # with gr.Accordion("Advanced Setting"):
186
- # description_lr = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
187
- # info="对希望生成的综述的一些描述. 包括这篇论文的创新点, 主要贡献, 等.")
188
- # with gr.Row():
189
- # clear_button_lr = gr.Button("Clear")
190
- # submit_button_lr = gr.Button("Submit", variant="primary")
191
- with gr.Tab("论文润色"):
192
  gr.Markdown('''
193
  <h1 style="text-align: center;">Coming soon!</h1>
194
  ''')
195
- with gr.Tab("帮我想想该写什么论文!"):
196
  gr.Markdown('''
197
  <h1 style="text-align: center;">Coming soon!</h1>
198
  ''')
@@ -207,13 +223,16 @@ with gr.Blocks(theme=theme) as demo:
207
  当`Cache`显示AVAILABLE的时候, 所有的输入和输出会被备份到我的云储存中. 显示NOT AVAILABLE的时候不影响实际使用.
208
  `OpenAI API`: <span style="{style_mapping[IS_OPENAI_API_KEY_AVAILABLE]}">{availability_mapping[IS_OPENAI_API_KEY_AVAILABLE]}</span>. `Cache`: <span style="{style_mapping[IS_CACHE_AVAILABLE]}">{availability_mapping[IS_CACHE_AVAILABLE]}</span>.''')
209
  file_output = gr.File(label="Output")
 
210
 
211
  clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
212
- # submit_button_pp.click(fn=wrapped_generator,
213
- # inputs=[title, description_pp, key, template, tldr, slider, sections, bibtex_file], outputs=file_output)
214
  submit_button_pp.click(fn=wrapped_generator,
215
  inputs=[title, description_pp, key, template, tldr_checkbox, slider, sections, bibtex_file,
216
  model_selection], outputs=file_output)
217
 
 
 
 
 
218
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
219
  demo.launch()
 
3
  import openai
4
  from auto_backgrounds import generate_backgrounds, generate_draft
5
  from utils.file_operations import hash_name
6
+ from references_generator import generate_top_k_references
7
 
8
  # note: App白屏bug:允许第三方cookie
9
  # todo:
 
49
  def clear_inputs(*args):
50
  return "", ""
51
 
52
+ def clear_inputs_refs(*args):
53
+ return "", 5
54
+
55
 
56
  def wrapped_generator(paper_title, paper_description, openai_api_key=None,
57
  paper_template="ICLR2022", tldr=True, max_num_refs=50, selected_sections=None, bib_refs=None, model="gpt-4",
 
95
  return output
96
 
97
 
98
+ def wrapped_references_generator(paper_title, num_refs):
99
+ return generate_top_k_references(paper_title, top_k=num_refs)
100
+
101
+
102
+
103
  theme = gr.themes.Default(font=gr.themes.GoogleFont("Questrial"))
104
  # .set(
105
  # background_fill_primary='#E5E4E2',
 
114
  3. 在右侧下载.zip格式的输出,在Overleaf上编译浏览.
115
  """
116
 
117
+
118
+ REFERENCES = """## 一键搜索相关论文
119
+
120
+ 1. 在Title文本框中输入想要搜索文献的论文(比如Playing Atari with Deep Reinforcement Learning).
121
+ 2. 点击Submit. 等待大概十分钟.
122
+ 3. 在右侧JSON处会显示相关文献.
123
+ """
124
+
125
  with gr.Blocks(theme=theme) as demo:
126
  gr.Markdown('''
127
  # Auto-Draft: 文献整理辅助工具
 
193
  clear_button_pp = gr.Button("Clear")
194
  submit_button_pp = gr.Button("Submit", variant="primary")
195
 
196
+ with gr.Tab("文献搜索 (NEW!)"):
197
+ gr.Markdown(REFERENCES)
198
+
199
+ title_refs = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
200
+ label="Title", info="论文标题")
201
+ slider_refs = gr.Slider(minimum=1, maximum=100, value=5, step=1,
202
+ interactive=True, label="最相关的参考文献数目")
203
+ with gr.Row():
204
+ clear_button_refs = gr.Button("Clear")
205
+ submit_button_refs = gr.Button("Submit", variant="primary")
206
+
207
+ with gr.Tab("文献综述 (Coming soon!)"):
 
208
  gr.Markdown('''
209
  <h1 style="text-align: center;">Coming soon!</h1>
210
  ''')
211
+ with gr.Tab("Github文档 (Coming soon!)"):
212
  gr.Markdown('''
213
  <h1 style="text-align: center;">Coming soon!</h1>
214
  ''')
 
223
  当`Cache`显示AVAILABLE的时候, 所有的输入和输出会被备份到我的云储存中. 显示NOT AVAILABLE的时候不影响实际使用.
224
  `OpenAI API`: <span style="{style_mapping[IS_OPENAI_API_KEY_AVAILABLE]}">{availability_mapping[IS_OPENAI_API_KEY_AVAILABLE]}</span>. `Cache`: <span style="{style_mapping[IS_CACHE_AVAILABLE]}">{availability_mapping[IS_CACHE_AVAILABLE]}</span>.''')
225
  file_output = gr.File(label="Output")
226
+ json_output = gr.JSON(label="References")
227
 
228
  clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
 
 
229
  submit_button_pp.click(fn=wrapped_generator,
230
  inputs=[title, description_pp, key, template, tldr_checkbox, slider, sections, bibtex_file,
231
  model_selection], outputs=file_output)
232
 
233
+ clear_button_refs.click(fn=clear_inputs_refs, inputs=[title_refs, slider_refs], outputs=[title_refs, slider_refs])
234
+ submit_button_refs.click(fn=wrapped_references_generator,
235
+ inputs=[title_refs, slider_refs], outputs=json_output)
236
+
237
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
238
  demo.launch()
latex_templates/ICLR2022/fig.png CHANGED
latex_templates/ICLR2022/template.tex CHANGED
@@ -7,7 +7,7 @@
7
  \usepackage{hyperref}
8
  \usepackage{url}
9
  \usepackage{algorithm}
10
- \usepackage{algorithmic}
11
 
12
  \title{TITLE}
13
  \author{GPT-4}
 
7
  \usepackage{hyperref}
8
  \usepackage{url}
9
  \usepackage{algorithm}
10
+ \usepackage{algpseudocode}
11
 
12
  \title{TITLE}
13
  \author{GPT-4}
references_generator.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import json
3
+ from utils.references import References
4
+ from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
5
+ import itertools
6
+ from gradio_client import Client
7
+
8
+ def generate_raw_references(title, description="",
9
+ bib_refs=None, tldr=False, max_kw_refs=10, save_to="ref.bib"):
10
+ # load pre-provided references
11
+ ref = References(title, bib_refs)
12
+
13
+ # generate multiple keywords for searching
14
+ input_dict = {"title": title, "description": description}
15
+ keywords, usage = keywords_generation(input_dict)
16
+ keywords = list(keywords)
17
+ comb_keywords = list(itertools.combinations(keywords, 2))
18
+ for comb_keyword in comb_keywords:
19
+ keywords.append(" ".join(comb_keyword))
20
+ keywords = {keyword:max_kw_refs for keyword in keywords}
21
+ print(f"keywords: {keywords}\n\n")
22
+
23
+ ref.collect_papers(keywords, tldr=tldr)
24
+ paper_json = ref.to_json()
25
+
26
+ with open(save_to, "w") as f:
27
+ json.dump(paper_json, f)
28
+
29
+ return save_to, paper_json
30
+
31
+ def generate_top_k_references(title, description="",
32
+ bib_refs=None, tldr=False, max_kw_refs=10, save_to="ref.bib", top_k=5):
33
+ json_path, json_content = generate_raw_references(title, description, bib_refs, tldr, max_kw_refs, save_to)
34
+
35
+ client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
36
+ result = client.predict(
37
+ title, # str in 'Title' Textbox component
38
+ json_path, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
39
+ top_k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
40
+ api_name="/get_k_relevant_papers"
41
+ )
42
+ with open(result) as f:
43
+ result = json.load(f)
44
+ return result
45
+
46
+ if __name__ == "__main__":
47
+ import openai
48
+ openai.api_key = os.getenv("OPENAI_API_KEY")
49
+
50
+ title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
51
+ description = ""
52
+ save_to = "paper.json"
53
+ save_to, paper_json = generate_raw_references(title, description, save_to=save_to)
54
+
55
+ print("`paper.json` has been generated. Now evaluating its similarity...")
56
+
57
+ k = 5
58
+ client = Client("https://shaocongma-evaluate-specter-embeddings.hf.space/")
59
+ result = client.predict(
60
+ title, # str in 'Title' Textbox component
61
+ save_to, # str (filepath or URL to file) in 'Papers JSON (as string)' File component
62
+ k, # int | float (numeric value between 1 and 50) in 'Top-k Relevant Papers' Slider component
63
+ api_name="/get_k_relevant_papers"
64
+ )
65
+
66
+ with open(result) as f:
67
+ result = json.load(f)
68
+
69
+ print(result)
70
+
71
+ save_to = "paper2.json"
72
+ with open(save_to, "w") as f:
73
+ json.dump(result, f)
section_generator.py CHANGED
@@ -90,7 +90,7 @@ def keywords_generation(input_dict):
90
  attempts_count = 0
91
  while attempts_count < max_attempts:
92
  try:
93
- keywords, usage= get_gpt_responses(KEYWORDS_SYSTEM.format(min_refs_num=3, max_refs_num=5), title,
94
  model="gpt-3.5-turbo", temperature=0.4)
95
  print(keywords)
96
  output = json.loads(keywords)
 
90
  attempts_count = 0
91
  while attempts_count < max_attempts:
92
  try:
93
+ keywords, usage= get_gpt_responses(KEYWORDS_SYSTEM.format(min_refs_num=1, max_refs_num=10), title,
94
  model="gpt-3.5-turbo", temperature=0.4)
95
  print(keywords)
96
  output = json.loads(keywords)