File size: 5,602 Bytes
f655011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from repeng import ControlVector, ControlModel
import gradio as gr

# Initialize model and tokenizer
mistral_path = "./models/mistral"  # Update this path as needed

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
#tokenizer = AutoTokenizer.from_pretrained("E:/language_models/models/mistral")
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(
    mistral_path,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    use_safetensors=True
)
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
model = ControlModel(model, list(range(-5, -18, -1)))

# Generation settings
generation_settings = {
    "pad_token_id": tokenizer.eos_token_id,  # Silence warning
    "do_sample": False,                      # Deterministic output
    "max_new_tokens": 256,
    "repetition_penalty": 1.1,              # Reduce repetition
}

# Tags for prompt formatting
user_tag, asst_tag = "[INST]", "[/INST]"

# List available control vectors
control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')]

if not control_vector_files:
    raise FileNotFoundError("No .gguf control vector files found in the current directory.")

# Function to toggle slider visibility based on checkbox state
def toggle_slider(checked):
    return gr.update(visible=checked)

# Function to generate the model's response
def generate_response(system_prompt, user_message, *args, history):
    # args contains alternating checkbox and slider values
    num_controls = len(control_vector_files)
    checkboxes = args[0::2]  # Extract every first item in each pair
    sliders = args[1::2]     # Extract every second item in each pair
    
    # Reset any previous control vectors
    model.reset()
    
    # Apply selected control vectors with their corresponding weights
    for i in range(num_controls):
        if checkboxes[i]:
            cv_file = control_vector_files[i]
            weight = sliders[i]
            try:
                control_vector = ControlVector.import_gguf(cv_file)
                model.set_control(control_vector, weight)
            except Exception as e:
                print(f"Failed to set control vector {cv_file}: {e}")
    
    # Format the prompt
    if system_prompt.strip():
        formatted_prompt = f"{system_prompt}\n{user_tag}{user_message}{asst_tag}"
    else:
        formatted_prompt = f"{user_tag}{user_message}{asst_tag}"
    
    # Tokenize the input
    input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    # Generate the response
    output_ids = model.generate(**input_ids, **generation_settings)
    response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)
    
    # Update conversation history
    history = history or []
    history.append((user_message, response))
    return history

# Function to reset the conversation history
def reset_chat():
    return []

# Build the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# 🧠 Language Model Interface")
    
    with gr.Row():
        with gr.Column(scale=1):
            # System Prompt Input
            system_prompt = gr.Textbox(
                label="System Prompt",
                lines=2,
                placeholder="Enter system-level instructions here..."
            )
            
            # User Message Input
            user_input = gr.Textbox(
                label="User Message",
                lines=2,
                placeholder="Type your message here..."
            )
            
            gr.Markdown("### 📊 Control Vectors")
            
            # Create checkboxes and sliders for each control vector
            control_checks = []
            control_sliders = []
            for cv_file in control_vector_files:
                with gr.Row():
                    # Checkbox to select the control vector
                    checkbox = gr.Checkbox(label=cv_file, value=False)
                    control_checks.append(checkbox)
                    
                    # Slider to adjust the control vector's weight
                    slider = gr.Slider(
                        minimum=-2.5,
                        maximum=2.5,
                        value=0.0,
                        step=0.1,
                        label=f"{cv_file} Weight",
                        visible=False
                    )
                    control_sliders.append(slider)
                    
                    # Link the checkbox to toggle slider visibility
                    checkbox.change(
                        toggle_slider,
                        inputs=checkbox,
                        outputs=slider
                    )
            
            with gr.Row():
                # Submit and New Chat buttons
                submit_button = gr.Button("💬 Submit")
                new_chat_button = gr.Button("🆕 New Chat")
        
        with gr.Column(scale=2):
            # Chatbot to display conversation
            chatbot = gr.Chatbot(label="🗨️ Conversation")
    
    # State to keep track of conversation history
    state = gr.State([])
    
    # Define button actions
    submit_button.click(
        generate_response,
        inputs=[system_prompt, user_input] + control_checks + control_sliders + [state],
        outputs=[chatbot]
    )
    
    new_chat_button.click(
        reset_chat,
        inputs=[],
        outputs=[chatbot]
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()