YenChunChen
commited on
Commit
·
bfc0c05
1
Parent(s):
5dcfdfb
make flash attention usage configurable from user code
Browse files- config.json +1 -2
- modeling_phi3_v.py +7 -6
config.json
CHANGED
@@ -143,6 +143,5 @@
|
|
143 |
"torch_dtype": "bfloat16",
|
144 |
"transformers_version": "4.38.1",
|
145 |
"use_cache": true,
|
146 |
-
"vocab_size": 32064
|
147 |
-
"_attn_implementation": "flash_attention_2"
|
148 |
}
|
|
|
143 |
"torch_dtype": "bfloat16",
|
144 |
"transformers_version": "4.38.1",
|
145 |
"use_cache": true,
|
146 |
+
"vocab_size": 32064
|
|
|
147 |
}
|
modeling_phi3_v.py
CHANGED
@@ -40,7 +40,6 @@ from transformers.utils import (
|
|
40 |
add_code_sample_docstrings,
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
43 |
-
is_flash_attn_2_available,
|
44 |
is_flash_attn_greater_or_equal_2_10,
|
45 |
logging,
|
46 |
replace_return_docstrings,
|
@@ -49,11 +48,13 @@ from .configuration_phi3_v import Phi3VConfig
|
|
49 |
from .image_embedding_phi3_v import Phi3ImageEmbedding
|
50 |
|
51 |
|
52 |
-
|
53 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
54 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
55 |
|
56 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
|
|
|
|
57 |
|
58 |
logger = logging.get_logger(__name__)
|
59 |
|
@@ -1000,8 +1001,8 @@ PHI3V_INPUTS_DOCSTRING = r"""
|
|
1000 |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1001 |
model's internal embedding lookup matrix.
|
1002 |
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
1003 |
-
The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
|
1004 |
-
See [`Phi3ImageProcessor.__call__`] for details.
|
1005 |
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
|
1006 |
The sizes of the images in the batch, being (height, width) for each image.
|
1007 |
use_cache (`bool`, *optional*):
|
@@ -1046,7 +1047,7 @@ class Phi3VModel(Phi3VPreTrainedModel):
|
|
1046 |
**config.embd_layer
|
1047 |
}
|
1048 |
self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
|
1049 |
-
# # set wte the same for vision embedding
|
1050 |
# self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
|
1051 |
|
1052 |
self.layers = nn.ModuleList(
|
@@ -1629,4 +1630,4 @@ class Phi3VForTokenClassification(Phi3VPreTrainedModel):
|
|
1629 |
logits=logits,
|
1630 |
hidden_states=model_outputs.hidden_states,
|
1631 |
attentions=model_outputs.attentions,
|
1632 |
-
)
|
|
|
40 |
add_code_sample_docstrings,
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
|
|
43 |
is_flash_attn_greater_or_equal_2_10,
|
44 |
logging,
|
45 |
replace_return_docstrings,
|
|
|
48 |
from .image_embedding_phi3_v import Phi3ImageEmbedding
|
49 |
|
50 |
|
51 |
+
try:
|
52 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
53 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
54 |
|
55 |
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
56 |
+
except ImportError:
|
57 |
+
pass
|
58 |
|
59 |
logger = logging.get_logger(__name__)
|
60 |
|
|
|
1001 |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
1002 |
model's internal embedding lookup matrix.
|
1003 |
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
1004 |
+
The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
|
1005 |
+
See [`Phi3ImageProcessor.__call__`] for details.
|
1006 |
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
|
1007 |
The sizes of the images in the batch, being (height, width) for each image.
|
1008 |
use_cache (`bool`, *optional*):
|
|
|
1047 |
**config.embd_layer
|
1048 |
}
|
1049 |
self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
|
1050 |
+
# # set wte the same for vision embedding
|
1051 |
# self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
|
1052 |
|
1053 |
self.layers = nn.ModuleList(
|
|
|
1630 |
logits=logits,
|
1631 |
hidden_states=model_outputs.hidden_states,
|
1632 |
attentions=model_outputs.attentions,
|
1633 |
+
)
|