Using gpt-neox for text classification with trainer class

Hi guys
I am trying to use GPTNeoXForSequenceClassification to train my custom dataset
on a single class (num_classes=4) text classification.
I am getting the output missmatch:
torch.Size([1, 4]) torch.Size([1])
And I get very bad predictions, because as I understood it only predicts the first column, like one hot encoding.

As I understood gpt-neox predicts not like gpt-j, but give probits for every labels.
My question is: how can I change the trainer class to use it with GPTNeoXForSequenceClassification ?
something like argmax function or squeeze?
Any ideas?
Here 's my code:
from datasets import DatasetDict
datasets = DatasetDict({
“train”: train_dataset,
“valid”: valid_dataset,
“test”: test_dataset})
print(datasets)
def tokenize_function(example):
return tokenizer(example[“sentence”] , truncation=True, max_length=512, padding=True, return_tensors=“pt”)

tokenized_datasets = datasets.map(tokenize_function, batched=True)
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = model.config.eos_token_id

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenized_datasets = tokenized_datasets.remove_columns([“sentence”])
tokenized_datasets = tokenized_datasets.rename_column(“label”, “labels”)

tokenized_datasets.set_format(“torch”)

from transformers import TrainingArguments

from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

train_dataloader = DataLoader(
tokenized_datasets[“train”], shuffle=True, batch_size=1
, collate_fn=data_collator
)
eval_dataloader = DataLoader(
tokenized_datasets[“valid”], batch_size=1
, collate_fn=data_collator
)

test_dataloader = DataLoader(
tokenized_datasets[“test”], batch_size=1
, collate_fn=data_collator
)

for batch in train_dataloader:
break
{k:v.to(device) for k,v in batch.items()}

batch.to(device)

from datasets import load_metric

metric = load_metric(“accuracy”)

def compute_metrics(eval_pred):
logits, labels = eval_pred

predictions = np.argmax(logits[0], axis = -1)

return metric.compute(predictions=predictions, references=labels)

print(‘Loading configuration…’)
training_args = TrainingArguments(output_dir=‘resultsseq’, num_train_epochs=3, logging_steps=100, save_steps=100,
load_best_model_at_end=True, save_strategy=“steps”, evaluation_strategy=“steps”,
# save_steps=128,
eval_steps=100,
eval_accumulation_steps=10,
# predict_with_generate=True,
per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=1,gradient_checkpointing=True,
optim=“adamw_bnb_8bit”,
warmup_steps=0, weight_decay=0.02, logging_dir=‘logs’, learning_rate=5e-06,
fp16 = False,
)

trainer =Trainer(model=model, args=training_args,
train_dataset=tokenized_datasets[‘train’],
eval_dataset=tokenized_datasets[‘valid’],
compute_metrics=compute_metrics,
tokenizer=tokenizer,
data_collator=data_collator,

)

trainer.train()

The “error” I am getting is:
[2023-06-27 12:31:52,409] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
GPU memory occupied: 424 MB.
Original dataset shape Counter({1: 896, 0: 746, 2: 719, 3: 146})
Undersample dataset shape Counter({0: 146, 1: 146, 2: 146, 3: 146})
Counter({1: 896, 0: 746, 2: 719, 3: 146})
Counter({1: 896, 0: 746, 2: 719, 3: 292})
Counter({1: 896, 0: 746, 2: 719, 3: 292})
cuda

DatasetDict({
train: Dataset({
features: [‘sentence’, ‘label’],
num_rows: 2653
})
valid: Dataset({
features: [‘label’, ‘sentence’],
num_rows: 200
})
test: Dataset({
features: [‘label’, ‘sentence’],
num_rows: 200
})
})
Loading configuration…
skipped Embedding(50278, 6144): 294.59765625M params
skipped: 294.59765625M params
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
torch.Size([1, 4]) torch.Size([1])
{‘loss’: 0.9142, ‘learning_rate’: 4.93717803744189e-06, ‘epoch’: 0.04}

the interesting thing about the problem is, that when I use the gpt-j model it works fine
model_name = “EleutherAI/gpt-j-6B”
instead of
model_name = “EleutherAI/gpt-neox-20b”