batch inference supported?

#7
by chenkq - opened

Thanks for your amazing work! It's awesome!

I am wondering if batch inference is currently supported, as I noticed there’s a function called model.generate_from_batch in the example. The example code in the README file only demonstrates inference using a single sample.

Specifically, I am not sure how to concatenate the "images," "image_input_idx," and "image_masks" fields in the "inputs" provided to the model when dealing with multiple images of different sizes.

Batch inference is supported, you would need to run the process on several inputs and then concatenate the fields together while padding them with -1 so they have the same shape. We will look at adding automatic functionality for that.

@chrisc36 Can you please provide example code for streaming batched inference? :)

is there any update on this?

@d-rau allenai is working on adding support for this model in vllm. Seems like it should be in place soon.

https://github.com/vllm-project/vllm/pull/9016#issue

You should be able to run the branch that supports molmo even now if you're willing to tinker a bit with your vllm installation.

Hi @chrisc36 , I tried to do what you said regarding batch inference but the tokenizer's pad_token_id= 151643 and setting the pad value to -1, does generate attention_masks as required but it leads to an error when generate_from_batch() calls super().generate(...) in modelling_molmo.py, it leads to an error which I have attached below. This error does not happen if I choose a different padding value but then attention_mask does not ignore those tokens and the model generation consists of weird outputs. :
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [8,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [9,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [10,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [11,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [12,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [13,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [14,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [2,0,0], thread: [15,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [32,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [33,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [34,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [35,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [36,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [37,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [38,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [39,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [40,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [41,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [42,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [43,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [108,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [109,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [110,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [111,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [112,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [113,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [114,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [115,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [4,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [5,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [6,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [7,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [8,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [9,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [10,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [11,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [12,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [13,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [14,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [15,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [16,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [17,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [18,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [0,0,0], thread: [19,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [84,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [85,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [86,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [87,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [88,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [89,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [90,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [91,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [92,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [93,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [94,0,0] Assertion srcIndex < srcSelectDimSize failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1,0,0], thread: [95,0,0] Assertion srcIndex < srcSelectDimSize failed.
Traceback (most recent call last):
File "/home/ubuntu/pratyushp/quantize/molmo-inference.py", line 138, in
output = model.generate_from_batch(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/prats/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/allenai/Molmo-7B-D-0924/1721478b71306fb7dc671176d5c204dc7a4d27d7/modeling_molmo.py", line 2212, in generate_from_batch
out = super().generate(
^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/prats/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/prats/lib/python3.11/site-packages/transformers/generation/utils.py", line 2047, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/opt/conda/envs/prats/lib/python3.11/site-packages/transformers/generation/utils.py", line 3061, in _sample
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/prats/lib/python3.11/site-packages/transformers/generation/stopping_criteria.py", line 496, in call
is_done = is_done | criteria(input_ids, scores, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/prats/lib/python3.11/site-packages/transformers/generation/stopping_criteria.py", line 402, in call
embedded = F.embedding(flipped_ids, self.embedding_vec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/prats/lib/python3.11/site-packages/torch/nn/functional.py", line 2267, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Hopefully this is helpful:

import numpy as np
import requests
import torch
from PIL import Image, ImageOps
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from typing import List, Dict

processor = AutoProcessor.from_pretrained(
    "allenai/Molmo-7B-D-0924",
    trust_remote_code=True,
    torch_dtype=torch.float32,
    device_map="auto",
)
model = AutoModelForCausalLM.from_pretrained(
    "allenai/Molmo-7B-D-0924",
    trust_remote_code=True,
    torch_dtype=torch.float32,
    device_map="auto",
)
urls = [
    "https://picsum.photos/id/237/536/354",
    "https://picsum.photos/id/238/536/354",
    "https://picsum.photos/id/239/536/354",
]
prompts = [
    "What breed is this dog?",
    "Describe the colors in this image.",
    "Is this an indoor or outdoor scene?",
]

images_list = []
for url in urls:
    response = requests.get(url)
    image = Image.open(requests.get(url, stream=True).raw)
    images_list.append([image])

texts = ["User: " + prompt + " Assistant:" for prompt in prompts]


def process_batch(
    processor: AutoProcessor,
    texts: List[str],
    images_list: List[List[Image.Image]]
) -> Dict[str, torch.Tensor]:
    """
    Process in batch.
    
    Args:
        processor: The original processor.
        texts: List of text inputs
        images_list: List of lists containing PIL images.
        
    Returns:
        Dict with padded input_ids, images, image_input_idx, image_masks.
    """
    batch_size = len(texts)
    tokens_list = []
    for text in texts:
        tokens = processor.tokenizer.encode(" " + text, add_special_tokens=False)
        tokens_list.append(tokens)
    images_arrays_list = []
    image_idxs_list = []
    for images in images_list:
        if images:
            image_arrays = []
            for image in images:
                if isinstance(image, Image.Image):
                    image = image.convert("RGB")
                    image = ImageOps.exif_transpose(image)
                    image_arrays.append(np.array(image))
                else:
                    assert len(image.shape) == 3 and image.shape[-1] == 3
                    image_arrays.append(image.astype(np.uint8))
            images_arrays_list.append(image_arrays)
            image_idx = [-1] * len(image_arrays)
            image_idxs_list.append(image_idx)
        else:
            images_arrays_list.append(None)
            image_idxs_list.append(None)
    images_kwargs = {
        "max_crops": 12,
        "overlap_margins": [4, 4],
        "base_image_input_size": [336, 336],
        "image_token_length_w": 12,
        "image_token_length_h": 12,
        "image_patch_size": 14,
        "image_padding_mask": True,
    }
    outputs_list = []
    for i in range(batch_size):
        tokens = tokens_list[i]
        images = images_arrays_list[i]
        image_idx = image_idxs_list[i]
        out = processor.image_processor.multimodal_preprocess(
            images=images,
            image_idx=image_idx,
            tokens=np.asarray(tokens).astype(np.int32),
            sequence_length=1536,
            image_patch_token_id=processor.special_token_ids["<im_patch>"],
            image_col_token_id=processor.special_token_ids["<im_col>"],
            image_start_token_id=processor.special_token_ids["<im_start>"],
            image_end_token_id=processor.special_token_ids["<im_end>"],
            **images_kwargs,
        )
        outputs_list.append(out)

    batch_outputs = {}
    for key in outputs_list[0].keys():
        tensors = [torch.from_numpy(out[key]) for out in outputs_list]
        batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
            tensors, batch_first=True, padding_value=-1
        )
    bos = processor.tokenizer.bos_token_id or processor.tokenizer.eos_token_id
    batch_outputs["input_ids"] = torch.nn.functional.pad(
        batch_outputs["input_ids"], (1, 0), value=bos
    )
    if "image_input_idx" in batch_outputs:
        image_input_idx = batch_outputs["image_input_idx"]
        batch_outputs["image_input_idx"] = torch.where(
            image_input_idx < 0, image_input_idx, image_input_idx + 1
        )
    return batch_outputs


inputs = process_batch(processor, texts, images_list)

inputs = {k: v.to(model.device) for k, v in inputs.items()}

output = model.generate_from_batch(
    inputs,
    GenerationConfig(
        max_new_tokens=200,
        stop_sequences=["<|endoftext|>"],
        eos_token_id=processor.tokenizer.eos_token_id,
        pad_token_id=processor.tokenizer.pad_token_id,
    ),
    tokenizer=processor.tokenizer,
)

generated_texts = processor.tokenizer.batch_decode(
    output[:, inputs["input_ids"].size(1) :], skip_special_tokens=True
)
for prompt, text in zip(prompts, generated_texts):
    print(f"\nPrompt: {prompt}")
    print(f"Response: {text}")

@chrisc36 I do think the batch processing should be handled in preprocessing_molmo.py though. Shall I open a PR for that or is it not worth it atm? I see there was a PR open for native transformers support.

Sign up or log in to comment