Accelerate documentation

Big Model Inference

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v1.3.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Big Model Inference

One of the biggest advancements Accelerate provides is Big Model Inference, which allows you to perform inference with models that don’t fully fit on your graphics card.

This tutorial will show you how to use Big Model Inference in Accelerate and the Hugging Face ecosystem.

Accelerate

A typical workflow for loading a PyTorch model is shown below. ModelClass is a model that exceeds the GPU memory of your device (mps or cuda or xpu).

import torch

my_model = ModelClass(...)
state_dict = torch.load(checkpoint_file)
my_model.load_state_dict(state_dict)

With Big Model Inference, the first step is to init an empty skeleton of the model with the init_empty_weights context manager. This doesn’t require any memory because my_model is “parameterless”.

from accelerate import init_empty_weights
with init_empty_weights():
    my_model = ModelClass(...)

Next, the weights are loaded into the model for inference.

The load_checkpoint_and_dispatch() method loads a checkpoint inside your empty model and dispatches the weights for each layer across all available devices, starting with the fastest devices (GPU, MPS, XPU, NPU, MLU, MUSA) first before moving to the slower ones (CPU and hard drive).

Setting device_map="auto" automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.

Refer to the Designing a device map guide for more details on how to design your own device map.

from accelerate import load_checkpoint_and_dispatch

model = load_checkpoint_and_dispatch(
    model, checkpoint=checkpoint_file, device_map="auto"
)

If there are certain “chunks” of layers that shouldn’t be split, pass them to no_split_module_classes (see here for more details).

A models weights can also be sharded into multiple checkpoints to save memory, such as when the state_dict doesn’t fit in memory (see here for more details).

Now that the model is fully dispatched, you can perform inference.

input = torch.randn(2,3)
device_type = next(iter(model.parameters())).device.type
input = input.to(device_type)
output = model(input)

Each time an input is passed through a layer, it is sent from the CPU to the GPU (or disk to CPU to GPU), the output is calculated, and the layer is removed from the GPU going back down the line. While this adds some overhead to inference, it enables you to run any size model on your system, as long as the largest layer fits on your GPU.

Multiple GPUs, or “model parallelism”, can be utilized but only one GPU will be active at any given moment. This forces the GPU to wait for the previous GPU to send it the output. You should launch your script normally with Python instead of other tools like torchrun and accelerate launch.

You may also be interested in pipeline parallelism which utilizes all available GPUs at once, instead of only having one GPU active at a time. This approach is less flexbile though. For more details, refer to the Memory-efficient pipeline parallelism guide.

Take a look at a full example of Big Model Inference below.

import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

with init_empty_weights():
    model = MyModel(...)

model = load_checkpoint_and_dispatch(
    model, checkpoint=checkpoint_file, device_map="auto"
)

input = torch.randn(2,3)
device_type = next(iter(model.parameters())).device.type
input = input.to(device_type)
output = model(input)

Hugging Face ecosystem

Other libraries in the Hugging Face ecosystem, like Transformers or Diffusers, supports Big Model Inference in their from_pretrained constructors.

You just need to add device_map="auto" in from_pretrained to enable Big Model Inference.

For example, load Big Sciences T0pp 11 billion parameter model with Big Model Inference.

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto")

After loading the model, the empty init and smart dispatch steps from before are executed and the model is fully ready to make use of all the resources in your machine. Through these constructors, you can also save more memory by specifying the torch_dtype parameter to load a model in a lower precision.

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto", torch_dtype=torch.float16)

Next steps

For a more detailed explanation of Big Model Inference, make sure to check out the conceptual guide!

< > Update on GitHub