ANON-STUDIOS-254
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
|
6 |
+
from torchvision.utils import draw_bounding_boxes
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from io import BytesIO
|
9 |
+
|
10 |
+
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
|
11 |
+
categories = weights.meta["categories"]
|
12 |
+
img_preprocess = weights.transforms()
|
13 |
+
|
14 |
+
def load_model():
|
15 |
+
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.5)
|
16 |
+
model.eval()
|
17 |
+
return model
|
18 |
+
|
19 |
+
model = load_model()
|
20 |
+
|
21 |
+
def make_prediction(img):
|
22 |
+
img_processed = img_preprocess(img)
|
23 |
+
prediction = model(img_processed.unsqueeze(0))
|
24 |
+
prediction = prediction[0]
|
25 |
+
prediction["labels"] = [categories[label] for label in prediction["labels"]]
|
26 |
+
return prediction
|
27 |
+
|
28 |
+
def create_image_with_bboxes(img, prediction):
|
29 |
+
img_tensor = torch.tensor(img)
|
30 |
+
img_with_bboxes = draw_bounding_boxes(img_tensor, boxes=prediction["boxes"], labels=prediction["labels"],
|
31 |
+
colors=["red" if label=="person" else "green" for label in prediction["labels"]], width=2)
|
32 |
+
img_with_bboxes_np = img_with_bboxes.detach().numpy().transpose(1,2,0)
|
33 |
+
return img_with_bboxes_np
|
34 |
+
|
35 |
+
def process_image(image):
|
36 |
+
img = Image.fromarray(image.astype('uint8'), 'RGB')
|
37 |
+
prediction = make_prediction(img)
|
38 |
+
img_with_bbox = create_image_with_bboxes(np.array(img).transpose(2,0,1), prediction)
|
39 |
+
|
40 |
+
fig = plt.figure(figsize=(12,12))
|
41 |
+
ax = fig.add_subplot(111)
|
42 |
+
plt.imshow(img_with_bbox)
|
43 |
+
plt.xticks([],[])
|
44 |
+
plt.yticks([],[])
|
45 |
+
ax.spines[["top", "bottom", "right", "left"]].set_visible(False)
|
46 |
+
|
47 |
+
plt.tight_layout()
|
48 |
+
plt.close(fig)
|
49 |
+
|
50 |
+
# Save plot to a BytesIO object
|
51 |
+
img_bytes = BytesIO()
|
52 |
+
fig.savefig(img_bytes, format='png')
|
53 |
+
img_bytes.seek(0)
|
54 |
+
|
55 |
+
# Create a summary of detected objects
|
56 |
+
detected_objects = []
|
57 |
+
for label, score in zip(prediction["labels"], prediction["scores"]):
|
58 |
+
detected_objects.append(f"{label}: {score:.2f}")
|
59 |
+
|
60 |
+
prediction_data = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in prediction.items()}
|
61 |
+
return Image.open(img_bytes), detected_objects, prediction_data
|
62 |
+
|
63 |
+
gr.Interface(
|
64 |
+
fn=process_image,
|
65 |
+
inputs=gr.Image(type="numpy"),
|
66 |
+
outputs=[gr.Image(type="pil"), gr.Textbox(), gr.JSON()],
|
67 |
+
title="OBJECT_DETECTOR_254",
|
68 |
+
description="Upload an image to detect objects and display bounding boxes along with a summary of detected objects.",
|
69 |
+
).launch()
|