Amanpreet Singh commited on
Commit
5485403
·
1 Parent(s): 6c95228

Add input validations, proper defaults and errors for parameters

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -19,6 +19,12 @@ Options:
19
 
20
  API_URL = "https://hfbloom.ngrok.io/generate"
21
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
 
 
 
 
 
 
22
 
23
  hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "huggingface/bloom_internal_prompts", organization="huggingface")
24
 
@@ -47,6 +53,12 @@ def query(payload):
47
  return json.loads(response.content.decode("utf-8"))
48
 
49
  def inference(input_sentence, max_length, sample_or_greedy, seed=42):
 
 
 
 
 
 
50
  if sample_or_greedy == "Sample":
51
  parameters = {"max_new_tokens": max_length,
52
  "top_p": 0.9,
@@ -77,9 +89,9 @@ gr.Interface(
77
  inference,
78
  [
79
  gr.inputs.Textbox(label="Input"),
80
- gr.inputs.Slider(1, 64, default=32, step=1, label="Tokens to generate"),
81
- gr.inputs.Radio(["Sample", "Greedy"], label="Sample or greedy"),
82
- gr.inputs.Radio(["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"], label="Sample other generations (only work in 'Sample' mode", type="index"),
83
  ],
84
  gr.outputs.Textbox(label="Output"),
85
  examples=examples,
 
19
 
20
  API_URL = "https://hfbloom.ngrok.io/generate"
21
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
22
+ MAX_TOKENS = 64
23
+ MAX_LENGTH_ERROR = "The input max length is more than the demo supports. Please select a correct value from the slider."
24
+ GENERATION_OPTIONS = ["Sample", "Greedy"]
25
+ GENERATION_OPTIONS_ERROR = "Please select one option from either 'Sample' or 'Greedy'."
26
+ SAMPLE_OPTIONS = ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"]
27
+
28
 
29
  hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "huggingface/bloom_internal_prompts", organization="huggingface")
30
 
 
53
  return json.loads(response.content.decode("utf-8"))
54
 
55
  def inference(input_sentence, max_length, sample_or_greedy, seed=42):
56
+ if max_length > MAX_TOKENS:
57
+ return MAX_LENGTH_ERROR
58
+
59
+ if len(sample_or_greedy) == 0 or sample_or_greedy not in GENERATION_OPTIONS:
60
+ return GENERATION_OPTIONS_ERROR
61
+
62
  if sample_or_greedy == "Sample":
63
  parameters = {"max_new_tokens": max_length,
64
  "top_p": 0.9,
 
89
  inference,
90
  [
91
  gr.inputs.Textbox(label="Input"),
92
+ gr.inputs.Slider(1, MAX_TOKENS, default=32, step=1, label="Tokens to generate"),
93
+ gr.inputs.Radio(GENERATION_OPTIONS, label="Sample or greedy", default="Sample"),
94
+ gr.inputs.Radio(SAMPLE_OPTIONS, label="Sample other generations (only work in 'Sample' mode", type="index", default="Sample 1"),
95
  ],
96
  gr.outputs.Textbox(label="Output"),
97
  examples=examples,