File size: 2,200 Bytes
36b393d
50e5e08
36b393d
50e5e08
8442614
50e5e08
 
36b393d
 
 
 
 
 
 
 
 
7c5b993
 
 
 
 
 
36b393d
50e5e08
 
 
7c5b993
50e5e08
 
7c5b993
36b393d
 
7c5b993
 
50e5e08
 
 
 
36b393d
50e5e08
7c5b993
 
 
 
36b393d
7c5b993
36b393d
7c5b993
36b393d
 
 
7c5b993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36b393d
 
 
 
 
7c5b993
36b393d
 
7c5b993
 
36b393d
 
 
50e5e08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model and tokenizer
model_name = "karthikqnq/qnqgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Construct the prompt from history and current message
    prompt = system_message + "\n\n"
    for user_msg, assistant_msg in history:
        if user_msg:
            prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
    prompt += f"User: {message}\nAssistant: "

    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)

    # Generate response
    outputs = model.generate(
        **inputs,
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1
    )

    # Decode the output and extract only the assistant's response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract the assistant's reply after "Assistant:"
    try:
        assistant_response = response.split("Assistant: ")[-1].strip()
    except:
        assistant_response = response

    return assistant_response

# Create the Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly Chatbot.", 
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
    title="QnQ GPT-2 Chatbot",
    description="A chatbot powered by the QnQ GPT-2 model"
)

if __name__ == "__main__":
    demo.launch()