tatihden commited on
Commit
974b72c
·
verified ·
1 Parent(s): ea63197

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import gradio as gr
2
- #from huggingface_hub import InferenceClient
3
  import random
4
  import spaces
5
  import os
6
 
7
  os.environ["KERAS_BACKEND"] = "jax"
 
 
8
  import keras_hub
9
  import keras
10
 
@@ -29,9 +30,8 @@ def format_prompt(message, history):
29
  return prompt
30
 
31
 
32
- def chat_inf(system_prompt, prompt, history, tokens, client_choice, temp, top_p, seed):
33
  client = clients[int(client_choice) - 1]
34
-
35
  if not history:
36
  history = []
37
  hist_len = 0
@@ -39,7 +39,7 @@ def chat_inf(system_prompt, prompt, history, tokens, client_choice, temp, top_p,
39
  hist_len = len(history)
40
  print(hist_len)
41
 
42
- sampler = keras_hub.samplers.TopPSampler(p=top_p, seed=seed, temperature=temp)
43
  client.compile(sampler=sampler)
44
 
45
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
@@ -47,7 +47,7 @@ def chat_inf(system_prompt, prompt, history, tokens, client_choice, temp, top_p,
47
  output = ""
48
 
49
  for response in stream:
50
- output += response.str(tokens)
51
  yield [(prompt, output)]
52
  history.append((prompt, output))
53
  yield history
@@ -67,14 +67,14 @@ def check_rand(inp, val):
67
  return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
68
 
69
 
70
- with gr.Blocks(theme=gr.themes.Glass(),css=".gradio-container {background-color: rgb(134 239 172)}") as app:
71
  gr.HTML(
72
- """<center><h1 style='font-size:xx-large;'>CalmChat</h1></center>""")
73
  with gr.Group():
74
  with gr.Row():
75
  client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
76
  interactive=True)
77
- chat_b = gr.Chatbot(height=500)
78
  with gr.Group():
79
  with gr.Row():
80
  with gr.Column(scale=1):
@@ -86,8 +86,8 @@ with gr.Blocks(theme=gr.themes.Glass(),css=".gradio-container {background-color:
86
  with gr.Column(scale=1):
87
  with gr.Group():
88
  temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
 
89
  top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
90
-
91
 
92
  with gr.Group():
93
  with gr.Row():
@@ -99,9 +99,12 @@ with gr.Blocks(theme=gr.themes.Glass(),css=".gradio-container {background-color:
99
  stop_btn = gr.Button("Stop")
100
  clear_btn = gr.Button("Clear")
101
 
102
- chat_sub = inp.submit(check_rand, [rand, seed],seed).then(chat_inf,[sys_inp, inp, chat_b, client_choice, tokens], chat_b, temp, top_p, seed)
103
- go = btn.click(check_rand, [rand, seed], seed).then(chat_inf,[sys_inp, inp, chat_b, client_choice, tokens], chat_b, temp, top_p, seed)
104
-
 
 
 
105
  stop_btn.click(None, None, None, cancels=[go, chat_sub])
106
  clear_btn.click(clear_fn, None, [chat_b])
107
  app.queue(default_concurrency_limit=10).launch()
 
1
  import gradio as gr
 
2
  import random
3
  import spaces
4
  import os
5
 
6
  os.environ["KERAS_BACKEND"] = "jax"
7
+ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
8
+
9
  import keras_hub
10
  import keras
11
 
 
30
  return prompt
31
 
32
 
33
+ def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens, top_k, top_p):
34
  client = clients[int(client_choice) - 1]
 
35
  if not history:
36
  history = []
37
  hist_len = 0
 
39
  hist_len = len(history)
40
  print(hist_len)
41
 
42
+ sampler = keras_nlp.samplers.TopKSampler(k=top_k, seed=seed, temperature=temp)
43
  client.compile(sampler=sampler)
44
 
45
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
 
47
  output = ""
48
 
49
  for response in stream:
50
+ output += response.token.text
51
  yield [(prompt, output)]
52
  history.append((prompt, output))
53
  yield history
 
67
  return gr.Slider(label="Seed", minimum=1, maximum=1111111111111111, value=int(val))
68
 
69
 
70
+ with gr.Blocks(css=".gradio-container {background-color: rgb(187 247 208)}", gr.themes.Monochrome()) as app:
71
  gr.HTML(
72
+ """<center><h1 style='font-size:xx-large;'>CalmChat: A Mental Health Conversational Agent</h1></center>""")
73
  with gr.Group():
74
  with gr.Row():
75
  client_choice = gr.Dropdown(label="Models", type='index', choices=[c for c in models], value=models[0],
76
  interactive=True)
77
+ chat_b = gr.Chatbot(height=300)
78
  with gr.Group():
79
  with gr.Row():
80
  with gr.Column(scale=1):
 
86
  with gr.Column(scale=1):
87
  with gr.Group():
88
  temp = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
89
+ top_k = gr.Slider(label="Top-K", step=0.5, minimum=1, maximum=10, value=5)
90
  top_p = gr.Slider(label="Top-P", step=0.01, minimum=0.01, maximum=1.0, value=0.9)
 
91
 
92
  with gr.Group():
93
  with gr.Row():
 
99
  stop_btn = gr.Button("Stop")
100
  clear_btn = gr.Button("Clear")
101
 
102
+ chat_sub = inp.submit(check_rand, [rand, seed], seed).then(chat_inf,
103
+ [sys_inp, inp, chat_b, client_choice, seed, temp, tokens,
104
+ top_p, top_k], chat_b)
105
+ go = btn.click(check_rand, [rand, seed], seed).then(chat_inf,
106
+ [sys_inp, inp, chat_b, client_choice, seed, temp, tokens, top_p,
107
+ top_k], chat_b)
108
  stop_btn.click(None, None, None, cancels=[go, chat_sub])
109
  clear_btn.click(clear_fn, None, [chat_b])
110
  app.queue(default_concurrency_limit=10).launch()