Diffusers documentation

Accelerated PyTorch 2.0 support in Diffusers

You are viewing v0.16.0 version. A newer version v0.32.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Accelerated PyTorch 2.0 support in Diffusers

Starting from version 0.13.0, Diffusers supports the latest optimization from the upcoming PyTorch 2.0 release. These include:

  1. Support for accelerated transformers implementation with memory-efficient attention – no extra dependencies required.
  2. torch.compile support for extra performance boost when individual models are compiled.

Installation

To benefit from the accelerated attention implementation and `torch.compile`, you just need to install the latest versions of PyTorch 2.0 from `pip`, and make sure you are on diffusers 0.13.0 or later. As explained below, `diffusers` automatically uses the attention optimizations (but not `torch.compile`) when available.
pip install --upgrade torch torchvision diffusers

Using accelerated transformers and torch.compile.

  1. Accelerated Transformers implementation

    PyTorch 2.0 includes an optimized and memory-efficient attention implementation through the torch.nn.functional.scaled_dot_product_attention function, which automatically enables several optimizations depending on the inputs and the GPU type. This is similar to the memory_efficient_attention from xFormers, but built natively into PyTorch.

    These optimizations will be enabled by default in Diffusers if PyTorch 2.0 is installed and if torch.nn.functional.scaled_dot_product_attention is available. To use it, just install torch 2.0 as suggested above and simply use the pipeline. For example:

    import torch
    from diffusers import DiffusionPipeline
    
    pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
    pipe = pipe.to("cuda")
    
    prompt = "a photo of an astronaut riding a horse on mars"
    image = pipe(prompt).images[0]

    If you want to enable it explicitly (which is not required), you can do so as shown below.

    import torch
    from diffusers import DiffusionPipeline
    from diffusers.models.attention_processor import AttnProcessor2_0
    
    pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
    pipe.unet.set_attn_processor(AttnProcessor2_0())
    
    prompt = "a photo of an astronaut riding a horse on mars"
    image = pipe(prompt).images[0]

    This should be as fast and memory efficient as xFormers. More details in our benchmark.

  1. torch.compile

    To get an additional speedup, we can use the new torch.compile feature. To do so, we simply wrap our unet with torch.compile. For more information and different options, refer to the torch compile docs.

    import torch
    from diffusers import DiffusionPipeline
    
    pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
    pipe.unet = torch.compile(pipe.unet)
    
    batch_size = 10
    prompt = "A photo of an astronaut riding a horse on marse."
    images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images

    Depending on the type of GPU, compile() can yield between 2-9% of additional speed-up over the accelerated transformer optimizations. Note, however, that compilation is able to squeeze more performance improvements in more recent GPU architectures such as Ampere (A100, 3090), Ada (4090) and Hopper (H100).

    Compilation takes some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times.

Benchmark

We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, torch.nn.functional.scaled_dot_product_attention and torch.compile+torch.nn.functional.scaled_dot_product_attention. For the benchmark we used the stable-diffusion-v1-4 model with 50 steps. The xFormers benchmark is done using the torch==1.13.1 version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got.

Please refer to our featured blog post in the PyTorch site for more details.

FP16 benchmark

The table below shows the benchmark results for inference using fp16. As we can see, torch.nn.functional.scaled_dot_product_attention is as fast as xFormers (sometimes slightly faster/slower) on all the GPUs we tested. And using torch.compile gives further speed-up of up of 10% over xFormers, but it’s mostly noticeable on the A100 GPU.

The time reported is in seconds.

GPU Batch Size Vanilla Attention xFormers PyTorch2.0 SDPA SDPA + torch.compile Speed over xformers (%)
A100 1 2.69 2.7 1.98 2.47 8.52
A100 2 3.21 3.04 2.38 2.78 8.55
A100 4 5.27 3.91 3.89 3.53 9.72
A100 8 9.74 7.03 7.04 6.62 5.83
A100 10 12.02 8.7 8.67 8.45 2.87
A100 16 18.95 13.57 13.55 13.20 2.73
A100 32 (1) OOM 26.56 26.68 25.85 2.67
A100 64 52.51 53.03 50.93 3.01
A10 4 13.94 9.81 10.01 9.35 4.69
A10 8 27.09 19 19.53 18.33 3.53
A10 10 33.69 23.53 24.19 22.52 4.29
A10 16 OOM 37.55 38.31 36.81 1.97
A10 32 (1) 77.19 78.43 76.64 0.71
A10 64 (1) 173.59 158.99 155.14 10.63
T4 4 38.81 30.09 29.74 27.55 8.44
T4 8 OOM 55.71 55.99 53.85 3.34
T4 10 OOM 68.96 69.86 65.35 5.23
T4 16 OOM 111.47 113.26 106.93 4.07
V100 4 9.84 8.16 8.09 7.65 6.25
V100 8 OOM 15.62 15.44 14.59 6.59
V100 10 OOM 19.52 19.28 18.18 6.86
V100 16 OOM 30.29 29.84 28.22 6.83
3090 1 2.94 2.5 2.42 2.33 6.80
3090 4 10.04 7.82 7.72 7.38 5.63
3090 8 19.27 14.97 14.88 14.15 5.48
3090 10 24.08 18.7 18.62 18.12 3.10
3090 16 OOM 29.06 28.88 28.2 2.96
3090 32 (1) 58.05 57.42 56.28 3.05
3090 64 (1) 126.54 114.27 112.21 11.32
3090 Ti 1 2.7 2.26 2.19 2.12 6.19
3090 Ti 4 9.07 7.14 7.00 6.71 6.02
3090 Ti 8 17.51 13.65 13.53 12.94 5.20
3090 Ti 10 (2) 21.79 16.85 16.77 16.44 2.43
3090 Ti 16 OOM 26.1 26.04 25.53 2.18
3090 Ti 32 (1) 51.78 51.71 50.91 1.68
3090 Ti 64 (1) 112.02 102.78 100.89 9.94
4090 1 4.47 3.98 1.28 1.21 69.60
4090 4 10.48 8.37 3.76 3.56 57.47
4090 8 14.33 10.22 7.43 6.99 31.60
4090 16 17.07 14.98 14.58 14.59
4090 32 (1) 39.03 30.18 29.49 24.44
4090 64 (1) 77.29 61.34 59.96 22.42

FP32 benchmark

The table below shows the benchmark results for inference using fp32. In this case, torch.nn.functional.scaled_dot_product_attention is faster than xFormers on all the GPUs we tested.

Using torch.compile in addition to the accelerated transformers implementation can yield up to 19% performance improvement over xFormers in Ampere and Ada cards, and up to 20% (Ampere) or 28% (Ada) over vanilla attention.

GPU Batch Size Vanilla Attention xFormers PyTorch2.0 SDPA SDPA + torch.compile Speed over xformers (%) Speed over vanilla (%)
A100 1 4.97 3.86 2.6 2.86 25.91 42.45
A100 2 9.03 6.76 4.41 4.21 37.72 53.38
A100 4 16.70 12.42 7.94 7.54 39.29 54.85
A100 10 OOM 29.93 18.70 18.46 38.32
A100 16 47.08 29.41 29.04 38.32
A100 32 92.89 57.55 56.67 38.99
A100 64 185.3 114.8 112.98 39.03
A10 1 10.59 8.81 7.51 7.35 16.57 30.59
A10 4 34.77 27.63 22.77 22.07 20.12 36.53
A10 8 56.19 43.53 43.86 21.94
A10 16 116.49 88.56 86.64 25.62
A10 32 221.95 175.74 168.18 24.23
A10 48 333.23 264.84 20.52
T4 1 28.2 24.49 23.93 23.56 3.80 16.45
T4 2 52.77 45.7 45.88 45.06 1.40 14.61
T4 4 OOM 85.72 85.78 84.48 1.45
T4 8 149.64 150.75 148.4 0.83
V100 1 7.4 6.84 6.8 6.66 2.63 10.00
V100 2 13.85 12.81 12.66 12.35 3.59 10.83
V100 4 OOM 25.73 25.31 24.78 3.69
V100 8 43.95 43.37 42.25 3.87
V100 16 84.99 84.73 82.55 2.87
3090 1 7.09 6.78 5.34 5.35 21.09 24.54
3090 4 22.69 21.45 18.56 18.18 15.24 19.88
3090 8 42.59 36.68 35.61 16.39
3090 16 85.35 72.93 70.18 17.77
3090 32 (1) 162.05 143.46 138.67 14.43
3090 Ti 1 6.45 6.19 4.99 4.89 21.00 24.19
3090 Ti 4 20.32 19.31 17.02 16.48 14.66 18.90
3090 Ti 8 37.93 33.21 32.24 15.00
3090 Ti 16 75.37 66.63 64.5 14.42
3090 Ti 32 (1) 142.55 128.89 124.92 12.37
4090 1 5.54 4.99 2.66 2.58 48.30 53.43
4090 4 13.67 11.4 8.81 8.46 25.79 38.11
4090 8 19.79 17.55 16.62 16.02
4090 16 38.62 35.65 34.07 11.78
4090 32 (1) 76.57 69.48 65.35 14.65
4090 48 114.44 106.3 7.11

(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665. This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and large batch sizes.

For more details about how this benchmark was run, please refer to this PR and to the blog post.