Yingxu He
commited on
Upload MERaLiONForConditionalGeneration
Browse files- config.json +6 -2
- modeling_text_decoder.py +361 -133
config.json
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
{
|
2 |
-
"
|
|
|
|
|
3 |
"auto_map": {
|
4 |
-
"AutoConfig": "configuration_meralion.MERaLiONConfig"
|
|
|
5 |
},
|
6 |
"head_dim": 256,
|
7 |
"hidden_size": 3584,
|
@@ -163,5 +166,6 @@
|
|
163 |
"sliding_window_size": 4096,
|
164 |
"torch_dtype": "bfloat16"
|
165 |
},
|
|
|
166 |
"transformers_version": "4.46.3"
|
167 |
}
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"MERaLiONForConditionalGeneration"
|
4 |
+
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_meralion.MERaLiONConfig",
|
7 |
+
"AutoModelForSpeechSeq2Seq": "modeling_meralion.MERaLiONForConditionalGeneration"
|
8 |
},
|
9 |
"head_dim": 256,
|
10 |
"hidden_size": 3584,
|
|
|
166 |
"sliding_window_size": 4096,
|
167 |
"torch_dtype": "bfloat16"
|
168 |
},
|
169 |
+
"torch_dtype": "bfloat16",
|
170 |
"transformers_version": "4.46.3"
|
171 |
}
|
modeling_text_decoder.py
CHANGED
@@ -16,21 +16,24 @@
|
|
16 |
from typing import List, Optional, Tuple, Union
|
17 |
|
18 |
import torch
|
|
|
19 |
import torch.utils.checkpoint
|
20 |
-
from torch import nn
|
21 |
-
from torch.nn import CrossEntropyLoss
|
22 |
|
23 |
from transformers.activations import ACT2FN
|
24 |
from transformers.cache_utils import Cache, HybridCache
|
|
|
|
|
25 |
from transformers.modeling_outputs import (
|
26 |
BaseModelOutputWithPast,
|
27 |
CausalLMOutputWithPast,
|
|
|
|
|
28 |
)
|
29 |
from transformers.modeling_utils import PreTrainedModel
|
30 |
from transformers.utils import (
|
|
|
31 |
add_start_docstrings,
|
32 |
add_start_docstrings_to_model_forward,
|
33 |
-
is_flash_attn_2_available,
|
34 |
is_flash_attn_greater_or_equal,
|
35 |
is_flash_attn_greater_or_equal_2_10,
|
36 |
logging,
|
@@ -39,64 +42,7 @@ from transformers.utils import (
|
|
39 |
from .configuration_meralion import MERaLiONTextConfig
|
40 |
|
41 |
|
42 |
-
|
43 |
-
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
44 |
-
|
45 |
-
|
46 |
-
logger = logging.get_logger(__name__)
|
47 |
-
|
48 |
-
|
49 |
-
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
50 |
-
def _prepare_4d_causal_attention_mask_with_cache_position(
|
51 |
-
attention_mask: torch.Tensor,
|
52 |
-
sequence_length: int,
|
53 |
-
target_length: int,
|
54 |
-
dtype: torch.dtype,
|
55 |
-
device: torch.device,
|
56 |
-
min_dtype: float,
|
57 |
-
cache_position: torch.Tensor,
|
58 |
-
batch_size: int,
|
59 |
-
):
|
60 |
-
"""
|
61 |
-
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
62 |
-
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
attention_mask (`torch.Tensor`):
|
66 |
-
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
67 |
-
sequence_length (`int`):
|
68 |
-
The sequence length being processed.
|
69 |
-
target_length (`int`):
|
70 |
-
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
71 |
-
dtype (`torch.dtype`):
|
72 |
-
The dtype to use for the 4D attention mask.
|
73 |
-
device (`torch.device`):
|
74 |
-
The device to plcae the 4D attention mask on.
|
75 |
-
min_dtype (`float`):
|
76 |
-
The minimum value representable with the dtype `dtype`.
|
77 |
-
cache_position (`torch.Tensor`):
|
78 |
-
Indices depicting the position of the input sequence tokens in the sequence.
|
79 |
-
batch_size (`torch.Tensor`):
|
80 |
-
Batch size.
|
81 |
-
"""
|
82 |
-
if attention_mask is not None and attention_mask.dim() == 4:
|
83 |
-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
84 |
-
causal_mask = attention_mask
|
85 |
-
else:
|
86 |
-
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
87 |
-
if sequence_length != 1:
|
88 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
89 |
-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
90 |
-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
91 |
-
if attention_mask is not None:
|
92 |
-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
93 |
-
mask_length = attention_mask.shape[-1]
|
94 |
-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
95 |
-
padding_mask = padding_mask == 0
|
96 |
-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
97 |
-
padding_mask, min_dtype
|
98 |
-
)
|
99 |
-
return causal_mask
|
100 |
|
101 |
|
102 |
class MERaLiONTextRMSNorm(nn.Module):
|
@@ -119,6 +65,24 @@ class MERaLiONTextRMSNorm(nn.Module):
|
|
119 |
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
120 |
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
class MERaLiONTextRotaryEmbedding(nn.Module):
|
123 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
124 |
super().__init__()
|
@@ -181,21 +145,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
181 |
return q_embed, k_embed
|
182 |
|
183 |
|
184 |
-
class MERaLiONTextMLP(nn.Module):
|
185 |
-
def __init__(self, config):
|
186 |
-
super().__init__()
|
187 |
-
self.config = config
|
188 |
-
self.hidden_size = config.hidden_size
|
189 |
-
self.intermediate_size = config.intermediate_size
|
190 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
191 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
192 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
193 |
-
self.act_fn = ACT2FN[config.hidden_activation]
|
194 |
-
|
195 |
-
def forward(self, x):
|
196 |
-
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
197 |
-
|
198 |
-
|
199 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
200 |
"""
|
201 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
@@ -243,12 +192,12 @@ class MERaLiONTextAttention(nn.Module):
|
|
243 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
244 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
245 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
|
|
246 |
self.rotary_emb = MERaLiONTextRotaryEmbedding(
|
247 |
self.head_dim,
|
248 |
max_position_embeddings=self.max_position_embeddings,
|
249 |
base=self.rope_theta,
|
250 |
)
|
251 |
-
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
252 |
|
253 |
def forward(
|
254 |
self,
|
@@ -492,9 +441,11 @@ class MERaLiONTextSdpaAttention(MERaLiONTextAttention):
|
|
492 |
|
493 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
494 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
495 |
causal_mask = attention_mask
|
496 |
if attention_mask is not None:
|
497 |
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
|
|
498 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
499 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
500 |
if query_states.device.type == "cuda" and causal_mask is not None:
|
@@ -505,6 +456,7 @@ class MERaLiONTextSdpaAttention(MERaLiONTextAttention):
|
|
505 |
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
506 |
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
507 |
is_causal = True if causal_mask is None and q_len > 1 else False
|
|
|
508 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
509 |
query_states,
|
510 |
key_states,
|
@@ -523,7 +475,7 @@ class MERaLiONTextSdpaAttention(MERaLiONTextAttention):
|
|
523 |
return attn_output, None, past_key_value
|
524 |
|
525 |
|
526 |
-
|
527 |
"eager": MERaLiONTextAttention,
|
528 |
"flash_attention_2": MERaLiONTextFlashAttention2,
|
529 |
"sdpa": MERaLiONTextSdpaAttention,
|
@@ -533,19 +485,16 @@ MERaLiONText_ATTENTION_CLASSES = {
|
|
533 |
class MERaLiONTextDecoderLayer(nn.Module):
|
534 |
def __init__(self, config: MERaLiONTextConfig, layer_idx: int):
|
535 |
super().__init__()
|
536 |
-
self.config = config
|
537 |
self.hidden_size = config.hidden_size
|
538 |
-
|
539 |
-
self.self_attn = MERaLiONText_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
540 |
-
|
541 |
self.mlp = MERaLiONTextMLP(config)
|
542 |
self.input_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
543 |
-
self.
|
544 |
-
|
545 |
self.is_sliding = not bool(layer_idx % 2)
|
546 |
self.pre_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
547 |
self.post_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
548 |
self.sliding_window = config.sliding_window
|
|
|
549 |
|
550 |
def forward(
|
551 |
self,
|
@@ -557,6 +506,25 @@ class MERaLiONTextDecoderLayer(nn.Module):
|
|
557 |
use_cache: Optional[bool] = False,
|
558 |
cache_position: Optional[torch.LongTensor] = None,
|
559 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
561 |
# Flash-attn is a 2D tensor
|
562 |
if self.config._attn_implementation == "flash_attention_2":
|
@@ -570,6 +538,7 @@ class MERaLiONTextDecoderLayer(nn.Module):
|
|
570 |
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
571 |
if attention_mask.shape[-1] <= 1: # when decoding
|
572 |
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
|
|
573 |
residual = hidden_states
|
574 |
|
575 |
hidden_states = self.input_layernorm(hidden_states)
|
@@ -648,6 +617,20 @@ class MERaLiONTextPreTrainedModel(PreTrainedModel):
|
|
648 |
if module.padding_idx is not None:
|
649 |
module.weight.data[module.padding_idx].zero_()
|
650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
|
652 |
_CONFIG_FOR_DOC = "MERaLiONTextConfig"
|
653 |
|
@@ -693,7 +676,8 @@ MERALION_TEXT_INPUTS_DOCSTRING = r"""
|
|
693 |
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
694 |
|
695 |
Two formats are allowed:
|
696 |
-
- a [`~cache_utils.Cache`] instance
|
|
|
697 |
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
698 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
699 |
cache format.
|
@@ -765,7 +749,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
|
|
765 |
input_ids: torch.LongTensor = None,
|
766 |
attention_mask: Optional[torch.Tensor] = None,
|
767 |
position_ids: Optional[torch.LongTensor] = None,
|
768 |
-
past_key_values: Optional[
|
769 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
770 |
use_cache: Optional[bool] = None,
|
771 |
output_attentions: Optional[bool] = None,
|
@@ -781,9 +765,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
|
|
781 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
782 |
|
783 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
784 |
-
raise ValueError(
|
785 |
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
786 |
-
)
|
787 |
|
788 |
if self.gradient_checkpointing and self.training and use_cache:
|
789 |
logger.warning_once(
|
@@ -794,8 +776,21 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
|
|
794 |
if inputs_embeds is None:
|
795 |
inputs_embeds = self.embed_tokens(input_ids)
|
796 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
797 |
if cache_position is None:
|
798 |
-
|
|
|
|
|
|
|
799 |
|
800 |
if position_ids is None:
|
801 |
position_ids = cache_position.unsqueeze(0)
|
@@ -813,6 +808,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
|
|
813 |
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
814 |
hidden_states = hidden_states * normalizer
|
815 |
|
|
|
816 |
all_hidden_states = () if output_hidden_states else None
|
817 |
all_self_attns = () if output_attentions else None
|
818 |
|
@@ -849,7 +845,6 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
|
|
849 |
|
850 |
hidden_states = self.norm(hidden_states)
|
851 |
|
852 |
-
# add hidden states from the last decoder layer
|
853 |
if output_hidden_states:
|
854 |
all_hidden_states += (hidden_states,)
|
855 |
|
@@ -869,7 +864,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
|
|
869 |
attention_mask: torch.Tensor,
|
870 |
input_tensor: torch.Tensor,
|
871 |
cache_position: torch.Tensor,
|
872 |
-
past_key_values:
|
873 |
output_attentions: bool,
|
874 |
):
|
875 |
# Flash Attention currently doesn't support static cache but MERaLiONText work only with static cache.
|
@@ -880,28 +875,82 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
|
|
880 |
return attention_mask
|
881 |
|
882 |
dtype, device = input_tensor.dtype, input_tensor.device
|
883 |
-
min_dtype = torch.finfo(dtype).min
|
884 |
sequence_length = input_tensor.shape[1]
|
885 |
if isinstance(past_key_values, HybridCache):
|
886 |
-
target_length = past_key_values.
|
887 |
else:
|
888 |
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
889 |
|
890 |
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
891 |
-
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
892 |
attention_mask,
|
893 |
sequence_length=sequence_length,
|
894 |
target_length=target_length,
|
895 |
dtype=dtype,
|
896 |
device=device,
|
897 |
-
min_dtype=min_dtype,
|
898 |
cache_position=cache_position,
|
899 |
batch_size=input_tensor.shape[0],
|
900 |
)
|
901 |
return causal_mask
|
902 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
903 |
|
904 |
-
class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
905 |
_tied_weights_keys = ["lm_head.weight"]
|
906 |
|
907 |
def __init__(self, config):
|
@@ -938,7 +987,7 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
938 |
input_ids: torch.LongTensor = None,
|
939 |
attention_mask: Optional[torch.Tensor] = None,
|
940 |
position_ids: Optional[torch.LongTensor] = None,
|
941 |
-
past_key_values: Optional[
|
942 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
943 |
labels: Optional[torch.LongTensor] = None,
|
944 |
use_cache: Optional[bool] = None,
|
@@ -946,6 +995,8 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
946 |
output_hidden_states: Optional[bool] = None,
|
947 |
return_dict: Optional[bool] = None,
|
948 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
|
949 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
950 |
r"""
|
951 |
Args:
|
@@ -954,24 +1005,14 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
954 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
955 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
956 |
|
957 |
-
|
958 |
-
|
959 |
-
|
|
|
960 |
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
|
965 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
966 |
-
|
967 |
-
>>> prompt = "What is your favorite condiment?"
|
968 |
-
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
969 |
|
970 |
-
>>> # Generate
|
971 |
-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
972 |
-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
973 |
-
"What is your favorite condiment?"
|
974 |
-
```"""
|
975 |
if self.training and self.config._attn_implementation != "eager":
|
976 |
logger.warning_once(
|
977 |
"It is strongly recommended to train MERaLiONText models with the `eager` attention implementation "
|
@@ -983,7 +1024,6 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
983 |
)
|
984 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
985 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
986 |
-
|
987 |
outputs = self.model(
|
988 |
input_ids=input_ids,
|
989 |
attention_mask=attention_mask,
|
@@ -998,25 +1038,16 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
998 |
)
|
999 |
|
1000 |
hidden_states = outputs[0]
|
1001 |
-
logits
|
|
|
1002 |
if self.config.final_logit_softcapping is not None:
|
1003 |
logits = logits / self.config.final_logit_softcapping
|
1004 |
logits = torch.tanh(logits)
|
1005 |
logits = logits * self.config.final_logit_softcapping
|
1006 |
|
1007 |
-
logits = logits.float()
|
1008 |
loss = None
|
1009 |
if labels is not None:
|
1010 |
-
|
1011 |
-
shift_logits = logits[..., :-1, :].contiguous()
|
1012 |
-
shift_labels = labels[..., 1:].contiguous()
|
1013 |
-
# Flatten the tokens
|
1014 |
-
loss_fct = CrossEntropyLoss()
|
1015 |
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1016 |
-
shift_labels = shift_labels.view(-1)
|
1017 |
-
# Enable model parallelism
|
1018 |
-
shift_labels = shift_labels.to(shift_logits.device)
|
1019 |
-
loss = loss_fct(shift_logits, shift_labels)
|
1020 |
|
1021 |
if not return_dict:
|
1022 |
output = (logits,) + outputs[1:]
|
@@ -1039,8 +1070,11 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
1039 |
cache_position=None,
|
1040 |
position_ids=None,
|
1041 |
use_cache=True,
|
|
|
1042 |
**kwargs,
|
1043 |
):
|
|
|
|
|
1044 |
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1045 |
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1046 |
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
@@ -1080,18 +1114,20 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
1080 |
else:
|
1081 |
batch_size, sequence_length = model_inputs["input_ids"].shape
|
1082 |
device = model_inputs["input_ids"].device
|
1083 |
-
|
1084 |
-
|
1085 |
-
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
1086 |
attention_mask,
|
1087 |
sequence_length=sequence_length,
|
1088 |
-
target_length=past_key_values.
|
1089 |
-
dtype=dtype,
|
1090 |
device=device,
|
1091 |
-
min_dtype=min_dtype,
|
1092 |
cache_position=cache_position,
|
1093 |
batch_size=batch_size,
|
1094 |
)
|
|
|
|
|
|
|
|
|
1095 |
model_inputs.update(
|
1096 |
{
|
1097 |
"position_ids": position_ids,
|
@@ -1102,3 +1138,195 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
|
|
1102 |
}
|
1103 |
)
|
1104 |
return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from typing import List, Optional, Tuple, Union
|
17 |
|
18 |
import torch
|
19 |
+
import torch.nn as nn
|
20 |
import torch.utils.checkpoint
|
|
|
|
|
21 |
|
22 |
from transformers.activations import ACT2FN
|
23 |
from transformers.cache_utils import Cache, HybridCache
|
24 |
+
from transformers.generation import GenerationMixin
|
25 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
26 |
from transformers.modeling_outputs import (
|
27 |
BaseModelOutputWithPast,
|
28 |
CausalLMOutputWithPast,
|
29 |
+
SequenceClassifierOutputWithPast,
|
30 |
+
TokenClassifierOutput,
|
31 |
)
|
32 |
from transformers.modeling_utils import PreTrainedModel
|
33 |
from transformers.utils import (
|
34 |
+
add_code_sample_docstrings,
|
35 |
add_start_docstrings,
|
36 |
add_start_docstrings_to_model_forward,
|
|
|
37 |
is_flash_attn_greater_or_equal,
|
38 |
is_flash_attn_greater_or_equal_2_10,
|
39 |
logging,
|
|
|
42 |
from .configuration_meralion import MERaLiONTextConfig
|
43 |
|
44 |
|
45 |
+
_CHECKPOINT_FOR_DOC = "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
class MERaLiONTextRMSNorm(nn.Module):
|
|
|
65 |
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
66 |
|
67 |
|
68 |
+
class MERaLiONTextMLP(nn.Module):
|
69 |
+
def __init__(self, config):
|
70 |
+
super().__init__()
|
71 |
+
self.config = config
|
72 |
+
self.hidden_size = config.hidden_size
|
73 |
+
self.intermediate_size = config.intermediate_size
|
74 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
75 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
76 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
77 |
+
self.act_fn = ACT2FN[config.hidden_activation]
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
81 |
+
|
82 |
+
|
83 |
+
logger = logging.get_logger(__name__)
|
84 |
+
|
85 |
+
|
86 |
class MERaLiONTextRotaryEmbedding(nn.Module):
|
87 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
88 |
super().__init__()
|
|
|
145 |
return q_embed, k_embed
|
146 |
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
149 |
"""
|
150 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
|
192 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
193 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
194 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
195 |
+
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
|
196 |
self.rotary_emb = MERaLiONTextRotaryEmbedding(
|
197 |
self.head_dim,
|
198 |
max_position_embeddings=self.max_position_embeddings,
|
199 |
base=self.rope_theta,
|
200 |
)
|
|
|
201 |
|
202 |
def forward(
|
203 |
self,
|
|
|
441 |
|
442 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
443 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
444 |
+
|
445 |
causal_mask = attention_mask
|
446 |
if attention_mask is not None:
|
447 |
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
448 |
+
|
449 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
450 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
451 |
if query_states.device.type == "cuda" and causal_mask is not None:
|
|
|
456 |
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
457 |
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
458 |
is_causal = True if causal_mask is None and q_len > 1 else False
|
459 |
+
|
460 |
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
461 |
query_states,
|
462 |
key_states,
|
|
|
475 |
return attn_output, None, past_key_value
|
476 |
|
477 |
|
478 |
+
MERALION_TEXT_ATTENTION_CLASSES = {
|
479 |
"eager": MERaLiONTextAttention,
|
480 |
"flash_attention_2": MERaLiONTextFlashAttention2,
|
481 |
"sdpa": MERaLiONTextSdpaAttention,
|
|
|
485 |
class MERaLiONTextDecoderLayer(nn.Module):
|
486 |
def __init__(self, config: MERaLiONTextConfig, layer_idx: int):
|
487 |
super().__init__()
|
|
|
488 |
self.hidden_size = config.hidden_size
|
489 |
+
self.self_attn = MERALION_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
|
|
|
|
490 |
self.mlp = MERaLiONTextMLP(config)
|
491 |
self.input_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
492 |
+
self.config = config
|
|
|
493 |
self.is_sliding = not bool(layer_idx % 2)
|
494 |
self.pre_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
495 |
self.post_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
496 |
self.sliding_window = config.sliding_window
|
497 |
+
self.post_attention_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
498 |
|
499 |
def forward(
|
500 |
self,
|
|
|
506 |
use_cache: Optional[bool] = False,
|
507 |
cache_position: Optional[torch.LongTensor] = None,
|
508 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
509 |
+
"""
|
510 |
+
Args:
|
511 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
512 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
513 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
514 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
515 |
+
output_attentions (`bool`, *optional*):
|
516 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
517 |
+
returned tensors for more detail.
|
518 |
+
use_cache (`bool`, *optional*):
|
519 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
520 |
+
(see `past_key_values`).
|
521 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
522 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
523 |
+
Indices depicting the position of the input sequence tokens in the sequence
|
524 |
+
kwargs (`dict`, *optional*):
|
525 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
526 |
+
into the model
|
527 |
+
"""
|
528 |
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
|
529 |
# Flash-attn is a 2D tensor
|
530 |
if self.config._attn_implementation == "flash_attention_2":
|
|
|
538 |
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
|
539 |
if attention_mask.shape[-1] <= 1: # when decoding
|
540 |
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
|
541 |
+
|
542 |
residual = hidden_states
|
543 |
|
544 |
hidden_states = self.input_layernorm(hidden_states)
|
|
|
617 |
if module.padding_idx is not None:
|
618 |
module.weight.data[module.padding_idx].zero_()
|
619 |
|
620 |
+
@classmethod
|
621 |
+
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
|
622 |
+
"""
|
623 |
+
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on MERaLiONText models.
|
624 |
+
SDPA reduces the model performance on MERaLiONText because of the logits softcapping.
|
625 |
+
"""
|
626 |
+
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
|
627 |
+
|
628 |
+
# if using the default path -> swap sdpa by eager
|
629 |
+
if not hard_check_only and config._attn_implementation == "sdpa":
|
630 |
+
config._attn_implementation = "eager"
|
631 |
+
|
632 |
+
return config
|
633 |
+
|
634 |
|
635 |
_CONFIG_FOR_DOC = "MERaLiONTextConfig"
|
636 |
|
|
|
676 |
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
677 |
|
678 |
Two formats are allowed:
|
679 |
+
- a [`~cache_utils.Cache`] instance, see our
|
680 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
681 |
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
682 |
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
683 |
cache format.
|
|
|
749 |
input_ids: torch.LongTensor = None,
|
750 |
attention_mask: Optional[torch.Tensor] = None,
|
751 |
position_ids: Optional[torch.LongTensor] = None,
|
752 |
+
past_key_values: Optional[HybridCache] = None,
|
753 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
754 |
use_cache: Optional[bool] = None,
|
755 |
output_attentions: Optional[bool] = None,
|
|
|
765 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
766 |
|
767 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
768 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
|
769 |
|
770 |
if self.gradient_checkpointing and self.training and use_cache:
|
771 |
logger.warning_once(
|
|
|
776 |
if inputs_embeds is None:
|
777 |
inputs_embeds = self.embed_tokens(input_ids)
|
778 |
|
779 |
+
if use_cache and past_key_values is None and not self.training:
|
780 |
+
batch_size, seq_len, _ = inputs_embeds.shape
|
781 |
+
past_key_values = HybridCache(
|
782 |
+
self.config,
|
783 |
+
batch_size=batch_size,
|
784 |
+
max_cache_len=seq_len,
|
785 |
+
device=self.device,
|
786 |
+
dtype=inputs_embeds.dtype,
|
787 |
+
)
|
788 |
+
|
789 |
if cache_position is None:
|
790 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
791 |
+
cache_position = torch.arange(
|
792 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
793 |
+
)
|
794 |
|
795 |
if position_ids is None:
|
796 |
position_ids = cache_position.unsqueeze(0)
|
|
|
808 |
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
809 |
hidden_states = hidden_states * normalizer
|
810 |
|
811 |
+
# decoder layers
|
812 |
all_hidden_states = () if output_hidden_states else None
|
813 |
all_self_attns = () if output_attentions else None
|
814 |
|
|
|
845 |
|
846 |
hidden_states = self.norm(hidden_states)
|
847 |
|
|
|
848 |
if output_hidden_states:
|
849 |
all_hidden_states += (hidden_states,)
|
850 |
|
|
|
864 |
attention_mask: torch.Tensor,
|
865 |
input_tensor: torch.Tensor,
|
866 |
cache_position: torch.Tensor,
|
867 |
+
past_key_values: HybridCache,
|
868 |
output_attentions: bool,
|
869 |
):
|
870 |
# Flash Attention currently doesn't support static cache but MERaLiONText work only with static cache.
|
|
|
875 |
return attention_mask
|
876 |
|
877 |
dtype, device = input_tensor.dtype, input_tensor.device
|
|
|
878 |
sequence_length = input_tensor.shape[1]
|
879 |
if isinstance(past_key_values, HybridCache):
|
880 |
+
target_length = past_key_values.get_max_cache_shape()
|
881 |
else:
|
882 |
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
|
883 |
|
884 |
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
885 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
886 |
attention_mask,
|
887 |
sequence_length=sequence_length,
|
888 |
target_length=target_length,
|
889 |
dtype=dtype,
|
890 |
device=device,
|
|
|
891 |
cache_position=cache_position,
|
892 |
batch_size=input_tensor.shape[0],
|
893 |
)
|
894 |
return causal_mask
|
895 |
|
896 |
+
@staticmethod
|
897 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
898 |
+
attention_mask: torch.Tensor,
|
899 |
+
sequence_length: int,
|
900 |
+
target_length: int,
|
901 |
+
dtype: torch.dtype,
|
902 |
+
device: torch.device,
|
903 |
+
cache_position: torch.Tensor,
|
904 |
+
batch_size: int,
|
905 |
+
**kwargs,
|
906 |
+
):
|
907 |
+
"""
|
908 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
909 |
+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
910 |
+
|
911 |
+
Args:
|
912 |
+
attention_mask (`torch.Tensor`):
|
913 |
+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
914 |
+
`(batch_size, 1, query_length, key_value_length)`.
|
915 |
+
sequence_length (`int`):
|
916 |
+
The sequence length being processed.
|
917 |
+
target_length (`int`):
|
918 |
+
The target length: when generating with static cache, the mask should be as long as the static cache,
|
919 |
+
to account for the 0 padding, the part of the cache that is not filled yet.
|
920 |
+
dtype (`torch.dtype`):
|
921 |
+
The dtype to use for the 4D attention mask.
|
922 |
+
device (`torch.device`):
|
923 |
+
The device to plcae the 4D attention mask on.
|
924 |
+
cache_position (`torch.Tensor`):
|
925 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
926 |
+
batch_size (`torch.Tensor`):
|
927 |
+
Batch size.
|
928 |
+
"""
|
929 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
930 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
931 |
+
causal_mask = attention_mask
|
932 |
+
else:
|
933 |
+
min_dtype = torch.finfo(dtype).min
|
934 |
+
causal_mask = torch.full(
|
935 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
936 |
+
)
|
937 |
+
if sequence_length != 1:
|
938 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
939 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
940 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
941 |
+
if attention_mask is not None:
|
942 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
943 |
+
mask_length = attention_mask.shape[-1]
|
944 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
945 |
+
padding_mask = padding_mask == 0
|
946 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
947 |
+
padding_mask, min_dtype
|
948 |
+
)
|
949 |
+
|
950 |
+
return causal_mask
|
951 |
+
|
952 |
|
953 |
+
class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel, GenerationMixin):
|
954 |
_tied_weights_keys = ["lm_head.weight"]
|
955 |
|
956 |
def __init__(self, config):
|
|
|
987 |
input_ids: torch.LongTensor = None,
|
988 |
attention_mask: Optional[torch.Tensor] = None,
|
989 |
position_ids: Optional[torch.LongTensor] = None,
|
990 |
+
past_key_values: Optional[HybridCache] = None,
|
991 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
992 |
labels: Optional[torch.LongTensor] = None,
|
993 |
use_cache: Optional[bool] = None,
|
|
|
995 |
output_hidden_states: Optional[bool] = None,
|
996 |
return_dict: Optional[bool] = None,
|
997 |
cache_position: Optional[torch.LongTensor] = None,
|
998 |
+
num_logits_to_keep: int = 0,
|
999 |
+
**loss_kwargs,
|
1000 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1001 |
r"""
|
1002 |
Args:
|
|
|
1005 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1006 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1007 |
|
1008 |
+
num_logits_to_keep (`int`, *optional*):
|
1009 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
1010 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
1011 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
1012 |
|
1013 |
+
Returns:
|
1014 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
1015 |
|
|
|
|
|
|
|
|
|
|
|
1016 |
if self.training and self.config._attn_implementation != "eager":
|
1017 |
logger.warning_once(
|
1018 |
"It is strongly recommended to train MERaLiONText models with the `eager` attention implementation "
|
|
|
1024 |
)
|
1025 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1026 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
1027 |
outputs = self.model(
|
1028 |
input_ids=input_ids,
|
1029 |
attention_mask=attention_mask,
|
|
|
1038 |
)
|
1039 |
|
1040 |
hidden_states = outputs[0]
|
1041 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
1042 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
1043 |
if self.config.final_logit_softcapping is not None:
|
1044 |
logits = logits / self.config.final_logit_softcapping
|
1045 |
logits = torch.tanh(logits)
|
1046 |
logits = logits * self.config.final_logit_softcapping
|
1047 |
|
|
|
1048 |
loss = None
|
1049 |
if labels is not None:
|
1050 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1051 |
|
1052 |
if not return_dict:
|
1053 |
output = (logits,) + outputs[1:]
|
|
|
1070 |
cache_position=None,
|
1071 |
position_ids=None,
|
1072 |
use_cache=True,
|
1073 |
+
num_logits_to_keep=None,
|
1074 |
**kwargs,
|
1075 |
):
|
1076 |
+
# Overwritten: has a special cache type, `HybridCache`
|
1077 |
+
|
1078 |
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1079 |
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1080 |
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
|
|
1114 |
else:
|
1115 |
batch_size, sequence_length = model_inputs["input_ids"].shape
|
1116 |
device = model_inputs["input_ids"].device
|
1117 |
+
|
1118 |
+
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
|
|
1119 |
attention_mask,
|
1120 |
sequence_length=sequence_length,
|
1121 |
+
target_length=past_key_values.get_max_cache_shape(),
|
1122 |
+
dtype=self.lm_head.weight.dtype,
|
1123 |
device=device,
|
|
|
1124 |
cache_position=cache_position,
|
1125 |
batch_size=batch_size,
|
1126 |
)
|
1127 |
+
|
1128 |
+
if num_logits_to_keep is not None:
|
1129 |
+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
1130 |
+
|
1131 |
model_inputs.update(
|
1132 |
{
|
1133 |
"position_ids": position_ids,
|
|
|
1138 |
}
|
1139 |
)
|
1140 |
return model_inputs
|
1141 |
+
|
1142 |
+
|
1143 |
+
@add_start_docstrings(
|
1144 |
+
"""
|
1145 |
+
The MERaLiONText Model transformer with a sequence classification head on top (linear layer).
|
1146 |
+
|
1147 |
+
[`MERaLiONTextForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1148 |
+
(e.g. GPT-2) do.
|
1149 |
+
|
1150 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1151 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1152 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1153 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1154 |
+
each row of the batch).
|
1155 |
+
""",
|
1156 |
+
MERALION_TEXT_START_DOCSTRING,
|
1157 |
+
)
|
1158 |
+
class MERaLiONTextForSequenceClassification(MERaLiONTextPreTrainedModel):
|
1159 |
+
def __init__(self, config):
|
1160 |
+
super().__init__(config)
|
1161 |
+
self.num_labels = config.num_labels
|
1162 |
+
self.model = MERaLiONTextModel(config)
|
1163 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1164 |
+
|
1165 |
+
# Initialize weights and apply final processing
|
1166 |
+
self.post_init()
|
1167 |
+
|
1168 |
+
def get_input_embeddings(self):
|
1169 |
+
return self.model.embed_tokens
|
1170 |
+
|
1171 |
+
def set_input_embeddings(self, value):
|
1172 |
+
self.model.embed_tokens = value
|
1173 |
+
|
1174 |
+
@add_start_docstrings_to_model_forward(MERALION_TEXT_INPUTS_DOCSTRING)
|
1175 |
+
def forward(
|
1176 |
+
self,
|
1177 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1178 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1179 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1180 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
1181 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1182 |
+
labels: Optional[torch.LongTensor] = None,
|
1183 |
+
use_cache: Optional[bool] = None,
|
1184 |
+
output_attentions: Optional[bool] = None,
|
1185 |
+
output_hidden_states: Optional[bool] = None,
|
1186 |
+
return_dict: Optional[bool] = None,
|
1187 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1188 |
+
r"""
|
1189 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1190 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1191 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1192 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1193 |
+
"""
|
1194 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1195 |
+
|
1196 |
+
transformer_outputs = self.model(
|
1197 |
+
input_ids,
|
1198 |
+
attention_mask=attention_mask,
|
1199 |
+
position_ids=position_ids,
|
1200 |
+
past_key_values=past_key_values,
|
1201 |
+
inputs_embeds=inputs_embeds,
|
1202 |
+
use_cache=use_cache,
|
1203 |
+
output_attentions=output_attentions,
|
1204 |
+
output_hidden_states=output_hidden_states,
|
1205 |
+
return_dict=return_dict,
|
1206 |
+
)
|
1207 |
+
hidden_states = transformer_outputs[0]
|
1208 |
+
logits = self.score(hidden_states)
|
1209 |
+
|
1210 |
+
if input_ids is not None:
|
1211 |
+
batch_size = input_ids.shape[0]
|
1212 |
+
else:
|
1213 |
+
batch_size = inputs_embeds.shape[0]
|
1214 |
+
|
1215 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1216 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1217 |
+
if self.config.pad_token_id is None:
|
1218 |
+
sequence_lengths = -1
|
1219 |
+
else:
|
1220 |
+
if input_ids is not None:
|
1221 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1222 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1223 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1224 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1225 |
+
else:
|
1226 |
+
sequence_lengths = -1
|
1227 |
+
|
1228 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1229 |
+
|
1230 |
+
loss = None
|
1231 |
+
if labels is not None:
|
1232 |
+
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
1233 |
+
|
1234 |
+
if not return_dict:
|
1235 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1236 |
+
return ((loss,) + output) if loss is not None else output
|
1237 |
+
|
1238 |
+
return SequenceClassifierOutputWithPast(
|
1239 |
+
loss=loss,
|
1240 |
+
logits=pooled_logits,
|
1241 |
+
past_key_values=transformer_outputs.past_key_values,
|
1242 |
+
hidden_states=transformer_outputs.hidden_states,
|
1243 |
+
attentions=transformer_outputs.attentions,
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
|
1247 |
+
@add_start_docstrings(
|
1248 |
+
"""
|
1249 |
+
The MERaLiONText Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
1250 |
+
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
1251 |
+
""",
|
1252 |
+
MERALION_TEXT_START_DOCSTRING,
|
1253 |
+
)
|
1254 |
+
class MERaLiONTextForTokenClassification(MERaLiONTextPreTrainedModel):
|
1255 |
+
def __init__(self, config):
|
1256 |
+
super().__init__(config)
|
1257 |
+
self.num_labels = config.num_labels
|
1258 |
+
self.model = MERaLiONTextModel(config)
|
1259 |
+
if getattr(config, "classifier_dropout", None) is not None:
|
1260 |
+
classifier_dropout = config.classifier_dropout
|
1261 |
+
elif getattr(config, "hidden_dropout", None) is not None:
|
1262 |
+
classifier_dropout = config.hidden_dropout
|
1263 |
+
else:
|
1264 |
+
classifier_dropout = 0.1
|
1265 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1266 |
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
1267 |
+
|
1268 |
+
# Initialize weights and apply final processing
|
1269 |
+
self.post_init()
|
1270 |
+
|
1271 |
+
def get_input_embeddings(self):
|
1272 |
+
return self.model.embed_tokens
|
1273 |
+
|
1274 |
+
def set_input_embeddings(self, value):
|
1275 |
+
self.model.embed_tokens = value
|
1276 |
+
|
1277 |
+
@add_start_docstrings_to_model_forward(MERALION_TEXT_INPUTS_DOCSTRING)
|
1278 |
+
@add_code_sample_docstrings(
|
1279 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1280 |
+
output_type=TokenClassifierOutput,
|
1281 |
+
config_class=_CONFIG_FOR_DOC,
|
1282 |
+
)
|
1283 |
+
def forward(
|
1284 |
+
self,
|
1285 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1286 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1287 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1288 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1289 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1290 |
+
labels: Optional[torch.LongTensor] = None,
|
1291 |
+
use_cache: Optional[bool] = None,
|
1292 |
+
output_attentions: Optional[bool] = None,
|
1293 |
+
output_hidden_states: Optional[bool] = None,
|
1294 |
+
return_dict: Optional[bool] = None,
|
1295 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1296 |
+
r"""
|
1297 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1298 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1299 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1300 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1301 |
+
"""
|
1302 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1303 |
+
|
1304 |
+
outputs = self.model(
|
1305 |
+
input_ids,
|
1306 |
+
attention_mask=attention_mask,
|
1307 |
+
position_ids=position_ids,
|
1308 |
+
past_key_values=past_key_values,
|
1309 |
+
inputs_embeds=inputs_embeds,
|
1310 |
+
use_cache=use_cache,
|
1311 |
+
output_attentions=output_attentions,
|
1312 |
+
output_hidden_states=output_hidden_states,
|
1313 |
+
return_dict=return_dict,
|
1314 |
+
)
|
1315 |
+
sequence_output = outputs[0]
|
1316 |
+
sequence_output = self.dropout(sequence_output)
|
1317 |
+
logits = self.score(sequence_output)
|
1318 |
+
|
1319 |
+
loss = None
|
1320 |
+
if labels is not None:
|
1321 |
+
loss = self.loss_function(logits, labels, self.config)
|
1322 |
+
|
1323 |
+
if not return_dict:
|
1324 |
+
output = (logits,) + outputs[2:]
|
1325 |
+
return ((loss,) + output) if loss is not None else output
|
1326 |
+
|
1327 |
+
return TokenClassifierOutput(
|
1328 |
+
loss=loss,
|
1329 |
+
logits=logits,
|
1330 |
+
hidden_states=outputs.hidden_states,
|
1331 |
+
attentions=outputs.attentions,
|
1332 |
+
)
|