objectDetection / app.py
Hzqhssn's picture
initial push
ce2dc14
import gradio as gr
import requests
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForCausalLM
from io import BytesIO
import torch
# Set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load model and processor
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# List of colors to cycle through for bounding boxes
COLORS = ["red", "blue", "green", "yellow", "purple", "orange", "cyan", "magenta"]
# Prediction function
def predict_from_url(url):
prompt = "<OD>"
if not url:
return {"Error": "Please input a URL"}, None
try:
image = Image.open(BytesIO(requests.get(url).content))
except Exception as e:
return {"Error": f"Failed to load image: {str(e)}"}, None
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=4096,
num_beams=3,
do_sample=False
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
labels = parsed_answer.get('<OD>', {}).get('labels', [])
bboxes = parsed_answer.get('<OD>', {}).get('bboxes', [])
# Draw bounding boxes on the image
draw = ImageDraw.Draw(image)
legend = [] # Store legend entries
for idx, (bbox, label) in enumerate(zip(bboxes, labels)):
x1, y1, x2, y2 = bbox
color = COLORS[idx % len(COLORS)] # Cycle through colors
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
legend.append(f"{label}: {color}")
return "\n".join(legend), image
# Gradio interface
demo = gr.Interface(
fn=predict_from_url,
inputs=gr.Textbox(label="Enter Image URL"),
outputs=[
gr.Textbox(label="Legend"), # Output the legend
gr.Image(label="Image with Bounding Boxes") # Output the processed image
],
title="Item Classifier with Bounding Boxes and Legend",
allow_flagging="never"
)
# Launch the interface
demo.launch()