xcurvnubaim commited on
Commit
a9f1bab
·
1 Parent(s): 765c987

feat: add object detection

Browse files
Files changed (2) hide show
  1. main.py +111 -2
  2. requirements.txt +4 -1
main.py CHANGED
@@ -3,11 +3,17 @@ from fastapi import FastAPI, File, UploadFile
3
  import tensorflow as tf
4
  from PIL import Image
5
  from io import BytesIO
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
  labels = []
10
- model = tf.keras.models.load_model('./models.h5')
 
 
11
  with open("labels.txt") as f:
12
  for line in f:
13
  labels.append(line.replace('\n', ''))
@@ -17,13 +23,116 @@ def classify_image(img):
17
  img_array = np.asarray(img.resize((224, 224)))[..., :3]
18
  img_array = img_array.reshape((1, 224, 224, 3)) # Add batch dimension
19
  img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
20
- prediction = model.predict(img_array).flatten()
21
  confidences = {labels[i]: float(prediction[i]) for i in range(90)}
22
  # Sort the confidences dictionary by value and get the top 3 items
23
  # top_3_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:3])
24
 
25
  return confidences
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  @app.post("/predict")
28
  async def predict(file: bytes = File(...)):
29
  img = Image.open(BytesIO(file))
 
3
  import tensorflow as tf
4
  from PIL import Image
5
  from io import BytesIO
6
+ from ultralytics import YOLO
7
+ import cv2
8
+ from datetime import datetime
9
+ from fastapi.responses import FileResponse
10
 
11
  app = FastAPI()
12
 
13
  labels = []
14
+ classification_model = tf.keras.models.load_model('./models.h5')
15
+ detection_model = YOLO('./best.pt')
16
+
17
  with open("labels.txt") as f:
18
  for line in f:
19
  labels.append(line.replace('\n', ''))
 
23
  img_array = np.asarray(img.resize((224, 224)))[..., :3]
24
  img_array = img_array.reshape((1, 224, 224, 3)) # Add batch dimension
25
  img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
26
+ prediction = classification_model.predict(img_array).flatten()
27
  confidences = {labels[i]: float(prediction[i]) for i in range(90)}
28
  # Sort the confidences dictionary by value and get the top 3 items
29
  # top_3_confidences = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:3])
30
 
31
  return confidences
32
 
33
+ def animal_detect_and_classify(img_path):
34
+ # Read the image
35
+ img = cv2.imread(img_path)
36
+
37
+ # Pass the image through the detection model and get the result
38
+ detect_results = detection_model(img)
39
+
40
+ combined_results = []
41
+ # print("dss", detect_results[0])
42
+ # Iterate over the detected objects
43
+ # Iterate over detections
44
+ for result in detect_results:
45
+ for box in result.boxes:
46
+ # print(box)
47
+ # Crop the RoI
48
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
49
+ detect_img = img[y1:y2, x1:x2]
50
+ # Convert the image to RGB format
51
+ detect_img = cv2.cvtColor(detect_img, cv2.COLOR_BGR2RGB)
52
+
53
+ # Resize the input image to the expected shape (224, 224)
54
+ detect_img = cv2.resize(detect_img, (224, 224))
55
+
56
+ # Convert the image to a numpy array
57
+ inp_array = np.array(detect_img)
58
+
59
+ # Reshape the array to match the expected input shape
60
+ inp_array = inp_array.reshape((-1, 224, 224, 3))
61
+
62
+ # Preprocess the input array
63
+ inp_array = tf.keras.applications.efficientnet.preprocess_input(inp_array)
64
+
65
+ # Make predictions using the classification model
66
+ prediction = classification_model.predict(inp_array)
67
+ # Map predictions to labels
68
+ threshold = 0.75
69
+ predicted_labels = [labels[np.argmax(pred)] if np.max(pred) >= threshold else "animal" for pred in prediction]
70
+ print(predicted_labels)
71
+ combined_results.append(((x1, y1, x2, y2), predicted_labels))
72
+
73
+ return combined_results
74
+
75
+ def generate_color(class_name):
76
+ # Generate a hash from the class name
77
+ color_hash = hash(class_name)
78
+ print(color_hash)
79
+ # Normalize the hash value to fit within the range of valid color values (0-255)
80
+ color_hash = abs(color_hash) % 16777216
81
+ R = color_hash//(256*256)
82
+ G = (color_hash//256) % 256
83
+ B = color_hash % 256
84
+ # Convert the hash value to RGB color format
85
+ color = (R, G, B)
86
+
87
+ return color
88
+
89
+ def plot_detected_rectangles(image, detections, output_path):
90
+ # Create a copy of the image to draw on
91
+ img_with_rectangles = image.copy()
92
+
93
+ # Iterate over each detected rectangle and its corresponding class name
94
+ for rectangle, class_names in detections:
95
+ # Extract the coordinates of the rectangle
96
+ x1, y1, x2, y2 = rectangle
97
+
98
+ # Generate a random color
99
+ color = generate_color(class_names[0])
100
+
101
+ # Draw the rectangle on the image
102
+ cv2.rectangle(img_with_rectangles, (x1, y1), (x2, y2), color, 2)
103
+
104
+ # Put the class names above the rectangle
105
+ for i, class_name in enumerate(class_names):
106
+ cv2.putText(img_with_rectangles, class_name, (x1, y1 - 10 - i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
107
+
108
+ # Show the image with rectangles and class names
109
+ cv2.imwrite(output_path, img_with_rectangles)
110
+
111
+
112
+ # Call the animal_detect_and_classify function to get detections
113
+ detections = animal_detect_and_classify('/content/cat_tiger.jpg')
114
+
115
+ # Plot the detected rectangles with their corresponding class names
116
+ plot_detected_rectangles(cv2.imread('/content/cat_tiger.jpg'), detections)
117
+
118
+
119
+ @app.post("/predict/v2")
120
+ async def predict_v2(file: UploadFile = File(...)):
121
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
122
+ filename = timestamp + file.filename
123
+ contents = await file.read()
124
+ image = Image.open(BytesIO(contents))
125
+ image.save("input/" + filename)
126
+ detections = animal_detect_and_classify("input/" + filename)
127
+ plot_detected_rectangles(cv2.imread("input/" + filename), detections, "output/" + filename)
128
+ return {"message": "Detection and classification completed successfully"}
129
+
130
+ @app.get("/image/")
131
+ async def get_image(image_name: str):
132
+ # Assume the images are stored in a directory named "images"
133
+ image_path = f"images/{image_name}"
134
+ return FileResponse(image_path)
135
+
136
  @app.post("/predict")
137
  async def predict(file: bytes = File(...)):
138
  img = Image.open(BytesIO(file))
requirements.txt CHANGED
@@ -10,4 +10,7 @@ uvicorn
10
  python-multipart
11
  numpy==1.25.2
12
  Pillow==9.4.0
13
- keras==2.15.0
 
 
 
 
10
  python-multipart
11
  numpy==1.25.2
12
  Pillow==9.4.0
13
+ keras==2.15.0
14
+ ultralytics
15
+ squarify
16
+ cv2