|
import torch |
|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import logging |
|
import json |
|
import os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
model = torch.load("path/to/your/model.pth", map_location=torch.device("cpu")) |
|
model.eval() |
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
""" |
|
Data model for the prediction request. |
|
|
|
Attributes: |
|
text (str): Input text for model inference. |
|
""" |
|
text: str |
|
|
|
class PredictionResponse(BaseModel): |
|
""" |
|
Data model for the prediction response. |
|
|
|
Attributes: |
|
text (str): The original input text. |
|
prediction (str): The predicted result from the model. |
|
""" |
|
text: str |
|
prediction: str |
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
async def predict(request: PredictionRequest): |
|
""" |
|
Endpoint for generating a prediction based on input text. |
|
|
|
Args: |
|
request (PredictionRequest): The request body containing the input text. |
|
|
|
Returns: |
|
PredictionResponse: The response body containing the original text and prediction. |
|
|
|
Raises: |
|
HTTPException: If any error occurs during the prediction process. |
|
""" |
|
try: |
|
|
|
inputs = tokenizer(request.text, return_tensors="pt") |
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
prediction = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True) |
|
|
|
|
|
return PredictionResponse(text=request.text, prediction=prediction) |
|
except Exception as e: |
|
logging.error("Error during prediction", exc_info=True) |
|
raise HTTPException(status_code=500, detail="Prediction failed") |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
""" |
|
Health check endpoint to verify if the service is up and running. |
|
|
|
Returns: |
|
dict: A dictionary containing the status of the service. |
|
""" |
|
logging.info("Health check requested.") |
|
return {"status": "healthy"} |
|
|