Gregor commited on
Commit
d1c467d
·
verified ·
1 Parent(s): 47a9186

Upload 3 files

Browse files
Files changed (2) hide show
  1. __init__.py +0 -0
  2. modeling_centurio.py +768 -0
__init__.py ADDED
File without changes
modeling_centurio.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADAPTED FROM https://raw.githubusercontent.com/huggingface/transformers/main/src/transformers/models/llava/modeling_llava.py
2
+ # coding=utf-8
3
+ # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch Llava model."""
17
+ import math
18
+
19
+ import logging
20
+ from dataclasses import dataclass
21
+ from functools import partial
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import timm
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from transformers import LlavaConfig, PreTrainedModel, add_start_docstrings, AutoModel, AutoModelForCausalLM, Cache, \
29
+ T5ForConditionalGeneration, HybridCache, Gemma2ForCausalLM
30
+ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
31
+
32
+ from transformers import LlavaConfig
33
+ from transformers.activations import ACT2FN
34
+ import torch
35
+ from einops import rearrange, repeat
36
+ from torch import einsum, nn
37
+
38
+ from .configuration_centurio import CenturioConfig
39
+
40
+ class LlavaMLPProjector(nn.Module):
41
+ def __init__(self, config: LlavaConfig):
42
+ super().__init__()
43
+
44
+ self.linear_1 = nn.Linear(config.image_hidden_size, config.text_config.hidden_size, bias=True)
45
+ self.act = ACT2FN["gelu"]
46
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
47
+
48
+ def forward(self, image_features):
49
+ hidden_states = self.linear_1(image_features)
50
+ hidden_states = self.act(hidden_states)
51
+ hidden_states = self.linear_2(hidden_states)
52
+ return hidden_states
53
+
54
+ class LlavaMultiModalAdapter(nn.Module):
55
+ def __init__(self, config: LlavaConfig):
56
+ super().__init__()
57
+
58
+ if config.adapter_type == "window-pool":
59
+ self.adapter = WindowPoolProjector(config)
60
+ elif config.adapter_type == "window-shuffel":
61
+ self.adapter = WindowShuffelProjector(config)
62
+ elif config.adapter_type == "multiscale-pool":
63
+ self.adapter = MultiscalePoolProjector(config)
64
+ elif config.adapter_type == "multiscale-shuffel":
65
+ self.adapter = MultiscaleShuffleProjector(config)
66
+ else:
67
+ self.adapter = LlavaMLPProjector(config)
68
+
69
+ def forward(self, image_features):
70
+ return self.adapter(image_features)
71
+
72
+
73
+
74
+ class WindowMLPProjector(nn.Module):
75
+ def __init__(self, config: LlavaConfig):
76
+ super().__init__()
77
+ self.multi_scale = getattr(config, "adapter_multi_scale", 2)
78
+ self.linear_1 = nn.Linear(config.image_hidden_size, config.text_config.hidden_size, bias=True)
79
+ self.act = ACT2FN["gelu"]
80
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
81
+
82
+ def forward(self, image_features):
83
+ hidden_states = self.linear_1(image_features)
84
+ hidden_states = self.act(hidden_states)
85
+ hidden_states = self.linear_2(hidden_states)
86
+
87
+ windows = 1 + self.multi_scale**2
88
+ hidden_states = rearrange(hidden_states, "(b h) w d -> b (h w) d", h=windows)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class WindowPoolProjector(nn.Module):
94
+ def __init__(self, config: LlavaConfig):
95
+ super().__init__()
96
+ self.multi_scale = getattr(config, "adapter_multi_scale", 2)
97
+ self.pool = nn.AdaptiveAvgPool2d(getattr(config, "adapter_pool", 8))
98
+ self.linear_1 = nn.Linear(config.image_hidden_size, config.text_config.hidden_size, bias=True)
99
+ self.act = ACT2FN["gelu"]
100
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
101
+
102
+ def forward(self, image_features):
103
+ hidden_states = self.linear_1(image_features)
104
+ hidden_states = self.act(hidden_states)
105
+ hidden_states = self.linear_2(hidden_states)
106
+
107
+ b, num_tokens, c = hidden_states.shape
108
+ h = int(math.sqrt(num_tokens))
109
+
110
+ hidden_states = rearrange(hidden_states, "b (h w) d -> b d h w", h=h, w=h)
111
+ hidden_states = self.pool(hidden_states)
112
+ hidden_states = rearrange(hidden_states, "b d h w -> b (h w) d")
113
+
114
+ windows = 1 + self.multi_scale**2
115
+ hidden_states = rearrange(hidden_states, "(b h) w d -> b (h w) d", h=windows)
116
+ return hidden_states
117
+
118
+
119
+ class WindowShuffelProjector(nn.Module):
120
+ def __init__(self, config: LlavaConfig):
121
+ super().__init__()
122
+ self.multi_scale = getattr(config, "adapter_multi_scale", 2)
123
+ self.scale_factor = getattr(config, "adapter_pool", 2)
124
+ self.pixel_unshuffel = nn.PixelUnshuffle(self.scale_factor)
125
+ self.linear_1 = nn.Linear(config.image_hidden_size*(self.scale_factor**2), config.text_config.hidden_size, bias=True)
126
+ self.act = ACT2FN["gelu"]
127
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
128
+
129
+
130
+
131
+ def forward(self, image_features):
132
+ bsz, seq, embed_dim = image_features.size()
133
+ height = width = int(seq ** 0.5)
134
+ hidden_states = rearrange(image_features, "b (w h) d -> b d w h", w=width, h=height)
135
+ hidden_states = self.pixel_unshuffel(hidden_states)
136
+ hidden_states = rearrange(hidden_states, "b d w h -> b (w h) d")
137
+
138
+ hidden_states = self.linear_1(hidden_states)
139
+ hidden_states = self.act(hidden_states)
140
+ hidden_states = self.linear_2(hidden_states)
141
+
142
+ windows = 1 + self.multi_scale ** 2
143
+ hidden_states = rearrange(hidden_states, "(b h) w d -> b (h w) d", h=windows)
144
+ return hidden_states
145
+
146
+
147
+ class MultiscalePoolProjector(nn.Module):
148
+ def __init__(self, config: LlavaConfig):
149
+ super().__init__()
150
+
151
+ self.multi_scale = getattr(config, "adapter_multi_scale", 2)
152
+ self.pool = nn.AvgPool2d(self.multi_scale)
153
+ self.linear_1 = nn.Linear(config.image_hidden_size*2, config.text_config.hidden_size, bias=True)
154
+ self.act = ACT2FN["gelu"]
155
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
156
+
157
+ def forward(self, image_features):
158
+ b, num_tokens, c = image_features.shape
159
+ h = int(math.sqrt(num_tokens))
160
+ assert h * h == num_tokens
161
+ image_features = rearrange(image_features, "b (h w) d -> b d h w", h=h, w=h)
162
+
163
+ steps = 1 + self.multi_scale**2
164
+ low_res_features = image_features[::steps]
165
+ high_res_features = image_features[[i for i in range(image_features.size(0)) if i%steps > 0]]
166
+
167
+ merged_features = rearrange(high_res_features, "(b m) d h w -> b d h (m w)", m=self.multi_scale)
168
+ merged_features = rearrange(merged_features, "(b m) d h w -> b d (m h) w", m=self.multi_scale)
169
+
170
+ merged_features = self.pool(merged_features)
171
+
172
+ concat_features = torch.cat([low_res_features, merged_features], dim=1)
173
+ concat_features = rearrange(concat_features, "b d h w -> b (h w) d")
174
+
175
+ hidden_states = self.linear_1(concat_features)
176
+ hidden_states = self.act(hidden_states)
177
+ hidden_states = self.linear_2(hidden_states)
178
+ return hidden_states
179
+
180
+ class MultiscaleShuffleProjector(nn.Module):
181
+ def __init__(self, config):
182
+ super().__init__()
183
+
184
+ self.multi_scale = getattr(config, "adapter_multi_scale", 2)
185
+ self.shuffle = nn.PixelUnshuffle(self.multi_scale)
186
+
187
+ inc, ouc = config.image_hidden_size*(1+self.multi_scale**2), config.text_config.hidden_size
188
+ #
189
+ self.mlp = nn.Sequential(
190
+ nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc)
191
+ )
192
+
193
+ self.dwn = nn.AvgPool2d(2) #TokenDownLayer((12, 12))
194
+ self.peg = nn.Conv2d(ouc, ouc, 3, 1, 1, bias=True, groups=ouc) #PosInjectLayer(ouc, ouc, stride=1)
195
+
196
+ def forward(self, x):
197
+ b, num_tokens, c = x.shape
198
+ h = int(math.sqrt(num_tokens))
199
+ assert h * h == num_tokens
200
+ image_features = rearrange(x, "b (h w) d -> b d h w", h=h, w=h)
201
+
202
+ steps = 1 + self.multi_scale ** 2
203
+ low_res_features = image_features[::steps]
204
+ high_res_features = image_features[[i for i in range(image_features.size(0)) if i % steps > 0]]
205
+
206
+ merged_features = rearrange(high_res_features, "(b m) d h w -> b d h (m w)", m=self.multi_scale)
207
+ merged_features = rearrange(merged_features, "(b m) d h w -> b d (m h) w", m=self.multi_scale)
208
+
209
+ merged_features = self.shuffle(merged_features)
210
+
211
+ concat_features = torch.cat([low_res_features, merged_features], dim=1)
212
+ concat_features = rearrange(concat_features, "b d h w -> b (h w) d")
213
+
214
+ x = self.mlp(concat_features)
215
+
216
+ # x = self.dwn(x)
217
+ b, num_tokens, c = x.shape
218
+ h = int(math.sqrt(num_tokens))
219
+ assert h * h == num_tokens
220
+ x = rearrange(x, "b (h w) d -> b d h w", h=h, w=h) #x.permute(0, 2, 1).reshape(b, -1, h, h)
221
+ x = self.dwn(x)
222
+ x = self.peg(x) + x
223
+ x = rearrange(x, "b d h w -> b (h w) d") #x.flatten(2).transpose(1, 2)
224
+
225
+ return x
226
+ #
227
+
228
+ _CONFIG_FOR_DOC = "LlavaConfig"
229
+
230
+ LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
231
+ "llava-hf/llava-1.5-7b-hf",
232
+ "llava-hf/llava-1.5-13b-hf",
233
+ "llava-hf/bakLlava-v1-hf",
234
+ # See all Llava models at https://huggingface.co/models?filter=llava
235
+ ]
236
+
237
+
238
+ @dataclass
239
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
240
+ class LlavaCausalLMOutputWithPast(ModelOutput):
241
+ """
242
+ Base class for Llava causal language model (or autoregressive) outputs.
243
+
244
+ Args:
245
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
246
+ Language modeling loss (for next-token prediction).
247
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
248
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
249
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
250
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
251
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
252
+
253
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
254
+ `past_key_values` input) to speed up sequential decoding.
255
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
256
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
257
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
258
+
259
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
260
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
261
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
262
+ sequence_length)`.
263
+
264
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
265
+ heads.
266
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
267
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
268
+ sequence_length, hidden_size)`.
269
+
270
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
271
+ """
272
+
273
+ loss: Optional[torch.FloatTensor] = None
274
+ logits: torch.FloatTensor = None
275
+ past_key_values: Optional[List[torch.FloatTensor]] = None
276
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
277
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
278
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
279
+ labels: Optional[torch.LongTensor] = None
280
+
281
+
282
+
283
+ LLAVA_START_DOCSTRING = r"""
284
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
285
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
286
+ etc.)
287
+
288
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
289
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
290
+ and behavior.
291
+
292
+ Parameters:
293
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
294
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
295
+ load the weights associated with the model, only the configuration. Check out the
296
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
297
+ """
298
+
299
+
300
+ @add_start_docstrings(
301
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
302
+ LLAVA_START_DOCSTRING,
303
+ )
304
+ class LlavaPreTrainedModel(PreTrainedModel):
305
+ config_class = LlavaConfig
306
+ base_model_prefix = "model"
307
+ supports_gradient_checkpointing = True
308
+ _no_split_modules = ["LlavaVisionAttention"]
309
+ _skip_keys_device_placement = "past_key_values"
310
+ _supports_flash_attn_2 = True
311
+
312
+ def _init_weights(self, module):
313
+ # important: this ported version of Llava isn't meant for training from scratch - only
314
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
315
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
316
+ std = (
317
+ self.config.initializer_range
318
+ if hasattr(self.config, "initializer_range")
319
+ else self.config.text_config.initializer_range
320
+ )
321
+
322
+ if hasattr(module, "class_embedding"):
323
+ module.class_embedding.data.normal_(mean=0.0, std=std)
324
+
325
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
326
+ module.weight.data.normal_(mean=0.0, std=std)
327
+ if module.bias is not None:
328
+ module.bias.data.zero_()
329
+ elif isinstance(module, nn.Embedding):
330
+ module.weight.data.normal_(mean=0.0, std=std)
331
+ if module.padding_idx is not None:
332
+ module.weight.data[module.padding_idx].zero_()
333
+
334
+ @property
335
+ def _supports_sdpa(self):
336
+ """
337
+ Retrieve language_model's attribute to check whether the model supports
338
+ SDPA or not.
339
+ """
340
+ return self.language_model._supports_sdpa
341
+
342
+
343
+ LLAVA_INPUTS_DOCSTRING = r"""
344
+ Args:
345
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
346
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
347
+ it.
348
+
349
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
350
+ [`PreTrainedTokenizer.__call__`] for details.
351
+
352
+ [What are input IDs?](../glossary#input-ids)
353
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
354
+ The tensors corresponding to the input images. Pixel values can be obtained using
355
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
356
+ [`CLIPImageProcessor`] for processing images).
357
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
358
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
359
+
360
+ - 1 for tokens that are **not masked**,
361
+ - 0 for tokens that are **masked**.
362
+
363
+ [What are attention masks?](../glossary#attention-mask)
364
+
365
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
366
+ [`PreTrainedTokenizer.__call__`] for details.
367
+
368
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
369
+ `past_key_values`).
370
+
371
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
372
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
373
+ information on the default strategy.
374
+
375
+ - 1 indicates the head is **not masked**,
376
+ - 0 indicates the head is **masked**.
377
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
378
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
379
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
380
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
381
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
382
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
383
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
384
+
385
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
386
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
387
+
388
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
389
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
390
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
391
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
392
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
393
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
394
+ model's internal embedding lookup matrix.
395
+ use_cache (`bool`, *optional*):
396
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
397
+ `past_key_values`).
398
+ output_attentions (`bool`, *optional*):
399
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
400
+ tensors for more detail.
401
+ output_hidden_states (`bool`, *optional*):
402
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
403
+ more detail.
404
+ return_dict (`bool`, *optional*):
405
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
406
+ """
407
+
408
+
409
+ class CenturioForConditionalGeneration(LlavaPreTrainedModel):
410
+ config_class = CenturioConfig
411
+ _supports_cache_class = True
412
+ _supports_quantized_cache = False
413
+ _supports_static_cache = True
414
+
415
+ def __init__(self, config: CenturioConfig):
416
+ super().__init__(config)
417
+ # self.vision_tower = AutoModel.from_config(config.vision_config)
418
+ self.vision_tower = timm.create_model(
419
+ config.timm_model,
420
+ pretrained=False,
421
+ num_classes=0,
422
+ )
423
+ # https://github.com/TRI-ML/prismatic-vlms/blob/main/prismatic/models/backbones/vision/base_vision.py#L125
424
+ def unpack_tuple(fn):
425
+ def wrapper(*args, **kwargs):
426
+ result = fn(*args, **kwargs)
427
+ return result[0] if isinstance(result, tuple) or isinstance(result, list) else result
428
+
429
+ return wrapper
430
+ self.vision_tower.forward = unpack_tuple(
431
+ partial(
432
+ self.vision_tower.get_intermediate_layers, n={len(self.vision_tower.blocks) - 2}
433
+ )
434
+ )
435
+
436
+ config.image_hidden_size = self.vision_tower.embed_dim
437
+
438
+ self.multi_modal_projector = LlavaMultiModalAdapter(config)
439
+ self.vocab_size = config.text_config.vocab_size
440
+ # if getattr(config, "delay_init", False):
441
+ # self.language_model = None
442
+ # else:
443
+ self.language_model = AutoModelForCausalLM.from_config(
444
+ config.text_config, attn_implementation=config._attn_implementation, torch_dtype=config.torch_dtype,
445
+ trust_remote_code = True
446
+ )
447
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
448
+ self.post_init()
449
+
450
+
451
+
452
+ def get_input_embeddings(self):
453
+ return self.language_model.get_input_embeddings()
454
+
455
+ def set_input_embeddings(self, value):
456
+ self.language_model.set_input_embeddings(value)
457
+
458
+ def get_output_embeddings(self):
459
+ return self.language_model.get_output_embeddings()
460
+
461
+ def set_output_embeddings(self, new_embeddings):
462
+ self.language_model.set_output_embeddings(new_embeddings)
463
+
464
+ def set_decoder(self, decoder):
465
+ self.language_model.set_decoder(decoder)
466
+
467
+ def get_decoder(self):
468
+ return self.language_model.get_decoder()
469
+
470
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
471
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
472
+ # update vocab size
473
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
474
+ self.config.vocab_size = model_embeds.num_embeddings
475
+ self.vocab_size = model_embeds.num_embeddings
476
+ return model_embeds
477
+
478
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
479
+ num_images, num_image_patches, embed_dim = image_features.shape
480
+ batch_size, sequence_length = input_ids.shape
481
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
482
+ # 1. Create a mask to know where special image tokens are
483
+ special_image_token_mask = input_ids == self.config.image_token_index
484
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
485
+
486
+ #check if preprocessing already expanded the number of <image_token> needed to directly replace them
487
+ if torch.sum(special_image_token_mask) == image_features.shape[:-1].numel():
488
+ new_inputs_embeds = inputs_embeds.clone()
489
+ reshaped_image_hidden_states = image_features.view(-1, embed_dim)
490
+ new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
491
+
492
+ position_ids = (attention_mask.cumsum(-1) - 1).masked_fill_((attention_mask == 0), 1)
493
+
494
+ return new_inputs_embeds, attention_mask, labels, position_ids
495
+
496
+
497
+ # Compute the maximum embed dimension
498
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
499
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
500
+
501
+ # 2. Compute the positions where text should be written
502
+ # Calculate new positions for text tokens in merged image-text sequence.
503
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
504
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
505
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
506
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
507
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
508
+ if left_padding:
509
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
510
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
511
+
512
+ # 3. Create the full embedding, already padded to the maximum position
513
+ final_embedding = torch.zeros(
514
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
515
+ )
516
+ final_attention_mask = torch.zeros(
517
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
518
+ )
519
+ if labels is not None:
520
+ final_labels = torch.full(
521
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
522
+ )
523
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
524
+ # set the corresponding tensors into their correct target device.
525
+ target_device = inputs_embeds.device
526
+ batch_indices, non_image_indices, text_to_overwrite = (
527
+ batch_indices.to(target_device),
528
+ non_image_indices.to(target_device),
529
+ text_to_overwrite.to(target_device),
530
+ )
531
+ attention_mask = attention_mask.to(target_device)
532
+
533
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
534
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
535
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
536
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
537
+ if labels is not None:
538
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
539
+
540
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
541
+ ## BUG: this does NOT work for models (Phi-3) that have set some embedding (padding) to be 0. Replaced with the below three lines.
542
+ # image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
543
+ image_to_overwrite = torch.ones_like(final_attention_mask)
544
+ image_to_overwrite[batch_indices, text_to_overwrite] = torch.zeros_like(attention_mask)[batch_indices, non_image_indices]
545
+ image_to_overwrite = image_to_overwrite.bool()
546
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
547
+
548
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
549
+ raise ValueError(
550
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
551
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
552
+ )
553
+
554
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
555
+ final_attention_mask |= image_to_overwrite
556
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
557
+
558
+ if labels is None:
559
+ final_labels = None
560
+
561
+ return final_embedding, final_attention_mask, final_labels, position_ids
562
+
563
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
564
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
565
+ def forward(
566
+ self,
567
+ input_ids: torch.LongTensor = None,
568
+ pixel_values: torch.FloatTensor = None,
569
+ attention_mask: Optional[torch.Tensor] = None,
570
+ position_ids: Optional[torch.LongTensor] = None,
571
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
572
+ inputs_embeds: Optional[torch.FloatTensor] = None,
573
+ labels: Optional[torch.LongTensor] = None,
574
+ use_cache: Optional[bool] = None,
575
+ cache_position: Optional[torch.LongTensor] = None,
576
+ output_attentions: Optional[bool] = None,
577
+ output_hidden_states: Optional[bool] = None,
578
+ return_dict: Optional[bool] = None,
579
+ **kwargs
580
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
581
+ r"""
582
+ Args:
583
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
584
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
585
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
586
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
587
+
588
+ Returns:
589
+
590
+ """
591
+
592
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
593
+ output_hidden_states = (
594
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
595
+ )
596
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
597
+
598
+ if inputs_embeds is None:
599
+ # 1. Extra the input embeddings
600
+ inputs_embeds = self.get_input_embeddings()(input_ids)
601
+
602
+ # 2. Merge text and images
603
+ if pixel_values is not None and input_ids.shape[1] != 1:
604
+ image_outputs = self.vision_tower(pixel_values)
605
+
606
+ image_features = self.multi_modal_projector(image_outputs)
607
+ image_features = image_features.to(inputs_embeds.dtype)
608
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
609
+ image_features, inputs_embeds, input_ids, attention_mask, labels
610
+ )
611
+ if labels is None:
612
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
613
+ else:
614
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
615
+ # generation with cache
616
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
617
+ if isinstance(past_key_values, Cache):
618
+ first_layer_past_key_value = past_key_values.key_cache[0][:, :, :, 0]
619
+ else:
620
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
621
+
622
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
623
+ extended_attention_mask = torch.ones(
624
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
625
+ dtype=attention_mask.dtype,
626
+ device=attention_mask.device,
627
+ )
628
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
629
+
630
+
631
+
632
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
633
+ # cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
634
+ # -target_length:
635
+ # ]
636
+
637
+ outputs = self.language_model(
638
+ attention_mask=attention_mask,
639
+ position_ids=position_ids,
640
+ past_key_values=past_key_values,
641
+ inputs_embeds=inputs_embeds,
642
+ use_cache=use_cache,
643
+ # cache_position=cache_position,
644
+ output_attentions=output_attentions,
645
+ output_hidden_states=output_hidden_states,
646
+ return_dict=return_dict,
647
+ )
648
+
649
+ logits = outputs[0]
650
+
651
+ loss = None
652
+ if labels is not None:
653
+ # Shift so that tokens < n predict n
654
+ if attention_mask is not None:
655
+ shift_attention_mask = attention_mask[..., 1:]
656
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
657
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
658
+ else:
659
+ shift_logits = logits[..., :-1, :].contiguous()
660
+ shift_labels = labels[..., 1:].contiguous()
661
+ # Flatten the tokens
662
+ loss_fct = nn.CrossEntropyLoss()
663
+ loss = loss_fct(
664
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
665
+ )
666
+
667
+ if not return_dict:
668
+ output = (logits,) + outputs[1:]
669
+ return (loss,) + output if loss is not None else output
670
+
671
+ return LlavaCausalLMOutputWithPast(
672
+ loss=loss,
673
+ logits=logits,
674
+ labels=labels,
675
+ past_key_values=outputs.past_key_values,
676
+ hidden_states=outputs.hidden_states,
677
+ attentions=outputs.attentions,
678
+ )
679
+
680
+ def prepare_inputs_for_generation(
681
+ self,
682
+ input_ids,
683
+ past_key_values=None,
684
+ inputs_embeds=None,
685
+ pixel_values=None,
686
+ attention_mask=None,
687
+ cache_position=None,
688
+ use_cache=True,
689
+ position_ids=None,
690
+ **kwargs
691
+ ):
692
+ model_inputs = self.language_model.prepare_inputs_for_generation(
693
+ input_ids,
694
+ past_key_values=past_key_values,
695
+ inputs_embeds=inputs_embeds,
696
+ attention_mask=attention_mask,
697
+ cache_position=cache_position,
698
+ **kwargs,
699
+ )
700
+ #Ugly comparison. Should use a config var that knows how many image tokens we have like HF does.
701
+ # But we are unlikely to use >30 images in one sample or use <=30 tokens per image.
702
+ if cache_position[0] == 0:
703
+ model_inputs["pixel_values"] = pixel_values
704
+ # "legacy" mode
705
+ if (input_ids == self.config.image_token_index).sum(1).max() < 30:
706
+ if past_key_values is not None:
707
+ if isinstance(past_key_values, Cache):
708
+ # branch for Gemma2 with hybrid cache
709
+ if past_key_values.seen_tokens is None:
710
+ past_length = cache_position[0] # torch.tensor(0, device=input_ids.device)
711
+ max_cache_length = (
712
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
713
+ if past_key_values.get_max_length() is not None
714
+ else None
715
+ )
716
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
717
+ # old default branch
718
+ else:
719
+ cache_length = past_key_values.get_seq_length()
720
+ past_length = past_key_values.seen_tokens
721
+
722
+ else:
723
+ cache_length = past_length = past_key_values[0][0].shape[2]
724
+
725
+ # Keep only the unprocessed tokens:
726
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
727
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
728
+ # input)
729
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
730
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
731
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
732
+ # input_ids based on the past_length.
733
+ elif past_length < input_ids.shape[1]:
734
+ input_ids = input_ids[:, past_length:]
735
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
736
+ elif self.config.image_token_index in input_ids:
737
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
738
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
739
+ # older attention values, as their corresponding values are not part of the input.
740
+ # if cache_length < past_length and attention_mask is not None:
741
+ # attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
742
+ if attention_mask is not None and position_ids is None:
743
+ # create position_ids on the fly for batch generation
744
+ position_ids = attention_mask.long().cumsum(-1) - 1
745
+ position_ids.masked_fill_(attention_mask == 0, 1)
746
+ if past_key_values:
747
+ position_ids = position_ids[:, -input_ids.shape[1] :]
748
+
749
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
750
+ if inputs_embeds is not None and past_key_values is None:
751
+ model_inputs = {"inputs_embeds": inputs_embeds}
752
+ else:
753
+ model_inputs = {"input_ids": input_ids}
754
+
755
+ # if cache_position[0] == 0 or (input_ids == self.config.image_token_index).sum(1).max() > 0:
756
+ model_inputs.update(
757
+ {
758
+ "position_ids": position_ids,
759
+ "past_key_values": past_key_values,
760
+ "attention_mask": attention_mask,
761
+ "use_cache": use_cache,
762
+ "pixel_values": pixel_values,
763
+ }
764
+ )
765
+ return model_inputs
766
+
767
+ def _reorder_cache(self, *args, **kwargs):
768
+ return self.language_model._reorder_cache(*args, **kwargs)