Hzqhssn commited on
Commit
ce2dc14
·
1 Parent(s): d4a7e13

initial push

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +69 -0
  3. requirements.txt +10 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Virtual environments
2
+ venv
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from PIL import Image, ImageDraw
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
+ from io import BytesIO
6
+ import torch
7
+
8
+ # Set device
9
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
11
+
12
+ # Load model and processor
13
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
14
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
15
+
16
+ # List of colors to cycle through for bounding boxes
17
+ COLORS = ["red", "blue", "green", "yellow", "purple", "orange", "cyan", "magenta"]
18
+
19
+ # Prediction function
20
+ def predict_from_url(url):
21
+ prompt = "<OD>"
22
+ if not url:
23
+ return {"Error": "Please input a URL"}, None
24
+
25
+ try:
26
+ image = Image.open(BytesIO(requests.get(url).content))
27
+ except Exception as e:
28
+ return {"Error": f"Failed to load image: {str(e)}"}, None
29
+
30
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
31
+ generated_ids = model.generate(
32
+ input_ids=inputs["input_ids"],
33
+ pixel_values=inputs["pixel_values"],
34
+ max_new_tokens=4096,
35
+ num_beams=3,
36
+ do_sample=False
37
+ )
38
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
39
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
40
+
41
+ labels = parsed_answer.get('<OD>', {}).get('labels', [])
42
+ bboxes = parsed_answer.get('<OD>', {}).get('bboxes', [])
43
+
44
+ # Draw bounding boxes on the image
45
+ draw = ImageDraw.Draw(image)
46
+ legend = [] # Store legend entries
47
+ for idx, (bbox, label) in enumerate(zip(bboxes, labels)):
48
+ x1, y1, x2, y2 = bbox
49
+ color = COLORS[idx % len(COLORS)] # Cycle through colors
50
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
51
+ legend.append(f"{label}: {color}")
52
+
53
+
54
+ return "\n".join(legend), image
55
+
56
+ # Gradio interface
57
+ demo = gr.Interface(
58
+ fn=predict_from_url,
59
+ inputs=gr.Textbox(label="Enter Image URL"),
60
+ outputs=[
61
+ gr.Textbox(label="Legend"), # Output the legend
62
+ gr.Image(label="Image with Bounding Boxes") # Output the processed image
63
+ ],
64
+ title="Item Classifier with Bounding Boxes and Legend",
65
+ allow_flagging="never"
66
+ )
67
+
68
+ # Launch the interface
69
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ requests
4
+ Pillow
5
+ open_clip_torch
6
+ ftfy
7
+ einops
8
+
9
+ # This is only needed for local deployment
10
+ gradio