Spaces:
Sleeping
Sleeping
File size: 2,200 Bytes
36b393d 50e5e08 36b393d 50e5e08 8442614 50e5e08 36b393d 7c5b993 36b393d 50e5e08 7c5b993 50e5e08 7c5b993 36b393d 7c5b993 50e5e08 36b393d 50e5e08 7c5b993 36b393d 7c5b993 36b393d 7c5b993 36b393d 7c5b993 36b393d 7c5b993 36b393d 7c5b993 36b393d 50e5e08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the model and tokenizer
model_name = "karthikqnq/qnqgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Construct the prompt from history and current message
prompt = system_message + "\n\n"
for user_msg, assistant_msg in history:
if user_msg:
prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
prompt += f"User: {message}\nAssistant: "
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
# Generate response
outputs = model.generate(
**inputs,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
num_return_sequences=1
)
# Decode the output and extract only the assistant's response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the assistant's reply after "Assistant:"
try:
assistant_response = response.split("Assistant: ")[-1].strip()
except:
assistant_response = response
return assistant_response
# Create the Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly Chatbot.",
label="System message"
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
title="QnQ GPT-2 Chatbot",
description="A chatbot powered by the QnQ GPT-2 model"
)
if __name__ == "__main__":
demo.launch()
|