File size: 4,196 Bytes
8c2f469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import List, Optional
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
import uvicorn

app = FastAPI(title="Medical LLaMA API")

model = None
tokenizer = None
model_output_path = "./model/medical_llama_3b"

class TrainRequest(BaseModel):
    dataset_path: str
    num_epochs: int = 3
    batch_size: int = 4
    learning_rate: float = 2e-5

class Query(BaseModel):
    text: str
    max_length: int = 512
    temperature: float = 0.7
    num_return_sequences: int = 1

class Response(BaseModel):
    generated_text: List[str]

def train_model(dataset_path: str, num_epochs: int, batch_size: int, learning_rate: float):
    global model, tokenizer

    os.makedirs(model_output_path, exist_ok=True)

    model_name = "nvidia/Meta-Llama-3.2-3B-Instruct-ONNX-INT4"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)

    dataset = load_dataset("json", data_files=dataset_path)

    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

    tokenized_dataset = dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset["train"].column_names
    )

    training_args = TrainingArguments(
        output_dir=f"{model_output_path}/checkpoints",
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        num_train_epochs=num_epochs,
        learning_rate=learning_rate,
        fp16=True,
        save_steps=500,
        logging_steps=100,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    )

    # Start training
    trainer.train()

    # Save the final model and tokenizer
    model.save_pretrained(model_output_path)
    tokenizer.save_pretrained(model_output_path)

    print(f"Model and tokenizer saved to: {model_output_path}")

@app.post("/train")
async def train(request: TrainRequest, background_tasks: BackgroundTasks):
    background_tasks.add_task(train_model, request.dataset_path, request.num_epochs, request.batch_size, request.learning_rate)
    return {"message": "Training started in the background"}

@app.post("/generate", response_model=Response)
async def generate_text(query: Query):
    global model, tokenizer

    if model is None or tokenizer is None:
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_output_path)
            model = AutoModelForCausalLM.from_pretrained(
                model_output_path,
                torch_dtype=torch.float16,
                device_map="auto"
            )
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Error loading model: {str(e)}")

    try:
        inputs = tokenizer(
            query.text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=query.max_length
        ).to(model.device)

        with torch.no_grad():
            generated_ids = model.generate(
                inputs.input_ids,
                max_length=query.max_length,
                num_return_sequences=query.num_return_sequences,
                temperature=query.temperature,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        generated_texts = [
            tokenizer.decode(g, skip_special_tokens=True)
            for g in generated_ids
        ]

        return Response(generated_text=generated_texts)

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)