RabbitRedux / app.py
Canstralian's picture
Update app.py
ed06dec verified
raw
history blame
2.48 kB
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import logging
import json
import os
# Set up logging configuration
logging.basicConfig(level=logging.INFO)
# Initialize the FastAPI app
app = FastAPI()
# Load the trained model (adjust the path to your saved model)
model = torch.load("path/to/your/model.pth", map_location=torch.device("cpu")) # Replace with your actual model path
model.eval()
# Define the input and output format for prediction requests
class PredictionRequest(BaseModel):
"""
Data model for the prediction request.
Attributes:
text (str): Input text for model inference.
"""
text: str
class PredictionResponse(BaseModel):
"""
Data model for the prediction response.
Attributes:
text (str): The original input text.
prediction (str): The predicted result from the model.
"""
text: str
prediction: str
# Define prediction endpoint
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""
Endpoint for generating a prediction based on input text.
Args:
request (PredictionRequest): The request body containing the input text.
Returns:
PredictionResponse: The response body containing the original text and prediction.
Raises:
HTTPException: If any error occurs during the prediction process.
"""
try:
# Tokenize the input text (assuming you're using a tokenizer for text inputs)
inputs = tokenizer(request.text, return_tensors="pt")
# Perform inference with the model
outputs = model(**inputs)
# Get the predicted token and decode it back to text
prediction = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
# Return the prediction response
return PredictionResponse(text=request.text, prediction=prediction)
except Exception as e:
logging.error("Error during prediction", exc_info=True)
raise HTTPException(status_code=500, detail="Prediction failed")
# Define health check endpoint
@app.get("/health")
async def health_check():
"""
Health check endpoint to verify if the service is up and running.
Returns:
dict: A dictionary containing the status of the service.
"""
logging.info("Health check requested.")
return {"status": "healthy"}