IamYash commited on
Commit
e5445d2
·
1 Parent(s): 7cdb4b7

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -0
README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Overview
3
+
4
+ This experiment is created to analyze the training dynamics of the vision transfomrers unde [Prisma](https://github.com/soniajoseph/ViT-Prisma.git) project. The small Vision Transformers were trained and evaluated for the task of shape classification on the dSprites dataset. This dataset consists of 2D shapes generated procedurally, focusing on six independent latent factors. This specific task involved classifying three distinct shapes within the dSprites dataset using ViTs.
5
+ All of the training checkpoints are available on the Hugging Face Hub. The checkpoints are summarised in the following table with links to the models on the Hub:
6
+
7
+ | Size | No. Layers | AttentionOnly | Attention-and-MLP |
8
+ |:---:|:---:|:---:|:---:|
9
+ | tiny | 1 | [link](https://huggingface.co/IamYash/dSprites-tiny-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-tiny-Attention-and-MLP) |
10
+ | base | 2 | [link](https://huggingface.co/IamYash/dSprites-base-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-base-Attention-and-MLP) |
11
+ | small | 3 | [link](https://huggingface.co/IamYash/dSprites-small-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-small-Attention-and-MLP) |
12
+ | medium | 4 | [link](https://huggingface.co/IamYash/dSprites-medium-AttentionOnly) | [link](https://huggingface.co/IamYash/dSprites-medium-Attention-and-MLP) |
13
+
14
+ Here each repo has the multiple intermediate checkpoints. Each checkpoint is stored as `"checkpoint_{i}.pth"`, where `i` the the number of traineng sample the model has been trained on.
15
+
16
+ The other details regarding training and results is described [Here](https://github.com/soniajoseph/ViT-Prisma/tree/main/docs).
17
+
18
+ ## How to Use
19
+
20
+ ```python
21
+ !git clone https://github.com/soniajoseph/ViT-Prisma
22
+ !cd ViT-Prisma
23
+ !pip install -e .
24
+ ```
25
+ ```python
26
+ from huggingface_hub import hf_hub_download
27
+ import torch
28
+
29
+ REPO_ID = "IamYash/dSprites-tiny-AttentionOnly"
30
+ FILENAME = "model_0.pth"
31
+
32
+ checkpoint = torch.load(
33
+ hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
34
+ )
35
+ ```
36
+ ```python
37
+ from vit_prisma.models.base_vit import BaseViT
38
+ from vit_prisma.configs.DSpritesConfig import GlobalConfig
39
+ from vit_prisma.utils.wandb_utils import update_dataclass_from_dict
40
+
41
+ config = GlobalConfig()
42
+ print(config)
43
+ update_dict = {
44
+ 'transformer':{
45
+ 'attention_only': True,
46
+ 'hidden_dim': 512,
47
+ 'num_heads': 8,
48
+ 'num_layers': 1
49
+ }
50
+ }
51
+ update_dataclass_from_dict(config, update_dict)
52
+
53
+ model = BaseViT(config)
54
+
55
+ model.load_state_dict(checkpoint['model_state_dict'])
56
+ ```
57
+
58
+
59
+ ---
60
+ license: mit
61
+ ---