Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
e0ba3e9
·
verified ·
1 Parent(s): 3986eb9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -59
README.md CHANGED
@@ -1,81 +1,76 @@
1
  ---
2
- {}
 
 
 
 
 
 
 
 
3
  ---
 
4
 
5
- ---
6
- license: apache-2.0
7
- tags:
8
- - mlx
9
- - mlx-image
10
- - vision
11
- - image-classification
12
- datasets:
13
- - imagenet-1k
14
- library_name: mlx-image
15
- ---
16
- # vit_small_patch16_224.dino
17
 
18
- A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
19
 
20
- The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
21
 
22
- Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
 
 
23
 
24
- <div align="center">
25
- <img width="100%" alt="DINO illustration" src="dino.gif">
26
- </div>
27
 
 
 
 
 
28
 
29
- ## How to use
30
- ```bash
31
- pip install mlx-image
32
- ```
33
 
34
- Here is how to use this model for image classification:
 
 
 
35
 
36
- ```python
37
- from mlxim.model import create_model
38
- from mlxim.io import read_rgb
39
- from mlxim.transform import ImageNetTransform
40
 
41
- transform = ImageNetTransform(train=False, img_size=224)
42
- x = transform(read_rgb("cat.png"))
43
- x = mx.expand_dims(x, 0)
44
 
45
- model = create_model("vit_small_patch16_224.dino")
46
- model.eval()
47
 
48
- logits, attn_masks = model(x, attn_masks=True)
49
- ```
 
 
 
50
 
51
- You can also use the embeds from layer before head:
52
- ```python
53
- from mlxim.model import create_model
54
- from mlxim.io import read_rgb
55
- from mlxim.transform import ImageNetTransform
56
 
57
- transform = ImageNetTransform(train=False, img_size=512)
58
- x = transform(read_rgb("cat.png"))
59
- x = mx.expand_dims(x, 0)
60
 
61
- # first option
62
- model = create_model("vit_small_patch16_224.dino", num_classes=0)
63
- model.eval()
64
 
65
- embeds = model(x)
 
 
66
 
67
- # second option
68
- model = create_model("vit_small_patch16_224.dino")
69
- model.eval()
70
 
71
- embeds, attn_masks = model.get_features(x)
72
- ```
73
 
74
- ## Attention maps
75
- You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).
 
76
 
77
- <div align="center">
78
- <img width="100%" alt="Attention Map" src="attention_maps.png">
79
- </div>
80
-
81
-
 
1
  ---
2
+ license: apache-2.0
3
+ tags:
4
+ - mlx
5
+ - mlx-image
6
+ - vision
7
+ - image-classification
8
+ datasets:
9
+ - imagenet-1k
10
+ library_name: mlx-image
11
  ---
12
+ # vit_small_patch16_224.dino
13
 
14
+ A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.
17
 
18
+ Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.
19
 
20
+ <div align="center">
21
+ <img width="100%" alt="DINO illustration" src="dino.gif">
22
+ </div>
23
 
 
 
 
24
 
25
+ ## How to use
26
+ ```bash
27
+ pip install mlx-image
28
+ ```
29
 
30
+ Here is how to use this model for image classification:
 
 
 
31
 
32
+ ```python
33
+ from mlxim.model import create_model
34
+ from mlxim.io import read_rgb
35
+ from mlxim.transform import ImageNetTransform
36
 
37
+ transform = ImageNetTransform(train=False, img_size=224)
38
+ x = transform(read_rgb("cat.png"))
39
+ x = mx.expand_dims(x, 0)
 
40
 
41
+ model = create_model("vit_small_patch16_224.dino")
42
+ model.eval()
 
43
 
44
+ logits, attn_masks = model(x, attn_masks=True)
45
+ ```
46
 
47
+ You can also use the embeds from layer before head:
48
+ ```python
49
+ from mlxim.model import create_model
50
+ from mlxim.io import read_rgb
51
+ from mlxim.transform import ImageNetTransform
52
 
53
+ transform = ImageNetTransform(train=False, img_size=512)
54
+ x = transform(read_rgb("cat.png"))
55
+ x = mx.expand_dims(x, 0)
 
 
56
 
57
+ # first option
58
+ model = create_model("vit_small_patch16_224.dino", num_classes=0)
59
+ model.eval()
60
 
61
+ embeds = model(x)
 
 
62
 
63
+ # second option
64
+ model = create_model("vit_small_patch16_224.dino")
65
+ model.eval()
66
 
67
+ embeds, attn_masks = model.get_features(x)
68
+ ```
 
69
 
70
+ ## Attention maps
71
+ You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).
72
 
73
+ <div align="center">
74
+ <img width="100%" alt="Attention Map" src="attention_maps.png">
75
+ </div>
76