VGG-like Kolmogorov-Arnold Convolutional network with Gram polynomials
This model is a Convolutional version of Kolmogorov-Arnold Network with VGG-11 like architecture, pretrained on Imagenet1k dataset. KANs were originally presented in [1, 2]. Gram version of KAN originally presented in [3]. For more details visit our torch-conv-kan repository on GitHub.
Model description
The model consists of consecutive 10 Gram ConvKAN Layers with InstanceNorm2d, polynomial degree equal to 5, GlobalAveragePooling and Linear classification head:
- KAGN Convolution, 32 filters, 3x3
- Max pooling, 2x2
- KAGN Convolution, 64 filters, 3x3
- Max pooling, 2x2
- KAGN Convolution, 128 filters, 3x3
- KAGN Convolution, 128 filters, 3x3
- Max pooling, 2x2
- KAGN Convolution, 256 filters, 3x3
- KAGN Convolution, 256 filters, 3x3 10 Max pooling, 2x2
- KAGN Convolution, 256 filters, 3x3
- KAGN Convolution, 256 filters, 3x3
- Max pooling, 2x2
- KAGN Convolution, 256 filters, 3x3
- KAGN Convolution, 256 filters, 3x3
- Global Average pooling
- Output layer, 1000 nodes.
Intended uses & limitations
You can use the raw model for image classification or use it as pretrained model for further finetuning.
How to use
First, clone the repository:
git clone https://github.com/IvanDrokin/torch-conv-kan.git
cd torch-conv-kan
pip install -r requirements.txt
Then you can initialize the model and load weights.
import torch
from models import vggkagn
model = vggkagn(3,
1000,
groups=1,
degree=5,
dropout=0.15,
l1_decay=0,
dropout_linear=0.25,
width_scale=2,
vgg_type='VGG11v2',
expected_feature_shape=(1, 1),
affine=True
)
model.from_pretrained('brivangl/vgg_kagn11_v2')
Transforms, used for validation on Imagenet1k:
from torchvision.transforms import v2
transforms_val = v2.Compose([
v2.ToImage(),
v2.Resize(256, antialias=True),
v2.CenterCrop(224),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Training data
This model trained on Imagenet1k dataset (1281167 images in train set)
Training procedure
Model was trained during 200 full epochs with AdamW optimizer, with following parameters:
{'learning_rate': 0.0009, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 5e-06,
'adam_epsilon': 1e-08, 'lr_warmup_steps': 7500, 'lr_power': 0.3, 'lr_end': 1e-07, 'set_grads_to_none': False}
And this augmnetations:
transforms_train = v2.Compose([
v2.ToImage(),
v2.RandomHorizontalFlip(p=0.5),
v2.RandomResizedCrop(224, antialias=True),
v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
v2.AutoAugment(AutoAugmentPolicy.IMAGENET)
]),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Evaluation results
On Imagenet1k Validation:
Accuracy, top1 | Accuracy, top5 | AUC (ovo) | AUC (ovr) |
---|---|---|---|
59.1 | 82.29 | 99.43 | 99.43 |
On Imagenet1k Test: Coming soon
BibTeX entry and citation info
If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.
@misc{torch-conv-kan,
author = {Ivan Drokin},
title = {Torch Conv KAN},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/IvanDrokin/torch-conv-kan}}
}
References
- [1] Ziming Liu et al., "KAN: Kolmogorov-Arnold Networks", 2024, arXiv. https://arxiv.org/abs/2404.19756
- [2] https://github.com/KindXiaoming/pykan
- [3] https://github.com/Khochawongwat/GRAMKAN
- Downloads last month
- 91