Spaces:
Runtime error
Runtime error
ighoshsubho
commited on
Commit
·
9aecc37
0
Parent(s):
Florence sam flux first commit
Browse files- .gitignore +3 -0
- README.md +12 -0
- app.py +121 -0
- requirements.txt +13 -0
- utils/florence.py +58 -0
- utils/sam.py +45 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
/venv
|
2 |
+
/.idea
|
3 |
+
/tmp
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Florence2 + SAM2 + FLUX
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.40.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from diffusers import FluxInpaintPipeline
|
5 |
+
from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
|
6 |
+
from utils.sam import load_sam_image_model, run_sam_inference
|
7 |
+
import gradio as gr
|
8 |
+
import supervision as sv
|
9 |
+
|
10 |
+
# Load models
|
11 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
FLUX_PIPE = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(
|
13 |
+
DEVICE)
|
14 |
+
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
|
15 |
+
SAM_MODEL = load_sam_image_model(device=DEVICE)
|
16 |
+
|
17 |
+
COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
|
18 |
+
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
|
19 |
+
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
|
20 |
+
LABEL_ANNOTATOR = sv.LabelAnnotator(
|
21 |
+
color=COLOR_PALETTE,
|
22 |
+
color_lookup=sv.ColorLookup.INDEX,
|
23 |
+
text_position=sv.Position.CENTER_OF_MASS,
|
24 |
+
text_color=sv.Color.from_hex("#000000"),
|
25 |
+
border_radius=5
|
26 |
+
)
|
27 |
+
MASK_ANNOTATOR = sv.MaskAnnotator(
|
28 |
+
color=COLOR_PALETTE,
|
29 |
+
color_lookup=sv.ColorLookup.INDEX
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def visualize_detections(image, detections):
|
34 |
+
output_image = image.copy()
|
35 |
+
output_image = MASK_ANNOTATOR.annotate(output_image, detections)
|
36 |
+
output_image = BOX_ANNOTATOR.annotate(output_image, detections)
|
37 |
+
output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
|
38 |
+
return output_image
|
39 |
+
|
40 |
+
|
41 |
+
def detect_objects(image, text_prompt):
|
42 |
+
# Use Florence for object detection
|
43 |
+
_, result = run_florence_inference(
|
44 |
+
model=FLORENCE_MODEL,
|
45 |
+
processor=FLORENCE_PROCESSOR,
|
46 |
+
device=DEVICE,
|
47 |
+
image=image,
|
48 |
+
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
|
49 |
+
text=text_prompt
|
50 |
+
)
|
51 |
+
detections = sv.Detections.from_lmm(
|
52 |
+
lmm=sv.LMM.FLORENCE_2,
|
53 |
+
result=result,
|
54 |
+
resolution_wh=image.size
|
55 |
+
)
|
56 |
+
|
57 |
+
# Use SAM to refine masks
|
58 |
+
detections = run_sam_inference(SAM_MODEL, image, detections)
|
59 |
+
return detections
|
60 |
+
|
61 |
+
|
62 |
+
def inpaint_selected_objects(image, detections, selected_indices, inpaint_prompt):
|
63 |
+
mask = np.zeros(image.size[::-1], dtype=np.uint8)
|
64 |
+
for idx in selected_indices:
|
65 |
+
mask |= detections.mask[idx]
|
66 |
+
|
67 |
+
mask_image = Image.fromarray(mask * 255)
|
68 |
+
|
69 |
+
result = FLUX_PIPE(
|
70 |
+
prompt=inpaint_prompt,
|
71 |
+
image=image,
|
72 |
+
mask_image=mask_image,
|
73 |
+
num_inference_steps=30,
|
74 |
+
strength=0.85,
|
75 |
+
).images[0]
|
76 |
+
|
77 |
+
return result
|
78 |
+
|
79 |
+
|
80 |
+
def process_image(input_image, detection_prompt, inpaint_prompt, selected_objects):
|
81 |
+
detections = detect_objects(input_image, detection_prompt)
|
82 |
+
|
83 |
+
# Visualize detected objects
|
84 |
+
detected_image = visualize_detections(input_image, detections)
|
85 |
+
|
86 |
+
if selected_objects:
|
87 |
+
selected_indices = [int(idx) for idx in selected_objects.split(',')]
|
88 |
+
inpainted_image = inpaint_selected_objects(input_image, detections, selected_indices, inpaint_prompt)
|
89 |
+
return detected_image, inpainted_image
|
90 |
+
else:
|
91 |
+
return detected_image, None
|
92 |
+
|
93 |
+
|
94 |
+
# Gradio interface
|
95 |
+
with gr.Blocks() as demo:
|
96 |
+
gr.Markdown("# Object Detection and Inpainting with FLUX, Florence, and SAM")
|
97 |
+
with gr.Row():
|
98 |
+
with gr.Column():
|
99 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
100 |
+
detection_prompt = gr.Textbox(label="Detection Prompt", placeholder="Enter objects to detect")
|
101 |
+
detect_button = gr.Button("Detect Objects")
|
102 |
+
with gr.Column():
|
103 |
+
detected_image = gr.Image(type="pil", label="Detected Objects")
|
104 |
+
selected_objects = gr.Textbox(label="Selected Objects",
|
105 |
+
placeholder="Enter indices of objects to inpaint (comma-separated)")
|
106 |
+
inpaint_prompt = gr.Textbox(label="Inpainting Prompt", placeholder="Describe what to inpaint")
|
107 |
+
inpaint_button = gr.Button("Inpaint Selected Objects")
|
108 |
+
output_image = gr.Image(type="pil", label="Inpainted Result")
|
109 |
+
|
110 |
+
detect_button.click(
|
111 |
+
fn=lambda img, prompt: process_image(img, prompt, "", "")[0],
|
112 |
+
inputs=[input_image, detection_prompt],
|
113 |
+
outputs=detected_image
|
114 |
+
)
|
115 |
+
inpaint_button.click(
|
116 |
+
fn=process_image,
|
117 |
+
inputs=[input_image, detection_prompt, inpaint_prompt, selected_objects],
|
118 |
+
outputs=[detected_image, output_image]
|
119 |
+
)
|
120 |
+
|
121 |
+
demo.launch(debug=False, show_error=True)
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm
|
2 |
+
einops
|
3 |
+
spaces
|
4 |
+
timm
|
5 |
+
transformers
|
6 |
+
samv2
|
7 |
+
gradio
|
8 |
+
supervision
|
9 |
+
opencv-python
|
10 |
+
pytest
|
11 |
+
torch
|
12 |
+
numpy
|
13 |
+
diffusers
|
utils/florence.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union, Any, Tuple, Dict
|
3 |
+
from unittest.mock import patch
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
8 |
+
from transformers.dynamic_module_utils import get_imports
|
9 |
+
|
10 |
+
FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
|
11 |
+
FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
|
12 |
+
FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
|
13 |
+
FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
|
14 |
+
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
|
15 |
+
FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
|
16 |
+
|
17 |
+
|
18 |
+
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
|
19 |
+
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
|
20 |
+
if not str(filename).endswith("/modeling_florence2.py"):
|
21 |
+
return get_imports(filename)
|
22 |
+
imports = get_imports(filename)
|
23 |
+
imports.remove("flash_attn")
|
24 |
+
return imports
|
25 |
+
|
26 |
+
|
27 |
+
def load_florence_model(
|
28 |
+
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
|
29 |
+
) -> Tuple[Any, Any]:
|
30 |
+
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
31 |
+
model = AutoModelForCausalLM.from_pretrained(
|
32 |
+
checkpoint, trust_remote_code=True).to(device).eval()
|
33 |
+
processor = AutoProcessor.from_pretrained(
|
34 |
+
checkpoint, trust_remote_code=True)
|
35 |
+
return model, processor
|
36 |
+
|
37 |
+
|
38 |
+
def run_florence_inference(
|
39 |
+
model: Any,
|
40 |
+
processor: Any,
|
41 |
+
device: torch.device,
|
42 |
+
image: Image,
|
43 |
+
task: str,
|
44 |
+
text: str = ""
|
45 |
+
) -> Tuple[str, Dict]:
|
46 |
+
prompt = task + text
|
47 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
|
48 |
+
generated_ids = model.generate(
|
49 |
+
input_ids=inputs["input_ids"],
|
50 |
+
pixel_values=inputs["pixel_values"],
|
51 |
+
max_new_tokens=1024,
|
52 |
+
num_beams=3
|
53 |
+
)
|
54 |
+
generated_text = processor.batch_decode(
|
55 |
+
generated_ids, skip_special_tokens=False)[0]
|
56 |
+
response = processor.post_process_generation(
|
57 |
+
generated_text, task=task, image_size=image.size)
|
58 |
+
return generated_text, response
|
utils/sam.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import supervision as sv
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
8 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
9 |
+
|
10 |
+
SAM_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
|
11 |
+
SAM_CONFIG = "sam2_hiera_s.yaml"
|
12 |
+
|
13 |
+
|
14 |
+
def load_sam_image_model(
|
15 |
+
device: torch.device,
|
16 |
+
config: str = SAM_CONFIG,
|
17 |
+
checkpoint: str = SAM_CHECKPOINT
|
18 |
+
) -> SAM2ImagePredictor:
|
19 |
+
model = build_sam2(config, checkpoint, device=device)
|
20 |
+
return SAM2ImagePredictor(sam_model=model)
|
21 |
+
|
22 |
+
|
23 |
+
def load_sam_video_model(
|
24 |
+
device: torch.device,
|
25 |
+
config: str = SAM_CONFIG,
|
26 |
+
checkpoint: str = SAM_CHECKPOINT
|
27 |
+
) -> Any:
|
28 |
+
return build_sam2_video_predictor(config, checkpoint, device=device)
|
29 |
+
|
30 |
+
|
31 |
+
def run_sam_inference(
|
32 |
+
model: Any,
|
33 |
+
image: Image,
|
34 |
+
detections: sv.Detections
|
35 |
+
) -> sv.Detections:
|
36 |
+
image = np.array(image.convert("RGB"))
|
37 |
+
model.set_image(image)
|
38 |
+
mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
|
39 |
+
|
40 |
+
# dirty fix; remove this later
|
41 |
+
if len(mask.shape) == 4:
|
42 |
+
mask = np.squeeze(mask)
|
43 |
+
|
44 |
+
detections.mask = mask.astype(bool)
|
45 |
+
return detections
|