Spaces:
Runtime error
Runtime error
File size: 4,540 Bytes
1a3d32b e4bdfa5 649a7fc eae7478 649a7fc eb76062 9eba7f5 649a7fc e4bdfa5 649a7fc ae0810b 9eba7f5 e4bdfa5 404b0ea e4bdfa5 649a7fc 0d2ec09 649a7fc 0d2ec09 649a7fc e4bdfa5 eae7478 e4bdfa5 e25ad45 e4bdfa5 0d2ec09 3992653 e4bdfa5 7d023f3 e4bdfa5 7d023f3 0d2ec09 e4bdfa5 0d2ec09 a42e410 e61db62 0d2ec09 e4bdfa5 0d2ec09 7890994 e4bdfa5 0d2ec09 e4bdfa5 0d2ec09 7890994 a42e410 e4bdfa5 7d023f3 e4bdfa5 0d2ec09 e4bdfa5 0d2ec09 e4bdfa5 4fc9c9c 0d2ec09 4f32996 e4bdfa5 4f32996 974b72c 0d2ec09 974b72c 0d2ec09 7c44984 e4bdfa5 6d7f46d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import random
#from huggingface_hub import InferenceClient
#import spaces
import os
os.environ["KERAS_BACKEND"] = "tensorflow" #"jax" "torch"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
import keras_hub
models = [
"hf://tatihden/gemma_mental_health_2b_it_en",
"hf://tatihden/gemma_mental_health_2b_en",
"hf://tatihden/gemma_mental_health_7b_it_en"
]
clients = []
for model in models:
clients.append(keras_hub.models.GemmaCausalLM.from_preset(model))
#from huggingface_hub import InferenceClient
#clients = []
#for model in models:
#clients.append(InferenceClient(model))
#@spaces.GPU
def format_prompt(message, history):
prompt = ""
if history:
for user_prompt, bot_response in history:
prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
#prompt += f"<start_of_turn>model{bot_response}"
prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
return prompt
def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p):
client = clients[int(client_choice) - 1]
if not history:
history = []
hist_len = 0
if history:
hist_len = len(history)
print(hist_len)
#generate_kwargs = dict(
#temperature=temp,
#max_new_tokens=tokens,
#top_p=top_p,
#repetition_penalty=rep_p,
#do_sample=True,
#seed=seed,
#)
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
stream = client.generate(formatted_prompt)
output = ""
for response in stream:
output+= response
history.append((prompt, output))
yield history
def clear_fn():
return None
rand_val = random.randint(1, 1111111111111111)
def check_rand(inp, val):
if inp is True:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
else:
return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
with gr.Blocks(theme=gr.themes.Soft(),css=".gradio-container {background-color: rgb(187 247 208)}") as app:
gr.HTML(
"""<center><h1 style='font-size:xx-large;'>CalmChat:A mental Health Conversational Agent</h1></center>""")
with gr.Group():
with gr.Row():
client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
interactive=True)
chat_b = gr.Chatbot(height=500)
with gr.Group():
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
rand = gr.Checkbox(label="Random Seed", value=True)
seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
tokens = gr.Slider(label="Max new tokens", value=6400, minimum=0, maximum=8000, step=64,
interactive=True, visible=True, info="The maximum number of tokens")
with gr.Column(scale=1):
with gr.Group():
temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
sys_inp = gr.Textbox(label="System Prompt (optional)")
inp = gr.Textbox(label="Prompt")
with gr.Row():
btn = gr.Button("Chat")
stop_btn = gr.Button("Stop")
clear_btn = gr.Button("Clear")
chat_sub = inp.submit(check_rand, [rand, seed], seed).then(chat_inf,
[sys_inp, inp, chat_b, client_choice, seed, temp, tokens,
top_p, rep_p], chat_b)
go = btn.click(check_rand, [rand, seed], seed).then(chat_inf,
[sys_inp, inp, chat_b, client_choice, seed, temp, tokens, top_p,
rep_p], chat_b)
stop_btn.click(None, None, None, cancels=[go, chat_sub])
clear_btn.click(clear_fn, None, [chat_b])
app.queue(default_concurrency_limit=10).launch() |