ProtoViT / README.md
chiyum609's picture
Update README.md
4073fe5 verified
metadata
license: mit
base_model:
  - timm/deit_small_patch16_224.fb_in1k
  - timm/deit_tiny_patch16_224.fb_in1k
  - timm/cait_xxs24_224.fb_dist_in1k
metrics:
  - accuracy
tags:
  - Interpretability
  - ViT
  - Classification
  - XAI

ProtoViT: Interpretable Vision Transformer with Adaptive Prototype Learning

This repository contains pretrained ProtoViT models for interpretable image classification, as described in our paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers".

Model Description

ProtoViT combines Vision Transformers with prototype-based learning to create models that are both highly accurate and interpretable. Rather than functioning as a black box, ProtoViT learns interpretable prototypes that explain its classification decisions through visual similarities.

Supported Architectures

We provide three variants of ProtoViT:

  • ProtoViT-T: Built on DeiT-Tiny backbone
  • ProtoViT-S: Built on DeiT-Small backbone
  • ProtoViT-CaiT: Built on CaiT-XXS24 backbone

Performance

All models were trained and evaluated on the CUB-200-2011 fine-grained bird species classification dataset.

Model Version Backbone Resolution Top-1 Accuracy Checkpoint
ProtoViT-T DeiT-Tiny 224×224 83.36% Download
ProtoViT-S DeiT-Small 224×224 85.30% Download
ProtoViT-CaiT CaiT_xxs24 224×224 86.02% Download

Features

  • 🔍 Interpretable Decisions: The model performs classification with self-explainatory reasoning based on the input’s similarity to learned prototypes, the key features for each classes.
  • 🎯 High Accuracy: Achieves competitive performance on fine-grained classification tasks
  • 🚀 Multiple Architectures: Supports various Vision Transformer backbones
  • 📊 Analysis Tools: Comes with tools for both local and global prototype analysis

Requirements

  • Python 3.8+
  • PyTorch 1.8+
  • timm==0.4.12
  • torchvision
  • numpy
  • pillow

Limitations and Bias

  • Data Bias: These models are trained on CUB-200-2011, which may not generalize well to images outside this dataset.
  • Resolution Constraints: The models are trained at a resolution of 224×224; higher or lower resolutions may impact performance.
  • Location Misalignment: Same as the CNN based models, these models are not perfectly immune to location misalignment under adversarial attack.

Citation

If you use this model in your research, please cite:

@article{ma2024interpretable,
  title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers},
  author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan},
  journal={arXiv preprint arXiv:2410.20722},
  year={2024}
}

Acknowledgements

This implementation builds upon the following excellent repositories:

License

This project is released under [MIT] license.

Contact

For any questions or feedback, please:

  1. Open an issue in the GitHub repository
  2. Contact [[email protected]]