AbdulManan093's picture
Update app.py
b8d1773 verified
raw
history blame
2.41 kB
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import gradio as gr
import io
# Load the processor and model
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
def detect_and_display_image(image):
# Ensure image is in PIL format
if isinstance(image, bytes):
image = Image.open(io.BytesIO(image))
elif isinstance(image, str):
image = Image.open(image)
# Process the image
inputs = processor(images=image, return_tensors="pt")
# Perform object detection
outputs = model(**inputs)
# Convert outputs to COCO API format
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Create a figure and axis for visualization
fig, ax = plt.subplots(1, figsize=(12, 9))
ax.imshow(image)
# Add bounding boxes and labels to the image
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
# Create a Rectangle patch
rect = patches.Rectangle(
(box[0], box[1]),
box[2] - box[0],
box[3] - box[1],
linewidth=2,
edgecolor='red',
facecolor='none'
)
# Add the patch to the Axes
ax.add_patch(rect)
# Add label and confidence score
plt.text(
box[0], box[1],
f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}',
color='red',
fontsize=12,
bbox=dict(facecolor='yellow', alpha=0.5)
)
plt.axis('off') # Hide the axes
# Save the figure to a BytesIO object and return it
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return Image.open(buf)
# Create a Gradio interface
iface = gr.Interface(
fn=detect_and_display_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Object Detection with DETR",
description="Upload an image to detect objects using the DETR model.",
live=True
)
# Launch the Gradio app
iface.launch()