Fabrice-TIERCELIN's picture
This PR allows to upload existing mask
208244e verified
raw
history blame
11.9 kB
from functools import partial
import cv2
import random
from typing import Tuple, Optional
import gradio as gr
import numpy as np
import requests
import spaces
import torch
from PIL import Image, ImageFilter
from diffusers import FluxInpaintPipeline
from gradio_client import Client, handle_file
MARKDOWN = """
# FLUX.1 Inpainting 🔥
Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
for taking it to the next level by enabling inpainting with the FLUX.
"""
MAX_SEED = np.iinfo(np.int32).max
IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIPE = FluxInpaintPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
CLIENT = Client("SkalskiP/florence-sam-masking")
EXAMPLES = [
[
{
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
"layers": [Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2-removebg.png", stream=True).raw)],
"composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
},
"little lion",
"",
5,
5,
42,
False,
0.85,
20,
None
],
[
{
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-5.jpeg", stream=True).raw),
"layers": None,
"composite": None
},
"big blue eyes",
"eyes",
10,
5,
42,
False,
0.9,
20,
None
]
]
def calculate_image_dimensions_for_flux(
original_resolution_wh: Tuple[int, int],
maximum_dimension: int = IMAGE_SIZE
) -> Tuple[int, int]:
width, height = original_resolution_wh
if width > height:
scaling_factor = maximum_dimension / width
else:
scaling_factor = maximum_dimension / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
new_width = new_width - (new_width % 32)
new_height = new_height - (new_height % 32)
return new_width, new_height
def is_mask_empty(image: Image.Image) -> bool:
gray_img = image.convert("L")
pixels = list(gray_img.getdata())
return all(pixel == 0 for pixel in pixels)
def process_mask(
mask: Image.Image,
mask_inflation: Optional[int] = None,
mask_blur: Optional[int] = None
) -> Image.Image:
"""
Inflates and blurs the white regions of a mask.
Args:
mask (Image.Image): The input mask image.
mask_inflation (Optional[int]): The number of pixels to inflate the mask by.
mask_blur (Optional[int]): The radius of the Gaussian blur to apply.
Returns:
Image.Image: The processed mask with inflated and/or blurred regions.
"""
if mask_inflation and mask_inflation > 0:
mask_array = np.array(mask)
kernel = np.ones((mask_inflation, mask_inflation), np.uint8)
mask_array = cv2.dilate(mask_array, kernel, iterations=1)
mask = Image.fromarray(mask_array)
if mask_blur and mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
return mask
def set_client_for_session(request: gr.Request):
try:
x_ip_token = request.headers['x-ip-token']
return Client("SkalskiP/florence-sam-masking", headers={"X-IP-Token": x_ip_token})
except:
return CLIENT
@spaces.GPU(duration=50)
def run_flux(
image: Image.Image,
mask: Image.Image,
prompt: str,
seed_slicer: int,
randomize_seed_checkbox: bool,
strength_slider: float,
num_inference_steps_slider: int,
resolution_wh: Tuple[int, int],
) -> Image.Image:
print("Running FLUX...")
width, height = resolution_wh
if randomize_seed_checkbox:
seed_slicer = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed_slicer)
return PIPE(
prompt=prompt,
image=image,
mask_image=mask,
width=width,
height=height,
strength=strength_slider,
generator=generator,
num_inference_steps=num_inference_steps_slider
).images[0]
def process(
client,
input_image_editor: dict,
inpainting_prompt_text: str,
masking_prompt_text: str,
mask_inflation_slider: int,
mask_blur_slider: int,
seed_slicer: int,
randomize_seed_checkbox: bool,
strength_slider: float,
num_inference_steps_slider: int,
uploaded_mask: Image.Image
):
if not inpainting_prompt_text:
gr.Info("Please enter inpainting text prompt.")
return None, None
image_path = input_image_editor['background']
image = Image.open(image_path)
if uploaded_mask is None:
mask_path = input_image_editor['layers'][0]
mask = Image.open(mask_path)
else:
mask = uploaded_mask
if not image:
gr.Info("Please upload an image.")
return None, None
if is_mask_empty(mask) and not masking_prompt_text:
gr.Info("Please draw a mask, upload a mask or enter a masking prompt.")
return None, None
if not is_mask_empty(mask) and masking_prompt_text:
gr.Info("Both mask and masking prompt are provided. Please provide only one.")
return None, None
if is_mask_empty(mask):
print("Generating mask...")
mask = client.predict(
image_input=handle_file(image_path),
text_input=masking_prompt_text,
api_name="/process_image")
mask = Image.open(mask)
print("Mask generated.")
width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
image = image.resize((width, height), Image.LANCZOS)
mask = mask.resize((width, height), Image.LANCZOS)
mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
image = run_flux(
image=image,
mask=mask,
prompt=inpainting_prompt_text,
seed_slicer=seed_slicer,
randomize_seed_checkbox=randomize_seed_checkbox,
strength_slider=strength_slider,
num_inference_steps_slider=num_inference_steps_slider,
resolution_wh=(width, height)
)
return image, mask
process_example = partial(process, client=CLIENT)
with gr.Blocks() as demo:
client_component = gr.State()
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
input_image_editor_component = gr.ImageEditor(
label='Image',
type='filepath',
sources=["upload", "webcam"],
image_mode='RGB',
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
with gr.Row():
inpainting_prompt_text_component = gr.Text(
label="Inpainting prompt",
show_label=False,
max_lines=1,
placeholder="Enter text to generate inpainting",
container=False,
)
submit_button_component = gr.Button(
value='Submit', variant='primary', scale=0)
with gr.Accordion("Upload a mask", open = False):
uploaded_mask_component = gr.Image(label = "Already made mask (white pixels will be preserved, black pixels will be redrawn)", sources = ["upload"], type = "pil")
with gr.Accordion("Advanced Settings", open=False):
masking_prompt_text_component = gr.Text(
label="Masking prompt",
show_label=False,
max_lines=1,
placeholder="Enter text to generate masking",
container=False,
)
with gr.Row():
mask_inflation_slider_component = gr.Slider(
label="Mask inflation",
info="Adjusts the amount of mask edge expansion before "
"inpainting.",
minimum=0,
maximum=20,
step=1,
value=5,
)
mask_blur_slider_component = gr.Slider(
label="Mask blur",
info="Controls the intensity of the Gaussian blur applied to "
"the mask edges.",
minimum=0,
maximum=20,
step=1,
value=5,
)
seed_slicer_component = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed_checkbox_component = gr.Checkbox(
label="Randomize seed", value=True)
with gr.Row():
strength_slider_component = gr.Slider(
label="Strength",
info="Indicates extent to transform the reference `image`. "
"Must be between 0 and 1. `image` is used as a starting "
"point and more noise is added the higher the `strength`.",
minimum=0,
maximum=1,
step=0.01,
value=0.85,
)
num_inference_steps_slider_component = gr.Slider(
label="Number of inference steps",
info="The number of denoising steps. More denoising steps "
"usually lead to a higher quality image at the",
minimum=1,
maximum=50,
step=1,
value=20,
)
with gr.Column():
output_image_component = gr.Image(
type='pil', image_mode='RGB', label='Generated image', format="png")
with gr.Accordion("Debug", open=False):
output_mask_component = gr.Image(
type='pil', image_mode='RGB', label='Input mask', format="png")
gr.Examples(
fn=process_example,
examples=EXAMPLES,
inputs=[
input_image_editor_component,
inpainting_prompt_text_component,
masking_prompt_text_component,
mask_inflation_slider_component,
mask_blur_slider_component,
seed_slicer_component,
randomize_seed_checkbox_component,
strength_slider_component,
num_inference_steps_slider_component,
uploaded_mask_component
],
outputs=[
output_image_component,
output_mask_component
],
run_on_click=False
)
submit_button_component.click(
fn=process,
inputs=[
client_component,
input_image_editor_component,
inpainting_prompt_text_component,
masking_prompt_text_component,
mask_inflation_slider_component,
mask_blur_slider_component,
seed_slicer_component,
randomize_seed_checkbox_component,
strength_slider_component,
num_inference_steps_slider_component,
uploaded_mask_component
],
outputs=[
output_image_component,
output_mask_component
]
)
demo.load(set_client_for_session, None, client_component)
demo.launch(debug=False, show_error=True)