Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from DeepCache import DeepCacheSDHelper | |
from diffusers import ( | |
DEISMultistepScheduler, | |
DPMSolverMultistepScheduler, | |
EulerAncestralDiscreteScheduler, | |
HeunDiscreteScheduler, | |
KDPM2AncestralDiscreteScheduler, | |
LMSDiscreteScheduler, | |
PNDMScheduler, | |
StableDiffusionImg2ImgPipeline, | |
StableDiffusionPipeline, | |
) | |
from diffusers.models import AutoencoderKL, AutoencoderTiny | |
from torch._dynamo import OptimizedModule | |
from .upscaler import RealESRGAN | |
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers") | |
# inspired by ComfyUI | |
# https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_management.py | |
class Loader: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(Loader, cls).__new__(cls) | |
cls._instance.pipe = None | |
cls._instance.upscaler = None | |
return cls._instance | |
def _load_upscaler(self, device=None, scale=4): | |
same_scale = self.upscaler is not None and self.upscaler.scale == scale | |
if scale == 1: | |
self.upscaler = None | |
if scale > 1 and not same_scale: | |
self.upscaler = RealESRGAN(device=device, scale=scale) | |
self.upscaler.load_weights() | |
def _load_deepcache(self, interval=1): | |
has_deepcache = hasattr(self.pipe, "deepcache") | |
if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval: | |
return | |
if has_deepcache: | |
self.pipe.deepcache.disable() | |
else: | |
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe) | |
self.pipe.deepcache.set_params(cache_interval=interval) | |
self.pipe.deepcache.enable() | |
def _load_freeu(self, freeu=False): | |
# https://github.com/huggingface/diffusers/blob/v0.30.0/src/diffusers/models/unets/unet_2d_condition.py | |
block = self.pipe.unet.up_blocks[0] | |
attrs = ["b1", "b2", "s1", "s2"] | |
has_freeu = all(getattr(block, attr, None) is not None for attr in attrs) | |
if has_freeu and not freeu: | |
self.pipe.disable_freeu() | |
elif not has_freeu and freeu: | |
# https://github.com/ChenyangSi/FreeU | |
self.pipe.enable_freeu(b1=1.5, b2=1.6, s1=0.9, s2=0.2) | |
def _load_vae(self, model_name=None, taesd=False, variant=None): | |
vae_type = type(self.pipe.vae) | |
is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule)) | |
is_tiny = issubclass(vae_type, AutoencoderTiny) | |
# by default all models use KL | |
if is_kl and taesd: | |
# can't compile tiny VAE | |
print("Switching to Tiny VAE...") | |
self.pipe.vae = AutoencoderTiny.from_pretrained( | |
pretrained_model_name_or_path="madebyollin/taesd", | |
).to(self.pipe.device) | |
return | |
if is_tiny and not taesd: | |
print("Switching to KL VAE...") | |
model = AutoencoderKL.from_pretrained( | |
pretrained_model_name_or_path=model_name, | |
subfolder="vae", | |
variant=variant, | |
).to(self.pipe.device) | |
self.pipe.vae = torch.compile( | |
mode="reduce-overhead", | |
fullgraph=True, | |
model=model, | |
) | |
def _load_pipeline(self, kind, model, device, dtype, **kwargs): | |
pipelines = { | |
"txt2img": StableDiffusionPipeline, | |
"img2img": StableDiffusionImg2ImgPipeline, | |
} | |
if self.pipe is None: | |
self.pipe = pipelines[kind].from_pretrained(model, **kwargs).to(device, dtype) | |
if not isinstance(self.pipe, pipelines[kind]): | |
self.pipe = pipelines[kind].from_pipe(self.pipe).to(device, dtype) | |
def load( | |
self, | |
kind, | |
model, | |
scheduler, | |
karras, | |
taesd, | |
freeu, | |
deepcache, | |
scale, | |
device, | |
dtype, | |
): | |
model_lower = model.lower() | |
schedulers = { | |
"DEIS 2M": DEISMultistepScheduler, | |
"DPM++ 2M": DPMSolverMultistepScheduler, | |
"DPM2 a": KDPM2AncestralDiscreteScheduler, | |
"Euler a": EulerAncestralDiscreteScheduler, | |
"Heun": HeunDiscreteScheduler, | |
"LMS": LMSDiscreteScheduler, | |
"PNDM": PNDMScheduler, | |
} | |
scheduler_kwargs = { | |
"beta_schedule": "scaled_linear", | |
"timestep_spacing": "leading", | |
"use_karras_sigmas": karras, | |
"beta_start": 0.00085, | |
"beta_end": 0.012, | |
"steps_offset": 1, | |
} | |
if scheduler in ["Euler a", "PNDM"]: | |
del scheduler_kwargs["use_karras_sigmas"] | |
# no fp16 variant | |
if model_lower not in [ | |
"sg161222/realistic_vision_v5.1_novae", | |
"prompthero/openjourney-v4", | |
"linaqruf/anything-v3-1", | |
]: | |
variant = "fp16" | |
else: | |
variant = None | |
pipe_kwargs = { | |
"scheduler": schedulers[scheduler](**scheduler_kwargs), | |
"requires_safety_checker": False, | |
"safety_checker": None, | |
"variant": variant, | |
} | |
if self.pipe is None: | |
print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...") | |
self._load_pipeline(kind, model_lower, device, dtype, **pipe_kwargs) | |
model_name = self.pipe.config._name_or_path | |
same_model = model_name.lower() == model_lower | |
same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler]) | |
same_karras = ( | |
not hasattr(self.pipe.scheduler.config, "use_karras_sigmas") | |
or self.pipe.scheduler.config.use_karras_sigmas == karras | |
) | |
if same_model: | |
if not same_scheduler: | |
print(f"Switching to {scheduler}...") | |
if not same_karras: | |
print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...") | |
if not same_scheduler or not same_karras: | |
self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs) | |
else: | |
self.pipe = None | |
self._load_pipeline(kind, model_lower, device, dtype, **pipe_kwargs) | |
self._load_vae(model_lower, taesd, variant) | |
self._load_freeu(freeu) | |
self._load_deepcache(deepcache) | |
self._load_upscaler(device, scale) | |
torch.cuda.empty_cache() | |
return self.pipe, self.upscaler | |