Vision Transformer (base-sized model)

Vision Transformer (ViT) model pre-trained on ImageNet-21k (14 million images, 21,843 classes) at resolution 224x224, and fine-tuned on ImageNet 2012 (1 million images, 1,000 classes) at resolution 384x384. It was introduced in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Dosovitskiy et al. and first released in this repository. However, the weights were converted from the timm repository by Ross Wightman, who already converted the weights from JAX to PyTorch. Credits go to him.

Finally the ViT was finetuned on the Chaoyang dataset at resolution 384x384, using a fixed 10% of the training set as the validation set and evaluated on the official test set using the best validation model based on the loss

Augmentation pipeline

To address the issue of class imbalance in our training set, we performed oversampling with repetition. Specifically, we duplicated the minority classes images until we obtained an even distribution across all classes. This resulted in a larger training set, but ensured that our model was exposed to an equal number of samples from each class during training. We verified that this approach did not lead to overfitting or other issues by using a validation set with the original class distribution. We used the following Albumentationsaugmentation pipeline for our experiments:

  • A.Resize(img_size, img_size),
  • A.HorizontalFlip(p=0.5),
  • A.VerticalFlip(p=0.5),
  • A.RandomRotate90(p=0.5),
  • A.RandomResizedCrop(img_size, img_size, scale=(0.5, 1.0), p=0.5),
  • ToTensorV2(p=1.0)

This pipeline consists of the following transformations:

  • Resize: resizes the image to a fixed size of (img_size, img_size).
  • HorizontalFlip: flips the image horizontally with a probability of 0.5.
  • VerticalFlip: flips the image vertically with a probability of 0.5.
  • RandomRotate90: randomly rotates the image by 90, 180, or 270 degrees with a probability of 0.5.
  • RandomResizedCrop: randomly crops and resizes the image to a size between 50% and 100% of the original size, with a probability of 0.5.
  • ToTensorV2: converts the image to a PyTorch tensor.

These transformations were chosen to augment the dataset with a variety of geometric transformations, while preserving important visual features.

Results

Our model represents the current state-of-the-art in the field, as it outperforms previous state-of-the-art models proposed in papers with code, based on the dataset's reference paper. The results are summarized in the following table using macro avg metrics.

Model Accuracy F1-Score Precision Recall
Baseline 0.83 0.77 0.78 0.75
Vit-384-finetuned 0.86 โ†‘3% 0.81 โ†‘4% 0.82 โ†‘4% 0.80 โ†‘5%
Vit-384-from-scratch 0.78 0.74 0.74 0.74
Vit-224-distilled-resnet50 0.74 0.00 0.00 0.00

How to use

Here is how to use this model to classify an image of the Chaoyang dataset into one of the 4 classes:

from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = ViTFeatureExtractor.from_pretrained('Snarci/ViT-base-patch16-384-Chaoyang-finetuned')
model = ViTForImageClassification.from_pretrained('Snarci/ViT-base-patch16-384-Chaoyang-finetuned')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 4 Chaoyang classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Currently, both the feature extractor and model support PyTorch. Tensorflow and JAX/FLAX are coming soon, and the API of ViTFeatureExtractor might change.

Training data

The ViT model was pretrained on ImageNet-21k, a dataset consisting of 14 million images and 21k classes, fine-tuned on ImageNet, a dataset consisting of 1 million images and 1k classes. Finally the ViT was finetuned on the Chaoyang dataset at resolution 384x384, using a fixed 10% of the training set as the validation set

Training procedure

Preprocessing

The exact details of preprocessing of images during training/validation can be found here.

Images are resized/rescaled to the same resolution (224x224 during pre-training, 384x384 during fine-tuning) and normalized across the RGB channels with mean (0.5, 0.5, 0.5) and standard deviation (0.5, 0.5, 0.5).

License

This model is provided for non-commercial use only and may not be used in any research or publication without prior written consent from the author.

Downloads last month
6
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.