Smart_LLM / app.py
Daemontatox's picture
Update app.py
13880c3 verified
raw
history blame
8.68 kB
import os
import re
import time
import torch
import spaces
import gradio as gr
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer
)
# Configuration Constants
MODEL_ID = "Daemontatox/AetherDrake"
DEFAULT_SYSTEM_PROMPT = """You are a Sentient Reasoning AI, expert at providing high-quality answers.
Your process involves these steps:
1. Initial Thought: Use the <Thinking> tag to reason step-by-step about any given request.
Example:
<Thinking>
Step 1: Understand the core request
Step 2: Analyze key components
Step 3: Formulate comprehensive response
</Thinking>
2. Self-Critique: Use <Critique> tags to evaluate your response:
<Critique>
- Accuracy: Verify facts and logic
- Clarity: Assess explanation clarity
- Completeness: Check all points addressed
- Improvements: Identify enhancement areas
</Critique>
3. Revision: Use <Revising> tags to refine your response:
<Revising>
Making identified improvements...
Enhancing clarity...
Adding examples...
</Revising>
4. Final Response: Present your polished answer in <Final> tags:
<Final>
Your complete, refined response goes here.
</Final>
Always organize your responses using these tags for clear reasoning structure."""
# UI Configuration
TITLE = "<h1><center>AI Reasoning Assistant</center></h1>"
PLACEHOLDER = """
<center>
<p>Ask me anything! I'll think through it step by step.</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
.message-wrap {
overflow-x: auto;
white-space: pre-wrap !important;
}
.message-wrap p {
margin-bottom: 1em;
white-space: pre-wrap !important;
}
.message-wrap pre {
background-color: #f6f8fa;
border-radius: 3px;
padding: 16px;
overflow-x: auto;
}
.message-wrap code {
background-color: rgba(175,184,193,0.2);
border-radius: 3px;
padding: 0.2em 0.4em;
font-family: monospace;
}
.custom-tag {
color: #0066cc;
font-weight: bold;
}
"""
def initialize_model():
"""Initialize the model with appropriate configurations"""
# Quantization configuration
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Initialize model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2",
quantization_config=quantization_config
)
return model, tokenizer
def format_text(text):
"""Format text with proper spacing and tag highlighting"""
# Add newlines around tags
tag_patterns = [
(r'<Thinking>', '\n<Thinking>\n'),
(r'</Thinking>', '\n</Thinking>\n'),
(r'<Critique>', '\n<Critique>\n'),
(r'</Critique>', '\n</Critique>\n'),
(r'<Revising>', '\n<Revising>\n'),
(r'</Revising>', '\n</Revising>\n'),
(r'<Final>', '\n<Final>\n'),
(r'</Final>', '\n</Final>\n')
]
formatted = text
for pattern, replacement in tag_patterns:
formatted = re.sub(pattern, replacement, formatted)
# Remove extra blank lines
formatted = '\n'.join(line for line in formatted.split('\n') if line.strip())
return formatted
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float = 0.2,
max_new_tokens: int = 8192,
top_p: float = 1.0,
top_k: int = 20,
penalty: float = 1.2,
):
"""Generate streaming chat responses with proper tag handling"""
# Format conversation context
conversation = [
{"role": "system", "content": system_prompt}
]
# Add conversation history
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
# Add current message
conversation.append({"role": "user", "content": message})
# Prepare input for model
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Configure streamer
streamer = TextIteratorStreamer(
tokenizer,
timeout=60.0,
skip_prompt=True,
skip_special_tokens=True
)
# Set generation parameters
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=False if temperature == 0 else True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=penalty,
streamer=streamer,
)
# Generate and stream response
buffer = ""
current_line = ""
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
for new_text in streamer:
buffer += new_text
current_line += new_text
if '\n' in current_line:
lines = current_line.split('\n')
current_line = lines[-1]
formatted_buffer = format_text(buffer)
yield formatted_buffer
else:
yield buffer
def create_examples():
"""Create example queries that demonstrate the system's capabilities"""
return [
["Explain how neural networks learn through backpropagation."],
["What are the key differences between classical and quantum computing?"],
["Analyze the environmental impact of renewable energy sources."],
["How does the human memory system work?"],
["Explain the concept of ethical AI and its importance."]
]
def main():
"""Main function to set up and launch the Gradio interface"""
# Initialize model and tokenizer
global model, tokenizer
model, tokenizer = initialize_model()
# Create chatbot interface
chatbot = gr.Chatbot(
height=600,
placeholder=PLACEHOLDER,
bubble_full_width=False,
show_copy_button=True
)
# Create interface
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML(TITLE)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_classes="duplicate-button"
)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(
label="⚙️ Advanced Settings",
open=False,
render=False
),
additional_inputs=[
gr.Textbox(
value=DEFAULT_SYSTEM_PROMPT,
label="System Prompt",
lines=5,
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.2,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=32000,
step=128,
value=8192,
label="Max Tokens",
render=False,
),
gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.1,
value=1.0,
label="Top-p",
render=False,
),
gr.Slider(
minimum=1,
maximum=100,
step=1,
value=20,
label="Top-k",
render=False,
),
gr.Slider(
minimum=1.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty",
render=False,
),
],
examples=create_examples(),
cache_examples=False,
)
return demo
if __name__ == "__main__":
demo = main()
demo.launch()