Spaces:
Build error
Build error
import os | |
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import torch | |
from datasets import load_dataset | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
TrainingArguments, | |
Trainer, | |
DataCollatorForLanguageModeling | |
) | |
import uvicorn | |
app = FastAPI(title="Medical LLaMA API") | |
model = None | |
tokenizer = None | |
model_output_path = "./model/medical_llama_3b" | |
class TrainRequest(BaseModel): | |
dataset_path: str | |
num_epochs: int = 3 | |
batch_size: int = 4 | |
learning_rate: float = 2e-5 | |
class Query(BaseModel): | |
text: str | |
max_length: int = 512 | |
temperature: float = 0.7 | |
num_return_sequences: int = 1 | |
class Response(BaseModel): | |
generated_text: List[str] | |
def train_model(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float): | |
global model, tokenizer | |
os.makedirs(model_output_path, exist_ok=True) | |
model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
dataset = load_dataset("json", data_files=dataset_path) | |
def preprocess_function(examples): | |
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512) | |
tokenized_dataset = dataset.map( | |
preprocess_function, | |
batched=True, | |
remove_columns=dataset["train"].column_names | |
) | |
training_args = TrainingArguments( | |
output_dir=f"{model_output_path}/checkpoints", | |
per_device_train_batch_size=batch_size, | |
gradient_accumulation_steps=4, | |
num_train_epochs=num_epochs, | |
learning_rate=learning_rate, | |
fp16=True, | |
save_steps=500, | |
logging_steps=100, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"], | |
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), | |
) | |
# Start training | |
trainer.train() | |
# Save the final model and tokenizer | |
model.save_pretrained(model_output_path) | |
tokenizer.save_pretrained(model_output_path) | |
print(f"Model and tokenizer saved to: {model_output_path}") | |
async def train(request: TrainRequest, background_tasks: BackgroundTasks): | |
background_tasks.add_task(train_model, request.dataset_path, request.num_epochs, request.batch_size, request.learning_rate) | |
return {"message": "Training started in the background"} | |
async def generate_text(query: Query): | |
global model, tokenizer | |
if model is None or tokenizer is None: | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_output_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_output_path, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}") | |
try: | |
inputs = tokenizer( | |
query.text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=query.max_length | |
).to(model.device) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
inputs.input_ids, | |
max_length=query.max_length, | |
num_return_sequences=query.num_return_sequences, | |
temperature=query.temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
generated_texts = [ | |
tokenizer.decode(g, skip_special_tokens=True) | |
for g in generated_ids | |
] | |
return Response(generated_text=generated_texts) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False) |