from typing import Dict, Sequence import torch from torch.nn.utils.rnn import pad_sequence from xtuner.parallel.sequence import (get_sequence_parallel_world_size, pad_for_sequence_parallel) from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX def glamm_collate_fn(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX, return_hf_format: bool = False, use_varlen_attn: bool = False): seq_parallel_world_size = get_sequence_parallel_world_size() input_ids, labels = [], [] has_image = any(inst.get('pixel_values') is not None for inst in instances) has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances) has_mask = any(inst.get('masks') is not None for inst in instances) has_bboxes = any(inst.get('bboxes') is not None for inst in instances) has_points = any(inst.get('points') is not None for inst in instances) if use_varlen_attn: position_ids, cumulative_len = [], [] assert len(instances) == 1, ( f'If utilizing varlen attention, the batch size should be' f' set to 1, but got {len(instances)}') assert not has_image, 'Currently, it is not configured to ' 'accommodate the use of varlen Attention in multimodal training' if has_image: pixel_values = [] if has_grounding_image: grounding_pixel_values = [] if has_mask: object_masks = [] if has_bboxes: object_bboxes = [] if has_points: prompt_points = [] for example in instances: input_ids.append(torch.LongTensor(example['input_ids'])) labels.append(torch.LongTensor(example['labels'])) if use_varlen_attn: cumulative_len.append(torch.IntTensor(example['cumulative_len'])) position_ids.append(torch.LongTensor(example['position_ids'])) if has_image: pixel_values.append(example['pixel_values']) if has_grounding_image: grounding_pixel_values.append(example['g_pixel_values']) if has_mask: if 'masks' in example.keys() and example['masks'] is not None: object_masks.append(example['masks']) if has_bboxes: if 'bboxes' in example.keys() and example['bboxes'] is not None: object_bboxes.append(example['bboxes']) if has_points: if 'points' in example.keys() and example['points'] is not None: prompt_points.append(example['points']) ori_length = [len(ids) for ids in input_ids] if len(instances) > 1: input_ids = pad_sequence( input_ids, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) else: input_ids = torch.stack(input_ids) labels = torch.stack(labels) if use_varlen_attn: assert input_ids.size(1) % seq_parallel_world_size == 0 attention_mask = None position_ids = torch.stack(position_ids, dim=0) else: # Some tokenizers have the same eos token and pad token, so input_ids # cannot be masked directly based on the pad token id. attention_mask = torch.zeros_like(input_ids).bool() for i, length in enumerate(ori_length): attention_mask[i, :length] = True bs, seq_len = input_ids.shape position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) if seq_parallel_world_size > 1: input_ids = pad_for_sequence_parallel(input_ids, pad_index) labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) position_ids = pad_for_sequence_parallel(position_ids, 0) if attention_mask is not None: attention_mask = pad_for_sequence_parallel(attention_mask, 0) if use_varlen_attn: max_seqlen = ( cumulative_len[0][1:] - # noqa: W504 cumulative_len[0][:-1]).max().item() data_dict = { 'input_ids': input_ids, 'cumulative_len': cumulative_len, 'position_ids': position_ids, 'labels': labels, 'max_seqlen': max_seqlen } else: data_dict = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids, 'labels': labels } if has_image: if all(x.shape == pixel_values[0].shape for x in pixel_values): pixel_values = torch.stack(pixel_values, dim=0) data_dict['pixel_values'] = pixel_values if has_grounding_image: # if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values): # grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0) data_dict['g_pixel_values'] = grounding_pixel_values if has_mask: data_dict['masks'] = object_masks if has_bboxes: data_dict['bboxes'] = object_bboxes if has_points: data_dict['points'] = prompt_points if return_hf_format: return data_dict else: return {'data': data_dict, 'data_samples': None}