Prathamesh1420's picture
Update app.py
1ff350d verified
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
import joblib
import numpy as np
from sklearn.datasets import load_iris
# Load the trained model
model = joblib.load("iris_model.pkl")
# Initialize FastAPI
app = FastAPI()
# Set up templates
templates = Jinja2Templates(directory="templates")
# Pydantic models for input and output data
class IrisInput(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
class IrisPrediction(BaseModel):
predicted_class: int
predicted_class_name: str
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict", response_model=IrisPrediction)
async def predict(
request: Request,
sepal_length: float = Form(...),
sepal_width: float = Form(...),
petal_length: float = Form(...),
petal_width: float = Form(...),
):
# Convert the input data to a numpy array
input_data = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
# Make a prediction
predicted_class = model.predict(input_data)[0]
predicted_class_name = load_iris().target_names[predicted_class]
return templates.TemplateResponse(
"result.html",
{
"request": request,
"predicted_class": predicted_class,
"predicted_class_name": predicted_class_name,
"sepal_length": sepal_length,
"sepal_width": sepal_width,
"petal_length": petal_length,
"petal_width": petal_width,
},
)
if __name__ == "__main__":
demo.launch()