ighoshsubho's picture
Bug fix spaces
566b641 verified
import torch
import numpy as np
from PIL import Image
from diffusers import FluxInpaintPipeline
from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
import gradio as gr
import supervision as sv
import spaces
# Load models
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FLUX_PIPE = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(
DEVICE)
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_MODEL = load_sam_image_model(device=DEVICE)
COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LABEL_ANNOTATOR = sv.LabelAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_color=sv.Color.from_hex("#000000"),
border_radius=5
)
MASK_ANNOTATOR = sv.MaskAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.INDEX
)
def visualize_detections(image, detections):
output_image = image.copy()
output_image = MASK_ANNOTATOR.annotate(output_image, detections)
output_image = BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
return output_image
def detect_objects(image, text_prompt):
# Use Florence for object detection
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text_prompt
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image.size
)
# Use SAM to refine masks
detections = run_sam_inference(SAM_MODEL, image, detections)
return detections
def inpaint_selected_objects(image, detections, selected_indices, inpaint_prompt):
mask = np.zeros(image.size[::-1], dtype=np.uint8)
for idx in selected_indices:
mask |= detections.mask[idx]
mask_image = Image.fromarray(mask * 255)
result = FLUX_PIPE(
prompt=inpaint_prompt,
image=image,
mask_image=mask_image,
num_inference_steps=30,
strength=0.85,
).images[0]
return result
@spaces.GPU(duration=200)
def process_image(input_image, detection_prompt, inpaint_prompt, selected_objects):
detections = detect_objects(input_image, detection_prompt)
# Visualize detected objects
detected_image = visualize_detections(input_image, detections)
if selected_objects:
selected_indices = [int(idx) for idx in selected_objects.split(',')]
inpainted_image = inpaint_selected_objects(input_image, detections, selected_indices, inpaint_prompt)
return detected_image, inpainted_image
else:
return detected_image, None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Object Detection and Inpainting with FLUX, Florence, and SAM")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
detection_prompt = gr.Textbox(label="Detection Prompt", placeholder="Enter objects to detect")
detect_button = gr.Button("Detect Objects")
with gr.Column():
detected_image = gr.Image(type="pil", label="Detected Objects")
selected_objects = gr.Textbox(label="Selected Objects",
placeholder="Enter indices of objects to inpaint (comma-separated)")
inpaint_prompt = gr.Textbox(label="Inpainting Prompt", placeholder="Describe what to inpaint")
inpaint_button = gr.Button("Inpaint Selected Objects")
output_image = gr.Image(type="pil", label="Inpainted Result")
detect_button.click(
fn=lambda img, prompt: process_image(img, prompt, "", "")[0],
inputs=[input_image, detection_prompt],
outputs=detected_image
)
inpaint_button.click(
fn=process_image,
inputs=[input_image, detection_prompt, inpaint_prompt, selected_objects],
outputs=[detected_image, output_image]
)
demo.launch(debug=False, show_error=True)