tatihden commited on
Commit
649a7fc
·
verified ·
1 Parent(s): 0d2ec09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -22
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
  import random
3
- from huggingface_hub import InferenceClient
4
  import spaces
5
- #import os
6
 
7
 
8
- #os.environ["KERAS_BACKEND"] = "tensorflow" #"jax" "torch"
9
- #os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
10
 
11
- #import keras_hub
12
 
13
 
14
  models = [
@@ -17,16 +17,16 @@ models = [
17
  "hf://tatihden/gemma_mental_health_7b_it_en"
18
  ]
19
 
20
- #clients = []
21
- #for model in models:
22
- #clients.append(keras_hub.models.GemmaCausalLM.from_preset(model))
23
 
24
- from huggingface_hub import InferenceClient
25
 
26
 
27
- clients = []
28
- for model in models:
29
- clients.append(InferenceClient(model))
30
 
31
  @spaces.GPU
32
  def format_prompt(message, history):
@@ -48,21 +48,20 @@ def chat_inf(system_prompt, prompt, history, client_choice, seed, temp, tokens,
48
  hist_len = len(history)
49
  print(hist_len)
50
 
51
- generate_kwargs = dict(
52
- temperature=temp,
53
- max_new_tokens=tokens,
54
- top_p=top_p,
55
- repetition_penalty=rep_p,
56
- do_sample=True,
57
- seed=seed,
58
  )
59
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
60
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True,
61
- return_full_text=False)
62
  output = ""
63
 
64
  for response in stream:
65
- output += response.token.text
66
  yield [(prompt, output)]
67
  history.append((prompt, output))
68
  yield history
 
1
  import gradio as gr
2
  import random
3
+ #from huggingface_hub import InferenceClient
4
  import spaces
5
+ import os
6
 
7
 
8
+ os.environ["KERAS_BACKEND"] = "tensorflow" #"jax" "torch"
9
+ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
10
 
11
+ import keras_hub
12
 
13
 
14
  models = [
 
17
  "hf://tatihden/gemma_mental_health_7b_it_en"
18
  ]
19
 
20
+ clients = []
21
+ for model in models:
22
+ clients.append(keras_hub.models.GemmaCausalLM.from_preset(model))
23
 
24
+ #from huggingface_hub import InferenceClient
25
 
26
 
27
+ #clients = []
28
+ #for model in models:
29
+ #clients.append(InferenceClient(model))
30
 
31
  @spaces.GPU
32
  def format_prompt(message, history):
 
48
  hist_len = len(history)
49
  print(hist_len)
50
 
51
+ #generate_kwargs = dict(
52
+ #temperature=temp,
53
+ #max_new_tokens=tokens,
54
+ #top_p=top_p,
55
+ #repetition_penalty=rep_p,
56
+ #do_sample=True,
57
+ #seed=seed,
58
  )
59
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
60
+ stream = client.generate(formatted_prompt,max_lenght=tokens)
 
61
  output = ""
62
 
63
  for response in stream:
64
+ output = response.replace(prompt, "")
65
  yield [(prompt, output)]
66
  history.append((prompt, output))
67
  yield history