Spaces:
Paused
Paused
from typing import Literal, Union, Optional, Tuple, List | |
import torch | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection | |
from diffusers import ( | |
UNet2DConditionModel, | |
SchedulerMixin, | |
StableDiffusionPipeline, | |
StableDiffusionXLPipeline, | |
AutoencoderKL, | |
) | |
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( | |
convert_ldm_unet_checkpoint, | |
) | |
from safetensors.torch import load_file | |
from diffusers.schedulers import ( | |
DDIMScheduler, | |
DDPMScheduler, | |
LMSDiscreteScheduler, | |
EulerDiscreteScheduler, | |
EulerAncestralDiscreteScheduler, | |
UniPCMultistepScheduler, | |
) | |
from omegaconf import OmegaConf | |
# DiffUsers版StableDiffusionのモデルパラメータ | |
NUM_TRAIN_TIMESTEPS = 1000 | |
BETA_START = 0.00085 | |
BETA_END = 0.0120 | |
UNET_PARAMS_MODEL_CHANNELS = 320 | |
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] | |
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] | |
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` | |
UNET_PARAMS_IN_CHANNELS = 4 | |
UNET_PARAMS_OUT_CHANNELS = 4 | |
UNET_PARAMS_NUM_RES_BLOCKS = 2 | |
UNET_PARAMS_CONTEXT_DIM = 768 | |
UNET_PARAMS_NUM_HEADS = 8 | |
# UNET_PARAMS_USE_LINEAR_PROJECTION = False | |
VAE_PARAMS_Z_CHANNELS = 4 | |
VAE_PARAMS_RESOLUTION = 256 | |
VAE_PARAMS_IN_CHANNELS = 3 | |
VAE_PARAMS_OUT_CH = 3 | |
VAE_PARAMS_CH = 128 | |
VAE_PARAMS_CH_MULT = [1, 2, 4, 4] | |
VAE_PARAMS_NUM_RES_BLOCKS = 2 | |
# V2 | |
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] | |
V2_UNET_PARAMS_CONTEXT_DIM = 1024 | |
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True | |
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" | |
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" | |
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"] | |
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] | |
DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this | |
def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"): | |
# text encoderの格納形式が違うモデルに対応する ('text_model'がない) | |
TEXT_ENCODER_KEY_REPLACEMENTS = [ | |
( | |
"cond_stage_model.transformer.embeddings.", | |
"cond_stage_model.transformer.text_model.embeddings.", | |
), | |
( | |
"cond_stage_model.transformer.encoder.", | |
"cond_stage_model.transformer.text_model.encoder.", | |
), | |
( | |
"cond_stage_model.transformer.final_layer_norm.", | |
"cond_stage_model.transformer.text_model.final_layer_norm.", | |
), | |
] | |
if ckpt_path.endswith(".safetensors"): | |
checkpoint = None | |
state_dict = load_file(ckpt_path) # , device) # may causes error | |
else: | |
checkpoint = torch.load(ckpt_path, map_location=device) | |
if "state_dict" in checkpoint: | |
state_dict = checkpoint["state_dict"] | |
else: | |
state_dict = checkpoint | |
checkpoint = None | |
key_reps = [] | |
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: | |
for key in state_dict.keys(): | |
if key.startswith(rep_from): | |
new_key = rep_to + key[len(rep_from) :] | |
key_reps.append((key, new_key)) | |
for key, new_key in key_reps: | |
state_dict[new_key] = state_dict[key] | |
del state_dict[key] | |
return checkpoint, state_dict | |
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): | |
""" | |
Creates a config for the diffusers based on the config of the LDM model. | |
""" | |
# unet_params = original_config.model.params.unet_config.params | |
block_out_channels = [ | |
UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT | |
] | |
down_block_types = [] | |
resolution = 1 | |
for i in range(len(block_out_channels)): | |
block_type = ( | |
"CrossAttnDownBlock2D" | |
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS | |
else "DownBlock2D" | |
) | |
down_block_types.append(block_type) | |
if i != len(block_out_channels) - 1: | |
resolution *= 2 | |
up_block_types = [] | |
for i in range(len(block_out_channels)): | |
block_type = ( | |
"CrossAttnUpBlock2D" | |
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS | |
else "UpBlock2D" | |
) | |
up_block_types.append(block_type) | |
resolution //= 2 | |
config = dict( | |
sample_size=UNET_PARAMS_IMAGE_SIZE, | |
in_channels=UNET_PARAMS_IN_CHANNELS, | |
out_channels=UNET_PARAMS_OUT_CHANNELS, | |
down_block_types=tuple(down_block_types), | |
up_block_types=tuple(up_block_types), | |
block_out_channels=tuple(block_out_channels), | |
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, | |
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM | |
if not v2 | |
else V2_UNET_PARAMS_CONTEXT_DIM, | |
attention_head_dim=UNET_PARAMS_NUM_HEADS | |
if not v2 | |
else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, | |
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, | |
) | |
if v2 and use_linear_projection_in_v2: | |
config["use_linear_projection"] = True | |
return config | |
def load_diffusers_model( | |
pretrained_model_name_or_path: str, | |
v2: bool = False, | |
clip_skip: Optional[int] = None, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: | |
if v2: | |
tokenizer = CLIPTokenizer.from_pretrained( | |
TOKENIZER_V2_MODEL_NAME, | |
subfolder="tokenizer", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
# default is clip skip 2 | |
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
else: | |
tokenizer = CLIPTokenizer.from_pretrained( | |
TOKENIZER_V1_MODEL_NAME, | |
subfolder="tokenizer", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") | |
return tokenizer, text_encoder, unet, vae | |
def load_checkpoint_model( | |
checkpoint_path: str, | |
v2: bool = False, | |
clip_skip: Optional[int] = None, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: | |
pipe = StableDiffusionPipeline.from_single_file( | |
checkpoint_path, | |
upcast_attention=True if v2 else False, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
_, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path) | |
unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2) | |
unet_config["class_embed_type"] = None | |
unet_config["addition_embed_type"] = None | |
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) | |
unet = UNet2DConditionModel(**unet_config) | |
unet.load_state_dict(converted_unet_checkpoint) | |
tokenizer = pipe.tokenizer | |
text_encoder = pipe.text_encoder | |
vae = pipe.vae | |
if clip_skip is not None: | |
if v2: | |
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) | |
else: | |
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) | |
del pipe | |
return tokenizer, text_encoder, unet, vae | |
def load_models( | |
pretrained_model_name_or_path: str, | |
scheduler_name: str, | |
v2: bool = False, | |
v_pred: bool = False, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: | |
if pretrained_model_name_or_path.endswith( | |
".ckpt" | |
) or pretrained_model_name_or_path.endswith(".safetensors"): | |
tokenizer, text_encoder, unet, vae = load_checkpoint_model( | |
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype | |
) | |
else: # diffusers | |
tokenizer, text_encoder, unet, vae = load_diffusers_model( | |
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype | |
) | |
if scheduler_name: | |
scheduler = create_noise_scheduler( | |
scheduler_name, | |
prediction_type="v_prediction" if v_pred else "epsilon", | |
) | |
else: | |
scheduler = None | |
return tokenizer, text_encoder, unet, scheduler, vae | |
def load_diffusers_model_xl( | |
pretrained_model_name_or_path: str, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: | |
# returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet | |
tokenizers = [ | |
CLIPTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
), | |
CLIPTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer_2", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
pad_token_id=0, # same as open clip | |
), | |
] | |
text_encoders = [ | |
CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
), | |
CLIPTextModelWithProjection.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder_2", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
), | |
] | |
unet = UNet2DConditionModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="unet", | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") | |
return tokenizers, text_encoders, unet, vae | |
def load_checkpoint_model_xl( | |
checkpoint_path: str, | |
weight_dtype: torch.dtype = torch.float32, | |
) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: | |
pipe = StableDiffusionXLPipeline.from_single_file( | |
checkpoint_path, | |
torch_dtype=weight_dtype, | |
cache_dir=DIFFUSERS_CACHE_DIR, | |
) | |
unet = pipe.unet | |
vae = pipe.vae | |
tokenizers = [pipe.tokenizer, pipe.tokenizer_2] | |
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
if len(text_encoders) == 2: | |
text_encoders[1].pad_token_id = 0 | |
del pipe | |
return tokenizers, text_encoders, unet, vae | |
def load_models_xl( | |
pretrained_model_name_or_path: str, | |
scheduler_name: str, | |
weight_dtype: torch.dtype = torch.float32, | |
noise_scheduler_kwargs=None, | |
) -> Tuple[ | |
List[CLIPTokenizer], | |
List[SDXL_TEXT_ENCODER_TYPE], | |
UNet2DConditionModel, | |
SchedulerMixin, | |
]: | |
if pretrained_model_name_or_path.endswith( | |
".ckpt" | |
) or pretrained_model_name_or_path.endswith(".safetensors"): | |
(tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl( | |
pretrained_model_name_or_path, weight_dtype | |
) | |
else: # diffusers | |
(tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl( | |
pretrained_model_name_or_path, weight_dtype | |
) | |
if scheduler_name: | |
scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs) | |
else: | |
scheduler = None | |
return tokenizers, text_encoders, unet, scheduler, vae | |
def create_noise_scheduler( | |
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", | |
noise_scheduler_kwargs=None, | |
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", | |
) -> SchedulerMixin: | |
name = scheduler_name.lower().replace(" ", "_") | |
if name.lower() == "ddim": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim | |
scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) | |
elif name.lower() == "ddpm": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm | |
scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) | |
elif name.lower() == "lms": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete | |
scheduler = LMSDiscreteScheduler( | |
**OmegaConf.to_container(noise_scheduler_kwargs) | |
) | |
elif name.lower() == "euler_a": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral | |
scheduler = EulerAncestralDiscreteScheduler( | |
**OmegaConf.to_container(noise_scheduler_kwargs) | |
) | |
elif name.lower() == "euler": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral | |
scheduler = EulerDiscreteScheduler( | |
**OmegaConf.to_container(noise_scheduler_kwargs) | |
) | |
elif name.lower() == "unipc": | |
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc | |
scheduler = UniPCMultistepScheduler( | |
**OmegaConf.to_container(noise_scheduler_kwargs) | |
) | |
else: | |
raise ValueError(f"Unknown scheduler name: {name}") | |
return scheduler | |
def torch_gc(): | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
with torch.cuda.device("cuda"): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
from enum import Enum | |
class CPUState(Enum): | |
GPU = 0 | |
CPU = 1 | |
MPS = 2 | |
cpu_state = CPUState.GPU | |
xpu_available = False | |
directml_enabled = False | |
def is_intel_xpu(): | |
global cpu_state | |
global xpu_available | |
if cpu_state == CPUState.GPU: | |
if xpu_available: | |
return True | |
return False | |
try: | |
import intel_extension_for_pytorch as ipex | |
if torch.xpu.is_available(): | |
xpu_available = True | |
except: | |
pass | |
try: | |
if torch.backends.mps.is_available(): | |
cpu_state = CPUState.MPS | |
import torch.mps | |
except: | |
pass | |
def get_torch_device(): | |
global directml_enabled | |
global cpu_state | |
if directml_enabled: | |
global directml_device | |
return directml_device | |
if cpu_state == CPUState.MPS: | |
return torch.device("mps") | |
if cpu_state == CPUState.CPU: | |
return torch.device("cpu") | |
else: | |
if is_intel_xpu(): | |
return torch.device("xpu") | |
else: | |
return torch.device(torch.cuda.current_device()) | |