lewtun HF staff commited on
Commit
ff6dd35
·
1 Parent(s): dadd4bc

Add logging

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -4,7 +4,9 @@ from threading import Thread
4
  import gradio as gr
5
  import torch
6
  from transformers import (AutoModelForCausalLM, AutoTokenizer,
7
- TextIteratorStreamer)
 
 
8
 
9
  theme = gr.themes.Monochrome(
10
  primary_hue="indigo",
@@ -15,6 +17,10 @@ theme = gr.themes.Monochrome(
15
  )
16
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
17
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
 
 
18
 
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -31,8 +37,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
31
 
32
  PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer:"""
33
 
 
 
 
 
 
 
34
 
35
  def generate(instruction, temperature=0.8, max_new_tokens=128, top_p=0.95, top_k=40):
 
36
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
37
 
38
  temperature = float(temperature)
@@ -64,13 +77,18 @@ def generate(instruction, temperature=0.8, max_new_tokens=128, top_p=0.95, top_k
64
  # new_text = new_text.replace(tokenizer.eos_token, "")
65
  output += new_text
66
  yield output
 
 
 
67
  return output
68
 
69
 
70
  examples = [
71
- "How do I create an array in C++ of length 5 which contains all even numbers between 1 and 10?",
 
72
  "How can I sort a list in Python?",
73
  "How can I write a Java function to generate the nth Fibonacci number?",
 
74
  ]
75
 
76
 
@@ -85,9 +103,11 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=".generating {visibilit
85
  gr.Markdown(
86
  """<h1><center>🦙🦙🦙 StackLLaMa 🦙🦙🦙</center></h1>
87
 
88
- StackLLaMa is a 7 billion parameter language model that has been trained on pairs of programming questions and answers from [Stack Exchange](https://stackexchange.com) using Reinforcement Learning from Human Feedback with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our [blog post](https://huggingface.co/blog/stackllama).
 
 
89
 
90
- Type in the box below and click the button to generate answers to your most pressing coding questions 🔥!
91
  """
92
  )
93
  with gr.Row():
@@ -96,12 +116,6 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=".generating {visibilit
96
  with gr.Box():
97
  gr.Markdown("**Answer**")
98
  output = gr.Markdown()
99
- # output = gr.Textbox(
100
- # interactive=False,
101
- # lines=8,
102
- # label="Answer",
103
- # placeholder="Here will be the answer to your question",
104
- # )
105
  submit = gr.Button("Generate", variant="primary")
106
  gr.Examples(
107
  examples=examples,
@@ -123,7 +137,7 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=".generating {visibilit
123
  )
124
  max_new_tokens = gr.Slider(
125
  label="Max new tokens",
126
- value=64,
127
  minimum=0,
128
  maximum=2048,
129
  step=4,
@@ -153,4 +167,4 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=".generating {visibilit
153
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
154
 
155
  demo.queue(concurrency_count=1)
156
- demo.launch(enable_queue=True)#, share=True)
 
4
  import gradio as gr
5
  import torch
6
  from transformers import (AutoModelForCausalLM, AutoTokenizer,
7
+ TextIteratorStreamer, set_seed)
8
+ from huggingface_hub import Repository
9
+ import json
10
 
11
  theme = gr.themes.Monochrome(
12
  primary_hue="indigo",
 
17
  )
18
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+ if HF_TOKEN:
21
+ repo = Repository(
22
+ local_dir="data", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
23
+ )
24
 
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
37
 
38
  PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer:"""
39
 
40
+ def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
41
+ with open(os.path.join("data", "prompts.jsonl"), "a") as f:
42
+ json.dump({"inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs}, f, ensure_ascii=False)
43
+ f.write("\n")
44
+ commit_url = repo.push_to_hub()
45
+
46
 
47
  def generate(instruction, temperature=0.8, max_new_tokens=128, top_p=0.95, top_k=40):
48
+ set_seed(42)
49
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
50
 
51
  temperature = float(temperature)
 
77
  # new_text = new_text.replace(tokenizer.eos_token, "")
78
  output += new_text
79
  yield output
80
+ if HF_TOKEN:
81
+ print("Pushing prompt and completion to the Hub")
82
+ save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
83
  return output
84
 
85
 
86
  examples = [
87
+ "A llama is in my lawn. How do I get rid of him?",
88
+ "How do I create an array in C++ which contains all even numbers between 1 and 10?",
89
  "How can I sort a list in Python?",
90
  "How can I write a Java function to generate the nth Fibonacci number?",
91
+ "How many helicopters can a llama eat in one sitting?",
92
  ]
93
 
94
 
 
103
  gr.Markdown(
104
  """<h1><center>🦙🦙🦙 StackLLaMa 🦙🦙🦙</center></h1>
105
 
106
+ StackLLaMa is a 7 billion parameter language model that has been trained on pairs of questions and answers from [Stack Exchange](https://stackexchange.com) using Reinforcement Learning from Human Feedback with the [TRL library](https://github.com/lvwerra/trl). For more details, check out our [blog post](https://huggingface.co/blog/stackllama).
107
+
108
+ Type in the box below and click the button to generate answers to your most pressing questions 🔥!
109
 
110
+ **Note:** we are collecting your prompts and model completions for research purposes.
111
  """
112
  )
113
  with gr.Row():
 
116
  with gr.Box():
117
  gr.Markdown("**Answer**")
118
  output = gr.Markdown()
 
 
 
 
 
 
119
  submit = gr.Button("Generate", variant="primary")
120
  gr.Examples(
121
  examples=examples,
 
137
  )
138
  max_new_tokens = gr.Slider(
139
  label="Max new tokens",
140
+ value=128,
141
  minimum=0,
142
  maximum=2048,
143
  step=4,
 
167
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
168
 
169
  demo.queue(concurrency_count=1)
170
+ demo.launch(enable_queue=True, share=True)