Nanobit commited on
Commit
1c60c10
·
1 Parent(s): 903ea30

Lint flash_attn.py

Browse files
Files changed (1) hide show
  1. src/axolotl/flash_attn.py +8 -5
src/axolotl/flash_attn.py CHANGED
@@ -1,9 +1,10 @@
 
 
1
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
2
 
3
- from typing import List, Optional, Tuple
4
 
5
  import torch
6
- from torch import nn
7
 
8
  import transformers
9
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
14
  from flash_attn.bert_padding import unpad_input, pad_input
15
 
16
 
17
- def forward(
18
  self,
19
  hidden_states: torch.Tensor,
20
  attention_mask: Optional[torch.Tensor] = None,
@@ -82,6 +83,8 @@ def forward(
82
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
83
  else:
84
  nheads = qkv.shape[-2]
 
 
85
  x = rearrange(qkv, "b s three h d -> b s (three h d)")
86
  x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
87
  x_unpad = rearrange(
@@ -104,13 +107,13 @@ def forward(
104
  # requires the attention mask to be the same as the key_padding_mask
105
  def _prepare_decoder_attention_mask(
106
  self, attention_mask, input_shape, inputs_embeds, past_key_values_length
107
- ):
108
  # [bsz, seq_len]
109
  return attention_mask
110
 
111
 
112
  def replace_llama_attn_with_flash_attn():
113
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
114
  _prepare_decoder_attention_mask
115
  )
116
  transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
 
1
+ """Flash attention monkey patch for llama model"""
2
+
3
  # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
4
 
5
+ from typing import Optional, Tuple
6
 
7
  import torch
 
8
 
9
  import transformers
10
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
 
15
  from flash_attn.bert_padding import unpad_input, pad_input
16
 
17
 
18
+ def forward( # pylint: disable=too-many-arguments
19
  self,
20
  hidden_states: torch.Tensor,
21
  attention_mask: Optional[torch.Tensor] = None,
 
83
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
84
  else:
85
  nheads = qkv.shape[-2]
86
+
87
+ # pylint: disable=invalid-name
88
  x = rearrange(qkv, "b s three h d -> b s (three h d)")
89
  x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
90
  x_unpad = rearrange(
 
107
  # requires the attention mask to be the same as the key_padding_mask
108
  def _prepare_decoder_attention_mask(
109
  self, attention_mask, input_shape, inputs_embeds, past_key_values_length
110
+ ): # pylint: disable=unused-argument
111
  # [bsz, seq_len]
112
  return attention_mask
113
 
114
 
115
  def replace_llama_attn_with_flash_attn():
116
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
117
  _prepare_decoder_attention_mask
118
  )
119
  transformers.models.llama.modeling_llama.LlamaAttention.forward = forward