myMagicClothing / garment_adapter /garment_diffusion.py
yzy0713's picture
Add files
6a05036
raw
history blame
10.9 kB
import copy
import torch
from safetensors import safe_open
from garment_seg.process import load_seg_model, generate_mask
from utils.utils import is_torch2_available, prepare_image, prepare_mask
from diffusers import UNet2DConditionModel
if is_torch2_available():
from .attention_processor import REFAttnProcessor2_0 as REFAttnProcessor
from .attention_processor import AttnProcessor2_0 as AttnProcessor
from .attention_processor import REFAnimateDiffAttnProcessor2_0 as REFAnimateDiffAttnProcessor
else:
from .attention_processor import REFAttnProcessor, AttnProcessor
class ClothAdapter:
def __init__(self, sd_pipe, ref_path, device, enable_cloth_guidance, set_seg_model=True):
self.enable_cloth_guidance = enable_cloth_guidance
self.device = device
self.pipe = sd_pipe.to(self.device)
self.set_adapter(self.pipe.unet, "write")
print(ref_path)
ref_unet = copy.deepcopy(sd_pipe.unet)
if ref_unet.config.in_channels == 9:
ref_unet.conv_in = torch.nn.Conv2d(4, 320, ref_unet.conv_in.kernel_size, ref_unet.conv_in.stride, ref_unet.conv_in.padding)
ref_unet.register_to_config(in_channels=4)
state_dict = {}
with safe_open(ref_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
ref_unet.load_state_dict(state_dict, strict=False)
self.ref_unet = ref_unet.to(self.device, dtype=self.pipe.dtype)
self.set_adapter(self.ref_unet, "read")
if set_seg_model:
self.set_seg_model()
self.attn_store = {}
def set_seg_model(self, ):
checkpoint_path = 'checkpoints/cloth_segm.pth'
self.seg_net = load_seg_model(checkpoint_path, device=self.device)
def set_adapter(self, unet, type):
attn_procs = {}
for name in unet.attn_processors.keys():
if "attn1" in name:
attn_procs[name] = REFAttnProcessor(name=name, type=type)
else:
attn_procs[name] = AttnProcessor()
unet.set_attn_processor(attn_procs)
def generate(
self,
cloth_image,
cloth_mask_image=None,
prompt=None,
a_prompt="best quality, high quality",
num_images_per_prompt=4,
negative_prompt=None,
seed=-1,
guidance_scale=7.5,
cloth_guidance_scale=2.5,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
):
if cloth_mask_image is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16)
if prompt is None:
prompt = "a photography of a model"
prompt = prompt + ", " + a_prompt
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
with torch.inference_mode():
prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
if self.enable_cloth_guidance:
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
cloth_guidance_scale=cloth_guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance},
**kwargs,
).images
else:
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0, "enable_cloth_guidance": self.enable_cloth_guidance},
**kwargs,
).images
return images, cloth_mask_image
def generate_inpainting(
self,
cloth_image,
cloth_mask_image=None,
num_images_per_prompt=4,
seed=-1,
cloth_guidance_scale=2.5,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
):
if cloth_mask_image is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16)
with torch.inference_mode():
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe(
prompt_embeds=prompt_embeds_null,
cloth_guidance_scale=cloth_guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": cloth_guidance_scale > 1.0, "enable_cloth_guidance": False},
**kwargs,
).images
return images, cloth_mask_image
class ClothAdapter_AnimateDiff:
def __init__(self, sd_pipe, pipe_path, ref_path, device, set_seg_model=True):
self.device = device
self.pipe = sd_pipe.to(self.device)
self.set_adapter(self.pipe.unet, "write")
ref_unet = UNet2DConditionModel.from_pretrained(pipe_path, subfolder='unet', torch_dtype=sd_pipe.dtype)
state_dict = {}
with safe_open(ref_path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
ref_unet.load_state_dict(state_dict, strict=False)
self.ref_unet = ref_unet.to(self.device)
self.set_adapter(self.ref_unet, "read")
if set_seg_model:
self.set_seg_model()
self.attn_store = {}
def set_seg_model(self, ):
checkpoint_path = 'checkpoints/cloth_segm.pth'
self.seg_net = load_seg_model(checkpoint_path, device=self.device)
def set_adapter(self, unet, type):
attn_procs = {}
for name in unet.attn_processors.keys():
if "attn1" in name and "motion_modules" not in name:
attn_procs[name] = REFAnimateDiffAttnProcessor(name=name, type=type)
else:
attn_procs[name] = AttnProcessor()
unet.set_attn_processor(attn_procs)
def generate(
self,
cloth_image,
cloth_mask_image=None,
prompt=None,
a_prompt="best quality, high quality",
num_images_per_prompt=4,
negative_prompt=None,
seed=-1,
guidance_scale=7.5,
cloth_guidance_scale=3.,
num_inference_steps=20,
height=512,
width=384,
**kwargs,
):
if cloth_mask_image is None:
cloth_mask_image = generate_mask(cloth_image, net=self.seg_net, device=self.device)
cloth = prepare_image(cloth_image, height, width)
cloth_mask = prepare_mask(cloth_mask_image, height, width)
cloth = (cloth * cloth_mask).to(self.device, dtype=torch.float16)
if prompt is None:
prompt = "a photography of a model"
prompt = prompt + ", " + a_prompt
if negative_prompt is None:
negative_prompt = "bare, naked, nude, undressed, monochrome, lowres, bad anatomy, worst quality, low quality"
with torch.inference_mode():
prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds_null = self.pipe.encode_prompt([""], device=self.device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=False)[0]
cloth_embeds = self.pipe.vae.encode(cloth).latent_dist.mode() * self.pipe.vae.config.scaling_factor
self.ref_unet(torch.cat([cloth_embeds] * num_images_per_prompt), 0, prompt_embeds_null, cross_attention_kwargs={"attn_store": self.attn_store})
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
frames = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
cloth_guidance_scale=cloth_guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
cross_attention_kwargs={"attn_store": self.attn_store, "do_classifier_free_guidance": guidance_scale > 1.0},
**kwargs,
).frames
return frames, cloth_mask_image