FitDiT / app.py
BoyuanJiang's picture
update app
35f7043
import spaces
import gradio as gr
import os
import math
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.dwpose import DWposeDetector
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
import torch
import torch.nn as nn
from src.pose_guider import PoseGuider
from PIL import Image
from src.utils_mask import get_mask_location
import numpy as np
from src.pipeline_stable_diffusion_3_tryon import StableDiffusion3TryOnPipeline
from src.transformer_sd3_garm import SD3Transformer2DModel as SD3Transformer2DModel_Garm
from src.transformer_sd3_vton import SD3Transformer2DModel as SD3Transformer2DModel_Vton
import cv2
import random
from huggingface_hub import snapshot_download
example_path = os.path.join(os.path.dirname(__file__), 'examples')
fitdit_repo = "BoyuanJiang/FitDiT"
repo_path = snapshot_download(repo_id=fitdit_repo)
weight_dtype = torch.bfloat16
device = "cuda"
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
pose_guider.to(device=device, dtype=weight_dtype)
image_encoder_large.to(device=device)
image_encoder_bigG.to(device=device)
pipeline = StableDiffusion3TryOnPipeline.from_pretrained(repo_path, torch_dtype=weight_dtype, \
transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, \
image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
pipeline.to(device)
dwprocessor = DWposeDetector(model_root=repo_path, device=device)
parsing_model = Parsing(model_root=repo_path, device=device)
def generate_mask(vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
with torch.inference_mode():
vton_img = Image.open(vton_img)
vton_img_det = resize_image(vton_img)
pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
candidate[candidate<0]=0
candidate = candidate[0]
candidate[:, 0]*=vton_img_det.width
candidate[:, 1]*=vton_img_det.height
pose_image = pose_image[:,:,::-1] #rgb
pose_image = Image.fromarray(pose_image)
model_parse, _ = parsing_model(vton_img_det)
mask, mask_gray = get_mask_location(category, model_parse, \
candidate, model_parse.width, model_parse.height, \
offset_top, offset_bottom, offset_left, offset_right)
mask = mask.resize(vton_img.size)
mask_gray = mask_gray.resize(vton_img.size)
mask = mask.convert("L")
mask_gray = mask_gray.convert("L")
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
im = {}
im['background'] = np.array(vton_img.convert("RGBA"))
im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
im['composite'] = np.array(masked_vton_img.convert("RGBA"))
return im, pose_image
@spaces.GPU
def process(vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
assert resolution in ["768x1024", "1152x1536", "1536x2048"]
new_width, new_height = resolution.split("x")
new_width = int(new_width)
new_height = int(new_height)
with torch.inference_mode():
garm_img = Image.open(garm_img)
vton_img = Image.open(vton_img)
model_image_size = vton_img.size
garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
mask = pre_mask["layers"][0][:,:,3]
mask = Image.fromarray(mask)
mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
mask = mask.convert("L")
pose_image = Image.fromarray(pose_image)
pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
if seed==-1:
seed = random.randint(0, 2147483647)
res = pipeline(
height=new_height,
width=new_width,
guidance_scale=image_scale,
num_inference_steps=n_steps,
generator=torch.Generator("cpu").manual_seed(seed),
cloth_image=garm_img,
model_image=vton_img,
mask=mask,
pose_image=pose_image,
num_images_per_prompt=num_images_per_prompt
).images
for idx in range(len(res)):
res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
return res
def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
old_width, old_height = im.size
ratio_w = new_width / old_width
ratio_h = new_height / old_height
if ratio_w < ratio_h:
new_size = (new_width, round(old_height * ratio_w))
else:
new_size = (round(old_width * ratio_h), new_height)
im_resized = im.resize(new_size, mode)
pad_w = math.ceil((new_width - im_resized.width) / 2)
pad_h = math.ceil((new_height - im_resized.height) / 2)
new_im = Image.new('RGB', (new_width, new_height), pad_color)
new_im.paste(im_resized, (pad_w, pad_h))
return new_im, pad_w, pad_h
def unpad_and_resize(padded_im, pad_w, pad_h, original_width, original_height):
width, height = padded_im.size
left = pad_w
top = pad_h
right = width - pad_w
bottom = height - pad_h
cropped_im = padded_im.crop((left, top, right, bottom))
resized_im = cropped_im.resize((original_width, original_height), Image.LANCZOS)
return resized_im
def resize_image(img, target_size=768):
width, height = img.size
if width < height:
scale = target_size / width
else:
scale = target_size / height
new_width = int(round(width * scale))
new_height = int(round(height * scale))
resized_img = img.resize((new_width, new_height), Image.LANCZOS)
return resized_img
HEADER = """
<h1 style="text-align: center;"> FitDiT: Advancing the Authentic Garment Details for High-fidelity Virtual Try-on </h1>
<div style="display: flex; justify-content: center; align-items: center;">
<a href="https://github.com/BoyuanJiang/FitDiT" style="margin: 0 2px;">
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
</a>
<a href="https://arxiv.org/abs/2411.10499" style="margin: 0 2px;">
<img src='https://img.shields.io/badge/arXiv-2411.10499-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
</a>
<a href="http://demo.fitdit.byjiang.com/" style="margin: 0 2px;">
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
</a>
<a href='https://byjiang.com/FitDiT/' style="margin: 0 2px;">
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
</a>
<a href="https://raw.githubusercontent.com/BoyuanJiang/FitDiT/refs/heads/main/LICENSE" style="margin: 0 2px;">
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
</a>
</div>
<br>
FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers (DiT). It can only be used for <b>Non-commercial Use</b>.<br>
If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>. A <b>ComfyUI version</b> of FitDiT is available <a href="https://github.com/BoyuanJiang/FitDiT/tree/FitDiT-ComfyUI" style="color: blue; text-decoration: underline;">here</a>.
"""
def create_demo():
with gr.Blocks(title="FitDiT") as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
vton_img = gr.Image(label="Model", sources=None, type="filepath", height=512)
with gr.Column():
garm_img = gr.Image(label="Garment", sources=None, type="filepath", height=512)
with gr.Row():
with gr.Column():
masked_vton_img = gr.ImageEditor(label="masked_vton_img", type="numpy", height=512, interactive=True, brush=gr.Brush(default_color="rgb(127, 127, 127)", colors=[
"rgb(128, 128, 128)"
]))
pose_image = gr.Image(label="pose_image", visible=False, interactive=False)
with gr.Column():
result_gallery = gr.Gallery(label="Output", elem_id="output-img", interactive=False, columns=[2], rows=[2], object_fit="contain", height="auto")
with gr.Row():
with gr.Column():
offset_top = gr.Slider(label="mask offset top", minimum=-200, maximum=200, step=1, value=0)
with gr.Column():
offset_bottom = gr.Slider(label="mask offset bottom", minimum=-200, maximum=200, step=1, value=0)
with gr.Column():
offset_left = gr.Slider(label="mask offset left", minimum=-200, maximum=200, step=1, value=0)
with gr.Column():
offset_right = gr.Slider(label="mask offset right", minimum=-200, maximum=200, step=1, value=0)
with gr.Row():
with gr.Column():
n_steps = gr.Slider(label="Steps", minimum=15, maximum=30, value=20, step=1)
with gr.Column():
image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2, step=0.1)
with gr.Column():
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
with gr.Column():
num_images_per_prompt = gr.Slider(label="num_images", minimum=1, maximum=4, step=1, value=1)
with gr.Row():
with gr.Column():
example = gr.Examples(
label="Model (upper-body)",
inputs=vton_img,
examples_per_page=7,
examples=[
os.path.join(example_path, 'model/0279.jpg'),
os.path.join(example_path, 'model/0303.jpg'),
os.path.join(example_path, 'model/2.jpg'),
os.path.join(example_path, 'model/0083.jpg'),
])
example = gr.Examples(
label="Model (upper-body/lower-body)",
inputs=vton_img,
examples_per_page=7,
examples=[
os.path.join(example_path, 'model/0.jpg'),
os.path.join(example_path, 'model/0179.jpg'),
os.path.join(example_path, 'model/0223.jpg'),
os.path.join(example_path, 'model/0347.jpg'),
])
example = gr.Examples(
label="Model (dresses)",
inputs=vton_img,
examples_per_page=7,
examples=[
os.path.join(example_path, 'model/4.jpg'),
os.path.join(example_path, 'model/5.jpg'),
os.path.join(example_path, 'model/6.jpg'),
os.path.join(example_path, 'model/7.jpg'),
])
with gr.Column():
example = gr.Examples(
label="Garment (upper-body)",
inputs=garm_img,
examples_per_page=7,
examples=[
os.path.join(example_path, 'garment/12.jpg'),
os.path.join(example_path, 'garment/0012.jpg'),
os.path.join(example_path, 'garment/0047.jpg'),
os.path.join(example_path, 'garment/0049.jpg'),
])
example = gr.Examples(
label="Garment (lower-body)",
inputs=garm_img,
examples_per_page=7,
examples=[
os.path.join(example_path, 'garment/0317.jpg'),
os.path.join(example_path, 'garment/0327.jpg'),
os.path.join(example_path, 'garment/0329.jpg'),
os.path.join(example_path, 'garment/0362.jpg'),
])
example = gr.Examples(
label="Garment (dresses)",
inputs=garm_img,
examples_per_page=7,
examples=[
os.path.join(example_path, 'garment/8.jpg'),
os.path.join(example_path, 'garment/9.png'),
os.path.join(example_path, 'garment/10.jpg'),
os.path.join(example_path, 'garment/11.jpg'),
])
with gr.Column():
category = gr.Dropdown(label="Garment category", choices=["Upper-body", "Lower-body", "Dresses"], value="Upper-body")
resolution = gr.Dropdown(label="Try-on resolution", choices=["768x1024", "1152x1536", "1536x2048"], value="768x1024")
with gr.Column():
run_mask_button = gr.Button(value="Step1: Run Mask")
run_button = gr.Button(value="Step2: Run Try-on")
ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
run_mask_button.click(fn=generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
run_button.click(fn=process, inputs=ips2, outputs=[result_gallery])
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()