import os from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel from typing import List, Optional import torch from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) import uvicorn app = FastAPI(title="Medical LLaMA API") model = None tokenizer = None model_output_path = "./model/medical_llama_3b" class TrainRequest(BaseModel): dataset_path: str num_epochs: int = 3 batch_size: int = 4 learning_rate: float = 2e-5 class Query(BaseModel): text: str max_length: int = 512 temperature: float = 0.7 num_return_sequences: int = 1 class Response(BaseModel): generated_text: List[str] def train_model(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float): global model, tokenizer os.makedirs(model_output_path, exist_ok=True) model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) dataset = load_dataset("json", data_files=dataset_path) def preprocess_function(examples): return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512) tokenized_dataset = dataset.map( preprocess_function, batched=True, remove_columns=dataset["train"].column_names ) training_args = TrainingArguments( output_dir=f"{model_output_path}/checkpoints", per_device_train_batch_size=batch_size, gradient_accumulation_steps=4, num_train_epochs=num_epochs, learning_rate=learning_rate, fp16=True, save_steps=500, logging_steps=100, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), ) # Start training trainer.train() # Save the final model and tokenizer model.save_pretrained(model_output_path) tokenizer.save_pretrained(model_output_path) print(f"Model and tokenizer saved to: {model_output_path}") @app.post("/train") async def train(request: TrainRequest, background_tasks: BackgroundTasks): background_tasks.add_task(train_model, request.dataset_path, request.num_epochs, request.batch_size, request.learning_rate) return {"message": "Training started in the background"} @app.post("/generate", response_model=Response) async def generate_text(query: Query): global model, tokenizer if model is None or tokenizer is None: try: tokenizer = AutoTokenizer.from_pretrained(model_output_path) model = AutoModelForCausalLM.from_pretrained( model_output_path, torch_dtype=torch.float16, device_map="auto" ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}") try: inputs = tokenizer( query.text, return_tensors="pt", padding=True, truncation=True, max_length=query.max_length ).to(model.device) with torch.no_grad(): generated_ids = model.generate( inputs.input_ids, max_length=query.max_length, num_return_sequences=query.num_return_sequences, temperature=query.temperature, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_texts = [ tokenizer.decode(g, skip_special_tokens=True) for g in generated_ids ] return Response(generated_text=generated_texts) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy"} if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)