import threading import time import gradio as gr import logging import json import re import torch import tempfile import os from pathlib import Path from typing import Dict, List, Tuple, Optional, Any, Union from dataclasses import dataclass, field from enum import Enum from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from sentence_transformers import SentenceTransformer import faiss import numpy as np from PIL import Image import black # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler('gradio_builder.log') ] ) logger = logging.getLogger(__name__) # Configuration @dataclass class Config: port: int = 7860 debug: bool = False share: bool = False model_name: str = "gpt2" embedding_model: str = "all-mpnet-base-v2" theme: str = "default" max_code_length: int = 5000 @classmethod def from_file(cls, path: str) -> "Config": try: with open(path) as f: return cls(**json.load(f)) except Exception as e: logger.warning(f"Failed to load config from {path}: {e}. Using defaults.") return cls() # Constants CONFIG_PATH = Path("config.json") MODEL_CACHE_DIR = Path("model_cache") TEMPLATE_DIR = Path("templates") TEMP_DIR = Path("temp") DATABASE_PATH = Path("code_database.json") # Ensure directories exist for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]: directory.mkdir(exist_ok=True, parents=True) @dataclass class Template: code: str description: str components: List[str] = field(default_factory=list) created_at: str = field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) tags: List[str] = field(default_factory=list) class TemplateManager: def __init__(self, template_dir: Path): self.template_dir = template_dir self.templates: Dict[str, Template] = {} def load_templates(self) -> None: for file_path in self.template_dir.glob("*.json"): try: with open(file_path, 'r') as f: template_data = json.load(f) template = Template(**template_data) self.templates[template_data['description']] = template logger.info(f"Loaded template: {file_path.stem}") except Exception as e: logger.error(f"Error loading template from {file_path}: {e}") def save_template(self, name: str, template: Template) -> bool: file_path = self.template_dir / f"{name}.json" try: with open(file_path, 'w') as f: json.dump(dataclasses.asdict(template), f, indent=2) self.templates[name] = template return True except Exception as e: logger.error(f"Error saving template to {file_path}: {e}") return False def get_template(self, name: str) -> Optional[str]: template = self.templates.get(name) return template.code if template else "" def delete_template(self, name: str) -> bool: file_path = self.template_dir / f"{name}.json" try: file_path.unlink() self.templates.pop(name, None) return True except Exception as e: logger.error(f"Error deleting template {name}: {e}") return False class RAGSystem: def __init__(self, config: Config): self.config = config self.device = "cuda" if torch.cuda.is_available() else "cpu" self.embedding_model = None self.code_embeddings = None self.index = None self.database = {'codes': [], 'embeddings': []} self.pipe = None try: self.tokenizer = AutoTokenizer.from_pretrained( config.model_name, cache_dir=MODEL_CACHE_DIR ) self.model = AutoModelForCausalLM.from_pretrained( config.model_name, cache_dir=MODEL_CACHE_DIR ).to(self.device) self.pipe = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device ) self.embedding_model = SentenceTransformer(config.embedding_model) self.load_database() logger.info("RAG system initialized successfully.") except Exception as e: logger.error(f"Error initializing RAG system: {e}") def load_database(self) -> None: if DATABASE_PATH.exists(): try: with open(DATABASE_PATH, 'r', encoding='utf-8') as f: self.database = json.load(f) self.code_embeddings = np.array(self.database['embeddings']) logger.info(f"Loaded {len(self.database['codes'])} code snippets from database.") self._build_index() except Exception as e: logger.error(f"Error loading database: {e}") self._initialize_empty_database() else: logger.info("Creating new database.") self._initialize_empty_database() def _initialize_empty_database(self) -> None: self.database = {'codes': [], 'embeddings': []} self.code_embeddings = np.array([]) self._build_index() def _build_index(self) -> None: if len(self.code_embeddings) > 0 and self.embedding_model: dim = self.code_embeddings.shape[1] self.index = faiss.IndexFlatL2(dim) self.index.add(self.code_embeddings) logger.info(f"Built FAISS index with {len(self.code_embeddings)} vectors") class GradioInterface: def __init__(self, config: Config): self.config = config self.template_manager = TemplateManager(TEMPLATE_DIR) self.template_manager.load_templates() self.rag_system = RAGSystem(config) def format_code(self, code: str) -> str: try: return black.format_str(code, mode=black.FileMode()) except Exception as e: logger.warning(f"Code formatting failed: {e}") return code def _extract_components(self, code: str) -> List[str]: components = [] try: function_matches = re.findall(r'def (\w+)\(', code) class_matches = re.findall(r'class (\w+):', code) components.extend(function_matches) components.extend(class_matches) except Exception as e: logger.error(f"Error extracting components: {e}") return list(set(components)) def launch(self) -> None: with gr.Blocks(theme=gr.themes.Base()) as interface: # Custom CSS gr.Markdown( """
Generate and manage code templates easily