Make flash attention configurable in user code

#26
by YenChunChen - opened
Files changed (2) hide show
  1. README.md +1 -19
  2. modeling_phi3_v.py +7 -6
README.md CHANGED
@@ -105,7 +105,7 @@ from transformers import AutoProcessor
105
 
106
  model_id = "microsoft/Phi-3-vision-128k-instruct"
107
 
108
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto")
109
 
110
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
111
 
@@ -217,24 +217,6 @@ Note that by default, the Phi-3-Vision-128K model uses flash attention, which re
217
  * NVIDIA A6000
218
  * NVIDIA H100
219
 
220
- ### Running on Windows or without flash attention
221
- To enable the model on these enviroment here are steps that you may consider to follow:
222
-
223
- Step 1: comment flash attention import code in modeling_phi3_v.py from line 52 to line 56.
224
- ```python
225
- # if is_flash_attn_2_available():
226
- # from flash_attn import flash_attn_func, flash_attn_varlen_func
227
- # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
228
-
229
- # _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
230
- ```
231
-
232
- Step 2: change _"_attn_implementation"_ from _"flash_attention_2"_ to _"eager"_ in config.json or disable flash attention when you create the model as below.
233
-
234
- ```python
235
- model = AutoModelForCausalLM.from_pretrained('microsoft/Phi-3-vision-128k-instruct', device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation="eager")
236
- ```
237
-
238
  ## License
239
 
240
  The model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/resolve/main/LICENSE).
 
105
 
106
  model_id = "microsoft/Phi-3-vision-128k-instruct"
107
 
108
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto", _attn_implementation='flash_attention_2') # use _attn_implementation='eager' to disable flash attention
109
 
110
  processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
111
 
 
217
  * NVIDIA A6000
218
  * NVIDIA H100
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  ## License
221
 
222
  The model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/resolve/main/LICENSE).
modeling_phi3_v.py CHANGED
@@ -40,7 +40,6 @@ from transformers.utils import (
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
- is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
@@ -49,11 +48,13 @@ from .configuration_phi3_v import Phi3VConfig
49
  from .image_embedding_phi3_v import Phi3ImageEmbedding
50
 
51
 
52
- if is_flash_attn_2_available():
53
  from flash_attn import flash_attn_func, flash_attn_varlen_func
54
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55
 
56
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
 
57
 
58
  logger = logging.get_logger(__name__)
59
 
@@ -1000,8 +1001,8 @@ PHI3V_INPUTS_DOCSTRING = r"""
1000
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1001
  model's internal embedding lookup matrix.
1002
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1003
- The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
1004
- See [`Phi3ImageProcessor.__call__`] for details.
1005
  image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
1006
  The sizes of the images in the batch, being (height, width) for each image.
1007
  use_cache (`bool`, *optional*):
@@ -1046,7 +1047,7 @@ class Phi3VModel(Phi3VPreTrainedModel):
1046
  **config.embd_layer
1047
  }
1048
  self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
1049
- # # set wte the same for vision embedding
1050
  # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
1051
 
1052
  self.layers = nn.ModuleList(
@@ -1629,4 +1630,4 @@ class Phi3VForTokenClassification(Phi3VPreTrainedModel):
1629
  logits=logits,
1630
  hidden_states=model_outputs.hidden_states,
1631
  attentions=model_outputs.attentions,
1632
- )
 
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
 
43
  is_flash_attn_greater_or_equal_2_10,
44
  logging,
45
  replace_return_docstrings,
 
48
  from .image_embedding_phi3_v import Phi3ImageEmbedding
49
 
50
 
51
+ try:
52
  from flash_attn import flash_attn_func, flash_attn_varlen_func
53
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
54
 
55
  _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
56
+ except ImportError:
57
+ pass
58
 
59
  logger = logging.get_logger(__name__)
60
 
 
1001
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1002
  model's internal embedding lookup matrix.
1003
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
1004
+ The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`].
1005
+ See [`Phi3ImageProcessor.__call__`] for details.
1006
  image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
1007
  The sizes of the images in the batch, being (height, width) for each image.
1008
  use_cache (`bool`, *optional*):
 
1047
  **config.embd_layer
1048
  }
1049
  self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config)
1050
+ # # set wte the same for vision embedding
1051
  # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight
1052
 
1053
  self.layers = nn.ModuleList(
 
1630
  logits=logits,
1631
  hidden_states=model_outputs.hidden_states,
1632
  attentions=model_outputs.attentions,
1633
+ )