Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from PIL import Image | |
from diffusers import FluxInpaintPipeline | |
from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK | |
from utils.sam import load_sam_image_model, run_sam_inference | |
import gradio as gr | |
import supervision as sv | |
import spaces | |
# Load models | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
FLUX_PIPE = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to( | |
DEVICE) | |
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) | |
SAM_MODEL = load_sam_image_model(device=DEVICE) | |
COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2'] | |
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS) | |
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
LABEL_ANNOTATOR = sv.LabelAnnotator( | |
color=COLOR_PALETTE, | |
color_lookup=sv.ColorLookup.INDEX, | |
text_position=sv.Position.CENTER_OF_MASS, | |
text_color=sv.Color.from_hex("#000000"), | |
border_radius=5 | |
) | |
MASK_ANNOTATOR = sv.MaskAnnotator( | |
color=COLOR_PALETTE, | |
color_lookup=sv.ColorLookup.INDEX | |
) | |
def visualize_detections(image, detections): | |
output_image = image.copy() | |
output_image = MASK_ANNOTATOR.annotate(output_image, detections) | |
output_image = BOX_ANNOTATOR.annotate(output_image, detections) | |
output_image = LABEL_ANNOTATOR.annotate(output_image, detections) | |
return output_image | |
def detect_objects(image, text_prompt): | |
# Use Florence for object detection | |
_, result = run_florence_inference( | |
model=FLORENCE_MODEL, | |
processor=FLORENCE_PROCESSOR, | |
device=DEVICE, | |
image=image, | |
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
text=text_prompt | |
) | |
detections = sv.Detections.from_lmm( | |
lmm=sv.LMM.FLORENCE_2, | |
result=result, | |
resolution_wh=image.size | |
) | |
# Use SAM to refine masks | |
detections = run_sam_inference(SAM_MODEL, image, detections) | |
return detections | |
def inpaint_selected_objects(image, detections, selected_indices, inpaint_prompt): | |
mask = np.zeros(image.size[::-1], dtype=np.uint8) | |
for idx in selected_indices: | |
mask |= detections.mask[idx] | |
mask_image = Image.fromarray(mask * 255) | |
result = FLUX_PIPE( | |
prompt=inpaint_prompt, | |
image=image, | |
mask_image=mask_image, | |
num_inference_steps=30, | |
strength=0.85, | |
).images[0] | |
return result | |
def process_image(input_image, detection_prompt, inpaint_prompt, selected_objects): | |
detections = detect_objects(input_image, detection_prompt) | |
# Visualize detected objects | |
detected_image = visualize_detections(input_image, detections) | |
if selected_objects: | |
selected_indices = [int(idx) for idx in selected_objects.split(',')] | |
inpainted_image = inpaint_selected_objects(input_image, detections, selected_indices, inpaint_prompt) | |
return detected_image, inpainted_image | |
else: | |
return detected_image, None | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Object Detection and Inpainting with FLUX, Florence, and SAM") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
detection_prompt = gr.Textbox(label="Detection Prompt", placeholder="Enter objects to detect") | |
detect_button = gr.Button("Detect Objects") | |
with gr.Column(): | |
detected_image = gr.Image(type="pil", label="Detected Objects") | |
selected_objects = gr.Textbox(label="Selected Objects", | |
placeholder="Enter indices of objects to inpaint (comma-separated)") | |
inpaint_prompt = gr.Textbox(label="Inpainting Prompt", placeholder="Describe what to inpaint") | |
inpaint_button = gr.Button("Inpaint Selected Objects") | |
output_image = gr.Image(type="pil", label="Inpainted Result") | |
detect_button.click( | |
fn=lambda img, prompt: process_image(img, prompt, "", "")[0], | |
inputs=[input_image, detection_prompt], | |
outputs=detected_image | |
) | |
inpaint_button.click( | |
fn=process_image, | |
inputs=[input_image, detection_prompt, inpaint_prompt, selected_objects], | |
outputs=[detected_image, output_image] | |
) | |
demo.launch(debug=False, show_error=True) |