lewtun HF staff commited on
Commit
b7f7d63
·
1 Parent(s): 16b017f
Files changed (1) hide show
  1. app.py +18 -28
app.py CHANGED
@@ -2,14 +2,11 @@ import json
2
  import os
3
 
4
  import gradio as gr
5
- # import torch
6
- # from transformers import (AutoModelForCausalLM, AutoTokenizer,
7
- # TextIteratorStreamer, set_seed)
8
  from huggingface_hub import Repository
9
  from text_generation import Client
10
 
11
- # from threading import Thread
12
-
13
 
14
 
15
  theme = gr.themes.Monochrome(
@@ -19,30 +16,16 @@ theme = gr.themes.Monochrome(
19
  radius_size=gr.themes.sizes.radius_sm,
20
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
21
  )
22
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
23
- # os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
  if HF_TOKEN:
25
  repo = Repository(
26
  local_dir="data", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
27
  )
28
 
29
  client = Client(
30
- "https://api-inference.huggingface.co/models/trl-lib/llama-se-rl-merged",
31
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
32
  )
33
 
34
- # device = "cuda" if torch.cuda.is_available() else "cpu"
35
- # model_id = "trl-lib/llama-se-rl-merged"
36
- # print(f"Loading model: {model_id}")
37
- # if device == "cpu":
38
- # model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_auth_token=HF_TOKEN)
39
- # else:
40
- # model = AutoModelForCausalLM.from_pretrained(
41
- # model_id, device_map="auto", load_in_8bit=True, use_auth_token=HF_TOKEN
42
- # )
43
-
44
- # tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
45
-
46
  PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer:"""
47
 
48
 
@@ -93,26 +76,33 @@ def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
93
 
94
 
95
  def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=100):
96
- # set_seed(42)
97
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
98
 
99
  temperature = float(temperature)
100
  top_p = float(top_p)
101
 
102
- stream = client.generate_stream(
103
- formatted_instruction,
104
  temperature=temperature,
105
- truncate=999,
106
  max_new_tokens=max_new_tokens,
107
  top_p=top_p,
108
  top_k=top_k,
109
- # stop_sequences=["</s>"],
 
 
 
 
 
 
 
110
  )
111
 
112
  output = ""
113
  for response in stream:
114
  output += response.token.text
115
  yield output
 
 
 
116
 
117
  return output
118
 
@@ -143,9 +133,9 @@ def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k
143
  # # new_text = new_text.replace(tokenizer.eos_token, "")
144
  # output += new_text
145
  # yield output
146
- # if HF_TOKEN:
147
- # print("Pushing prompt and completion to the Hub")
148
- # save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
149
  # return output
150
 
151
 
 
2
  import os
3
 
4
  import gradio as gr
 
 
 
5
  from huggingface_hub import Repository
6
  from text_generation import Client
7
 
8
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
+ API_URL = os.environ.get("API_URL")
10
 
11
 
12
  theme = gr.themes.Monochrome(
 
16
  radius_size=gr.themes.sizes.radius_sm,
17
  font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
18
  )
 
 
19
  if HF_TOKEN:
20
  repo = Repository(
21
  local_dir="data", clone_from="trl-lib/stack-llama-prompts", use_auth_token=HF_TOKEN, repo_type="dataset"
22
  )
23
 
24
  client = Client(
25
+ API_URL,
26
  headers={"Authorization": f"Bearer {HF_TOKEN}"},
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer:"""
30
 
31
 
 
76
 
77
 
78
  def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=100):
 
79
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
80
 
81
  temperature = float(temperature)
82
  top_p = float(top_p)
83
 
84
+ generate_kwargs = dict(
 
85
  temperature=temperature,
 
86
  max_new_tokens=max_new_tokens,
87
  top_p=top_p,
88
  top_k=top_k,
89
+ do_sample=True,
90
+ truncate=999,
91
+ seed=42,
92
+ )
93
+
94
+ stream = client.generate_stream(
95
+ formatted_instruction,
96
+ **generate_kwargs,
97
  )
98
 
99
  output = ""
100
  for response in stream:
101
  output += response.token.text
102
  yield output
103
+ if HF_TOKEN:
104
+ print("Pushing prompt and completion to the Hub")
105
+ save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
106
 
107
  return output
108
 
 
133
  # # new_text = new_text.replace(tokenizer.eos_token, "")
134
  # output += new_text
135
  # yield output
136
+ if HF_TOKEN:
137
+ print("Pushing prompt and completion to the Hub")
138
+ save_inputs_and_outputs(formatted_instruction, output, generate_kwargs)
139
  # return output
140
 
141