tatihden commited on
Commit
4f32996
·
verified ·
1 Parent(s): 01a9918

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -24
app.py CHANGED
@@ -1,17 +1,19 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- import random
 
 
4
  import spaces
5
 
6
  models = [
7
- "https://huggingface.co/tatihden/gemma_mental_health_7b_it_en",
8
- "https://huggingface.co/tatihden/gemma_mental_health_2b_en",
9
- "https://huggingface.co/tatihden/gemma_mental_health_2b_it_en"
10
  ]
11
 
12
  clients = []
13
  for model in models:
14
- clients.append(InferenceClient(model))
15
 
16
  @spaces.GPU
17
  def format_prompt(message, history):
@@ -25,7 +27,7 @@ def format_prompt(message, history):
25
 
26
 
27
  def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p):
28
- client = clients[int(client_choice) - 1]
29
  if not history:
30
  history = []
31
  hist_len = 0
@@ -42,7 +44,7 @@ def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens,
42
  seed=seed,
43
  )
44
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
45
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True,
46
  return_full_text=False)
47
  output = ""
48
 
@@ -67,27 +69,13 @@ def check_rand(inp, val):
67
  return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
68
 
69
 
70
- with gr.Blocks(css=".gradio-container {background-color: rgb(74 222 128)}",theme=gr.themes.Soft()) as app:
71
  gr.HTML(
72
- """<center><h1 style='font-size:xx-large;'>Google Gemma Models</h1></center>""")
73
  with gr.Group():
74
  with gr.Row():
75
  client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
76
  interactive=True)
77
- chat_b = gr.Chatbot(height=500)
78
- with gr.Group():
79
- with gr.Row():
80
- with gr.Column(scale=1):
81
- with gr.Group():
82
- rand = gr.Checkbox(label="Random Seed", value=True)
83
- seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
84
- tokens = gr.Slider(label="Max new tokens", value=6400, minimum=0, maximum=8000, step=64,
85
- interactive=True, visible=True, info="The maximum number of tokens")
86
- with gr.Column(scale=1):
87
- with gr.Group():
88
- temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
89
- top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
90
- rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
91
 
92
  with gr.Group():
93
  with gr.Row():
@@ -99,6 +87,25 @@ with gr.Blocks(css=".gradio-container {background-color: rgb(74 222 128)}",theme
99
  stop_btn = gr.Button("Stop")
100
  clear_btn = gr.Button("Clear")
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  chat_sub = inp.submit(check_rand, [rand, seed], seed).then(chat_inf,
103
  [sys_inp, inp, chat_b, client_choice, seed, temp, tokens,
104
  top_p, rep_p], chat_b)
 
1
  import gradio as gr
2
+ #from huggingface_hub import InferenceClient
3
+ #import random
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
  import spaces
7
 
8
  models = [
9
+ "hf://tatihden/gemma_mental_health_7b_it_en",
10
+ "hf://tatihden/gemma_mental_health_2b_it_en",
11
+ "hf://tatihden/gemma_mental_health_2b_en"
12
  ]
13
 
14
  clients = []
15
  for model in models:
16
+ clients.append(keras.models.load_model(model))
17
 
18
  @spaces.GPU
19
  def format_prompt(message, history):
 
27
 
28
 
29
  def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_p, rep_p):
30
+ #client = clients[int(client_choice) - 1]
31
  if not history:
32
  history = []
33
  hist_len = 0
 
44
  seed=seed,
45
  )
46
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
47
+ stream = model.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True,
48
  return_full_text=False)
49
  output = ""
50
 
 
69
  return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
70
 
71
 
72
+ with gr.Blocks(css=".gradio-container {background-color: rgb(187 247 208)}",theme=gr.themes.Soft()) as app:
73
  gr.HTML(
74
+ """<center><h1 style='font-size:xx-large;'>CalmChat:A mental health conversational agent</h1></center>""")
75
  with gr.Group():
76
  with gr.Row():
77
  client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
78
  interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  with gr.Group():
81
  with gr.Row():
 
87
  stop_btn = gr.Button("Stop")
88
  clear_btn = gr.Button("Clear")
89
 
90
+
91
+ with gr.Column(scale=1):
92
+ with gr.Group():
93
+ temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
94
+ top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
95
+ rep_p = gr.Slider(label="Repetition Penalty", step=0.1, minimum=0.1, maximum=2.0, value=1.0)
96
+
97
+
98
+ chat_b = gr.Chatbot(height=500)
99
+ with gr.Group():
100
+ with gr.Row():
101
+ with gr.Column(scale=1):
102
+ with gr.Group():
103
+ rand = gr.Checkbox(label="Random Seed", value=True)
104
+ seed = gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, step=1, value=rand_val)
105
+ tokens = gr.Slider(label="Max new tokens", value=6400, minimum=0, maximum=8000, step=64,
106
+ interactive=True, visible=True, info="The maximum number of tokens")
107
+
108
+
109
  chat_sub = inp.submit(check_rand, [rand, seed], seed).then(chat_inf,
110
  [sys_inp, inp, chat_b, client_choice, seed, temp, tokens,
111
  top_p, rep_p], chat_b)