MohamedRashad commited on
Commit
f0ac041
·
1 Parent(s): 6253bc5

Add generation configurations to chatbot interface

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -25,7 +25,7 @@ terminators = [
25
  ]
26
 
27
  @spaces.GPU(duration=120)
28
- def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
29
  base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
30
  new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
31
 
@@ -60,22 +60,24 @@ def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
60
  base_generation_kwargs = dict(
61
  input_ids=base_input_ids,
62
  streamer=base_text_streamer,
63
- max_new_tokens=2048,
64
  eos_token_id=terminators,
65
  pad_token_id=tokenizer.eos_token_id,
66
- do_sample=True,
67
- temperature=0.2,
68
- top_p=0.9,
 
69
  )
70
  new_generation_kwargs = dict(
71
  input_ids=new_input_ids,
72
  streamer=new_text_streamer,
73
- max_new_tokens=2048,
74
  eos_token_id=terminators,
75
  pad_token_id=tokenizer.eos_token_id,
76
- do_sample=True,
77
- temperature=0.2,
78
- top_p=0.9,
 
79
  )
80
 
81
  base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
@@ -111,16 +113,21 @@ with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
111
  gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
112
  system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
113
  with gr.Row(variant="panel"):
114
- base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True)
115
- new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True)
116
  with gr.Row(variant="panel"):
117
  with gr.Column(scale=1):
118
  submit_btn = gr.Button(value="Generate", variant="primary")
119
  clear_btn = gr.Button(value="Clear", variant="secondary")
120
  input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
121
-
122
- input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot])
123
- submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot])
 
 
 
 
 
124
  clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
125
 
126
  demo.launch()
 
25
  ]
26
 
27
  @spaces.GPU(duration=120)
28
+ def generate_both(system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
29
  base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
30
  new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
31
 
 
60
  base_generation_kwargs = dict(
61
  input_ids=base_input_ids,
62
  streamer=base_text_streamer,
63
+ max_new_tokens=max_new_tokens,
64
  eos_token_id=terminators,
65
  pad_token_id=tokenizer.eos_token_id,
66
+ do_sample=True if temperature > 0 else False,
67
+ temperature=temperature,
68
+ top_p=top_p,
69
+ repetition_penalty=repetition_penalty,
70
  )
71
  new_generation_kwargs = dict(
72
  input_ids=new_input_ids,
73
  streamer=new_text_streamer,
74
+ max_new_tokens=max_new_tokens,
75
  eos_token_id=terminators,
76
  pad_token_id=tokenizer.eos_token_id,
77
+ do_sample=True if temperature > 0 else False,
78
+ temperature=temperature,
79
+ top_p=top_p,
80
+ repetition_penalty=repetition_penalty,
81
  )
82
 
83
  base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
 
113
  gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
114
  system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
115
  with gr.Row(variant="panel"):
116
+ base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
117
+ new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
118
  with gr.Row(variant="panel"):
119
  with gr.Column(scale=1):
120
  submit_btn = gr.Button(value="Generate", variant="primary")
121
  clear_btn = gr.Button(value="Clear", variant="secondary")
122
  input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
123
+ with gr.Accordion(label="Generation Configurations", open=False):
124
+ max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=2048, label="Max New Tokens", step=128)
125
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature", step=0.01)
126
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p", step=0.01)
127
+ repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1)
128
+
129
+ input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
130
+ submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
131
  clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
132
 
133
  demo.launch()