import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification # Initialize the FastAPI app app = FastAPI() # Load the model and tokenizer from Hugging Face model_name = "Canstralian/RabbitRedux" # Replace with your model's name tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.eval() # Define the input and output format for prediction requests class PredictionRequest(BaseModel): text: str class PredictionResponse(BaseModel): text: str prediction: str # Define prediction endpoint @app.post("/predict", response_model=PredictionResponse) async def predict(request: PredictionRequest): try: # Tokenize the input text inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True) # Perform inference with the model with torch.no_grad(): outputs = model(**inputs) # Get the predicted class prediction = torch.argmax(outputs.logits, dim=-1).item() # Map the prediction to a label (adjust as per your model's labels) labels = ["Label 1", "Label 2", "Label 3"] # Replace with your actual labels predicted_label = labels[prediction] # Return the prediction response return PredictionResponse(text=request.text, prediction=predicted_label) except Exception as e: raise HTTPException(status_code=500, detail="Prediction failed") # Define health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy"}