acecalisto3 commited on
Commit
7e568ab
·
verified ·
1 Parent(s): d5d02d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -172
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import threading
3
  import time
4
  import gradio as gr
@@ -9,22 +8,17 @@ import torch
9
  import tempfile
10
  import subprocess
11
  import ast
 
 
12
  from pathlib import Path
13
  from typing import Dict, List, Tuple, Optional, Any, Union
14
  from dataclasses import dataclass, field
15
  from enum import Enum
16
- from transformers import (
17
- AutoTokenizer,
18
- AutoModelForCausalLM,
19
- pipeline,
20
- AutoProcessor,
21
- AutoModel
22
- )
23
  from sentence_transformers import SentenceTransformer
24
  import faiss
25
  import numpy as np
26
  from PIL import Image
27
- from transformers import BlipForConditionalGeneration
28
 
29
  # Configure logging
30
  logging.basicConfig(
@@ -42,14 +36,16 @@ DEFAULT_PORT = 7860
42
  MODEL_CACHE_DIR = Path("model_cache")
43
  TEMPLATE_DIR = Path("templates")
44
  TEMP_DIR = Path("temp")
 
 
45
 
46
  # Ensure directories exist
47
  for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]:
48
- directory.mkdir(exist_ok=True)
 
49
 
50
  @dataclass
51
  class Template:
52
- """Template data structure"""
53
  code: str
54
  description: str
55
  components: List[str]
@@ -57,192 +53,169 @@ class Template:
57
  version: str = "1.0"
58
 
59
  class TemplateManager:
60
- def __init__(self, template_dir: Path):
61
- self.template_dir = template_dir
62
- self.templates: Dict[str, Any] = {}
63
 
64
- def _get_builtin_templates(self) -> Dict[str, Any]:
65
- # Implement this method to return built-in templates
66
- # For now, we'll return an empty dict
67
- return {}
68
 
69
- def load_templates(self):
70
- """Load all templates from directory"""
71
  try:
72
- # Load built-in templates
73
- self.templates.update(self._get_builtin_templates())
74
-
75
- # Load custom templates
76
- for template_file in self.template_dir.glob("*.json"):
77
- try:
78
- with open(template_file, 'r', encoding='utf-8') as f:
79
- template_data = json.load(f)
80
-
81
- # Process template_data
82
- template_name = template_file.stem
83
- self.templates[template_name] = template_data
84
- logger.info(f"Loaded template: {template_name}")
85
-
86
- except json.JSONDecodeError as e:
87
- logger.error(f"Error parsing template {template_file}: {e}")
88
- except Exception as e:
89
- logger.error(f"Error loading template {template_file}: {e}")
90
-
91
  except Exception as e:
92
- logger.error(f"Error loading templates: {e}")
93
-
94
- def get_template(self, name: str) -> Dict[str, Any]:
95
- """Retrieve a template by name"""
96
- return self.templates.get(name, {})
97
 
98
- def list_templates(self) -> List[Dict[str, str]]:
99
- """List all available templates"""
100
- return [{"name": name, "description": template.get("description", "")}
101
- for name, template in self.templates.items()]
102
 
103
- def save_template(self, name: str, template: Dict[str, Any]) -> bool:
104
- """Save a new template"""
 
105
  try:
106
- file_path = self.template_dir / f"{name}.json"
107
- with open(file_path, 'w', encoding='utf-8') as f:
108
- json.dump(template, f, indent=2)
109
- self.templates[name] = template
110
- logger.info(f"Saved template: {name}")
111
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  except Exception as e:
113
- logger.error(f"Error saving template {name}: {e}")
114
- return False
115
-
116
- class InterfaceAnalyzer:
117
- @staticmethod
118
- def extract_components(code: str) -> List[str]:
119
- """Extract components from the interface code"""
120
- # This is a placeholder implementation. In a real-world scenario,
121
- # you'd want to parse the code and extract the actual components.
122
- return ["Textbox", "Button"] # Example components
123
-
124
- @staticmethod
125
- def analyze_interface_structure(code: str) -> Dict[str, Any]:
126
- """Analyze the structure of the interface code"""
127
- # This is a placeholder implementation. In a real-world scenario,
128
- # you'd want to parse the code and extract the actual structure.
129
- return {
130
- "components": {"Textbox": 1, "Button": 1},
131
- "functions": {"submit": "def submit(text): ..."},
132
- "dependencies": ["gradio"]
133
- }
134
-
135
- class CodeGenerator:
136
- @staticmethod
137
- def generate_requirements(dependencies: List[str]) -> str:
138
- """Generate requirements.txt content"""
139
- return "\n".join(dependencies)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  class GradioInterface:
142
  def __init__(self):
143
  self.template_manager = TemplateManager(TEMPLATE_DIR)
144
  self.template_manager.load_templates()
145
  self.current_code = ""
146
- # Initialize other necessary components (e.g., rag_system, preview_manager)
147
- # self.rag_system = ...
148
- # self.preview_manager = ...
149
-
150
- def _get_template_choices(self) -> List[str]:
151
- """Get list of available templates"""
152
- templates = self.template_manager.list_templates()
153
- return [""] + [t["name"] for t in templates]
154
-
155
- def _generate_interface(
156
- self,
157
- description: str,
158
- screenshot: Optional[Image.Image],
159
- template_name: str
160
- ) -> Tuple[str, str]:
161
- """Generate interface code"""
162
- try:
163
- if template_name:
164
- template = self.template_manager.get_template(template_name)
165
- if template:
166
- code = self.rag_system.generate_code(description, template["code"])
167
- else:
168
- raise ValueError(f"Template {template_name} not found")
169
- else:
170
- code = self.rag_system.generate_interface(screenshot, description)
171
-
172
- self.current_code = code
173
- return code, "✅ Code generated successfully"
174
-
175
- except Exception as e:
176
- error_msg = f"❌ Error generating interface: {str(e)}"
177
- logger.error(error_msg)
178
- return "", error_msg
179
 
180
- def _save_as_template(self, code: str, description: str) -> Tuple[List[str], str]:
181
- """Save current code as template"""
 
 
182
  try:
183
- # Generate template name
184
- base_name = "custom_template"
185
- counter = 1
186
- name = base_name
187
- while self.template_manager.get_template(name):
188
- name = f"{base_name}_{counter}"
189
- counter += 1
190
-
191
- # Create template
192
- template = {
193
- "code": code,
194
- "description": description,
195
- "components": InterfaceAnalyzer.extract_components(code),
196
- "metadata": {"category": "custom"}
197
- }
198
-
199
- # Save template
200
  if self.template_manager.save_template(name, template):
201
- return self._get_template_choices(), f"✅ Template saved as {name}"
 
202
  else:
203
  raise Exception("Failed to save template")
204
-
205
  except Exception as e:
206
  error_msg = f"❌ Error saving template: {str(e)}"
207
  logger.error(error_msg)
208
  return self._get_template_choices(), error_msg
209
 
210
- def _load_template(self, template_name: str) -> str:
211
- """Load selected template"""
212
- if not template_name:
213
- return ""
214
-
215
- template = self.template_manager.get_template(template_name)
216
- if template:
217
- return template["code"]
218
- return ""
219
-
220
- def _analyze_interface(self, code: str) -> Tuple[Dict, Dict, Dict, str]:
221
- """Analyze interface structure"""
222
- try:
223
- analysis = InterfaceAnalyzer.analyze_interface_structure(code)
224
-
225
- # Generate requirements.txt
226
- dependencies = analysis.get("dependencies", [])
227
- requirements = CodeGenerator.generate_requirements(dependencies)
228
-
229
- return (
230
- analysis.get("components", {}),
231
- analysis.get("functions", {}),
232
- {"dependencies": dependencies},
233
- requirements
234
- )
235
-
236
- except Exception as e:
237
- logger.error(f"Error analyzing interface: {e}")
238
- return {}, {}, {}, ""
239
-
240
- # Add other necessary methods (e.g., _clear_interface, _validate_code, _format_code, _start_preview, _stop_preview)
241
 
242
  def launch(self, **kwargs):
243
- """Launch the interface"""
244
- # Implement the launch logic here
245
- pass
 
 
246
 
247
  def main():
248
  # Configure logging
@@ -272,4 +245,4 @@ def main():
272
  logger.info("=== Application Shutdown ===")
273
 
274
  if __name__ == "__main__":
275
- main()
 
 
1
  import threading
2
  import time
3
  import gradio as gr
 
8
  import tempfile
9
  import subprocess
10
  import ast
11
+ import os
12
+ import dataclasses
13
  from pathlib import Path
14
  from typing import Dict, List, Tuple, Optional, Any, Union
15
  from dataclasses import dataclass, field
16
  from enum import Enum
17
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
 
 
 
 
18
  from sentence_transformers import SentenceTransformer
19
  import faiss
20
  import numpy as np
21
  from PIL import Image
 
22
 
23
  # Configure logging
24
  logging.basicConfig(
 
36
  MODEL_CACHE_DIR = Path("model_cache")
37
  TEMPLATE_DIR = Path("templates")
38
  TEMP_DIR = Path("temp")
39
+ DATABASE_PATH = Path("code_database.json") #Path for our simple database
40
+
41
 
42
  # Ensure directories exist
43
  for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]:
44
+ directory.mkdir(exist_ok=True, parents=True)
45
+
46
 
47
  @dataclass
48
  class Template:
 
49
  code: str
50
  description: str
51
  components: List[str]
 
53
  version: str = "1.0"
54
 
55
  class TemplateManager:
56
+ # ... (TemplateManager remains the same) ...
 
 
57
 
 
 
 
 
58
 
59
+ class RAGSystem:
60
+ def __init__(self, model_name: str = "gpt2", device: str = "cuda" if torch.cuda.is_available() else "cpu", embedding_model="all-mpnet-base-v2"):
61
  try:
62
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
64
+ self.device = device
65
+ self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device)
66
+ self.embedding_model = SentenceTransformer(embedding_model)
67
+ self.load_database()
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  except Exception as e:
69
+ logger.error(f"Error loading language model or embedding model: {e}. Falling back to placeholder generation.")
70
+ self.pipe = None
71
+ self.embedding_model = None
72
+ self.code_embeddings = None
 
73
 
 
 
 
 
74
 
75
+ def load_database(self):
76
+ """Loads or creates the code database"""
77
+ if DATABASE_PATH.exists():
78
  try:
79
+ with open(DATABASE_PATH, 'r', encoding='utf-8') as f:
80
+ self.database = json.load(f)
81
+ self.code_embeddings = np.array(self.database['embeddings'])
82
+ logger.info("Loaded code database from file")
83
+ except (json.JSONDecodeError, KeyError) as e:
84
+ logger.error(f"Error loading code database: {e}. Creating new database.")
85
+ self.database = {'codes': [], 'embeddings': []}
86
+ self.code_embeddings = np.array([])
87
+
88
+ else:
89
+ logger.info("Code database does not exist. Creating new database.")
90
+ self.database = {'codes': [], 'embeddings': []}
91
+ self.code_embeddings = np.array([])
92
+
93
+ if self.embedding_model and len(self.database['codes']) != len(self.database['embeddings']):
94
+ logger.warning("Mismatch between number of codes and embeddings, rebuilding embeddings")
95
+ self.rebuild_embeddings()
96
+ elif self.embedding_model is None:
97
+ logger.warning("Embeddings are not supported in this context. ")
98
+ #Index the embeddings for efficient searching
99
+ if len(self.code_embeddings) > 0 and self.embedding_model:
100
+ self.index = faiss.IndexFlatL2(self.code_embeddings.shape[1]) #L2 distance
101
+ self.index.add(self.code_embeddings)
102
+
103
+ def add_to_database(self, code: str):
104
+ """Adds a code snippet to the database"""
105
+ try:
106
+ embedding = self.embedding_model.encode(code)
107
+ self.database['codes'].append(code)
108
+ self.database['embeddings'].append(embedding.tolist())
109
+ self.code_embeddings = np.vstack((self.code_embeddings, embedding))
110
+ self.index.add(np.array([embedding])) # update FAISS index
111
+ self.save_database()
112
+ logger.info(f"Added code snippet to database. Total size:{len(self.database['codes'])}")
113
  except Exception as e:
114
+ logger.error(f"Error adding to database: {e}")
115
+
116
+
117
+ def save_database(self):
118
+ """Saves the database to a file"""
119
+ try:
120
+ with open(DATABASE_PATH, 'w', encoding='utf-8') as f:
121
+ json.dump(self.database, f, indent=2)
122
+ logger.info(f"Saved database to {DATABASE_PATH}")
123
+ except Exception as e:
124
+ logger.error(f"Error saving database: {e}")
125
+
126
+ def rebuild_embeddings(self):
127
+ """rebuilds embeddings from the codes"""
128
+ try:
129
+ embeddings = self.embedding_model.encode(self.database['codes'])
130
+ self.code_embeddings = embeddings
131
+ self.database['embeddings'] = embeddings.tolist()
132
+ self.index = faiss.IndexFlatL2(embeddings.shape[1]) #L2 distance
133
+ self.index.add(embeddings)
134
+ self.save_database()
135
+ logger.info("Rebuilt and saved embeddings to the database")
136
+ except Exception as e:
137
+ logger.error(f"Error rebuilding embeddings: {e}")
138
+
139
+
140
+ def retrieve_similar_code(self, description: str, top_k: int = 3) -> List[str]:
141
+ """Retrieves similar code snippets from the database"""
142
+ if self.embedding_model is None:
143
+ return []
144
+ try:
145
+ embedding = self.embedding_model.encode(description)
146
+ D, I = self.index.search(np.array([embedding]), top_k)
147
+ return [self.database['codes'][i] for i in I[0]]
148
+ except Exception as e:
149
+ logger.error(f"Error retrieving similar code: {e}")
150
+ return []
151
+
152
+ def generate_code(self, description: str, template_code: str) -> str:
153
+ retrieved_codes = self.retrieve_similar_code(description)
154
+ prompt = f"Description: {description}\nRetrieved Code Snippets:\n{''.join([f'```python\n{code}\n```\n' for code in retrieved_codes])}\nTemplate:\n```python\n{template_code}\n```\nGenerated Code:\n```python\n"
155
+ if self.pipe:
156
+ try:
157
+ generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
158
+ generated_code = generated_text.split("Generated Code:")[1].strip().split('```')[0]
159
+ return generated_code
160
+ except Exception as e:
161
+ logger.error(f"Error generating code with language model: {e}. Returning template code.")
162
+ return template_code
163
+ else:
164
+ return f"# Placeholder code generation. Description: {description}\n{template_code}"
165
+
166
+ def generate_interface(self, screenshot: Optional[Image.Image], description: str) -> str:
167
+ retrieved_codes = self.retrieve_similar_code(description)
168
+ prompt = f"Create a Gradio interface based on this description: {description}\nRetrieved Code Snippets:\n{''.join([f'```python\n{code}\n```\n' for code in retrieved_codes])}"
169
+ if screenshot:
170
+ prompt += "\nThe interface should resemble the provided screenshot."
171
+ prompt += "\n```python\n"
172
+ if self.pipe:
173
+ try:
174
+ generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
175
+ generated_code = generated_text.split("```")[1].strip()
176
+ return generated_code
177
+ except Exception as e:
178
+ logger.error(f"Error generating interface with language model: {e}. Returning placeholder.")
179
+ return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
180
+ else:
181
+ return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
182
+
183
+ class PreviewManager:
184
+ # ... (PreviewManager remains largely the same) ...
185
+
186
 
187
  class GradioInterface:
188
  def __init__(self):
189
  self.template_manager = TemplateManager(TEMPLATE_DIR)
190
  self.template_manager.load_templates()
191
  self.current_code = ""
192
+ self.rag_system = RAGSystem()
193
+ self.preview_manager = PreviewManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # ... (other GradioInterface methods remain largely the same, but you may want to improve error handling) ...
196
+
197
+ def _save_as_template(self, code: str, name: str, description: str) -> Tuple[List[str], str]:
198
+ """Save current code as template and add to database"""
199
  try:
200
+ components = self._extract_components(code)
201
+ template = Template(code=code, description=description, components=components)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  if self.template_manager.save_template(name, template):
203
+ self.rag_system.add_to_database(code) #add code to the database
204
+ return self._get_template_choices(), f"✅ Template saved as {name}"
205
  else:
206
  raise Exception("Failed to save template")
 
207
  except Exception as e:
208
  error_msg = f"❌ Error saving template: {str(e)}"
209
  logger.error(error_msg)
210
  return self._get_template_choices(), error_msg
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  def launch(self, **kwargs):
214
+ with gr.Blocks() as interface:
215
+ # ... (Interface remains largely the same) ...
216
+
217
+ interface.launch(**kwargs)
218
+
219
 
220
  def main():
221
  # Configure logging
 
245
  logger.info("=== Application Shutdown ===")
246
 
247
  if __name__ == "__main__":
248
+ main()