Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import transformers | |
# Model setup | |
model_id = "GoToCompany/gemma2-9b-cpt-sahabatai-v1-instruct" | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model_id, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto", | |
) | |
terminators = [ | |
pipeline.tokenizer.eos_token_id, | |
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
# Chatbot Functionality | |
def chatbot(messages): | |
""" | |
Handles user interactions and returns the model's response. | |
Args: | |
messages (list): List of messages with roles ('user' or 'assistant') and content. | |
Returns: | |
list: Updated conversation with the assistant's response. | |
""" | |
# Prepare the conversation for the model | |
outputs = pipeline( | |
messages, | |
max_new_tokens=256, | |
eos_token_id=terminators, | |
) | |
# Extract and format the assistant's response | |
assistant_response = outputs[0]["generated_text"] if outputs else "I'm sorry, I couldn't generate a response." | |
messages.append({"role": "assistant", "content": assistant_response}) | |
return messages | |
# Gradio Chat Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# π€ Gemma2 Chatbot") | |
gr.Markdown("A chatbot that understands Javanese and Sundanese, powered by `GoToCompany/gemma2`.") | |
chat_history = gr.Chatbot(label="Gemma2 Chatbot") | |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
send_button = gr.Button("Send") | |
def respond(chat_history, user_message): | |
# Add user message to chat history | |
chat_history.append(("user", user_message)) | |
# Generate assistant's response | |
conversation = [{"role": role, "content": content} for role, content in chat_history] | |
response = chatbot(conversation) | |
# Add assistant's response to chat history | |
assistant_message = response[-1]["content"] | |
chat_history.append(("assistant", assistant_message)) | |
return chat_history, "" | |
send_button.click(respond, inputs=[chat_history, user_input], outputs=[chat_history, user_input]) | |
# Launch the app | |
demo.launch() | |