torinriley commited on
Commit
f53d612
·
verified ·
1 Parent(s): 70e7526

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +43 -35
handler.py CHANGED
@@ -1,64 +1,72 @@
1
  import torch
2
- from torchvision import transforms
 
3
  from PIL import Image
4
  import io
5
- import os
6
 
7
- from model import get_model
 
 
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path: str = ""):
11
  """
12
- Initialize the handler. Load the Faster R-CNN model.
13
  """
14
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- self.model_weights_path = os.path.join(path, "model.pt") # Adjust path
16
-
17
- # Load model
18
- self.model = get_model(num_classes=4)
19
- checkpoint = torch.load(self.model_weights_path, map_location=self.device)
20
  self.model.load_state_dict(checkpoint["model_state_dict"])
21
- self.model.to(self.device)
22
  self.model.eval()
23
 
24
- # Image preprocessing
25
- self.transform = transforms.Compose([
26
- transforms.Resize((640, 640)),
27
- transforms.ToTensor(),
28
- ])
 
 
 
 
 
 
 
 
 
29
 
30
  def __call__(self, data):
31
  """
32
- Process incoming binary image data and return object detection results.
33
  """
34
  try:
35
- # Read raw binary data (image file)
36
- image_bytes = data.get("body", b"")
37
- if not image_bytes:
38
  return {"error": "No image data provided in request."}
39
 
40
-
41
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
42
-
43
-
44
- input_tensor = self.transform(image).unsqueeze(0).to(self.device)
45
-
46
 
 
47
  with torch.no_grad():
48
- predictions = self.model(input_tensor)
49
 
50
-
51
  boxes = predictions[0]["boxes"].cpu().tolist()
52
  labels = predictions[0]["labels"].cpu().tolist()
53
  scores = predictions[0]["scores"].cpu().tolist()
54
 
55
-
56
- threshold = 0.5
57
- results = [
58
- {"box": box, "label": label, "score": score}
59
- for box, label, score in zip(boxes, labels, scores)
60
- if score > threshold
61
- ]
 
 
 
 
62
 
63
  return {"predictions": results}
64
  except Exception as e:
 
1
  import torch
2
+ from model import get_model
3
+ from torchvision.transforms import ToTensor
4
  from PIL import Image
5
  import io
 
6
 
7
+ # Constants
8
+ NUM_CLASSES = 4
9
+ CONFIDENCE_THRESHOLD = 0.5
10
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  class EndpointHandler:
13
  def __init__(self, path: str = ""):
14
  """
15
+ Initialize the handler: load the model.
16
  """
17
+ # Load the model
18
+ self.model_weights_path = os.path.join(path, "model.pt")
19
+ self.model = get_model(NUM_CLASSES).to(DEVICE)
20
+ checkpoint = torch.load(self.model_weights_path, map_location=DEVICE)
 
 
21
  self.model.load_state_dict(checkpoint["model_state_dict"])
 
22
  self.model.eval()
23
 
24
+ # Preprocessing function
25
+ self.preprocess = ToTensor()
26
+
27
+ # Class labels
28
+ self.label_map = {1: "yellow", 2: "red", 3: "blue"}
29
+
30
+ def preprocess_frame(self, image_bytes):
31
+ """
32
+ Convert raw binary image data to a tensor.
33
+ """
34
+ # Load image from binary data
35
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
+ image_tensor = self.preprocess(image).unsqueeze(0).to(DEVICE)
37
+ return image_tensor
38
 
39
  def __call__(self, data):
40
  """
41
+ Process incoming raw binary image data.
42
  """
43
  try:
44
+ if "body" not in data:
 
 
45
  return {"error": "No image data provided in request."}
46
 
47
+ image_bytes = data["body"]
48
+ image_tensor = self.preprocess_frame(image_bytes)
 
 
 
 
49
 
50
+ # Perform inference
51
  with torch.no_grad():
52
+ predictions = self.model(image_tensor)
53
 
54
+ # Extract predictions
55
  boxes = predictions[0]["boxes"].cpu().tolist()
56
  labels = predictions[0]["labels"].cpu().tolist()
57
  scores = predictions[0]["scores"].cpu().tolist()
58
 
59
+ # Filter predictions by confidence threshold
60
+ results = []
61
+ for box, label, score in zip(boxes, labels, scores):
62
+ if score >= CONFIDENCE_THRESHOLD:
63
+ x1, y1, x2, y2 = map(int, box)
64
+ label_text = self.label_map.get(label, "unknown")
65
+ results.append({
66
+ "box": [x1, y1, x2, y2],
67
+ "label": label_text,
68
+ "score": round(score, 2)
69
+ })
70
 
71
  return {"predictions": results}
72
  except Exception as e: