petals-playground / prompt.py
slush0's picture
Basics works, but still WIP; separators and examples need to be updated from bloom to llama2-related models.
c461bd0
import time
from datetime import datetime
import gradio as gr
import chat_client
CHAT_URL = "wss://chat.petals.dev/api/v2/generate"
#CHAT_URL='ws://localhost:8000/api/v2/generate'
def generate(state, *args):
# Save that we're in generating loop
state["generate"] = True
try:
yield from _generate(state, *args)
finally:
state["generate"] = False
def _generate(
state,
prompt,
model,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
add_stoptoken,
copy_output,
):
start = time.time()
cnt = 0
def stats():
# Produces inline stats for generation speed
# sec/t or t/sec depending on the speed
if cnt == 0:
return "\u2026 | ? sec/t"
if cnt > time.time() - start:
items_per_sec = cnt / (time.time() - start)
return f" | {items_per_sec:.1f} t/sec"
sec_per_item = (time.time() - start) / cnt
return f" | {sec_per_item:.1f} sec/t"
try:
client = chat_client.ModelClient(CHAT_URL)
client.open_session(model, max_length)
except Exception as e:
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
if add_stoptoken:
prompt += "</s>" if "bloomz" in model else "\n\n"
# Translate checkbox items to actual sequences
seq = []
for s in endseq:
if s == "\\n":
seq.append("\n")
elif s == "</s>":
seq.append("</s>")
elif s == "? (question mark)":
seq.append("?")
elif s == ". (dot)":
seq.append(".")
# only top_k or top_p can be set
if top_k == 0:
top_k = None
if top_p == 0:
top_p = None
if top_p and top_k:
top_k = None
if not temperature:
temperature = 1.0
prompt2 = prompt
output = ""
# This render prompt dialog immediately and
# don't wait to generator to return first result
yield [state, prompt2, stats()]
try:
for out in client.generate(
prompt,
max_new_tokens=1,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=seq,
):
if not state["generate"]:
client.close_session()
return
cnt += 1
output += out
if copy_output:
prompt2 += out
yield state, prompt2, output + stats()
# Avoid throwing exception by generate()
# to prevent UI errors.
if cnt >= max_length - 6: # FIXME bulgarian constant
break
# Prints final result w/o statistics
yield state, prompt2, output
except Exception as e:
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
def stop(state):
"""Stops generating."""
state.update({"generate": False})
return state
# ---------------------------------------------------------
# Defining Gradio layout
with gr.Blocks() as iface_prompt:
gr.Markdown(
"""**Useful for testing raw prompts with zero,
one or few-shot prompting.**"""
)
with gr.Row():
model = gr.Radio(
["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz", "bigscience/bloom"], value="stabilityai/StableBeluga2", label="Use model"
)
# Additional ending sequence, at which generation shoud stop
endseq = gr.CheckboxGroup(
["\\n", "</s>", "? (question mark)", ". (dot)"],
value=["</s>"],
label="Extra end sequences",
)
# Maximum length of inference session
max_length = gr.Radio(
[64, 128, 256, 512, 1024, 2048],
value=512,
interactive=True,
label="Max length",
)
with gr.Row():
with gr.Column():
# Switch between sampling and greedy generation
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
# Should the app append stop sequence at the end of prompt
# or should it leave the prompt open?
add_stoptoken = gr.Checkbox(
value=True,
interactive=True,
label="Automatically add eos token to the prompt.",
)
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
# TODO num_beams
# Generation temperature
temperature = gr.Number(
value=0.75, precision=2, interactive=True, label="Temperature"
)
prompt = gr.Textbox(lines=3, label="Prompt", placeholder="Prompt Here...")
state = gr.State({"generate": False})
with gr.Row():
button_generate = gr.Button("Generate")
button_stop = gr.Button("Stop")
# Automatically copy the output at the end of prompt
copy_output = gr.Checkbox(label="Output -> Prompt")
output = gr.Textbox(lines=3, label="Output")
# Define button actions
button_generate.click(
generate,
inputs=[
state,
prompt,
model,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
add_stoptoken,
copy_output,
],
outputs=[state, prompt, output],
)
button_stop.click(stop, inputs=[state], outputs=[state])
examples = gr.Examples(
inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
examples=[
[
"The SQL command to extract all the users whose name starts with A is: ",
"stabilityai/StableBeluga2",
False,
0,
0,
1,
False,
],
[
"// Returns every other value in the list as a new list.\n"
"def every_other(l):\n",
"stabilityai/StableBeluga2",
False,
0,
0,
1,
False,
],
[
"The Spanish translation of thank you for your help is: ",
"stabilityai/StableBeluga2",
False,
0,
0,
1,
False,
],
[
"A human talks to a powerful AI that follows the Human's instructions.\n"
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
"Human: Hi!</s>\n"
"AI: Hi! How can I help you?</s>\n"
"Human: What's the capital of Portugal?</s>\n"
"AI: ",
"stabilityai/StableBeluga2",
True,
0,
0.9,
0.75,
False,
],
[
"Here is a very polite and formal e-mail writing to staff that they are fired because of late delivery of the project and higher costs:\n"
"Dear staff,\n"
"it is with utmost ",
"stabilityai/StableBeluga2",
True,
0,
0.9,
0.75,
False,
],
[
"Lorem ipsum dolor sit amet, ",
"stabilityai/StableBeluga2",
True,
0,
0.9,
0.75,
False,
],
],
)