Clean up landmark patching
Browse files
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
@@ -28,15 +28,23 @@ from typing import List, Optional, Tuple, Union
|
|
28 |
import torch
|
29 |
import torch.utils.checkpoint
|
30 |
from torch import nn
|
31 |
-
from torch.nn import
|
32 |
-
from transformers.activations import ACT2FN
|
33 |
from transformers.modeling_outputs import (
|
34 |
BaseModelOutputWithPast,
|
35 |
CausalLMOutputWithPast,
|
36 |
-
SequenceClassifierOutputWithPast,
|
37 |
)
|
38 |
-
from transformers.modeling_utils import PreTrainedModel
|
39 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
from transformers.utils import (
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
@@ -51,131 +59,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
|
51 |
MEM_TOKEN = "<landmark>" # nosec
|
52 |
|
53 |
|
54 |
-
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
55 |
-
def _make_causal_mask(
|
56 |
-
input_ids_shape: torch.Size,
|
57 |
-
dtype: torch.dtype,
|
58 |
-
device: torch.device,
|
59 |
-
past_key_values_length: int = 0,
|
60 |
-
):
|
61 |
-
"""
|
62 |
-
Make causal mask used for bi-directional self-attention.
|
63 |
-
"""
|
64 |
-
bsz, tgt_len = input_ids_shape
|
65 |
-
mask = torch.full(
|
66 |
-
(tgt_len, tgt_len),
|
67 |
-
torch.tensor(torch.finfo(dtype).min, device=device),
|
68 |
-
device=device,
|
69 |
-
)
|
70 |
-
mask_cond = torch.arange(mask.size(-1), device=device)
|
71 |
-
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
72 |
-
mask = mask.to(dtype)
|
73 |
-
|
74 |
-
if past_key_values_length > 0:
|
75 |
-
mask = torch.cat(
|
76 |
-
[
|
77 |
-
torch.zeros(
|
78 |
-
tgt_len, past_key_values_length, dtype=dtype, device=device
|
79 |
-
),
|
80 |
-
mask,
|
81 |
-
],
|
82 |
-
dim=-1,
|
83 |
-
)
|
84 |
-
return mask[None, None, :, :].expand(
|
85 |
-
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
86 |
-
)
|
87 |
-
|
88 |
-
|
89 |
-
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
90 |
-
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
91 |
-
"""
|
92 |
-
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
93 |
-
"""
|
94 |
-
bsz, src_len = mask.size()
|
95 |
-
tgt_len = tgt_len if tgt_len is not None else src_len
|
96 |
-
|
97 |
-
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
98 |
-
|
99 |
-
inverted_mask = 1.0 - expanded_mask
|
100 |
-
|
101 |
-
return inverted_mask.masked_fill(
|
102 |
-
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
103 |
-
)
|
104 |
-
|
105 |
-
|
106 |
-
class LlamaRMSNorm(nn.Module):
|
107 |
-
def __init__(self, hidden_size, eps=1e-6):
|
108 |
-
"""
|
109 |
-
LlamaRMSNorm is equivalent to T5LayerNorm
|
110 |
-
"""
|
111 |
-
super().__init__()
|
112 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
113 |
-
self.variance_epsilon = eps
|
114 |
-
|
115 |
-
def forward(self, hidden_states):
|
116 |
-
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
117 |
-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
118 |
-
|
119 |
-
# convert into half-precision if necessary
|
120 |
-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
121 |
-
hidden_states = hidden_states.to(self.weight.dtype)
|
122 |
-
|
123 |
-
return self.weight * hidden_states
|
124 |
-
|
125 |
-
|
126 |
-
class LlamaRotaryEmbedding(torch.nn.Module):
|
127 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
128 |
-
super().__init__()
|
129 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
130 |
-
self.register_buffer("inv_freq", inv_freq)
|
131 |
-
|
132 |
-
# Build here to make `torch.jit.trace` work.
|
133 |
-
self.max_seq_len_cached = max_position_embeddings
|
134 |
-
t = torch.arange(
|
135 |
-
self.max_seq_len_cached,
|
136 |
-
device=self.inv_freq.device,
|
137 |
-
dtype=self.inv_freq.dtype,
|
138 |
-
)
|
139 |
-
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
140 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
141 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
142 |
-
self.register_buffer(
|
143 |
-
"cos_cached", emb.cos()[None, None, :, :], persistent=False
|
144 |
-
)
|
145 |
-
self.register_buffer(
|
146 |
-
"sin_cached", emb.sin()[None, None, :, :], persistent=False
|
147 |
-
)
|
148 |
-
|
149 |
-
def forward(self, x, seq_len=None):
|
150 |
-
# x: [bs, num_attention_heads, seq_len, head_size]
|
151 |
-
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
152 |
-
if seq_len > self.max_seq_len_cached:
|
153 |
-
self.max_seq_len_cached = seq_len
|
154 |
-
t = torch.arange(
|
155 |
-
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
|
156 |
-
)
|
157 |
-
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
158 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
159 |
-
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
160 |
-
self.register_buffer(
|
161 |
-
"cos_cached", emb.cos()[None, None, :, :], persistent=False
|
162 |
-
)
|
163 |
-
self.register_buffer(
|
164 |
-
"sin_cached", emb.sin()[None, None, :, :], persistent=False
|
165 |
-
)
|
166 |
-
return (
|
167 |
-
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
168 |
-
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
169 |
-
)
|
170 |
-
|
171 |
-
|
172 |
-
def rotate_half(x):
|
173 |
-
"""Rotates half the hidden dims of the input."""
|
174 |
-
x1 = x[..., : x.shape[-1] // 2]
|
175 |
-
x2 = x[..., x.shape[-1] // 2 :]
|
176 |
-
return torch.cat((-x2, x1), dim=-1)
|
177 |
-
|
178 |
-
|
179 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
180 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
181 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
@@ -190,24 +73,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
190 |
return q_embed, k_embed
|
191 |
|
192 |
|
193 |
-
class LlamaMLP(nn.Module):
|
194 |
-
def __init__(
|
195 |
-
self,
|
196 |
-
hidden_size: int,
|
197 |
-
intermediate_size: int,
|
198 |
-
hidden_act: str,
|
199 |
-
):
|
200 |
-
super().__init__()
|
201 |
-
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
202 |
-
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
203 |
-
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
204 |
-
self.act_fn = ACT2FN[hidden_act]
|
205 |
-
|
206 |
-
def forward(self, x):
|
207 |
-
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
208 |
-
|
209 |
-
|
210 |
class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
211 |
# Note that forward, setup_context, and backward are @staticmethods
|
212 |
@staticmethod
|
213 |
def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
|
@@ -682,16 +552,14 @@ class LlamaAttention(nn.Module):
|
|
682 |
# upcast attention to fp32
|
683 |
if is_mem is None:
|
684 |
raise ValueError("Don't use this without landmarks")
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
last_section_mask=last_section_mask,
|
694 |
-
).to(query_states.dtype)
|
695 |
if attn_prefix is not None:
|
696 |
attn_prefix, attn_weights = torch.split(
|
697 |
attn_weights,
|
@@ -722,6 +590,10 @@ class LlamaAttention(nn.Module):
|
|
722 |
|
723 |
|
724 |
class LlamaDecoderLayer(nn.Module):
|
|
|
|
|
|
|
|
|
725 |
def __init__(self, config: LlamaConfig):
|
726 |
super().__init__()
|
727 |
self.hidden_size = config.hidden_size
|
@@ -802,114 +674,6 @@ class LlamaDecoderLayer(nn.Module):
|
|
802 |
return outputs
|
803 |
|
804 |
|
805 |
-
LLAMA_START_DOCSTRING = r"""
|
806 |
-
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
807 |
-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
808 |
-
etc.)
|
809 |
-
|
810 |
-
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
811 |
-
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
812 |
-
and behavior.
|
813 |
-
|
814 |
-
Parameters:
|
815 |
-
config ([`LlamaConfig`]):
|
816 |
-
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
817 |
-
load the weights associated with the model, only the configuration. Check out the
|
818 |
-
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
819 |
-
"""
|
820 |
-
|
821 |
-
|
822 |
-
@add_start_docstrings(
|
823 |
-
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
824 |
-
LLAMA_START_DOCSTRING,
|
825 |
-
)
|
826 |
-
class LlamaPreTrainedModel(PreTrainedModel):
|
827 |
-
config_class = LlamaConfig
|
828 |
-
base_model_prefix = "model"
|
829 |
-
supports_gradient_checkpointing = True
|
830 |
-
_no_split_modules = ["LlamaDecoderLayer"]
|
831 |
-
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
832 |
-
|
833 |
-
def _init_weights(self, module):
|
834 |
-
std = self.config.initializer_range
|
835 |
-
if isinstance(module, nn.Linear):
|
836 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
837 |
-
if module.bias is not None:
|
838 |
-
module.bias.data.zero_()
|
839 |
-
elif isinstance(module, nn.Embedding):
|
840 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
841 |
-
if module.padding_idx is not None:
|
842 |
-
module.weight.data[module.padding_idx].zero_()
|
843 |
-
|
844 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
845 |
-
if isinstance(module, LlamaModel):
|
846 |
-
module.gradient_checkpointing = value
|
847 |
-
|
848 |
-
|
849 |
-
LLAMA_INPUTS_DOCSTRING = r"""
|
850 |
-
Args:
|
851 |
-
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
852 |
-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
853 |
-
it.
|
854 |
-
|
855 |
-
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
856 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
857 |
-
|
858 |
-
[What are input IDs?](../glossary#input-ids)
|
859 |
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
860 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
861 |
-
|
862 |
-
- 1 for tokens that are **not masked**,
|
863 |
-
- 0 for tokens that are **masked**.
|
864 |
-
|
865 |
-
[What are attention masks?](../glossary#attention-mask)
|
866 |
-
|
867 |
-
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
868 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
869 |
-
|
870 |
-
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
871 |
-
`past_key_values`).
|
872 |
-
|
873 |
-
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
874 |
-
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
875 |
-
information on the default strategy.
|
876 |
-
|
877 |
-
- 1 indicates the head is **not masked**,
|
878 |
-
- 0 indicates the head is **masked**.
|
879 |
-
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
880 |
-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
881 |
-
config.n_positions - 1]`.
|
882 |
-
|
883 |
-
[What are position IDs?](../glossary#position-ids)
|
884 |
-
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
885 |
-
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
886 |
-
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
887 |
-
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
888 |
-
|
889 |
-
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
890 |
-
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
891 |
-
|
892 |
-
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
893 |
-
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
894 |
-
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
895 |
-
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
896 |
-
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
897 |
-
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
898 |
-
model's internal embedding lookup matrix.
|
899 |
-
use_cache (`bool`, *optional*):
|
900 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
901 |
-
`past_key_values`).
|
902 |
-
output_attentions (`bool`, *optional*):
|
903 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
904 |
-
tensors for more detail.
|
905 |
-
output_hidden_states (`bool`, *optional*):
|
906 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
907 |
-
more detail.
|
908 |
-
return_dict (`bool`, *optional*):
|
909 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
910 |
-
"""
|
911 |
-
|
912 |
-
|
913 |
@add_start_docstrings(
|
914 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
915 |
LLAMA_START_DOCSTRING,
|
@@ -1178,6 +942,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
1178 |
|
1179 |
|
1180 |
class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
|
|
|
|
|
|
|
1181 |
def __init__(self, config):
|
1182 |
super().__init__(config)
|
1183 |
self.model = LlamaModel(config)
|
@@ -1448,149 +1216,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1448 |
return reordered_past
|
1449 |
|
1450 |
|
1451 |
-
@add_start_docstrings(
|
1452 |
-
"""
|
1453 |
-
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
1454 |
-
|
1455 |
-
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1456 |
-
(e.g. GPT-2) do.
|
1457 |
-
|
1458 |
-
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1459 |
-
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1460 |
-
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1461 |
-
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1462 |
-
each row of the batch).
|
1463 |
-
""",
|
1464 |
-
LLAMA_START_DOCSTRING,
|
1465 |
-
)
|
1466 |
-
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
1467 |
-
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
1468 |
-
|
1469 |
-
def __init__(self, config):
|
1470 |
-
super().__init__(config)
|
1471 |
-
self.num_labels = config.num_labels
|
1472 |
-
self.model = LlamaModel(config)
|
1473 |
-
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1474 |
-
|
1475 |
-
# Initialize weights and apply final processing
|
1476 |
-
self.post_init()
|
1477 |
-
|
1478 |
-
def get_input_embeddings(self):
|
1479 |
-
return self.model.embed_tokens
|
1480 |
-
|
1481 |
-
def set_input_embeddings(self, value):
|
1482 |
-
self.model.embed_tokens = value
|
1483 |
-
|
1484 |
-
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1485 |
-
def forward(
|
1486 |
-
self,
|
1487 |
-
input_ids: torch.LongTensor = None,
|
1488 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1489 |
-
position_ids: Optional[torch.LongTensor] = None,
|
1490 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1491 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1492 |
-
labels: Optional[torch.LongTensor] = None,
|
1493 |
-
use_cache: Optional[bool] = None,
|
1494 |
-
output_attentions: Optional[bool] = None,
|
1495 |
-
output_hidden_states: Optional[bool] = None,
|
1496 |
-
return_dict: Optional[bool] = None,
|
1497 |
-
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1498 |
-
r"""
|
1499 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1500 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1501 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1502 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1503 |
-
"""
|
1504 |
-
return_dict = (
|
1505 |
-
return_dict if return_dict is not None else self.config.use_return_dict
|
1506 |
-
)
|
1507 |
-
|
1508 |
-
transformer_outputs = self.model(
|
1509 |
-
input_ids,
|
1510 |
-
attention_mask=attention_mask,
|
1511 |
-
position_ids=position_ids,
|
1512 |
-
past_key_values=past_key_values,
|
1513 |
-
inputs_embeds=inputs_embeds,
|
1514 |
-
use_cache=use_cache,
|
1515 |
-
output_attentions=output_attentions,
|
1516 |
-
output_hidden_states=output_hidden_states,
|
1517 |
-
return_dict=return_dict,
|
1518 |
-
)
|
1519 |
-
hidden_states = transformer_outputs[0]
|
1520 |
-
logits = self.score(hidden_states)
|
1521 |
-
|
1522 |
-
if input_ids is not None:
|
1523 |
-
batch_size = input_ids.shape[0]
|
1524 |
-
else:
|
1525 |
-
batch_size = inputs_embeds.shape[0]
|
1526 |
-
|
1527 |
-
if self.config.pad_token_id is None and batch_size != 1:
|
1528 |
-
raise ValueError(
|
1529 |
-
"Cannot handle batch sizes > 1 if no padding token is defined."
|
1530 |
-
)
|
1531 |
-
if self.config.pad_token_id is None:
|
1532 |
-
sequence_lengths = -1
|
1533 |
-
else:
|
1534 |
-
if input_ids is not None:
|
1535 |
-
sequence_lengths = (
|
1536 |
-
torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
1537 |
-
).to(logits.device)
|
1538 |
-
else:
|
1539 |
-
sequence_lengths = -1
|
1540 |
-
|
1541 |
-
pooled_logits = logits[
|
1542 |
-
torch.arange(batch_size, device=logits.device), sequence_lengths
|
1543 |
-
]
|
1544 |
-
|
1545 |
-
loss = None
|
1546 |
-
if labels is not None:
|
1547 |
-
labels = labels.to(logits.device)
|
1548 |
-
if self.config.problem_type is None:
|
1549 |
-
if self.num_labels == 1:
|
1550 |
-
self.config.problem_type = "regression"
|
1551 |
-
elif self.num_labels > 1 and (
|
1552 |
-
labels.dtype == torch.long or labels.dtype == torch.int
|
1553 |
-
):
|
1554 |
-
self.config.problem_type = "single_label_classification"
|
1555 |
-
else:
|
1556 |
-
self.config.problem_type = "multi_label_classification"
|
1557 |
-
|
1558 |
-
if self.config.problem_type == "regression":
|
1559 |
-
loss_fct = MSELoss()
|
1560 |
-
if self.num_labels == 1:
|
1561 |
-
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1562 |
-
else:
|
1563 |
-
loss = loss_fct(pooled_logits, labels)
|
1564 |
-
elif self.config.problem_type == "single_label_classification":
|
1565 |
-
loss_fct = CrossEntropyLoss()
|
1566 |
-
loss = loss_fct(
|
1567 |
-
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
1568 |
-
)
|
1569 |
-
elif self.config.problem_type == "multi_label_classification":
|
1570 |
-
loss_fct = BCEWithLogitsLoss()
|
1571 |
-
loss = loss_fct(pooled_logits, labels)
|
1572 |
-
if not return_dict:
|
1573 |
-
output = (pooled_logits,) + transformer_outputs[1:]
|
1574 |
-
return ((loss,) + output) if loss is not None else output
|
1575 |
-
|
1576 |
-
return SequenceClassifierOutputWithPast(
|
1577 |
-
loss=loss,
|
1578 |
-
logits=pooled_logits,
|
1579 |
-
past_key_values=transformer_outputs.past_key_values,
|
1580 |
-
hidden_states=transformer_outputs.hidden_states,
|
1581 |
-
attentions=transformer_outputs.attentions,
|
1582 |
-
)
|
1583 |
-
|
1584 |
-
|
1585 |
def add_mem_tokens(example, mem_freq, mem_id):
|
1586 |
-
|
1587 |
ret = []
|
1588 |
prev_idx = 0
|
1589 |
-
for t_idx in range(mem_freq, len(
|
1590 |
-
ret.extend(
|
1591 |
ret.append(mem_id)
|
1592 |
prev_idx = t_idx
|
1593 |
-
ret.extend(
|
1594 |
# drop attention_mask
|
1595 |
return {"input_ids": ret}
|
1596 |
|
@@ -1602,3 +1236,4 @@ def patch_llama_with_landmark_attn():
|
|
1602 |
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
|
1603 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
1604 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
|
|
|
28 |
import torch
|
29 |
import torch.utils.checkpoint
|
30 |
from torch import nn
|
31 |
+
from torch.nn import CrossEntropyLoss
|
|
|
32 |
from transformers.modeling_outputs import (
|
33 |
BaseModelOutputWithPast,
|
34 |
CausalLMOutputWithPast,
|
|
|
35 |
)
|
|
|
36 |
from transformers.models.llama.configuration_llama import LlamaConfig
|
37 |
+
from transformers.models.llama.modeling_llama import (
|
38 |
+
LLAMA_INPUTS_DOCSTRING,
|
39 |
+
LLAMA_START_DOCSTRING,
|
40 |
+
LlamaMLP,
|
41 |
+
LlamaPreTrainedModel,
|
42 |
+
LlamaRMSNorm,
|
43 |
+
LlamaRotaryEmbedding,
|
44 |
+
_expand_mask,
|
45 |
+
_make_causal_mask,
|
46 |
+
rotate_half,
|
47 |
+
)
|
48 |
from transformers.utils import (
|
49 |
add_start_docstrings,
|
50 |
add_start_docstrings_to_model_forward,
|
|
|
59 |
MEM_TOKEN = "<landmark>" # nosec
|
60 |
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
63 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
64 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
|
|
73 |
return q_embed, k_embed
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
|
77 |
+
"""
|
78 |
+
Landmark grouped softmax function.
|
79 |
+
"""
|
80 |
+
|
81 |
# Note that forward, setup_context, and backward are @staticmethods
|
82 |
@staticmethod
|
83 |
def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
|
|
|
552 |
# upcast attention to fp32
|
553 |
if is_mem is None:
|
554 |
raise ValueError("Don't use this without landmarks")
|
555 |
+
|
556 |
+
attn_weights = landmark_grouped_softmax(
|
557 |
+
attn_weights,
|
558 |
+
dim=-1,
|
559 |
+
is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
|
560 |
+
last_section_mask=last_section_mask,
|
561 |
+
).to(query_states.dtype)
|
562 |
+
|
|
|
|
|
563 |
if attn_prefix is not None:
|
564 |
attn_prefix, attn_weights = torch.split(
|
565 |
attn_weights,
|
|
|
590 |
|
591 |
|
592 |
class LlamaDecoderLayer(nn.Module):
|
593 |
+
"""
|
594 |
+
Llama Decoder layer
|
595 |
+
"""
|
596 |
+
|
597 |
def __init__(self, config: LlamaConfig):
|
598 |
super().__init__()
|
599 |
self.hidden_size = config.hidden_size
|
|
|
674 |
return outputs
|
675 |
|
676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
@add_start_docstrings(
|
678 |
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
679 |
LLAMA_START_DOCSTRING,
|
|
|
942 |
|
943 |
|
944 |
class LlamaForCausalLM(LlamaPreTrainedModel):
|
945 |
+
"""
|
946 |
+
Llama model with a causal language modeling head.
|
947 |
+
"""
|
948 |
+
|
949 |
def __init__(self, config):
|
950 |
super().__init__(config)
|
951 |
self.model = LlamaModel(config)
|
|
|
1216 |
return reordered_past
|
1217 |
|
1218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1219 |
def add_mem_tokens(example, mem_freq, mem_id):
|
1220 |
+
ids = example["input_ids"]
|
1221 |
ret = []
|
1222 |
prev_idx = 0
|
1223 |
+
for t_idx in range(mem_freq, len(ids), mem_freq):
|
1224 |
+
ret.extend(ids[prev_idx:t_idx])
|
1225 |
ret.append(mem_id)
|
1226 |
prev_idx = t_idx
|
1227 |
+
ret.extend(ids[prev_idx:])
|
1228 |
# drop attention_mask
|
1229 |
return {"input_ids": ret}
|
1230 |
|
|
|
1236 |
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
|
1237 |
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
1238 |
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
1239 |
+
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|