fffiloni's picture
Migrated from GitHub
d59f323 verified
raw
history blame
5.25 kB
from typing import Dict, Sequence
import torch
from torch.nn.utils.rnn import pad_sequence
from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
pad_for_sequence_parallel)
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
def glamm_collate_fn(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
return_hf_format: bool = False,
use_varlen_attn: bool = False):
seq_parallel_world_size = get_sequence_parallel_world_size()
input_ids, labels = [], []
has_image = any(inst.get('pixel_values') is not None for inst in instances)
has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
has_mask = any(inst.get('masks') is not None for inst in instances)
has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
has_points = any(inst.get('points') is not None for inst in instances)
if use_varlen_attn:
position_ids, cumulative_len = [], []
assert len(instances) == 1, (
f'If utilizing varlen attention, the batch size should be'
f' set to 1, but got {len(instances)}')
assert not has_image, 'Currently, it is not configured to '
'accommodate the use of varlen Attention in multimodal training'
if has_image:
pixel_values = []
if has_grounding_image:
grounding_pixel_values = []
if has_mask:
object_masks = []
if has_bboxes:
object_bboxes = []
if has_points:
prompt_points = []
for example in instances:
input_ids.append(torch.LongTensor(example['input_ids']))
labels.append(torch.LongTensor(example['labels']))
if use_varlen_attn:
cumulative_len.append(torch.IntTensor(example['cumulative_len']))
position_ids.append(torch.LongTensor(example['position_ids']))
if has_image:
pixel_values.append(example['pixel_values'])
if has_grounding_image:
grounding_pixel_values.append(example['g_pixel_values'])
if has_mask:
if 'masks' in example.keys() and example['masks'] is not None:
object_masks.append(example['masks'])
if has_bboxes:
if 'bboxes' in example.keys() and example['bboxes'] is not None:
object_bboxes.append(example['bboxes'])
if has_points:
if 'points' in example.keys() and example['points'] is not None:
prompt_points.append(example['points'])
ori_length = [len(ids) for ids in input_ids]
if len(instances) > 1:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX)
else:
input_ids = torch.stack(input_ids)
labels = torch.stack(labels)
if use_varlen_attn:
assert input_ids.size(1) % seq_parallel_world_size == 0
attention_mask = None
position_ids = torch.stack(position_ids, dim=0)
else:
# Some tokenizers have the same eos token and pad token, so input_ids
# cannot be masked directly based on the pad token id.
attention_mask = torch.zeros_like(input_ids).bool()
for i, length in enumerate(ori_length):
attention_mask[i, :length] = True
bs, seq_len = input_ids.shape
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
if seq_parallel_world_size > 1:
input_ids = pad_for_sequence_parallel(input_ids, pad_index)
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
position_ids = pad_for_sequence_parallel(position_ids, 0)
if attention_mask is not None:
attention_mask = pad_for_sequence_parallel(attention_mask, 0)
if use_varlen_attn:
max_seqlen = (
cumulative_len[0][1:] - # noqa: W504
cumulative_len[0][:-1]).max().item()
data_dict = {
'input_ids': input_ids,
'cumulative_len': cumulative_len,
'position_ids': position_ids,
'labels': labels,
'max_seqlen': max_seqlen
}
else:
data_dict = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
'labels': labels
}
if has_image:
if all(x.shape == pixel_values[0].shape for x in pixel_values):
pixel_values = torch.stack(pixel_values, dim=0)
data_dict['pixel_values'] = pixel_values
if has_grounding_image:
# if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
# grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
data_dict['g_pixel_values'] = grounding_pixel_values
if has_mask:
data_dict['masks'] = object_masks
if has_bboxes:
data_dict['bboxes'] = object_bboxes
if has_points:
data_dict['points'] = prompt_points
if return_hf_format:
return data_dict
else:
return {'data': data_dict, 'data_samples': None}