|
import io |
|
import requests |
|
import numpy as np |
|
import torch |
|
import os |
|
from PIL import Image |
|
from typing import List, Optional |
|
from functools import reduce |
|
from argparse import ArgumentParser |
|
|
|
import gradio as gr |
|
|
|
from transformers import DetrFeatureExtractor, DetrForSegmentation, DetrConfig |
|
from transformers.models.detr.feature_extraction_detr import rgb_to_id |
|
|
|
from diffusers import StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler |
|
|
|
parser = ArgumentParser() |
|
parser.add_argument('--disable-cuda', action='store_true') |
|
parser.add_argument('--attention-slicing', action='store_true') |
|
args = parser.parse_args() |
|
|
|
auth_token = os.environ.get("READ_TOKEN") |
|
try_cuda = not args.disable_cuda |
|
|
|
torch.inference_mode() |
|
torch.no_grad() |
|
|
|
|
|
def get_device(try_cuda=True): |
|
return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu') |
|
|
|
device = get_device(try_cuda=try_cuda) |
|
|
|
|
|
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'): |
|
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name) |
|
model = DetrForSegmentation.from_pretrained(model_name) |
|
cfg = DetrConfig.from_pretrained(model_name) |
|
|
|
return feature_extractor, model, cfg |
|
|
|
|
|
def load_diffusion_pipeline(model_name: str = 'stabilityai/stable-diffusion-2-inpainting'): |
|
return StableDiffusionInpaintPipeline.from_pretrained( |
|
model_name, |
|
revision='fp16', |
|
torch_dtype=torch.float16 if try_cuda and torch.cuda.is_available() else torch.float32, |
|
use_auth_token=auth_token |
|
) |
|
|
|
def min_pool(x: torch.Tensor, kernel_size: int): |
|
pad_size = (kernel_size - 1) // 2 |
|
return -torch.nn.functional.max_pool2d(-x, kernel_size, (1, 1), padding=pad_size) |
|
|
|
def max_pool(x: torch.Tensor, kernel_size: int): |
|
pad_size = (kernel_size - 1) // 2 |
|
return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size) |
|
|
|
|
|
def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5): |
|
mask = torch.Tensor(mask[None, None]).float().to(device) |
|
mask = min_pool(mask, min_kernel) |
|
mask = max_pool(mask, max_kernel) |
|
mask = mask.bool().squeeze().cpu().numpy() |
|
return mask |
|
|
|
|
|
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models() |
|
pipe = load_diffusion_pipeline() |
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
segmentation_model = segmentation_model.to(device) |
|
pipe = pipe.to(device) |
|
if args.attention_slicing: |
|
pipe.enable_attention_slicing() |
|
|
|
|
|
def fn_segmentation(image, max_kernel, min_kernel): |
|
inputs = feature_extractor(images=image, return_tensors="pt").to(device) |
|
outputs = segmentation_model(**inputs) |
|
|
|
processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0) |
|
result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0] |
|
|
|
panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height)) |
|
panoptic_seg = np.array(panoptic_seg, dtype=np.uint8) |
|
|
|
panoptic_seg_id = rgb_to_id(panoptic_seg) |
|
|
|
raw_masks = [] |
|
for s in result['segments_info']: |
|
m = panoptic_seg_id == s['id'] |
|
raw_masks.append(m.astype(np.uint8) * 255) |
|
|
|
checkbox_choices = [f"{s['id']}:{segmentation_cfg.id2label[s['category_id']]}" for s in result['segments_info']] |
|
|
|
checkbox_group = gr.CheckboxGroup.update( |
|
choices=checkbox_choices |
|
) |
|
|
|
return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image) |
|
|
|
|
|
def fn_update_mask( |
|
image: Image, |
|
masks: List[np.array], |
|
masks_enabled: List[int], |
|
max_kernel: int, |
|
min_kernel: int, |
|
invert_mask: bool |
|
): |
|
masks_enabled = [int(m.split(':')[0]) for m in masks_enabled] |
|
combined_mask = reduce(lambda x, y: x | y, [masks[i] for i in masks_enabled], np.zeros_like(masks[0], dtype=bool)) |
|
|
|
if invert_mask: |
|
combined_mask = ~combined_mask |
|
|
|
combined_mask = clean_mask(combined_mask, max_kernel, min_kernel) |
|
|
|
masked_image = np.array(image).copy() |
|
masked_image[combined_mask] = 0.0 |
|
|
|
return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image) |
|
|
|
|
|
def fn_diffusion( |
|
prompt: str, |
|
masked_image: Image, |
|
mask: Image, |
|
num_diffusion_steps: int, |
|
guidance_scale: float, |
|
negative_prompt: Optional[str] = None, |
|
): |
|
if len(negative_prompt) == 0: |
|
negative_prompt = None |
|
|
|
|
|
|
|
STABLE_DIFFUSION_SMALL_EDGE = 512 |
|
|
|
w, h = masked_image.size |
|
is_width_larger = w > h |
|
resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w) |
|
|
|
new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE |
|
new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio) |
|
|
|
new_width += 8 - (new_width % 8) if is_width_larger else 0 |
|
new_height += 0 if is_width_larger else 8 - (new_height % 8) |
|
|
|
mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height)) |
|
masked_image = masked_image.convert("RGB").resize((new_width, new_height)) |
|
|
|
|
|
inpainted_image = pipe( |
|
height=new_height, |
|
width=new_width, |
|
prompt=prompt, |
|
image=masked_image, |
|
mask_image=mask, |
|
num_inference_steps=num_diffusion_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt |
|
).images[0] |
|
|
|
|
|
inpainted_image = inpainted_image.resize((w, h)) |
|
|
|
return inpainted_image |
|
|
|
demo = gr.Blocks(css=open('app.css').read()) |
|
|
|
with demo: |
|
|
|
|
|
input_image = gr.Image(type='pil', label="Input Image") |
|
|
|
bt_masks = gr.Button("Compute Masks") |
|
with gr.Row(): |
|
masked_image = gr.Image(type='pil', label="Masked Image") |
|
mask_storage = gr.State() |
|
|
|
|
|
with gr.Row(): |
|
max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow") |
|
min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising") |
|
|
|
with gr.Row(style="align-contents:left;"): |
|
invert_mask = gr.Checkbox(label="Invert Mask") |
|
with gr.Row(): |
|
mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox("An angry dog floating in outer deep space. Twinkling stars in the background. High definition.", label="Prompt") |
|
negative_prompt = gr.Textbox(label="Negative Prompt") |
|
with gr.Column(): |
|
steps_slider = gr.Slider(minimum=1, maximum=100, value=50, label="Inference Steps") |
|
guidance_slider = gr.Slider(minimum=0.0, maximum=50.0, value=7.5, step=0.1, label="Guidance Scale") |
|
bt_diffusion = gr.Button("Run Diffusion") |
|
mask_image = gr.Image(type='numpy', label="Diffusion Mask") |
|
|
|
inpainted_image = gr.Image(type='pil', label="Inpainted Image") |
|
|
|
|
|
|
|
update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider, invert_mask] |
|
update_mask_outputs = [mask_image, masked_image] |
|
|
|
|
|
input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes) |
|
input_image.change(lambda: gr.Checkbox.update(value=False), outputs=invert_mask) |
|
|
|
|
|
bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image]) |
|
|
|
|
|
max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs, show_progress=False) |
|
min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs, show_progress=False) |
|
mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs, show_progress=False) |
|
invert_mask.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs, show_progress=False) |
|
|
|
|
|
bt_diffusion.click(fn_diffusion, inputs=[ |
|
prompt, |
|
masked_image, |
|
mask_image, |
|
steps_slider, |
|
guidance_slider, |
|
negative_prompt |
|
], outputs=inpainted_image) |
|
gr.HTML(open('app_license.html').read()) |
|
|
|
demo.launch() |
|
|