from typing import Literal, Union, Optional, Tuple, List import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from diffusers import ( UNet2DConditionModel, SchedulerMixin, StableDiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL, ) from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( convert_ldm_unet_checkpoint, ) from safetensors.torch import load_file from diffusers.schedulers import ( DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler, ) from omegaconf import OmegaConf # DiffUsers版StableDiffusionのモデルパラメータ NUM_TRAIN_TIMESTEPS = 1000 BETA_START = 0.00085 BETA_END = 0.0120 UNET_PARAMS_MODEL_CHANNELS = 320 UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` UNET_PARAMS_IN_CHANNELS = 4 UNET_PARAMS_OUT_CHANNELS = 4 UNET_PARAMS_NUM_RES_BLOCKS = 2 UNET_PARAMS_CONTEXT_DIM = 768 UNET_PARAMS_NUM_HEADS = 8 # UNET_PARAMS_USE_LINEAR_PROJECTION = False VAE_PARAMS_Z_CHANNELS = 4 VAE_PARAMS_RESOLUTION = 256 VAE_PARAMS_IN_CHANNELS = 3 VAE_PARAMS_OUT_CH = 3 VAE_PARAMS_CH = 128 VAE_PARAMS_CH_MULT = [1, 2, 4, 4] VAE_PARAMS_NUM_RES_BLOCKS = 2 # V2 V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] V2_UNET_PARAMS_CONTEXT_DIM = 1024 # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"] SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) TEXT_ENCODER_KEY_REPLACEMENTS = [ ( "cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings.", ), ( "cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder.", ), ( "cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm.", ), ] if ckpt_path.endswith(".safetensors"): checkpoint = None state_dict = load_file(ckpt_path) # , device) # may causes error else: checkpoint = torch.load(ckpt_path, map_location=device) if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint checkpoint = None key_reps = [] for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: for key in state_dict.keys(): if key.startswith(rep_from): new_key = rep_to + key[len(rep_from) :] key_reps.append((key, new_key)) for key, new_key in key_reps: state_dict[new_key] = state_dict[key] del state_dict[key] return checkpoint, state_dict def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): """ Creates a config for the diffusers based on the config of the LDM model. """ # unet_params = original_config.model.params.unet_config.params block_out_channels = [ UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT ] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = ( "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" ) down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = ( "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" ) up_block_types.append(block_type) resolution //= 2 config = dict( sample_size=UNET_PARAMS_IMAGE_SIZE, in_channels=UNET_PARAMS_IN_CHANNELS, out_channels=UNET_PARAMS_OUT_CHANNELS, down_block_types=tuple(down_block_types), up_block_types=tuple(up_block_types), block_out_channels=tuple(block_out_channels), layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, ) if v2 and use_linear_projection_in_v2: config["use_linear_projection"] = True return config def load_diffusers_model( pretrained_model_name_or_path: str, v2: bool = False, clip_skip: Optional[int] = None, weight_dtype: torch.dtype = torch.float32, ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: if v2: tokenizer = CLIPTokenizer.from_pretrained( TOKENIZER_V2_MODEL_NAME, subfolder="tokenizer", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", # default is clip skip 2 num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) else: tokenizer = CLIPTokenizer.from_pretrained( TOKENIZER_V1_MODEL_NAME, subfolder="tokenizer", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") return tokenizer, text_encoder, unet, vae def load_checkpoint_model( checkpoint_path: str, v2: bool = False, clip_skip: Optional[int] = None, weight_dtype: torch.dtype = torch.float32, ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: pipe = StableDiffusionPipeline.from_single_file( checkpoint_path, upcast_attention=True if v2 else False, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path) unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2) unet_config["class_embed_type"] = None unet_config["addition_embed_type"] = None converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) unet = UNet2DConditionModel(**unet_config) unet.load_state_dict(converted_unet_checkpoint) tokenizer = pipe.tokenizer text_encoder = pipe.text_encoder vae = pipe.vae if clip_skip is not None: if v2: text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) else: text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) del pipe return tokenizer, text_encoder, unet, vae def load_models( pretrained_model_name_or_path: str, scheduler_name: str, v2: bool = False, v_pred: bool = False, weight_dtype: torch.dtype = torch.float32, ) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: if pretrained_model_name_or_path.endswith( ".ckpt" ) or pretrained_model_name_or_path.endswith(".safetensors"): tokenizer, text_encoder, unet, vae = load_checkpoint_model( pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype ) else: # diffusers tokenizer, text_encoder, unet, vae = load_diffusers_model( pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype ) if scheduler_name: scheduler = create_noise_scheduler( scheduler_name, prediction_type="v_prediction" if v_pred else "epsilon", ) else: scheduler = None return tokenizer, text_encoder, unet, scheduler, vae def load_diffusers_model_xl( pretrained_model_name_or_path: str, weight_dtype: torch.dtype = torch.float32, ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet tokenizers = [ CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ), CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer_2", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, pad_token_id=0, # same as open clip ), ] text_encoders = [ CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ), CLIPTextModelWithProjection.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder_2", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ), ] unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") return tokenizers, text_encoders, unet, vae def load_checkpoint_model_xl( checkpoint_path: str, weight_dtype: torch.dtype = torch.float32, ) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: pipe = StableDiffusionXLPipeline.from_single_file( checkpoint_path, torch_dtype=weight_dtype, cache_dir=DIFFUSERS_CACHE_DIR, ) unet = pipe.unet vae = pipe.vae tokenizers = [pipe.tokenizer, pipe.tokenizer_2] text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if len(text_encoders) == 2: text_encoders[1].pad_token_id = 0 del pipe return tokenizers, text_encoders, unet, vae def load_models_xl( pretrained_model_name_or_path: str, scheduler_name: str, weight_dtype: torch.dtype = torch.float32, noise_scheduler_kwargs=None, ) -> Tuple[ List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel, SchedulerMixin, ]: if pretrained_model_name_or_path.endswith( ".ckpt" ) or pretrained_model_name_or_path.endswith(".safetensors"): (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl( pretrained_model_name_or_path, weight_dtype ) else: # diffusers (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl( pretrained_model_name_or_path, weight_dtype ) if scheduler_name: scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs) else: scheduler = None return tokenizers, text_encoders, unet, scheduler, vae def create_noise_scheduler( scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", noise_scheduler_kwargs=None, prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", ) -> SchedulerMixin: name = scheduler_name.lower().replace(" ", "_") if name.lower() == "ddim": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) elif name.lower() == "ddpm": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) elif name.lower() == "lms": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete scheduler = LMSDiscreteScheduler( **OmegaConf.to_container(noise_scheduler_kwargs) ) elif name.lower() == "euler_a": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral scheduler = EulerAncestralDiscreteScheduler( **OmegaConf.to_container(noise_scheduler_kwargs) ) elif name.lower() == "euler": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral scheduler = EulerDiscreteScheduler( **OmegaConf.to_container(noise_scheduler_kwargs) ) elif name.lower() == "unipc": # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc scheduler = UniPCMultistepScheduler( **OmegaConf.to_container(noise_scheduler_kwargs) ) else: raise ValueError(f"Unknown scheduler name: {name}") return scheduler def torch_gc(): import gc gc.collect() if torch.cuda.is_available(): with torch.cuda.device("cuda"): torch.cuda.empty_cache() torch.cuda.ipc_collect() from enum import Enum class CPUState(Enum): GPU = 0 CPU = 1 MPS = 2 cpu_state = CPUState.GPU xpu_available = False directml_enabled = False def is_intel_xpu(): global cpu_state global xpu_available if cpu_state == CPUState.GPU: if xpu_available: return True return False try: import intel_extension_for_pytorch as ipex if torch.xpu.is_available(): xpu_available = True except: pass try: if torch.backends.mps.is_available(): cpu_state = CPUState.MPS import torch.mps except: pass def get_torch_device(): global directml_enabled global cpu_state if directml_enabled: global directml_device return directml_device if cpu_state == CPUState.MPS: return torch.device("mps") if cpu_state == CPUState.CPU: return torch.device("cpu") else: if is_intel_xpu(): return torch.device("xpu") else: return torch.device(torch.cuda.current_device())