tatihden commited on
Commit
1a3d32b
·
verified ·
1 Parent(s): 5f0442a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -32
app.py CHANGED
@@ -1,39 +1,23 @@
1
- import os
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
- import re
5
- import gradio as gr
6
- import spaces
7
-
8
- model_id = "tatihden/gemma_mental_health_2b_it_en"
9
- dtype = torch.bfloat16
10
 
11
- @spaces.GPU
12
- def gemma_chat(message, history):
13
- tokenizer = AutoTokenizer.from_pretrained(model_id)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_id,
16
- hidden_activation="gelu_pytorch_tanh",
17
- device_map="cuda",
18
- torch_dtype=dtype,
19
- )
20
 
21
- chat = [
22
- { "role": "user", "content": message },
23
- ]
24
 
25
- prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
26
-
27
- inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
28
- outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=2048)
29
-
30
- response = tokenizer.decode(outputs[0])
31
 
32
- # clean the response
33
- response_cleaned = re.split("model", response)
 
 
 
34
 
35
- # return the response
36
- return response_cleaned
37
 
 
 
 
38
 
39
- gr.ChatInterface(gemma_chat).launch()
 
1
+ from transformers import pipeline
2
+ from transformers import Conversation
 
 
 
 
 
 
 
3
 
4
+ import gradio as gr
 
 
 
 
 
 
 
 
5
 
6
+ chatbot = pipeline(model="hf://tatihden/gemma_mental_health_2b_it_en")
 
 
7
 
8
+ message_list = []
9
+ response_list = []
 
 
 
 
10
 
11
+ def mini_chatbot(message, history):
12
+ conversation = Conversation(text=message,
13
+ past_user_inputs=message_list,
14
+ generated_responses=response_list)
15
+ conversation = chatbot(conversation)
16
 
17
+ return conversation.generated_responses[-1]
 
18
 
19
+ demo_chatbot = gr.ChatInterface(mini_chatbot,
20
+ title="CalmChat",
21
+ description="Enter text to start chatting.")
22
 
23
+ demo_chatbot.launch()