import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer from repeng import ControlVector, ControlModel import gradio as gr # Initialize model and tokenizer mistral_path = "./models/mistral" # Update this path as needed tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") #tokenizer = AutoTokenizer.from_pretrained("E:/language_models/models/mistral") tokenizer.pad_token_id = 0 model = AutoModelForCausalLM.from_pretrained( mistral_path, torch_dtype=torch.float16, trust_remote_code=True, use_safetensors=True ) model = model.to("cuda:0" if torch.cuda.is_available() else "cpu") model = ControlModel(model, list(range(-5, -18, -1))) # Generation settings generation_settings = { "pad_token_id": tokenizer.eos_token_id, # Silence warning "do_sample": False, # Deterministic output "max_new_tokens": 256, "repetition_penalty": 1.1, # Reduce repetition } # Tags for prompt formatting user_tag, asst_tag = "[INST]", "[/INST]" # List available control vectors control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')] if not control_vector_files: raise FileNotFoundError("No .gguf control vector files found in the current directory.") # Function to toggle slider visibility based on checkbox state def toggle_slider(checked): return gr.update(visible=checked) # Function to generate the model's response def generate_response(system_prompt, user_message, *args, history): # args contains alternating checkbox and slider values num_controls = len(control_vector_files) checkboxes = args[0::2] # Extract every first item in each pair sliders = args[1::2] # Extract every second item in each pair # Reset any previous control vectors model.reset() # Apply selected control vectors with their corresponding weights for i in range(num_controls): if checkboxes[i]: cv_file = control_vector_files[i] weight = sliders[i] try: control_vector = ControlVector.import_gguf(cv_file) model.set_control(control_vector, weight) except Exception as e: print(f"Failed to set control vector {cv_file}: {e}") # Format the prompt if system_prompt.strip(): formatted_prompt = f"{system_prompt}\n{user_tag}{user_message}{asst_tag}" else: formatted_prompt = f"{user_tag}{user_message}{asst_tag}" # Tokenize the input input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) # Generate the response output_ids = model.generate(**input_ids, **generation_settings) response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True) # Update conversation history history = history or [] history.append((user_message, response)) return history # Function to reset the conversation history def reset_chat(): return [] # Build the Gradio interface with gr.Blocks() as demo: gr.Markdown("# 🧠 Language Model Interface") with gr.Row(): with gr.Column(scale=1): # System Prompt Input system_prompt = gr.Textbox( label="System Prompt", lines=2, placeholder="Enter system-level instructions here..." ) # User Message Input user_input = gr.Textbox( label="User Message", lines=2, placeholder="Type your message here..." ) gr.Markdown("### 📊 Control Vectors") # Create checkboxes and sliders for each control vector control_checks = [] control_sliders = [] for cv_file in control_vector_files: with gr.Row(): # Checkbox to select the control vector checkbox = gr.Checkbox(label=cv_file, value=False) control_checks.append(checkbox) # Slider to adjust the control vector's weight slider = gr.Slider( minimum=-2.5, maximum=2.5, value=0.0, step=0.1, label=f"{cv_file} Weight", visible=False ) control_sliders.append(slider) # Link the checkbox to toggle slider visibility checkbox.change( toggle_slider, inputs=checkbox, outputs=slider ) with gr.Row(): # Submit and New Chat buttons submit_button = gr.Button("💬 Submit") new_chat_button = gr.Button("🆕 New Chat") with gr.Column(scale=2): # Chatbot to display conversation chatbot = gr.Chatbot(label="🗨️ Conversation") # State to keep track of conversation history state = gr.State([]) # Define button actions submit_button.click( generate_response, inputs=[system_prompt, user_input] + control_checks + control_sliders + [state], outputs=[chatbot] ) new_chat_button.click( reset_chat, inputs=[], outputs=[chatbot] ) # Launch the Gradio app if __name__ == "__main__": demo.launch()