--- license: mit base_model: - timm/deit_small_patch16_224.fb_in1k - timm/deit_tiny_patch16_224.fb_in1k - timm/cait_xxs24_224.fb_dist_in1k metrics: - accuracy tags: - Interpretability - ViT - Classification - XAI --- # ProtoViT: Interpretable Vision Transformer with Adaptive Prototype Learning This repository contains pretrained ProtoViT models for interpretable image classification, as described in our paper "Interpretable Image Classification with Adaptive Prototype-based Vision Transformers". ## Model Description [ProtoViT](https://github.com/Henrymachiyu/ProtoViT) combines Vision Transformers with prototype-based learning to create models that are both highly accurate and interpretable. Rather than functioning as a black box, ProtoViT learns interpretable prototypes that explain its classification decisions through visual similarities. ### Supported Architectures We provide three variants of ProtoViT: - **ProtoViT-T**: Built on DeiT-Tiny backbone - **ProtoViT-S**: Built on DeiT-Small backbone - **ProtoViT-CaiT**: Built on CaiT-XXS24 backbone ## Performance All models were trained and evaluated on the CUB-200-2011 fine-grained bird species classification dataset. | Model Version | Backbone | Resolution | Top-1 Accuracy | Checkpoint | |--------------|----------|------------|----------------|------------| | ProtoViT-T | DeiT-Tiny | 224Γ—224 | 83.36% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/DeiT_Tiny_finetuned0.8336.pth) | | ProtoViT-S | DeiT-Small | 224Γ—224 | 85.30% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/DeiT_Small_finetuned0.8530.pth) | | ProtoViT-CaiT | CaiT_xxs24 | 224Γ—224 | 86.02% | [Download](https://huggingface.co/chiyum609/ProtoViT/blob/main/CaiT_xxs24_224_finetuned0.8602.pth) | ## Features - πŸ” **Interpretable Decisions**: The model performs classification with self-explainatory reasoning based on the input’s similarity to learned prototypes, the key features for each classes. - 🎯 **High Accuracy**: Achieves competitive performance on fine-grained classification tasks - πŸš€ **Multiple Architectures**: Supports various Vision Transformer backbones - πŸ“Š **Analysis Tools**: Comes with tools for both local and global prototype analysis ## Requirements - Python 3.8+ - PyTorch 1.8+ - timm==0.4.12 - torchvision - numpy - pillow ## Limitations and Bias - Data Bias: These models are trained on CUB-200-2011, which may not generalize well to images outside this dataset. - Resolution Constraints: The models are trained at a resolution of 224Γ—224; higher or lower resolutions may impact performance. - Location Misalignment: Same as the CNN based models, these models are not perfectly immune to location misalignment under adversarial attack. ## Citation If you use this model in your research, please cite: ```bibtex @article{ma2024interpretable, title={Interpretable Image Classification with Adaptive Prototype-based Vision Transformers}, author={Ma, Chiyu and Donnelly, Jon and Liu, Wenjun and Vosoughi, Soroush and Rudin, Cynthia and Chen, Chaofan}, journal={arXiv preprint arXiv:2410.20722}, year={2024} } ``` ## Acknowledgements This implementation builds upon the following excellent repositories: - [DeiT](https://github.com/facebookresearch/deit) - [CaiT](https://github.com/facebookresearch/deit) - [ProtoPNet](https://github.com/cfchen-duke/ProtoPNet) ## License This project is released under [MIT] license. ## Contact For any questions or feedback, please: 1. Open an issue in the GitHub repository 2. Contact [Chiyu.ma.gr@dartmouth.edu]