fullstuckdev commited on
Commit
93374aa
·
1 Parent(s): f6b6cd4

fixing training

Browse files
Files changed (1) hide show
  1. app.py +186 -1
app.py CHANGED
@@ -6,6 +6,9 @@ import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import logging
8
  from typing import List, Optional
 
 
 
9
 
10
  # Setup logging
11
  logging.basicConfig(level=logging.INFO)
@@ -31,6 +34,26 @@ class HealthResponse(BaseModel):
31
  gpu_available: bool
32
  device: str
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Initialize FastAPI app
35
  app = FastAPI(
36
  title="Medical LLaMA API",
@@ -133,4 +156,166 @@ async def startup_event():
133
  tokenizer, model = init_model()
134
  logger.info("Model loaded successfully")
135
  except Exception as e:
136
- logger.error(f"Failed to load model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import logging
8
  from typing import List, Optional
9
+ from datasets import load_dataset
10
+ from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
11
+ import json
12
 
13
  # Setup logging
14
  logging.basicConfig(level=logging.INFO)
 
34
  gpu_available: bool
35
  device: str
36
 
37
+ class TrainRequest(BaseModel):
38
+ dataset_path: str
39
+ num_epochs: Optional[int] = 3
40
+ batch_size: Optional[int] = 4
41
+ learning_rate: Optional[float] = 2e-5
42
+
43
+ class TrainResponse(BaseModel):
44
+ status: str
45
+ message: str
46
+
47
+ # Add training status tracking
48
+ class TrainingStatus:
49
+ def __init__(self):
50
+ self.is_training = False
51
+ self.current_epoch = 0
52
+ self.current_loss = None
53
+ self.status = "idle"
54
+
55
+ training_status = TrainingStatus()
56
+
57
  # Initialize FastAPI app
58
  app = FastAPI(
59
  title="Medical LLaMA API",
 
156
  tokenizer, model = init_model()
157
  logger.info("Model loaded successfully")
158
  except Exception as e:
159
+ logger.error(f"Failed to load model: {str(e)}")
160
+
161
+ @app.post("/train", response_model=TrainResponse, tags=["Training"])
162
+ async def train_model(request: TrainRequest, background_tasks: BackgroundTasks):
163
+ """
164
+ Start model training with the specified dataset
165
+
166
+ Parameters:
167
+ - dataset_path: Path to the JSON dataset file
168
+ - num_epochs: Number of training epochs
169
+ - batch_size: Training batch size
170
+ - learning_rate: Learning rate for training
171
+ """
172
+ if training_status.is_training:
173
+ raise HTTPException(status_code=400, detail="Training is already in progress")
174
+
175
+ try:
176
+ # Verify dataset exists
177
+ if not os.path.exists(request.dataset_path):
178
+ raise HTTPException(status_code=404, detail="Dataset file not found")
179
+
180
+ # Start training in background
181
+ background_tasks.add_task(
182
+ run_training,
183
+ request.dataset_path,
184
+ request.num_epochs,
185
+ request.batch_size,
186
+ request.learning_rate
187
+ )
188
+
189
+ return TrainResponse(
190
+ status="started",
191
+ message="Training started in background"
192
+ )
193
+
194
+ except Exception as e:
195
+ logger.error(f"Training setup error: {str(e)}")
196
+ raise HTTPException(status_code=500, detail=str(e))
197
+
198
+ @app.get("/train/status", tags=["Training"])
199
+ async def get_training_status():
200
+ """
201
+ Get current training status
202
+ """
203
+ return {
204
+ "is_training": training_status.is_training,
205
+ "current_epoch": training_status.current_epoch,
206
+ "current_loss": training_status.current_loss,
207
+ "status": training_status.status
208
+ }
209
+
210
+ # Add training function
211
+ async def run_training(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float):
212
+ global model, tokenizer, training_status
213
+
214
+ try:
215
+ training_status.is_training = True
216
+ training_status.status = "loading_dataset"
217
+
218
+ # Load dataset
219
+ dataset = load_dataset("json", data_files=dataset_path)
220
+
221
+ training_status.status = "preprocessing"
222
+
223
+ # Preprocess function
224
+ def preprocess_function(examples):
225
+ return tokenizer(
226
+ examples["text"],
227
+ truncation=True,
228
+ padding="max_length",
229
+ max_length=512
230
+ )
231
+
232
+ # Tokenize dataset
233
+ tokenized_dataset = dataset.map(
234
+ preprocess_function,
235
+ batched=True,
236
+ remove_columns=dataset["train"].column_names
237
+ )
238
+
239
+ training_status.status = "training"
240
+
241
+ # Training arguments
242
+ training_args = TrainingArguments(
243
+ output_dir=f"{model_output_path}/checkpoints",
244
+ per_device_train_batch_size=batch_size,
245
+ gradient_accumulation_steps=4,
246
+ num_train_epochs=num_epochs,
247
+ learning_rate=learning_rate,
248
+ fp16=True,
249
+ save_steps=500,
250
+ logging_steps=100,
251
+ )
252
+
253
+ # Initialize trainer
254
+ trainer = Trainer(
255
+ model=model,
256
+ args=training_args,
257
+ train_dataset=tokenized_dataset["train"],
258
+ data_collator=DataCollatorForLanguageModeling(
259
+ tokenizer=tokenizer,
260
+ mlm=False
261
+ ),
262
+ )
263
+
264
+ # Training callback to update status
265
+ class TrainingCallback(trainer.callback_handler):
266
+ def on_epoch_begin(self, args, state, control, **kwargs):
267
+ training_status.current_epoch = state.epoch
268
+
269
+ def on_log(self, args, state, control, logs=None, **kwargs):
270
+ if logs:
271
+ training_status.current_loss = logs.get("loss", None)
272
+
273
+ trainer.add_callback(TrainingCallback)
274
+
275
+ # Start training
276
+ trainer.train()
277
+
278
+ # Save the model
279
+ training_status.status = "saving"
280
+ model.save_pretrained(model_output_path)
281
+ tokenizer.save_pretrained(model_output_path)
282
+
283
+ training_status.status = "completed"
284
+ logger.info("Training completed successfully")
285
+
286
+ except Exception as e:
287
+ training_status.status = f"failed: {str(e)}"
288
+ logger.error(f"Training error: {str(e)}")
289
+ raise
290
+
291
+ finally:
292
+ training_status.is_training = False
293
+
294
+ # Update model initialization
295
+ def init_model():
296
+ try:
297
+ device = "cuda" if torch.cuda.is_available() else "cpu"
298
+ logger.info(f"Loading model on device: {device}")
299
+
300
+ # Try to load fine-tuned model if it exists
301
+ if os.path.exists(model_output_path):
302
+ tokenizer = AutoTokenizer.from_pretrained(model_output_path)
303
+ model = AutoModelForCausalLM.from_pretrained(
304
+ model_output_path,
305
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
306
+ device_map="auto"
307
+ )
308
+ else:
309
+ # Load base model if no fine-tuned model exists
310
+ model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4"
311
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
312
+ model = AutoModelForCausalLM.from_pretrained(
313
+ model_name,
314
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
315
+ device_map="auto"
316
+ )
317
+
318
+ return tokenizer, model
319
+ except Exception as e:
320
+ logger.error(f"Model initialization error: {str(e)}")
321
+ raise