ANON-STUDIOS-254 commited on
Commit
ff464a5
·
verified ·
1 Parent(s): d22d09b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
6
+ from torchvision.utils import draw_bounding_boxes
7
+ import matplotlib.pyplot as plt
8
+ from io import BytesIO
9
+
10
+ weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
11
+ categories = weights.meta["categories"]
12
+ img_preprocess = weights.transforms()
13
+
14
+ def load_model():
15
+ model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.5)
16
+ model.eval()
17
+ return model
18
+
19
+ model = load_model()
20
+
21
+ def make_prediction(img):
22
+ img_processed = img_preprocess(img)
23
+ prediction = model(img_processed.unsqueeze(0))
24
+ prediction = prediction[0]
25
+ prediction["labels"] = [categories[label] for label in prediction["labels"]]
26
+ return prediction
27
+
28
+ def create_image_with_bboxes(img, prediction):
29
+ img_tensor = torch.tensor(img)
30
+ img_with_bboxes = draw_bounding_boxes(img_tensor, boxes=prediction["boxes"], labels=prediction["labels"],
31
+ colors=["red" if label=="person" else "green" for label in prediction["labels"]], width=2)
32
+ img_with_bboxes_np = img_with_bboxes.detach().numpy().transpose(1,2,0)
33
+ return img_with_bboxes_np
34
+
35
+ def process_image(image):
36
+ img = Image.fromarray(image.astype('uint8'), 'RGB')
37
+ prediction = make_prediction(img)
38
+ img_with_bbox = create_image_with_bboxes(np.array(img).transpose(2,0,1), prediction)
39
+
40
+ fig = plt.figure(figsize=(12,12))
41
+ ax = fig.add_subplot(111)
42
+ plt.imshow(img_with_bbox)
43
+ plt.xticks([],[])
44
+ plt.yticks([],[])
45
+ ax.spines[["top", "bottom", "right", "left"]].set_visible(False)
46
+
47
+ plt.tight_layout()
48
+ plt.close(fig)
49
+
50
+ # Save plot to a BytesIO object
51
+ img_bytes = BytesIO()
52
+ fig.savefig(img_bytes, format='png')
53
+ img_bytes.seek(0)
54
+
55
+ # Create a summary of detected objects
56
+ detected_objects = []
57
+ for label, score in zip(prediction["labels"], prediction["scores"]):
58
+ detected_objects.append(f"{label}: {score:.2f}")
59
+
60
+ prediction_data = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in prediction.items()}
61
+ return Image.open(img_bytes), detected_objects, prediction_data
62
+
63
+ gr.Interface(
64
+ fn=process_image,
65
+ inputs=gr.Image(type="numpy"),
66
+ outputs=[gr.Image(type="pil"), gr.Textbox(), gr.JSON()],
67
+ title="OBJECT_DETECTOR_254",
68
+ description="Upload an image to detect objects and display bounding boxes along with a summary of detected objects.",
69
+ ).launch()