FTCVision-PyTorch / handler.py
torinriley's picture
Update handler.py
f53d612 verified
raw
history blame
2.47 kB
import torch
from model import get_model
from torchvision.transforms import ToTensor
from PIL import Image
import io
# Constants
NUM_CLASSES = 4
CONFIDENCE_THRESHOLD = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler: load the model.
"""
# Load the model
self.model_weights_path = os.path.join(path, "model.pt")
self.model = get_model(NUM_CLASSES).to(DEVICE)
checkpoint = torch.load(self.model_weights_path, map_location=DEVICE)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
# Preprocessing function
self.preprocess = ToTensor()
# Class labels
self.label_map = {1: "yellow", 2: "red", 3: "blue"}
def preprocess_frame(self, image_bytes):
"""
Convert raw binary image data to a tensor.
"""
# Load image from binary data
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE)
return image_tensor
def __call__(self, data):
"""
Process incoming raw binary image data.
"""
try:
if "body" not in data:
return {"error": "No image data provided in request."}
image_bytes = data["body"]
image_tensor = self.preprocess_frame(image_bytes)
# Perform inference
with torch.no_grad():
predictions = self.model(image_tensor)
# Extract predictions
boxes = predictions[0]["boxes"].cpu().tolist()
labels = predictions[0]["labels"].cpu().tolist()
scores = predictions[0]["scores"].cpu().tolist()
# Filter predictions by confidence threshold
results = []
for box, label, score in zip(boxes, labels, scores):
if score >= CONFIDENCE_THRESHOLD:
x1, y1, x2, y2 = map(int, box)
label_text = self.label_map.get(label, "unknown")
results.append({
"box": [x1, y1, x2, y2],
"label": label_text,
"score": round(score, 2)
})
return {"predictions": results}
except Exception as e:
return {"error": str(e)}