vpcom commited on
Commit
d0e883b
·
1 Parent(s): cbadafb

fix: top_k as input

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -132,7 +132,7 @@ def format_prompt(message, history, system_prompt):
132
 
133
  def generate(
134
  prompt, history, system_prompt,
135
- temperature=0.9, max_new_tokens=100, top_p=0.95,
136
  repetition_penalty=1.0, seed=42,
137
  ):
138
  global HISTORY
@@ -152,7 +152,7 @@ def generate(
152
  stop_sequences=stop_sequences,
153
  do_sample=True,
154
  #best_of=2,
155
- top_k=100,
156
  #typical_p=0.9,
157
  #seed=seed,
158
  )
@@ -208,13 +208,22 @@ additional_inputs=[
208
  ),
209
  gr.Slider(
210
  label="Top-p (nucleus sampling)",
211
- value=0.75,
212
  minimum=0.0,
213
  maximum=1,
214
  step=0.05,
215
  interactive=True,
216
  info="Higher values sample more low-probability tokens",
217
  ),
 
 
 
 
 
 
 
 
 
218
  gr.Slider(
219
  label="Repetition penalty",
220
  value=1.005,
@@ -621,7 +630,8 @@ def vote(data: gr.LikeData):
621
  "temperature": additional_inputs[1].value,
622
  "max_new_tokens": additional_inputs[2].value,
623
  "top_p": additional_inputs[3].value,
624
- "repetition_penalty": additional_inputs[4].value,
 
625
  "response": data.value,
626
  "label": data.liked,
627
  }, ensure_ascii=False
 
132
 
133
  def generate(
134
  prompt, history, system_prompt,
135
+ temperature=0.9, max_new_tokens=100, top_p=0.95, top_k=100,
136
  repetition_penalty=1.0, seed=42,
137
  ):
138
  global HISTORY
 
152
  stop_sequences=stop_sequences,
153
  do_sample=True,
154
  #best_of=2,
155
+ top_k=top_k,
156
  #typical_p=0.9,
157
  #seed=seed,
158
  )
 
208
  ),
209
  gr.Slider(
210
  label="Top-p (nucleus sampling)",
211
+ value=1.0,
212
  minimum=0.0,
213
  maximum=1,
214
  step=0.05,
215
  interactive=True,
216
  info="Higher values sample more low-probability tokens",
217
  ),
218
+ gr.Slider(
219
+ label="Top-k",
220
+ value=40,
221
+ minimum=0.0,
222
+ maximum=1000,
223
+ step=1,
224
+ interactive=True,
225
+ info="Higher values sample more low-probability tokens",
226
+ ),
227
  gr.Slider(
228
  label="Repetition penalty",
229
  value=1.005,
 
630
  "temperature": additional_inputs[1].value,
631
  "max_new_tokens": additional_inputs[2].value,
632
  "top_p": additional_inputs[3].value,
633
+ "top_k": additional_inputs[4].value,
634
+ "repetition_penalty": additional_inputs[5].value,
635
  "response": data.value,
636
  "label": data.liked,
637
  }, ensure_ascii=False