|
# Quick training guide |
|
Combine it with this guide, it's really helpful! |
|
|
|
[Fine-Tune ViT for Image Classification with Hugging Face Transformers](https://huggingface.co/blog/fine-tune-vit) |
|
## Start |
|
```bash |
|
pip install transformers datasets |
|
``` |
|
## Preparing the data: |
|
|
|
Your data shouldn't look like this: |
|
```json |
|
{ |
|
"file_name": "train/aeae3547df6be819a42dcbb83e65586fd6deb424f134375c1dbc00188b37e2bf.jpeg", |
|
"labels": ["general", "furina (genshin impact)", "1girl", "ahoge", "bangs", "bare shoulders", ...] |
|
} |
|
``` |
|
But it should look more like this: |
|
```json |
|
|
|
{ |
|
"file_name": "train/aeae3547df6be819a42dcbb83e65586fd6deb424f134375c1dbc00188b37e2bf.jpeg", |
|
"labels": ["0", "3028", "4", "702", "8", "9", "382", ...] |
|
} |
|
|
|
``` |
|
|
|
Where the labels should be represented as a list of integers (or anything you define as a number) that correspond to the tags you want to train with – essentially, they're the IDs of the labels. |
|
|
|
Loading labels and their IDs: |
|
|
|
```python |
|
import csv |
|
|
|
with open("labels.csv", "r", encoding="utf-8") as f: |
|
reader = csv.reader(f) |
|
l = [row for row in reader] |
|
header = l[0] # tag_id,name,category |
|
rows = l[1:] |
|
|
|
id2labels = {} |
|
labels2id = {} |
|
|
|
for row in rows: |
|
id2labels[str(row[0])] = row[1] |
|
labels2id[row[1]] = str(row[0]) |
|
``` |
|
|
|
Where `labels.csv` is a file containing labels and their respective IDs. |
|
|
|
Load dataset: |
|
```python |
|
from datasets import load_dataset |
|
dataset = load_dataset("./vit_dataset") |
|
``` |
|
Congratulations! You've completed the toughest challenge. Why, you ask? Training this model took me a whole week just to gather and label the data. |
|
|
|
## Preprocess: |
|
|
|
```python |
|
from transformers import ViTImageProcessor |
|
import torch |
|
model_name_or_path = 'google/vit-base-patch16-224-in21k' |
|
processor = ViTImageProcessor.from_pretrained(model_name_or_path) |
|
|
|
def transform(example_batch): |
|
inputs = processor([x for x in example_batch['image']], return_tensors='pt') |
|
|
|
inputs['labels'] = [] |
|
inputs['label_names'] = [[id2labels[tagid] for tagid in x] for x in example_batch['labels']] |
|
|
|
for x in example_batch['labels']: |
|
x : list |
|
one_hot = [0 for x in range(0, len(labels2id.items()))] |
|
for index in x: |
|
one_hot[int(index)] = 1 |
|
|
|
inputs['labels'] += [one_hot] |
|
|
|
|
|
return inputs |
|
``` |
|
Well, this code might not look pretty, but it gets the job done! As for the images (inputs), we resize them to 224x224 and flatten them out. Now, for the labels (target), we're transforming them into a multi-hot format. Why, you ask? Because I like it that way, and it's simple. |
|
|
|
## Training |
|
These parts are relatively simple so I'll go quickly. |
|
|
|
- Load dataset: |
|
|
|
```python |
|
from torch.utils.data import DataLoader |
|
|
|
batch_size = 16 |
|
|
|
def collate_fn(batch): |
|
data = { |
|
'pixel_values': torch.stack([x['pixel_values'] for x in batch]), |
|
'labels': torch.stack([torch.tensor(x['labels']) for x in batch]), |
|
'label_names' : [x['label_names'] for x in batch] |
|
} |
|
|
|
return data |
|
|
|
train_dataloader = DataLoader(prepared_dataset['train'], collate_fn=collate_fn, batch_size=batch_size) |
|
|
|
eval_dataloader = DataLoader(prepared_dataset['test'], collate_fn=collate_fn, batch_size=1) |
|
``` |
|
- Initialize the model: |
|
|
|
```python |
|
from transformers import ViTForImageClassification, ViTConfig |
|
|
|
configuration = ViTConfig( |
|
num_labels=len(id2labels.items()), |
|
id2label=id2labels, |
|
label2id=labels2id) |
|
model = ViTForImageClassification(config=configuration) |
|
``` |
|
|
|
Setup train: |
|
|
|
```python |
|
device = torch.device('cuda') |
|
test_steps = 5000 |
|
epochs = 50 |
|
mix_precision = torch.float16 |
|
global_steps = 0 |
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) |
|
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=optimizer) |
|
``` |
|
|
|
Test and Evaluation: |
|
```python |
|
import torch |
|
from transformers.modeling_outputs import ImageClassifierOutput |
|
|
|
def test(eval_dataloader : DataLoader, model : ViTForImageClassification, device, t=0.7): |
|
batchs = list(iter(eval_dataloader)) |
|
batch = batchs[0] |
|
|
|
with torch.no_grad(): |
|
pixel_values = batch['pixel_values'].to(device=device) |
|
labels = batch['labels'].to(device=device, dtype=torch.float) |
|
|
|
outputs : ImageClassifierOutput = model(pixel_values=pixel_values) |
|
|
|
logits = outputs.logits |
|
sigmod = torch.nn.Sigmoid() |
|
logits : torch.FloatTensor = sigmod(logits) |
|
predictions = [] |
|
|
|
for idx, p in enumerate(logits[0]): |
|
if p > t: |
|
predictions.append((model.config.id2label[idx], p.item())) |
|
|
|
print(f"label_names : {batch['label_names'][0]}") |
|
print(f"predictions : {predictions}") |
|
|
|
def eval(eval_dataloader : DataLoader, model : ViTForImageClassification, device, t=0.7): |
|
|
|
result = { |
|
"eval_predictions" : 0, |
|
"eval_loss" : 0, |
|
"total_predictions" : 0, |
|
"total_loss" : 0 |
|
} |
|
|
|
for batch in eval_dataloader: |
|
pixel_values = batch['pixel_values'].to(device=device) |
|
labels = batch['labels'].to(device=device, dtype=torch.float) |
|
label_names = batch['label_names'][0] |
|
|
|
prediction = 0 |
|
with torch.no_grad(): |
|
|
|
outputs : ImageClassifierOutput = model(pixel_values=pixel_values, labels=labels) |
|
|
|
logits = outputs.logits |
|
loss = outputs.loss |
|
predictions = [] |
|
|
|
for idx, p in enumerate(logits[0]): |
|
if p > t: |
|
predictions.append(model.config.id2label[idx]) |
|
|
|
for p in predictions: |
|
if p in label_names: |
|
prediction += 1 / len(label_names) |
|
|
|
result['total_predictions'] += prediction |
|
result['total_loss'] += loss.item() |
|
|
|
result['eval_predictions'] = result['total_predictions'] / len(eval_dataloader) |
|
result['eval_loss'] = result['total_loss'] / len(eval_dataloader) |
|
print(result) |
|
``` |
|
Train: |
|
```python |
|
import tqdm |
|
from transformers.modeling_outputs import ImageClassifierOutput |
|
|
|
process_bar = tqdm.tqdm(total=epochs * len(train_dataloader)) |
|
|
|
for e in range(1, epochs + 1): |
|
model.train() |
|
|
|
total_loss = 0 |
|
|
|
for idx, (batch) in enumerate(train_dataloader): |
|
|
|
pixel_values = batch['pixel_values'].to(device=device) |
|
labels = batch['labels'].to(device=device, dtype=torch.float) |
|
|
|
with torch.autocast(device_type=str(device), dtype=mix_precision): |
|
outputs : ImageClassifierOutput = model(pixel_values=pixel_values, labels=labels) |
|
|
|
loss = outputs.loss |
|
total_loss += loss.detach().float() |
|
|
|
loss.backward() |
|
if torch.isnan(loss): |
|
assert False, "NaN detection." |
|
|
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
process_bar.update(1) |
|
process_bar.desc = f"{model.config.problem_type} - Epoch: {e}/{epochs}" |
|
process_bar.set_postfix({'loss' : f'{loss.item():.5f}', "train_loss" : total_loss.item() / len(train_dataloader)}) |
|
|
|
if global_steps % test_steps == 0 and global_steps > 1: |
|
model.eval() |
|
process_bar.desc = f"Evalute - Epoch: {e}/{epochs}" |
|
eval(eval_dataloader=eval_dataloader, model=model, device=device, t=0.3) |
|
test(eval_dataloader, model, device, 0.3) |
|
model.train() |
|
|
|
global_steps += 1 |
|
``` |
|
|
|
Thank you for reading through all this verbose stuff. Of course, all the code above is impromptu; there might be some inconsistencies. Your contributions are highly appreciated. |