import time from datetime import datetime import gradio as gr import chat_client CHAT_URL = "wss://chat.petals.dev/api/v2/generate" #CHAT_URL='ws://localhost:8000/api/v2/generate' def generate(state, *args): # Save that we're in generating loop state["generate"] = True try: yield from _generate(state, *args) finally: state["generate"] = False def _generate( state, prompt, model, endseq, max_length, do_sample, top_k, top_p, temperature, add_stoptoken, copy_output, ): start = time.time() cnt = 0 def stats(): # Produces inline stats for generation speed # sec/t or t/sec depending on the speed if cnt == 0: return "\u2026 | ? sec/t" if cnt > time.time() - start: items_per_sec = cnt / (time.time() - start) return f" | {items_per_sec:.1f} t/sec" sec_per_item = (time.time() - start) / cnt return f" | {sec_per_item:.1f} sec/t" try: client = chat_client.ModelClient(CHAT_URL) client.open_session(model, max_length) except Exception as e: print(datetime.now(), str(e)[-500:]) raise gr.Error(str(e)[-500:]) if add_stoptoken: prompt += "" if "bloomz" in model else "\n\n" # Translate checkbox items to actual sequences seq = [] for s in endseq: if s == "\\n": seq.append("\n") elif s == "": seq.append("") elif s == "? (question mark)": seq.append("?") elif s == ". (dot)": seq.append(".") # only top_k or top_p can be set if top_k == 0: top_k = None if top_p == 0: top_p = None if top_p and top_k: top_k = None if not temperature: temperature = 1.0 prompt2 = prompt output = "" # This render prompt dialog immediately and # don't wait to generator to return first result yield [state, prompt2, stats()] try: for out in client.generate( prompt, max_new_tokens=1, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, stop_sequences=seq, ): if not state["generate"]: client.close_session() return cnt += 1 output += out if copy_output: prompt2 += out yield state, prompt2, output + stats() # Avoid throwing exception by generate() # to prevent UI errors. if cnt >= max_length - 6: # FIXME bulgarian constant break # Prints final result w/o statistics yield state, prompt2, output except Exception as e: print(datetime.now(), str(e)[-500:]) raise gr.Error(str(e)[-500:]) def stop(state): """Stops generating.""" state.update({"generate": False}) return state # --------------------------------------------------------- # Defining Gradio layout with gr.Blocks() as iface_prompt: gr.Markdown( """**Useful for testing raw prompts with zero, one or few-shot prompting.**""" ) with gr.Row(): model = gr.Radio( ["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz", "bigscience/bloom"], value="stabilityai/StableBeluga2", label="Use model" ) # Additional ending sequence, at which generation shoud stop endseq = gr.CheckboxGroup( ["\\n", "", "? (question mark)", ". (dot)"], value=[""], label="Extra end sequences", ) # Maximum length of inference session max_length = gr.Radio( [64, 128, 256, 512, 1024, 2048], value=512, interactive=True, label="Max length", ) with gr.Row(): with gr.Column(): # Switch between sampling and greedy generation do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample") # Should the app append stop sequence at the end of prompt # or should it leave the prompt open? add_stoptoken = gr.Checkbox( value=True, interactive=True, label="Automatically add eos token to the prompt.", ) # Only one of top_k and top_p can be set. Requires "do_sample=True" to work. top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k") top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p") # TODO num_beams # Generation temperature temperature = gr.Number( value=0.75, precision=2, interactive=True, label="Temperature" ) prompt = gr.Textbox(lines=3, label="Prompt", placeholder="Prompt Here...") state = gr.State({"generate": False}) with gr.Row(): button_generate = gr.Button("Generate") button_stop = gr.Button("Stop") # Automatically copy the output at the end of prompt copy_output = gr.Checkbox(label="Output -> Prompt") output = gr.Textbox(lines=3, label="Output") # Define button actions button_generate.click( generate, inputs=[ state, prompt, model, endseq, max_length, do_sample, top_k, top_p, temperature, add_stoptoken, copy_output, ], outputs=[state, prompt, output], ) button_stop.click(stop, inputs=[state], outputs=[state]) examples = gr.Examples( inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken], examples=[ [ "The SQL command to extract all the users whose name starts with A is: ", "stabilityai/StableBeluga2", False, 0, 0, 1, False, ], [ "// Returns every other value in the list as a new list.\n" "def every_other(l):\n", "stabilityai/StableBeluga2", False, 0, 0, 1, False, ], [ "The Spanish translation of thank you for your help is: ", "stabilityai/StableBeluga2", False, 0, 0, 1, False, ], [ "A human talks to a powerful AI that follows the Human's instructions.\n" "AI is talkative, friendly, positive and provides detailed answers to any question.\n" "Human: Hi!\n" "AI: Hi! How can I help you?\n" "Human: What's the capital of Portugal?\n" "AI: ", "stabilityai/StableBeluga2", True, 0, 0.9, 0.75, False, ], [ "Here is a very polite and formal e-mail writing to staff that they are fired because of late delivery of the project and higher costs:\n" "Dear staff,\n" "it is with utmost ", "stabilityai/StableBeluga2", True, 0, 0.9, 0.75, False, ], [ "Lorem ipsum dolor sit amet, ", "stabilityai/StableBeluga2", True, 0, 0.9, 0.75, False, ], ], )