How do I backpropagate specific output tokens using Trainer?

I have a binary mask that masks the training loss for tokens that I don’t want to be updated in backpropagation. Until now I only set the loss of the tokens I didn’t want to train to zero. But now, I want to completely remove backpropagation for these tokens, to gain speed in training. Does anyone have any idea how to make this modification when I’m using this CustomTrainer?

class CustomTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        labels = inputs.pop('labels')
        loss_mask = inputs.pop('loss_mask')

        # forward

        outputs = model(**inputs)

        logits = outputs.logits

        if torch.isnan(logits).any():
            print('NaN detected in logits')
            print(logits)

        probs = nn.functional.softmax(logits, dim=-1)

        predicted_token_ids = torch.argmax(probs, dim=-1)

        loss_fct = nn.CrossEntropyLoss(reduction='none')
        losses = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        losses = losses.view(-1, inputs['input_ids'].size(1))

        masked_loss = losses * loss_mask

        loss = masked_loss.sum() / (loss_mask.sum() + 1e-9)
        batch_size, seq_length = inputs['input_ids'].size()

        return (loss, outputs) if return_outputs else loss

    def get_train_dataloader(self):

        train_dataset = self.train_dataset
        data_collator = self.data_collator

        dataloader_params = {
            'batch_size': self.args.train_batch_size,
            'collate_fn': data_collator,
            'num_workers': self.args.dataloader_num_workers,
            'pin_memory': self.args.dataloader_pin_memory
        }

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params['shuffle'] = True
            dataloader_params['drop_last'] = self.args.dataloader_drop_last

        return DataLoader(train_dataset, **dataloader_params)

    def get_eval_dataloader(self, eval_dataset=None):
        if eval_dataset is None:
            eval_dataset = self.eval_dataset
        data_collator = self.data_collator

        dataloader_params = {
            'batch_size': self.args.eval_batch_size,
            'collate_fn': data_collator,
            'num_workers': self.args.dataloader_num_workers,
            'pin_memory': self.args.dataloader_pin_memory,
            'shuffle': False,
            'drop_last': self.args.dataloader_drop_last,
        }

        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params.pop('shuffle', None)
            dataloader_params.pop('drop_last', None)

        return DataLoader(eval_dataset, **dataloader_params)

I’ve already tried some modifications, such as detaching the logits in the logits tensor lines that I don’t want backpropagation to go through, but I don’t know if this is the right way.

I need a backpropagation like this, where only the first token and the first EOS are updated in the training (Ignore the active SOS in the output part in the first image):


1 Like