IamYash's picture
Create README.md
e5445d2
## Overview
This experiment is created to analyze the training dynamics of the vision transfomrers unde [Prisma](https://github.com/soniajoseph/ViT-Prisma.git) project. The small Vision Transformers were trained and evaluated for the task of shape classification on the dSprites dataset. This dataset consists of 2D shapes generated procedurally, focusing on six independent latent factors. This specific task involved classifying three distinct shapes within the dSprites dataset using ViTs.
All of the training checkpoints are available on the Hugging Face Hub. The checkpoints are summarised in the following table with links to the models on the Hub:
| Size | No. Layers | AttentionOnly | Attention-and-MLP |
|:---:|:---:|:---:|:---:|
| tiny | 1 | [link](https://huggingface.co/IamYash/dSprites-tiny-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-tiny-Attention-and-MLP) |
| base | 2 | [link](https://huggingface.co/IamYash/dSprites-base-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-base-Attention-and-MLP) |
| small | 3 | [link](https://huggingface.co/IamYash/dSprites-small-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-small-Attention-and-MLP) |
| medium | 4 | [link](https://huggingface.co/IamYash/dSprites-medium-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-medium-Attention-and-MLP) |
Here each repo has the multiple intermediate checkpoints. Each checkpoint is stored as `"checkpoint_{i}.pth"`, where `i` the the number of traineng sample the model has been trained on.
The other details regarding training and results is described [Here](https://github.com/soniajoseph/ViT-Prisma/tree/main/docs).
## How to Use
```python
!git clone https://github.com/soniajoseph/ViT-Prisma
!cd ViT-Prisma
!pip install -e .
```
```python
from huggingface_hub import hf_hub_download
import torch
REPO_ID = "IamYash/dSprites-tiny-AttentionOnly"
FILENAME = "model_0.pth"
checkpoint = torch.load(
hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
)
```
```python
from vit_prisma.models.base_vit import BaseViT
from vit_prisma.configs.DSpritesConfig import GlobalConfig
from vit_prisma.utils.wandb_utils import update_dataclass_from_dict
config = GlobalConfig()
print(config)
update_dict = {
'transformer':{
'attention_only': True,
'hidden_dim': 512,
'num_heads': 8,
'num_layers': 1
}
}
update_dataclass_from_dict(config, update_dict)
model = BaseViT(config)
model.load_state_dict(checkpoint['model_state_dict'])
```
---
license: mit
---