Accelerated PyTorch 2.0 support in Diffusers
Starting from version 0.13.0
, Diffusers supports the latest optimization from the upcoming PyTorch 2.0 release. These include:
- Support for accelerated transformers implementation with memory-efficient attention – no extra dependencies required.
- torch.compile support for extra performance boost when individual models are compiled.
Installation
To benefit from the accelerated transformers implementation and `torch.compile`, we will need to install the nightly version of PyTorch, as the stable version is yet to be released. The first step is to install CUDA 11.7 or CUDA 11.8, as PyTorch 2.0 does not support the previous versions. Once CUDA is installed, torch nightly can be installed using:pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu117
Using accelerated transformers and torch.compile.
Accelerated Transformers implementation
PyTorch 2.0 includes an optimized and memory-efficient attention implementation through the
torch.nn.functional.scaled_dot_product_attention
function, which automatically enables several optimizations depending on the inputs and the GPU type. This is similar to thememory_efficient_attention
from xFormers, but built natively into PyTorch.These optimizations will be enabled by default in Diffusers if PyTorch 2.0 is installed and if
torch.nn.functional.scaled_dot_product_attention
is available. To use it, just installtorch 2.0
as suggested above and simply use the pipeline. For example:import torch from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0]
If you want to enable it explicitly (which is not required), you can do so as shown below.
import torch from diffusers import StableDiffusionPipeline from diffusers.models.cross_attention import AttnProcessor2_0 pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") pipe.unet.set_attn_processor(AttnProcessor2_0()) prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0]
This should be as fast and memory efficient as
xFormers
. More details in our benchmark.
torch.compile
To get an additional speedup, we can use the new
torch.compile
feature. To do so, we simply wrap ourunet
withtorch.compile
. For more information and different options, refer to the torch compile docs.import torch from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( "cuda" ) pipe.unet = torch.compile(pipe.unet) batch_size = 10 prompt = "A photo of an astronaut riding a horse on marse." images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images
Depending on the type of GPU,
compile()
can yield between 2-9% of additional speed-up over the accelerated transformer optimizations. Note, however, that compilation is able to squeeze more performance improvements in more recent GPU architectures such as Ampere (A100, 3090), Ada (4090) and Hopper (H100).Compilation takes some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times.
Benchmark
We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, torch.nn.functional.scaled_dot_product_attention
and torch.compile+torch.nn.functional.scaled_dot_product_attention
.
For the benchmark we used the the stable-diffusion-v1-4 model with 50 steps. The xFormers
benchmark is done using the torch==1.13.1
version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got.
The Speed over xformers
columns denote the speed-up gained over xFormers
using the torch.compile+torch.nn.functional.scaled_dot_product_attention
.
FP16 benchmark
The table below shows the benchmark results for inference using fp16
. As we can see, torch.nn.functional.scaled_dot_product_attention
is as fast as xFormers
(sometimes slightly faster/slower) on all the GPUs we tested.
And using torch.compile
gives further speed-up of up of 10% over xFormers
, but it’s mostly noticeable on the A100 GPU.
The time reported is in seconds.
GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) |
---|---|---|---|---|---|---|
A100 | 10 | 12.02 | 8.7 | 8.79 | 7.89 | 9.31 |
A100 | 16 | 18.95 | 13.57 | 13.67 | 12.25 | 9.73 |
A100 | 32 (1) | OOM | 26.56 | 26.68 | 24.08 | 9.34 |
A100 | 64 | 52.51 | 53.03 | 47.81 | 8.95 | |
A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 |
A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 |
A10 | 10 | 33.69 | 23.53 | 24.19 | 22.52 | 4.29 |
A10 | 16 | OOM | 37.55 | 38.31 | 36.81 | 1.97 |
A10 | 32 (1) | 77.19 | 78.43 | 76.64 | 0.71 | |
A10 | 64 (1) | 173.59 | 158.99 | 155.14 | 10.63 | |
T4 | 4 | 38.81 | 30.09 | 29.74 | 27.55 | 8.44 |
T4 | 8 | OOM | 55.71 | 55.99 | 53.85 | 3.34 |
T4 | 10 | OOM | 68.96 | 69.86 | 65.35 | 5.23 |
T4 | 16 | OOM | 111.47 | 113.26 | 106.93 | 4.07 |
V100 | 4 | 9.84 | 8.16 | 8.09 | 7.65 | 6.25 |
V100 | 8 | OOM | 15.62 | 15.44 | 14.59 | 6.59 |
V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 |
V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 |
3090 | 4 | 10.04 | 7.82 | 7.89 | 7.47 | 4.48 |
3090 | 8 | 19.27 | 14.97 | 15.04 | 14.22 | 5.01 |
3090 | 10 | 24.08 | 18.7 | 18.7 | 17.69 | 5.40 |
3090 | 16 | OOM | 29.06 | 29.06 | 28.2 | 2.96 |
3090 | 32 (1) | 58.05 | 58 | 54.88 | 5.46 | |
3090 | 64 (1) | 126.54 | 126.03 | 117.33 | 7.28 | |
3090 Ti | 4 | 9.07 | 7.14 | 7.15 | 6.81 | 4.62 |
3090 Ti | 8 | 17.51 | 13.65 | 13.72 | 12.99 | 4.84 |
3090 Ti | 10 (2) | 21.79 | 16.85 | 16.93 | 16.02 | 4.93 |
3090 Ti | 16 | OOM | 26.1 | 26.28 | 25.46 | 2.45 |
3090 Ti | 32 (1) | 51.78 | 52.04 | 49.15 | 5.08 | |
3090 Ti | 64 (1) | 112.02 | 112.33 | 103.91 | 7.24 | |
4090 | 4 | 10.48 | 8.37 | 8.32 | 8.01 | 4.30 |
4090 | 8 | 14.33 | 10.22 | 10.42 | 9.78 | 4.31 |
4090 | 16 | 17.07 | 17.46 | 17.15 | -0.47 | |
4090 | 32 (1) | 39.03 | 39.86 | 37.97 | 2.72 | |
4090 | 64 (1) | 77.29 | 79.44 | 77.67 | -0.49 |
FP32 benchmark
The table below shows the benchmark results for inference using fp32
. In this case, torch.nn.functional.scaled_dot_product_attention
is faster than xFormers
on all the GPUs we tested.
Using torch.compile
in addition to the accelerated transformers implementation can yield up to 19% performance improvement over xFormers
in Ampere and Ada cards, and up to 20% (Ampere) or 28% (Ada) over vanilla attention.
GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) |
---|---|---|---|---|---|---|---|
A100 | 4 | 16.56 | 12.42 | 12.2 | 11.84 | 4.67 | 28.50 |
A100 | 10 | OOM | 29.93 | 29.44 | 28.5 | 4.78 | |
A100 | 16 | 47.08 | 46.27 | 44.8 | 4.84 | ||
A100 | 32 | 92.89 | 91.34 | 88.35 | 4.89 | ||
A100 | 64 | 185.3 | 182.71 | 176.48 | 4.76 | ||
A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 |
A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 |
A10 | 8 | 56.19 | 43.53 | 43.86 | 21.94 | ||
A10 | 16 | 116.49 | 88.56 | 86.64 | 25.62 | ||
A10 | 32 | 221.95 | 175.74 | 168.18 | 24.23 | ||
A10 | 48 | 333.23 | 264.84 | 20.52 | |||
T4 | 1 | 28.2 | 24.49 | 23.93 | 23.56 | 3.80 | 16.45 |
T4 | 2 | 52.77 | 45.7 | 45.88 | 45.06 | 1.40 | 14.61 |
T4 | 4 | OOM | 85.72 | 85.78 | 84.48 | 1.45 | |
T4 | 8 | 149.64 | 150.75 | 148.4 | 0.83 | ||
V100 | 1 | 7.4 | 6.84 | 6.8 | 6.66 | 2.63 | 10.00 |
V100 | 2 | 13.85 | 12.81 | 12.66 | 12.35 | 3.59 | 10.83 |
V100 | 4 | OOM | 25.73 | 25.31 | 24.78 | 3.69 | |
V100 | 8 | 43.95 | 43.37 | 42.25 | 3.87 | ||
V100 | 16 | 84.99 | 84.73 | 82.55 | 2.87 | ||
3090 | 1 | 7.09 | 6.78 | 6.11 | 6.03 | 11.06 | 14.95 |
3090 | 4 | 22.69 | 21.45 | 18.67 | 18.09 | 15.66 | 20.27 |
3090 | 8 | 42.59 | 36.75 | 35.59 | 16.44 | ||
3090 | 16 | 85.35 | 72.37 | 70.25 | 17.69 | ||
3090 | 32 (1) | 162.05 | 138.99 | 134.53 | 16.98 | ||
3090 | 48 | 241.91 | 207.75 | 14.12 | |||
3090 Ti | 1 | 6.45 | 6.19 | 5.64 | 5.49 | 11.31 | 14.88 |
3090 Ti | 4 | 20.32 | 19.31 | 16.9 | 16.37 | 15.23 | 19.44 |
3090 Ti | 8 (2) | 37.93 | 33.05 | 31.99 | 15.66 | ||
3090 Ti | 16 | 75.37 | 65.25 | 64.32 | 14.66 | ||
3090 Ti | 32 (1) | 142.55 | 124.44 | 120.74 | 15.30 | ||
3090 Ti | 48 | 213.19 | 186.55 | 12.50 | |||
4090 | 1 | 5.54 | 4.99 | 4.51 | 4.44 | 11.02 | 19.86 |
4090 | 4 | 13.67 | 11.4 | 10.3 | 9.84 | 13.68 | 28.02 |
4090 | 8 | 19.79 | 17.13 | 16.19 | 18.19 | ||
4090 | 16 | 38.62 | 33.14 | 32.31 | 16.34 | ||
4090 | 32 (1) | 76.57 | 65.96 | 62.05 | 18.96 | ||
4090 | 48 | 114.44 | 98.78 | 13.68 |
(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665 This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and batch size of 64
For more details about how this benchmark was run, please refer to this PR.