fullstuckdev
first init
8c2f469
raw
history blame
4.2 kB
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import List, Optional
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
import uvicorn
app = FastAPI(title="Medical LLaMA API")
model = None
tokenizer = None
model_output_path = "./model/medical_llama_3b"
class TrainRequest(BaseModel):
dataset_path: str
num_epochs: int = 3
batch_size: int = 4
learning_rate: float = 2e-5
class Query(BaseModel):
text: str
max_length: int = 512
temperature: float = 0.7
num_return_sequences: int = 1
class Response(BaseModel):
generated_text: List[str]
def train_model(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float):
global model, tokenizer
os.makedirs(model_output_path, exist_ok=True)
model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
dataset = load_dataset("json", data_files=dataset_path)
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
tokenized_dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=dataset["train"].column_names
)
training_args = TrainingArguments(
output_dir=f"{model_output_path}/checkpoints",
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=4,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
save_steps=500,
logging_steps=100,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
# Start training
trainer.train()
# Save the final model and tokenizer
model.save_pretrained(model_output_path)
tokenizer.save_pretrained(model_output_path)
print(f"Model and tokenizer saved to: {model_output_path}")
@app.post("/train")
async def train(request: TrainRequest, background_tasks: BackgroundTasks):
background_tasks.add_task(train_model, request.dataset_path, request.num_epochs, request.batch_size, request.learning_rate)
return {"message": "Training started in the background"}
@app.post("/generate", response_model=Response)
async def generate_text(query: Query):
global model, tokenizer
if model is None or tokenizer is None:
try:
tokenizer = AutoTokenizer.from_pretrained(model_output_path)
model = AutoModelForCausalLM.from_pretrained(
model_output_path,
torch_dtype=torch.float16,
device_map="auto"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}")
try:
inputs = tokenizer(
query.text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=query.max_length
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(
inputs.input_ids,
max_length=query.max_length,
num_return_sequences=query.num_return_sequences,
temperature=query.temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated_texts = [
tokenizer.decode(g, skip_special_tokens=True)
for g in generated_ids
]
return Response(generated_text=generated_texts)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)