Spaces:
Runtime error
Runtime error
File size: 4,167 Bytes
1a3d32b abd151c e4bdfa5 dbcb73c abd151c 3be2f59 abd151c e4bdfa5 404b0ea e4bdfa5 abd151c e4bdfa5 dbcb73c e4bdfa5 99223bb 3992653 e4bdfa5 3992653 99223bb e4bdfa5 34c6d2e e4bdfa5 99223bb e4bdfa5 7890994 e4bdfa5 7890994 7c44984 e4bdfa5 abd151c e4bdfa5 7c44984 4f32996 e4bdfa5 4f32996 34c6d2e 7c44984 e4bdfa5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
#from huggingface_hub import InferenceClient
import random
import spaces
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ["KERAS_BACKEND"] = "jax"
import keras_hub
import keras
models = [
"hf://tatihden/gemma_mental_health_2b_it_en",
"hf://tatihden/gemma_mental_health_2b_en",
"hf://tatihden/gemma_mental_health_7b_it_en"
]
clients = []
for model in models:
clients.append(keras_hub.models.GemmaCausalLM.from_preset(model))
@spaces.GPU
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
prompt += f"<start_of_turn>model{bot_response}"
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
return prompt
def chat_inf(system_prompt, prompt, history, temp, tokens, top_p, seed, client_choice):
client = clients[int(client_choice) - 1]
if not history:
history = []
hist_len = 0
if history:
hist_len = len(history)
print(hist_len)
sampler = keras_hub.samplers.TopKSampler(k=5, seed=seed, top_p=top_p, temperature=temp)
client.compile(sampler=sampler)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.generate(formatted_prompt)
output = ""
for response in stream:
output += response.str(tokens)
yield [(prompt, output)]
history.append((prompt, output))
yield history
def clear_fn():
return None
rand_val = random.randint(1, 1111111111111111)
def check_rand(inp, val):
if inp is True:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
else:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
with gr.Blocks(theme=gr.themes.Glass(),css=".gradio-container {background-color: rgb(134 239 172)}") as app:
gr.HTML(
"""<center><h1 style='font-size:xx-large;'>CalmChat</h1></center>""")
with gr.Group():
with gr.Row():
client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
interactive=True)
chat_b = gr.Chatbot(height=500)
with gr.Group():
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
rand = gr.Checkbox(label="Random Seed", value=True)
seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
tokens = gr.Slider(label="Max new tokens", value=6400, minimum=0, maximum=8000, step=64,
interactive=True, visible=True, info="The maximum number of tokens")
with gr.Column(scale=1):
with gr.Group():
temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
#rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
sys_inp = gr.Textbox(label="System Prompt (optional)")
inp = gr.Textbox(label="Prompt")
with gr.Row():
btn = gr.Button("Chat")
stop_btn = gr.Button("Stop")
clear_btn = gr.Button("Clear")
chat_sub = inp.submit(check_rand, [rand,seed],seed).then(chat_inf,
[sys_inp, inp, chat_b, client_choice, temp, tokens, top_p], chat_b)
go = btn.click(check_rand, [rand, seed], seed).then(chat_inf,
[sys_inp, inp, chat_b, client_choice, temp, tokens, top_p], chat_b)
stop_btn.click(None, None, None, cancels=[go, chat_sub])
clear_btn.click(clear_fn, None, [chat_b])
app.queue(default_concurrency_limit=10).launch() |