sc_ma commited on
Commit
1b82d4c
·
1 Parent(s): 1179bb0

Add functions to support cloud storage cache.

Browse files

Add functions to flatten the latex file (for future polishing).
Update README.md to include additonal codes from different license.

Files changed (6) hide show
  1. README.md +4 -1
  2. app.py +60 -20
  3. auto_backgrounds.py +22 -2
  4. latex-flatten.py +50 -0
  5. utils/gpt_interaction.py +7 -1
  6. utils/storage.py +50 -0
README.md CHANGED
@@ -44,5 +44,8 @@ Page 1 | Page 2
44
  :-------------------------:|:-------------------------:
45
  ![](assets/page1.png "Page-1") | ![](assets/page2.png "Page-2")
46
 
 
 
 
47
 
48
-
 
44
  :-------------------------:|:-------------------------:
45
  ![](assets/page1.png "Page-1") | ![](assets/page2.png "Page-2")
46
 
47
+ # License
48
+ This project is licensed under the MIT License.
49
+ Some parts of the code are under different licenses, as listed below:
50
 
51
+ * `latex-flatten.py`: Licensed under the Unlicense. Original source: [rekka/latex-flatten](https://github.com/rekka/latex-flatten).
app.py CHANGED
@@ -1,34 +1,65 @@
1
  import gradio as gr
2
- import openai
3
- from auto_backgrounds import generate_backgrounds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # todo:  5. Add more functions in this demo.
6
 
7
  def clear_inputs(text1, text2):
8
- return ("", "")
9
 
10
- def wrapped_generate_backgrounds(title, description):
11
- if title == "Deep Reinforcement Learning":
12
- return "output.zip"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  else:
14
- return generate_backgrounds(title, description)
 
 
 
15
 
16
 
17
  with gr.Blocks() as demo:
18
  gr.Markdown('''
19
  # Auto-Draft: 文献整理辅助工具-限量免费使用
20
 
21
- 本Demo提供对[Auto-Draft](https://github.com/CCCBora/auto-draft)的auto_backgrounds功能的测试。通过输入一个领域的名称(比如Deep Reinforcement Learning)
22
- 即可自动对这个领域的相关文献进行归纳总结.
23
 
24
- 生成一篇论文,需要使用我GPT4的API,大概每篇15000 Tokens(大约0.5到0.8美元).
25
- 我为大家提供了30刀的额度上限,希望大家有明确需求再使用. 如果有更多需求,建议本地部署, 使用自己的API KEY!
26
-
27
- ***2023-04-26 Update***: 我本月的余额用完了, 感谢乐乐老师帮忙宣传, 也感觉大家的体验和反馈! 我会按照大家的意见对功能进行改进. 下个月开始仅会在Huggingface
28
- 的Organization里提供免费的试用, 欢迎有兴趣的同学通过下面的链接加入!
29
-
30
- [https://huggingface.co/organizations/auto-academic/share/HPjgazDSlkwLNCWKiAiZoYtXaJIatkWDYM](https://huggingface.co/organizations/auto-academic/share/HPjgazDSlkwLNCWKiAiZoYtXaJIatkWDYM)
31
 
 
32
 
33
  ## 用法
34
 
@@ -36,6 +67,7 @@ with gr.Blocks() as demo:
36
  ''')
37
  with gr.Row():
38
  with gr.Column():
 
39
  title = gr.Textbox(value="Deep Reinforcement Learning", lines=1, max_lines=1, label="Title")
40
  description = gr.Textbox(lines=5, label="Description (Optional)")
41
 
@@ -43,10 +75,18 @@ with gr.Blocks() as demo:
43
  clear_button = gr.Button("Clear")
44
  submit_button = gr.Button("Submit")
45
  with gr.Column():
46
- file_output = gr.File()
 
 
 
 
 
 
 
 
47
 
48
  clear_button.click(fn=clear_inputs, inputs=[title, description], outputs=[title, description])
49
- submit_button.click(fn=wrapped_generate_backgrounds, inputs=[title, description], outputs=file_output)
50
 
51
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
52
- demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ from auto_backgrounds import generate_backgrounds, fake_generate_backgrounds
4
+
5
+ openai_key = os.getenv("OPENAI_API_KEY")
6
+ access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
7
+ secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
8
+ if access_key_id is None or secret_access_key is None:
9
+ print("Access keys are not provided. Outputs cannot be saved to AWS Cloud Storage.")
10
+ IS_CACHE_AVAILABLE = False
11
+ else:
12
+ IS_CACHE_AVAILABLE = True
13
+
14
+ if openai_key is None:
15
+ print("OPENAI_API_KEY is not found in environment variables. The output may not be generated.")
16
+ IS_OPENAI_API_KEY_AVAILABLE = False
17
+ else:
18
+ # todo: check if this key is available or not
19
+ IS_OPENAI_API_KEY_AVAILABLE = True
20
+
21
 
 
22
 
23
  def clear_inputs(text1, text2):
24
+ return "", ""
25
 
26
+
27
+ def wrapped_generate_backgrounds(title, description, openai_key = None, cache_mode = True):
28
+ # if `cache_mode` is True, then follow the following logic:
29
+ # check if "title"+"description" have been generated before
30
+ # if so, download from the cloud storage, return it
31
+ # if not, generate the result.
32
+ if cache_mode:
33
+ from utils.storage import list_all_files, hash_name, download_file, upload_file
34
+ # check if "title"+"description" have been generated before
35
+ file_name = hash_name(title, description) + ".zip"
36
+ file_list = list_all_files()
37
+ if file_name in file_list:
38
+ # download from the cloud storage, return it
39
+ download_file(file_name)
40
+ return file_name
41
+ else:
42
+ # generate the result.
43
+ # output = fake_generate_backgrounds(title, description, openai_key)
44
+ output = generate_backgrounds(title, description, openai_key) #todo: change the output of this function to hashed title
45
+ upload_file(file_name)
46
+ return output
47
  else:
48
+ # output = fake_generate_backgrounds(title, description, openai_key)
49
+ output = generate_backgrounds(title, description, openai_key) #todo: change the output of this function to hashed title
50
+ return output
51
+
52
 
53
 
54
  with gr.Blocks() as demo:
55
  gr.Markdown('''
56
  # Auto-Draft: 文献整理辅助工具-限量免费使用
57
 
58
+ 本Demo提供对[Auto-Draft](https://github.com/CCCBora/auto-draft)的auto_backgrounds功能的测试。通过输入一个领域的名称(比如Deep Reinforcement Learning),即可自动对这个领域的相关文献进行归纳总结.
 
59
 
60
+ ***2023-04-30 Update***: 如果有更多想法和建议欢迎加入群里交流, 群号: ***249738228***.
 
 
 
 
 
 
61
 
62
+ ***2023-04-26 Update***: 我本月的余额用完了, 感谢乐乐老师帮忙宣传, 也感觉大家的体验和反馈! 我会按照大家的意见对功能进行改进. 下个月会把Space的访问权限限制在Huggingface的Organization里, 欢迎有兴趣的同学通过下面的链接加入! [AUTO-ACADEMIC](https://huggingface.co/organizations/auto-academic/share/HPjgazDSlkwLNCWKiAiZoYtXaJIatkWDYM)
63
 
64
  ## 用法
65
 
 
67
  ''')
68
  with gr.Row():
69
  with gr.Column():
70
+ key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key", visible=not IS_OPENAI_API_KEY_AVAILABLE)
71
  title = gr.Textbox(value="Deep Reinforcement Learning", lines=1, max_lines=1, label="Title")
72
  description = gr.Textbox(lines=5, label="Description (Optional)")
73
 
 
75
  clear_button = gr.Button("Clear")
76
  submit_button = gr.Button("Submit")
77
  with gr.Column():
78
+ style_mapping = {True: "color:white;background-color:green", False: "color:white;background-color:red"}
79
+ availablity_mapping = {True: "AVAILABLE", False: "NOT AVAILABLE"}
80
+ gr.Markdown(f'''## Huggingface Space Status
81
+ 当`OpenAI API`显示AVAILABLE的时候这个Space可以直接使用.
82
+ 当`OpenAI API`显示UNAVAILABLE的时候这个Space可以通过在左侧输入OPENAI KEY来使用.
83
+ `OpenAI API`: <span style="{style_mapping[IS_OPENAI_API_KEY_AVAILABLE]}">{availablity_mapping[IS_OPENAI_API_KEY_AVAILABLE]}</span>. `Cache`: <span style="{style_mapping[IS_CACHE_AVAILABLE]}">{availablity_mapping[IS_CACHE_AVAILABLE]}</span>.''')
84
+ file_output = gr.File(label="Output")
85
+
86
+
87
 
88
  clear_button.click(fn=clear_inputs, inputs=[title, description], outputs=[title, description])
89
+ submit_button.click(fn=wrapped_generate_backgrounds, inputs=[title, description, key], outputs=file_output)
90
 
91
  demo.queue(concurrency_count=1, max_size=5, api_open=False)
92
+ demo.launch()
auto_backgrounds.py CHANGED
@@ -13,6 +13,17 @@ TOTAL_PROMPTS_TOKENS = 0
13
  TOTAL_COMPLETION_TOKENS = 0
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def log_usage(usage, generating_target, print_out=True):
17
  global TOTAL_TOKENS
18
  global TOTAL_PROMPTS_TOKENS
@@ -42,7 +53,7 @@ def make_archive(source, destination):
42
  shutil.move('%s.%s'%(name,format), destination)
43
  return destination
44
 
45
- def pipeline(paper, section, save_to_path, model):
46
  """
47
  The main pipeline of generating a section.
48
  1. Generate prompts.
@@ -75,7 +86,7 @@ def pipeline(paper, section, save_to_path, model):
75
 
76
 
77
 
78
- def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-4"):
79
  paper = {}
80
  paper_body = {}
81
 
@@ -120,6 +131,15 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
120
  # shutil.make_archive("output.zip", 'zip', save_to_path)
121
  return make_archive(destination_folder, "output.zip")
122
 
 
 
 
 
 
 
 
 
 
123
  if __name__ == "__main__":
124
  title = "Reinforcement Learning"
125
  description = ""
 
13
  TOTAL_COMPLETION_TOKENS = 0
14
 
15
 
16
+ def hash_name(title, description):
17
+ '''
18
+ For same title and description, it should return the same value.
19
+ '''
20
+ name = title + description
21
+ name = name.lower()
22
+ md5 = hashlib.md5()
23
+ md5.update(name.encode('utf-8'))
24
+ hashed_string = md5.hexdigest()
25
+ return hashed_string
26
+
27
  def log_usage(usage, generating_target, print_out=True):
28
  global TOTAL_TOKENS
29
  global TOTAL_PROMPTS_TOKENS
 
53
  shutil.move('%s.%s'%(name,format), destination)
54
  return destination
55
 
56
+ def pipeline(paper, section, save_to_path, model, openai_key=None):
57
  """
58
  The main pipeline of generating a section.
59
  1. Generate prompts.
 
86
 
87
 
88
 
89
+ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-4", openai_key=None):
90
  paper = {}
91
  paper_body = {}
92
 
 
131
  # shutil.make_archive("output.zip", 'zip', save_to_path)
132
  return make_archive(destination_folder, "output.zip")
133
 
134
+
135
+ def fake_generate_backgrounds(title, description, openai_key = None):
136
+ """
137
+ This function is used to test the whole pipeline without calling OpenAI API.
138
+ """
139
+ filename = hash_name(title, description) + ".zip"
140
+ return make_archive("sample-output.pdf", filename)
141
+
142
+
143
  if __name__ == "__main__":
144
  title = "Reinforcement Learning"
145
  description = ""
latex-flatten.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # This script is taken from: https://github.com/rekka/latex-flatten
3
+
4
+ # A simple script for flattening LaTeX files by inlining included files.
5
+ #
6
+ # - Supports `\include` and `\input` commands.
7
+ # - Automatically adds extension `.tex` if the file does not have an extension.
8
+ # - Handles multiple include commands per line, comments.
9
+ # - Does not flatten recursively.
10
+
11
+ import re
12
+ import sys
13
+
14
+ if len(sys.argv)==3:
15
+ main_name = sys.argv[1]
16
+ output_name = sys.argv[2]
17
+ else:
18
+ sys.exit('USAGE: %s main.tex output.tex' %sys.argv[0])
19
+
20
+ main = open(main_name,'r')
21
+ output = open(output_name,'w')
22
+
23
+ for line in main.readlines():
24
+ s = re.split('%', line, 2)
25
+ tex = s[0]
26
+ if len(s) > 1:
27
+ comment = '%' + s[1]
28
+ else:
29
+ comment = ''
30
+
31
+ chunks = re.split(r'\\(?:input|include)\{[^}]+\}', tex)
32
+
33
+ if len(chunks) > 1:
34
+ for (c, t) in zip(chunks, re.finditer(r'\\(input|include)\{([^}]+)\}', tex)):
35
+ cmd_name = t.group(1)
36
+ include_name = t.group(2)
37
+ if '.' not in include_name: include_name = include_name + '.tex'
38
+ if c.strip(): output.write(c + '\n')
39
+ output.write('% BEGIN \\' + cmd_name + '{' + include_name + '}\n')
40
+ include = open(include_name, 'r')
41
+ output.write(include.read())
42
+ include.close()
43
+ output.write('% END \\' + cmd_name + '{' + include_name + '}\n')
44
+ tail = chunks[-1] + comment
45
+ if tail.strip(): output.write(tail)
46
+ else:
47
+ output.write(line)
48
+
49
+ output.close()
50
+ main.close()
utils/gpt_interaction.py CHANGED
@@ -5,6 +5,7 @@ import json
5
  import logging
6
  log = logging.getLogger(__name__)
7
 
 
8
  openai.api_key = os.environ['OPENAI_API_KEY']
9
 
10
  def extract_responses(assistant_message):
@@ -54,7 +55,12 @@ def extract_json(assistant_message, default_output=None):
54
  return dict.keys()
55
 
56
 
57
- def get_responses(user_message, model="gpt-4", temperature=0.4):
 
 
 
 
 
58
  conversation_history = [
59
  {"role": "system", "content": "You are an assistant in writing machine learning papers."}
60
  ]
 
5
  import logging
6
  log = logging.getLogger(__name__)
7
 
8
+ # todo: 将api_key通过函数传入; 需要改很多地方
9
  openai.api_key = os.environ['OPENAI_API_KEY']
10
 
11
  def extract_responses(assistant_message):
 
55
  return dict.keys()
56
 
57
 
58
+ def get_responses(user_message, model="gpt-4", temperature=0.4, openai_key = None):
59
+ if openai.api_key is None and openai_key is None:
60
+ raise ValueError("OpenAI API key must be provided.")
61
+ if openai_key is not None:
62
+ openai.api_key = openai_key
63
+
64
  conversation_history = [
65
  {"role": "system", "content": "You are an assistant in writing machine learning papers."}
66
  ]
utils/storage.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import boto3
3
+ import hashlib
4
+
5
+ access_key_id = os.environ['AWS_ACCESS_KEY_ID']
6
+ secret_access_key = os.environ['AWS_SECRET_ACCESS_KEY']
7
+ bucket_name = "hf-storage"
8
+
9
+ session = boto3.Session(
10
+ aws_access_key_id=access_key_id,
11
+ aws_secret_access_key=secret_access_key,
12
+ )
13
+
14
+ s3 = session.resource('s3')
15
+ bucket = s3.Bucket(bucket_name)
16
+
17
+ def upload_file(file_name, target_name=None):
18
+ if target_name is None:
19
+ target_name = file_name
20
+ try:
21
+ s3.meta.client.upload_file(Filename=file_name, Bucket=bucket_name, Key=target_name)
22
+ print(f"The file {file_name} has been uploaded!")
23
+ except:
24
+ print("Uploading failed!")
25
+
26
+ def list_all_files():
27
+ return [obj.key for obj in bucket.objects.all()]
28
+
29
+ def download_file(file_name):
30
+ ''' Download `file_name` from the bucket. todo:check existence before downloading!
31
+ Bucket (str) – The name of the bucket to download from.
32
+ Key (str) – The name of the key to download from.
33
+ Filename (str) – The path to the file to download to.
34
+ '''
35
+ try:
36
+ s3.meta.client.download_file(Bucket=bucket_name, Key=file_name, Filename=file_name)
37
+ print(f"The file {file_name} has been downloaded!")
38
+ except:
39
+ print("Uploading failed!")
40
+
41
+ def hash_name(title, description):
42
+ '''
43
+ For same title and description, it should return the same value.
44
+ '''
45
+ name = title + description
46
+ name = name.lower()
47
+ md5 = hashlib.md5()
48
+ md5.update(name.encode('utf-8'))
49
+ hashed_string = md5.hexdigest()
50
+ return hashed_string