model_files
#1
by
ManishThota
- opened
- README.md +1 -60
- README_old.md +0 -62
- config.json +4 -4
- configuration_imp.py +0 -175
- model.safetensors +0 -3
- modeling_imp.py +0 -1262
- pytorch_model.bin +0 -3
- tokenizer.json +0 -0
- vision_encoder.py +0 -593
- vocab.json +0 -0
README.md
CHANGED
@@ -1,62 +1,3 @@
|
|
1 |
---
|
2 |
-
license:
|
3 |
-
language:
|
4 |
-
- en
|
5 |
-
metrics:
|
6 |
-
- bleu
|
7 |
-
tags:
|
8 |
-
- endpoints
|
9 |
-
- text-generation-inference
|
10 |
-
inference: true
|
11 |
---
|
12 |
-
|
13 |
-
<h3 align='center' style='font-size: 24px;'>Blazzing Fast Tiny Vision Language Model</h3>
|
14 |
-
|
15 |
-
|
16 |
-
<p align='center', style='font-size: 16px;' >A Custom 3B parameter Model. Built by <a href="https://www.linkedin.com/in/manishkumarthota/">@Manish</a> The model is released for research purposes only, commercial use is not allowed. </p>
|
17 |
-
|
18 |
-
## How to use
|
19 |
-
|
20 |
-
|
21 |
-
**Install dependencies**
|
22 |
-
```bash
|
23 |
-
pip install transformers # latest version is ok, but we recommend v4.31.0
|
24 |
-
pip install -q pillow accelerate einops
|
25 |
-
```
|
26 |
-
|
27 |
-
You can use the following code for model inference. The format of text instruction is similar to [LLaVA](https://github.com/haotian-liu/LLaVA).
|
28 |
-
|
29 |
-
```Python
|
30 |
-
import torch
|
31 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
32 |
-
from PIL import Image
|
33 |
-
|
34 |
-
torch.set_default_device("cuda")
|
35 |
-
|
36 |
-
#Create model
|
37 |
-
model = AutoModelForCausalLM.from_pretrained(
|
38 |
-
"ManishThota/CustomModel",
|
39 |
-
torch_dtype=torch.float16,
|
40 |
-
device_map="auto",
|
41 |
-
trust_remote_code=True)
|
42 |
-
tokenizer = AutoTokenizer.from_pretrained("ManishThota/CustomModel", trust_remote_code=True)
|
43 |
-
|
44 |
-
#function to generate the answer
|
45 |
-
def predict(question, image_path):
|
46 |
-
#Set inputs
|
47 |
-
text = f"USER: <image>\n{question}? ASSISTANT:"
|
48 |
-
image = Image.open(image_path)
|
49 |
-
|
50 |
-
input_ids = tokenizer(text, return_tensors='pt').input_ids.to('cuda')
|
51 |
-
image_tensor = model.image_preprocess(image)
|
52 |
-
|
53 |
-
#Generate the answer
|
54 |
-
output_ids = model.generate(
|
55 |
-
input_ids,
|
56 |
-
max_new_tokens=25,
|
57 |
-
images=image_tensor,
|
58 |
-
use_cache=True)[0]
|
59 |
-
|
60 |
-
return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
|
61 |
-
|
62 |
-
```
|
|
|
1 |
---
|
2 |
+
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README_old.md
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
---
|
2 |
-
license: creativeml-openrail-m
|
3 |
-
language:
|
4 |
-
- en
|
5 |
-
metrics:
|
6 |
-
- bleu
|
7 |
-
---
|
8 |
-
<h1 align='center' style='font-size: 36px; font-weight: bold;'>Sparrow</h1>
|
9 |
-
<h3 align='center' style='font-size: 24px;'>Blazzing Fast Tiny Vision Language Model</h3>
|
10 |
-
|
11 |
-
|
12 |
-
<p align="center">
|
13 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/650c7fbb8ffe1f53bdbe1aec/DTjDSq2yG-5Cqnk6giPFq.jpeg" width="50%" height="auto"/>
|
14 |
-
</p>
|
15 |
-
|
16 |
-
<p align='center', style='font-size: 16px;' >A Custom 3B parameter Model Enhanced for Educational Contexts: This specialized model integrates slide-text pairs from machine learning classes, leveraging a unique training approach. It connects a frozen pre-trained vision encoder (SigLip) with a frozen language model (Phi-2) through an innovative projector. The model employs attention mechanisms and language modeling loss to deeply understand and generate educational content, specifically tailored to the context of machine learning education. Built by <a href="https://www.linkedin.com/in/manishkumarthota/">@Manish</a> The model is released for research purposes only, commercial use is not allowed. </p>
|
17 |
-
|
18 |
-
## How to use
|
19 |
-
|
20 |
-
|
21 |
-
**Install dependencies**
|
22 |
-
```bash
|
23 |
-
pip install transformers # latest version is ok, but we recommend v4.31.0
|
24 |
-
pip install -q pillow accelerate einops
|
25 |
-
```
|
26 |
-
|
27 |
-
You can use the following code for model inference. The format of text instruction is similar to [LLaVA](https://github.com/haotian-liu/LLaVA).
|
28 |
-
|
29 |
-
```Python
|
30 |
-
import torch
|
31 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
32 |
-
from PIL import Image
|
33 |
-
|
34 |
-
torch.set_default_device("cuda")
|
35 |
-
|
36 |
-
#Create model
|
37 |
-
model = AutoModelForCausalLM.from_pretrained(
|
38 |
-
"ManishThota/Sparrow",
|
39 |
-
torch_dtype=torch.float16,
|
40 |
-
device_map="auto",
|
41 |
-
trust_remote_code=True)
|
42 |
-
tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)
|
43 |
-
|
44 |
-
#function to generate the answer
|
45 |
-
def predict(question, image_path):
|
46 |
-
#Set inputs
|
47 |
-
text = f"USER: <image>\n{question}? ASSISTANT:"
|
48 |
-
image = Image.open(image_path)
|
49 |
-
|
50 |
-
input_ids = tokenizer(text, return_tensors='pt').input_ids.to('cuda')
|
51 |
-
image_tensor = model.image_preprocess(image)
|
52 |
-
|
53 |
-
#Generate the answer
|
54 |
-
output_ids = model.generate(
|
55 |
-
input_ids,
|
56 |
-
max_new_tokens=25,
|
57 |
-
images=image_tensor,
|
58 |
-
use_cache=True)[0]
|
59 |
-
|
60 |
-
return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
|
61 |
-
|
62 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.json
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"activation_function": "gelu_new",
|
4 |
"architectures": [
|
5 |
"ImpForCausalLM"
|
6 |
],
|
7 |
"attn_pdrop": 0.0,
|
8 |
"auto_map": {
|
9 |
-
"AutoConfig": "configuration_imp.ImpConfig",
|
10 |
-
"AutoModelForCausalLM": "modeling_imp.ImpForCausalLM"
|
11 |
},
|
12 |
"embd_pdrop": 0.0,
|
13 |
"eos_token_id": 50295,
|
@@ -29,7 +29,7 @@
|
|
29 |
"mm_vision_select_feature": "patch",
|
30 |
"mm_vision_select_layer": -2,
|
31 |
"mm_vision_tower": "google/siglip-so400m-patch14-384",
|
32 |
-
"model_type": "
|
33 |
"n_embd": 2560,
|
34 |
"n_head": 32,
|
35 |
"n_head_kv": null,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "MILVLG/imp-v1-3b",
|
3 |
"activation_function": "gelu_new",
|
4 |
"architectures": [
|
5 |
"ImpForCausalLM"
|
6 |
],
|
7 |
"attn_pdrop": 0.0,
|
8 |
"auto_map": {
|
9 |
+
"AutoConfig": "MILVLG/imp-v1-3b--configuration_imp.ImpConfig",
|
10 |
+
"AutoModelForCausalLM": "MILVLG/imp-v1-3b--modeling_imp.ImpForCausalLM"
|
11 |
},
|
12 |
"embd_pdrop": 0.0,
|
13 |
"eos_token_id": 50295,
|
|
|
29 |
"mm_vision_select_feature": "patch",
|
30 |
"mm_vision_select_layer": -2,
|
31 |
"mm_vision_tower": "google/siglip-so400m-patch14-384",
|
32 |
+
"model_type": "imp",
|
33 |
"n_embd": 2560,
|
34 |
"n_head": 32,
|
35 |
"n_head_kv": null,
|
configuration_imp.py
DELETED
@@ -1,175 +0,0 @@
|
|
1 |
-
|
2 |
-
# ------------------------------- Phi-2 ---------------------------------------------
|
3 |
-
# Copyright (c) Microsoft Corporation.
|
4 |
-
# Licensed under the MIT license.
|
5 |
-
# https://huggingface.co/google/siglip-so400m-patch14-384
|
6 |
-
#
|
7 |
-
# Copyright (c) 2022, Tri Dao, [email protected].
|
8 |
-
# Licensed under the BSD 3-Clause License.
|
9 |
-
# ------------------------------- SigLIP --------------------------------------------
|
10 |
-
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
|
11 |
-
#
|
12 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
13 |
-
# you may not use this file except in compliance with the License.
|
14 |
-
# You may obtain a copy of the License at
|
15 |
-
#
|
16 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
17 |
-
#
|
18 |
-
# Unless required by applicable law or agreed to in writing, software
|
19 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
20 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
21 |
-
# See the License for the specific language governing permissions and
|
22 |
-
# limitations under the License.
|
23 |
-
# ------------------------------- Llava ---------------------------------------------
|
24 |
-
# Copyright 2023 Haotian Liu
|
25 |
-
#
|
26 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
27 |
-
# you may not use this file except in compliance with the License.
|
28 |
-
# You may obtain a copy of the License at
|
29 |
-
#
|
30 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
31 |
-
#
|
32 |
-
# Unless required by applicable law or agreed to in writing, software
|
33 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
34 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
35 |
-
# See the License for the specific language governing permissions and
|
36 |
-
# limitations under the License.
|
37 |
-
# -----------------------------------------------------------------------------------
|
38 |
-
|
39 |
-
|
40 |
-
import os
|
41 |
-
import math
|
42 |
-
from typing import Optional, Union
|
43 |
-
|
44 |
-
from transformers import PretrainedConfig
|
45 |
-
from transformers.utils import logging
|
46 |
-
|
47 |
-
logger = logging.get_logger(__name__)
|
48 |
-
|
49 |
-
|
50 |
-
class PhiConfig(PretrainedConfig):
|
51 |
-
"""Phi configuration."""
|
52 |
-
|
53 |
-
model_type = "phi-msft"
|
54 |
-
attribute_map = {
|
55 |
-
"max_position_embeddings": "n_positions",
|
56 |
-
"hidden_size": "n_embd",
|
57 |
-
"num_attention_heads": "n_head",
|
58 |
-
"num_hidden_layers": "n_layer",
|
59 |
-
}
|
60 |
-
|
61 |
-
def __init__(
|
62 |
-
self,
|
63 |
-
vocab_size: int = 50304,
|
64 |
-
n_positions: int = 2048,
|
65 |
-
n_embd: int = 1024,
|
66 |
-
n_layer: int = 20,
|
67 |
-
n_inner: Optional[int] = None,
|
68 |
-
n_head: int = 16,
|
69 |
-
n_head_kv: Optional[int] = None,
|
70 |
-
rotary_dim: Optional[int] = 32,
|
71 |
-
activation_function: Optional[str] = "gelu_new",
|
72 |
-
flash_attn: bool = False,
|
73 |
-
flash_rotary: bool = False,
|
74 |
-
fused_dense: bool = False,
|
75 |
-
attn_pdrop: float = 0.0,
|
76 |
-
embd_pdrop: float = 0.0,
|
77 |
-
resid_pdrop: float = 0.0,
|
78 |
-
layer_norm_epsilon: float = 1e-5,
|
79 |
-
initializer_range: float = 0.02,
|
80 |
-
tie_word_embeddings: bool = False,
|
81 |
-
pad_vocab_size_multiple: int = 64,
|
82 |
-
**kwargs
|
83 |
-
) -> None:
|
84 |
-
self.vocab_size = int(math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
85 |
-
self.n_positions = n_positions
|
86 |
-
self.n_embd = n_embd
|
87 |
-
self.n_layer = n_layer
|
88 |
-
self.n_inner = n_inner
|
89 |
-
self.n_head = n_head
|
90 |
-
self.n_head_kv = n_head_kv
|
91 |
-
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
92 |
-
self.activation_function = activation_function
|
93 |
-
self.flash_attn = flash_attn
|
94 |
-
self.flash_rotary = flash_rotary
|
95 |
-
self.fused_dense = fused_dense
|
96 |
-
self.attn_pdrop = attn_pdrop
|
97 |
-
self.embd_pdrop = embd_pdrop
|
98 |
-
self.resid_pdrop = resid_pdrop
|
99 |
-
self.layer_norm_epsilon = layer_norm_epsilon
|
100 |
-
self.initializer_range = initializer_range
|
101 |
-
|
102 |
-
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
class SiglipVisionConfig(PretrainedConfig):
|
107 |
-
|
108 |
-
model_type = "siglip_vision_model"
|
109 |
-
|
110 |
-
def __init__(
|
111 |
-
self,
|
112 |
-
hidden_size=768,
|
113 |
-
intermediate_size=3072,
|
114 |
-
num_hidden_layers=12,
|
115 |
-
num_attention_heads=12,
|
116 |
-
num_channels=3,
|
117 |
-
image_size=224,
|
118 |
-
patch_size=16,
|
119 |
-
hidden_act="gelu_pytorch_tanh",
|
120 |
-
layer_norm_eps=1e-6,
|
121 |
-
attention_dropout=0.0,
|
122 |
-
**kwargs,
|
123 |
-
):
|
124 |
-
super().__init__(**kwargs)
|
125 |
-
|
126 |
-
self.hidden_size = hidden_size
|
127 |
-
self.intermediate_size = intermediate_size
|
128 |
-
self.num_hidden_layers = num_hidden_layers
|
129 |
-
self.num_attention_heads = num_attention_heads
|
130 |
-
self.num_channels = num_channels
|
131 |
-
self.patch_size = patch_size
|
132 |
-
self.image_size = image_size
|
133 |
-
self.attention_dropout = attention_dropout
|
134 |
-
self.layer_norm_eps = layer_norm_eps
|
135 |
-
self.hidden_act = hidden_act
|
136 |
-
|
137 |
-
@classmethod
|
138 |
-
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
139 |
-
cls._set_token_in_kwargs(kwargs)
|
140 |
-
|
141 |
-
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
142 |
-
|
143 |
-
# get the vision config dict if we are loading from SiglipConfig
|
144 |
-
if config_dict.get("model_type") == "siglip":
|
145 |
-
config_dict = config_dict["vision_config"]
|
146 |
-
|
147 |
-
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
148 |
-
logger.warning(
|
149 |
-
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
150 |
-
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
151 |
-
)
|
152 |
-
|
153 |
-
return cls.from_dict(config_dict, **kwargs)
|
154 |
-
|
155 |
-
|
156 |
-
class ImpConfig(PhiConfig):
|
157 |
-
model_type = "imp"
|
158 |
-
|
159 |
-
def __init__(self, **kwargs):
|
160 |
-
super().__init__(**kwargs)
|
161 |
-
self.image_token_index = getattr(self, "image_token_index", 50296)
|
162 |
-
self.image_token = getattr(self, "image_token", "<image>")
|
163 |
-
|
164 |
-
if not hasattr(self, "vision_tower_config") and hasattr(self, "mm_vision_tower"):
|
165 |
-
vision_tower_config = SiglipVisionConfig.from_pretrained(self.mm_vision_tower)
|
166 |
-
self.vision_tower_config = vision_tower_config.to_diff_dict()
|
167 |
-
|
168 |
-
@property
|
169 |
-
def vision_tower_cfg(self):
|
170 |
-
cfg = SiglipVisionConfig.from_dict(self.vision_tower_config)
|
171 |
-
# imp-v1 only supports `patch` feature for now w/o cls token
|
172 |
-
# cfg.mm_vision_select_feature = self.mm_vision_select_feature
|
173 |
-
cfg.mm_vision_select_layer = self.mm_vision_select_layer
|
174 |
-
cfg.mm_vision_tower = self.mm_vision_tower
|
175 |
-
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f22e7b5e04ac6d134a269cbb2d6c724aafd81bb4446b3ad567225fb93b757e75
|
3 |
-
size 6373981888
|
|
|
|
|
|
|
|
modeling_imp.py
DELETED
@@ -1,1262 +0,0 @@
|
|
1 |
-
# Copyright (c) MILVLG team.
|
2 |
-
# Licensed under the Apache 2.0 license.
|
3 |
-
#
|
4 |
-
# Some code here is copied from the project Phi-2 (https://huggingface.co/microsoft/phi-2),
|
5 |
-
# SigLIP@transformers==4.37.0.dev0 (https://huggingface.co/google/siglip-so400m-patch14-384),
|
6 |
-
# and Llava (https://github.com/haotian-liu/LLaVA), and modified by
|
7 |
-
# Zhenwei Shao ([email protected]) @ MILVLG. We thank them for their great works.
|
8 |
-
# And their original licenses and copyright should be inherited (see the statements
|
9 |
-
# in `configuration_imp.py` for more details).
|
10 |
-
|
11 |
-
|
12 |
-
# Be careful: The way how `past_key_values.seqlen_offset` is updated is modified from
|
13 |
-
# the implementation of original Phi-2. See the comments below for details.
|
14 |
-
|
15 |
-
from __future__ import annotations
|
16 |
-
import os
|
17 |
-
import math
|
18 |
-
import re
|
19 |
-
from dataclasses import dataclass, field
|
20 |
-
from typing import Any, Dict, Optional, Tuple, Union, List
|
21 |
-
from abc import ABC, abstractmethod
|
22 |
-
|
23 |
-
import torch
|
24 |
-
import torch.nn as nn
|
25 |
-
from einops import rearrange, repeat
|
26 |
-
from transformers import (
|
27 |
-
PretrainedConfig,
|
28 |
-
PreTrainedModel,
|
29 |
-
AutoConfig,
|
30 |
-
AutoModelForCausalLM
|
31 |
-
)
|
32 |
-
from transformers.activations import ACT2FN
|
33 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
34 |
-
import sys
|
35 |
-
from .configuration_imp import PhiConfig, ImpConfig
|
36 |
-
from .vision_encoder import VisionTower
|
37 |
-
|
38 |
-
try:
|
39 |
-
from flash_attn.bert_padding import pad_input, unpad_input
|
40 |
-
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
41 |
-
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
42 |
-
from flash_attn.ops.fused_dense import FusedDense
|
43 |
-
except:
|
44 |
-
pad_input, unpad_input = None, None
|
45 |
-
FlashRotaryEmbedding = None
|
46 |
-
FlashSelfAttention, FlashCrossAttention = None, None
|
47 |
-
FusedDense = None
|
48 |
-
|
49 |
-
|
50 |
-
@dataclass
|
51 |
-
class InferenceParams:
|
52 |
-
"""Inference parameters passed to model to efficiently calculate
|
53 |
-
and store context during inference.
|
54 |
-
|
55 |
-
Reference:
|
56 |
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
|
57 |
-
|
58 |
-
Args:
|
59 |
-
max_seqlen: Maximum sequence length.
|
60 |
-
max_batch_size: Maximum batch size.
|
61 |
-
seqlen_offset: Sequence length offset.
|
62 |
-
batch_size_offset: Batch size offset.
|
63 |
-
key_value_memory_dict: Key value memory dictionary.
|
64 |
-
lengths_per_sample: Lengths per sample.
|
65 |
-
|
66 |
-
"""
|
67 |
-
|
68 |
-
max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
|
69 |
-
|
70 |
-
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
|
71 |
-
|
72 |
-
seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
|
73 |
-
|
74 |
-
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
|
75 |
-
|
76 |
-
key_value_memory_dict: Dict[str, Any] = field(
|
77 |
-
default_factory=dict, metadata={"help": "Key value memory dictionary."}
|
78 |
-
)
|
79 |
-
|
80 |
-
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
|
81 |
-
|
82 |
-
|
83 |
-
class Embedding(nn.Module):
|
84 |
-
"""Token embedding with dropout."""
|
85 |
-
|
86 |
-
def __init__(self, config: PretrainedConfig) -> None:
|
87 |
-
super().__init__()
|
88 |
-
|
89 |
-
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
90 |
-
self.drop = nn.Dropout(config.embd_pdrop)
|
91 |
-
|
92 |
-
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
93 |
-
input_shape = input_ids.size()
|
94 |
-
input_ids = input_ids.view(-1, input_shape[-1])
|
95 |
-
|
96 |
-
hidden_states = self.wte(input_ids)
|
97 |
-
hidden_states = self.drop(hidden_states)
|
98 |
-
|
99 |
-
return hidden_states
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
def _apply_rotary_emb(
|
104 |
-
x: torch.FloatTensor,
|
105 |
-
cos: torch.FloatTensor,
|
106 |
-
sin: torch.FloatTensor,
|
107 |
-
) -> torch.FloatTensor:
|
108 |
-
_, seqlen, _, _ = x.shape
|
109 |
-
_, rotary_dim = cos.shape
|
110 |
-
rotary_dim *= 2
|
111 |
-
|
112 |
-
x_rot = x[:, :, :, :rotary_dim]
|
113 |
-
x_pass = x[:, :, :, rotary_dim:]
|
114 |
-
|
115 |
-
x1, x2 = x_rot.chunk(2, dim=-1)
|
116 |
-
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
117 |
-
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
|
118 |
-
|
119 |
-
x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
|
120 |
-
|
121 |
-
return torch.cat([x_rot, x_pass], axis=-1)
|
122 |
-
|
123 |
-
|
124 |
-
def _apply_rotary_emb_kv(
|
125 |
-
kv: torch.FloatTensor,
|
126 |
-
cos: torch.FloatTensor,
|
127 |
-
sin: torch.FloatTensor,
|
128 |
-
cos_k: Optional[torch.FloatTensor] = None,
|
129 |
-
sin_k: Optional[torch.FloatTensor] = None,
|
130 |
-
) -> torch.FloatTensor:
|
131 |
-
_, seqlen, _, _, _ = kv.shape
|
132 |
-
_, rotary_dim = cos.shape
|
133 |
-
rotary_dim *= 2
|
134 |
-
|
135 |
-
k_rot = kv[:, :, 0, :, :rotary_dim]
|
136 |
-
k_pass = kv[:, :, 0, :, rotary_dim:]
|
137 |
-
|
138 |
-
k1, k2 = k_rot.chunk(2, dim=-1)
|
139 |
-
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
140 |
-
k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
|
141 |
-
|
142 |
-
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
|
143 |
-
|
144 |
-
return torch.cat(
|
145 |
-
[
|
146 |
-
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
147 |
-
kv[:, :, 1:2, :, :],
|
148 |
-
],
|
149 |
-
axis=2,
|
150 |
-
)
|
151 |
-
|
152 |
-
|
153 |
-
def _apply_rotary_emb_qkv(
|
154 |
-
qkv: torch.FloatTensor,
|
155 |
-
cos: torch.FloatTensor,
|
156 |
-
sin: torch.FloatTensor,
|
157 |
-
cos_k: Optional[torch.FloatTensor] = None,
|
158 |
-
sin_k: Optional[torch.FloatTensor] = None,
|
159 |
-
) -> torch.FloatTensor:
|
160 |
-
_, seqlen, _, _, _ = qkv.shape
|
161 |
-
_, rotary_dim = cos.shape
|
162 |
-
rotary_dim *= 2
|
163 |
-
|
164 |
-
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
165 |
-
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
166 |
-
|
167 |
-
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
168 |
-
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
169 |
-
|
170 |
-
q1, q2 = q_rot.chunk(2, dim=-1)
|
171 |
-
k1, k2 = k_rot.chunk(2, dim=-1)
|
172 |
-
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
173 |
-
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
174 |
-
|
175 |
-
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
176 |
-
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
177 |
-
|
178 |
-
return torch.cat(
|
179 |
-
[
|
180 |
-
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
181 |
-
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
182 |
-
qkv[:, :, 2:3, :, :],
|
183 |
-
],
|
184 |
-
axis=2,
|
185 |
-
)
|
186 |
-
|
187 |
-
|
188 |
-
class RotaryEmbedding(nn.Module):
|
189 |
-
"""Rotary positional embedding (RoPE).
|
190 |
-
|
191 |
-
Reference:
|
192 |
-
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
193 |
-
https://arxiv.org/pdf/2104.09864.pdf.
|
194 |
-
|
195 |
-
"""
|
196 |
-
|
197 |
-
def __init__(
|
198 |
-
self,
|
199 |
-
dim: int,
|
200 |
-
base: int = 10000,
|
201 |
-
scale_base: Optional[float] = None,
|
202 |
-
pos_idx_in_fp32: bool = True,
|
203 |
-
max_position_embeddings: int = 2048,
|
204 |
-
device: Optional[str] = None,
|
205 |
-
**kwargs,
|
206 |
-
) -> None:
|
207 |
-
super().__init__()
|
208 |
-
|
209 |
-
if scale_base is not None:
|
210 |
-
raise NotImplementedError
|
211 |
-
|
212 |
-
self.dim = dim
|
213 |
-
self.base = float(base)
|
214 |
-
self.scale_base = scale_base
|
215 |
-
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
216 |
-
self.max_position_embeddings = max_position_embeddings
|
217 |
-
self.device = device
|
218 |
-
|
219 |
-
# Generate and save the inverse frequency buffer (non-trainable)
|
220 |
-
inv_freq = self._compute_inv_freq(device)
|
221 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
222 |
-
|
223 |
-
# Generate and save the scale buffer (non-trainable)
|
224 |
-
scale = (
|
225 |
-
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
226 |
-
if scale_base is not None
|
227 |
-
else None
|
228 |
-
)
|
229 |
-
self.register_buffer("scale", scale, persistent=False)
|
230 |
-
|
231 |
-
# Initialize cached attributes since ONNX can't rely on dynamic initialization
|
232 |
-
self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
|
233 |
-
|
234 |
-
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
235 |
-
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
236 |
-
|
237 |
-
def _update_cos_sin_cache(
|
238 |
-
self,
|
239 |
-
seqlen: int,
|
240 |
-
device: Optional[str] = None,
|
241 |
-
dtype: Optional[torch.dtype] = None,
|
242 |
-
) -> None:
|
243 |
-
self._seq_len_cached = seqlen
|
244 |
-
|
245 |
-
# fp32 is preferred since the output of `torch.arange` can be quite large
|
246 |
-
# and bf16 would lose a lot of precision
|
247 |
-
if self.pos_idx_in_fp32:
|
248 |
-
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
249 |
-
if self.inv_freq.dtype != torch.float32:
|
250 |
-
inv_freq = self._compute_inv_freq(device=device)
|
251 |
-
else:
|
252 |
-
inv_freq = self.inv_freq
|
253 |
-
else:
|
254 |
-
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
255 |
-
inv_freq = self.inv_freq
|
256 |
-
|
257 |
-
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
258 |
-
freqs = torch.outer(t, inv_freq)
|
259 |
-
if self.scale is None:
|
260 |
-
self._cos_cached = torch.cos(freqs).to(dtype)
|
261 |
-
self._sin_cached = torch.sin(freqs).to(dtype)
|
262 |
-
else:
|
263 |
-
power = (
|
264 |
-
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
265 |
-
) / self.scale_base
|
266 |
-
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
267 |
-
|
268 |
-
# Force the scale multiplication to happen in fp32
|
269 |
-
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
270 |
-
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
271 |
-
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
272 |
-
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
273 |
-
|
274 |
-
def forward(
|
275 |
-
self,
|
276 |
-
qkv: torch.Tensor,
|
277 |
-
kv: Optional[torch.Tensor] = None,
|
278 |
-
seqlen_offset: int = 0,
|
279 |
-
**kwargs,
|
280 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
281 |
-
if (
|
282 |
-
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
283 |
-
or self._cos_cached.device != qkv.device
|
284 |
-
or self._cos_cached.dtype != qkv.dtype
|
285 |
-
or (self.training and self._cos_cached.is_inference())
|
286 |
-
):
|
287 |
-
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
288 |
-
|
289 |
-
if kv is None:
|
290 |
-
return _apply_rotary_emb_qkv(
|
291 |
-
qkv,
|
292 |
-
self._cos_cached[seqlen_offset:],
|
293 |
-
self._sin_cached[seqlen_offset:],
|
294 |
-
)
|
295 |
-
else:
|
296 |
-
q = _apply_rotary_emb(
|
297 |
-
qkv,
|
298 |
-
self._cos_cached[seqlen_offset:],
|
299 |
-
self._sin_cached[seqlen_offset:],
|
300 |
-
)
|
301 |
-
kv = _apply_rotary_emb_kv(
|
302 |
-
kv,
|
303 |
-
self._cos_cached[seqlen_offset:],
|
304 |
-
self._sin_cached[seqlen_offset:],
|
305 |
-
)
|
306 |
-
|
307 |
-
return q, kv
|
308 |
-
|
309 |
-
|
310 |
-
class MLP(nn.Module):
|
311 |
-
"""Multi-Layer Perceptron.
|
312 |
-
|
313 |
-
Reference:
|
314 |
-
Attention Is All You Need.
|
315 |
-
https://arxiv.org/pdf/1706.03762.pdf.
|
316 |
-
|
317 |
-
"""
|
318 |
-
|
319 |
-
def __init__(
|
320 |
-
self,
|
321 |
-
config: PretrainedConfig,
|
322 |
-
n_inner: Optional[int] = None,
|
323 |
-
act_fn: Optional[str] = None,
|
324 |
-
) -> None:
|
325 |
-
super().__init__()
|
326 |
-
|
327 |
-
act_fn = config.activation_function if act_fn is None else act_fn
|
328 |
-
|
329 |
-
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
330 |
-
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
331 |
-
|
332 |
-
self.fc1 = nn.Linear(config.n_embd, n_inner)
|
333 |
-
self.fc2 = nn.Linear(n_inner, config.n_embd)
|
334 |
-
self.act = ACT2FN[act_fn]
|
335 |
-
|
336 |
-
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
337 |
-
hidden_states = self.fc1(hidden_states)
|
338 |
-
hidden_states = self.act(hidden_states)
|
339 |
-
hidden_states = self.fc2(hidden_states)
|
340 |
-
|
341 |
-
return hidden_states
|
342 |
-
|
343 |
-
|
344 |
-
class SelfAttention(nn.Module):
|
345 |
-
"""Self-attention layer (compatible with PyTorch).
|
346 |
-
|
347 |
-
Reference:
|
348 |
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
349 |
-
|
350 |
-
"""
|
351 |
-
|
352 |
-
def __init__(
|
353 |
-
self,
|
354 |
-
causal: bool = True,
|
355 |
-
softmax_scale: Optional[float] = None,
|
356 |
-
attention_dropout: float = 0.0,
|
357 |
-
) -> None:
|
358 |
-
super().__init__()
|
359 |
-
|
360 |
-
self.causal = causal
|
361 |
-
self.softmax_scale = softmax_scale
|
362 |
-
self.drop = nn.Dropout(attention_dropout)
|
363 |
-
|
364 |
-
@torch.autocast("cpu", enabled=False)
|
365 |
-
@torch.autocast("cuda", enabled=False)
|
366 |
-
def forward(
|
367 |
-
self,
|
368 |
-
qkv: torch.FloatTensor,
|
369 |
-
causal: bool = None,
|
370 |
-
key_padding_mask: Optional[torch.BoolTensor] = None,
|
371 |
-
**kwargs,
|
372 |
-
) -> torch.FloatTensor:
|
373 |
-
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
374 |
-
q, k, v = qkv.unbind(dim=2)
|
375 |
-
|
376 |
-
q = q.to(torch.float32)
|
377 |
-
k = k.to(torch.float32)
|
378 |
-
|
379 |
-
causal = self.causal if causal is None else causal
|
380 |
-
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
381 |
-
|
382 |
-
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
383 |
-
# using float16, which might lead to overflow
|
384 |
-
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
385 |
-
|
386 |
-
if key_padding_mask is not None:
|
387 |
-
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
388 |
-
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
389 |
-
|
390 |
-
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
391 |
-
|
392 |
-
if causal:
|
393 |
-
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
394 |
-
scores = scores + causal_mask.to(dtype=scores.dtype)
|
395 |
-
|
396 |
-
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
397 |
-
attention = self.drop(attention)
|
398 |
-
|
399 |
-
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
400 |
-
|
401 |
-
return output
|
402 |
-
|
403 |
-
|
404 |
-
class CrossAttention(nn.Module):
|
405 |
-
"""Cross-attention layer (compatible with PyTorch).
|
406 |
-
|
407 |
-
Reference:
|
408 |
-
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
409 |
-
|
410 |
-
"""
|
411 |
-
|
412 |
-
def __init__(
|
413 |
-
self,
|
414 |
-
causal: bool = True,
|
415 |
-
softmax_scale: Optional[float] = None,
|
416 |
-
attention_dropout: float = 0.0,
|
417 |
-
) -> None:
|
418 |
-
super().__init__()
|
419 |
-
|
420 |
-
self.causal = causal
|
421 |
-
self.softmax_scale = softmax_scale
|
422 |
-
self.drop = nn.Dropout(attention_dropout)
|
423 |
-
|
424 |
-
@torch.autocast("cpu", enabled=False)
|
425 |
-
@torch.autocast("cuda", enabled=False)
|
426 |
-
def forward(
|
427 |
-
self,
|
428 |
-
q: torch.FloatTensor,
|
429 |
-
kv: torch.FloatTensor,
|
430 |
-
causal: bool = None,
|
431 |
-
key_padding_mask: Optional[torch.BoolTensor] = None,
|
432 |
-
**kwargs,
|
433 |
-
) -> torch.FloatTensor:
|
434 |
-
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
435 |
-
seqlen_k = kv.shape[1]
|
436 |
-
|
437 |
-
if kv.shape[3] != q.shape[2]:
|
438 |
-
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
439 |
-
k, v = kv.unbind(dim=2)
|
440 |
-
|
441 |
-
q = q.to(torch.float32)
|
442 |
-
k = k.to(torch.float32)
|
443 |
-
|
444 |
-
causal = self.causal if causal is None else causal
|
445 |
-
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
446 |
-
|
447 |
-
# Autocast is manually disabled to avoid `torch.einsum` performing the operation
|
448 |
-
# using float16, which might lead to overflow
|
449 |
-
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
450 |
-
|
451 |
-
if key_padding_mask is not None:
|
452 |
-
padding_mask = torch.full(
|
453 |
-
(batch_size, seqlen_k),
|
454 |
-
-10000.0,
|
455 |
-
dtype=scores.dtype,
|
456 |
-
device=scores.device,
|
457 |
-
)
|
458 |
-
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
459 |
-
|
460 |
-
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
461 |
-
|
462 |
-
if causal:
|
463 |
-
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
|
464 |
-
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
|
465 |
-
causal_mask = cols > rows + seqlen_k - seqlen_q
|
466 |
-
|
467 |
-
scores = scores.masked_fill(causal_mask, -10000.0)
|
468 |
-
|
469 |
-
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
470 |
-
attention = self.drop(attention)
|
471 |
-
|
472 |
-
output = torch.einsum("bhts,bshd->bthd", attention, v)
|
473 |
-
|
474 |
-
return output
|
475 |
-
|
476 |
-
|
477 |
-
def _find_mha_dims(
|
478 |
-
config: PretrainedConfig,
|
479 |
-
n_head: Optional[int] = None,
|
480 |
-
n_head_kv: Optional[int] = None,
|
481 |
-
head_dim: Optional[int] = None,
|
482 |
-
) -> Tuple[int, int]:
|
483 |
-
if n_head is None and head_dim is None:
|
484 |
-
head_dim = config.n_embd // config.n_head
|
485 |
-
n_head = config.n_head
|
486 |
-
elif n_head is None or head_dim is None:
|
487 |
-
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
488 |
-
|
489 |
-
if n_head_kv is None:
|
490 |
-
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
491 |
-
|
492 |
-
return n_head, n_head_kv, head_dim
|
493 |
-
|
494 |
-
|
495 |
-
def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
|
496 |
-
num_heads, head_dim = kv.shape[-2:]
|
497 |
-
|
498 |
-
if layer_idx not in inference_params.key_value_memory_dict:
|
499 |
-
inference_params.key_value_memory_dict[layer_idx] = torch.empty(
|
500 |
-
inference_params.max_batch_size,
|
501 |
-
inference_params.max_seqlen,
|
502 |
-
2,
|
503 |
-
num_heads,
|
504 |
-
head_dim,
|
505 |
-
dtype=kv.dtype,
|
506 |
-
device=kv.device,
|
507 |
-
)
|
508 |
-
|
509 |
-
batch_start = inference_params.batch_size_offset
|
510 |
-
batch_end = batch_start + kv.shape[0]
|
511 |
-
|
512 |
-
sequence_start = inference_params.seqlen_offset
|
513 |
-
sequence_end = sequence_start + kv.shape[1]
|
514 |
-
|
515 |
-
# When the current sequence length is equal to or larger than the maximum sequence length,
|
516 |
-
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
517 |
-
if sequence_end >= inference_params.max_seqlen:
|
518 |
-
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
519 |
-
|
520 |
-
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
521 |
-
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
522 |
-
|
523 |
-
return kv
|
524 |
-
|
525 |
-
|
526 |
-
class MHA(nn.Module):
|
527 |
-
"""Multi-head attention layer."""
|
528 |
-
|
529 |
-
def __init__(
|
530 |
-
self,
|
531 |
-
config: PretrainedConfig,
|
532 |
-
dtype: Optional[torch.dtype] = None,
|
533 |
-
device: Optional[str] = None,
|
534 |
-
rotary_dim: Optional[int] = None,
|
535 |
-
rotary_base: float = 10000.0,
|
536 |
-
rotary_scale_base: Optional[float] = None,
|
537 |
-
n_head: Optional[int] = None,
|
538 |
-
n_head_kv: Optional[int] = None,
|
539 |
-
head_dim: Optional[int] = None,
|
540 |
-
bias: bool = True,
|
541 |
-
causal: bool = True,
|
542 |
-
softmax_scale: Optional[float] = None,
|
543 |
-
layer_idx: Optional[int] = None,
|
544 |
-
return_residual: bool = False,
|
545 |
-
checkpointing: bool = False,
|
546 |
-
) -> None:
|
547 |
-
super().__init__()
|
548 |
-
|
549 |
-
# Rotary embedding
|
550 |
-
self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
|
551 |
-
if self.rotary_dim > 0:
|
552 |
-
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
553 |
-
if rotary_cls is None:
|
554 |
-
rotary_cls = RotaryEmbedding
|
555 |
-
|
556 |
-
rotary_kwargs = {}
|
557 |
-
if rotary_cls is RotaryEmbedding:
|
558 |
-
rotary_kwargs["max_position_embeddings"] = config.n_positions
|
559 |
-
|
560 |
-
self.rotary_emb = rotary_cls(
|
561 |
-
self.rotary_dim,
|
562 |
-
base=rotary_base,
|
563 |
-
scale_base=rotary_scale_base,
|
564 |
-
device=device,
|
565 |
-
**rotary_kwargs,
|
566 |
-
)
|
567 |
-
|
568 |
-
# MLP
|
569 |
-
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
|
570 |
-
config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
|
571 |
-
)
|
572 |
-
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
573 |
-
hidden_size = config.n_embd
|
574 |
-
|
575 |
-
linear_cls = FusedDense if config.fused_dense else nn.Linear
|
576 |
-
if linear_cls is None:
|
577 |
-
linear_cls = nn.Linear
|
578 |
-
|
579 |
-
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
580 |
-
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
581 |
-
|
582 |
-
# Attention
|
583 |
-
attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
|
584 |
-
if attn_cls is None:
|
585 |
-
attn_cls = SelfAttention
|
586 |
-
|
587 |
-
cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
|
588 |
-
if cross_attn_cls is None:
|
589 |
-
cross_attn_cls = CrossAttention
|
590 |
-
|
591 |
-
self.inner_attn = attn_cls(
|
592 |
-
causal=causal,
|
593 |
-
softmax_scale=softmax_scale,
|
594 |
-
attention_dropout=config.attn_pdrop,
|
595 |
-
)
|
596 |
-
self.inner_cross_attn = cross_attn_cls(
|
597 |
-
causal=causal,
|
598 |
-
softmax_scale=softmax_scale,
|
599 |
-
attention_dropout=config.attn_pdrop,
|
600 |
-
)
|
601 |
-
|
602 |
-
self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
603 |
-
self.layer_idx = layer_idx
|
604 |
-
self.return_residual = return_residual
|
605 |
-
self.checkpointing = checkpointing
|
606 |
-
|
607 |
-
def _forward_self_attn(
|
608 |
-
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|
609 |
-
) -> torch.FloatTensor:
|
610 |
-
qkv = self.Wqkv(x)
|
611 |
-
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
612 |
-
|
613 |
-
if self.rotary_dim > 0:
|
614 |
-
qkv = self.rotary_emb(qkv)
|
615 |
-
|
616 |
-
if self.flash_attn:
|
617 |
-
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
618 |
-
|
619 |
-
cu_seqlens, max_seqlen = None, None
|
620 |
-
if key_padding_mask is not None:
|
621 |
-
# If `key_padding_mask` is supplied, we need to unpad the input and retrieve
|
622 |
-
# the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
|
623 |
-
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
|
624 |
-
|
625 |
-
if self.checkpointing:
|
626 |
-
attn_output = torch.utils.checkpoint.checkpoint(
|
627 |
-
self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
|
628 |
-
)
|
629 |
-
else:
|
630 |
-
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
631 |
-
|
632 |
-
# If `key_padding_mask` is supplied, we need to pad the output back to the original shape
|
633 |
-
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
634 |
-
|
635 |
-
if self.checkpointing:
|
636 |
-
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask)
|
637 |
-
|
638 |
-
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
639 |
-
|
640 |
-
def _forward_cross_attn(
|
641 |
-
self,
|
642 |
-
x: torch.FloatTensor,
|
643 |
-
past_key_values: Optional[InferenceParams],
|
644 |
-
key_padding_mask: Optional[torch.BoolTensor],
|
645 |
-
) -> torch.FloatTensor:
|
646 |
-
batch_size = x.shape[0]
|
647 |
-
|
648 |
-
qkv = self.Wqkv(x)
|
649 |
-
|
650 |
-
q = qkv[..., : self.n_head * self.head_dim]
|
651 |
-
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
652 |
-
|
653 |
-
kv = qkv[..., self.n_head * self.head_dim :]
|
654 |
-
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
655 |
-
|
656 |
-
seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
|
657 |
-
causal = None if seqlen_offset == 0 else False
|
658 |
-
if self.rotary_dim > 0:
|
659 |
-
q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
|
660 |
-
|
661 |
-
if past_key_values is not None:
|
662 |
-
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
663 |
-
|
664 |
-
if self.flash_attn:
|
665 |
-
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
666 |
-
seqlen_k = kv.shape[1]
|
667 |
-
|
668 |
-
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
|
669 |
-
None,
|
670 |
-
None,
|
671 |
-
None,
|
672 |
-
None,
|
673 |
-
)
|
674 |
-
if key_padding_mask is not None:
|
675 |
-
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
|
676 |
-
|
677 |
-
if seqlen_q == 1:
|
678 |
-
key_padding_mask = torch.ones(batch_size, 1, device=q.device)
|
679 |
-
elif seqlen_q != seqlen_k:
|
680 |
-
key_padding_mask = key_padding_mask[:, -seqlen_q:]
|
681 |
-
|
682 |
-
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
|
683 |
-
|
684 |
-
if self.checkpointing:
|
685 |
-
attn_output = torch.utils.checkpoint.checkpoint(
|
686 |
-
self.inner_cross_attn,
|
687 |
-
q,
|
688 |
-
kv,
|
689 |
-
causal=causal,
|
690 |
-
cu_seqlens=cu_seqlens_q,
|
691 |
-
max_seqlen=max_seqlen_q,
|
692 |
-
cu_seqlens_k=cu_seqlens_k,
|
693 |
-
max_seqlen_k=max_seqlen_k,
|
694 |
-
)
|
695 |
-
else:
|
696 |
-
attn_output = self.inner_cross_attn(
|
697 |
-
q,
|
698 |
-
kv,
|
699 |
-
causal=causal,
|
700 |
-
cu_seqlens=cu_seqlens_q,
|
701 |
-
max_seqlen=max_seqlen_q,
|
702 |
-
cu_seqlens_k=cu_seqlens_k,
|
703 |
-
max_seqlen_k=max_seqlen_k,
|
704 |
-
)
|
705 |
-
|
706 |
-
return (
|
707 |
-
pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
|
708 |
-
if key_padding_mask is not None
|
709 |
-
else attn_output
|
710 |
-
)
|
711 |
-
|
712 |
-
if self.checkpointing:
|
713 |
-
return torch.utils.checkpoint.checkpoint(
|
714 |
-
self.inner_cross_attn,
|
715 |
-
q,
|
716 |
-
kv,
|
717 |
-
key_padding_mask=key_padding_mask,
|
718 |
-
causal=causal,
|
719 |
-
)
|
720 |
-
|
721 |
-
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
722 |
-
|
723 |
-
def forward(
|
724 |
-
self,
|
725 |
-
x: torch.FloatTensor,
|
726 |
-
past_key_values: Optional[InferenceParams] = None,
|
727 |
-
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
728 |
-
**kwargs,
|
729 |
-
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
730 |
-
if attention_mask is not None:
|
731 |
-
attention_mask = attention_mask.bool()
|
732 |
-
else:
|
733 |
-
attention_mask = None
|
734 |
-
|
735 |
-
# MHA
|
736 |
-
if self.n_head == self.n_head_kv:
|
737 |
-
if past_key_values is None:
|
738 |
-
# If `past_key_values` are not supplied, we run self-attention
|
739 |
-
attn_output = self._forward_self_attn(x, attention_mask)
|
740 |
-
else:
|
741 |
-
# If `past_key_values` are supplied, it means that we might have cached values and
|
742 |
-
# could take advantage of cross-attention
|
743 |
-
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
744 |
-
# MQA / GQA
|
745 |
-
else:
|
746 |
-
# Regardless of `past_key_values` being supplied or not, it always use cross-attention
|
747 |
-
# because `q` and `kv` lengths might be different
|
748 |
-
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
749 |
-
|
750 |
-
output = rearrange(attn_output, "... h d -> ... (h d)")
|
751 |
-
output = self.out_proj(output)
|
752 |
-
|
753 |
-
return output if not self.return_residual else (output, x)
|
754 |
-
|
755 |
-
|
756 |
-
class ParallelBlock(nn.Module):
|
757 |
-
"""Parallel block.
|
758 |
-
|
759 |
-
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
|
760 |
-
|
761 |
-
"""
|
762 |
-
|
763 |
-
def __init__(
|
764 |
-
self,
|
765 |
-
config: PretrainedConfig,
|
766 |
-
block_idx: Optional[int] = None,
|
767 |
-
) -> None:
|
768 |
-
super().__init__()
|
769 |
-
|
770 |
-
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
771 |
-
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
772 |
-
self.block_idx = block_idx
|
773 |
-
|
774 |
-
self.mixer = MHA(config, layer_idx=block_idx)
|
775 |
-
self.mlp = MLP(config)
|
776 |
-
|
777 |
-
def forward(
|
778 |
-
self,
|
779 |
-
hidden_states: torch.FloatTensor,
|
780 |
-
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
781 |
-
attention_mask: Optional[torch.BoolTensor] = None,
|
782 |
-
**kwargs,
|
783 |
-
) -> torch.FloatTensor:
|
784 |
-
residual = hidden_states
|
785 |
-
hidden_states = self.ln(hidden_states)
|
786 |
-
|
787 |
-
attn_outputs = self.mixer(
|
788 |
-
hidden_states,
|
789 |
-
past_key_values=past_key_values,
|
790 |
-
attention_mask=attention_mask,
|
791 |
-
)
|
792 |
-
if isinstance(attn_outputs, tuple):
|
793 |
-
attn_outputs = attn_outputs[0]
|
794 |
-
|
795 |
-
attn_outputs = self.resid_dropout(attn_outputs)
|
796 |
-
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
797 |
-
|
798 |
-
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
799 |
-
|
800 |
-
return hidden_states
|
801 |
-
|
802 |
-
|
803 |
-
class CausalLMHead(nn.Module):
|
804 |
-
"""Causal Language Modeling head.
|
805 |
-
|
806 |
-
Reference:
|
807 |
-
Improving Language Understanding by Generative Pre-Training.
|
808 |
-
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
|
809 |
-
|
810 |
-
"""
|
811 |
-
|
812 |
-
def __init__(self, config: PretrainedConfig) -> None:
|
813 |
-
super().__init__()
|
814 |
-
|
815 |
-
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
816 |
-
self.linear = nn.Linear(config.n_embd, config.vocab_size)
|
817 |
-
|
818 |
-
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
819 |
-
hidden_states = self.ln(hidden_states)
|
820 |
-
logits = self.linear(hidden_states).to(torch.float32)
|
821 |
-
|
822 |
-
return logits
|
823 |
-
|
824 |
-
|
825 |
-
class PhiPreTrainedModel(PreTrainedModel):
|
826 |
-
"""Phi pre-trained model."""
|
827 |
-
|
828 |
-
config_class = PhiConfig
|
829 |
-
base_model_prefix = "transformer"
|
830 |
-
supports_gradient_checkpointing = True
|
831 |
-
_no_split_modules = ["ParallelBlock", "CLIPEncoderLayer", "Block"]
|
832 |
-
|
833 |
-
def __init__(self, *inputs, **kwargs) -> None:
|
834 |
-
super().__init__(*inputs, **kwargs)
|
835 |
-
|
836 |
-
def _init_weights(self, module: nn.Module) -> None:
|
837 |
-
if isinstance(module, (nn.Linear,)):
|
838 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
839 |
-
if module.bias is not None:
|
840 |
-
module.bias.data.zero_()
|
841 |
-
elif isinstance(module, nn.Embedding):
|
842 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
843 |
-
if module.padding_idx is not None:
|
844 |
-
module.weight.data[module.padding_idx].zero_()
|
845 |
-
elif isinstance(module, nn.LayerNorm):
|
846 |
-
if module.bias is not None:
|
847 |
-
module.bias.data.zero_()
|
848 |
-
module.weight.data.fill_(1.0)
|
849 |
-
|
850 |
-
def prepare_inputs_for_generation(
|
851 |
-
self,
|
852 |
-
input_ids: torch.LongTensor,
|
853 |
-
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
854 |
-
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
855 |
-
**kwargs,
|
856 |
-
) -> Dict[str, Any]:
|
857 |
-
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
858 |
-
past_key_values = InferenceParams(
|
859 |
-
max_seqlen=self.config.n_positions,
|
860 |
-
max_batch_size=input_ids.shape[0],
|
861 |
-
seqlen_offset=0,
|
862 |
-
batch_size_offset=0,
|
863 |
-
key_value_memory_dict={},
|
864 |
-
lengths_per_sample=None,
|
865 |
-
)
|
866 |
-
else:
|
867 |
-
# ======================================================================
|
868 |
-
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
869 |
-
# inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...]
|
870 |
-
# past_key_values.seqlen_offset = input_ids.shape[1] - 1
|
871 |
-
# ======================================================================
|
872 |
-
# I change the way of updating `past_key_values.seqlen_offset` to make the inference of imp work.
|
873 |
-
# [Edited by zhenwei - 2024-01-20 21:15]
|
874 |
-
input_ids = input_ids[:, -1].unsqueeze(-1)
|
875 |
-
|
876 |
-
return {
|
877 |
-
"input_ids": input_ids,
|
878 |
-
"past_key_values": past_key_values,
|
879 |
-
"attention_mask": attention_mask,
|
880 |
-
}
|
881 |
-
|
882 |
-
|
883 |
-
class LlavaMetaModel(ABC):
|
884 |
-
"""
|
885 |
-
Define the APIs for building components that are related to image perceiving.
|
886 |
-
This implementation is based on the implementation from the Llave project.
|
887 |
-
"""
|
888 |
-
|
889 |
-
def get_vision_tower(self):
|
890 |
-
vision_tower = getattr(self, 'vision_tower', None)
|
891 |
-
if type(vision_tower) is list:
|
892 |
-
vision_tower = vision_tower[0]
|
893 |
-
return vision_tower
|
894 |
-
|
895 |
-
def build_vision_tower(self, config):
|
896 |
-
self.vision_tower = VisionTower(config.vision_tower_cfg)
|
897 |
-
|
898 |
-
def build_vision_projector(self, config):
|
899 |
-
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
900 |
-
|
901 |
-
if projector_type == 'linear':
|
902 |
-
self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
|
903 |
-
return
|
904 |
-
|
905 |
-
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
906 |
-
if mlp_gelu_match:
|
907 |
-
mlp_depth = int(mlp_gelu_match.group(1))
|
908 |
-
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
909 |
-
for _ in range(1, mlp_depth):
|
910 |
-
modules.append(nn.GELU())
|
911 |
-
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
912 |
-
self.mm_projector = nn.Sequential(*modules)
|
913 |
-
return
|
914 |
-
|
915 |
-
if projector_type == 'identity':
|
916 |
-
self.mm_projector = nn.Identity()
|
917 |
-
return
|
918 |
-
|
919 |
-
raise ValueError(f'Unknown projector type: {projector_type}')
|
920 |
-
|
921 |
-
|
922 |
-
class ImpModel(PhiPreTrainedModel, LlavaMetaModel):
|
923 |
-
"""Imp model. This implementation is modified from the implementation of Phi-2"""
|
924 |
-
|
925 |
-
config_class = ImpConfig
|
926 |
-
# _keys_to_ignore_on_load_missing = [""]
|
927 |
-
# _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
928 |
-
|
929 |
-
def __init__(self, config: ImpConfig) -> None:
|
930 |
-
super().__init__(config)
|
931 |
-
|
932 |
-
self.embd = Embedding(config)
|
933 |
-
self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
|
934 |
-
self.gradient_checkpointing = False
|
935 |
-
|
936 |
-
if hasattr(config, "mm_vision_tower"):
|
937 |
-
self.build_vision_tower(config)
|
938 |
-
self.build_vision_projector(config)
|
939 |
-
|
940 |
-
self.post_init()
|
941 |
-
|
942 |
-
def embed_tokens(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
943 |
-
return self.embd(input_ids)[0]
|
944 |
-
|
945 |
-
def get_input_embeddings(self) -> nn.Embedding:
|
946 |
-
return self.embd.wte
|
947 |
-
|
948 |
-
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
949 |
-
self.embd.wte = new_embeddings
|
950 |
-
|
951 |
-
def forward(
|
952 |
-
self,
|
953 |
-
input_ids: torch.LongTensor,
|
954 |
-
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
955 |
-
attention_mask: Optional[torch.BoolTensor] = None,
|
956 |
-
inputs_embeds: Optional[torch.FloatTensor] = None
|
957 |
-
) -> torch.FloatTensor:
|
958 |
-
|
959 |
-
if inputs_embeds is None:
|
960 |
-
hidden_states = self.embd(input_ids)
|
961 |
-
else:
|
962 |
-
hidden_states = inputs_embeds
|
963 |
-
|
964 |
-
for layer in self.h:
|
965 |
-
if self.gradient_checkpointing and self.training:
|
966 |
-
|
967 |
-
def create_custom_forward(module):
|
968 |
-
def custom_forward(*inputs):
|
969 |
-
# None for past_key_value
|
970 |
-
return module(*inputs)
|
971 |
-
|
972 |
-
return custom_forward
|
973 |
-
|
974 |
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
975 |
-
create_custom_forward(layer),
|
976 |
-
hidden_states,
|
977 |
-
None,
|
978 |
-
attention_mask,
|
979 |
-
)
|
980 |
-
else:
|
981 |
-
hidden_states = layer(
|
982 |
-
hidden_states,
|
983 |
-
past_key_values=past_key_values,
|
984 |
-
attention_mask=attention_mask,
|
985 |
-
)
|
986 |
-
|
987 |
-
# I change the way of updating `past_key_values.seqlen_offset` to make the inference of imp work.
|
988 |
-
# [Edited by zhenwei - 2024-01-20 21:15]
|
989 |
-
if past_key_values is not None: # FIXME: when multi-batch inference, it is a bug
|
990 |
-
past_key_values.seqlen_offset += hidden_states.shape[1]
|
991 |
-
|
992 |
-
return hidden_states
|
993 |
-
|
994 |
-
|
995 |
-
class LlavaMetaForCausalLM(ABC):
|
996 |
-
"""This implementation is based on the implementation from the Llave project."""
|
997 |
-
|
998 |
-
def init_constants(self, config):
|
999 |
-
self.IGNORE_INDEX = getattr(config, 'ignore_index', -100)
|
1000 |
-
self.IMAGE_TOKEN_INDEX = getattr(config, 'image_token_index', 50296)
|
1001 |
-
self.DEFAULT_IMAGE_TOKEN = getattr(config, 'image_token', "<image>")
|
1002 |
-
|
1003 |
-
@abstractmethod
|
1004 |
-
def get_model(self):
|
1005 |
-
pass
|
1006 |
-
|
1007 |
-
def get_vision_tower(self):
|
1008 |
-
return self.get_model().get_vision_tower()
|
1009 |
-
|
1010 |
-
def encode_images(self, images):
|
1011 |
-
image_features = self.get_model().get_vision_tower()(images)
|
1012 |
-
image_features = self.get_model().mm_projector(image_features)
|
1013 |
-
return image_features
|
1014 |
-
|
1015 |
-
def prepare_inputs_labels_for_multimodal(
|
1016 |
-
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
1017 |
-
):
|
1018 |
-
vision_tower = self.get_vision_tower()
|
1019 |
-
# if vision_tower is None or images is None or past_key_values.seqlen_offset != 0:
|
1020 |
-
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
1021 |
-
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
1022 |
-
target_shape = past_key_values.seqlen_offset + 1
|
1023 |
-
# inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...]
|
1024 |
-
attention_mask = torch.cat((attention_mask, torch.ones(
|
1025 |
-
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
1026 |
-
dtype=attention_mask.dtype,
|
1027 |
-
device=attention_mask.device
|
1028 |
-
)), dim=1)
|
1029 |
-
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
1030 |
-
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
1031 |
-
|
1032 |
-
if type(images) is list or images.ndim == 5:
|
1033 |
-
concat_images = torch.cat([image for image in images], dim=0)
|
1034 |
-
concat_images = concat_images.to(device=self.device, dtype=vision_tower.dtype)
|
1035 |
-
image_features = self.encode_images(concat_images)
|
1036 |
-
split_sizes = [image.shape[0] for image in images]
|
1037 |
-
image_features = torch.split(image_features, split_sizes, dim=0)
|
1038 |
-
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
1039 |
-
else:
|
1040 |
-
images = images.to(device=self.device, dtype=vision_tower.dtype)
|
1041 |
-
image_features = self.encode_images(images).to(self.device)
|
1042 |
-
|
1043 |
-
# TODO: image start / end is not implemented here to support pretraining.
|
1044 |
-
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
1045 |
-
raise NotImplementedError
|
1046 |
-
|
1047 |
-
# Let's just add dummy tensors if they do not exist,
|
1048 |
-
# it is a headache to deal with None all the time.
|
1049 |
-
# But it is not ideal, and if you have a better idea,
|
1050 |
-
# please open an issue / submit a PR, thanks.
|
1051 |
-
_labels = labels
|
1052 |
-
_position_ids = position_ids
|
1053 |
-
_attention_mask = attention_mask
|
1054 |
-
if attention_mask is None:
|
1055 |
-
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
1056 |
-
else:
|
1057 |
-
attention_mask = attention_mask.bool()
|
1058 |
-
if position_ids is None:
|
1059 |
-
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
1060 |
-
if labels is None:
|
1061 |
-
labels = torch.full_like(input_ids, self.IGNORE_INDEX)
|
1062 |
-
|
1063 |
-
# remove the padding using attention_mask -- TODO: double check
|
1064 |
-
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
1065 |
-
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
1066 |
-
|
1067 |
-
new_input_embeds = []
|
1068 |
-
new_labels = []
|
1069 |
-
cur_image_idx = 0
|
1070 |
-
for batch_idx, cur_input_ids in enumerate(input_ids):
|
1071 |
-
num_images = (cur_input_ids == self.IMAGE_TOKEN_INDEX).sum()
|
1072 |
-
if num_images == 0:
|
1073 |
-
cur_image_features = image_features[cur_image_idx]
|
1074 |
-
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
1075 |
-
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
1076 |
-
new_input_embeds.append(cur_input_embeds)
|
1077 |
-
new_labels.append(labels[batch_idx])
|
1078 |
-
cur_image_idx += 1
|
1079 |
-
continue
|
1080 |
-
|
1081 |
-
image_token_indices = [-1] + torch.where(cur_input_ids == self.IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
1082 |
-
cur_input_ids_noim = []
|
1083 |
-
cur_labels = labels[batch_idx]
|
1084 |
-
cur_labels_noim = []
|
1085 |
-
for i in range(len(image_token_indices) - 1):
|
1086 |
-
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
1087 |
-
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
1088 |
-
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
1089 |
-
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
1090 |
-
# print(cur_input_embeds.shape)
|
1091 |
-
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
1092 |
-
cur_new_input_embeds = []
|
1093 |
-
cur_new_labels = []
|
1094 |
-
|
1095 |
-
for i in range(num_images + 1):
|
1096 |
-
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
1097 |
-
cur_new_labels.append(cur_labels_noim[i])
|
1098 |
-
if i < num_images:
|
1099 |
-
cur_image_features = image_features[cur_image_idx]
|
1100 |
-
cur_image_idx += 1
|
1101 |
-
cur_new_input_embeds.append(cur_image_features)
|
1102 |
-
cur_new_labels.append(torch.full((cur_image_features.shape[0],), self.IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
1103 |
-
|
1104 |
-
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
1105 |
-
cur_new_labels = torch.cat(cur_new_labels)
|
1106 |
-
|
1107 |
-
new_input_embeds.append(cur_new_input_embeds)
|
1108 |
-
new_labels.append(cur_new_labels)
|
1109 |
-
|
1110 |
-
# Truncate sequences to max length as image embeddings can make the sequence longer
|
1111 |
-
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
1112 |
-
if tokenizer_model_max_length is not None:
|
1113 |
-
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
1114 |
-
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
1115 |
-
|
1116 |
-
# Combine them
|
1117 |
-
max_len = max(x.shape[0] for x in new_input_embeds)
|
1118 |
-
batch_size = len(new_input_embeds)
|
1119 |
-
|
1120 |
-
new_input_embeds_padded = []
|
1121 |
-
new_labels_padded = torch.full((batch_size, max_len), self.IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
1122 |
-
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
1123 |
-
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
1124 |
-
|
1125 |
-
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
1126 |
-
cur_len = cur_new_embed.shape[0]
|
1127 |
-
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
1128 |
-
new_input_embeds_padded.append(torch.cat((
|
1129 |
-
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
1130 |
-
cur_new_embed
|
1131 |
-
), dim=0))
|
1132 |
-
if cur_len > 0:
|
1133 |
-
new_labels_padded[i, -cur_len:] = cur_new_labels
|
1134 |
-
attention_mask[i, -cur_len:] = True
|
1135 |
-
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
1136 |
-
else:
|
1137 |
-
new_input_embeds_padded.append(torch.cat((
|
1138 |
-
cur_new_embed,
|
1139 |
-
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
1140 |
-
), dim=0))
|
1141 |
-
if cur_len > 0:
|
1142 |
-
new_labels_padded[i, :cur_len] = cur_new_labels
|
1143 |
-
attention_mask[i, :cur_len] = True
|
1144 |
-
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
1145 |
-
|
1146 |
-
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
1147 |
-
|
1148 |
-
if _labels is None:
|
1149 |
-
new_labels = None
|
1150 |
-
else:
|
1151 |
-
new_labels = new_labels_padded
|
1152 |
-
|
1153 |
-
if _attention_mask is None:
|
1154 |
-
attention_mask = None
|
1155 |
-
else:
|
1156 |
-
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
1157 |
-
|
1158 |
-
if _position_ids is None:
|
1159 |
-
position_ids = None
|
1160 |
-
|
1161 |
-
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
1162 |
-
|
1163 |
-
|
1164 |
-
class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM):
|
1165 |
-
"""Imp for Causal Language Modeling."""
|
1166 |
-
|
1167 |
-
# _keys_to_ignore_on_load_missing = [""]
|
1168 |
-
# _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
1169 |
-
config_class = ImpConfig
|
1170 |
-
|
1171 |
-
def __init__(self, config: ImpConfig) -> None:
|
1172 |
-
super().__init__(config)
|
1173 |
-
|
1174 |
-
self.transformer = ImpModel(config)
|
1175 |
-
self.lm_head = CausalLMHead(config)
|
1176 |
-
|
1177 |
-
self.post_init()
|
1178 |
-
self.init_constants(config)
|
1179 |
-
|
1180 |
-
def get_output_embeddings(self) -> nn.Linear:
|
1181 |
-
return self.lm_head.linear
|
1182 |
-
|
1183 |
-
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
1184 |
-
self.lm_head.linear = new_embeddings
|
1185 |
-
|
1186 |
-
def get_model(self):
|
1187 |
-
return self.transformer
|
1188 |
-
|
1189 |
-
def image_preprocess(self, images):
|
1190 |
-
return self.get_vision_tower().image_processor(images)['pixel_values']
|
1191 |
-
|
1192 |
-
def backbone_forward(
|
1193 |
-
self,
|
1194 |
-
input_ids: torch.LongTensor,
|
1195 |
-
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
1196 |
-
attention_mask: Optional[torch.BoolTensor] = None,
|
1197 |
-
labels: Optional[torch.LongTensor] = None,
|
1198 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1199 |
-
**kwargs,
|
1200 |
-
) -> CausalLMOutputWithPast:
|
1201 |
-
hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
|
1202 |
-
lm_logits = self.lm_head(hidden_states)
|
1203 |
-
|
1204 |
-
return CausalLMOutputWithPast(loss=None, logits=lm_logits, past_key_values=past_key_values)
|
1205 |
-
|
1206 |
-
def forward(
|
1207 |
-
self,
|
1208 |
-
input_ids: torch.LongTensor = None,
|
1209 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1210 |
-
position_ids: Optional[torch.LongTensor] = None,
|
1211 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1212 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1213 |
-
labels: Optional[torch.LongTensor] = None,
|
1214 |
-
use_cache: Optional[bool] = None,
|
1215 |
-
output_attentions: Optional[bool] = None,
|
1216 |
-
output_hidden_states: Optional[bool] = None,
|
1217 |
-
images: Optional[torch.FloatTensor] = None,
|
1218 |
-
return_dict: Optional[bool] = None,
|
1219 |
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1220 |
-
|
1221 |
-
if inputs_embeds is None:
|
1222 |
-
(
|
1223 |
-
input_ids,
|
1224 |
-
position_ids,
|
1225 |
-
attention_mask,
|
1226 |
-
past_key_values,
|
1227 |
-
inputs_embeds,
|
1228 |
-
labels
|
1229 |
-
) = self.prepare_inputs_labels_for_multimodal(
|
1230 |
-
input_ids,
|
1231 |
-
position_ids,
|
1232 |
-
attention_mask,
|
1233 |
-
past_key_values,
|
1234 |
-
labels,
|
1235 |
-
images
|
1236 |
-
)
|
1237 |
-
|
1238 |
-
return self.backbone_forward(
|
1239 |
-
input_ids=input_ids,
|
1240 |
-
attention_mask=attention_mask,
|
1241 |
-
position_ids=position_ids,
|
1242 |
-
past_key_values=past_key_values,
|
1243 |
-
inputs_embeds=inputs_embeds,
|
1244 |
-
labels=labels,
|
1245 |
-
use_cache=use_cache,
|
1246 |
-
output_attentions=output_attentions,
|
1247 |
-
output_hidden_states=output_hidden_states,
|
1248 |
-
return_dict=return_dict
|
1249 |
-
)
|
1250 |
-
|
1251 |
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
1252 |
-
images = kwargs.pop("images", None)
|
1253 |
-
_inputs = super().prepare_inputs_for_generation(
|
1254 |
-
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
1255 |
-
)
|
1256 |
-
if images is not None:
|
1257 |
-
_inputs['images'] = images
|
1258 |
-
return _inputs
|
1259 |
-
|
1260 |
-
|
1261 |
-
AutoConfig.register("imp", ImpConfig)
|
1262 |
-
AutoModelForCausalLM.register(ImpConfig, ImpForCausalLM)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pytorch_model.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c162cd9d0a121183d6c71232d2e8bfbcbd293e9d37b9ecfea8534800a5350efd
|
3 |
-
size 6374152890
|
|
|
|
|
|
|
|
tokenizer.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
vision_encoder.py
DELETED
@@ -1,593 +0,0 @@
|
|
1 |
-
# Copyright (c) MILVLG team.
|
2 |
-
# Licensed under the Apache 2.0 license.
|
3 |
-
#
|
4 |
-
# Some code here is copied from the project Phi-2 (https://huggingface.co/microsoft/phi-2),
|
5 |
-
# SigLIP@transformers==4.37.0.dev0 (https://huggingface.co/google/siglip-so400m-patch14-384),
|
6 |
-
# and Llava (https://github.com/haotian-liu/LLaVA), and modified by
|
7 |
-
# Zhenwei Shao ([email protected]) @ MILVLG. We thank them for their great works.
|
8 |
-
# And their original licenses and copyright should be inherited (see the statements
|
9 |
-
# in `configuration_imp.py` for more details).
|
10 |
-
|
11 |
-
|
12 |
-
from typing import Any, Optional, Tuple, Union, List, Dict
|
13 |
-
from dataclasses import dataclass
|
14 |
-
import math
|
15 |
-
import warnings
|
16 |
-
from functools import partial, reduce
|
17 |
-
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
from PIL import Image
|
21 |
-
import torch
|
22 |
-
import torch.utils.checkpoint
|
23 |
-
from torch import nn
|
24 |
-
|
25 |
-
from transformers.image_processing_utils import BatchFeature
|
26 |
-
from transformers.image_transforms import (
|
27 |
-
convert_to_rgb,
|
28 |
-
normalize,
|
29 |
-
rescale,
|
30 |
-
resize,
|
31 |
-
to_channel_dimension_format,
|
32 |
-
)
|
33 |
-
from transformers.image_utils import (
|
34 |
-
ChannelDimension,
|
35 |
-
PILImageResampling,
|
36 |
-
to_numpy_array,
|
37 |
-
)
|
38 |
-
from transformers.activations import ACT2FN
|
39 |
-
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
40 |
-
from transformers.modeling_utils import PreTrainedModel
|
41 |
-
from transformers.utils import ModelOutput
|
42 |
-
|
43 |
-
from .configuration_imp import SiglipVisionConfig
|
44 |
-
|
45 |
-
|
46 |
-
# ============================================================================
|
47 |
-
# A simple image preprocessor for SigLIP models.
|
48 |
-
# ============================================================================
|
49 |
-
|
50 |
-
def simple_image_processor(
|
51 |
-
images,
|
52 |
-
image_mean=(0.5, 0.5, 0.5),
|
53 |
-
image_std=(0.5, 0.5, 0.5),
|
54 |
-
size=(384, 384),
|
55 |
-
resample=PILImageResampling.BICUBIC,
|
56 |
-
rescale_factor=1 / 255,
|
57 |
-
data_format=ChannelDimension.FIRST,
|
58 |
-
return_tensors="pt"
|
59 |
-
):
|
60 |
-
|
61 |
-
if isinstance(images, Image.Image):
|
62 |
-
images = [images]
|
63 |
-
else:
|
64 |
-
assert isinstance(images, list)
|
65 |
-
|
66 |
-
transforms = [
|
67 |
-
convert_to_rgb,
|
68 |
-
to_numpy_array,
|
69 |
-
partial(resize, size=size, resample=resample, data_format=data_format),
|
70 |
-
partial(rescale, scale=rescale_factor, data_format=data_format),
|
71 |
-
partial(normalize, mean=image_mean, std=image_std, data_format=data_format),
|
72 |
-
partial(to_channel_dimension_format, channel_dim=data_format, input_channel_dim=data_format),
|
73 |
-
]
|
74 |
-
|
75 |
-
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
|
76 |
-
data = {"pixel_values": images}
|
77 |
-
|
78 |
-
return BatchFeature(data=data, tensor_type=return_tensors)
|
79 |
-
|
80 |
-
# ============================================================================
|
81 |
-
# Definitions for SigLIP models.
|
82 |
-
# ============================================================================
|
83 |
-
|
84 |
-
@dataclass
|
85 |
-
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
|
86 |
-
class SiglipVisionModelOutput(ModelOutput):
|
87 |
-
"""
|
88 |
-
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
89 |
-
|
90 |
-
Args:
|
91 |
-
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
92 |
-
The image embeddings obtained by applying the projection layer to the pooler_output.
|
93 |
-
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
94 |
-
Sequence of hidden-states at the output of the last layer of the model.
|
95 |
-
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
96 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
97 |
-
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
98 |
-
|
99 |
-
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
100 |
-
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
101 |
-
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
102 |
-
sequence_length)`.
|
103 |
-
|
104 |
-
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
105 |
-
heads.
|
106 |
-
"""
|
107 |
-
|
108 |
-
image_embeds: Optional[torch.FloatTensor] = None
|
109 |
-
last_hidden_state: torch.FloatTensor = None
|
110 |
-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
111 |
-
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
112 |
-
|
113 |
-
|
114 |
-
class SiglipVisionEmbeddings(nn.Module):
|
115 |
-
def __init__(self, config: SiglipVisionConfig):
|
116 |
-
super().__init__()
|
117 |
-
self.config = config
|
118 |
-
self.embed_dim = config.hidden_size
|
119 |
-
self.image_size = config.image_size
|
120 |
-
self.patch_size = config.patch_size
|
121 |
-
|
122 |
-
self.patch_embedding = nn.Conv2d(
|
123 |
-
in_channels=config.num_channels,
|
124 |
-
out_channels=self.embed_dim,
|
125 |
-
kernel_size=self.patch_size,
|
126 |
-
stride=self.patch_size,
|
127 |
-
padding="valid",
|
128 |
-
)
|
129 |
-
|
130 |
-
self.num_patches = (self.image_size // self.patch_size) ** 2
|
131 |
-
self.num_positions = self.num_patches
|
132 |
-
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
133 |
-
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
134 |
-
|
135 |
-
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
136 |
-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
137 |
-
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
138 |
-
|
139 |
-
embeddings = embeddings + self.position_embedding(self.position_ids)
|
140 |
-
return embeddings
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
class SiglipAttention(nn.Module):
|
145 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
146 |
-
|
147 |
-
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
148 |
-
def __init__(self, config):
|
149 |
-
super().__init__()
|
150 |
-
self.config = config
|
151 |
-
self.embed_dim = config.hidden_size
|
152 |
-
self.num_heads = config.num_attention_heads
|
153 |
-
self.head_dim = self.embed_dim // self.num_heads
|
154 |
-
if self.head_dim * self.num_heads != self.embed_dim:
|
155 |
-
raise ValueError(
|
156 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
157 |
-
f" {self.num_heads})."
|
158 |
-
)
|
159 |
-
self.scale = self.head_dim**-0.5
|
160 |
-
self.dropout = config.attention_dropout
|
161 |
-
|
162 |
-
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
163 |
-
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
164 |
-
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
165 |
-
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
166 |
-
|
167 |
-
def forward(
|
168 |
-
self,
|
169 |
-
hidden_states: torch.Tensor,
|
170 |
-
attention_mask: Optional[torch.Tensor] = None,
|
171 |
-
output_attentions: Optional[bool] = False,
|
172 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
173 |
-
"""Input shape: Batch x Time x Channel"""
|
174 |
-
|
175 |
-
batch_size, q_len, _ = hidden_states.size()
|
176 |
-
|
177 |
-
query_states = self.q_proj(hidden_states)
|
178 |
-
key_states = self.k_proj(hidden_states)
|
179 |
-
value_states = self.v_proj(hidden_states)
|
180 |
-
|
181 |
-
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
182 |
-
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
183 |
-
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
184 |
-
|
185 |
-
k_v_seq_len = key_states.shape[-2]
|
186 |
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
187 |
-
|
188 |
-
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
189 |
-
raise ValueError(
|
190 |
-
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
191 |
-
f" {attn_weights.size()}"
|
192 |
-
)
|
193 |
-
|
194 |
-
if attention_mask is not None:
|
195 |
-
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
196 |
-
raise ValueError(
|
197 |
-
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
198 |
-
)
|
199 |
-
attn_weights = attn_weights + attention_mask
|
200 |
-
|
201 |
-
# upcast attention to fp32
|
202 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
203 |
-
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
204 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
205 |
-
|
206 |
-
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
|
207 |
-
raise ValueError(
|
208 |
-
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
209 |
-
f" {attn_output.size()}"
|
210 |
-
)
|
211 |
-
|
212 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
213 |
-
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
214 |
-
|
215 |
-
attn_output = self.out_proj(attn_output)
|
216 |
-
|
217 |
-
return attn_output, attn_weights
|
218 |
-
|
219 |
-
|
220 |
-
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
221 |
-
class SiglipMLP(nn.Module):
|
222 |
-
def __init__(self, config):
|
223 |
-
super().__init__()
|
224 |
-
self.config = config
|
225 |
-
self.activation_fn = ACT2FN[config.hidden_act]
|
226 |
-
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
227 |
-
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
228 |
-
|
229 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
230 |
-
hidden_states = self.fc1(hidden_states)
|
231 |
-
hidden_states = self.activation_fn(hidden_states)
|
232 |
-
hidden_states = self.fc2(hidden_states)
|
233 |
-
return hidden_states
|
234 |
-
|
235 |
-
|
236 |
-
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
|
237 |
-
class SiglipEncoderLayer(nn.Module):
|
238 |
-
def __init__(self, config: SiglipVisionConfig):
|
239 |
-
super().__init__()
|
240 |
-
self.embed_dim = config.hidden_size
|
241 |
-
self.self_attn = SiglipAttention(config)
|
242 |
-
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
243 |
-
self.mlp = SiglipMLP(config)
|
244 |
-
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
245 |
-
|
246 |
-
# Ignore copy
|
247 |
-
def forward(
|
248 |
-
self,
|
249 |
-
hidden_states: torch.Tensor,
|
250 |
-
attention_mask: torch.Tensor,
|
251 |
-
output_attentions: Optional[bool] = False,
|
252 |
-
) -> Tuple[torch.FloatTensor]:
|
253 |
-
"""
|
254 |
-
Args:
|
255 |
-
hidden_states (`torch.FloatTensor`):
|
256 |
-
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
257 |
-
attention_mask (`torch.FloatTensor`):
|
258 |
-
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
259 |
-
output_attentions (`bool`, *optional*, defaults to `False`):
|
260 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
261 |
-
returned tensors for more detail.
|
262 |
-
"""
|
263 |
-
residual = hidden_states
|
264 |
-
|
265 |
-
hidden_states = self.layer_norm1(hidden_states)
|
266 |
-
hidden_states, attn_weights = self.self_attn(
|
267 |
-
hidden_states=hidden_states,
|
268 |
-
attention_mask=attention_mask,
|
269 |
-
output_attentions=output_attentions,
|
270 |
-
)
|
271 |
-
hidden_states = residual + hidden_states
|
272 |
-
|
273 |
-
residual = hidden_states
|
274 |
-
hidden_states = self.layer_norm2(hidden_states)
|
275 |
-
hidden_states = self.mlp(hidden_states)
|
276 |
-
hidden_states = residual + hidden_states
|
277 |
-
|
278 |
-
outputs = (hidden_states,)
|
279 |
-
|
280 |
-
if output_attentions:
|
281 |
-
outputs += (attn_weights,)
|
282 |
-
|
283 |
-
return outputs
|
284 |
-
|
285 |
-
|
286 |
-
class SiglipPreTrainedModel(PreTrainedModel):
|
287 |
-
"""
|
288 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
289 |
-
models.
|
290 |
-
"""
|
291 |
-
|
292 |
-
config_class = SiglipVisionConfig
|
293 |
-
base_model_prefix = "siglip"
|
294 |
-
supports_gradient_checkpointing = True
|
295 |
-
|
296 |
-
def _init_weights(self, module):
|
297 |
-
"""Initialize the weights"""
|
298 |
-
pass
|
299 |
-
|
300 |
-
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
|
301 |
-
class SiglipEncoder(nn.Module):
|
302 |
-
"""
|
303 |
-
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
304 |
-
[`SiglipEncoderLayer`].
|
305 |
-
|
306 |
-
Args:
|
307 |
-
config: SiglipVisionConfig
|
308 |
-
"""
|
309 |
-
|
310 |
-
def __init__(self, config: SiglipVisionConfig):
|
311 |
-
super().__init__()
|
312 |
-
self.config = config
|
313 |
-
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
314 |
-
self.gradient_checkpointing = False
|
315 |
-
|
316 |
-
# Ignore copy
|
317 |
-
def forward(
|
318 |
-
self,
|
319 |
-
inputs_embeds,
|
320 |
-
attention_mask: Optional[torch.Tensor] = None,
|
321 |
-
output_attentions: Optional[bool] = None,
|
322 |
-
output_hidden_states: Optional[bool] = None,
|
323 |
-
return_dict: Optional[bool] = None,
|
324 |
-
) -> Union[Tuple, BaseModelOutput]:
|
325 |
-
r"""
|
326 |
-
Args:
|
327 |
-
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
328 |
-
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
329 |
-
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
330 |
-
than the model's internal embedding lookup matrix.
|
331 |
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
332 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
333 |
-
|
334 |
-
- 1 for tokens that are **not masked**,
|
335 |
-
- 0 for tokens that are **masked**.
|
336 |
-
|
337 |
-
[What are attention masks?](../glossary#attention-mask)
|
338 |
-
output_attentions (`bool`, *optional*):
|
339 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
340 |
-
returned tensors for more detail.
|
341 |
-
output_hidden_states (`bool`, *optional*):
|
342 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
343 |
-
for more detail.
|
344 |
-
return_dict (`bool`, *optional*):
|
345 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
346 |
-
"""
|
347 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
348 |
-
output_hidden_states = (
|
349 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
350 |
-
)
|
351 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
352 |
-
|
353 |
-
encoder_states = () if output_hidden_states else None
|
354 |
-
all_attentions = () if output_attentions else None
|
355 |
-
|
356 |
-
hidden_states = inputs_embeds
|
357 |
-
for encoder_layer in self.layers:
|
358 |
-
if output_hidden_states:
|
359 |
-
encoder_states = encoder_states + (hidden_states,)
|
360 |
-
if self.gradient_checkpointing and self.training:
|
361 |
-
layer_outputs = self._gradient_checkpointing_func(
|
362 |
-
encoder_layer.__call__,
|
363 |
-
hidden_states,
|
364 |
-
attention_mask,
|
365 |
-
output_attentions,
|
366 |
-
)
|
367 |
-
else:
|
368 |
-
layer_outputs = encoder_layer(
|
369 |
-
hidden_states,
|
370 |
-
attention_mask,
|
371 |
-
output_attentions=output_attentions,
|
372 |
-
)
|
373 |
-
|
374 |
-
hidden_states = layer_outputs[0]
|
375 |
-
|
376 |
-
if output_attentions:
|
377 |
-
all_attentions = all_attentions + (layer_outputs[1],)
|
378 |
-
|
379 |
-
if output_hidden_states:
|
380 |
-
encoder_states = encoder_states + (hidden_states,)
|
381 |
-
|
382 |
-
if not return_dict:
|
383 |
-
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
384 |
-
return BaseModelOutput(
|
385 |
-
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
386 |
-
)
|
387 |
-
|
388 |
-
|
389 |
-
class SiglipVisionTransformer(nn.Module):
|
390 |
-
def __init__(self, config: SiglipVisionConfig):
|
391 |
-
super().__init__()
|
392 |
-
self.config = config
|
393 |
-
embed_dim = config.hidden_size
|
394 |
-
|
395 |
-
self.embeddings = SiglipVisionEmbeddings(config)
|
396 |
-
self.encoder = SiglipEncoder(config)
|
397 |
-
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
398 |
-
self.head = SiglipMultiheadAttentionPoolingHead(config)
|
399 |
-
|
400 |
-
def forward(
|
401 |
-
self,
|
402 |
-
pixel_values,
|
403 |
-
output_attentions: Optional[bool] = None,
|
404 |
-
output_hidden_states: Optional[bool] = None,
|
405 |
-
return_dict: Optional[bool] = None,
|
406 |
-
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
407 |
-
r"""
|
408 |
-
Returns:
|
409 |
-
|
410 |
-
"""
|
411 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
412 |
-
output_hidden_states = (
|
413 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
414 |
-
)
|
415 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
416 |
-
|
417 |
-
hidden_states = self.embeddings(pixel_values)
|
418 |
-
|
419 |
-
encoder_outputs = self.encoder(
|
420 |
-
inputs_embeds=hidden_states,
|
421 |
-
output_attentions=output_attentions,
|
422 |
-
output_hidden_states=output_hidden_states,
|
423 |
-
return_dict=return_dict,
|
424 |
-
)
|
425 |
-
|
426 |
-
last_hidden_state = encoder_outputs[0]
|
427 |
-
last_hidden_state = self.post_layernorm(last_hidden_state)
|
428 |
-
|
429 |
-
pooled_output = self.head(last_hidden_state)
|
430 |
-
|
431 |
-
if not return_dict:
|
432 |
-
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
433 |
-
|
434 |
-
return BaseModelOutputWithPooling(
|
435 |
-
last_hidden_state=last_hidden_state,
|
436 |
-
pooler_output=pooled_output,
|
437 |
-
hidden_states=encoder_outputs.hidden_states,
|
438 |
-
attentions=encoder_outputs.attentions,
|
439 |
-
)
|
440 |
-
|
441 |
-
|
442 |
-
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
443 |
-
"""Multihead Attention Pooling."""
|
444 |
-
|
445 |
-
def __init__(self, config: SiglipVisionConfig):
|
446 |
-
super().__init__()
|
447 |
-
|
448 |
-
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
449 |
-
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
|
450 |
-
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
451 |
-
self.mlp = SiglipMLP(config)
|
452 |
-
|
453 |
-
def forward(self, hidden_state):
|
454 |
-
batch_size = hidden_state.shape[0]
|
455 |
-
probe = self.probe.repeat(batch_size, 1, 1)
|
456 |
-
|
457 |
-
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
458 |
-
|
459 |
-
residual = hidden_state
|
460 |
-
hidden_state = self.layernorm(hidden_state)
|
461 |
-
hidden_state = residual + self.mlp(hidden_state)
|
462 |
-
|
463 |
-
return hidden_state[:, 0]
|
464 |
-
|
465 |
-
|
466 |
-
class SiglipVisionModel(SiglipPreTrainedModel):
|
467 |
-
config_class = SiglipVisionConfig
|
468 |
-
main_input_name = "pixel_values"
|
469 |
-
_no_split_modules = ["SiglipEncoderLayer"]
|
470 |
-
|
471 |
-
def __init__(self, config: SiglipVisionConfig):
|
472 |
-
super().__init__(config)
|
473 |
-
|
474 |
-
self.vision_model = SiglipVisionTransformer(config)
|
475 |
-
|
476 |
-
# Initialize weights and apply final processing
|
477 |
-
self.post_init()
|
478 |
-
|
479 |
-
def get_input_embeddings(self) -> nn.Module:
|
480 |
-
return self.vision_model.embeddings.patch_embedding
|
481 |
-
|
482 |
-
def forward(
|
483 |
-
self,
|
484 |
-
pixel_values,
|
485 |
-
output_attentions: Optional[bool] = None,
|
486 |
-
output_hidden_states: Optional[bool] = None,
|
487 |
-
return_dict: Optional[bool] = None,
|
488 |
-
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
489 |
-
r"""
|
490 |
-
Returns:
|
491 |
-
|
492 |
-
Examples:
|
493 |
-
|
494 |
-
```python
|
495 |
-
>>> from PIL import Image
|
496 |
-
>>> import requests
|
497 |
-
>>> from transformers import AutoProcessor, SiglipVisionModel
|
498 |
-
|
499 |
-
>>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
|
500 |
-
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
501 |
-
|
502 |
-
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
503 |
-
>>> image = Image.open(requests.get(url, stream=True).raw)
|
504 |
-
|
505 |
-
>>> inputs = processor(images=image, return_tensors="pt")
|
506 |
-
|
507 |
-
>>> outputs = model(**inputs)
|
508 |
-
>>> last_hidden_state = outputs.last_hidden_state
|
509 |
-
>>> pooled_output = outputs.pooler_output # pooled features
|
510 |
-
```"""
|
511 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
512 |
-
|
513 |
-
return self.vision_model(
|
514 |
-
pixel_values=pixel_values,
|
515 |
-
output_attentions=output_attentions,
|
516 |
-
output_hidden_states=output_hidden_states,
|
517 |
-
return_dict=return_dict,
|
518 |
-
)
|
519 |
-
|
520 |
-
|
521 |
-
# ============================================================================
|
522 |
-
# VisionTower module for Imp
|
523 |
-
# ============================================================================
|
524 |
-
|
525 |
-
class VisionTower(nn.Module):
|
526 |
-
def __init__(self, vision_tower_cfg, delay_load=False):
|
527 |
-
super().__init__()
|
528 |
-
|
529 |
-
self.is_loaded = False
|
530 |
-
|
531 |
-
self.config = vision_tower_cfg
|
532 |
-
self.vision_tower_name = vision_tower_cfg.mm_vision_tower
|
533 |
-
self.select_layer = vision_tower_cfg.mm_vision_select_layer
|
534 |
-
# self.select_feature = getattr(vision_tower_cfg, 'mm_vision_select_feature', 'patch')
|
535 |
-
|
536 |
-
self.image_processor = simple_image_processor
|
537 |
-
|
538 |
-
if not delay_load:
|
539 |
-
self.load_model()
|
540 |
-
else:
|
541 |
-
raise NotImplementedError("delay load is not implemented yet.")
|
542 |
-
|
543 |
-
def load_model(self):
|
544 |
-
if self.is_loaded:
|
545 |
-
return
|
546 |
-
|
547 |
-
# "google/siglip-so400m-patch14-384"
|
548 |
-
# self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
|
549 |
-
self.vision_tower = SiglipVisionModel(self.config)
|
550 |
-
del self.vision_tower.vision_model.encoder.layers[(self.select_layer + 1):]
|
551 |
-
self.vision_tower.vision_model.head = nn.Identity()
|
552 |
-
self.vision_tower.requires_grad_(False)
|
553 |
-
self.vision_tower.eval()
|
554 |
-
|
555 |
-
self.is_loaded = True
|
556 |
-
|
557 |
-
@torch.no_grad()
|
558 |
-
def forward(self, images):
|
559 |
-
if type(images) is list:
|
560 |
-
image_features = []
|
561 |
-
for image in images:
|
562 |
-
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
563 |
-
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
564 |
-
assert image_features.shape[-2] == 729
|
565 |
-
image_features.append(image_feature)
|
566 |
-
else:
|
567 |
-
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
568 |
-
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
569 |
-
assert image_features.shape[-2] == 729
|
570 |
-
|
571 |
-
return image_features
|
572 |
-
|
573 |
-
@property
|
574 |
-
def dummy_feature(self):
|
575 |
-
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
576 |
-
|
577 |
-
@property
|
578 |
-
def dtype(self):
|
579 |
-
for p in self.vision_tower.parameters():
|
580 |
-
return p.dtype
|
581 |
-
|
582 |
-
@property
|
583 |
-
def device(self):
|
584 |
-
for p in self.vision_tower.parameters():
|
585 |
-
return p.device
|
586 |
-
|
587 |
-
@property
|
588 |
-
def hidden_size(self):
|
589 |
-
return self.config.hidden_size
|
590 |
-
|
591 |
-
@property
|
592 |
-
def num_patches(self):
|
593 |
-
return (self.config.image_size // self.config.patch_size) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|