Spaces:
Running
on
A100
Running
on
A100
import torch | |
from diffusers import AutoencoderKL | |
from einops import rearrange | |
from torch import Tensor | |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder | |
try: | |
import torch_xla.core.xla_model as xm | |
except ImportError: | |
xm = None | |
def vae_encode( | |
media_items: Tensor, | |
vae: AutoencoderKL, | |
split_size: int = 1, | |
vae_per_channel_normalize=False, | |
) -> Tensor: | |
""" | |
Encodes media items (images or videos) into latent representations using a specified VAE model. | |
The function supports processing batches of images or video frames and can handle the processing | |
in smaller sub-batches if needed. | |
Args: | |
media_items (Tensor): A torch Tensor containing the media items to encode. The expected | |
shape is (batch_size, channels, height, width) for images or (batch_size, channels, | |
frames, height, width) for videos. | |
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, | |
pre-configured and loaded with the appropriate model weights. | |
split_size (int, optional): The number of sub-batches to split the input batch into for encoding. | |
If set to more than 1, the input media items are processed in smaller batches according to | |
this value. Defaults to 1, which processes all items in a single batch. | |
Returns: | |
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted | |
to match the input shape, scaled by the model's configuration. | |
Examples: | |
>>> import torch | |
>>> from diffusers import AutoencoderKL | |
>>> vae = AutoencoderKL.from_pretrained('your-model-name') | |
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. | |
>>> latents = vae_encode(images, vae) | |
>>> print(latents.shape) # Output shape will depend on the model's latent configuration. | |
Note: | |
In case of a video, the function encodes the media item frame-by frame. | |
""" | |
is_video_shaped = media_items.dim() == 5 | |
batch_size, channels = media_items.shape[0:2] | |
if channels != 3: | |
raise ValueError(f"Expects tensors with 3 channels, got {channels}.") | |
if is_video_shaped and not isinstance( | |
vae, (VideoAutoencoder, CausalVideoAutoencoder) | |
): | |
media_items = rearrange(media_items, "b c n h w -> (b n) c h w") | |
if split_size > 1: | |
if len(media_items) % split_size != 0: | |
raise ValueError( | |
"Error: The batch size must be divisible by 'train.vae_bs_split" | |
) | |
encode_bs = len(media_items) // split_size | |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] | |
latents = [] | |
if media_items.device.type == "xla": | |
xm.mark_step() | |
for image_batch in media_items.split(encode_bs): | |
latents.append(vae.encode(image_batch).latent_dist.sample()) | |
if media_items.device.type == "xla": | |
xm.mark_step() | |
latents = torch.cat(latents, dim=0) | |
else: | |
latents = vae.encode(media_items).latent_dist.sample() | |
latents = normalize_latents(latents, vae, vae_per_channel_normalize) | |
if is_video_shaped and not isinstance( | |
vae, (VideoAutoencoder, CausalVideoAutoencoder) | |
): | |
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) | |
return latents | |
def vae_decode( | |
latents: Tensor, | |
vae: AutoencoderKL, | |
is_video: bool = True, | |
split_size: int = 1, | |
vae_per_channel_normalize=False, | |
) -> Tensor: | |
is_video_shaped = latents.dim() == 5 | |
batch_size = latents.shape[0] | |
if is_video_shaped and not isinstance( | |
vae, (VideoAutoencoder, CausalVideoAutoencoder) | |
): | |
latents = rearrange(latents, "b c n h w -> (b n) c h w") | |
if split_size > 1: | |
if len(latents) % split_size != 0: | |
raise ValueError( | |
"Error: The batch size must be divisible by 'train.vae_bs_split" | |
) | |
encode_bs = len(latents) // split_size | |
image_batch = [ | |
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize) | |
for latent_batch in latents.split(encode_bs) | |
] | |
images = torch.cat(image_batch, dim=0) | |
else: | |
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize) | |
if is_video_shaped and not isinstance( | |
vae, (VideoAutoencoder, CausalVideoAutoencoder) | |
): | |
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) | |
return images | |
def _run_decoder( | |
latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False | |
) -> Tensor: | |
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): | |
*_, fl, hl, wl = latents.shape | |
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) | |
latents = latents.to(vae.dtype) | |
image = vae.decode( | |
un_normalize_latents(latents, vae, vae_per_channel_normalize), | |
return_dict=False, | |
target_shape=( | |
1, | |
3, | |
fl * temporal_scale if is_video else 1, | |
hl * spatial_scale, | |
wl * spatial_scale, | |
), | |
)[0] | |
else: | |
image = vae.decode( | |
un_normalize_latents(latents, vae, vae_per_channel_normalize), | |
return_dict=False, | |
)[0] | |
return image | |
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: | |
if isinstance(vae, CausalVideoAutoencoder): | |
spatial = vae.spatial_downscale_factor | |
temporal = vae.temporal_downscale_factor | |
else: | |
down_blocks = len( | |
[ | |
block | |
for block in vae.encoder.down_blocks | |
if isinstance(block.downsample, Downsample3D) | |
] | |
) | |
spatial = vae.config.patch_size * 2**down_blocks | |
temporal = ( | |
vae.config.patch_size_t * 2**down_blocks | |
if isinstance(vae, VideoAutoencoder) | |
else 1 | |
) | |
return (temporal, spatial, spatial) | |
def normalize_latents( | |
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False | |
) -> Tensor: | |
return ( | |
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) | |
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) | |
if vae_per_channel_normalize | |
else latents * vae.config.scaling_factor | |
) | |
def un_normalize_latents( | |
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False | |
) -> Tensor: | |
return ( | |
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) | |
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) | |
if vae_per_channel_normalize | |
else latents / vae.config.scaling_factor | |
) | |