CodeMixt / app.py
acecalisto3's picture
Update app.py
19627c4 verified
raw
history blame
12.1 kB
import threading
import time
import gradio as gr
import logging
import json
import re
import torch
import tempfile
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any, Union
from dataclasses import dataclass, field
from enum import Enum
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from PIL import Image
import black
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('gradio_builder.log')
]
)
logger = logging.getLogger(__name__)
# Configuration
@dataclass
class Config:
port: int = 7860
debug: bool = False
share: bool = False
model_name: str = "gpt2"
embedding_model: str = "all-mpnet-base-v2"
theme: str = "default"
max_code_length: int = 5000
@classmethod
def from_file(cls, path: str) -> "Config":
try:
with open(path) as f:
return cls(**json.load(f))
except Exception as e:
logger.warning(f"Failed to load config from {path}: {e}. Using defaults.")
return cls()
# Constants
CONFIG_PATH = Path("config.json")
MODEL_CACHE_DIR = Path("model_cache")
TEMPLATE_DIR = Path("templates")
TEMP_DIR = Path("temp")
DATABASE_PATH = Path("code_database.json")
# Ensure directories exist
for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]:
directory.mkdir(exist_ok=True, parents=True)
@dataclass
class Template:
code: str
description: str
components: List[str] = field(default_factory=list)
created_at: str = field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S"))
tags: List[str] = field(default_factory=list)
class TemplateManager:
def __init__(self, template_dir: Path):
self.template_dir = template_dir
self.templates: Dict[str, Template] = {}
def load_templates(self) -> None:
for file_path in self.template_dir.glob("*.json"):
try:
with open(file_path, 'r') as f:
template_data = json.load(f)
template = Template(**template_data)
self.templates[template_data['description']] = template
logger.info(f"Loaded template: {file_path.stem}")
except Exception as e:
logger.error(f"Error loading template from {file_path}: {e}")
def save_template(self, name: str, template: Template) -> bool:
file_path = self.template_dir / f"{name}.json"
try:
with open(file_path, 'w') as f:
json.dump(dataclasses.asdict(template), f, indent=2)
self.templates[name] = template
return True
except Exception as e:
logger.error(f"Error saving template to {file_path}: {e}")
return False
def get_template(self, name: str) -> Optional[str]:
template = self.templates.get(name)
return template.code if template else ""
def delete_template(self, name: str) -> bool:
file_path = self.template_dir / f"{name}.json"
try:
file_path.unlink()
self.templates.pop(name, None)
return True
except Exception as e:
logger.error(f"Error deleting template {name}: {e}")
return False
class RAGSystem:
def __init__(self, config: Config):
self.config = config
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.embedding_model = None
self.code_embeddings = None
self.index = None
self.database = {'codes': [], 'embeddings': []}
self.pipe = None
try:
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_name,
cache_dir=MODEL_CACHE_DIR
)
self.model = AutoModelForCausalLM.from_pretrained(
config.model_name,
cache_dir=MODEL_CACHE_DIR
).to(self.device)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=self.device
)
self.embedding_model = SentenceTransformer(config.embedding_model)
self.load_database()
logger.info("RAG system initialized successfully.")
except Exception as e:
logger.error(f"Error initializing RAG system: {e}")
def load_database(self) -> None:
if DATABASE_PATH.exists():
try:
with open(DATABASE_PATH, 'r', encoding='utf-8') as f:
self.database = json.load(f)
self.code_embeddings = np.array(self.database['embeddings'])
logger.info(f"Loaded {len(self.database['codes'])} code snippets from database.")
self._build_index()
except Exception as e:
logger.error(f"Error loading database: {e}")
self._initialize_empty_database()
else:
logger.info("Creating new database.")
self._initialize_empty_database()
def _initialize_empty_database(self) -> None:
self.database = {'codes': [], 'embeddings': []}
self.code_embeddings = np.array([])
self._build_index()
def _build_index(self) -> None:
if len(self.code_embeddings) > 0 and self.embedding_model:
dim = self.code_embeddings.shape[1]
self.index = faiss.IndexFlatL2(dim)
self.index.add(self.code_embeddings)
logger.info(f"Built FAISS index with {len(self.code_embeddings)} vectors")
class GradioInterface:
def __init__(self, config: Config):
self.config = config
self.template_manager = TemplateManager(TEMPLATE_DIR)
self.template_manager.load_templates()
self.rag_system = RAGSystem(config)
def format_code(self, code: str) -> str:
try:
return black.format_str(code, mode=black.FileMode())
except Exception as e:
logger.warning(f"Code formatting failed: {e}")
return code
def _extract_components(self, code: str) -> List[str]:
components = []
try:
function_matches = re.findall(r'def (\w+)\(', code)
class_matches = re.findall(r'class (\w+):', code)
components.extend(function_matches)
components.extend(class_matches)
except Exception as e:
logger.error(f"Error extracting components: {e}")
return list(set(components))
def launch(self) -> None:
with gr.Blocks(theme=gr.themes.Base()) as interface:
# Custom CSS
gr.Markdown(
"""
<style>
.header {
text-align: center;
background-color: #f0f0f0;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
</style>
<div class="header">
<h1>Code Generation Interface</h1>
<p>Generate and manage code templates easily</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=2):
description_input = gr.Textbox(
label="Description",
placeholder="Enter a description for the code you want to generate",
lines=3
)
template_choice = gr.Dropdown(
label="Select Template",
choices=list(self.template_manager.templates.keys()),
value=None
)
with gr.Row():
generate_button = gr.Button("Generate Code", variant="primary")
save_button = gr.Button("Save as Template", variant="secondary")
clear_button = gr.Button("Clear", variant="stop")
with gr.Row():
code_output = gr.Code(
label="Generated Code",
language="python",
interactive=True
)
status_output = gr.Textbox(
label="Status",
interactive=False
)
def generate_code_wrapper(description: str, template_choice: str) -> Tuple[str, str]:
if not description.strip():
return "", "Please provide a description"
try:
template_code = self.template_manager.get_template(template_choice) if template_choice else ""
generated_code = self.rag_system.generate_code(description, template_code)
formatted_code = self.format_code(generated_code)
if not formatted_code:
return "", "Failed to generate code. Please try again."
return formatted_code, "Code generated successfully."
except Exception as e:
logger.error(f"Error in code generation: {str(e)}")
return "", f"Error: {str(e)}"
def save_template_wrapper(code: str, name: str, description: str) -> Tuple[str, str]:
try:
if not name or not code:
return code, "Template name and code are required."
components = self._extract_components(code)
template = Template(
code=code,
description=name,
components=components,
tags=[t.strip() for t in description.split(',') if t.strip()]
)
if self.template_manager.save_template(name, template):
self.rag_system.add_to_database(code)
template_choice.choices = list(self.template_manager.templates.keys())
return code, f"Template '{name}' saved successfully."
else:
return code, "Failed to save template."
except Exception as e:
return code, f"Error saving template: {e}"
def clear_outputs() -> Tuple[str, str, str]:
return "", "", ""
# Event handlers
generate_button.click(
fn=generate_code_wrapper,
inputs=[description_input, template_choice],
outputs=[code_output, status_output],
api_name="generate_code",
show_progress=True
)
save_button.click(
fn=save_template_wrapper,
inputs=[code_output, template_choice, description_input],
outputs=[code_output, status_output]
)
clear_button.click(
fn=clear_outputs,
inputs=[],
outputs=[description_input, code_output, status_output]
)
# Launch the interface
interface.launch(
server_port=self.config.port,
share=self.config.share,
debug=self.config.debug
)
def main():
logger.info("=== Application Startup ===")
try:
config = Config.from_file(CONFIG_PATH)
interface = GradioInterface(config)
interface.launch()
except Exception as e:
logger.error(f"Application error: {e}")
raise
finally:
logger.info("=== Application Shutdown ===")
if __name__ == "__main__":
main()