Spaces:
Runtime error
Runtime error
File size: 5,130 Bytes
7226095 3dda37f 7226095 3c85871 7226095 b0bc8a2 54cab03 9480df8 2a946bb 7226095 9a316f2 7226095 165c814 7226095 165c814 7226095 165c814 7226095 165c814 7226095 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import torch
import os
from transformers import pipeline
from transformers import AutoTokenizer
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)
TOKEN = os.getenv("USER_TOKEN")
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
instruct_pipeline_falcon = pipeline(model="tiiuae/falcon-7b-instruct", tokenizer = tokenizer, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", device=0)
instruct_pipeline_llama = pipeline(model="project-baize/baize-v2-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
def generate(query, temperature, top_p, top_k, max_new_tokens):
return [instruct_pipeline_falcon(query, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens)[0]["generated_text"],
instruct_pipeline_llama(query, temperature=temperature, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens)[0]["generated_text"]]
examples = [
"How many helicopters can a human eat in one sitting?",
"What is an alpaca? How is it different from a llama?",
"Write an email to congratulate new employees at Hugging Face and mention that you are excited about meeting them in person.",
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
"Explain the moon landing to a 6 year old in a few sentences.",
"Why aren't birds real?",
"How can I steal from a grocery store without getting caught?",
"Why is it important to eat socks after meditating?",
]
def process_example(args):
for x in generate(args):
pass
return x
css = ".generating {visibility: hidden}"
with gr.Blocks(theme=theme) as demo:
gr.Markdown(
"""<h1><center>Falcon 7B vs. LLaMA 7B instruction tuned</center></h1>
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
instruction = gr.Textbox(placeholder="Enter your question here", label="Question", elem_id="q-input")
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.5,
minimum=0.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.95,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample fewer low-probability tokens",
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=50,
minimum=0.0,
maximum=100,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens",
)
with gr.Column():
with gr.Row():
max_new_tokens = gr.Slider(
label="Maximum new tokens",
value=256,
minimum=0,
maximum=2048,
step=5,
interactive=True,
info="The maximum number of new tokens to generate",
)
with gr.Row():
submit = gr.Button("Generate Answers")
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("**Falcon 7B instruct**")
output_falcon = gr.Markdown()
with gr.Column():
with gr.Box():
gr.Markdown("**LLaMA 7B instruct**")
output_llama = gr.Markdown()
with gr.Row():
gr.Examples(
examples=examples,
inputs=[instruction],
cache_examples=False,
fn=process_example,
outputs=[output_falcon, output_llama],
)
submit.click(generate, inputs=[instruction, temperature, top_p, top_k, max_new_tokens], outputs=[output_falcon, output_llama ])
instruction.submit(generate, inputs=[instruction, temperature, top_p, top_k, max_new_tokens ], outputs=[output_falcon, output_llama])
demo.queue(concurrency_count=16).launch(debug=True) |