Use Flash Attention for CLIP image encoder (#52)
Browse files- add flash attention support for CLIP (5d83217d777dcbf324206dba8b723231761f8a29)
Co-authored-by: Yen-Chun Chen <[email protected]>
- image_embedding_phi3_v.py +62 -17
image_embedding_phi3_v.py
CHANGED
@@ -13,13 +13,18 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
|
16 |
-
import
|
|
|
17 |
import torch
|
18 |
-
|
19 |
-
from transformers import CLIPVisionModel, PretrainedConfig
|
20 |
-
from transformers import
|
21 |
from transformers.utils import logging
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
24 |
logger = logging.get_logger(__name__)
|
25 |
|
@@ -37,9 +42,42 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
|
|
37 |
num_channels=3,
|
38 |
num_hidden_layers=24,
|
39 |
patch_size=14,
|
40 |
-
projection_dim=768
|
41 |
)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
class Phi3ImageEmbedding(nn.Module):
|
44 |
"""Phi3 Image embedding."""
|
45 |
|
@@ -65,6 +103,13 @@ class Phi3ImageEmbedding(nn.Module):
|
|
65 |
self.img_processor = CLIPVisionModel(clip_config)
|
66 |
image_dim_out = config.img_processor['image_dim_out']
|
67 |
self.num_img_tokens = config.img_processor['num_img_tokens']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
else:
|
69 |
raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')
|
70 |
|
@@ -157,15 +202,15 @@ class Phi3ImageEmbedding(nn.Module):
|
|
157 |
|
158 |
with torch.no_grad():
|
159 |
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
|
160 |
-
|
161 |
select = False
|
162 |
|
163 |
-
if isinstance(self.img_projection, nn.Sequential):
|
164 |
-
target_device = self.img_projection[0].bias.device
|
165 |
-
target_dtype = self.img_projection[0].bias.dtype
|
166 |
-
else: # It's a single nn.Linear layer
|
167 |
-
target_device = self.img_projection.bias.device
|
168 |
-
target_dtype = self.img_projection.bias.dtype
|
169 |
|
170 |
if len(positions.tolist()) > 0:
|
171 |
with torch.no_grad():
|
@@ -197,7 +242,7 @@ class Phi3ImageEmbedding(nn.Module):
|
|
197 |
img_sizes = img_sizes.view(-1, 2)
|
198 |
for _bs in range(bs):
|
199 |
h, w = img_sizes[_bs]
|
200 |
-
h = h // 336
|
201 |
w = w // 336
|
202 |
B_ = h * w
|
203 |
|
@@ -235,7 +280,7 @@ class Phi3ImageEmbedding(nn.Module):
|
|
235 |
temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
|
236 |
assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
|
237 |
output_len.append(temp_len)
|
238 |
-
|
239 |
num_img_tokens = output_len
|
240 |
img_set_tensor = []
|
241 |
for _output_img in output_imgs:
|
@@ -267,10 +312,10 @@ class Phi3ImageEmbedding(nn.Module):
|
|
267 |
else:
|
268 |
raise NotImplementedError
|
269 |
select = True
|
270 |
-
|
271 |
with torch.no_grad():
|
272 |
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
273 |
-
|
274 |
hidden_states = self.wte(input_ids)
|
275 |
|
276 |
if select:
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
|
16 |
+
from datetime import datetime
|
17 |
+
|
18 |
import torch
|
19 |
+
from torch import nn
|
20 |
+
from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
|
21 |
+
from transformers.models.clip.modeling_clip import CLIPAttention
|
22 |
from transformers.utils import logging
|
23 |
+
|
24 |
+
try:
|
25 |
+
from flash_attn import flash_attn_func
|
26 |
+
except ImportError:
|
27 |
+
pass
|
28 |
|
29 |
logger = logging.get_logger(__name__)
|
30 |
|
|
|
42 |
num_channels=3,
|
43 |
num_hidden_layers=24,
|
44 |
patch_size=14,
|
45 |
+
projection_dim=768
|
46 |
)
|
47 |
|
48 |
+
class CLIPAttentionFA2(CLIPAttention):
|
49 |
+
"""Add flash attention 2 to CLIPAttention. (This is only used in the vision encoder)"""
|
50 |
+
|
51 |
+
def forward(self,
|
52 |
+
hidden_states,
|
53 |
+
attention_mask=None,
|
54 |
+
causal_attention_mask=None,
|
55 |
+
output_attentions=False,
|
56 |
+
):
|
57 |
+
"""Input shape: Batch x Time x Channel"""
|
58 |
+
|
59 |
+
assert attention_mask is None, "CLIPAttentionFA2 does not support attention_mask"
|
60 |
+
assert causal_attention_mask is None, "CLIPAttentionFA2 does not support causal_attention_mask"
|
61 |
+
assert output_attentions is False, "CLIPAttentionFA2 does not support output_attentions"
|
62 |
+
|
63 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
64 |
+
query_states = self.q_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
|
65 |
+
key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
|
66 |
+
value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
|
67 |
+
|
68 |
+
attn_output = flash_attn_func(
|
69 |
+
query_states,
|
70 |
+
key_states,
|
71 |
+
value_states,
|
72 |
+
dropout_p=self.dropout if self.training else 0.0,
|
73 |
+
softmax_scale=self.scale,
|
74 |
+
causal=False,
|
75 |
+
).reshape(bsz, tgt_len, embed_dim)
|
76 |
+
|
77 |
+
attn_output = self.out_proj(attn_output)
|
78 |
+
return attn_output, None
|
79 |
+
|
80 |
+
|
81 |
class Phi3ImageEmbedding(nn.Module):
|
82 |
"""Phi3 Image embedding."""
|
83 |
|
|
|
103 |
self.img_processor = CLIPVisionModel(clip_config)
|
104 |
image_dim_out = config.img_processor['image_dim_out']
|
105 |
self.num_img_tokens = config.img_processor['num_img_tokens']
|
106 |
+
|
107 |
+
# FA2 in CLIP
|
108 |
+
if config._attn_implementation == 'flash_attention_2':
|
109 |
+
for layer in self.img_processor.vision_model.encoder.layers:
|
110 |
+
clip_fa2 = CLIPAttentionFA2(clip_config)
|
111 |
+
del layer.self_attn
|
112 |
+
layer.self_attn = clip_fa2
|
113 |
else:
|
114 |
raise NotImplementedError(f'img_processor = {config.img_processor}, not implemented')
|
115 |
|
|
|
202 |
|
203 |
with torch.no_grad():
|
204 |
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=False)
|
205 |
+
|
206 |
select = False
|
207 |
|
208 |
+
if isinstance(self.img_projection, nn.Sequential):
|
209 |
+
target_device = self.img_projection[0].bias.device
|
210 |
+
target_dtype = self.img_projection[0].bias.dtype
|
211 |
+
else: # It's a single nn.Linear layer
|
212 |
+
target_device = self.img_projection.bias.device
|
213 |
+
target_dtype = self.img_projection.bias.dtype
|
214 |
|
215 |
if len(positions.tolist()) > 0:
|
216 |
with torch.no_grad():
|
|
|
242 |
img_sizes = img_sizes.view(-1, 2)
|
243 |
for _bs in range(bs):
|
244 |
h, w = img_sizes[_bs]
|
245 |
+
h = h // 336
|
246 |
w = w // 336
|
247 |
B_ = h * w
|
248 |
|
|
|
280 |
temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
|
281 |
assert temp_len == output_imgs[-1].shape[1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}'
|
282 |
output_len.append(temp_len)
|
283 |
+
|
284 |
num_img_tokens = output_len
|
285 |
img_set_tensor = []
|
286 |
for _output_img in output_imgs:
|
|
|
312 |
else:
|
313 |
raise NotImplementedError
|
314 |
select = True
|
315 |
+
|
316 |
with torch.no_grad():
|
317 |
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
318 |
+
|
319 |
hidden_states = self.wte(input_ids)
|
320 |
|
321 |
if select:
|