Image Classification
mlx-image
Safetensors
MLX
vision
File size: 2,185 Bytes
4bc3cd8
3986eb9
4bc3cd8
 
3986eb9
 
 
 
 
 
 
 
 
 
 
 
4bc3cd8
3986eb9
4bc3cd8
3986eb9
4bc3cd8
3986eb9
4bc3cd8
3986eb9
 
 
4bc3cd8
 
3986eb9
 
 
 
4bc3cd8
3986eb9
4bc3cd8
3986eb9
 
 
 
4bc3cd8
3986eb9
 
 
4bc3cd8
3986eb9
 
4bc3cd8
3986eb9
 
4bc3cd8
3986eb9
 
 
 
 
4bc3cd8
3986eb9
 
 
4bc3cd8
3986eb9
 
 
4bc3cd8
3986eb9
4bc3cd8
3986eb9
 
 
4bc3cd8
3986eb9
 
4bc3cd8
3986eb9
 
4bc3cd8
3986eb9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
---
{}
---

    ---
    license: apache-2.0
    tags:
    - mlx
    - mlx-image
    - vision
    - image-classification
    datasets:
    - imagenet-1k
    library_name: mlx-image
    ---
    # vit_small_patch16_224.dino

    A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINO](https://arxiv.org/abs/2104.14294).

    The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.

    Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.

    <div align="center">
    <img width="100%" alt="DINO illustration" src="dino.gif">
    </div>


    ## How to use
    ```bash
    pip install mlx-image
    ```

    Here is how to use this model for image classification:

    ```python
    from mlxim.model import create_model
    from mlxim.io import read_rgb
    from mlxim.transform import ImageNetTransform

    transform = ImageNetTransform(train=False, img_size=224)
    x = transform(read_rgb("cat.png"))
    x = mx.expand_dims(x, 0)

    model = create_model("vit_small_patch16_224.dino")
    model.eval()

    logits, attn_masks = model(x, attn_masks=True)
    ```

    You can also use the embeds from layer before head:
    ```python
    from mlxim.model import create_model
    from mlxim.io import read_rgb
    from mlxim.transform import ImageNetTransform

    transform = ImageNetTransform(train=False, img_size=512)
    x = transform(read_rgb("cat.png"))
    x = mx.expand_dims(x, 0)

    # first option
    model = create_model("vit_small_patch16_224.dino", num_classes=0)
    model.eval()

    embeds = model(x)

    # second option
    model = create_model("vit_small_patch16_224.dino")
    model.eval()

    embeds, attn_masks = model.get_features(x)
    ```

    ## Attention maps
    You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).

    <div align="center">
    <img width="100%" alt="Attention Map" src="attention_maps.png">
    </div>