I’ve been trying to train a customized version of GPT-NeoX on the MiniPile dataset. My training code is fairly minimal, with the only oddities of using ReLU^2 for activation, T5’s tokenizer, and tied embeddings - since I’m comparing this model to a non-transformer model that uses these. I’m training on 8 GPUs.
from transformers import AutoTokenizer, GPTNeoXConfig, GPTNeoXForCausalLM, DataCollatorForLanguageModeling
from datasets import load_dataset
tokenizer = AutoTokenizer.from_pretrained("t5-small")
cfg = GPTNeoXConfig(
vocab_size = len(tokenizer),
hidden_size = 768,
intermediate_size = 768*4,
num_hidden_layers = 11,
num_attention_heads = 12,
hidden_act = "relu2",
max_position_embeddings = 1024,
tie_word_embeddings = True
)
model = GPTNeoXForCausalLM(cfg)
ds = load_dataset("JeanKaddour/minipile", split="train", cache_dir="/workspace/hf_cache")
ds_val = load_dataset("JeanKaddour/minipile", split="validation", cache_dir="/workspace/hf_cache")
x_n_max = 1024
def tokenize_fn(x):
return tokenizer(x['text'], max_length = x_n_max, truncation = True)
toks = ds.map(tokenize_fn, batched = True, remove_columns=["text"]).shuffle(seed=42)
toks_val = ds_val.map(tokenize_fn, batched = True, remove_columns=["text"])
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors='pt')
from transformers import Trainer, TrainingArguments
import wandb
wandb.login()
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {pytorch_total_params/1000000}M")
args = TrainingArguments(
output_dir="gptx",
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
evaluation_strategy="steps",
eval_steps = 100,
logging_steps = 100,
gradient_accumulation_steps=1,
num_train_epochs=1,
lr_scheduler_type="cosine",
warmup_steps = 1000,
bf16=True,
report_to="wandb",
load_best_model_at_end=True
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=args,
data_collator=data_collator,
train_dataset=toks,
eval_dataset=toks_val
)
trainer.train()
wandb.finish()
trainer.save_model("gptx")
However, although the loss initially decreases, after a few hundred steps it starts increasing again:
Step Training Loss Validation Loss
100 9.824500 9.099127
200 8.797100 8.515280
300 8.215400 7.917030
400 7.648100 7.450994
500 7.399400 7.420805
600 7.374300 7.341578
700 7.382400 7.465085
800 7.594300 7.760824
900 7.914200 8.104758
1000 8.256700 8.487398
1100 8.696600 8.936168
1200 9.142600 9.377343
When I once left it, the loss continued to increase all the way past 11.5 - significantly worse than even the accuracy of uniform random token choices. The same occurs when I use LLaMA with a similar configuration instead of GPT-NeoX, when I use the default activation function instead of relu2, when I untie the embeddings, and under varied learning rates. (Hence, the problem is not the unsupported use of relu2) e.g.:
cfg = GPTNeoXConfig(
vocab_size = len(tokenizer),
hidden_size = 768,
intermediate_size = 768*4,
num_hidden_layers = 12,
num_attention_heads = 12,
tie_word_embeddings = True,
bos_token_id = tokenizer.bos_token_id,
eos_token_id = tokenizer.eos_token_id,
tokenizer_class = "T5TokenizerFast"
)
yields
100 9.347200 8.959894
200 8.786000 8.624897
300 8.459400 8.292471
400 8.108800 7.929269
500 7.737300 7.550851
600 7.386000 7.236368
700 7.134700 7.065014
800 7.073900 7.122225
900 7.251800 7.446449
1000 7.621200 7.849643
1100 8.031700 8.260697
1200 8.493300 8.792882
1300 9.046300 9.347054
1400 9.596400 9.945238
1500 10.211000 10.569137
1600 10.859000 11.239925
1700 11.563300 11.997416
1800 12.349400 12.686112
In contrast, a custom Jax model (non-transformer) gets to 3.6 on the dataset without trouble.
Am I doing something wrong? I can’t seem to find anything obviously incorrect, but perhaps someone here has a better idea of what’s going on?