Canstralian commited on
Commit
177d7f0
·
verified ·
1 Parent(s): dd3e589

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -47
app.py CHANGED
@@ -1,80 +1,49 @@
1
  import torch
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
- import logging
5
- import json
6
- import os
7
-
8
- # Set up logging configuration
9
- logging.basicConfig(level=logging.INFO)
10
 
11
  # Initialize the FastAPI app
12
  app = FastAPI()
13
 
14
- # Load the trained model (adjust the path to your saved model)
15
- model = torch.load("path/to/your/model.pth", map_location=torch.device("cpu")) # Replace with your actual model path
 
 
16
  model.eval()
17
 
18
  # Define the input and output format for prediction requests
19
  class PredictionRequest(BaseModel):
20
- """
21
- Data model for the prediction request.
22
-
23
- Attributes:
24
- text (str): Input text for model inference.
25
- """
26
  text: str
27
 
28
  class PredictionResponse(BaseModel):
29
- """
30
- Data model for the prediction response.
31
-
32
- Attributes:
33
- text (str): The original input text.
34
- prediction (str): The predicted result from the model.
35
- """
36
  text: str
37
  prediction: str
38
 
39
  # Define prediction endpoint
40
  @app.post("/predict", response_model=PredictionResponse)
41
  async def predict(request: PredictionRequest):
42
- """
43
- Endpoint for generating a prediction based on input text.
44
-
45
- Args:
46
- request (PredictionRequest): The request body containing the input text.
47
-
48
- Returns:
49
- PredictionResponse: The response body containing the original text and prediction.
50
-
51
- Raises:
52
- HTTPException: If any error occurs during the prediction process.
53
- """
54
  try:
55
- # Tokenize the input text (assuming you're using a tokenizer for text inputs)
56
- inputs = tokenizer(request.text, return_tensors="pt")
57
 
58
  # Perform inference with the model
59
- outputs = model(**inputs)
 
 
 
 
60
 
61
- # Get the predicted token and decode it back to text
62
- prediction = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
 
63
 
64
  # Return the prediction response
65
- return PredictionResponse(text=request.text, prediction=prediction)
66
  except Exception as e:
67
- logging.error("Error during prediction", exc_info=True)
68
  raise HTTPException(status_code=500, detail="Prediction failed")
69
 
70
  # Define health check endpoint
71
  @app.get("/health")
72
  async def health_check():
73
- """
74
- Health check endpoint to verify if the service is up and running.
75
-
76
- Returns:
77
- dict: A dictionary containing the status of the service.
78
- """
79
- logging.info("Health check requested.")
80
  return {"status": "healthy"}
 
1
  import torch
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
 
5
 
6
  # Initialize the FastAPI app
7
  app = FastAPI()
8
 
9
+ # Load the model and tokenizer from Hugging Face
10
+ model_name = "Canstralian/RabbitRedux" # Replace with your model's name
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
  model.eval()
14
 
15
  # Define the input and output format for prediction requests
16
  class PredictionRequest(BaseModel):
 
 
 
 
 
 
17
  text: str
18
 
19
  class PredictionResponse(BaseModel):
 
 
 
 
 
 
 
20
  text: str
21
  prediction: str
22
 
23
  # Define prediction endpoint
24
  @app.post("/predict", response_model=PredictionResponse)
25
  async def predict(request: PredictionRequest):
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
+ # Tokenize the input text
28
+ inputs = tokenizer(request.text, return_tensors="pt", truncation=True, padding=True)
29
 
30
  # Perform inference with the model
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+
34
+ # Get the predicted class
35
+ prediction = torch.argmax(outputs.logits, dim=-1).item()
36
 
37
+ # Map the prediction to a label (adjust as per your model's labels)
38
+ labels = ["Label 1", "Label 2", "Label 3"] # Replace with your actual labels
39
+ predicted_label = labels[prediction]
40
 
41
  # Return the prediction response
42
+ return PredictionResponse(text=request.text, prediction=predicted_label)
43
  except Exception as e:
 
44
  raise HTTPException(status_code=500, detail="Prediction failed")
45
 
46
  # Define health check endpoint
47
  @app.get("/health")
48
  async def health_check():
 
 
 
 
 
 
 
49
  return {"status": "healthy"}