|
import torch |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
model_name = "Canstralian/RabbitRedux" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
model.eval() |
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
text: str |
|
|
|
class PredictionResponse(BaseModel): |
|
text: str |
|
prediction: str |
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
async def predict(request: PredictionRequest): |
|
try: |
|
|
|
inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
prediction = torch.argmax(outputs.logits, dim=-1).item() |
|
|
|
|
|
labels = ["Label 1", "Label 2", "Label 3"] |
|
predicted_label = labels[prediction] |
|
|
|
|
|
return PredictionResponse(text=request.text, prediction=predicted_label) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail="Prediction failed") |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy"} |
|
|