Hey everyone!
I’m coding a custom loss function with transformers using a pytorch loop. I need to combine the crossentropy from the trainset with the crossentropy from another labeled set, which was artificially generated (inferred from another model).
Loss = train_loss + artificial_loss.
I usually have way more artificial data than train data, so in a single train loop, I must iterate both sets, and iterate N times more through my artificial data set (with N being the ratio between the sets).
The problem is: when printing my loss, the combined loss and the labeled (train set) loss decreases, but the unlabeled (artificial set) loss always increases to unusually large values. What am I doing wrong?
Training Log Example:
https://drive.google.com/file/d/1-tA9Mn0-yc4RHweP7wO6-H7ddUKc94/view?usp=sharing
Training code:
def train(
model,
train_dataloader,
optimizer,
scheduler,
val_dataloader=None,
evaluate_during_training=False,
is_student=False,
unlabeled_dataloader=None,
unl_to_label_batch_ratio=None,
):
progress_bar = tqdm(range(CFG.num_train_epochs * len(train_dataloader)))
log("Start training...\n")
for epoch_i in range(CFG.num_train_epochs):
if is_student:
log(
f"{'Epoch':^7} | {'Labeled Batch':^14} | {'Unlabeled Batch':^16} | "
f"{'Train Loss':^11} | {'Labeled Loss':^13} | "
f"{'Unlabeled Loss':^15} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}"
)
log("-"*130)
else:
log(
f"{'Epoch':^7} | {'Train Batch':^12} | "
f"{'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}"
)
log("-"*80)
# measure the elapsed time of each epoch
t0_epoch, t0_batch = time.time(), time.time()
# reset tracking variables at the beginning of each epoch
total_loss, batch_loss, batch_unl_loss, batch_lab_loss, batch_counts, = 0, 0, 0, 0, 0
# train loop
model.train()
loss_fn = nn.CrossEntropyLoss()
for step, batch in enumerate(train_dataloader):
batch_counts +=1
batch_inputs = {k: v.to(CFG.device) for k, v in batch.items()}
optimizer.zero_grad()
output = model(**batch_inputs)
# if model is student, train with the noised data aswell
if is_student:
text_col = "text_augmented" if CFG.augmented_data else "text"
unl_logits = []
unl_labels = []
for i in range(unl_to_label_batch_ratio):
unl_batch = next(iter(unlabeled_dataloader))
unl_inputs = tokenizer.batch_encode_plus(
unl_batch[text_col],
padding="max_length",
truncation=True,
max_length=CFG.max_seq_len,
return_tensors="pt"
)
unl_inputs["labels"] = unl_batch["labels"].clone().detach()
unl_batch_inputs = {k: v.to(CFG.device) for k, v in unl_inputs.items()}
unl_output = model(**unl_batch_inputs)
unl_logits.append(unl_output.logits)
unl_labels.append(unl_inputs["labels"])
unl_labels = torch.cat([t for t in unl_labels]).to(CFG.device)
unl_logits = torch.cat([t for t in unl_logits])
unl_loss = loss_fn(unl_logits, unl_labels)
lab_loss = output.loss
loss = lab_loss + unl_loss
batch_lab_loss += lab_loss.item()
batch_unl_loss += unl_loss.item()
else:
loss = output.loss
batch_loss += loss.item()
total_loss += loss.item()
loss.backward()
if CFG.clip_grad:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
progress_bar.update(1)
if (step % 100 == 0 and step != 0) or (step == len(train_dataloader) - 1):
time_elapsed = time.time() - t0_batch
# Print training results
if is_student:
log(
f"{epoch_i + 1:^7} | {step:^14} | {(step*unl_to_label_batch_ratio):^16} | "
f"{batch_loss / batch_counts:^11.6f} | "
f"{batch_lab_loss / batch_counts:^15.6f} | "
f"{batch_unl_loss / batch_counts :^13.6f} | "
f"{'-':^10} | {'-':^9} | {time_elapsed:^9.2f}"
)
else:
log(
f"{epoch_i + 1:^7} | {step:^12} | {batch_loss / batch_counts:^12.6f} | "
f"{'-':^10} | {'-':^9} | {time_elapsed:^9.2f}"
)
batch_loss, batch_lab_loss, batch_unl_los, batch_counts = 0, 0, 0, 0
t0_batch = time.time()
# Calculate the average loss over the entire training data
avg_train_loss = total_loss / len(train_dataloader)
if evaluate_during_training:
val_loss, val_accuracy = evaluate(model, val_dataloader)
time_elapsed = time.time() - t0_epoch
if is_student:
log("-"*130)
log(
f"{epoch_i + 1:^7} | {'-':^14} | {'-':^16} | {avg_train_loss:^11.6f} | "
f"{'-':^15} | {'-':^13}| {val_loss:^10.6f} | "
f"{val_accuracy:^9.2f} | {time_elapsed:^9.2f}"
)
log("-"*130)
else:
log("-"*80)
log(
f"{epoch_i + 1:^7} | {'-':^12} | {avg_train_loss:^12.6f} | "
f"{val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}"
)
log("-"*80)
log("\n")