Understanding memory consumption during inference

#34
by andrew-kirfman - opened

Hello!
Is there a good way to quantify the amount of memory that this model will consume during inference based on the input token count of the data I'm generating embeddings for?

When I start the model, its initial consumption is approx 14-15GB of vRAM. Example provided below was run on an AWS EC2 g5.12xlarge with 4 A10s:

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3         3580MiB |
|    1   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3         4210MiB |
|    2   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3         4210MiB |
|    3   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3         3330MiB |
+---------------------------------------------------------------------------------------+

However, when sending input through, the size seems to balloon significantly. For example, the following is memory consumption after sending a 4096 token input through:

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3        10762MiB |
|    1   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3        13334MiB |
|    2   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3        13334MiB |
|    3   N/A  N/A     44696      C   .pyenv/versions/3.11.8/bin/python3        10620MiB |
+---------------------------------------------------------------------------------------+

Is it expected that the model memory consumption would scale so significantly based on input size? Or am I doing something wrong in my hosting config? Is there a good way to cap the memory consumption while still allowing embeddings of larger text sequences?

I am afraid this is indeed an expected phenomenon. The memory consumption of self-attention layers on long sequences is huge.

A few things you can try:
1.Use FlashAttention-2 to reduce GPU memory consumption as in https://huggingface.co/docs/transformers/main/model_doc/mistral#speeding-up-mistral-by-using-flash-attention . It also speeds up inference.
2.Make sure you have turned on the torch.no_grad() context and use fp16 / bf16 if possible.

I've been having a hell of a time getting this running on all 4 gpus on a g5.12xlarge (leveraging all the gpu memory).
Do you by chance have an example from your code that you managed to achieve that (based off the numbers above)?
@andrew-kirfman ?

Sign up or log in to comment