AbdulManan093 commited on
Commit
b8d1773
·
verified ·
1 Parent(s): b88bacc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py CHANGED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+ import gradio as gr
7
+ import io
8
+
9
+ # Load the processor and model
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
12
+
13
+ def detect_and_display_image(image):
14
+ # Ensure image is in PIL format
15
+ if isinstance(image, bytes):
16
+ image = Image.open(io.BytesIO(image))
17
+ elif isinstance(image, str):
18
+ image = Image.open(image)
19
+
20
+ # Process the image
21
+ inputs = processor(images=image, return_tensors="pt")
22
+
23
+ # Perform object detection
24
+ outputs = model(**inputs)
25
+
26
+ # Convert outputs to COCO API format
27
+ target_sizes = torch.tensor([image.size[::-1]])
28
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
29
+
30
+ # Create a figure and axis for visualization
31
+ fig, ax = plt.subplots(1, figsize=(12, 9))
32
+ ax.imshow(image)
33
+
34
+ # Add bounding boxes and labels to the image
35
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
36
+ box = [round(i, 2) for i in box.tolist()]
37
+ # Create a Rectangle patch
38
+ rect = patches.Rectangle(
39
+ (box[0], box[1]),
40
+ box[2] - box[0],
41
+ box[3] - box[1],
42
+ linewidth=2,
43
+ edgecolor='red',
44
+ facecolor='none'
45
+ )
46
+ # Add the patch to the Axes
47
+ ax.add_patch(rect)
48
+ # Add label and confidence score
49
+ plt.text(
50
+ box[0], box[1],
51
+ f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}',
52
+ color='red',
53
+ fontsize=12,
54
+ bbox=dict(facecolor='yellow', alpha=0.5)
55
+ )
56
+
57
+ plt.axis('off') # Hide the axes
58
+
59
+ # Save the figure to a BytesIO object and return it
60
+ buf = io.BytesIO()
61
+ plt.savefig(buf, format='png')
62
+ buf.seek(0)
63
+ return Image.open(buf)
64
+
65
+ # Create a Gradio interface
66
+ iface = gr.Interface(
67
+ fn=detect_and_display_image,
68
+ inputs=gr.Image(type="pil"),
69
+ outputs=gr.Image(type="pil"),
70
+ title="Object Detection with DETR",
71
+ description="Upload an image to detect objects using the DETR model.",
72
+ live=True
73
+ )
74
+
75
+ # Launch the Gradio app
76
+ iface.launch()
77
+
78
+