File size: 4,167 Bytes
1a3d32b
abd151c
e4bdfa5
dbcb73c
abd151c
3be2f59
abd151c
 
 
e4bdfa5
 
404b0ea
 
 
e4bdfa5
 
 
 
abd151c
e4bdfa5
dbcb73c
e4bdfa5
 
 
 
 
 
 
 
 
 
99223bb
3992653
 
e4bdfa5
 
 
 
 
 
 
3992653
99223bb
 
e4bdfa5
34c6d2e
e4bdfa5
 
 
99223bb
e4bdfa5
 
 
 
 
 
 
 
 
 
7890994
 
e4bdfa5
 
 
 
 
7890994
 
7c44984
e4bdfa5
abd151c
e4bdfa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c44984
4f32996
e4bdfa5
 
 
 
 
 
 
 
 
4f32996
34c6d2e
7c44984
 
 
 
e4bdfa5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
#from huggingface_hub import InferenceClient
import random
import spaces
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ["KERAS_BACKEND"] = "jax"
import keras_hub
import keras

models = [
    "hf://tatihden/gemma_mental_health_2b_it_en",
    "hf://tatihden/gemma_mental_health_2b_en",
    "hf://tatihden/gemma_mental_health_7b_it_en"
]

clients = []
for model in models:
    clients.append(keras_hub.models.GemmaCausalLM.from_preset(model))

@spaces.GPU
def format_prompt(message, history):
    prompt = ""
    if history:
        for user_prompt, bot_response in history:
            prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
            prompt += f"<start_of_turn>model{bot_response}"
    prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
    return prompt


def chat_inf(system_prompt, prompt, history, temp, tokens, top_p, seed, client_choice):
    client = clients[int(client_choice) - 1]
    
    if not history:
        history = []
        hist_len = 0
    if history:
        hist_len = len(history)
        print(hist_len)

    sampler = keras_hub.samplers.TopKSampler(k=5, seed=seed, top_p=top_p, temperature=temp)
    client.compile(sampler=sampler)
    
    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.generate(formatted_prompt)
    output = ""

    for response in stream:
        output += response.str(tokens)
        yield [(prompt, output)]
    history.append((prompt, output))
    yield history


def clear_fn():
    return None


rand_val = random.randint(1, 1111111111111111)


def check_rand(inp, val):
    if inp is True:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
    else:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))


with gr.Blocks(theme=gr.themes.Glass(),css=".gradio-container {background-color: rgb(134 239 172)}") as app:
    gr.HTML(
        """<center><h1 style='font-size:xx-large;'>CalmChat</h1></center>""")
    with gr.Group():
        with gr.Row():
            client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
                                        interactive=True)
    chat_b = gr.Chatbot(height=500)
    with gr.Group():
        with gr.Row():
            with gr.Column(scale=1):
                with gr.Group():
                    rand = gr.Checkbox(label="Random Seed", value=True)
                    seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
                    tokens = gr.Slider(label="Max new tokens", value=6400, minimum=0, maximum=8000, step=64,
                                       interactive=True, visible=True, info="The maximum number of tokens")
            with gr.Column(scale=1):
                with gr.Group():
                    temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    #rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)

    with gr.Group():
        with gr.Row():
            with gr.Column(scale=3):
                sys_inp = gr.Textbox(label="System Prompt (optional)")
                inp = gr.Textbox(label="Prompt")
                with gr.Row():
                    btn = gr.Button("Chat")
                    stop_btn = gr.Button("Stop")
                    clear_btn = gr.Button("Clear")

    chat_sub = inp.submit(check_rand, [rand,seed],seed).then(chat_inf,
                                                   [sys_inp, inp, chat_b, client_choice, temp, tokens, top_p], chat_b)
    go = btn.click(check_rand, [rand, seed], seed).then(chat_inf,
                                                        [sys_inp, inp, chat_b, client_choice, temp, tokens, top_p], chat_b)
    stop_btn.click(None, None, None, cancels=[go, chat_sub])
    clear_btn.click(clear_fn, None, [chat_b])
app.queue(default_concurrency_limit=10).launch()