import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel import logging import json import os # Set up logging configuration logging.basicConfig(level=logging.INFO) # Initialize the FastAPI app app = FastAPI() # Load the trained model (adjust the path to your saved model) model = torch.load("path/to/your/model.pth", map_location=torch.device("cpu")) # Replace with your actual model path model.eval() # Define the input and output format for prediction requests class PredictionRequest(BaseModel): """ Data model for the prediction request. Attributes: text (str): Input text for model inference. """ text: str class PredictionResponse(BaseModel): """ Data model for the prediction response. Attributes: text (str): The original input text. prediction (str): The predicted result from the model. """ text: str prediction: str # Define prediction endpoint @app.post("/predict", response_model=PredictionResponse) async def predict(request: PredictionRequest): """ Endpoint for generating a prediction based on input text. Args: request (PredictionRequest): The request body containing the input text. Returns: PredictionResponse: The response body containing the original text and prediction. Raises: HTTPException: If any error occurs during the prediction process. """ try: # Tokenize the input text (assuming you're using a tokenizer for text inputs) inputs = tokenizer(request.text, return_tensors="pt") # Perform inference with the model outputs = model(**inputs) # Get the predicted token and decode it back to text prediction = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True) # Return the prediction response return PredictionResponse(text=request.text, prediction=prediction) except Exception as e: logging.error("Error during prediction", exc_info=True) raise HTTPException(status_code=500, detail="Prediction failed") # Define health check endpoint @app.get("/health") async def health_check(): """ Health check endpoint to verify if the service is up and running. Returns: dict: A dictionary containing the status of the service. """ logging.info("Health check requested.") return {"status": "healthy"}