BAAI
/

ryanzhangfan commited on
Commit
827d192
·
1 Parent(s): e50ad9f

Upload 26 files

Browse files
added_tokens.json ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</delimiter_of_multi_objects/>": 32013,
3
+ "</object>": 32012,
4
+ "</phrase>": 32010,
5
+ "<REC>": 32014,
6
+ "<grounding>": 32008,
7
+ "<image>": 32003,
8
+ "<object>": 32011,
9
+ "<patch_index_0000>": 32015,
10
+ "<patch_index_0001>": 32016,
11
+ "<patch_index_0002>": 32017,
12
+ "<patch_index_0003>": 32018,
13
+ "<patch_index_0004>": 32019,
14
+ "<patch_index_0005>": 32020,
15
+ "<patch_index_0006>": 32021,
16
+ "<patch_index_0007>": 32022,
17
+ "<patch_index_0008>": 32023,
18
+ "<patch_index_0009>": 32024,
19
+ "<patch_index_0010>": 32025,
20
+ "<patch_index_0011>": 32026,
21
+ "<patch_index_0012>": 32027,
22
+ "<patch_index_0013>": 32028,
23
+ "<patch_index_0014>": 32029,
24
+ "<patch_index_0015>": 32030,
25
+ "<patch_index_0016>": 32031,
26
+ "<patch_index_0017>": 32032,
27
+ "<patch_index_0018>": 32033,
28
+ "<patch_index_0019>": 32034,
29
+ "<patch_index_0020>": 32035,
30
+ "<patch_index_0021>": 32036,
31
+ "<patch_index_0022>": 32037,
32
+ "<patch_index_0023>": 32038,
33
+ "<patch_index_0024>": 32039,
34
+ "<patch_index_0025>": 32040,
35
+ "<patch_index_0026>": 32041,
36
+ "<patch_index_0027>": 32042,
37
+ "<patch_index_0028>": 32043,
38
+ "<patch_index_0029>": 32044,
39
+ "<patch_index_0030>": 32045,
40
+ "<patch_index_0031>": 32046,
41
+ "<patch_index_0032>": 32047,
42
+ "<patch_index_0033>": 32048,
43
+ "<patch_index_0034>": 32049,
44
+ "<patch_index_0035>": 32050,
45
+ "<patch_index_0036>": 32051,
46
+ "<patch_index_0037>": 32052,
47
+ "<patch_index_0038>": 32053,
48
+ "<patch_index_0039>": 32054,
49
+ "<patch_index_0040>": 32055,
50
+ "<patch_index_0041>": 32056,
51
+ "<patch_index_0042>": 32057,
52
+ "<patch_index_0043>": 32058,
53
+ "<patch_index_0044>": 32059,
54
+ "<patch_index_0045>": 32060,
55
+ "<patch_index_0046>": 32061,
56
+ "<patch_index_0047>": 32062,
57
+ "<patch_index_0048>": 32063,
58
+ "<patch_index_0049>": 32064,
59
+ "<patch_index_0050>": 32065,
60
+ "<patch_index_0051>": 32066,
61
+ "<patch_index_0052>": 32067,
62
+ "<patch_index_0053>": 32068,
63
+ "<patch_index_0054>": 32069,
64
+ "<patch_index_0055>": 32070,
65
+ "<patch_index_0056>": 32071,
66
+ "<patch_index_0057>": 32072,
67
+ "<patch_index_0058>": 32073,
68
+ "<patch_index_0059>": 32074,
69
+ "<patch_index_0060>": 32075,
70
+ "<patch_index_0061>": 32076,
71
+ "<patch_index_0062>": 32077,
72
+ "<patch_index_0063>": 32078,
73
+ "<patch_index_0064>": 32079,
74
+ "<patch_index_0065>": 32080,
75
+ "<patch_index_0066>": 32081,
76
+ "<patch_index_0067>": 32082,
77
+ "<patch_index_0068>": 32083,
78
+ "<patch_index_0069>": 32084,
79
+ "<patch_index_0070>": 32085,
80
+ "<patch_index_0071>": 32086,
81
+ "<patch_index_0072>": 32087,
82
+ "<patch_index_0073>": 32088,
83
+ "<patch_index_0074>": 32089,
84
+ "<patch_index_0075>": 32090,
85
+ "<patch_index_0076>": 32091,
86
+ "<patch_index_0077>": 32092,
87
+ "<patch_index_0078>": 32093,
88
+ "<patch_index_0079>": 32094,
89
+ "<patch_index_0080>": 32095,
90
+ "<patch_index_0081>": 32096,
91
+ "<patch_index_0082>": 32097,
92
+ "<patch_index_0083>": 32098,
93
+ "<patch_index_0084>": 32099,
94
+ "<patch_index_0085>": 32100,
95
+ "<patch_index_0086>": 32101,
96
+ "<patch_index_0087>": 32102,
97
+ "<patch_index_0088>": 32103,
98
+ "<patch_index_0089>": 32104,
99
+ "<patch_index_0090>": 32105,
100
+ "<patch_index_0091>": 32106,
101
+ "<patch_index_0092>": 32107,
102
+ "<patch_index_0093>": 32108,
103
+ "<patch_index_0094>": 32109,
104
+ "<patch_index_0095>": 32110,
105
+ "<patch_index_0096>": 32111,
106
+ "<patch_index_0097>": 32112,
107
+ "<patch_index_0098>": 32113,
108
+ "<patch_index_0099>": 32114,
109
+ "<patch_index_0100>": 32115,
110
+ "<patch_index_0101>": 32116,
111
+ "<patch_index_0102>": 32117,
112
+ "<patch_index_0103>": 32118,
113
+ "<patch_index_0104>": 32119,
114
+ "<patch_index_0105>": 32120,
115
+ "<patch_index_0106>": 32121,
116
+ "<patch_index_0107>": 32122,
117
+ "<patch_index_0108>": 32123,
118
+ "<patch_index_0109>": 32124,
119
+ "<patch_index_0110>": 32125,
120
+ "<patch_index_0111>": 32126,
121
+ "<patch_index_0112>": 32127,
122
+ "<patch_index_0113>": 32128,
123
+ "<patch_index_0114>": 32129,
124
+ "<patch_index_0115>": 32130,
125
+ "<patch_index_0116>": 32131,
126
+ "<patch_index_0117>": 32132,
127
+ "<patch_index_0118>": 32133,
128
+ "<patch_index_0119>": 32134,
129
+ "<patch_index_0120>": 32135,
130
+ "<patch_index_0121>": 32136,
131
+ "<patch_index_0122>": 32137,
132
+ "<patch_index_0123>": 32138,
133
+ "<patch_index_0124>": 32139,
134
+ "<patch_index_0125>": 32140,
135
+ "<patch_index_0126>": 32141,
136
+ "<patch_index_0127>": 32142,
137
+ "<patch_index_0128>": 32143,
138
+ "<patch_index_0129>": 32144,
139
+ "<patch_index_0130>": 32145,
140
+ "<patch_index_0131>": 32146,
141
+ "<patch_index_0132>": 32147,
142
+ "<patch_index_0133>": 32148,
143
+ "<patch_index_0134>": 32149,
144
+ "<patch_index_0135>": 32150,
145
+ "<patch_index_0136>": 32151,
146
+ "<patch_index_0137>": 32152,
147
+ "<patch_index_0138>": 32153,
148
+ "<patch_index_0139>": 32154,
149
+ "<patch_index_0140>": 32155,
150
+ "<patch_index_0141>": 32156,
151
+ "<patch_index_0142>": 32157,
152
+ "<patch_index_0143>": 32158,
153
+ "<patch_index_0144>": 32159,
154
+ "<patch_index_0145>": 32160,
155
+ "<patch_index_0146>": 32161,
156
+ "<patch_index_0147>": 32162,
157
+ "<patch_index_0148>": 32163,
158
+ "<patch_index_0149>": 32164,
159
+ "<patch_index_0150>": 32165,
160
+ "<patch_index_0151>": 32166,
161
+ "<patch_index_0152>": 32167,
162
+ "<patch_index_0153>": 32168,
163
+ "<patch_index_0154>": 32169,
164
+ "<patch_index_0155>": 32170,
165
+ "<patch_index_0156>": 32171,
166
+ "<patch_index_0157>": 32172,
167
+ "<patch_index_0158>": 32173,
168
+ "<patch_index_0159>": 32174,
169
+ "<patch_index_0160>": 32175,
170
+ "<patch_index_0161>": 32176,
171
+ "<patch_index_0162>": 32177,
172
+ "<patch_index_0163>": 32178,
173
+ "<patch_index_0164>": 32179,
174
+ "<patch_index_0165>": 32180,
175
+ "<patch_index_0166>": 32181,
176
+ "<patch_index_0167>": 32182,
177
+ "<patch_index_0168>": 32183,
178
+ "<patch_index_0169>": 32184,
179
+ "<patch_index_0170>": 32185,
180
+ "<patch_index_0171>": 32186,
181
+ "<patch_index_0172>": 32187,
182
+ "<patch_index_0173>": 32188,
183
+ "<patch_index_0174>": 32189,
184
+ "<patch_index_0175>": 32190,
185
+ "<patch_index_0176>": 32191,
186
+ "<patch_index_0177>": 32192,
187
+ "<patch_index_0178>": 32193,
188
+ "<patch_index_0179>": 32194,
189
+ "<patch_index_0180>": 32195,
190
+ "<patch_index_0181>": 32196,
191
+ "<patch_index_0182>": 32197,
192
+ "<patch_index_0183>": 32198,
193
+ "<patch_index_0184>": 32199,
194
+ "<patch_index_0185>": 32200,
195
+ "<patch_index_0186>": 32201,
196
+ "<patch_index_0187>": 32202,
197
+ "<patch_index_0188>": 32203,
198
+ "<patch_index_0189>": 32204,
199
+ "<patch_index_0190>": 32205,
200
+ "<patch_index_0191>": 32206,
201
+ "<patch_index_0192>": 32207,
202
+ "<patch_index_0193>": 32208,
203
+ "<patch_index_0194>": 32209,
204
+ "<patch_index_0195>": 32210,
205
+ "<patch_index_0196>": 32211,
206
+ "<patch_index_0197>": 32212,
207
+ "<patch_index_0198>": 32213,
208
+ "<patch_index_0199>": 32214,
209
+ "<patch_index_0200>": 32215,
210
+ "<patch_index_0201>": 32216,
211
+ "<patch_index_0202>": 32217,
212
+ "<patch_index_0203>": 32218,
213
+ "<patch_index_0204>": 32219,
214
+ "<patch_index_0205>": 32220,
215
+ "<patch_index_0206>": 32221,
216
+ "<patch_index_0207>": 32222,
217
+ "<patch_index_0208>": 32223,
218
+ "<patch_index_0209>": 32224,
219
+ "<patch_index_0210>": 32225,
220
+ "<patch_index_0211>": 32226,
221
+ "<patch_index_0212>": 32227,
222
+ "<patch_index_0213>": 32228,
223
+ "<patch_index_0214>": 32229,
224
+ "<patch_index_0215>": 32230,
225
+ "<patch_index_0216>": 32231,
226
+ "<patch_index_0217>": 32232,
227
+ "<patch_index_0218>": 32233,
228
+ "<patch_index_0219>": 32234,
229
+ "<patch_index_0220>": 32235,
230
+ "<patch_index_0221>": 32236,
231
+ "<patch_index_0222>": 32237,
232
+ "<patch_index_0223>": 32238,
233
+ "<patch_index_0224>": 32239,
234
+ "<patch_index_0225>": 32240,
235
+ "<patch_index_0226>": 32241,
236
+ "<patch_index_0227>": 32242,
237
+ "<patch_index_0228>": 32243,
238
+ "<patch_index_0229>": 32244,
239
+ "<patch_index_0230>": 32245,
240
+ "<patch_index_0231>": 32246,
241
+ "<patch_index_0232>": 32247,
242
+ "<patch_index_0233>": 32248,
243
+ "<patch_index_0234>": 32249,
244
+ "<patch_index_0235>": 32250,
245
+ "<patch_index_0236>": 32251,
246
+ "<patch_index_0237>": 32252,
247
+ "<patch_index_0238>": 32253,
248
+ "<patch_index_0239>": 32254,
249
+ "<patch_index_0240>": 32255,
250
+ "<patch_index_0241>": 32256,
251
+ "<patch_index_0242>": 32257,
252
+ "<patch_index_0243>": 32258,
253
+ "<patch_index_0244>": 32259,
254
+ "<patch_index_0245>": 32260,
255
+ "<patch_index_0246>": 32261,
256
+ "<patch_index_0247>": 32262,
257
+ "<patch_index_0248>": 32263,
258
+ "<patch_index_0249>": 32264,
259
+ "<patch_index_0250>": 32265,
260
+ "<patch_index_0251>": 32266,
261
+ "<patch_index_0252>": 32267,
262
+ "<patch_index_0253>": 32268,
263
+ "<patch_index_0254>": 32269,
264
+ "<patch_index_0255>": 32270,
265
+ "<patch_index_0256>": 32271,
266
+ "<phrase>": 32009,
267
+ "[/IMG]": 32002,
268
+ "[/gIMG]": 32005,
269
+ "[ASSISTANT]": 32273,
270
+ "[EOC]": 32006,
271
+ "[IMG]": 32001,
272
+ "[PAD]": 32000,
273
+ "[USER]": 32272,
274
+ "[VIDEO]": 32007,
275
+ "[gIMG]": 32004
276
+ }
config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "emu2",
3
+ "architectures": [
4
+ "EmuForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_emu.EmuConfig",
10
+ "AutoModelForCausalLM": "modeling_emu.EmuForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "d_model": 1792,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 6656,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 17920,
19
+ "max_position_embeddings": 2048,
20
+ "model_version": "chat",
21
+ "num_attention_heads": 52,
22
+ "num_hidden_layers": 60,
23
+ "num_key_value_heads": 52,
24
+ "pad_token_id": 32000,
25
+ "pretraining_tp": 1,
26
+ "rms_norm_eps": 1e-06,
27
+ "rope_scaling": null,
28
+ "rope_theta": 10000.0,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.31.0",
32
+ "use_cache": true,
33
+ "vision_config": {
34
+ "drop_path_rate": 0,
35
+ "eva_model_name": "eva-clip-E-14-plus",
36
+ "head_width": 112,
37
+ "image_size": 448,
38
+ "intermediate_size": 15360,
39
+ "layer_norm_eps": 1e-06,
40
+ "layers": 64,
41
+ "mlp_ratio": 8.571428571428571,
42
+ "n_query": 256,
43
+ "patch_size": 14,
44
+ "postnorm": true,
45
+ "qkv_bias": true,
46
+ "v_query": 64,
47
+ "width": 1792,
48
+ "xattn": false
49
+ },
50
+ "vocab_size": 32274
51
+ }
configuration_emu.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class EmuConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ intermediate_size=11008,
13
+ num_hidden_layers=32,
14
+ num_attention_heads=32,
15
+ hidden_act='silu',
16
+ max_position_embeddings=2048,
17
+ initializer_range=0.02,
18
+ rms_norm_eps=1e-06,
19
+ model_version: Literal["base", "chat"] = "base",
20
+ pad_token_id=0,
21
+ bos_token_id=1,
22
+ eos_token_id=2,
23
+ tie_word_embeddings=False,
24
+ use_cache=True,
25
+ pretraining_tp=1,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ **kwargs,
31
+ ):
32
+ self.hidden_size = hidden_size
33
+ self.intermediate_size = intermediate_size
34
+ self.num_attention_heads = num_attention_heads
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.rms_norm_eps = rms_norm_eps
37
+ self.initializer_range = initializer_range
38
+ self.vocab_size = vocab_size
39
+ self.num_hidden_layers = num_hidden_layers
40
+ self.hidden_act = hidden_act
41
+ self.model_version = model_version
42
+ self.use_cache = use_cache
43
+ self.pretraining_tp = pretraining_tp
44
+ self.use_cache = use_cache
45
+ self.rope_theta = rope_theta
46
+ self.rope_scaling = rope_scaling
47
+ self._rope_scaling_validation()
48
+ self.attention_bias = attention_bias
49
+ self.attention_dropout = attention_dropout
50
+ super().__init__(
51
+ pad_token_id=pad_token_id,
52
+ bos_token_id=bos_token_id,
53
+ eos_token_id=eos_token_id,
54
+ tie_word_embeddings=tie_word_embeddings,
55
+ **kwargs,
56
+ )
57
+
58
+ def _rope_scaling_validation(self):
59
+ """
60
+ Validate the `rope_scaling` configuration.
61
+ """
62
+ if self.rope_scaling is None:
63
+ return
64
+
65
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
66
+ raise ValueError(
67
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
68
+ f"got {self.rope_scaling}"
69
+ )
70
+ rope_scaling_type = self.rope_scaling.get("type", None)
71
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
72
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
73
+ raise ValueError(
74
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
75
+ )
76
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
77
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
constants.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EVA_IMAGE_SIZE = 448
2
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
3
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
4
+
5
+ DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.png', 'png', 'jpeg', 'webp']
6
+ DEFAULT_TEXT_FILE_SUFFIX = ['txt', '0.txt']
7
+
8
+ IGNORE_INDEX = -100
9
+
10
+ # special tokens
11
+ # START
12
+ DEFAULT_PAD_TOKEN = "[PAD]"
13
+ DEFAULT_BOS_TOKEN = '<s>'
14
+ DEFAULT_EOS_TOKEN = '</s>'
15
+ DEFAULT_UNK_TOKEN = "<unk>"
16
+
17
+ DEFAULT_IMG_TOKEN = "[IMG]"
18
+ DEFAULT_IMG_END_TOKEN = "[/IMG]"
19
+ DEFAULT_IMAGE_TOKEN = "<image>"
20
+ DEFAULT_gIMG_TOKEN = "[gIMG]"
21
+ DEFAULT_gIMG_END_TOKEN = "[/gIMG]"
22
+ DEFAULT_EOC_TOKEN = "[EOC]"
23
+ DEFAULT_VIDEO_TOKEN = "[VIDEO]"
24
+
25
+ GRD_SYMBOL = "<grounding>"
26
+ BOP_SYMBOL = "<phrase>"
27
+ EOP_SYMBOL = "</phrase>"
28
+ BOO_SYMBOL = "<object>"
29
+ EOO_SYMBOL = "</object>"
30
+ DOM_SYMBOL = "</delimiter_of_multi_objects/>"
31
+
32
+ REC_SYMBOL = "<REC>"
33
+
34
+ USER_TOKEN = "[USER]"
35
+ ASSISTANT_TOKEN = "[ASSISTANT]"
36
+ # END
37
+
38
+ # special token id
39
+ # START
40
+ IMAGE = 32003
41
+ BOI = 32001
42
+ VIDEO = 32004
43
+ # END
44
+
45
+ DEFAULT_IMG_PLACEHOLDER = "[<IMG_PLH>]"
46
+ DEFAULT_VID_PLACEHOLDER = "[<VID_PLH>]"
47
+ FAKE_VIDEO_END_TOKEN = "[/VIDEO]"
modeling_emu.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, List, Optional, Mapping, Callable
3
+ from collections import OrderedDict
4
+ from argparse import Namespace
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as T
9
+ import PIL
10
+ from PIL import Image
11
+ import transformers
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+
14
+ from .configuration_emu import EmuConfig
15
+ from .constants import *
16
+ from .modeling_llama import LlamaForCausalLM
17
+ from .visual import EVAVisionTransformer
18
+
19
+
20
+ class EmuPreTrainedModel(PreTrainedModel):
21
+ config_class = EmuConfig
22
+ base_model_prefix = "model"
23
+ supports_gradient_checkpointing = False
24
+ _no_split_modules = ["LlamaDecoderLayer", "Block"]
25
+ _skip_keys_device_placement = "past_key_values"
26
+
27
+ def _init_weights(self, module):
28
+ std = self.config.initializer_range
29
+ if isinstance(module, nn.Linear):
30
+ module.weight.data.normal_(mean=0.0, std=std)
31
+ if module.bias is not None:
32
+ module.bias.data.zero_()
33
+ elif isinstance(module, nn.Embedding):
34
+ module.weight.data.normal_(mean=0.0, std=std)
35
+ if module.padding_idx is not None:
36
+ module.weight.data[module.padding_idx].zero_()
37
+
38
+
39
+ class EmuForClsAndRegression(EmuPreTrainedModel):
40
+
41
+ def __init__(self, config):
42
+ super(EmuForClsAndRegression, self).__init__(config)
43
+
44
+ self.lm = LlamaForCausalLM(config=config)
45
+
46
+ self.lm.model.embed_tokens.padding_idx = config.pad_token_id
47
+
48
+ def get_num_layers(self):
49
+ return len(self.lm.model.layers)
50
+
51
+
52
+ class EmuModel(EmuPreTrainedModel):
53
+
54
+ def __init__(self, config):
55
+ super().__init__(config)
56
+
57
+ vision_config = Namespace(**config.vision_config)
58
+
59
+ self.visual = EVAVisionTransformer(
60
+ img_size=vision_config.image_size,
61
+ patch_size=vision_config.patch_size,
62
+ embed_dim=vision_config.width,
63
+ depth=vision_config.layers,
64
+ num_heads=vision_config.width // vision_config.head_width,
65
+ mlp_ratio=vision_config.mlp_ratio,
66
+ qkv_bias=vision_config.qkv_bias,
67
+ drop_path_rate=vision_config.drop_path_rate,
68
+ norm_layer=partial(nn.LayerNorm, eps=vision_config.layer_norm_eps),
69
+ xattn=vision_config.xattn,
70
+ postnorm=vision_config.postnorm,
71
+ )
72
+
73
+ self.decoder = EmuForClsAndRegression(config)
74
+
75
+ self.gradient_checkpointing = False
76
+
77
+ self.n_query = vision_config.n_query
78
+ self.v_query = vision_config.v_query
79
+
80
+ @property
81
+ def device(self):
82
+ return next(iter(self.parameters())).device
83
+
84
+ @property
85
+ def dtype(self):
86
+ return next(iter(self.parameters())).dtype
87
+
88
+ @torch.no_grad()
89
+ def encode_image(self, image: torch.Tensor, *, n_query=None):
90
+ n_query = n_query if n_query is not None else self.n_query
91
+
92
+ image_embeds = self.visual(image)
93
+ image_embeds = image_embeds[:, 1:, :]
94
+ b, n, c = image_embeds.shape
95
+ sqrt_n = int(n**0.5)
96
+ image_embeds = image_embeds.permute(0, 2, 1).view(b, c, sqrt_n, sqrt_n)
97
+
98
+ stride = int(sqrt_n // (n_query ** 0.5))
99
+ image_embeds = F.avg_pool2d(image_embeds, kernel_size=(stride, stride), stride=stride)
100
+ image_embeds = image_embeds.view(b, c, -1).permute(0, 2, 1).contiguous()
101
+ return image_embeds
102
+
103
+
104
+ class EmuForCausalLM(EmuPreTrainedModel):
105
+ _auto_class = "AutoModelForCausalLM"
106
+
107
+ def __init__(self, config):
108
+ super().__init__(config)
109
+
110
+ self.config = config
111
+ self.model = EmuModel(config)
112
+ # LM to EVA
113
+ self.project_down = nn.Linear(config.hidden_size, config.d_model, bias=False)
114
+ # EVA to LM
115
+ self.project_up = nn.Linear(config.d_model, config.hidden_size, bias=False)
116
+
117
+ self.n_query = self.model.n_query
118
+ self.v_query = self.model.v_query
119
+ self.image_placeholder = DEFAULT_IMG_TOKEN + DEFAULT_IMAGE_TOKEN * self.n_query + DEFAULT_IMG_END_TOKEN
120
+
121
+ # temporarily borrow [gIMG] as the video frame feature placeholder.
122
+ self.video_placeholder = DEFAULT_IMG_TOKEN + DEFAULT_gIMG_TOKEN * self.v_query + DEFAULT_IMG_END_TOKEN
123
+
124
+ def device(self, module=None):
125
+ if module is not None:
126
+ return next(module.parameters()).device
127
+ return next(iter(self.parameters())).device
128
+
129
+ def dtype(self, module=None):
130
+ if module is not None:
131
+ return next(module.parameters()).dtype
132
+ return next(iter(self.parameters())).dtype
133
+
134
+ @torch.no_grad()
135
+ def generate(
136
+ self,
137
+ input_ids,
138
+ attention_mask,
139
+ image: Optional[torch.Tensor] = None,
140
+ video: Optional[torch.Tensor] = None,
141
+ num_beams=5,
142
+ max_new_tokens=10,
143
+ min_len=1,
144
+ do_sample=False,
145
+ penalty_alpha=None,
146
+ top_p=None,
147
+ top_k=None,
148
+ temperature=None,
149
+ length_penalty=-1,
150
+ repetition_penalty=1.0,
151
+ **kwargs
152
+ ):
153
+
154
+ text_embeds = self.model.decoder.lm.model.embed_tokens(input_ids)
155
+ if image is not None:
156
+ prompt_image_embeds = self.model.encode_image(image, n_query=self.n_query)
157
+ _, _, c = prompt_image_embeds.shape
158
+ prompt_image_embeds = prompt_image_embeds.view(-1, c)
159
+ prompt_image_embeds = self.project_up(prompt_image_embeds)
160
+ image_idx = (input_ids == IMAGE)
161
+ text_embeds[image_idx] = prompt_image_embeds.to(text_embeds.device)
162
+
163
+ if video is not None:
164
+ prompt_video_embeds = self.model.encode_image(video, n_query=self.v_query)
165
+ _, _, c = prompt_video_embeds.shape
166
+ prompt_video_embeds = prompt_video_embeds.view(-1, c)
167
+ prompt_video_embeds = self.project_up(prompt_video_embeds)
168
+ video_idx = (input_ids == VIDEO)
169
+ text_embeds[video_idx] = prompt_video_embeds.to(text_embeds.device)
170
+
171
+ outputs = self.model.decoder.lm.generate(
172
+ inputs_embeds=text_embeds,
173
+ attention_mask=attention_mask,
174
+ do_sample=do_sample,
175
+ num_beams=num_beams,
176
+ max_new_tokens=max_new_tokens,
177
+ min_length=min_len,
178
+ length_penalty=length_penalty,
179
+ repetition_penalty=repetition_penalty,
180
+ penalty_alpha=penalty_alpha,
181
+ top_k=top_k,
182
+ top_p=top_p,
183
+ temperature=temperature,
184
+ **kwargs,
185
+ )
186
+
187
+ return outputs
188
+
189
+ def prepare_image_input(self, images):
190
+ image_size: int = self.config.vision_config['image_size']
191
+ transform = T.Compose(
192
+ [
193
+ T.Resize(
194
+ (image_size, image_size), interpolation=T.InterpolationMode.BICUBIC
195
+ ),
196
+ T.ToTensor(),
197
+ T.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD),
198
+ ]
199
+ )
200
+ images = [transform(image) for image in images]
201
+ return torch.stack(images, 0)
202
+
203
+ def _prepare_chat_template(self, text, system_msg=""):
204
+ text = [
205
+ system_msg + USER_TOKEN + ": " + t + ASSISTANT_TOKEN +":"
206
+ for t in text
207
+ ]
208
+ return text
209
+
210
+ def prepare_text_input(
211
+ self,
212
+ text: List[str],
213
+ tokenizer: PreTrainedTokenizer,
214
+ image_placeholder: str = DEFAULT_IMG_PLACEHOLDER,
215
+ video_placeholder: str = DEFAULT_VID_PLACEHOLDER,
216
+ ):
217
+ text = [
218
+ t.replace(image_placeholder, self.image_placeholder).replace(video_placeholder, self.video_placeholder)
219
+ for t in text
220
+ ]
221
+ input_ids = tokenizer(text, padding="longest", return_tensors="pt")
222
+ return input_ids
223
+
224
+
225
+ def build_input_ids(
226
+ self,
227
+ text: List[str],
228
+ tokenizer: PreTrainedTokenizer,
229
+ image: Optional[List["PIL.Image"]] = None,
230
+ video: Optional[List["PIL.Image"]] = None,
231
+ system_msg: str = "",
232
+ to_cuda: bool = True
233
+ ):
234
+
235
+ if self.config.model_version == "chat":
236
+ text = self._prepare_chat_template(text, system_msg)
237
+
238
+ if image is not None:
239
+ image = self.prepare_image_input(image)
240
+ if video is not None:
241
+ video = self.prepare_image_input(video)
242
+ inputs = self.prepare_text_input(text, tokenizer)
243
+ input_ids = inputs.input_ids
244
+ attention_mask = inputs.attention_mask
245
+
246
+ if to_cuda:
247
+ device = self.device(self.model.decoder.lm.model.embed_tokens)
248
+ input_ids = input_ids.to(device)
249
+ attention_mask = attention_mask.to(device)
250
+
251
+ device = self.device(self.model.visual)
252
+ if image is not None:
253
+ image = image.to(device)
254
+ if video is not None:
255
+ video = video.to(device)
256
+
257
+ return {
258
+ 'input_ids': input_ids,
259
+ 'attention_mask': attention_mask,
260
+ 'image': image,
261
+ 'video': video
262
+ }
modeling_llama.py ADDED
@@ -0,0 +1,1013 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers import PreTrainedModel
31
+ from transformers import LlamaConfig
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
34
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
35
+
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CONFIG_FOR_DOC = "LlamaConfig"
41
+
42
+
43
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
44
+ def _make_causal_mask(
45
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
46
+ ):
47
+ """
48
+ Make causal mask used for bi-directional self-attention.
49
+ """
50
+ bsz, tgt_len = input_ids_shape
51
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
52
+ mask_cond = torch.arange(mask.size(-1), device=device)
53
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
54
+ mask = mask.to(dtype)
55
+
56
+ if past_key_values_length > 0:
57
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
58
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
59
+
60
+
61
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
62
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
63
+ """
64
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
65
+ """
66
+ bsz, src_len = mask.size()
67
+ tgt_len = tgt_len if tgt_len is not None else src_len
68
+
69
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
70
+
71
+ inverted_mask = 1.0 - expanded_mask
72
+
73
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
74
+
75
+
76
+ class LlamaRMSNorm(nn.Module):
77
+ def __init__(self, hidden_size, eps=1e-6):
78
+ """
79
+ LlamaRMSNorm is equivalent to T5LayerNorm
80
+ """
81
+ super().__init__()
82
+ self.weight = nn.Parameter(torch.ones(hidden_size))
83
+ self.variance_epsilon = eps
84
+
85
+ def forward(self, hidden_states):
86
+ input_dtype = hidden_states.dtype
87
+ hidden_states = hidden_states.to(torch.float32)
88
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
89
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
90
+ return self.weight * hidden_states.to(input_dtype)
91
+
92
+
93
+ class LlamaRotaryEmbedding(torch.nn.Module):
94
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
+ super().__init__()
96
+
97
+ self.dim = dim
98
+ self.max_position_embeddings = max_position_embeddings
99
+ self.base = base
100
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
101
+ self.register_buffer("inv_freq", inv_freq)
102
+
103
+ # Build here to make `torch.jit.trace` work.
104
+ self._set_cos_sin_cache(
105
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
106
+ )
107
+
108
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
109
+ self.max_seq_len_cached = seq_len
110
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
111
+
112
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
113
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
114
+ emb = torch.cat((freqs, freqs), dim=-1)
115
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
116
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
117
+
118
+ def forward(self, x, seq_len=None):
119
+ # x: [bs, num_attention_heads, seq_len, head_size]
120
+ if seq_len > self.max_seq_len_cached:
121
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
122
+
123
+ return (
124
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
125
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
126
+ )
127
+
128
+
129
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
130
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
131
+
132
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
133
+ self.scaling_factor = scaling_factor
134
+ super().__init__(dim, max_position_embeddings, base, device)
135
+
136
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
137
+ self.max_seq_len_cached = seq_len
138
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
139
+ t = t / self.scaling_factor
140
+
141
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
142
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
143
+ emb = torch.cat((freqs, freqs), dim=-1)
144
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
145
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
146
+
147
+
148
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
149
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
150
+
151
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
152
+ self.scaling_factor = scaling_factor
153
+ super().__init__(dim, max_position_embeddings, base, device)
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+
158
+ if seq_len > self.max_position_embeddings:
159
+ base = self.base * (
160
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
161
+ ) ** (self.dim / (self.dim - 2))
162
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
163
+ self.register_buffer("inv_freq", inv_freq)
164
+
165
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
166
+
167
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
168
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
169
+ emb = torch.cat((freqs, freqs), dim=-1)
170
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
171
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
172
+
173
+
174
+ def rotate_half(x):
175
+ """Rotates half the hidden dims of the input."""
176
+ x1 = x[..., : x.shape[-1] // 2]
177
+ x2 = x[..., x.shape[-1] // 2 :]
178
+ return torch.cat((-x2, x1), dim=-1)
179
+
180
+
181
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
182
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
183
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
184
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
185
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
186
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
187
+ q_embed = (q * cos) + (rotate_half(q) * sin)
188
+ k_embed = (k * cos) + (rotate_half(k) * sin)
189
+ return q_embed, k_embed
190
+
191
+
192
+ class LlamaMLP(nn.Module):
193
+ def __init__(self, config):
194
+ super().__init__()
195
+ self.pretraining_tp = config.pretraining_tp
196
+ self.hidden_size = config.hidden_size
197
+ self.intermediate_size = config.intermediate_size
198
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
199
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
200
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
201
+ self.act_fn = ACT2FN[config.hidden_act]
202
+
203
+ def forward(self, x):
204
+ if self.pretraining_tp > 1:
205
+ slice = self.intermediate_size // self.pretraining_tp
206
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
207
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
208
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
209
+
210
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
211
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
212
+
213
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
214
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
215
+ down_proj = sum(down_proj)
216
+ else:
217
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
218
+
219
+ return down_proj
220
+
221
+
222
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
223
+ """
224
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
225
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
226
+ """
227
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
228
+ if n_rep == 1:
229
+ return hidden_states
230
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
231
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
232
+
233
+
234
+ class LlamaAttention(nn.Module):
235
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
236
+
237
+ def __init__(self, config: LlamaConfig):
238
+ super().__init__()
239
+ self.config = config
240
+ self.hidden_size = config.hidden_size
241
+ self.num_heads = config.num_attention_heads
242
+ self.head_dim = self.hidden_size // self.num_heads
243
+ self.num_key_value_heads = config.num_key_value_heads
244
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
245
+ self.pretraining_tp = config.pretraining_tp
246
+ self.max_position_embeddings = config.max_position_embeddings
247
+
248
+ if (self.head_dim * self.num_heads) != self.hidden_size:
249
+ raise ValueError(
250
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
251
+ f" and `num_heads`: {self.num_heads})."
252
+ )
253
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
254
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
255
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
256
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
257
+ self._init_rope()
258
+
259
+ def _init_rope(self):
260
+ if self.config.rope_scaling is None:
261
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
262
+ else:
263
+ scaling_type = self.config.rope_scaling["type"]
264
+ scaling_factor = self.config.rope_scaling["factor"]
265
+ if scaling_type == "linear":
266
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
267
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
268
+ )
269
+ elif scaling_type == "dynamic":
270
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
271
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
272
+ )
273
+ else:
274
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
275
+
276
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
277
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
278
+
279
+ def forward(
280
+ self,
281
+ hidden_states: torch.Tensor,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ position_ids: Optional[torch.LongTensor] = None,
284
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
285
+ output_attentions: bool = False,
286
+ use_cache: bool = False,
287
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
288
+ bsz, q_len, _ = hidden_states.size()
289
+
290
+ if self.pretraining_tp > 1:
291
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
292
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
293
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
294
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
295
+
296
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
297
+ query_states = torch.cat(query_states, dim=-1)
298
+
299
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
300
+ key_states = torch.cat(key_states, dim=-1)
301
+
302
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
303
+ value_states = torch.cat(value_states, dim=-1)
304
+
305
+ else:
306
+ query_states = self.q_proj(hidden_states)
307
+ key_states = self.k_proj(hidden_states)
308
+ value_states = self.v_proj(hidden_states)
309
+
310
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
311
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
312
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
313
+
314
+ kv_seq_len = key_states.shape[-2]
315
+ if past_key_value is not None:
316
+ kv_seq_len += past_key_value[0].shape[-2]
317
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
318
+
319
+ # query_states, key_states, cos, sin, position_ids = query_states.to(hidden_states.device), key_states.to(hidden_states.device), cos.to(hidden_states.device), sin.to(hidden_states.device), position_ids.to(hidden_states.device)
320
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
321
+
322
+ if past_key_value is not None:
323
+ # reuse k, v, self_attention
324
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
325
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
326
+
327
+ past_key_value = (key_states, value_states) if use_cache else None
328
+
329
+ # repeat k/v heads if n_kv_heads < n_heads
330
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
331
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
332
+
333
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
334
+
335
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
336
+ raise ValueError(
337
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
338
+ f" {attn_weights.size()}"
339
+ )
340
+
341
+ if attention_mask is not None:
342
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
343
+ raise ValueError(
344
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
345
+ )
346
+ attn_weights = attn_weights + attention_mask
347
+
348
+ # upcast attention to fp32
349
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
350
+ attn_output = torch.matmul(attn_weights, value_states)
351
+
352
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
353
+ raise ValueError(
354
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
355
+ f" {attn_output.size()}"
356
+ )
357
+
358
+ attn_output = attn_output.transpose(1, 2).contiguous()
359
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
360
+
361
+ if self.pretraining_tp > 1:
362
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
363
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
364
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
365
+ else:
366
+ attn_output = self.o_proj(attn_output)
367
+
368
+ if not output_attentions:
369
+ attn_weights = None
370
+
371
+ return attn_output, attn_weights, past_key_value
372
+
373
+
374
+ class LlamaDecoderLayer(nn.Module):
375
+ def __init__(self, config: LlamaConfig):
376
+ super().__init__()
377
+ self.hidden_size = config.hidden_size
378
+ self.self_attn = LlamaAttention(config=config)
379
+ self.mlp = LlamaMLP(config)
380
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states: torch.Tensor,
386
+ attention_mask: Optional[torch.Tensor] = None,
387
+ position_ids: Optional[torch.LongTensor] = None,
388
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
389
+ output_attentions: Optional[bool] = False,
390
+ use_cache: Optional[bool] = False,
391
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
392
+ """
393
+ Args:
394
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
395
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
396
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
397
+ output_attentions (`bool`, *optional*):
398
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
399
+ returned tensors for more detail.
400
+ use_cache (`bool`, *optional*):
401
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
402
+ (see `past_key_values`).
403
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
404
+ """
405
+
406
+ residual = hidden_states
407
+
408
+ hidden_states = self.input_layernorm(hidden_states)
409
+
410
+ # Self Attention
411
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
412
+ hidden_states=hidden_states,
413
+ attention_mask=attention_mask,
414
+ position_ids=position_ids,
415
+ past_key_value=past_key_value,
416
+ output_attentions=output_attentions,
417
+ use_cache=use_cache,
418
+ )
419
+ hidden_states = residual + hidden_states
420
+
421
+ # Fully Connected
422
+ residual = hidden_states
423
+ hidden_states = self.post_attention_layernorm(hidden_states)
424
+ hidden_states = self.mlp(hidden_states)
425
+ hidden_states = residual + hidden_states
426
+
427
+ outputs = (hidden_states,)
428
+
429
+ if output_attentions:
430
+ outputs += (self_attn_weights,)
431
+
432
+ if use_cache:
433
+ outputs += (present_key_value,)
434
+
435
+ return outputs
436
+
437
+
438
+ LLAMA_START_DOCSTRING = r"""
439
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
440
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
441
+ etc.)
442
+
443
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
444
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
445
+ and behavior.
446
+
447
+ Parameters:
448
+ config ([`LlamaConfig`]):
449
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
450
+ load the weights associated with the model, only the configuration. Check out the
451
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
452
+ """
453
+
454
+
455
+ @add_start_docstrings(
456
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
457
+ LLAMA_START_DOCSTRING,
458
+ )
459
+ class LlamaPreTrainedModel(PreTrainedModel):
460
+ config_class = LlamaConfig
461
+ base_model_prefix = "model"
462
+ supports_gradient_checkpointing = True
463
+ _no_split_modules = ["LlamaDecoderLayer"]
464
+ _skip_keys_device_placement = "past_key_values"
465
+
466
+ def _init_weights(self, module):
467
+ std = self.config.initializer_range
468
+ if isinstance(module, nn.Linear):
469
+ module.weight.data.normal_(mean=0.0, std=std)
470
+ if module.bias is not None:
471
+ module.bias.data.zero_()
472
+ elif isinstance(module, nn.Embedding):
473
+ module.weight.data.normal_(mean=0.0, std=std)
474
+ if module.padding_idx is not None:
475
+ module.weight.data[module.padding_idx].zero_()
476
+
477
+ def _set_gradient_checkpointing(self, module, value=False):
478
+ if isinstance(module, LlamaModel):
479
+ module.gradient_checkpointing = value
480
+
481
+
482
+ LLAMA_INPUTS_DOCSTRING = r"""
483
+ Args:
484
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
485
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
486
+ it.
487
+
488
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
489
+ [`PreTrainedTokenizer.__call__`] for details.
490
+
491
+ [What are input IDs?](../glossary#input-ids)
492
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
493
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
494
+
495
+ - 1 for tokens that are **not masked**,
496
+ - 0 for tokens that are **masked**.
497
+
498
+ [What are attention masks?](../glossary#attention-mask)
499
+
500
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
501
+ [`PreTrainedTokenizer.__call__`] for details.
502
+
503
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
504
+ `past_key_values`).
505
+
506
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
507
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
508
+ information on the default strategy.
509
+
510
+ - 1 indicates the head is **not masked**,
511
+ - 0 indicates the head is **masked**.
512
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
513
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
514
+ config.n_positions - 1]`.
515
+
516
+ [What are position IDs?](../glossary#position-ids)
517
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
518
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
519
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
520
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
521
+
522
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
523
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
524
+
525
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
526
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
527
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
528
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
529
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
530
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
531
+ model's internal embedding lookup matrix.
532
+ use_cache (`bool`, *optional*):
533
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
534
+ `past_key_values`).
535
+ output_attentions (`bool`, *optional*):
536
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
537
+ tensors for more detail.
538
+ output_hidden_states (`bool`, *optional*):
539
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
540
+ more detail.
541
+ return_dict (`bool`, *optional*):
542
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
543
+ """
544
+
545
+
546
+ @add_start_docstrings(
547
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
548
+ LLAMA_START_DOCSTRING,
549
+ )
550
+ class LlamaModel(LlamaPreTrainedModel):
551
+ """
552
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
553
+
554
+ Args:
555
+ config: LlamaConfig
556
+ """
557
+
558
+ def __init__(self, config: LlamaConfig):
559
+ super().__init__(config)
560
+ self.padding_idx = config.pad_token_id
561
+ self.vocab_size = config.vocab_size
562
+
563
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
564
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
565
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
566
+
567
+ self.gradient_checkpointing = False
568
+ # Initialize weights and apply final processing
569
+ self.post_init()
570
+
571
+ def get_input_embeddings(self):
572
+ return self.embed_tokens
573
+
574
+ def set_input_embeddings(self, value):
575
+ self.embed_tokens = value
576
+
577
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
578
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
579
+ # create causal mask
580
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
581
+ combined_attention_mask = None
582
+ if input_shape[-1] > 1:
583
+ combined_attention_mask = _make_causal_mask(
584
+ input_shape,
585
+ inputs_embeds.dtype,
586
+ device=inputs_embeds.device,
587
+ past_key_values_length=past_key_values_length,
588
+ )
589
+
590
+ if attention_mask is not None:
591
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
592
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
593
+ inputs_embeds.device
594
+ )
595
+ combined_attention_mask = (
596
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
597
+ )
598
+
599
+ return combined_attention_mask
600
+
601
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
602
+ def forward(
603
+ self,
604
+ input_ids: torch.LongTensor = None,
605
+ attention_mask: Optional[torch.Tensor] = None,
606
+ position_ids: Optional[torch.LongTensor] = None,
607
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
608
+ inputs_embeds: Optional[torch.FloatTensor] = None,
609
+ use_cache: Optional[bool] = None,
610
+ output_attentions: Optional[bool] = None,
611
+ output_hidden_states: Optional[bool] = None,
612
+ return_dict: Optional[bool] = None,
613
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
614
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
615
+ output_hidden_states = (
616
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
617
+ )
618
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
619
+
620
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
621
+
622
+ # retrieve input_ids and inputs_embeds
623
+ if input_ids is not None and inputs_embeds is not None:
624
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
625
+ elif input_ids is not None:
626
+ batch_size, seq_length = input_ids.shape
627
+ elif inputs_embeds is not None:
628
+ batch_size, seq_length, _ = inputs_embeds.shape
629
+ else:
630
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
631
+
632
+ seq_length_with_past = seq_length
633
+ past_key_values_length = 0
634
+
635
+ if past_key_values is not None:
636
+ past_key_values_length = past_key_values[0][0].shape[2]
637
+ seq_length_with_past = seq_length_with_past + past_key_values_length
638
+
639
+ if position_ids is None:
640
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
641
+ position_ids = torch.arange(
642
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
643
+ )
644
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
645
+ else:
646
+ position_ids = position_ids.view(-1, seq_length).long()
647
+
648
+ if inputs_embeds is None:
649
+ inputs_embeds = self.embed_tokens(input_ids)
650
+ # embed positions
651
+ if attention_mask is None:
652
+ attention_mask = torch.ones(
653
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
654
+ )
655
+ attention_mask = self._prepare_decoder_attention_mask(
656
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
657
+ )
658
+
659
+ hidden_states = inputs_embeds
660
+
661
+ if self.gradient_checkpointing and self.training:
662
+ if use_cache:
663
+ logger.warning_once(
664
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
665
+ )
666
+ use_cache = False
667
+
668
+ # decoder layers
669
+ all_hidden_states = () if output_hidden_states else None
670
+ all_self_attns = () if output_attentions else None
671
+ next_decoder_cache = () if use_cache else None
672
+
673
+ for idx, decoder_layer in enumerate(self.layers):
674
+ if output_hidden_states:
675
+ all_hidden_states += (hidden_states,)
676
+
677
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
678
+
679
+ if self.gradient_checkpointing and self.training:
680
+
681
+ def create_custom_forward(module):
682
+ def custom_forward(*inputs):
683
+ # None for past_key_value
684
+ return module(*inputs, output_attentions, None)
685
+
686
+ return custom_forward
687
+
688
+ layer_outputs = torch.utils.checkpoint.checkpoint(
689
+ create_custom_forward(decoder_layer),
690
+ hidden_states,
691
+ attention_mask,
692
+ position_ids,
693
+ None,
694
+ )
695
+ else:
696
+ layer_outputs = decoder_layer(
697
+ hidden_states,
698
+ attention_mask=attention_mask,
699
+ position_ids=position_ids,
700
+ past_key_value=past_key_value,
701
+ output_attentions=output_attentions,
702
+ use_cache=use_cache,
703
+ )
704
+
705
+ hidden_states = layer_outputs[0]
706
+
707
+ if use_cache:
708
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
709
+
710
+ if output_attentions:
711
+ all_self_attns += (layer_outputs[1],)
712
+
713
+ hidden_states = self.norm(hidden_states)
714
+
715
+ # add hidden states from the last decoder layer
716
+ if output_hidden_states:
717
+ all_hidden_states += (hidden_states,)
718
+
719
+ next_cache = next_decoder_cache if use_cache else None
720
+ if not return_dict:
721
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
722
+ return BaseModelOutputWithPast(
723
+ last_hidden_state=hidden_states,
724
+ past_key_values=next_cache,
725
+ hidden_states=all_hidden_states,
726
+ attentions=all_self_attns,
727
+ )
728
+
729
+
730
+ class LlamaForCausalLM(LlamaPreTrainedModel):
731
+ _tied_weights_keys = ["lm_head.weight"]
732
+
733
+ def __init__(self, config):
734
+ super().__init__(config)
735
+ self.model = LlamaModel(config)
736
+ self.pretraining_tp = config.pretraining_tp
737
+ self.vocab_size = config.vocab_size
738
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
739
+
740
+ # Initialize weights and apply final processing
741
+ self.post_init()
742
+
743
+ def get_input_embeddings(self):
744
+ return self.model.embed_tokens
745
+
746
+ def set_input_embeddings(self, value):
747
+ self.model.embed_tokens = value
748
+
749
+ def get_output_embeddings(self):
750
+ return self.lm_head
751
+
752
+ def set_output_embeddings(self, new_embeddings):
753
+ self.lm_head = new_embeddings
754
+
755
+ def set_decoder(self, decoder):
756
+ self.model = decoder
757
+
758
+ def get_decoder(self):
759
+ return self.model
760
+
761
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
762
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
763
+ def forward(
764
+ self,
765
+ input_ids: torch.LongTensor = None,
766
+ attention_mask: Optional[torch.Tensor] = None,
767
+ position_ids: Optional[torch.LongTensor] = None,
768
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
769
+ inputs_embeds: Optional[torch.FloatTensor] = None,
770
+ labels: Optional[torch.LongTensor] = None,
771
+ use_cache: Optional[bool] = None,
772
+ output_attentions: Optional[bool] = None,
773
+ output_hidden_states: Optional[bool] = None,
774
+ return_dict: Optional[bool] = None,
775
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
776
+ r"""
777
+ Args:
778
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
779
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
780
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
781
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
782
+
783
+ Returns:
784
+
785
+ Example:
786
+
787
+ ```python
788
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
789
+
790
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
791
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
792
+
793
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
794
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
795
+
796
+ >>> # Generate
797
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
798
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
799
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
800
+ ```"""
801
+
802
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
803
+ output_hidden_states = (
804
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
805
+ )
806
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
807
+
808
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
809
+ outputs = self.model(
810
+ input_ids=input_ids,
811
+ attention_mask=attention_mask,
812
+ position_ids=position_ids,
813
+ past_key_values=past_key_values,
814
+ inputs_embeds=inputs_embeds,
815
+ use_cache=use_cache,
816
+ output_attentions=output_attentions,
817
+ output_hidden_states=output_hidden_states,
818
+ return_dict=return_dict,
819
+ )
820
+
821
+ hidden_states = outputs[0]
822
+ if self.pretraining_tp > 1:
823
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
824
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
825
+ logits = torch.cat(logits, dim=-1)
826
+ else:
827
+ logits = self.lm_head(hidden_states)
828
+
829
+ logits = logits.float()
830
+
831
+ loss = None
832
+ if labels is not None:
833
+ # Shift so that tokens < n predict n
834
+ shift_logits = logits[..., :-1, :].contiguous()
835
+ shift_labels = labels[..., 1:].contiguous()
836
+ # Flatten the tokens
837
+ loss_fct = CrossEntropyLoss()
838
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
839
+ shift_labels = shift_labels.view(-1)
840
+ # Enable model parallelism
841
+ shift_labels = shift_labels.to(shift_logits.device)
842
+ loss = loss_fct(shift_logits, shift_labels)
843
+
844
+ if not return_dict:
845
+ output = (logits,) + outputs[1:]
846
+ return (loss,) + output if loss is not None else output
847
+
848
+ return CausalLMOutputWithPast(
849
+ loss=loss,
850
+ logits=logits,
851
+ past_key_values=outputs.past_key_values,
852
+ hidden_states=outputs.hidden_states,
853
+ attentions=outputs.attentions,
854
+ )
855
+
856
+ def prepare_inputs_for_generation(
857
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
858
+ ):
859
+ if past_key_values:
860
+ input_ids = input_ids[:, -1:]
861
+
862
+ position_ids = kwargs.get("position_ids", None)
863
+ if attention_mask is not None and position_ids is None:
864
+ # create position_ids on the fly for batch generation
865
+ position_ids = attention_mask.long().cumsum(-1) - 1
866
+ position_ids.masked_fill_(attention_mask == 0, 1)
867
+ if past_key_values:
868
+ position_ids = position_ids[:, -1].unsqueeze(-1)
869
+
870
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
871
+ if inputs_embeds is not None and past_key_values is None:
872
+ model_inputs = {"inputs_embeds": inputs_embeds}
873
+ else:
874
+ model_inputs = {"input_ids": input_ids}
875
+
876
+ model_inputs.update(
877
+ {
878
+ "position_ids": position_ids,
879
+ "past_key_values": past_key_values,
880
+ "use_cache": kwargs.get("use_cache"),
881
+ "attention_mask": attention_mask,
882
+ }
883
+ )
884
+ return model_inputs
885
+
886
+ @staticmethod
887
+ def _reorder_cache(past_key_values, beam_idx):
888
+ reordered_past = ()
889
+ for layer_past in past_key_values:
890
+ reordered_past += (
891
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
892
+ )
893
+ return reordered_past
894
+
895
+
896
+ @add_start_docstrings(
897
+ """
898
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
899
+
900
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
901
+ (e.g. GPT-2) do.
902
+
903
+ Since it does classification on the last token, it requires to know the position of the last token. If a
904
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
905
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
906
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
907
+ each row of the batch).
908
+ """,
909
+ LLAMA_START_DOCSTRING,
910
+ )
911
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
912
+ def __init__(self, config):
913
+ super().__init__(config)
914
+ self.num_labels = config.num_labels
915
+ self.model = LlamaModel(config)
916
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
917
+
918
+ # Initialize weights and apply final processing
919
+ self.post_init()
920
+
921
+ def get_input_embeddings(self):
922
+ return self.model.embed_tokens
923
+
924
+ def set_input_embeddings(self, value):
925
+ self.model.embed_tokens = value
926
+
927
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
928
+ def forward(
929
+ self,
930
+ input_ids: torch.LongTensor = None,
931
+ attention_mask: Optional[torch.Tensor] = None,
932
+ position_ids: Optional[torch.LongTensor] = None,
933
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
934
+ inputs_embeds: Optional[torch.FloatTensor] = None,
935
+ labels: Optional[torch.LongTensor] = None,
936
+ use_cache: Optional[bool] = None,
937
+ output_attentions: Optional[bool] = None,
938
+ output_hidden_states: Optional[bool] = None,
939
+ return_dict: Optional[bool] = None,
940
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
941
+ r"""
942
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
943
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
944
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
945
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
946
+ """
947
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
+
949
+ transformer_outputs = self.model(
950
+ input_ids,
951
+ attention_mask=attention_mask,
952
+ position_ids=position_ids,
953
+ past_key_values=past_key_values,
954
+ inputs_embeds=inputs_embeds,
955
+ use_cache=use_cache,
956
+ output_attentions=output_attentions,
957
+ output_hidden_states=output_hidden_states,
958
+ return_dict=return_dict,
959
+ )
960
+ hidden_states = transformer_outputs[0]
961
+ logits = self.score(hidden_states)
962
+
963
+ if input_ids is not None:
964
+ batch_size = input_ids.shape[0]
965
+ else:
966
+ batch_size = inputs_embeds.shape[0]
967
+
968
+ if self.config.pad_token_id is None and batch_size != 1:
969
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
970
+ if self.config.pad_token_id is None:
971
+ sequence_lengths = -1
972
+ else:
973
+ if input_ids is not None:
974
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
975
+ else:
976
+ sequence_lengths = -1
977
+
978
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
979
+
980
+ loss = None
981
+ if labels is not None:
982
+ labels = labels.to(logits.device)
983
+ if self.config.problem_type is None:
984
+ if self.num_labels == 1:
985
+ self.config.problem_type = "regression"
986
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
987
+ self.config.problem_type = "single_label_classification"
988
+ else:
989
+ self.config.problem_type = "multi_label_classification"
990
+
991
+ if self.config.problem_type == "regression":
992
+ loss_fct = MSELoss()
993
+ if self.num_labels == 1:
994
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
995
+ else:
996
+ loss = loss_fct(pooled_logits, labels)
997
+ elif self.config.problem_type == "single_label_classification":
998
+ loss_fct = CrossEntropyLoss()
999
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1000
+ elif self.config.problem_type == "multi_label_classification":
1001
+ loss_fct = BCEWithLogitsLoss()
1002
+ loss = loss_fct(pooled_logits, labels)
1003
+ if not return_dict:
1004
+ output = (pooled_logits,) + transformer_outputs[1:]
1005
+ return ((loss,) + output) if loss is not None else output
1006
+
1007
+ return SequenceClassifierOutputWithPast(
1008
+ loss=loss,
1009
+ logits=pooled_logits,
1010
+ past_key_values=transformer_outputs.past_key_values,
1011
+ hidden_states=transformer_outputs.hidden_states,
1012
+ attentions=transformer_outputs.attentions,
1013
+ )
pytorch_model-00001-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9819a2b493c192e83063e439bfb7ac4c5185b111e0863c8dc033aad23ab15ad4
3
+ size 9954784612
pytorch_model-00002-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc6cebc3842b9590f1319ad48cb89d8627356a679b307dc1906613d24ef2946a
3
+ size 9968618628
pytorch_model-00003-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d61878aa9742237c6ecf2a3f569ce40799bcad00910ae62b3652662a8926b391
3
+ size 9746799062
pytorch_model-00004-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f4b78795c21e763c09ff5dc7e074310183b30713d361209c344821f0ebf05de
3
+ size 9992164572
pytorch_model-00005-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:026bd6e40783a703d65c2233e717ea21b8a2283dc61a24b4c4df42fedff308d6
3
+ size 9746745214
pytorch_model-00006-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffbf0818a2358b68780bb4609b7ef04d7d91be36ac8b10d0dfcdc5c999777bc2
3
+ size 9869481610
pytorch_model-00007-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:892f2b27949ad0488a7887d635bf7c53bbef88b87246e3dcbefc2f7007a39fd6
3
+ size 9869428168
pytorch_model-00008-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2cc905fce61211c0fcd4e87f0eb3a83d6ffab056776468715ae2fb3da597914
3
+ size 9746799126
pytorch_model-00009-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3053a7546a8f425bdd2965b7f0c6f8edc218119a5a15d9604dce6aa98de2f44a
3
+ size 9992164636
pytorch_model-00010-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3887df7f66f5877b12f65abcf033c5ea820841ed4f70a750838f0892b27a9bc
3
+ size 9746745214
pytorch_model-00011-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d174b7c56c23e0cfa6fa1a2a6b595ee4bed3d576d5dfa4ae58250a1ce92e433
3
+ size 9869481610
pytorch_model-00012-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d3032e2d2124fbe0d86599ff0fe81773df7c0c59f3372ff5daa28639737f7d3
3
+ size 9869428168
pytorch_model-00013-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45209632e2d11384f68a084c65748e0ae7b59f08b3ad39919d01d5a89e5abbef
3
+ size 9746799126
pytorch_model-00014-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6f4b9d8d2ec90715cca7d6c824ad4edacc4ae9866ff70d245d132be3584f3e7
3
+ size 9992164636
pytorch_model-00015-of-00015.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:832389d9df8f0743cd755b0b81349237ebf13f8a2dd9a9f888ef6472e369dddb
3
+ size 9515514626
pytorch_model.bin.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "[IMG]",
4
+ "[/IMG]",
5
+ "<image>",
6
+ "[gIMG]",
7
+ "[/gIMG]",
8
+ "[EOC]",
9
+ "[VIDEO]",
10
+ "<grounding>",
11
+ "<phrase>",
12
+ "</phrase>",
13
+ "<object>",
14
+ "</object>",
15
+ "</delimiter_of_multi_objects/>",
16
+ "<REC>",
17
+ "<patch_index_0000>",
18
+ "<patch_index_0001>",
19
+ "<patch_index_0002>",
20
+ "<patch_index_0003>",
21
+ "<patch_index_0004>",
22
+ "<patch_index_0005>",
23
+ "<patch_index_0006>",
24
+ "<patch_index_0007>",
25
+ "<patch_index_0008>",
26
+ "<patch_index_0009>",
27
+ "<patch_index_0010>",
28
+ "<patch_index_0011>",
29
+ "<patch_index_0012>",
30
+ "<patch_index_0013>",
31
+ "<patch_index_0014>",
32
+ "<patch_index_0015>",
33
+ "<patch_index_0016>",
34
+ "<patch_index_0017>",
35
+ "<patch_index_0018>",
36
+ "<patch_index_0019>",
37
+ "<patch_index_0020>",
38
+ "<patch_index_0021>",
39
+ "<patch_index_0022>",
40
+ "<patch_index_0023>",
41
+ "<patch_index_0024>",
42
+ "<patch_index_0025>",
43
+ "<patch_index_0026>",
44
+ "<patch_index_0027>",
45
+ "<patch_index_0028>",
46
+ "<patch_index_0029>",
47
+ "<patch_index_0030>",
48
+ "<patch_index_0031>",
49
+ "<patch_index_0032>",
50
+ "<patch_index_0033>",
51
+ "<patch_index_0034>",
52
+ "<patch_index_0035>",
53
+ "<patch_index_0036>",
54
+ "<patch_index_0037>",
55
+ "<patch_index_0038>",
56
+ "<patch_index_0039>",
57
+ "<patch_index_0040>",
58
+ "<patch_index_0041>",
59
+ "<patch_index_0042>",
60
+ "<patch_index_0043>",
61
+ "<patch_index_0044>",
62
+ "<patch_index_0045>",
63
+ "<patch_index_0046>",
64
+ "<patch_index_0047>",
65
+ "<patch_index_0048>",
66
+ "<patch_index_0049>",
67
+ "<patch_index_0050>",
68
+ "<patch_index_0051>",
69
+ "<patch_index_0052>",
70
+ "<patch_index_0053>",
71
+ "<patch_index_0054>",
72
+ "<patch_index_0055>",
73
+ "<patch_index_0056>",
74
+ "<patch_index_0057>",
75
+ "<patch_index_0058>",
76
+ "<patch_index_0059>",
77
+ "<patch_index_0060>",
78
+ "<patch_index_0061>",
79
+ "<patch_index_0062>",
80
+ "<patch_index_0063>",
81
+ "<patch_index_0064>",
82
+ "<patch_index_0065>",
83
+ "<patch_index_0066>",
84
+ "<patch_index_0067>",
85
+ "<patch_index_0068>",
86
+ "<patch_index_0069>",
87
+ "<patch_index_0070>",
88
+ "<patch_index_0071>",
89
+ "<patch_index_0072>",
90
+ "<patch_index_0073>",
91
+ "<patch_index_0074>",
92
+ "<patch_index_0075>",
93
+ "<patch_index_0076>",
94
+ "<patch_index_0077>",
95
+ "<patch_index_0078>",
96
+ "<patch_index_0079>",
97
+ "<patch_index_0080>",
98
+ "<patch_index_0081>",
99
+ "<patch_index_0082>",
100
+ "<patch_index_0083>",
101
+ "<patch_index_0084>",
102
+ "<patch_index_0085>",
103
+ "<patch_index_0086>",
104
+ "<patch_index_0087>",
105
+ "<patch_index_0088>",
106
+ "<patch_index_0089>",
107
+ "<patch_index_0090>",
108
+ "<patch_index_0091>",
109
+ "<patch_index_0092>",
110
+ "<patch_index_0093>",
111
+ "<patch_index_0094>",
112
+ "<patch_index_0095>",
113
+ "<patch_index_0096>",
114
+ "<patch_index_0097>",
115
+ "<patch_index_0098>",
116
+ "<patch_index_0099>",
117
+ "<patch_index_0100>",
118
+ "<patch_index_0101>",
119
+ "<patch_index_0102>",
120
+ "<patch_index_0103>",
121
+ "<patch_index_0104>",
122
+ "<patch_index_0105>",
123
+ "<patch_index_0106>",
124
+ "<patch_index_0107>",
125
+ "<patch_index_0108>",
126
+ "<patch_index_0109>",
127
+ "<patch_index_0110>",
128
+ "<patch_index_0111>",
129
+ "<patch_index_0112>",
130
+ "<patch_index_0113>",
131
+ "<patch_index_0114>",
132
+ "<patch_index_0115>",
133
+ "<patch_index_0116>",
134
+ "<patch_index_0117>",
135
+ "<patch_index_0118>",
136
+ "<patch_index_0119>",
137
+ "<patch_index_0120>",
138
+ "<patch_index_0121>",
139
+ "<patch_index_0122>",
140
+ "<patch_index_0123>",
141
+ "<patch_index_0124>",
142
+ "<patch_index_0125>",
143
+ "<patch_index_0126>",
144
+ "<patch_index_0127>",
145
+ "<patch_index_0128>",
146
+ "<patch_index_0129>",
147
+ "<patch_index_0130>",
148
+ "<patch_index_0131>",
149
+ "<patch_index_0132>",
150
+ "<patch_index_0133>",
151
+ "<patch_index_0134>",
152
+ "<patch_index_0135>",
153
+ "<patch_index_0136>",
154
+ "<patch_index_0137>",
155
+ "<patch_index_0138>",
156
+ "<patch_index_0139>",
157
+ "<patch_index_0140>",
158
+ "<patch_index_0141>",
159
+ "<patch_index_0142>",
160
+ "<patch_index_0143>",
161
+ "<patch_index_0144>",
162
+ "<patch_index_0145>",
163
+ "<patch_index_0146>",
164
+ "<patch_index_0147>",
165
+ "<patch_index_0148>",
166
+ "<patch_index_0149>",
167
+ "<patch_index_0150>",
168
+ "<patch_index_0151>",
169
+ "<patch_index_0152>",
170
+ "<patch_index_0153>",
171
+ "<patch_index_0154>",
172
+ "<patch_index_0155>",
173
+ "<patch_index_0156>",
174
+ "<patch_index_0157>",
175
+ "<patch_index_0158>",
176
+ "<patch_index_0159>",
177
+ "<patch_index_0160>",
178
+ "<patch_index_0161>",
179
+ "<patch_index_0162>",
180
+ "<patch_index_0163>",
181
+ "<patch_index_0164>",
182
+ "<patch_index_0165>",
183
+ "<patch_index_0166>",
184
+ "<patch_index_0167>",
185
+ "<patch_index_0168>",
186
+ "<patch_index_0169>",
187
+ "<patch_index_0170>",
188
+ "<patch_index_0171>",
189
+ "<patch_index_0172>",
190
+ "<patch_index_0173>",
191
+ "<patch_index_0174>",
192
+ "<patch_index_0175>",
193
+ "<patch_index_0176>",
194
+ "<patch_index_0177>",
195
+ "<patch_index_0178>",
196
+ "<patch_index_0179>",
197
+ "<patch_index_0180>",
198
+ "<patch_index_0181>",
199
+ "<patch_index_0182>",
200
+ "<patch_index_0183>",
201
+ "<patch_index_0184>",
202
+ "<patch_index_0185>",
203
+ "<patch_index_0186>",
204
+ "<patch_index_0187>",
205
+ "<patch_index_0188>",
206
+ "<patch_index_0189>",
207
+ "<patch_index_0190>",
208
+ "<patch_index_0191>",
209
+ "<patch_index_0192>",
210
+ "<patch_index_0193>",
211
+ "<patch_index_0194>",
212
+ "<patch_index_0195>",
213
+ "<patch_index_0196>",
214
+ "<patch_index_0197>",
215
+ "<patch_index_0198>",
216
+ "<patch_index_0199>",
217
+ "<patch_index_0200>",
218
+ "<patch_index_0201>",
219
+ "<patch_index_0202>",
220
+ "<patch_index_0203>",
221
+ "<patch_index_0204>",
222
+ "<patch_index_0205>",
223
+ "<patch_index_0206>",
224
+ "<patch_index_0207>",
225
+ "<patch_index_0208>",
226
+ "<patch_index_0209>",
227
+ "<patch_index_0210>",
228
+ "<patch_index_0211>",
229
+ "<patch_index_0212>",
230
+ "<patch_index_0213>",
231
+ "<patch_index_0214>",
232
+ "<patch_index_0215>",
233
+ "<patch_index_0216>",
234
+ "<patch_index_0217>",
235
+ "<patch_index_0218>",
236
+ "<patch_index_0219>",
237
+ "<patch_index_0220>",
238
+ "<patch_index_0221>",
239
+ "<patch_index_0222>",
240
+ "<patch_index_0223>",
241
+ "<patch_index_0224>",
242
+ "<patch_index_0225>",
243
+ "<patch_index_0226>",
244
+ "<patch_index_0227>",
245
+ "<patch_index_0228>",
246
+ "<patch_index_0229>",
247
+ "<patch_index_0230>",
248
+ "<patch_index_0231>",
249
+ "<patch_index_0232>",
250
+ "<patch_index_0233>",
251
+ "<patch_index_0234>",
252
+ "<patch_index_0235>",
253
+ "<patch_index_0236>",
254
+ "<patch_index_0237>",
255
+ "<patch_index_0238>",
256
+ "<patch_index_0239>",
257
+ "<patch_index_0240>",
258
+ "<patch_index_0241>",
259
+ "<patch_index_0242>",
260
+ "<patch_index_0243>",
261
+ "<patch_index_0244>",
262
+ "<patch_index_0245>",
263
+ "<patch_index_0246>",
264
+ "<patch_index_0247>",
265
+ "<patch_index_0248>",
266
+ "<patch_index_0249>",
267
+ "<patch_index_0250>",
268
+ "<patch_index_0251>",
269
+ "<patch_index_0252>",
270
+ "<patch_index_0253>",
271
+ "<patch_index_0254>",
272
+ "<patch_index_0255>",
273
+ "<patch_index_0256>",
274
+ "[USER]",
275
+ "[ASSISTANT]"
276
+ ],
277
+ "bos_token": "<s>",
278
+ "eos_token": "</s>",
279
+ "pad_token": "[PAD]",
280
+ "unk_token": {
281
+ "content": "<unk>",
282
+ "lstrip": false,
283
+ "normalized": true,
284
+ "rstrip": false,
285
+ "single_word": false
286
+ }
287
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "pad_token": null,
24
+ "sp_model_kwargs": {},
25
+ "tokenizer_class": "LlamaTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
visual.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+
5
+ import os
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ try:
14
+ from timm.models.layers import drop_path, to_2tuple
15
+ except:
16
+ from timm.layers import drop_path, to_2tuple
17
+
18
+ try:
19
+ import xformers.ops as xops
20
+ except ImportError:
21
+ xops = None
22
+ print("Please 'pip install xformers'")
23
+
24
+
25
+ class PatchDropout(nn.Module):
26
+ """
27
+ https://arxiv.org/abs/2212.00794
28
+ """
29
+
30
+ def __init__(self, prob, exclude_first_token=True):
31
+ super().__init__()
32
+ assert 0 <= prob < 1.
33
+ self.prob = prob
34
+ self.exclude_first_token = exclude_first_token # exclude CLS token
35
+ print(f"os.getenv('RoPE')={os.getenv('RoPE')}")
36
+
37
+ def forward(self, x):
38
+ if not self.training or self.prob == 0.:
39
+ return x
40
+
41
+ if self.exclude_first_token:
42
+ cls_tokens, x = x[:, :1], x[:, 1:]
43
+ else:
44
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
45
+
46
+ batch = x.size()[0]
47
+ num_tokens = x.size()[1]
48
+
49
+ batch_indices = torch.arange(batch)
50
+ batch_indices = batch_indices[..., None]
51
+
52
+ keep_prob = 1 - self.prob
53
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
54
+
55
+ rand = torch.randn(batch, num_tokens)
56
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
57
+
58
+ x = x[batch_indices, patch_indices_keep]
59
+
60
+ if self.exclude_first_token:
61
+ x = torch.cat((cls_tokens, x), dim=1)
62
+
63
+ if self.training and os.getenv('RoPE') == '1':
64
+ return x, patch_indices_keep
65
+
66
+ return x
67
+
68
+ class DropPath(nn.Module):
69
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
70
+ """
71
+ def __init__(self, drop_prob=None):
72
+ super(DropPath, self).__init__()
73
+ self.drop_prob = drop_prob
74
+
75
+ def forward(self, x):
76
+ return drop_path(x, self.drop_prob, self.training)
77
+
78
+ def extra_repr(self) -> str:
79
+ return 'p={}'.format(self.drop_prob)
80
+
81
+
82
+ class Mlp(nn.Module):
83
+ def __init__(
84
+ self,
85
+ in_features,
86
+ hidden_features=None,
87
+ out_features=None,
88
+ act_layer=nn.GELU,
89
+ norm_layer=nn.LayerNorm,
90
+ drop=0.,
91
+ subln=False,
92
+
93
+ ):
94
+ super().__init__()
95
+ out_features = out_features or in_features
96
+ hidden_features = hidden_features or in_features
97
+ self.fc1 = nn.Linear(in_features, hidden_features)
98
+ self.act = act_layer()
99
+
100
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
101
+
102
+ self.fc2 = nn.Linear(hidden_features, out_features)
103
+ self.drop = nn.Dropout(drop)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x)
107
+ x = self.act(x)
108
+ # x = self.drop(x)
109
+ # commit this for the orignal BERT implement
110
+ x = self.ffn_ln(x)
111
+
112
+ x = self.fc2(x)
113
+ x = self.drop(x)
114
+ return x
115
+
116
+ class SwiGLU(nn.Module):
117
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
118
+ norm_layer=nn.LayerNorm, subln=False):
119
+ super().__init__()
120
+ out_features = out_features or in_features
121
+ hidden_features = hidden_features or in_features
122
+
123
+ self.w1 = nn.Linear(in_features, hidden_features)
124
+ self.w2 = nn.Linear(in_features, hidden_features)
125
+
126
+ self.act = act_layer()
127
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
128
+ self.w3 = nn.Linear(hidden_features, out_features)
129
+
130
+ self.drop = nn.Dropout(drop)
131
+
132
+ def forward(self, x):
133
+ x1 = self.w1(x)
134
+ x2 = self.w2(x)
135
+ hidden = self.act(x1) * x2
136
+ x = self.ffn_ln(hidden)
137
+ x = self.w3(x)
138
+ x = self.drop(x)
139
+ return x
140
+
141
+ class Attention(nn.Module):
142
+ def __init__(
143
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
144
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
145
+ super().__init__()
146
+ self.num_heads = num_heads
147
+ head_dim = dim // num_heads
148
+ if attn_head_dim is not None:
149
+ head_dim = attn_head_dim
150
+ all_head_dim = head_dim * self.num_heads
151
+ self.scale = qk_scale or head_dim ** -0.5
152
+
153
+ self.subln = subln
154
+ if self.subln:
155
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
156
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
157
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
158
+ else:
159
+ if qkv_bias:
160
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True)
161
+ else:
162
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
163
+
164
+ # if qkv_bias:
165
+ # self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
166
+ # self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
167
+ # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
168
+ # self.qkv.bias.data = qkv_bias
169
+ # else:
170
+ # self.q_bias = None
171
+ # self.v_bias = None
172
+
173
+ self.window_size = None
174
+ self.relative_position_bias_table = None
175
+ self.relative_position_index = None
176
+
177
+ self.attn_drop = nn.Dropout(attn_drop)
178
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
179
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
180
+ self.proj = nn.Linear(all_head_dim, dim)
181
+ self.proj_drop = nn.Dropout(proj_drop)
182
+ self.xattn = xattn
183
+ self.xattn_drop = attn_drop
184
+
185
+ self.rope = rope
186
+
187
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
188
+ B, N, C = x.shape
189
+ if self.subln:
190
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
191
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
192
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
193
+
194
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
195
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
196
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
197
+ else:
198
+
199
+ # qkv_bias = None
200
+ # if self.q_bias is not None:
201
+ # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
202
+
203
+ # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
204
+
205
+ qkv = self.qkv(x)
206
+
207
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
208
+ q, k, v = qkv[0], qkv[1], qkv[2]
209
+
210
+ if self.rope:
211
+ q_t = q[:, :, 1:, :]
212
+ ro_q_t = self.rope(q_t)
213
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
214
+
215
+ k_t = k[:, :, 1:, :]
216
+ ro_k_t = self.rope(k_t)
217
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
218
+
219
+ if self.xattn:
220
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
221
+ k = k.permute(0, 2, 1, 3)
222
+ v = v.permute(0, 2, 1, 3)
223
+
224
+ x = xops.memory_efficient_attention(
225
+ q, k, v,
226
+ p=self.xattn_drop,
227
+ scale=self.scale,
228
+ )
229
+ x = x.reshape(B, N, -1)
230
+ x = self.inner_attn_ln(x)
231
+ x = self.proj(x)
232
+ x = self.proj_drop(x)
233
+ else:
234
+ q = q * self.scale
235
+ attn = (q @ k.transpose(-2, -1))
236
+
237
+ if self.relative_position_bias_table is not None:
238
+ relative_position_bias = \
239
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
240
+ self.window_size[0] * self.window_size[1] + 1,
241
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
242
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
243
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
244
+
245
+ if rel_pos_bias is not None:
246
+ attn = attn + rel_pos_bias.type_as(attn)
247
+
248
+ if attn_mask is not None:
249
+ attn_mask = attn_mask.bool()
250
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
251
+
252
+ attn = attn.softmax(dim=-1)
253
+ attn = self.attn_drop(attn)
254
+
255
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
256
+ x = self.inner_attn_ln(x)
257
+ x = self.proj(x)
258
+ x = self.proj_drop(x)
259
+ return x
260
+
261
+
262
+ class Block(nn.Module):
263
+
264
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
265
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
266
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
267
+ subln=False, naiveswiglu=False):
268
+ super().__init__()
269
+ self.norm1 = norm_layer(dim)
270
+ self.attn = Attention(
271
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
272
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
273
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
274
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
276
+ self.norm2 = norm_layer(dim)
277
+ mlp_hidden_dim = int(dim * mlp_ratio)
278
+
279
+ if naiveswiglu:
280
+ self.mlp = SwiGLU(
281
+ in_features=dim,
282
+ hidden_features=mlp_hidden_dim,
283
+ subln=subln,
284
+ norm_layer=norm_layer,
285
+ )
286
+ else:
287
+ self.mlp = Mlp(
288
+ in_features=dim,
289
+ hidden_features=mlp_hidden_dim,
290
+ act_layer=act_layer,
291
+ subln=subln,
292
+ drop=drop
293
+ )
294
+
295
+ if init_values is not None and init_values > 0:
296
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
297
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
298
+ else:
299
+ self.gamma_1, self.gamma_2 = None, None
300
+
301
+ self.postnorm = postnorm
302
+
303
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
304
+ if self.gamma_1 is None:
305
+ if self.postnorm:
306
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
307
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
308
+ else:
309
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
310
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
311
+ else:
312
+ if self.postnorm:
313
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
314
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
315
+ else:
316
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
317
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
318
+ return x
319
+
320
+
321
+ class PatchEmbed(nn.Module):
322
+ """ Image to Patch Embedding
323
+ """
324
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
325
+ super().__init__()
326
+ img_size = to_2tuple(img_size)
327
+ patch_size = to_2tuple(patch_size)
328
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
329
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
330
+ self.img_size = img_size
331
+ self.patch_size = patch_size
332
+ self.num_patches = num_patches
333
+
334
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
335
+
336
+ def forward(self, x, **kwargs):
337
+ B, C, H, W = x.shape
338
+ # FIXME look at relaxing size constraints
339
+ assert H == self.img_size[0] and W == self.img_size[1], \
340
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
341
+ x = self.proj(x).flatten(2).transpose(1, 2)
342
+ return x
343
+
344
+
345
+ class EVAVisionTransformer(nn.Module):
346
+ """ Vision Transformer with support for patch or hybrid CNN input stage
347
+ """
348
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
349
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
350
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
351
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
352
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
353
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False,
354
+ ):
355
+ super().__init__()
356
+ self.image_size = img_size
357
+ # self.num_classes = num_classes
358
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
359
+
360
+ self.patch_embed = PatchEmbed(
361
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
362
+ num_patches = self.patch_embed.num_patches
363
+
364
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
365
+ if use_abs_pos_emb:
366
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
367
+ else:
368
+ self.pos_embed = None
369
+ self.pos_drop = nn.Dropout(p=drop_rate)
370
+
371
+ self.rel_pos_bias = None
372
+ self.rope = None
373
+
374
+ self.naiveswiglu = naiveswiglu
375
+
376
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
377
+ self.use_rel_pos_bias = use_rel_pos_bias
378
+ self.blocks = nn.ModuleList([
379
+ Block(
380
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
381
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
382
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
383
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
384
+ for i in range(depth)])
385
+
386
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
387
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
388
+
389
+ self.grad_checkpointing = grad_checkpointing
390
+
391
+
392
+ def get_num_layers(self):
393
+ return len(self.blocks)
394
+
395
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
396
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
397
+ for param in self.parameters():
398
+ param.requires_grad = False
399
+
400
+ @torch.jit.ignore
401
+ def set_grad_checkpointing(self, enable=True):
402
+ self.grad_checkpointing = enable
403
+
404
+ @torch.jit.ignore
405
+ def no_weight_decay(self):
406
+ return {'pos_embed', 'cls_token'}
407
+
408
+
409
+ def forward_features(self, x):
410
+ x = self.patch_embed(x)
411
+ batch_size, seq_len, _ = x.size()
412
+
413
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
414
+ x = torch.cat((cls_tokens, x), dim=1)
415
+ if self.pos_embed is not None:
416
+ x = x + self.pos_embed
417
+ x = self.pos_drop(x)
418
+
419
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
420
+ if os.getenv('RoPE') == '1':
421
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
422
+ x, patch_indices_keep = self.patch_dropout(x)
423
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
424
+ else:
425
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
426
+ x = self.patch_dropout(x)
427
+ else:
428
+ x = self.patch_dropout(x)
429
+
430
+ rel_pos_bias = None
431
+
432
+ for blk in self.blocks:
433
+ if self.grad_checkpointing:
434
+ x = checkpoint(blk, x, (rel_pos_bias,))
435
+ else:
436
+ x = blk(x, rel_pos_bias=rel_pos_bias)
437
+
438
+ return x
439
+
440
+ def forward(self, x):
441
+
442
+ """
443
+ :return:
444
+ forward_features function returns raw features of ViT,
445
+ forward with return_all_features returns normalized features of ViT
446
+ :param x:
447
+ :param return_all_features:
448
+ """
449
+
450
+ features = self.forward_features(x) # [B, n_patch, C]
451
+
452
+ return features