Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
from PIL import Image, ImageDraw | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from io import BytesIO | |
import torch | |
# Set device | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Load model and processor | |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) | |
# List of colors to cycle through for bounding boxes | |
COLORS = ["red", "blue", "green", "yellow", "purple", "orange", "cyan", "magenta"] | |
# Prediction function | |
def predict_from_url(url): | |
prompt = "<OD>" | |
if not url: | |
return {"Error": "Please input a URL"}, None | |
try: | |
image = Image.open(BytesIO(requests.get(url).content)) | |
except Exception as e: | |
return {"Error": f"Failed to load image: {str(e)}"}, None | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=4096, | |
num_beams=3, | |
do_sample=False | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) | |
labels = parsed_answer.get('<OD>', {}).get('labels', []) | |
bboxes = parsed_answer.get('<OD>', {}).get('bboxes', []) | |
# Draw bounding boxes on the image | |
draw = ImageDraw.Draw(image) | |
legend = [] # Store legend entries | |
for idx, (bbox, label) in enumerate(zip(bboxes, labels)): | |
x1, y1, x2, y2 = bbox | |
color = COLORS[idx % len(COLORS)] # Cycle through colors | |
draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
legend.append(f"{label}: {color}") | |
return "\n".join(legend), image | |
# Gradio interface | |
demo = gr.Interface( | |
fn=predict_from_url, | |
inputs=gr.Textbox(label="Enter Image URL"), | |
outputs=[ | |
gr.Textbox(label="Legend"), # Output the legend | |
gr.Image(label="Image with Bounding Boxes") # Output the processed image | |
], | |
title="Item Classifier with Bounding Boxes and Legend", | |
allow_flagging="never" | |
) | |
# Launch the interface | |
demo.launch() | |