File size: 4,540 Bytes
1a3d32b
e4bdfa5
649a7fc
eae7478
649a7fc
eb76062
9eba7f5
649a7fc
 
e4bdfa5
649a7fc
ae0810b
9eba7f5
e4bdfa5
404b0ea
 
 
e4bdfa5
 
649a7fc
 
 
0d2ec09
649a7fc
0d2ec09
 
649a7fc
 
 
e4bdfa5
eae7478
e4bdfa5
 
 
 
 
e25ad45
e4bdfa5
 
 
 
0d2ec09
3992653
e4bdfa5
 
 
 
 
 
 
7d023f3
 
 
 
 
 
 
 
e4bdfa5
7d023f3
0d2ec09
e4bdfa5
0d2ec09
a42e410
e61db62
0d2ec09
e4bdfa5
 
 
 
 
 
0d2ec09
7890994
 
e4bdfa5
 
0d2ec09
e4bdfa5
0d2ec09
7890994
 
a42e410
e4bdfa5
7d023f3
e4bdfa5
 
 
 
0d2ec09
e4bdfa5
 
 
 
 
0d2ec09
e4bdfa5
 
4fc9c9c
 
 
0d2ec09
 
4f32996
e4bdfa5
 
 
 
 
 
 
 
 
4f32996
974b72c
 
0d2ec09
974b72c
0d2ec09
 
7c44984
e4bdfa5
6d7f46d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import random
#from huggingface_hub import InferenceClient
#import spaces
import os


os.environ["KERAS_BACKEND"] =  "tensorflow"  #"jax" "torch"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

import keras_hub


models = [
    "hf://tatihden/gemma_mental_health_2b_it_en",
    "hf://tatihden/gemma_mental_health_2b_en",
    "hf://tatihden/gemma_mental_health_7b_it_en"
]

clients = []
for model in models:
    clients.append(keras_hub.models.GemmaCausalLM.from_preset(model))

#from huggingface_hub import InferenceClient


#clients = []
#for model in models:
    #clients.append(InferenceClient(model))

#@spaces.GPU
def format_prompt(message, history):
    prompt = ""
    if history:
        for user_prompt, bot_response in history:
            prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
            #prompt += f"<start_of_turn>model{bot_response}"
    prompt += f"<start_of_turn>user{message}<end_of_turn><start_of_turn>model"
    return prompt


def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p):
    client = clients[int(client_choice) - 1]
    if not history:
        history = []
        hist_len = 0
    if history:
        hist_len = len(history)
        print(hist_len)

    #generate_kwargs = dict(
        #temperature=temp,
        #max_new_tokens=tokens,
        #top_p=top_p,
        #repetition_penalty=rep_p,
        #do_sample=True,
        #seed=seed,
    #)
    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    stream = client.generate(formatted_prompt)
    output = ""

    for response in stream:
        output+= response
    history.append((prompt, output))
    yield history


def clear_fn():
    return None


rand_val = random.randint(1, 1111111111111111)


def check_rand(inp, val):
    if inp is True:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=random.randint(1, 1111111111111111))
    else:
        return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))


with gr.Blocks(theme=gr.themes.Soft(),css=".gradio-container {background-color: rgb(187 247 208)}") as app:
    gr.HTML(
        """<center><h1 style='font-size:xx-large;'>CalmChat:A mental Health Conversational Agent</h1></center>""")
    with gr.Group():
        with gr.Row():
            client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
                                        interactive=True)
    chat_b = gr.Chatbot(height=500)
    with gr.Group():
        with gr.Row():
            with gr.Column(scale=1):
                with gr.Group():
                    rand = gr.Checkbox(label="Random Seed", value=True)
                    seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
                    tokens = gr.Slider(label="Max new tokens", value=6400, minimum=0, maximum=8000, step=64,
                                       interactive=True, visible=True, info="The maximum number of tokens")
            with gr.Column(scale=1):
                with gr.Group():
                    temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
                    rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)

    with gr.Group():
        with gr.Row():
            with gr.Column(scale=3):
                sys_inp = gr.Textbox(label="System Prompt (optional)")
                inp = gr.Textbox(label="Prompt")
                with gr.Row():
                    btn = gr.Button("Chat")
                    stop_btn = gr.Button("Stop")
                    clear_btn = gr.Button("Clear")

    chat_sub = inp.submit(check_rand, [rand, seed], seed).then(chat_inf,
                                                               [sys_inp, inp, chat_b, client_choice, seed, temp, tokens,
                                                                top_p, rep_p], chat_b)
    go = btn.click(check_rand, [rand, seed], seed).then(chat_inf,
                                                        [sys_inp, inp, chat_b, client_choice, seed, temp, tokens, top_p,
                                                         rep_p], chat_b)
    stop_btn.click(None, None, None, cancels=[go, chat_sub])
    clear_btn.click(clear_fn, None, [chat_b])
app.queue(default_concurrency_limit=10).launch()