gmastrapas commited on
Commit
56fe6da
·
1 Parent(s): 949a53f

feat: initial commit

Browse files
Files changed (8) hide show
  1. README.md +14 -0
  2. configuration_clip.py +300 -0
  3. eva_model.py +763 -0
  4. hf_model.py +425 -0
  5. modeling_clip.py +317 -0
  6. processing_clip.py +66 -0
  7. rope_embeddings.py +165 -0
  8. transform.py +458 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jina CLIP
2
+
3
+ The Jina CLIP implementation is hosted in this repository. The model uses:
4
+ * the EVA 02 architecture for the vision tower
5
+ * the Jina BERT with Flash Attention model as a text tower
6
+
7
+ To use the Jina CLIP model, the following packages are required:
8
+ * `torch`
9
+ * `timm`
10
+ * `transformers`
11
+ * `einops`
12
+ * `xformers` to use x-attention
13
+ * `flash-attn` to use flash attention
14
+ * `apex` to use fused layer normalization
configuration_clip.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Code mainly copied from:
4
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/configuration_clip.py
5
+ # and adjusted for Jina CLIP
6
+
7
+ import os
8
+ from copy import deepcopy
9
+ from typing import Any, Dict, Optional, Union
10
+
11
+ from transformers import PretrainedConfig, logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ """ Jina CLIP model configuration """
17
+
18
+
19
+ class JinaCLIPTextConfig(PretrainedConfig):
20
+ model_type = 'jina_clip_text'
21
+
22
+ def __init__(
23
+ self,
24
+ embed_dim: int = 768,
25
+ hf_model_name_or_path: str = 'jinaai/jina-bert-v2-base-en-flash',
26
+ hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
27
+ pooler_type: Optional[str] = None,
28
+ proj_type: Optional[str] = None,
29
+ proj_bias: bool = False,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+
34
+ self.embed_dim = embed_dim
35
+ self.hf_model_name_or_path = hf_model_name_or_path
36
+ self.hf_model_config_kwargs = hf_model_config_kwargs or {}
37
+ self.pooler_type = pooler_type
38
+ self.proj_type = proj_type
39
+ self.proj_bias = proj_bias
40
+
41
+ @classmethod
42
+ def from_pretrained(
43
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
44
+ ) -> 'PretrainedConfig':
45
+ cls._set_token_in_kwargs(kwargs)
46
+
47
+ configdict, kwargs = cls.get_config_dict(
48
+ pretrained_model_name_or_path, **kwargs
49
+ )
50
+
51
+ # get the text config dict if we are loading from JinaCLIPConfig
52
+ if configdict.get('model_type') == 'jina_clip':
53
+ configdict = configdict['text_config']
54
+
55
+ if (
56
+ 'model_type' in configdict
57
+ and hasattr(cls, 'model_type')
58
+ and configdict['model_type'] != cls.model_type
59
+ ):
60
+ logger.warning(
61
+ f'You are using a model of type {configdict["model_type"]} to '
62
+ f'instantiate a model of type {cls.model_type}. This is not supported '
63
+ 'for all configurations of models and can yield errors.'
64
+ )
65
+
66
+ return cls.from_dict(configdict, **kwargs)
67
+
68
+
69
+ class JinaCLIPVisionConfig(PretrainedConfig):
70
+ model_type = 'jina_clip_vision'
71
+
72
+ def __init__(
73
+ self,
74
+ embed_dim: int = 768,
75
+ width: int = 768,
76
+ image_size: int = 224,
77
+ patch_size: int = 16,
78
+ layers: int = 12,
79
+ head_width: int = 64,
80
+ mlp_ratio: float = 4.0,
81
+ ls_init_value: Optional[float] = None,
82
+ patch_dropout: float = 0.0,
83
+ qkv_bias: bool = True,
84
+ fused_layer_norm: bool = False,
85
+ x_attention: bool = False,
86
+ post_norm: bool = False,
87
+ rope_embeddings: bool = False,
88
+ pt_hw_seq_len: int = 16,
89
+ intp_freq: bool = False,
90
+ naive_swiglu: bool = False,
91
+ subln: bool = False,
92
+ drop_path_rate: float = 0.0,
93
+ proj_type: Optional[str] = None,
94
+ **kwargs,
95
+ ):
96
+ super().__init__(**kwargs)
97
+
98
+ self.layers = layers
99
+ self.embed_dim = embed_dim
100
+ self.width = width
101
+ self.head_width = head_width
102
+ self.mlp_ratio = mlp_ratio
103
+ self.image_size = image_size
104
+ self.patch_size = patch_size
105
+ self.ls_init_value = ls_init_value
106
+ self.patch_dropout = patch_dropout
107
+ self.qkv_bias = qkv_bias
108
+ self.fused_layer_norm = fused_layer_norm
109
+ self.x_attention = x_attention
110
+ self.post_norm = post_norm
111
+ self.rope_embeddings = rope_embeddings
112
+ self.pt_hw_seq_len = pt_hw_seq_len
113
+ self.intp_freq = intp_freq
114
+ self.naive_swiglu = naive_swiglu
115
+ self.subln = subln
116
+ self.drop_path_rate = drop_path_rate
117
+ self.proj_type = proj_type
118
+
119
+ @classmethod
120
+ def from_pretrained(
121
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
122
+ ) -> 'PretrainedConfig':
123
+ cls._set_token_in_kwargs(kwargs)
124
+
125
+ configdict, kwargs = cls.get_config_dict(
126
+ pretrained_model_name_or_path, **kwargs
127
+ )
128
+
129
+ # get the vision config dict if we are loading from JinaCLIPConfig
130
+ if configdict.get('model_type') == 'jina_clip':
131
+ configdict = configdict['vision_config']
132
+
133
+ if (
134
+ 'model_type' in configdict
135
+ and hasattr(cls, 'model_type')
136
+ and configdict['model_type'] != cls.model_type
137
+ ):
138
+ logger.warning(
139
+ f'You are using a model of type {configdict["model_type"]} to '
140
+ f'instantiate a model of type {cls.model_type}. This is not supported '
141
+ 'for all configurations of models and can yield errors.'
142
+ )
143
+
144
+ return cls.from_dict(configdict, **kwargs)
145
+
146
+
147
+ class JinaCLIPConfig(PretrainedConfig):
148
+ model_type = 'jina_clip'
149
+ is_composition = True
150
+
151
+ def __init__(
152
+ self,
153
+ text_config: Optional[Dict] = None,
154
+ vision_config: Optional[Dict] = None,
155
+ add_projections: bool = False,
156
+ projection_dim: int = 768,
157
+ logit_scale_init_value: float = 2.6592,
158
+ **kwargs,
159
+ ):
160
+ # If `_config_dict` exist, we use them for the backward compatibility.
161
+ # We pop out these 2 attributes before calling `super().__init__` to avoid
162
+ # them being saved (which causes a lot of confusion!).
163
+
164
+ text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
165
+ vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
166
+
167
+ super().__init__(**kwargs)
168
+
169
+ if text_config_dict is not None:
170
+ if text_config is None:
171
+ text_config = {}
172
+
173
+ # This is the complete result when using `text_config_dict`.
174
+ _text_config_dict = JinaCLIPTextConfig(**text_config_dict).to_dict()
175
+
176
+ # Give a warning if the values exist in both `_text_config_dict` and
177
+ # `text_config` but being different.
178
+ for key, value in _text_config_dict.items():
179
+ if (
180
+ key in text_config
181
+ and value != text_config[key]
182
+ and key not in ['transformers_version']
183
+ ):
184
+ # If specified in `text_config_dict`
185
+ if key in text_config_dict:
186
+ message = (
187
+ f'`{key}` is found in both `text_config_dict` and '
188
+ f'`text_config` but with different values. '
189
+ f'The value `text_config_dict["{key}"]` will be used '
190
+ f'instead.'
191
+ )
192
+ # If inferred from default argument values (
193
+ # just to be super careful)
194
+ else:
195
+ message = (
196
+ f'`text_config_dict` is provided which will be used to '
197
+ f'initialize `JinaCLIPTextConfig`. The '
198
+ f'value `text_config["{key}"]` will be overriden.'
199
+ )
200
+ logger.info(message)
201
+
202
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
203
+ text_config.update(_text_config_dict)
204
+
205
+ if vision_config_dict is not None:
206
+ if vision_config is None:
207
+ vision_config = {}
208
+
209
+ # This is the complete result when using `vision_config_dict`.
210
+ _vision_config_dict = JinaCLIPVisionConfig(**vision_config_dict).to_dict()
211
+ # convert keys to string instead of integer
212
+ if 'id2label' in _vision_config_dict:
213
+ _vision_config_dict['id2label'] = {
214
+ str(key): value
215
+ for key, value in _vision_config_dict['id2label'].items()
216
+ }
217
+
218
+ # Give a warning if the values exist in both `_vision_config_dict`
219
+ # and `vision_config` but being different.
220
+ for key, value in _vision_config_dict.items():
221
+ if (
222
+ key in vision_config
223
+ and value != vision_config[key]
224
+ and key not in ['transformers_version']
225
+ ):
226
+ # If specified in `vision_config_dict`
227
+ if key in vision_config_dict:
228
+ message = (
229
+ f'`{key}` is found in both `vision_config_dict` and '
230
+ f'`vision_config` but with different '
231
+ f'values. The value `vision_config_dict["{key}"]` will '
232
+ f'be used instead.'
233
+ )
234
+ # If inferred from default argument values
235
+ # (just to be super careful)
236
+ else:
237
+ message = (
238
+ f'`vision_config_dict` is provided which will be used to '
239
+ f'initialize `JinaCLIPVisionConfig`. '
240
+ f'The value `vision_config["{key}"]` will be overriden.'
241
+ )
242
+ logger.info(message)
243
+
244
+ # Update all values in `vision_config` with the ones in
245
+ # `_vision_config_dict`.
246
+ vision_config.update(_vision_config_dict)
247
+
248
+ if text_config is None:
249
+ text_config = {}
250
+ logger.info(
251
+ '`text_config` is `None`. Initializing the `JinaCLIPTextConfig` with '
252
+ 'default values.'
253
+ )
254
+
255
+ if vision_config is None:
256
+ vision_config = {}
257
+ logger.info(
258
+ '`vision_config` is `None`. initializing the `JinaCLIPVisionConfig` '
259
+ 'with default values.'
260
+ )
261
+
262
+ self.text_config = JinaCLIPTextConfig(**text_config)
263
+ self.vision_config = JinaCLIPVisionConfig(**vision_config)
264
+
265
+ self.add_projections = add_projections
266
+ self.projection_dim = projection_dim
267
+ self.logit_scale_init_value = logit_scale_init_value
268
+ self.initializer_factor = 1.0
269
+
270
+ if not self.add_projections:
271
+ if self.text_config.embed_dim != self.vision_config.embed_dim:
272
+ raise ValueError(
273
+ 'When projections are disabled (`add_projections=False`), text '
274
+ 'and vision towers need to have the same embedding dimensionality. '
275
+ f'Currently text embedding dim is {self.text_config.embed_dim} != '
276
+ f'{self.vision_config.embed_dim} of the vision tower. '
277
+ 'Either set the same output dim for both towers, or enable '
278
+ 'projections with `add_projections=True`.'
279
+ )
280
+
281
+ @classmethod
282
+ def from_text_vision_configs(
283
+ cls,
284
+ text_config: JinaCLIPTextConfig,
285
+ vision_config: JinaCLIPVisionConfig,
286
+ **kwargs,
287
+ ):
288
+ return cls(
289
+ text_config=text_config.to_dict(),
290
+ vision_config=vision_config.to_dict(),
291
+ projection_dim=text_config.projection_dim,
292
+ **kwargs,
293
+ )
294
+
295
+ def to_dict(self):
296
+ output = deepcopy(self.__dict__)
297
+ output['text_config'] = self.text_config.to_dict()
298
+ output['vision_config'] = self.vision_config.to_dict()
299
+ output['model_type'] = self.__class__.model_type
300
+ return output
eva_model.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from EVA CLIP
3
+ # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import os
8
+ from functools import partial
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ except ImportError or ModuleNotFoundError:
17
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
18
+
19
+ from .rope_embeddings import VisionRotaryEmbeddingFast
20
+
21
+ if os.getenv('ENV_TYPE') == 'deepspeed':
22
+ try:
23
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
24
+ except ImportError or ModuleNotFoundError:
25
+ from torch.utils.checkpoint import checkpoint
26
+ else:
27
+ from torch.utils.checkpoint import checkpoint
28
+
29
+ try:
30
+ import xformers.ops as xops
31
+ except ImportError:
32
+ xops = None
33
+
34
+
35
+ class PatchDropout(nn.Module):
36
+ """
37
+ https://arxiv.org/abs/2212.00794
38
+ """
39
+
40
+ def __init__(self, prob, exclude_first_token=True):
41
+ super().__init__()
42
+ assert 0 <= prob < 1.0
43
+ self.prob = prob
44
+ self.exclude_first_token = exclude_first_token # exclude CLS token
45
+
46
+ def forward(self, x):
47
+ if not self.training or self.prob == 0.0:
48
+ return x
49
+
50
+ if self.exclude_first_token:
51
+ cls_tokens, x = x[:, :1], x[:, 1:]
52
+ else:
53
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
54
+
55
+ batch = x.size()[0]
56
+ num_tokens = x.size()[1]
57
+
58
+ batch_indices = torch.arange(batch)
59
+ batch_indices = batch_indices[..., None]
60
+
61
+ keep_prob = 1 - self.prob
62
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
63
+
64
+ rand = torch.randn(batch, num_tokens)
65
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
66
+
67
+ x = x[batch_indices, patch_indices_keep]
68
+
69
+ if self.exclude_first_token:
70
+ x = torch.cat((cls_tokens, x), dim=1)
71
+
72
+ return x, patch_indices_keep
73
+
74
+
75
+ class DropPath(nn.Module):
76
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
77
+ residual blocks)."""
78
+
79
+ def __init__(self, drop_prob=None):
80
+ super(DropPath, self).__init__()
81
+ self.drop_prob = drop_prob
82
+
83
+ def forward(self, x):
84
+ return drop_path(x, self.drop_prob, self.training)
85
+
86
+ def extra_repr(self) -> str:
87
+ return 'p={}'.format(self.drop_prob)
88
+
89
+
90
+ class Mlp(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_features,
94
+ hidden_features=None,
95
+ out_features=None,
96
+ act_layer=nn.GELU,
97
+ norm_layer=nn.LayerNorm,
98
+ drop=0.0,
99
+ subln=False,
100
+ ):
101
+ super().__init__()
102
+ out_features = out_features or in_features
103
+ hidden_features = hidden_features or in_features
104
+ self.fc1 = nn.Linear(in_features, hidden_features)
105
+ self.act = act_layer()
106
+
107
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
108
+
109
+ self.fc2 = nn.Linear(hidden_features, out_features)
110
+ self.drop = nn.Dropout(drop)
111
+
112
+ def forward(self, x):
113
+ x = self.fc1(x)
114
+ x = self.act(x)
115
+ # x = self.drop(x)
116
+ # commit this for the orignal BERT implement
117
+ x = self.ffn_ln(x)
118
+
119
+ x = self.fc2(x)
120
+ x = self.drop(x)
121
+ return x
122
+
123
+
124
+ class SwiGLU(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_features,
128
+ hidden_features=None,
129
+ out_features=None,
130
+ act_layer=nn.SiLU,
131
+ drop=0.0,
132
+ norm_layer=nn.LayerNorm,
133
+ subln=False,
134
+ ):
135
+ super().__init__()
136
+ out_features = out_features or in_features
137
+ hidden_features = hidden_features or in_features
138
+
139
+ self.w1 = nn.Linear(in_features, hidden_features)
140
+ self.w2 = nn.Linear(in_features, hidden_features)
141
+
142
+ self.act = act_layer()
143
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
144
+ self.w3 = nn.Linear(hidden_features, out_features)
145
+
146
+ self.drop = nn.Dropout(drop)
147
+
148
+ def forward(self, x):
149
+ x1 = self.w1(x)
150
+ x2 = self.w2(x)
151
+ hidden = self.act(x1) * x2
152
+ x = self.ffn_ln(hidden)
153
+ x = self.w3(x)
154
+ x = self.drop(x)
155
+ return x
156
+
157
+
158
+ class Attention(nn.Module):
159
+ def __init__(
160
+ self,
161
+ dim,
162
+ num_heads=8,
163
+ qkv_bias=False,
164
+ qk_scale=None,
165
+ attn_drop=0.0,
166
+ proj_drop=0.0,
167
+ window_size=None,
168
+ attn_head_dim=None,
169
+ xattn=False,
170
+ rope=None,
171
+ subln=False,
172
+ norm_layer=nn.LayerNorm,
173
+ ):
174
+ super().__init__()
175
+ self.num_heads = num_heads
176
+ head_dim = dim // num_heads
177
+ if attn_head_dim is not None:
178
+ head_dim = attn_head_dim
179
+ all_head_dim = head_dim * self.num_heads
180
+ self.scale = qk_scale or head_dim**-0.5
181
+
182
+ self.subln = subln
183
+ if self.subln:
184
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
185
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
186
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
187
+ else:
188
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
189
+
190
+ if qkv_bias:
191
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
192
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
193
+ else:
194
+ self.q_bias = None
195
+ self.v_bias = None
196
+
197
+ if window_size:
198
+ self.window_size = window_size
199
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
200
+ 2 * window_size[1] - 1
201
+ ) + 3
202
+ self.relative_position_bias_table = nn.Parameter(
203
+ torch.zeros(self.num_relative_distance, num_heads)
204
+ ) # 2*Wh-1 * 2*Ww-1, nH
205
+ # cls to token & token 2 cls & cls to cls
206
+
207
+ # get pair-wise relative position index for each token inside the window
208
+ coords_h = torch.arange(window_size[0])
209
+ coords_w = torch.arange(window_size[1])
210
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
211
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
212
+ relative_coords = (
213
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
214
+ ) # 2, Wh*Ww, Wh*Ww
215
+ relative_coords = relative_coords.permute(
216
+ 1, 2, 0
217
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
218
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
219
+ relative_coords[:, :, 1] += window_size[1] - 1
220
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
221
+ relative_position_index = torch.zeros(
222
+ size=(window_size[0] * window_size[1] + 1,) * 2,
223
+ dtype=relative_coords.dtype,
224
+ )
225
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
226
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
227
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
228
+ relative_position_index[0, 0] = self.num_relative_distance - 1
229
+
230
+ self.register_buffer('relative_position_index', relative_position_index)
231
+ else:
232
+ self.window_size = None
233
+ self.relative_position_bias_table = None
234
+ self.relative_position_index = None
235
+
236
+ self.attn_drop = nn.Dropout(attn_drop)
237
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
238
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
239
+ self.proj = nn.Linear(all_head_dim, dim)
240
+ self.proj_drop = nn.Dropout(proj_drop)
241
+ self.xattn = xattn
242
+ self.xattn_drop = attn_drop
243
+
244
+ self.rope = rope
245
+
246
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
247
+ B, N, C = x.shape
248
+ if self.subln:
249
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
250
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
251
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
252
+
253
+ q = q.reshape(B, N, self.num_heads, -1).permute(
254
+ 0, 2, 1, 3
255
+ ) # B, num_heads, N, C
256
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
257
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
258
+ else:
259
+ qkv_bias = None
260
+ if self.q_bias is not None:
261
+ qkv_bias = torch.cat(
262
+ (
263
+ self.q_bias,
264
+ torch.zeros_like(self.v_bias, requires_grad=False),
265
+ self.v_bias,
266
+ )
267
+ )
268
+
269
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
270
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
271
+ 2, 0, 3, 1, 4
272
+ ) # 3, B, num_heads, N, C
273
+ q, k, v = qkv[0], qkv[1], qkv[2]
274
+
275
+ if self.rope:
276
+ # slightly fast impl
277
+ q_t = q[:, :, 1:, :]
278
+ ro_q_t = self.rope(q_t)
279
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
280
+
281
+ k_t = k[:, :, 1:, :]
282
+ ro_k_t = self.rope(k_t)
283
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
284
+
285
+ if self.xattn:
286
+ if xops is None:
287
+ raise ValueError(
288
+ "Can't use xattn without xformers. Please 'pip install xformers'"
289
+ )
290
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
291
+ k = k.permute(0, 2, 1, 3)
292
+ v = v.permute(0, 2, 1, 3)
293
+
294
+ x = xops.memory_efficient_attention(
295
+ q,
296
+ k,
297
+ v,
298
+ p=self.xattn_drop,
299
+ scale=self.scale,
300
+ )
301
+ x = x.reshape(B, N, -1)
302
+ x = self.inner_attn_ln(x)
303
+ x = self.proj(x)
304
+ x = self.proj_drop(x)
305
+ else:
306
+ q = q * self.scale
307
+ attn = q @ k.transpose(-2, -1)
308
+
309
+ if self.relative_position_bias_table is not None:
310
+ relative_position_bias = self.relative_position_bias_table[
311
+ self.relative_position_index.view(-1)
312
+ ].view(
313
+ self.window_size[0] * self.window_size[1] + 1,
314
+ self.window_size[0] * self.window_size[1] + 1,
315
+ -1,
316
+ ) # Wh*Ww,Wh*Ww,nH
317
+ relative_position_bias = relative_position_bias.permute(
318
+ 2, 0, 1
319
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
320
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
321
+
322
+ if rel_pos_bias is not None:
323
+ attn = attn + rel_pos_bias.type_as(attn)
324
+
325
+ if attn_mask is not None:
326
+ attn_mask = attn_mask.bool()
327
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float('-inf'))
328
+
329
+ attn = attn.softmax(dim=-1)
330
+ attn = self.attn_drop(attn)
331
+
332
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
333
+ x = self.inner_attn_ln(x)
334
+ x = self.proj(x)
335
+ x = self.proj_drop(x)
336
+ return x
337
+
338
+
339
+ class Block(nn.Module):
340
+ def __init__(
341
+ self,
342
+ dim,
343
+ num_heads,
344
+ mlp_ratio=4.0,
345
+ qkv_bias=False,
346
+ qk_scale=None,
347
+ drop=0.0,
348
+ attn_drop=0.0,
349
+ drop_path=0.0,
350
+ init_values=None,
351
+ act_layer=nn.GELU,
352
+ norm_layer=nn.LayerNorm,
353
+ window_size=None,
354
+ attn_head_dim=None,
355
+ xattn=False,
356
+ rope=None,
357
+ postnorm=False,
358
+ subln=False,
359
+ naiveswiglu=False,
360
+ ):
361
+ super().__init__()
362
+ self.norm1 = norm_layer(dim)
363
+ self.attn = Attention(
364
+ dim,
365
+ num_heads=num_heads,
366
+ qkv_bias=qkv_bias,
367
+ qk_scale=qk_scale,
368
+ attn_drop=attn_drop,
369
+ proj_drop=drop,
370
+ window_size=window_size,
371
+ attn_head_dim=attn_head_dim,
372
+ xattn=xattn,
373
+ rope=rope,
374
+ subln=subln,
375
+ norm_layer=norm_layer,
376
+ )
377
+ # NOTE: drop path for stochastic depth, we shall see if this is better
378
+ # than dropout here
379
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
380
+ self.norm2 = norm_layer(dim)
381
+ mlp_hidden_dim = int(dim * mlp_ratio)
382
+
383
+ if naiveswiglu:
384
+ self.mlp = SwiGLU(
385
+ in_features=dim,
386
+ hidden_features=mlp_hidden_dim,
387
+ subln=subln,
388
+ norm_layer=norm_layer,
389
+ )
390
+ else:
391
+ self.mlp = Mlp(
392
+ in_features=dim,
393
+ hidden_features=mlp_hidden_dim,
394
+ act_layer=act_layer,
395
+ subln=subln,
396
+ drop=drop,
397
+ )
398
+
399
+ if init_values is not None and init_values > 0:
400
+ self.gamma_1 = nn.Parameter(
401
+ init_values * torch.ones((dim,)), requires_grad=True
402
+ )
403
+ self.gamma_2 = nn.Parameter(
404
+ init_values * torch.ones((dim,)), requires_grad=True
405
+ )
406
+ else:
407
+ self.gamma_1, self.gamma_2 = None, None
408
+
409
+ self.postnorm = postnorm
410
+
411
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
412
+ if self.gamma_1 is None:
413
+ if self.postnorm:
414
+ x = x + self.drop_path(
415
+ self.norm1(
416
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
417
+ )
418
+ )
419
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
420
+ else:
421
+ x = x + self.drop_path(
422
+ self.attn(
423
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
424
+ )
425
+ )
426
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
427
+ else:
428
+ if self.postnorm:
429
+ x = x + self.drop_path(
430
+ self.gamma_1
431
+ * self.norm1(
432
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
433
+ )
434
+ )
435
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
436
+ else:
437
+ x = x + self.drop_path(
438
+ self.gamma_1
439
+ * self.attn(
440
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
441
+ )
442
+ )
443
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
444
+ return x
445
+
446
+
447
+ class PatchEmbed(nn.Module):
448
+ """Image to Patch Embedding"""
449
+
450
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
451
+ super().__init__()
452
+ img_size = to_2tuple(img_size)
453
+ patch_size = to_2tuple(patch_size)
454
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
455
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
456
+ self.img_size = img_size
457
+ self.patch_size = patch_size
458
+ self.num_patches = num_patches
459
+
460
+ self.proj = nn.Conv2d(
461
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
462
+ )
463
+
464
+ def forward(self, x, **kwargs):
465
+ B, C, H, W = x.shape
466
+ # FIXME look at relaxing size constraints
467
+ assert H == self.img_size[0] and W == self.img_size[1], (
468
+ f"Input image size ({H}*{W}) doesn't match model "
469
+ f'({self.img_size[0]}*{self.img_size[1]}).'
470
+ )
471
+ x = self.proj(x).flatten(2).transpose(1, 2)
472
+ return x
473
+
474
+
475
+ class RelativePositionBias(nn.Module):
476
+ def __init__(self, window_size, num_heads):
477
+ super().__init__()
478
+ self.window_size = window_size
479
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
480
+ 2 * window_size[1] - 1
481
+ ) + 3
482
+ self.relative_position_bias_table = nn.Parameter(
483
+ torch.zeros(self.num_relative_distance, num_heads)
484
+ ) # 2*Wh-1 * 2*Ww-1, nH
485
+ # cls to token & token 2 cls & cls to cls
486
+
487
+ # get pair-wise relative position index for each token inside the window
488
+ coords_h = torch.arange(window_size[0])
489
+ coords_w = torch.arange(window_size[1])
490
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
491
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
492
+ relative_coords = (
493
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
494
+ ) # 2, Wh*Ww, Wh*Ww
495
+ relative_coords = relative_coords.permute(
496
+ 1, 2, 0
497
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
498
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
499
+ relative_coords[:, :, 1] += window_size[1] - 1
500
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
501
+ relative_position_index = torch.zeros(
502
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
503
+ )
504
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
505
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
506
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
507
+ relative_position_index[0, 0] = self.num_relative_distance - 1
508
+
509
+ self.register_buffer('relative_position_index', relative_position_index)
510
+
511
+ def forward(self):
512
+ relative_position_bias = self.relative_position_bias_table[
513
+ self.relative_position_index.view(-1)
514
+ ].view(
515
+ self.window_size[0] * self.window_size[1] + 1,
516
+ self.window_size[0] * self.window_size[1] + 1,
517
+ -1,
518
+ ) # Wh*Ww,Wh*Ww,nH
519
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
520
+
521
+
522
+ class EVAVisionTransformer(nn.Module):
523
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
524
+
525
+ def __init__(
526
+ self,
527
+ img_size=224,
528
+ patch_size=16,
529
+ in_chans=3,
530
+ num_classes=0,
531
+ embed_dim=768,
532
+ depth=12,
533
+ num_heads=12,
534
+ mlp_ratio=4.0,
535
+ qkv_bias=False,
536
+ qk_scale=None,
537
+ drop_rate=0.0,
538
+ attn_drop_rate=0.0,
539
+ drop_path_rate=0.0,
540
+ norm_layer=nn.LayerNorm,
541
+ init_values=None,
542
+ patch_dropout=0.0,
543
+ use_abs_pos_emb=True,
544
+ use_rel_pos_bias=False,
545
+ use_shared_rel_pos_bias=False,
546
+ rope=False,
547
+ use_mean_pooling=True,
548
+ init_scale=0.001,
549
+ grad_checkpointing=False,
550
+ xattn=False,
551
+ postnorm=False,
552
+ pt_hw_seq_len=16,
553
+ intp_freq=False,
554
+ naiveswiglu=False,
555
+ subln=False,
556
+ proj_type=None,
557
+ ):
558
+ super().__init__()
559
+ self.image_size = img_size
560
+ self.num_classes = num_classes
561
+ self.num_features = (
562
+ self.embed_dim
563
+ ) = embed_dim # num_features for consistency with other models
564
+
565
+ self.patch_embed = PatchEmbed(
566
+ img_size=img_size,
567
+ patch_size=patch_size,
568
+ in_chans=in_chans,
569
+ embed_dim=embed_dim,
570
+ )
571
+ num_patches = self.patch_embed.num_patches
572
+
573
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
574
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
575
+ if use_abs_pos_emb:
576
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
577
+ else:
578
+ self.pos_embed = None
579
+ self.pos_drop = nn.Dropout(p=drop_rate)
580
+
581
+ if use_shared_rel_pos_bias:
582
+ self.rel_pos_bias = RelativePositionBias(
583
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
584
+ )
585
+ else:
586
+ self.rel_pos_bias = None
587
+
588
+ if rope:
589
+ half_head_dim = embed_dim // num_heads // 2
590
+ hw_seq_len = img_size // patch_size
591
+ self.rope = VisionRotaryEmbeddingFast(
592
+ dim=half_head_dim,
593
+ pt_seq_len=pt_hw_seq_len,
594
+ ft_seq_len=hw_seq_len if intp_freq else None,
595
+ patch_dropout=patch_dropout,
596
+ )
597
+ else:
598
+ self.rope = None
599
+
600
+ self.naiveswiglu = naiveswiglu
601
+
602
+ dpr = [
603
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
604
+ ] # stochastic depth decay rule
605
+ self.use_rel_pos_bias = use_rel_pos_bias
606
+ self.blocks = nn.ModuleList(
607
+ [
608
+ Block(
609
+ dim=embed_dim,
610
+ num_heads=num_heads,
611
+ mlp_ratio=mlp_ratio,
612
+ qkv_bias=qkv_bias,
613
+ qk_scale=qk_scale,
614
+ drop=drop_rate,
615
+ attn_drop=attn_drop_rate,
616
+ drop_path=dpr[i],
617
+ norm_layer=norm_layer,
618
+ init_values=init_values,
619
+ window_size=self.patch_embed.patch_shape
620
+ if use_rel_pos_bias
621
+ else None,
622
+ xattn=xattn,
623
+ rope=self.rope,
624
+ postnorm=postnorm,
625
+ subln=subln,
626
+ naiveswiglu=naiveswiglu,
627
+ )
628
+ for i in range(depth)
629
+ ]
630
+ )
631
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
632
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
633
+ if (num_classes == embed_dim) and (proj_type is None):
634
+ self.head = nn.Identity()
635
+ elif proj_type == 'linear':
636
+ self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias)
637
+ elif proj_type == 'mlp':
638
+ hidden_size = (embed_dim + num_classes) // 2
639
+ self.proj = nn.Sequential(
640
+ nn.Linear(embed_dim, hidden_size, bias=qkv_bias),
641
+ nn.GELU(),
642
+ nn.Linear(hidden_size, num_classes, bias=qkv_bias),
643
+ )
644
+
645
+ if self.pos_embed is not None:
646
+ trunc_normal_(self.pos_embed, std=0.02)
647
+
648
+ trunc_normal_(self.cls_token, std=0.02)
649
+
650
+ self.apply(self._init_weights)
651
+ self.fix_init_weight()
652
+
653
+ if isinstance(self.head, nn.Linear):
654
+ trunc_normal_(self.head.weight, std=0.02)
655
+ self.head.weight.data.mul_(init_scale)
656
+ if qkv_bias:
657
+ self.head.bias.data.mul_(init_scale)
658
+
659
+ # setting a patch_dropout of 0. would mean it is disabled and this function
660
+ # would be the identity fn
661
+ self.patch_dropout = (
662
+ PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
663
+ )
664
+
665
+ self.grad_checkpointing = grad_checkpointing
666
+
667
+ def fix_init_weight(self):
668
+ def rescale(param, layer_id):
669
+ param.div_(math.sqrt(2.0 * layer_id))
670
+
671
+ for layer_id, layer in enumerate(self.blocks):
672
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
673
+ if self.naiveswiglu:
674
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
675
+ else:
676
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
677
+
678
+ def get_cast_dtype(self) -> torch.dtype:
679
+ return self.blocks[0].mlp.fc2.weight.dtype
680
+
681
+ def _init_weights(self, m):
682
+ if isinstance(m, nn.Linear):
683
+ trunc_normal_(m.weight, std=0.02)
684
+ if m.bias is not None:
685
+ nn.init.constant_(m.bias, 0)
686
+ elif isinstance(m, nn.LayerNorm):
687
+ nn.init.constant_(m.bias, 0)
688
+ nn.init.constant_(m.weight, 1.0)
689
+
690
+ def get_num_layers(self):
691
+ return len(self.blocks)
692
+
693
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
694
+ assert (
695
+ unlocked_groups == 0
696
+ ), 'partial locking not currently supported for this model'
697
+ for param in self.parameters():
698
+ param.requires_grad = False
699
+
700
+ @torch.jit.ignore
701
+ def set_grad_checkpointing(self, enable=True):
702
+ self.grad_checkpointing = enable
703
+
704
+ @torch.jit.ignore
705
+ def no_weight_decay(self):
706
+ return {'pos_embed', 'cls_token'}
707
+
708
+ def get_classifier(self):
709
+ return self.head
710
+
711
+ def reset_classifier(self, num_classes, global_pool=''):
712
+ self.num_classes = num_classes
713
+ self.head = (
714
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
715
+ )
716
+
717
+ def forward_features(self, x, return_all_features=False):
718
+ x = self.patch_embed(x)
719
+ batch_size, seq_len, _ = x.size()
720
+
721
+ cls_tokens = self.cls_token.expand(
722
+ batch_size, -1, -1
723
+ ) # stole cls_tokens impl from Phil Wang, thanks
724
+ x = torch.cat((cls_tokens, x), dim=1)
725
+ if self.pos_embed is not None:
726
+ x = x + self.pos_embed
727
+ x = self.pos_drop(x)
728
+
729
+ # a patch_dropout of 0. would mean it is disabled and this function would do
730
+ # nothing but return what was passed in
731
+ if self.rope is not None:
732
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
733
+ x, patch_indices_keep = self.patch_dropout(x)
734
+ self.rope.forward = partial(
735
+ self.rope.forward, patch_indices_keep=patch_indices_keep
736
+ )
737
+ else:
738
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
739
+ x = self.patch_dropout(x)
740
+ else:
741
+ x = self.patch_dropout(x)
742
+
743
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
744
+ for blk in self.blocks:
745
+ if self.grad_checkpointing:
746
+ x = checkpoint(blk, x, (rel_pos_bias,))
747
+ else:
748
+ x = blk(x, rel_pos_bias=rel_pos_bias)
749
+
750
+ if not return_all_features:
751
+ x = self.norm(x)
752
+ if self.fc_norm is not None:
753
+ return self.fc_norm(x.mean(1))
754
+ else:
755
+ return x[:, 0]
756
+ return x
757
+
758
+ def forward(self, x, return_all_features=False):
759
+ if return_all_features:
760
+ return self.forward_features(x, return_all_features)
761
+ x = self.forward_features(x)
762
+ x = self.head(x)
763
+ return x
hf_model.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoConfig, AutoModel, PretrainedConfig
7
+ from transformers.modeling_outputs import (
8
+ BaseModelOutput,
9
+ BaseModelOutputWithPooling,
10
+ BaseModelOutputWithPoolingAndCrossAttentions,
11
+ )
12
+
13
+ """
14
+ HF architecture mapping
15
+ """
16
+
17
+ _HF_ARCH_DICT = {
18
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
19
+ 'roberta': {
20
+ 'config_names': {
21
+ 'context_length': 'max_position_embeddings',
22
+ 'vocab_size': 'vocab_size',
23
+ 'width': 'hidden_size',
24
+ 'heads': 'num_attention_heads',
25
+ 'layers': 'num_hidden_layers',
26
+ 'layer_attr': 'layer',
27
+ 'token_embeddings_attr': 'embeddings',
28
+ },
29
+ 'pooler': 'mean_pooler',
30
+ },
31
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
32
+ 'xlm-roberta': {
33
+ 'config_names': {
34
+ 'context_length': 'max_position_embeddings',
35
+ 'vocab_size': 'vocab_size',
36
+ 'width': 'hidden_size',
37
+ 'heads': 'num_attention_heads',
38
+ 'layers': 'num_hidden_layers',
39
+ 'layer_attr': 'layer',
40
+ 'token_embeddings_attr': 'embeddings',
41
+ },
42
+ 'pooler': 'mean_pooler',
43
+ },
44
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
45
+ 'mt5': {
46
+ 'config_names': {
47
+ # unlimited seqlen
48
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
49
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
50
+ 'context_length': '',
51
+ 'vocab_size': 'vocab_size',
52
+ 'width': 'd_model',
53
+ 'heads': 'num_heads',
54
+ 'layers': 'num_layers',
55
+ 'layer_attr': 'block',
56
+ 'token_embeddings_attr': 'embed_tokens',
57
+ },
58
+ 'pooler': 'mean_pooler',
59
+ },
60
+ # https://huggingface.co/docs/transformers/model_doc/bert
61
+ 'bert': {
62
+ 'config_names': {
63
+ 'context_length': 'max_position_embeddings',
64
+ 'vocab_size': 'vocab_size',
65
+ 'width': 'hidden_size',
66
+ 'heads': 'num_attention_heads',
67
+ 'layers': 'num_hidden_layers',
68
+ },
69
+ 'pooler': 'cls_pooler',
70
+ },
71
+ # https://huggingface.co/docs/transformers/model_doc/m2m_100
72
+ 'm2m_100': {
73
+ 'config_names': {
74
+ 'context_length': 'max_position_embeddings',
75
+ 'vocab_size': 'vocab_size',
76
+ 'width': 'd_model',
77
+ 'heads': 'encoder_attention_heads',
78
+ 'layers': 'encoder_layers',
79
+ },
80
+ 'pooler': 'cls_pooler',
81
+ },
82
+ }
83
+
84
+
85
+ """
86
+ Pooling functions
87
+ """
88
+
89
+ _POOLERS = {}
90
+
91
+
92
+ def _camel2snake(s):
93
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
94
+
95
+
96
+ def register_pooler(cls):
97
+ """Decorator registering pooler class"""
98
+ _POOLERS[_camel2snake(cls.__name__)] = cls
99
+ return cls
100
+
101
+
102
+ @register_pooler
103
+ class MeanPooler(nn.Module):
104
+ """Mean pooling"""
105
+
106
+ @staticmethod
107
+ def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
108
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
109
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
110
+
111
+
112
+ @register_pooler
113
+ class MaxPooler(nn.Module):
114
+ """
115
+ Max pooling
116
+ """
117
+
118
+ @staticmethod
119
+ def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
120
+ masked_output = x.last_hidden_state.masked_fill(
121
+ attention_mask.unsqueeze(-1), -torch.inf
122
+ )
123
+ return masked_output.max(1).values
124
+
125
+
126
+ @register_pooler
127
+ class ClsPooler(nn.Module):
128
+ """
129
+ CLS token pooling
130
+ """
131
+
132
+ def __init__(self, use_pooler_output=True):
133
+ super().__init__()
134
+ self.cls_token_position = 0
135
+ self.use_pooler_output = use_pooler_output
136
+
137
+ def forward(self, x: BaseModelOutput, _: torch.Tensor):
138
+ if (
139
+ self.use_pooler_output
140
+ and isinstance(
141
+ x,
142
+ (
143
+ BaseModelOutputWithPooling,
144
+ BaseModelOutputWithPoolingAndCrossAttentions,
145
+ ),
146
+ )
147
+ and (x.pooler_output is not None)
148
+ ):
149
+ return x.pooler_output
150
+
151
+ return x.last_hidden_state[:, self.cls_token_position, :]
152
+
153
+
154
+ """
155
+ HF text model
156
+ """
157
+
158
+
159
+ class HFTextEncoder(nn.Module):
160
+ output_tokens: torch.jit.Final[bool]
161
+
162
+ def __init__(
163
+ self,
164
+ model_name_or_path: str,
165
+ output_dim: int,
166
+ config: PretrainedConfig = None,
167
+ pooler_type: str = None,
168
+ proj_type: str = None,
169
+ proj_bias: bool = False,
170
+ pretrained: bool = True,
171
+ output_tokens: bool = False,
172
+ trust_remote_code: bool = False,
173
+ revision: Optional[str] = None,
174
+ model_config_kwargs: Optional[Dict] = None,
175
+ ):
176
+ super().__init__()
177
+ self.output_tokens = output_tokens
178
+ self.output_dim = output_dim
179
+
180
+ # TODO: find better way to get this information
181
+ uses_transformer_pooler = pooler_type == 'cls_pooler'
182
+ model_config_kwargs = model_config_kwargs or {}
183
+
184
+ if config is None:
185
+ self.config = AutoConfig.from_pretrained(
186
+ model_name_or_path,
187
+ trust_remote_code=trust_remote_code,
188
+ code_revision=revision,
189
+ )
190
+ self.config.update(model_config_kwargs)
191
+ create_func, model_args = (
192
+ (AutoModel.from_pretrained, model_name_or_path)
193
+ if pretrained
194
+ else (AutoModel.from_config, self.config)
195
+ )
196
+ # TODO: do all model configs have this attribute?
197
+ # PretrainedConfig does so yes??
198
+ if (
199
+ hasattr(self.config, 'is_encoder_decoder')
200
+ and self.config.is_encoder_decoder
201
+ ):
202
+ self.transformer = create_func(model_args)
203
+ self.transformer = self.transformer.encoder
204
+ else:
205
+ self.transformer = create_func(
206
+ model_args,
207
+ trust_remote_code=trust_remote_code,
208
+ add_pooling_layer=uses_transformer_pooler,
209
+ code_revision=revision,
210
+ )
211
+ else:
212
+ self.config = config
213
+ self.config.update(model_config_kwargs)
214
+ self.transformer = AutoModel.from_config(self.config)
215
+
216
+ if pooler_type is None: # get default arch pooler
217
+ pooler_type = _HF_ARCH_DICT[self.config.model_type]['pooler']
218
+
219
+ # FIXME downstream users of OpenCLIP models use these attr,
220
+ # need to verify valid across all models
221
+ self.vocab_size = getattr(self.config, 'vocab_size', 0)
222
+ self.context_length = getattr(self.config, 'max_position_embeddings', 0)
223
+
224
+ self.pooler = _POOLERS[pooler_type]()
225
+
226
+ d_model = getattr(
227
+ self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
228
+ )
229
+ if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
230
+ self.proj = nn.Identity()
231
+ elif proj_type == 'linear':
232
+ self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
233
+ elif proj_type == 'mlp':
234
+ hidden_size = (d_model + output_dim) // 2
235
+ self.proj = nn.Sequential(
236
+ nn.Linear(d_model, hidden_size, bias=proj_bias),
237
+ nn.GELU(),
238
+ nn.Linear(hidden_size, output_dim, bias=proj_bias),
239
+ )
240
+
241
+ def forward(self, x: torch.Tensor):
242
+ attn_mask = (x != self.config.pad_token_id).long()
243
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
244
+ pooled_out = self.pooler(out, attn_mask)
245
+ projected = self.proj(pooled_out)
246
+
247
+ seq_len = out.last_hidden_state.shape[1]
248
+ tokens = (
249
+ out.last_hidden_state[
250
+ :, torch.arange(seq_len) != self.pooler.cls_token_position, :
251
+ ]
252
+ if isinstance(self.pooler, ClsPooler)
253
+ else out.last_hidden_state
254
+ )
255
+
256
+ if self.output_tokens:
257
+ return projected, tokens
258
+ return projected
259
+
260
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
261
+ if not unlocked_layers: # full freezing
262
+ for n, p in self.transformer.named_parameters():
263
+ p.requires_grad = (
264
+ (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
265
+ )
266
+ return
267
+
268
+ encoder = (
269
+ self.transformer.encoder
270
+ if hasattr(self.transformer, 'encoder')
271
+ else self.transformer
272
+ )
273
+ layer_list = getattr(
274
+ encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
275
+ )
276
+ print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
277
+ embeddings = getattr(
278
+ self.transformer,
279
+ _HF_ARCH_DICT[self.config.model_type]['config_names'][
280
+ 'token_embeddings_attr'
281
+ ],
282
+ )
283
+ modules = [embeddings, *layer_list][:-unlocked_layers]
284
+ # freeze layers
285
+ for module in modules:
286
+ for n, p in module.named_parameters():
287
+ p.requires_grad = (
288
+ (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
289
+ )
290
+
291
+ @torch.jit.ignore
292
+ def set_grad_checkpointing(self, _=True):
293
+ self.transformer.gradient_checkpointing_enable()
294
+
295
+ def init_parameters(self):
296
+ pass
297
+
298
+
299
+ """
300
+ HF vision model
301
+ """
302
+
303
+
304
+ class HFVisionEncoder(nn.Module):
305
+ output_tokens: torch.jit.Final[bool]
306
+
307
+ def __init__(
308
+ self,
309
+ model_name_or_path: str,
310
+ image_size: int,
311
+ output_dim: int,
312
+ config: PretrainedConfig = None,
313
+ pool_type: str = 'tok',
314
+ proj_type: Optional[str] = None,
315
+ proj_bias: bool = False,
316
+ attn_drop: float = 0.0,
317
+ hidden_drop: float = 0.0,
318
+ drop_path: Optional[float] = None,
319
+ pretrained: bool = True,
320
+ output_tokens: bool = False,
321
+ trust_remote_code: bool = False,
322
+ ):
323
+ super().__init__()
324
+ self.output_tokens = output_tokens
325
+ self.output_dim = output_dim
326
+ self.image_size = (image_size, image_size)
327
+
328
+ if config is None:
329
+ self.config = AutoConfig.from_pretrained(
330
+ model_name_or_path,
331
+ trust_remote_code=trust_remote_code,
332
+ hidden_dropout_prob=hidden_drop,
333
+ attention_probs_dropout_prob=attn_drop,
334
+ drop_path_rate=drop_path,
335
+ )
336
+ create_func, model_args = (
337
+ (AutoModel.from_pretrained, model_name_or_path)
338
+ if pretrained
339
+ else (AutoModel.from_config, self.config)
340
+ )
341
+ self.transformer = create_func(
342
+ model_args,
343
+ trust_remote_code=trust_remote_code,
344
+ hidden_dropout_prob=hidden_drop,
345
+ attention_probs_dropout_prob=attn_drop,
346
+ )
347
+ else:
348
+ self.config = config
349
+ self.transformer = AutoModel.from_config(config)
350
+
351
+ if 'dinov2' in model_name_or_path:
352
+ self.transformer.embeddings.mask_token.requires_grad = False
353
+
354
+ assert pool_type in ('tok', 'avg', 'none')
355
+ self.pool_type = pool_type
356
+
357
+ d_model = self.config.hidden_size
358
+ if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
359
+ self.proj = nn.Identity()
360
+ elif proj_type == 'linear':
361
+ self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
362
+ elif proj_type == 'mlp':
363
+ hidden_size = (d_model + output_dim) // 2
364
+ self.proj = nn.Sequential(
365
+ nn.Linear(d_model, hidden_size, bias=proj_bias),
366
+ nn.GELU(),
367
+ nn.Linear(hidden_size, output_dim, bias=proj_bias),
368
+ )
369
+
370
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
371
+ if self.pool_type == 'avg':
372
+ pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
373
+ elif self.pool_type == 'tok':
374
+ pooled, tokens = x[:, 0], x[:, 1:]
375
+ else:
376
+ pooled = tokens = x
377
+
378
+ return pooled, tokens
379
+
380
+ def forward(self, x: torch.Tensor):
381
+ # returns a tuple of (final hidden states, token pooled outputs)
382
+ x = self.transformer(x)[0]
383
+ pooled, tokens = self._global_pool(x)
384
+ projected = self.proj(pooled)
385
+
386
+ return projected
387
+
388
+ def lock(self, unlocked_layers: int = 0, freeze_bn_stats: bool = True):
389
+ if not unlocked_layers: # full freezing
390
+ for n, p in self.transformer.named_parameters():
391
+ p.requires_grad = (
392
+ (not freeze_bn_stats) if 'LayerNorm' in n.split('.') else False
393
+ )
394
+ return
395
+
396
+ # TODO: make it work if unlocked_layers !=0
397
+ encoder = (
398
+ self.transformer.encoder
399
+ if hasattr(self.transformer, 'encoder')
400
+ else self.transformer
401
+ )
402
+ layer_list = getattr(
403
+ encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
404
+ )
405
+ print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
406
+ embeddings = getattr(
407
+ self.transformer,
408
+ _HF_ARCH_DICT[self.config.model_type]['config_names'][
409
+ 'token_embeddings_attr'
410
+ ],
411
+ )
412
+ modules = [embeddings, *layer_list][:-unlocked_layers]
413
+ # freeze layers
414
+ for module in modules:
415
+ for n, p in module.named_parameters():
416
+ p.requires_grad = (
417
+ (not freeze_bn_stats) if 'LayerNorm' in n.split('.') else False
418
+ )
419
+
420
+ @torch.jit.ignore
421
+ def set_grad_checkpointing(self, *_, **__):
422
+ self.transformer.gradient_checkpointing_enable()
423
+
424
+ def init_parameters(self):
425
+ pass
modeling_clip.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Code mainly copied from:
4
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
5
+ # and adjusted for Jina CLIP
6
+
7
+ from functools import partial
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as f
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from transformers import BatchEncoding, BatchFeature, PreTrainedModel, logging
15
+ from transformers.models.clip.modeling_clip import (
16
+ CLIPOutput,
17
+ CLIPTextModelOutput,
18
+ CLIPVisionModelOutput,
19
+ clip_loss,
20
+ )
21
+
22
+ from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
23
+ from .eva_model import EVAVisionTransformer
24
+ from .hf_model import HFTextEncoder
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ """ Jina CLIP model implementation """
30
+
31
+
32
+ class LayerNorm(nn.LayerNorm):
33
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
34
+
35
+ def forward(self, x: torch.Tensor):
36
+ origtype = x.dtype
37
+ x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
38
+ return x.to(origtype)
39
+
40
+
41
+ def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
42
+ return HFTextEncoder(
43
+ model_name_or_path=config.hf_model_name_or_path,
44
+ output_dim=config.embed_dim,
45
+ pooler_type=config.pooler_type,
46
+ proj_type=config.proj_type,
47
+ proj_bias=config.proj_bias,
48
+ pretrained=False,
49
+ output_tokens=False,
50
+ trust_remote_code=True,
51
+ revision=None,
52
+ model_config_kwargs=config.hf_model_config_kwargs,
53
+ )
54
+
55
+
56
+ def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer:
57
+ norm_layer = partial(LayerNorm, eps=1e-6)
58
+
59
+ if config.fused_layer_norm:
60
+ try:
61
+ from apex.normalization import FusedLayerNorm
62
+
63
+ norm_layer = partial(FusedLayerNorm, eps=1e-6)
64
+ except (ModuleNotFoundError, ImportError):
65
+ logger.warning('Please install apex to use fused layer norm, ignoring')
66
+
67
+ return EVAVisionTransformer(
68
+ img_size=config.image_size,
69
+ patch_size=config.patch_size,
70
+ num_classes=config.embed_dim,
71
+ use_mean_pooling=False,
72
+ init_values=config.ls_init_value,
73
+ patch_dropout=config.patch_dropout,
74
+ embed_dim=config.width,
75
+ depth=config.layers,
76
+ num_heads=config.width // config.head_width,
77
+ mlp_ratio=config.mlp_ratio,
78
+ qkv_bias=config.qkv_bias,
79
+ drop_path_rate=config.drop_path_rate,
80
+ norm_layer=norm_layer,
81
+ xattn=config.x_attention,
82
+ rope=config.rope_embeddings,
83
+ postnorm=config.post_norm,
84
+ pt_hw_seq_len=config.pt_hw_seq_len,
85
+ intp_freq=config.intp_freq,
86
+ naiveswiglu=config.naive_swiglu,
87
+ subln=config.subln,
88
+ proj_type=config.proj_type,
89
+ )
90
+
91
+
92
+ class JinaCLIPPreTrainedModel(PreTrainedModel):
93
+ """
94
+ An abstract class to handle weights initialization and a simple interface for
95
+ downloading and loading pretrained models.
96
+ """
97
+
98
+ config_class = JinaCLIPConfig
99
+ base_model_prefix = 'clip'
100
+ supports_gradient_checkpointing = True
101
+
102
+ def _init_weights(self, module):
103
+ """Initialize the weights"""
104
+ if isinstance(module, JinaCLIPModel):
105
+ if isinstance(module.text_projection, nn.Linear):
106
+ nn.init.normal_(
107
+ module.text_projection.weight,
108
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
109
+ )
110
+ if isinstance(module.text_projection, nn.Linear):
111
+ nn.init.normal_(
112
+ module.visual_projection.weight,
113
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
114
+ )
115
+ if isinstance(module, nn.LayerNorm):
116
+ module.bias.data.zero_()
117
+ module.weight.data.fill_(1.0)
118
+ if isinstance(module, nn.Linear) and module.bias is not None:
119
+ module.bias.data.zero_()
120
+
121
+
122
+ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
123
+ config_class = JinaCLIPTextConfig
124
+
125
+ def __init__(self, config: JinaCLIPTextConfig):
126
+ super().__init__(config)
127
+ self.text_model = _build_text_tower(config)
128
+ self.post_init()
129
+
130
+ def forward(
131
+ self,
132
+ input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
133
+ return_dict: Optional[bool] = None,
134
+ *_,
135
+ **__,
136
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
137
+ return_dict = (
138
+ return_dict if return_dict is not None else self.config.use_return_dict
139
+ )
140
+ x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
141
+ feats = self.text_model(x=x)
142
+ out = CLIPTextModelOutput(text_embeds=feats)
143
+ return out if return_dict else out.to_tuple()
144
+
145
+
146
+ class JinaCLIPVisionModel(JinaCLIPPreTrainedModel):
147
+ config_class = JinaCLIPVisionConfig
148
+ main_input_name = 'pixel_values'
149
+
150
+ def __init__(self, config: JinaCLIPVisionConfig):
151
+ super().__init__(config)
152
+ self.vision_model = _build_vision_tower(config)
153
+ self.post_init()
154
+
155
+ def forward(
156
+ self,
157
+ pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
158
+ return_dict: Optional[bool] = None,
159
+ *_,
160
+ **__,
161
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
162
+ return_dict = (
163
+ return_dict if return_dict is not None else self.config.use_return_dict
164
+ )
165
+ x = (
166
+ pixel_values.pixel_values
167
+ if isinstance(pixel_values, BatchFeature)
168
+ else pixel_values
169
+ )
170
+ feats = self.vision_model(x=x)
171
+ out = CLIPVisionModelOutput(image_embeds=feats)
172
+ return out if return_dict else out.to_tuple()
173
+
174
+
175
+ class JinaCLIPModel(JinaCLIPPreTrainedModel):
176
+ config_class = JinaCLIPConfig
177
+
178
+ def __init__(self, config: JinaCLIPConfig):
179
+ super().__init__(config)
180
+
181
+ if not isinstance(config.text_config, JinaCLIPTextConfig):
182
+ raise ValueError(
183
+ 'Attribute config.text_config is expected to be of type '
184
+ f'JinaCLIPTextConfig but is of type {type(config.text_config)}.'
185
+ )
186
+
187
+ if not isinstance(config.vision_config, JinaCLIPVisionConfig):
188
+ raise ValueError(
189
+ 'Attribute config.vision_config is expected to be of type '
190
+ f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
191
+ )
192
+
193
+ text_config = config.text_config
194
+ vision_config = config.vision_config
195
+
196
+ self.add_projections = config.add_projections
197
+ self.projection_dim = config.projection_dim
198
+ self.text_embed_dim = text_config.embed_dim
199
+ self.vision_embed_dim = vision_config.embed_dim
200
+
201
+ self.text_model = _build_text_tower(text_config)
202
+ self.vision_model = _build_vision_tower(vision_config)
203
+ self.logit_scale = nn.Parameter(
204
+ torch.tensor(self.config.logit_scale_init_value)
205
+ )
206
+
207
+ if self.add_projections:
208
+ self.visual_projection = nn.Linear(
209
+ self.vision_embed_dim, self.projection_dim, bias=False
210
+ )
211
+ self.text_projection = nn.Linear(
212
+ self.text_embed_dim, self.projection_dim, bias=False
213
+ )
214
+ else:
215
+ self.visual_projection = nn.Identity()
216
+ self.text_projection = nn.Identity()
217
+
218
+ self.post_init()
219
+
220
+ def get_text_features(
221
+ self,
222
+ input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
223
+ *_,
224
+ **__,
225
+ ) -> torch.FloatTensor:
226
+ x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
227
+ return self.text_projection(self.text_model(x=x))
228
+
229
+ def get_image_features(
230
+ self,
231
+ pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
232
+ *_,
233
+ **__,
234
+ ) -> torch.FloatTensor:
235
+ x = (
236
+ pixel_values.pixel_values
237
+ if isinstance(pixel_values, BatchFeature)
238
+ else pixel_values
239
+ )
240
+ return self.visual_projection(self.vision_model(x=x))
241
+
242
+ def encode_text(
243
+ self,
244
+ input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
245
+ return_dict: Optional[bool] = None,
246
+ *_,
247
+ **__,
248
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
249
+ return_dict = (
250
+ return_dict if return_dict is not None else self.config.use_return_dict
251
+ )
252
+ feats = self.get_text_features(input_ids=input_ids)
253
+ out = CLIPTextModelOutput(text_embeds=feats)
254
+ return out if return_dict else out.to_tuple()
255
+
256
+ def encode_image(
257
+ self,
258
+ pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
259
+ return_dict: Optional[bool] = None,
260
+ *_,
261
+ **__,
262
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
263
+ return_dict = (
264
+ return_dict if return_dict is not None else self.config.use_return_dict
265
+ )
266
+ feats = self.get_image_features(pixel_values=pixel_values)
267
+ out = CLIPVisionModelOutput(image_embeds=feats)
268
+ return out if return_dict else out.to_tuple()
269
+
270
+ def forward(
271
+ self,
272
+ input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
273
+ pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
274
+ return_dict: Optional[bool] = None,
275
+ return_loss: Optional[bool] = None,
276
+ *_,
277
+ **__,
278
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]:
279
+ return_dict = (
280
+ return_dict if return_dict is not None else self.config.use_return_dict
281
+ )
282
+ image_embeds = self.get_image_features(pixel_values=pixel_values)
283
+ text_embeds = self.get_text_features(input_ids=input_ids)
284
+
285
+ # normalized features
286
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
287
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
288
+
289
+ # cosine similarity as logits
290
+ logit_scale = self.logit_scale.exp()
291
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
292
+ logits_per_image = logits_per_text.t()
293
+
294
+ loss = None
295
+ if return_loss:
296
+ loss = clip_loss(logits_per_text)
297
+
298
+ if not return_dict:
299
+ output = (
300
+ logits_per_image,
301
+ logits_per_text,
302
+ text_embeds,
303
+ image_embeds,
304
+ None,
305
+ None,
306
+ )
307
+ return ((loss,) + output) if loss is not None else output
308
+
309
+ return CLIPOutput(
310
+ loss=loss,
311
+ logits_per_image=logits_per_image,
312
+ logits_per_text=logits_per_text,
313
+ text_embeds=text_embeds,
314
+ image_embeds=image_embeds,
315
+ text_model_output=None,
316
+ vision_model_output=None,
317
+ )
processing_clip.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Code mainly copied from:
4
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py
5
+ # and adjusted for Jina CLIP
6
+
7
+ from typing import Tuple, Union
8
+
9
+ import torch
10
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
11
+ from transformers.image_utils import ImageInput, make_list_of_images
12
+ from transformers.models.clip import CLIPProcessor
13
+
14
+ from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform
15
+
16
+ """ Jina CLIP processor implementation """
17
+
18
+
19
+ class JinaCLIPProcessor(CLIPProcessor):
20
+ image_processor_class = 'JinaCLIPImageProcessor'
21
+ tokenizer_class = 'CLIPTokenizer'
22
+
23
+
24
+ """ Jina CLIP image processor implementation """
25
+
26
+
27
+ class JinaCLIPImageProcessor(BaseImageProcessor):
28
+ model_input_names = ['pixel_values']
29
+
30
+ def __init__(
31
+ self,
32
+ size: Union[int, Tuple[int, int]] = 224,
33
+ mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN,
34
+ std: Union[float, Tuple[float]] = OPENAI_DATASET_STD,
35
+ resize_mode: str = 'shortest',
36
+ interpolation: str = 'bicubic',
37
+ fill_color: int = 0,
38
+ **kwargs,
39
+ ) -> None:
40
+ super().__init__(**kwargs)
41
+ self.size = size
42
+ self.mean = mean
43
+ self.std = std
44
+ self.resize_mode = resize_mode
45
+ self.interpolation = interpolation
46
+ self.fill_color = fill_color
47
+ self.transform = image_transform(
48
+ image_size=size,
49
+ is_train=False,
50
+ mean=mean,
51
+ std=std,
52
+ resize_mode=resize_mode,
53
+ interpolation=interpolation,
54
+ fill_color=fill_color,
55
+ aug_cfg=None,
56
+ )
57
+
58
+ def to_dict(self):
59
+ output = super().to_dict()
60
+ output.pop('transform')
61
+ return output
62
+
63
+ def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
64
+ images = make_list_of_images(images)
65
+ out = torch.stack([self.transform(image) for image in images], dim=0)
66
+ return BatchFeature(data={'pixel_values': out})
rope_embeddings.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from EVA CLIP
3
+ # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
+ # --------------------------------------------------------
5
+
6
+ import logging
7
+ from math import pi
8
+
9
+ import torch
10
+ from einops import rearrange, repeat
11
+ from torch import nn
12
+
13
+
14
+ def broadcast(tensors, dim=-1):
15
+ num_tensors = len(tensors)
16
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
17
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
18
+ shape_len = list(shape_lens)[0]
19
+ dim = (dim + shape_len) if dim < 0 else dim
20
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
21
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
22
+ assert all(
23
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
24
+ ), 'invalid dimensions for broadcastable concatentation'
25
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
26
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
27
+ expanded_dims.insert(dim, (dim, dims[dim]))
28
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
29
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
30
+ return torch.cat(tensors, dim=dim)
31
+
32
+
33
+ def rotate_half(x):
34
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
35
+ x1, x2 = x.unbind(dim=-1)
36
+ x = torch.stack((-x2, x1), dim=-1)
37
+ return rearrange(x, '... d r -> ... (d r)')
38
+
39
+
40
+ class VisionRotaryEmbedding(nn.Module):
41
+ def __init__(
42
+ self,
43
+ dim,
44
+ pt_seq_len,
45
+ ft_seq_len=None,
46
+ custom_freqs=None,
47
+ freqs_for='lang',
48
+ theta=10000,
49
+ max_freq=10,
50
+ num_freqs=1,
51
+ ):
52
+ super().__init__()
53
+ if custom_freqs:
54
+ freqs = custom_freqs
55
+ elif freqs_for == 'lang':
56
+ freqs = 1.0 / (
57
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
58
+ )
59
+ elif freqs_for == 'pixel':
60
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
61
+ elif freqs_for == 'constant':
62
+ freqs = torch.ones(num_freqs).float()
63
+ else:
64
+ raise ValueError(f'unknown modality {freqs_for}')
65
+
66
+ if ft_seq_len is None:
67
+ ft_seq_len = pt_seq_len
68
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
69
+
70
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
71
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
72
+
73
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
74
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
75
+
76
+ freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
77
+
78
+ self.register_buffer('freqs_cos', freqs.cos())
79
+ self.register_buffer('freqs_sin', freqs.sin())
80
+
81
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
82
+
83
+ def forward(self, t, start_index=0):
84
+ rot_dim = self.freqs_cos.shape[-1]
85
+ end_index = start_index + rot_dim
86
+ assert rot_dim <= t.shape[-1], (
87
+ f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in '
88
+ f'all the positions {rot_dim}'
89
+ )
90
+ t_left, t, t_right = (
91
+ t[..., :start_index],
92
+ t[..., start_index:end_index],
93
+ t[..., end_index:],
94
+ )
95
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
96
+
97
+ return torch.cat((t_left, t, t_right), dim=-1)
98
+
99
+
100
+ class VisionRotaryEmbeddingFast(nn.Module):
101
+ def __init__(
102
+ self,
103
+ dim,
104
+ pt_seq_len,
105
+ ft_seq_len=None,
106
+ custom_freqs=None,
107
+ freqs_for='lang',
108
+ theta=10000,
109
+ max_freq=10,
110
+ num_freqs=1,
111
+ patch_dropout=0.0,
112
+ ):
113
+ super().__init__()
114
+ if custom_freqs:
115
+ freqs = custom_freqs
116
+ elif freqs_for == 'lang':
117
+ freqs = 1.0 / (
118
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
119
+ )
120
+ elif freqs_for == 'pixel':
121
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
122
+ elif freqs_for == 'constant':
123
+ freqs = torch.ones(num_freqs).float()
124
+ else:
125
+ raise ValueError(f'unknown modality {freqs_for}')
126
+
127
+ if ft_seq_len is None:
128
+ ft_seq_len = pt_seq_len
129
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
130
+
131
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
132
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2)
133
+ freqs = broadcast((freqs[:, None, :], freqs[None, :, :]), dim=-1)
134
+
135
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
136
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
137
+
138
+ self.patch_dropout = patch_dropout
139
+
140
+ self.register_buffer('freqs_cos', freqs_cos)
141
+ self.register_buffer('freqs_sin', freqs_sin)
142
+
143
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
144
+
145
+ def forward(self, t, patch_indices_keep=None):
146
+ if patch_indices_keep is not None:
147
+ batch = t.size()[0]
148
+ batch_indices = torch.arange(batch)
149
+ batch_indices = batch_indices[..., None]
150
+
151
+ freqs_cos = repeat(
152
+ self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
153
+ )
154
+ freqs_sin = repeat(
155
+ self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
156
+ )
157
+
158
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
159
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
160
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
161
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
162
+
163
+ return t * freqs_cos + rotate_half(t) * freqs_sin
164
+
165
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
transform.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import random
3
+ import warnings
4
+ from dataclasses import asdict, dataclass
5
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import torch
8
+ import torchvision.transforms.functional as F
9
+ from torchvision.transforms import (
10
+ CenterCrop,
11
+ ColorJitter,
12
+ Compose,
13
+ Grayscale,
14
+ InterpolationMode,
15
+ Normalize,
16
+ RandomResizedCrop,
17
+ Resize,
18
+ ToTensor,
19
+ )
20
+ from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
21
+
22
+ OPENAI_DATASET_MEAN = tuple(OPENAI_CLIP_MEAN)
23
+ OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
24
+
25
+
26
+ @dataclass
27
+ class PreprocessCfg:
28
+ size: Union[int, Tuple[int, int]] = 224
29
+ mode: str = 'RGB'
30
+ mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
31
+ std: Tuple[float, ...] = OPENAI_DATASET_STD
32
+ interpolation: str = 'bicubic'
33
+ resize_mode: str = 'shortest'
34
+ fill_color: int = 0
35
+
36
+ def __post_init__(self):
37
+ assert self.mode in ('RGB',)
38
+
39
+ @property
40
+ def num_channels(self):
41
+ return 3
42
+
43
+ @property
44
+ def input_size(self):
45
+ return (self.num_channels,) + (self.size, self.size)
46
+
47
+
48
+ _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
49
+
50
+
51
+ def merge_preprocess_dict(
52
+ base: Union[PreprocessCfg, Dict],
53
+ overlay: Dict,
54
+ ):
55
+ """Merge overlay key-value pairs on top of base preprocess cfg or dict.
56
+ Input dicts are filtered based on PreprocessCfg fields.
57
+ """
58
+ if isinstance(base, PreprocessCfg):
59
+ base_clean = asdict(base)
60
+ else:
61
+ base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
62
+ if overlay:
63
+ overlay_clean = {
64
+ k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None
65
+ }
66
+ base_clean.update(overlay_clean)
67
+ return base_clean
68
+
69
+
70
+ def merge_preprocess_kwargs(base: Union[PreprocessCfg, Dict], **kwargs):
71
+ return merge_preprocess_dict(base, kwargs)
72
+
73
+
74
+ @dataclass
75
+ class AugmentationCfg:
76
+ scale: Tuple[float, float] = (0.9, 1.0)
77
+ ratio: Optional[Tuple[float, float]] = None
78
+ color_jitter: Optional[
79
+ Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
80
+ ] = None
81
+ re_prob: Optional[float] = None
82
+ re_count: Optional[int] = None
83
+ use_timm: bool = False
84
+
85
+ # params for simclr_jitter_gray
86
+ color_jitter_prob: float = None
87
+ gray_scale_prob: float = None
88
+
89
+
90
+ def _setup_size(size, error_msg):
91
+ if isinstance(size, numbers.Number):
92
+ return int(size), int(size)
93
+
94
+ if isinstance(size, Sequence) and len(size) == 1:
95
+ return size[0], size[0]
96
+
97
+ if len(size) != 2:
98
+ raise ValueError(error_msg)
99
+
100
+ return size
101
+
102
+
103
+ class ResizeKeepRatio:
104
+ """Resize and Keep Ratio
105
+
106
+ Copy & paste from `timm`
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ size,
112
+ longest=0.0,
113
+ interpolation=InterpolationMode.BICUBIC,
114
+ random_scale_prob=0.0,
115
+ random_scale_range=(0.85, 1.05),
116
+ random_aspect_prob=0.0,
117
+ random_aspect_range=(0.9, 1.11),
118
+ ):
119
+ if isinstance(size, (list, tuple)):
120
+ self.size = tuple(size)
121
+ else:
122
+ self.size = (size, size)
123
+ self.interpolation = interpolation
124
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
125
+ self.random_scale_prob = random_scale_prob
126
+ self.random_scale_range = random_scale_range
127
+ self.random_aspect_prob = random_aspect_prob
128
+ self.random_aspect_range = random_aspect_range
129
+
130
+ @staticmethod
131
+ def get_params(
132
+ img,
133
+ target_size,
134
+ longest,
135
+ random_scale_prob=0.0,
136
+ random_scale_range=(0.85, 1.05),
137
+ random_aspect_prob=0.0,
138
+ random_aspect_range=(0.9, 1.11),
139
+ ):
140
+ """Get parameters"""
141
+ source_size = img.size[::-1] # h, w
142
+ h, w = source_size
143
+ target_h, target_w = target_size
144
+ ratio_h = h / target_h
145
+ ratio_w = w / target_w
146
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
147
+ 1.0 - longest
148
+ )
149
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
150
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
151
+ ratio_factor = (ratio_factor, ratio_factor)
152
+ else:
153
+ ratio_factor = (1.0, 1.0)
154
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
155
+ aspect_factor = random.uniform(
156
+ random_aspect_range[0], random_aspect_range[1]
157
+ )
158
+ ratio_factor = (
159
+ ratio_factor[0] / aspect_factor,
160
+ ratio_factor[1] * aspect_factor,
161
+ )
162
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
163
+ return size
164
+
165
+ def __call__(self, img):
166
+ """
167
+ Args:
168
+ img (PIL Image): Image to be cropped and resized.
169
+
170
+ Returns:
171
+ PIL Image: Resized, padded to at least target size, possibly
172
+ cropped to exactly target size
173
+ """
174
+ size = self.get_params(
175
+ img,
176
+ self.size,
177
+ self.longest,
178
+ self.random_scale_prob,
179
+ self.random_scale_range,
180
+ self.random_aspect_prob,
181
+ self.random_aspect_range,
182
+ )
183
+ img = F.resize(img, size, self.interpolation)
184
+ return img
185
+
186
+ def __repr__(self):
187
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
188
+ format_string += f', interpolation={self.interpolation})'
189
+ format_string += f', longest={self.longest:.3f})'
190
+ return format_string
191
+
192
+
193
+ def center_crop_or_pad(
194
+ img: torch.Tensor, output_size: List[int], fill=0
195
+ ) -> torch.Tensor:
196
+ """Center crops and/or pads the given image.
197
+ If the image is torch Tensor, it is expected
198
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
199
+ dimensions. If image size is smaller than output size along any edge, image is
200
+ padded with 0 and then center cropped.
201
+
202
+ Args:
203
+ img (PIL Image or Tensor): Image to be cropped.
204
+ output_size (sequence or int): (height, width) of the crop box. If int or
205
+ sequence with single int, it is used for both directions.
206
+ fill (int, Tuple[int]): Padding color
207
+
208
+ Returns:
209
+ PIL Image or Tensor: Cropped image.
210
+ """
211
+ if isinstance(output_size, numbers.Number):
212
+ output_size = (int(output_size), int(output_size))
213
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
214
+ output_size = (output_size[0], output_size[0])
215
+
216
+ _, image_height, image_width = F.get_dimensions(img)
217
+ crop_height, crop_width = output_size
218
+
219
+ if crop_width > image_width or crop_height > image_height:
220
+ padding_ltrb = [
221
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
222
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
223
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
224
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
225
+ ]
226
+ img = F.pad(img, padding_ltrb, fill=fill)
227
+ _, image_height, image_width = F.get_dimensions(img)
228
+ if crop_width == image_width and crop_height == image_height:
229
+ return img
230
+
231
+ crop_top = int(round((image_height - crop_height) / 2.0))
232
+ crop_left = int(round((image_width - crop_width) / 2.0))
233
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
234
+
235
+
236
+ class CenterCropOrPad(torch.nn.Module):
237
+ """Crops the given image at the center.
238
+ If the image is torch Tensor, it is expected
239
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
240
+ dimensions. If image size is smaller than output size along any edge, image is
241
+ padded with 0 and then center cropped.
242
+
243
+ Args:
244
+ size (sequence or int): Desired output size of the crop. If size is an
245
+ int instead of sequence like (h, w), a square crop (size, size) is
246
+ made. If provided a sequence of length 1, it will be interpreted as
247
+ (size[0], size[0]).
248
+ """
249
+
250
+ def __init__(self, size, fill=0):
251
+ super().__init__()
252
+ self.size = _setup_size(
253
+ size, error_msg='Please provide only two dimensions (h, w) for size.'
254
+ )
255
+ self.fill = fill
256
+
257
+ def forward(self, img):
258
+ """
259
+ Args:
260
+ img (PIL Image or Tensor): Image to be cropped.
261
+
262
+ Returns:
263
+ PIL Image or Tensor: Cropped image.
264
+ """
265
+ return center_crop_or_pad(img, self.size, fill=self.fill)
266
+
267
+ def __repr__(self) -> str:
268
+ return f'{self.__class__.__name__}(size={self.size})'
269
+
270
+
271
+ def _convert_to_rgb(image):
272
+ return image.convert('RGB')
273
+
274
+
275
+ class _ColorJitter(object):
276
+ """
277
+ Apply Color Jitter to the PIL image with a specified probability.
278
+ """
279
+
280
+ def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
281
+ assert 0.0 <= p <= 1.0
282
+ self.p = p
283
+ self.transf = ColorJitter(
284
+ brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
285
+ )
286
+
287
+ def __call__(self, img):
288
+ if random.random() < self.p:
289
+ return self.transf(img)
290
+ else:
291
+ return img
292
+
293
+
294
+ class _GrayScale(object):
295
+ """
296
+ Apply Gray Scale to the PIL image with a specified probability.
297
+ """
298
+
299
+ def __init__(self, p=0.2):
300
+ assert 0.0 <= p <= 1.0
301
+ self.p = p
302
+ self.transf = Grayscale(num_output_channels=3)
303
+
304
+ def __call__(self, img):
305
+ if random.random() < self.p:
306
+ return self.transf(img)
307
+ else:
308
+ return img
309
+
310
+
311
+ def image_transform(
312
+ image_size: Union[int, Tuple[int, int]],
313
+ is_train: bool,
314
+ mean: Optional[Tuple[float, ...]] = None,
315
+ std: Optional[Tuple[float, ...]] = None,
316
+ resize_mode: Optional[str] = None,
317
+ interpolation: Optional[str] = None,
318
+ fill_color: int = 0,
319
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
320
+ ):
321
+ mean = mean or OPENAI_DATASET_MEAN
322
+ if not isinstance(mean, (list, tuple)):
323
+ mean = (mean,) * 3
324
+
325
+ std = std or OPENAI_DATASET_STD
326
+ if not isinstance(std, (list, tuple)):
327
+ std = (std,) * 3
328
+
329
+ interpolation = interpolation or 'bicubic'
330
+ assert interpolation in ['bicubic', 'bilinear', 'random']
331
+ # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for
332
+ # inference if set
333
+ interpolation_mode = (
334
+ InterpolationMode.BILINEAR
335
+ if interpolation == 'bilinear'
336
+ else InterpolationMode.BICUBIC
337
+ )
338
+
339
+ resize_mode = resize_mode or 'shortest'
340
+ assert resize_mode in ('shortest', 'longest', 'squash')
341
+
342
+ if isinstance(aug_cfg, dict):
343
+ aug_cfg = AugmentationCfg(**aug_cfg)
344
+ else:
345
+ aug_cfg = aug_cfg or AugmentationCfg()
346
+
347
+ normalize = Normalize(mean=mean, std=std)
348
+
349
+ if is_train:
350
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
351
+ use_timm = aug_cfg_dict.pop('use_timm', False)
352
+ if use_timm:
353
+ from timm.data import create_transform # timm can still be optional
354
+
355
+ if isinstance(image_size, (tuple, list)):
356
+ assert len(image_size) >= 2
357
+ input_size = (3,) + image_size[-2:]
358
+ else:
359
+ input_size = (3, image_size, image_size)
360
+
361
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
362
+ # drop extra non-timm items
363
+ aug_cfg_dict.pop('color_jitter_prob', None)
364
+ aug_cfg_dict.pop('gray_scale_prob', None)
365
+
366
+ train_transform = create_transform(
367
+ input_size=input_size,
368
+ is_training=True,
369
+ hflip=0.0,
370
+ mean=mean,
371
+ std=std,
372
+ re_mode='pixel',
373
+ interpolation=interpolation,
374
+ **aug_cfg_dict,
375
+ )
376
+ else:
377
+ train_transform = [
378
+ RandomResizedCrop(
379
+ image_size,
380
+ scale=aug_cfg_dict.pop('scale'),
381
+ interpolation=InterpolationMode.BICUBIC,
382
+ ),
383
+ _convert_to_rgb,
384
+ ]
385
+ if aug_cfg.color_jitter_prob:
386
+ assert (
387
+ aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
388
+ )
389
+ train_transform.extend(
390
+ [_ColorJitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)]
391
+ )
392
+ if aug_cfg.gray_scale_prob:
393
+ train_transform.extend([_GrayScale(aug_cfg.gray_scale_prob)])
394
+ train_transform.extend(
395
+ [
396
+ ToTensor(),
397
+ normalize,
398
+ ]
399
+ )
400
+ train_transform = Compose(train_transform)
401
+ if aug_cfg_dict:
402
+ warnings.warn(
403
+ f'Unused augmentation cfg items, specify `use_timm` to use '
404
+ f'({list(aug_cfg_dict.keys())}).'
405
+ )
406
+ return train_transform
407
+ else:
408
+ if resize_mode == 'longest':
409
+ transforms = [
410
+ ResizeKeepRatio(
411
+ image_size, interpolation=interpolation_mode, longest=1
412
+ ),
413
+ CenterCropOrPad(image_size, fill=fill_color),
414
+ ]
415
+ elif resize_mode == 'squash':
416
+ if isinstance(image_size, int):
417
+ image_size = (image_size, image_size)
418
+ transforms = [
419
+ Resize(image_size, interpolation=interpolation_mode),
420
+ ]
421
+ else:
422
+ assert resize_mode == 'shortest'
423
+ if not isinstance(image_size, (tuple, list)):
424
+ image_size = (image_size, image_size)
425
+ if image_size[0] == image_size[1]:
426
+ # simple case, use torchvision built-in Resize w/ shortest edge mode
427
+ # (scalar size arg)
428
+ transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
429
+ else:
430
+ # resize shortest edge to matching target dim for non-square target
431
+ transforms = [ResizeKeepRatio(image_size)]
432
+ transforms += [CenterCrop(image_size)]
433
+
434
+ transforms.extend(
435
+ [
436
+ _convert_to_rgb,
437
+ ToTensor(),
438
+ normalize,
439
+ ]
440
+ )
441
+ return Compose(transforms)
442
+
443
+
444
+ def image_transform_v2(
445
+ cfg: PreprocessCfg,
446
+ is_train: bool,
447
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
448
+ ):
449
+ return image_transform(
450
+ image_size=cfg.size,
451
+ is_train=is_train,
452
+ mean=cfg.mean,
453
+ std=cfg.std,
454
+ interpolation=cfg.interpolation,
455
+ resize_mode=cfg.resize_mode,
456
+ fill_color=cfg.fill_color,
457
+ aug_cfg=aug_cfg,
458
+ )