Spaces:
Runtime error
Runtime error
import time | |
from datetime import datetime | |
import gradio as gr | |
import chat_client | |
CHAT_URL = "wss://chat.petals.dev/api/v2/generate" | |
#CHAT_URL='ws://localhost:8000/api/v2/generate' | |
def generate(state, *args): | |
# Save that we're in generating loop | |
state["generate"] = True | |
try: | |
yield from _generate(state, *args) | |
finally: | |
state["generate"] = False | |
def _generate( | |
state, | |
prompt, | |
model, | |
endseq, | |
max_length, | |
do_sample, | |
top_k, | |
top_p, | |
temperature, | |
add_stoptoken, | |
copy_output, | |
): | |
start = time.time() | |
cnt = 0 | |
def stats(): | |
# Produces inline stats for generation speed | |
# sec/t or t/sec depending on the speed | |
if cnt == 0: | |
return "\u2026 | ? sec/t" | |
if cnt > time.time() - start: | |
items_per_sec = cnt / (time.time() - start) | |
return f" | {items_per_sec:.1f} t/sec" | |
sec_per_item = (time.time() - start) / cnt | |
return f" | {sec_per_item:.1f} sec/t" | |
try: | |
client = chat_client.ModelClient(CHAT_URL) | |
client.open_session(model, max_length) | |
except Exception as e: | |
print(datetime.now(), str(e)[-500:]) | |
raise gr.Error(str(e)[-500:]) | |
if add_stoptoken: | |
prompt += "</s>" if "bloomz" in model else "\n\n" | |
# Translate checkbox items to actual sequences | |
seq = [] | |
for s in endseq: | |
if s == "\\n": | |
seq.append("\n") | |
elif s == "</s>": | |
seq.append("</s>") | |
elif s == "? (question mark)": | |
seq.append("?") | |
elif s == ". (dot)": | |
seq.append(".") | |
# only top_k or top_p can be set | |
if top_k == 0: | |
top_k = None | |
if top_p == 0: | |
top_p = None | |
if top_p and top_k: | |
top_k = None | |
if not temperature: | |
temperature = 1.0 | |
prompt2 = prompt | |
output = "" | |
# This render prompt dialog immediately and | |
# don't wait to generator to return first result | |
yield [state, prompt2, stats()] | |
try: | |
for out in client.generate( | |
prompt, | |
max_new_tokens=1, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
stop_sequences=seq, | |
): | |
if not state["generate"]: | |
client.close_session() | |
return | |
cnt += 1 | |
output += out | |
if copy_output: | |
prompt2 += out | |
yield state, prompt2, output + stats() | |
# Avoid throwing exception by generate() | |
# to prevent UI errors. | |
if cnt >= max_length - 6: # FIXME bulgarian constant | |
break | |
# Prints final result w/o statistics | |
yield state, prompt2, output | |
except Exception as e: | |
print(datetime.now(), str(e)[-500:]) | |
raise gr.Error(str(e)[-500:]) | |
def stop(state): | |
"""Stops generating.""" | |
state.update({"generate": False}) | |
return state | |
# --------------------------------------------------------- | |
# Defining Gradio layout | |
with gr.Blocks() as iface_prompt: | |
gr.Markdown( | |
"""**Useful for testing raw prompts with zero, | |
one or few-shot prompting.**""" | |
) | |
with gr.Row(): | |
model = gr.Radio( | |
["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz", "bigscience/bloom"], value="stabilityai/StableBeluga2", label="Use model" | |
) | |
# Additional ending sequence, at which generation shoud stop | |
endseq = gr.CheckboxGroup( | |
["\\n", "</s>", "? (question mark)", ". (dot)"], | |
value=["</s>"], | |
label="Extra end sequences", | |
) | |
# Maximum length of inference session | |
max_length = gr.Radio( | |
[64, 128, 256, 512, 1024, 2048], | |
value=512, | |
interactive=True, | |
label="Max length", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
# Switch between sampling and greedy generation | |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample") | |
# Should the app append stop sequence at the end of prompt | |
# or should it leave the prompt open? | |
add_stoptoken = gr.Checkbox( | |
value=True, | |
interactive=True, | |
label="Automatically add eos token to the prompt.", | |
) | |
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work. | |
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k") | |
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p") | |
# TODO num_beams | |
# Generation temperature | |
temperature = gr.Number( | |
value=0.75, precision=2, interactive=True, label="Temperature" | |
) | |
prompt = gr.Textbox(lines=3, label="Prompt", placeholder="Prompt Here...") | |
state = gr.State({"generate": False}) | |
with gr.Row(): | |
button_generate = gr.Button("Generate") | |
button_stop = gr.Button("Stop") | |
# Automatically copy the output at the end of prompt | |
copy_output = gr.Checkbox(label="Output -> Prompt") | |
output = gr.Textbox(lines=3, label="Output") | |
# Define button actions | |
button_generate.click( | |
generate, | |
inputs=[ | |
state, | |
prompt, | |
model, | |
endseq, | |
max_length, | |
do_sample, | |
top_k, | |
top_p, | |
temperature, | |
add_stoptoken, | |
copy_output, | |
], | |
outputs=[state, prompt, output], | |
) | |
button_stop.click(stop, inputs=[state], outputs=[state]) | |
examples = gr.Examples( | |
inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken], | |
examples=[ | |
[ | |
"The SQL command to extract all the users whose name starts with A is: ", | |
"stabilityai/StableBeluga2", | |
False, | |
0, | |
0, | |
1, | |
False, | |
], | |
[ | |
"// Returns every other value in the list as a new list.\n" | |
"def every_other(l):\n", | |
"stabilityai/StableBeluga2", | |
False, | |
0, | |
0, | |
1, | |
False, | |
], | |
[ | |
"The Spanish translation of thank you for your help is: ", | |
"stabilityai/StableBeluga2", | |
False, | |
0, | |
0, | |
1, | |
False, | |
], | |
[ | |
"A human talks to a powerful AI that follows the Human's instructions.\n" | |
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n" | |
"Human: Hi!</s>\n" | |
"AI: Hi! How can I help you?</s>\n" | |
"Human: What's the capital of Portugal?</s>\n" | |
"AI: ", | |
"stabilityai/StableBeluga2", | |
True, | |
0, | |
0.9, | |
0.75, | |
False, | |
], | |
[ | |
"Here is a very polite and formal e-mail writing to staff that they are fired because of late delivery of the project and higher costs:\n" | |
"Dear staff,\n" | |
"it is with utmost ", | |
"stabilityai/StableBeluga2", | |
True, | |
0, | |
0.9, | |
0.75, | |
False, | |
], | |
[ | |
"Lorem ipsum dolor sit amet, ", | |
"stabilityai/StableBeluga2", | |
True, | |
0, | |
0.9, | |
0.75, | |
False, | |
], | |
], | |
) | |