import gradio as gr import requests from PIL import Image, ImageDraw from transformers import AutoProcessor, AutoModelForCausalLM from io import BytesIO import torch # Set device device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load model and processor model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device) processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) # List of colors to cycle through for bounding boxes COLORS = ["red", "blue", "green", "yellow", "purple", "orange", "cyan", "magenta"] # Prediction function def predict_from_url(url): prompt = "" if not url: return {"Error": "Please input a URL"}, None try: image = Image.open(BytesIO(requests.get(url).content)) except Exception as e: return {"Error": f"Failed to load image: {str(e)}"}, None inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=4096, num_beams=3, do_sample=False ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) labels = parsed_answer.get('', {}).get('labels', []) bboxes = parsed_answer.get('', {}).get('bboxes', []) # Draw bounding boxes on the image draw = ImageDraw.Draw(image) legend = [] # Store legend entries for idx, (bbox, label) in enumerate(zip(bboxes, labels)): x1, y1, x2, y2 = bbox color = COLORS[idx % len(COLORS)] # Cycle through colors draw.rectangle([x1, y1, x2, y2], outline=color, width=3) legend.append(f"{label}: {color}") return "\n".join(legend), image # Gradio interface demo = gr.Interface( fn=predict_from_url, inputs=gr.Textbox(label="Enter Image URL"), outputs=[ gr.Textbox(label="Legend"), # Output the legend gr.Image(label="Image with Bounding Boxes") # Output the processed image ], title="Item Classifier with Bounding Boxes and Legend", allow_flagging="never" ) # Launch the interface demo.launch()