FLAX Latent Consistency Model (LCM) LoRA: SDXL - UNet
Unet with merged LCM weights (lora_scale=0.7) and converted to work with FLAX.
Setup
To use on TPUs:
git clone https://github.com/entrpn/diffusers
cd diffusers
git checkout lcm_flax
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install transformers flax torch torchvision
pip install .
Run
import os
from diffusers import FlaxStableDiffusionXLPipeline
import torch
import time
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import numpy as np
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache(os.path.expanduser("~/jax_cache"))
from diffusers import (
FlaxUNet2DConditionModel,
FlaxLCMScheduler
)
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
weight_dtype = jnp.bfloat16
revision= 'refs/pr/95'
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
base_model, revision=revision, dtype=weight_dtype
)
del params["unet"]
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
"jffacevedo/flax_lcm_unet",
dtype=weight_dtype,
)
scheduler, scheduler_state = FlaxLCMScheduler.from_pretrained(
base_model,
subfolder="scheduler",
revision=revision,
dtype=jnp.float32
)
params["unet"] = unet_params
pipeline.unet = unet
pipeline.scheduler = scheduler
params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params)
params["scheduler"] = scheduler_state
default_prompt = "high-quality photo of a baby dolphin playing in a pool and wearing a party hat"
default_neg_prompt = ""
default_seed = 42
default_guidance_scale = 1.0
default_num_steps = 4
def tokenize_prompt(prompt, neg_prompt):
prompt_ids = pipeline.prepare_inputs(prompt)
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
return prompt_ids, neg_prompt_ids
NUM_DEVICES = jax.device_count()
p_params = replicate(params)
def replicate_all(prompt_ids, neg_prompt_ids, seed):
p_prompt_ids = replicate(prompt_ids)
p_neg_prompt_ids = replicate(neg_prompt_ids)
rng = jax.random.PRNGKey(seed)
rng = jax.random.split(rng, NUM_DEVICES)
return p_prompt_ids, p_neg_prompt_ids, rng
def generate(
prompt,
negative_prompt,
seed=default_seed,
guidance_scale=default_guidance_scale,
num_inference_steps=default_num_steps,
):
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
images = pipeline(
prompt_ids,
p_params,
rng,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
do_classifier_free_guidance=False,
jit=True,
).images
print("images.shape: ", images.shape)
# convert the images to PIL
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
return pipeline.numpy_to_pil(np.array(images))
start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")
dts = []
i = 0
for x in range(2):
start = time.time()
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
neg_prompt = ""
print(f"Prompt: {prompt}")
images = generate(prompt, neg_prompt)
t = time.time() - start
print(f"Inference in {t}")
dts.append(t)
for img in images:
img.save(f'{i:06d}.jpg')
i += 1
mean = np.mean(dts)
stdev = np.std(dts)
print(f"batches: {i}, Mean {mean:.2f} sec/batch± {stdev * 1.96 / np.sqrt(len(dts)):.2f} (95%)")
- Downloads last month
- 9