|
import torch |
|
from model import get_model |
|
from torchvision.transforms import ToTensor |
|
from PIL import Image |
|
import io |
|
|
|
|
|
NUM_CLASSES = 4 |
|
CONFIDENCE_THRESHOLD = 0.5 |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
""" |
|
Initialize the handler: load the model. |
|
""" |
|
|
|
self.model_weights_path = os.path.join(path, "model.pt") |
|
self.model = get_model(NUM_CLASSES).to(DEVICE) |
|
checkpoint = torch.load(self.model_weights_path, map_location=DEVICE) |
|
self.model.load_state_dict(checkpoint["model_state_dict"]) |
|
self.model.eval() |
|
|
|
|
|
self.preprocess = ToTensor() |
|
|
|
|
|
self.label_map = {1: "yellow", 2: "red", 3: "blue"} |
|
|
|
def preprocess_frame(self, image_bytes): |
|
""" |
|
Convert raw binary image data to a tensor. |
|
""" |
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE) |
|
return image_tensor |
|
|
|
def __call__(self, data): |
|
""" |
|
Process incoming raw binary image data. |
|
""" |
|
try: |
|
if "body" not in data: |
|
return {"error": "No image data provided in request."} |
|
|
|
image_bytes = data["body"] |
|
image_tensor = self.preprocess_frame(image_bytes) |
|
|
|
|
|
with torch.no_grad(): |
|
predictions = self.model(image_tensor) |
|
|
|
|
|
boxes = predictions[0]["boxes"].cpu().tolist() |
|
labels = predictions[0]["labels"].cpu().tolist() |
|
scores = predictions[0]["scores"].cpu().tolist() |
|
|
|
|
|
results = [] |
|
for box, label, score in zip(boxes, labels, scores): |
|
if score >= CONFIDENCE_THRESHOLD: |
|
x1, y1, x2, y2 = map(int, box) |
|
label_text = self.label_map.get(label, "unknown") |
|
results.append({ |
|
"box": [x1, y1, x2, y2], |
|
"label": label_text, |
|
"score": round(score, 2) |
|
}) |
|
|
|
return {"predictions": results} |
|
except Exception as e: |
|
return {"error": str(e)} |