Fix grad checkpoint and outputs param
Browse files
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
@@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union
|
|
27 |
|
28 |
import torch
|
29 |
import torch.utils.checkpoint
|
30 |
-
import transformers
|
31 |
from torch import nn
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
from transformers.activations import ACT2FN
|
@@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
|
52 |
MEM_TOKEN = "<landmark>" # nosec
|
53 |
|
54 |
|
55 |
-
def hijack_llama_landmark_attn():
|
56 |
-
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
57 |
-
|
58 |
-
|
59 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
60 |
def _make_causal_mask(
|
61 |
input_ids_shape: torch.Size,
|
@@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1125 |
def create_custom_forward(module):
|
1126 |
def custom_forward(*inputs):
|
1127 |
# None for past_key_value
|
1128 |
-
return module(*inputs
|
1129 |
|
1130 |
return custom_forward
|
1131 |
|
@@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1135 |
attention_mask,
|
1136 |
position_ids,
|
1137 |
None,
|
|
|
|
|
1138 |
is_mem,
|
1139 |
last_section_mask,
|
1140 |
)
|
@@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1300 |
return_dict=return_dict,
|
1301 |
offload_cache_to_cpu=offload_cache_to_cpu,
|
1302 |
)
|
1303 |
-
past_key_values = outputs
|
1304 |
if last_logits is not None:
|
1305 |
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
|
1306 |
last_logits = outputs[0]
|
|
|
27 |
|
28 |
import torch
|
29 |
import torch.utils.checkpoint
|
|
|
30 |
from torch import nn
|
31 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
32 |
from transformers.activations import ACT2FN
|
|
|
51 |
MEM_TOKEN = "<landmark>" # nosec
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
54 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
55 |
def _make_causal_mask(
|
56 |
input_ids_shape: torch.Size,
|
|
|
1120 |
def create_custom_forward(module):
|
1121 |
def custom_forward(*inputs):
|
1122 |
# None for past_key_value
|
1123 |
+
return module(*inputs)
|
1124 |
|
1125 |
return custom_forward
|
1126 |
|
|
|
1130 |
attention_mask,
|
1131 |
position_ids,
|
1132 |
None,
|
1133 |
+
output_attentions,
|
1134 |
+
None,
|
1135 |
is_mem,
|
1136 |
last_section_mask,
|
1137 |
)
|
|
|
1297 |
return_dict=return_dict,
|
1298 |
offload_cache_to_cpu=offload_cache_to_cpu,
|
1299 |
)
|
1300 |
+
past_key_values = outputs.past_key_values
|
1301 |
if last_logits is not None:
|
1302 |
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
|
1303 |
last_logits = outputs[0]
|