How to use ViT MAE for image classification?

My use case it to use the pre-trained weights of MAE to do image classification on my custom image classification datasets.

For the ViT model, there are implementations for both ViTModel (which returns raw hidden-states without any specific head on top) and ViTForImageClassification (which adds a linear layer on top of the final hidden state of the [CLS] token).

MAE has been considered as the state-of-the-art image classification model among models without external data. However, it turns out that Hugging Face only provides ViTMAEModel which returns raw hidden-states without any specific head on top, it does not provide ViTMAEForImageClassification nor ViTMAEModelForImageClassification. Is there any particular reason of not doing it?

One explanation I could think of is that MAE shares the same model architecture as ViT, so one could do the following to fulfil my use case:

image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
model = ViTForImageClassification.from_pretrained("facebook/vit-mae-base")

Is the above code the correct way to use the pre-trained weights of MAE to do image classification?

If yes, I wonder what the intended use cases of ViTMAEModel are?

For the MAE pretraining, we already have ViTMAEForPreTraining, so ViTMAEModel is not used for pre-training.

@nielsr @sayakpaul Could you please help me?

2 Likes

MAE has been considered as the state-of-the-art image classification model among models without external data. However, it turns out that Hugging Face only provides ViTMAEModel which returns raw hidden-states without any specific head on top, it does not provide ViTMAEForImageClassification nor ViTMAEModelForImageClassification . Is there any particular reason of not doing it?

It’s likely because when the model was open-sourced, it didn’t have the classification head params.

One explanation I could think of is that MAE shares the same model architecture as ViT, so one could do the following to fulfil my use case:

If you do the above, then the classification head parameters will be randomly initialized which is likely what you want when fine-tuning.

If yes, I wonder what the intended use cases of ViTMAEModel are?

It’s useful for feature extraction. See this notebook where I used a similar model class called: TFData2VecVisionModel: https://github.com/sayakpaul/TF-2.0-Hacks/blob/master/data2vec_vision_image_classification.ipynb

1 Like

Thanks a lot for your reply! :slight_smile:

It’s useful for feature extraction. See this notebook where I used a similar model class called: TFData2VecVisionModel : TF-2.0-Hacks/data2vec_vision_image_classification.ipynb at master · sayakpaul/TF-2.0-Hacks · GitHub

For this use case (feature extraction), I am wondering it is sufficient to use ViTModel as shown below, is it correct?

If yes, then ViTMAEModel still does not have its own unique use case, but it can be see as a utility class to build ViTMAEForPreTraining.

model = ViTModel.from_pretrained("facebook/vit-mae-base")
1 Like

Hello! Do you have any update about your idea?

ViTMAEModel returns cls_token in addition to patch embeddings, while ViTMAEForPreTraining does not. Check:
outputs.last_hidden_state.shape

1 Like