I have a binary mask that masks the training loss for tokens that I don’t want to be updated in backpropagation. Until now I only set the loss of the tokens I didn’t want to train to zero. But now, I want to completely remove backpropagation for these tokens, to gain speed in training. Does anyone have any idea how to make this modification when I’m using this CustomTrainer?
class CustomTrainer(transformers.Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop('labels')
loss_mask = inputs.pop('loss_mask')
# forward
outputs = model(**inputs)
logits = outputs.logits
if torch.isnan(logits).any():
print('NaN detected in logits')
print(logits)
probs = nn.functional.softmax(logits, dim=-1)
predicted_token_ids = torch.argmax(probs, dim=-1)
loss_fct = nn.CrossEntropyLoss(reduction='none')
losses = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))
losses = losses.view(-1, inputs['input_ids'].size(1))
masked_loss = losses * loss_mask
loss = masked_loss.sum() / (loss_mask.sum() + 1e-9)
batch_size, seq_length = inputs['input_ids'].size()
return (loss, outputs) if return_outputs else loss
def get_train_dataloader(self):
train_dataset = self.train_dataset
data_collator = self.data_collator
dataloader_params = {
'batch_size': self.args.train_batch_size,
'collate_fn': data_collator,
'num_workers': self.args.dataloader_num_workers,
'pin_memory': self.args.dataloader_pin_memory
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params['shuffle'] = True
dataloader_params['drop_last'] = self.args.dataloader_drop_last
return DataLoader(train_dataset, **dataloader_params)
def get_eval_dataloader(self, eval_dataset=None):
if eval_dataset is None:
eval_dataset = self.eval_dataset
data_collator = self.data_collator
dataloader_params = {
'batch_size': self.args.eval_batch_size,
'collate_fn': data_collator,
'num_workers': self.args.dataloader_num_workers,
'pin_memory': self.args.dataloader_pin_memory,
'shuffle': False,
'drop_last': self.args.dataloader_drop_last,
}
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params.pop('shuffle', None)
dataloader_params.pop('drop_last', None)
return DataLoader(eval_dataset, **dataloader_params)
I’ve already tried some modifications, such as detaching the logits in the logits tensor lines that I don’t want backpropagation to go through, but I don’t know if this is the right way.
I need a backpropagation like this, where only the first token and the first EOS are updated in the training (Ignore the active SOS in the output part in the first image):