Yingxu He commited on
Commit
3094fb4
·
verified ·
1 Parent(s): 3f50621

Upload MERaLiONForConditionalGeneration

Browse files
Files changed (2) hide show
  1. config.json +6 -2
  2. modeling_text_decoder.py +361 -133
config.json CHANGED
@@ -1,7 +1,10 @@
1
  {
2
- "_attn_implementation_autoset": true,
 
 
3
  "auto_map": {
4
- "AutoConfig": "configuration_meralion.MERaLiONConfig"
 
5
  },
6
  "head_dim": 256,
7
  "hidden_size": 3584,
@@ -163,5 +166,6 @@
163
  "sliding_window_size": 4096,
164
  "torch_dtype": "bfloat16"
165
  },
 
166
  "transformers_version": "4.46.3"
167
  }
 
1
  {
2
+ "architectures": [
3
+ "MERaLiONForConditionalGeneration"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_meralion.MERaLiONConfig",
7
+ "AutoModelForSpeechSeq2Seq": "modeling_meralion.MERaLiONForConditionalGeneration"
8
  },
9
  "head_dim": 256,
10
  "hidden_size": 3584,
 
166
  "sliding_window_size": 4096,
167
  "torch_dtype": "bfloat16"
168
  },
169
+ "torch_dtype": "bfloat16",
170
  "transformers_version": "4.46.3"
171
  }
modeling_text_decoder.py CHANGED
@@ -16,21 +16,24 @@
16
  from typing import List, Optional, Tuple, Union
17
 
18
  import torch
 
19
  import torch.utils.checkpoint
20
- from torch import nn
21
- from torch.nn import CrossEntropyLoss
22
 
23
  from transformers.activations import ACT2FN
24
  from transformers.cache_utils import Cache, HybridCache
 
 
25
  from transformers.modeling_outputs import (
26
  BaseModelOutputWithPast,
27
  CausalLMOutputWithPast,
 
 
28
  )
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import (
 
31
  add_start_docstrings,
32
  add_start_docstrings_to_model_forward,
33
- is_flash_attn_2_available,
34
  is_flash_attn_greater_or_equal,
35
  is_flash_attn_greater_or_equal_2_10,
36
  logging,
@@ -39,64 +42,7 @@ from transformers.utils import (
39
  from .configuration_meralion import MERaLiONTextConfig
40
 
41
 
42
- if is_flash_attn_2_available():
43
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
44
-
45
-
46
- logger = logging.get_logger(__name__)
47
-
48
-
49
- # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
50
- def _prepare_4d_causal_attention_mask_with_cache_position(
51
- attention_mask: torch.Tensor,
52
- sequence_length: int,
53
- target_length: int,
54
- dtype: torch.dtype,
55
- device: torch.device,
56
- min_dtype: float,
57
- cache_position: torch.Tensor,
58
- batch_size: int,
59
- ):
60
- """
61
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
62
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
63
-
64
- Args:
65
- attention_mask (`torch.Tensor`):
66
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
67
- sequence_length (`int`):
68
- The sequence length being processed.
69
- target_length (`int`):
70
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
71
- dtype (`torch.dtype`):
72
- The dtype to use for the 4D attention mask.
73
- device (`torch.device`):
74
- The device to plcae the 4D attention mask on.
75
- min_dtype (`float`):
76
- The minimum value representable with the dtype `dtype`.
77
- cache_position (`torch.Tensor`):
78
- Indices depicting the position of the input sequence tokens in the sequence.
79
- batch_size (`torch.Tensor`):
80
- Batch size.
81
- """
82
- if attention_mask is not None and attention_mask.dim() == 4:
83
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
84
- causal_mask = attention_mask
85
- else:
86
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
87
- if sequence_length != 1:
88
- causal_mask = torch.triu(causal_mask, diagonal=1)
89
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
90
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
91
- if attention_mask is not None:
92
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
93
- mask_length = attention_mask.shape[-1]
94
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
95
- padding_mask = padding_mask == 0
96
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
97
- padding_mask, min_dtype
98
- )
99
- return causal_mask
100
 
101
 
102
  class MERaLiONTextRMSNorm(nn.Module):
@@ -119,6 +65,24 @@ class MERaLiONTextRMSNorm(nn.Module):
119
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  class MERaLiONTextRotaryEmbedding(nn.Module):
123
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
124
  super().__init__()
@@ -181,21 +145,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
181
  return q_embed, k_embed
182
 
183
 
184
- class MERaLiONTextMLP(nn.Module):
185
- def __init__(self, config):
186
- super().__init__()
187
- self.config = config
188
- self.hidden_size = config.hidden_size
189
- self.intermediate_size = config.intermediate_size
190
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
191
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
192
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
193
- self.act_fn = ACT2FN[config.hidden_activation]
194
-
195
- def forward(self, x):
196
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
197
-
198
-
199
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
200
  """
201
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -243,12 +192,12 @@ class MERaLiONTextAttention(nn.Module):
243
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
244
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
245
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
 
246
  self.rotary_emb = MERaLiONTextRotaryEmbedding(
247
  self.head_dim,
248
  max_position_embeddings=self.max_position_embeddings,
249
  base=self.rope_theta,
250
  )
251
- self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
252
 
253
  def forward(
254
  self,
@@ -492,9 +441,11 @@ class MERaLiONTextSdpaAttention(MERaLiONTextAttention):
492
 
493
  key_states = repeat_kv(key_states, self.num_key_value_groups)
494
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
495
  causal_mask = attention_mask
496
  if attention_mask is not None:
497
  causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
 
498
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
499
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
500
  if query_states.device.type == "cuda" and causal_mask is not None:
@@ -505,6 +456,7 @@ class MERaLiONTextSdpaAttention(MERaLiONTextAttention):
505
  # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
506
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
507
  is_causal = True if causal_mask is None and q_len > 1 else False
 
508
  attn_output = torch.nn.functional.scaled_dot_product_attention(
509
  query_states,
510
  key_states,
@@ -523,7 +475,7 @@ class MERaLiONTextSdpaAttention(MERaLiONTextAttention):
523
  return attn_output, None, past_key_value
524
 
525
 
526
- MERaLiONText_ATTENTION_CLASSES = {
527
  "eager": MERaLiONTextAttention,
528
  "flash_attention_2": MERaLiONTextFlashAttention2,
529
  "sdpa": MERaLiONTextSdpaAttention,
@@ -533,19 +485,16 @@ MERaLiONText_ATTENTION_CLASSES = {
533
  class MERaLiONTextDecoderLayer(nn.Module):
534
  def __init__(self, config: MERaLiONTextConfig, layer_idx: int):
535
  super().__init__()
536
- self.config = config
537
  self.hidden_size = config.hidden_size
538
-
539
- self.self_attn = MERaLiONText_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
540
-
541
  self.mlp = MERaLiONTextMLP(config)
542
  self.input_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
543
- self.post_attention_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
544
-
545
  self.is_sliding = not bool(layer_idx % 2)
546
  self.pre_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
547
  self.post_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
548
  self.sliding_window = config.sliding_window
 
549
 
550
  def forward(
551
  self,
@@ -557,6 +506,25 @@ class MERaLiONTextDecoderLayer(nn.Module):
557
  use_cache: Optional[bool] = False,
558
  cache_position: Optional[torch.LongTensor] = None,
559
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
561
  # Flash-attn is a 2D tensor
562
  if self.config._attn_implementation == "flash_attention_2":
@@ -570,6 +538,7 @@ class MERaLiONTextDecoderLayer(nn.Module):
570
  attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
571
  if attention_mask.shape[-1] <= 1: # when decoding
572
  attention_mask = attention_mask[:, :, :, -self.sliding_window :]
 
573
  residual = hidden_states
574
 
575
  hidden_states = self.input_layernorm(hidden_states)
@@ -648,6 +617,20 @@ class MERaLiONTextPreTrainedModel(PreTrainedModel):
648
  if module.padding_idx is not None:
649
  module.weight.data[module.padding_idx].zero_()
650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
 
652
  _CONFIG_FOR_DOC = "MERaLiONTextConfig"
653
 
@@ -693,7 +676,8 @@ MERALION_TEXT_INPUTS_DOCSTRING = r"""
693
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
694
 
695
  Two formats are allowed:
696
- - a [`~cache_utils.Cache`] instance;
 
697
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
698
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
699
  cache format.
@@ -765,7 +749,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
765
  input_ids: torch.LongTensor = None,
766
  attention_mask: Optional[torch.Tensor] = None,
767
  position_ids: Optional[torch.LongTensor] = None,
768
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
769
  inputs_embeds: Optional[torch.FloatTensor] = None,
770
  use_cache: Optional[bool] = None,
771
  output_attentions: Optional[bool] = None,
@@ -781,9 +765,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
781
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
782
 
783
  if (input_ids is None) ^ (inputs_embeds is not None):
784
- raise ValueError(
785
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
786
- )
787
 
788
  if self.gradient_checkpointing and self.training and use_cache:
789
  logger.warning_once(
@@ -794,8 +776,21 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
794
  if inputs_embeds is None:
795
  inputs_embeds = self.embed_tokens(input_ids)
796
 
 
 
 
 
 
 
 
 
 
 
797
  if cache_position is None:
798
- cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
 
 
 
799
 
800
  if position_ids is None:
801
  position_ids = cache_position.unsqueeze(0)
@@ -813,6 +808,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
813
  normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
814
  hidden_states = hidden_states * normalizer
815
 
 
816
  all_hidden_states = () if output_hidden_states else None
817
  all_self_attns = () if output_attentions else None
818
 
@@ -849,7 +845,6 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
849
 
850
  hidden_states = self.norm(hidden_states)
851
 
852
- # add hidden states from the last decoder layer
853
  if output_hidden_states:
854
  all_hidden_states += (hidden_states,)
855
 
@@ -869,7 +864,7 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
869
  attention_mask: torch.Tensor,
870
  input_tensor: torch.Tensor,
871
  cache_position: torch.Tensor,
872
- past_key_values: Cache,
873
  output_attentions: bool,
874
  ):
875
  # Flash Attention currently doesn't support static cache but MERaLiONText work only with static cache.
@@ -880,28 +875,82 @@ class MERaLiONTextModel(MERaLiONTextPreTrainedModel):
880
  return attention_mask
881
 
882
  dtype, device = input_tensor.dtype, input_tensor.device
883
- min_dtype = torch.finfo(dtype).min
884
  sequence_length = input_tensor.shape[1]
885
  if isinstance(past_key_values, HybridCache):
886
- target_length = past_key_values.get_max_length()
887
  else:
888
  target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
889
 
890
  # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
891
- causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
892
  attention_mask,
893
  sequence_length=sequence_length,
894
  target_length=target_length,
895
  dtype=dtype,
896
  device=device,
897
- min_dtype=min_dtype,
898
  cache_position=cache_position,
899
  batch_size=input_tensor.shape[0],
900
  )
901
  return causal_mask
902
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
 
904
- class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
905
  _tied_weights_keys = ["lm_head.weight"]
906
 
907
  def __init__(self, config):
@@ -938,7 +987,7 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
938
  input_ids: torch.LongTensor = None,
939
  attention_mask: Optional[torch.Tensor] = None,
940
  position_ids: Optional[torch.LongTensor] = None,
941
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
942
  inputs_embeds: Optional[torch.FloatTensor] = None,
943
  labels: Optional[torch.LongTensor] = None,
944
  use_cache: Optional[bool] = None,
@@ -946,6 +995,8 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
946
  output_hidden_states: Optional[bool] = None,
947
  return_dict: Optional[bool] = None,
948
  cache_position: Optional[torch.LongTensor] = None,
 
 
949
  ) -> Union[Tuple, CausalLMOutputWithPast]:
950
  r"""
951
  Args:
@@ -954,24 +1005,14 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
954
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
955
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
956
 
957
- Returns:
958
-
959
- Example:
 
960
 
961
- ```python
962
- >>> from transformers import AutoTokenizer, GemmaForCausalLM
963
-
964
- >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
965
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
966
-
967
- >>> prompt = "What is your favorite condiment?"
968
- >>> inputs = tokenizer(prompt, return_tensors="pt")
969
 
970
- >>> # Generate
971
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
972
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
973
- "What is your favorite condiment?"
974
- ```"""
975
  if self.training and self.config._attn_implementation != "eager":
976
  logger.warning_once(
977
  "It is strongly recommended to train MERaLiONText models with the `eager` attention implementation "
@@ -983,7 +1024,6 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
983
  )
984
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
985
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
986
-
987
  outputs = self.model(
988
  input_ids=input_ids,
989
  attention_mask=attention_mask,
@@ -998,25 +1038,16 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
998
  )
999
 
1000
  hidden_states = outputs[0]
1001
- logits = self.lm_head(hidden_states)
 
1002
  if self.config.final_logit_softcapping is not None:
1003
  logits = logits / self.config.final_logit_softcapping
1004
  logits = torch.tanh(logits)
1005
  logits = logits * self.config.final_logit_softcapping
1006
 
1007
- logits = logits.float()
1008
  loss = None
1009
  if labels is not None:
1010
- # Shift so that tokens < n predict n
1011
- shift_logits = logits[..., :-1, :].contiguous()
1012
- shift_labels = labels[..., 1:].contiguous()
1013
- # Flatten the tokens
1014
- loss_fct = CrossEntropyLoss()
1015
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1016
- shift_labels = shift_labels.view(-1)
1017
- # Enable model parallelism
1018
- shift_labels = shift_labels.to(shift_logits.device)
1019
- loss = loss_fct(shift_logits, shift_labels)
1020
 
1021
  if not return_dict:
1022
  output = (logits,) + outputs[1:]
@@ -1039,8 +1070,11 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
1039
  cache_position=None,
1040
  position_ids=None,
1041
  use_cache=True,
 
1042
  **kwargs,
1043
  ):
 
 
1044
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1045
  # Exception 1: when passing input_embeds, input_ids may be missing entries
1046
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
@@ -1080,18 +1114,20 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
1080
  else:
1081
  batch_size, sequence_length = model_inputs["input_ids"].shape
1082
  device = model_inputs["input_ids"].device
1083
- dtype = self.lm_head.weight.dtype
1084
- min_dtype = torch.finfo(dtype).min
1085
- attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1086
  attention_mask,
1087
  sequence_length=sequence_length,
1088
- target_length=past_key_values.get_max_length(),
1089
- dtype=dtype,
1090
  device=device,
1091
- min_dtype=min_dtype,
1092
  cache_position=cache_position,
1093
  batch_size=batch_size,
1094
  )
 
 
 
 
1095
  model_inputs.update(
1096
  {
1097
  "position_ids": position_ids,
@@ -1102,3 +1138,195 @@ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel):
1102
  }
1103
  )
1104
  return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from typing import List, Optional, Tuple, Union
17
 
18
  import torch
19
+ import torch.nn as nn
20
  import torch.utils.checkpoint
 
 
21
 
22
  from transformers.activations import ACT2FN
23
  from transformers.cache_utils import Cache, HybridCache
24
+ from transformers.generation import GenerationMixin
25
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
26
  from transformers.modeling_outputs import (
27
  BaseModelOutputWithPast,
28
  CausalLMOutputWithPast,
29
+ SequenceClassifierOutputWithPast,
30
+ TokenClassifierOutput,
31
  )
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import (
34
+ add_code_sample_docstrings,
35
  add_start_docstrings,
36
  add_start_docstrings_to_model_forward,
 
37
  is_flash_attn_greater_or_equal,
38
  is_flash_attn_greater_or_equal_2_10,
39
  logging,
 
42
  from .configuration_meralion import MERaLiONTextConfig
43
 
44
 
45
+ _CHECKPOINT_FOR_DOC = "MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  class MERaLiONTextRMSNorm(nn.Module):
 
65
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
66
 
67
 
68
+ class MERaLiONTextMLP(nn.Module):
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.config = config
72
+ self.hidden_size = config.hidden_size
73
+ self.intermediate_size = config.intermediate_size
74
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
75
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
76
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
77
+ self.act_fn = ACT2FN[config.hidden_activation]
78
+
79
+ def forward(self, x):
80
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
81
+
82
+
83
+ logger = logging.get_logger(__name__)
84
+
85
+
86
  class MERaLiONTextRotaryEmbedding(nn.Module):
87
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
88
  super().__init__()
 
145
  return q_embed, k_embed
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
149
  """
150
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
192
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
193
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
194
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
195
+ self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
196
  self.rotary_emb = MERaLiONTextRotaryEmbedding(
197
  self.head_dim,
198
  max_position_embeddings=self.max_position_embeddings,
199
  base=self.rope_theta,
200
  )
 
201
 
202
  def forward(
203
  self,
 
441
 
442
  key_states = repeat_kv(key_states, self.num_key_value_groups)
443
  value_states = repeat_kv(value_states, self.num_key_value_groups)
444
+
445
  causal_mask = attention_mask
446
  if attention_mask is not None:
447
  causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
448
+
449
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
450
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
451
  if query_states.device.type == "cuda" and causal_mask is not None:
 
456
  # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
457
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
458
  is_causal = True if causal_mask is None and q_len > 1 else False
459
+
460
  attn_output = torch.nn.functional.scaled_dot_product_attention(
461
  query_states,
462
  key_states,
 
475
  return attn_output, None, past_key_value
476
 
477
 
478
+ MERALION_TEXT_ATTENTION_CLASSES = {
479
  "eager": MERaLiONTextAttention,
480
  "flash_attention_2": MERaLiONTextFlashAttention2,
481
  "sdpa": MERaLiONTextSdpaAttention,
 
485
  class MERaLiONTextDecoderLayer(nn.Module):
486
  def __init__(self, config: MERaLiONTextConfig, layer_idx: int):
487
  super().__init__()
 
488
  self.hidden_size = config.hidden_size
489
+ self.self_attn = MERALION_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
 
 
490
  self.mlp = MERaLiONTextMLP(config)
491
  self.input_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
492
+ self.config = config
 
493
  self.is_sliding = not bool(layer_idx % 2)
494
  self.pre_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
495
  self.post_feedforward_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
496
  self.sliding_window = config.sliding_window
497
+ self.post_attention_layernorm = MERaLiONTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
498
 
499
  def forward(
500
  self,
 
506
  use_cache: Optional[bool] = False,
507
  cache_position: Optional[torch.LongTensor] = None,
508
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
509
+ """
510
+ Args:
511
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
512
+ attention_mask (`torch.FloatTensor`, *optional*):
513
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
514
+ query_sequence_length, key_sequence_length)` if default attention is used.
515
+ output_attentions (`bool`, *optional*):
516
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
517
+ returned tensors for more detail.
518
+ use_cache (`bool`, *optional*):
519
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
520
+ (see `past_key_values`).
521
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
522
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
523
+ Indices depicting the position of the input sequence tokens in the sequence
524
+ kwargs (`dict`, *optional*):
525
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
526
+ into the model
527
+ """
528
  if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
529
  # Flash-attn is a 2D tensor
530
  if self.config._attn_implementation == "flash_attention_2":
 
538
  attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
539
  if attention_mask.shape[-1] <= 1: # when decoding
540
  attention_mask = attention_mask[:, :, :, -self.sliding_window :]
541
+
542
  residual = hidden_states
543
 
544
  hidden_states = self.input_layernorm(hidden_states)
 
617
  if module.padding_idx is not None:
618
  module.weight.data[module.padding_idx].zero_()
619
 
620
+ @classmethod
621
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
622
+ """
623
+ Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on MERaLiONText models.
624
+ SDPA reduces the model performance on MERaLiONText because of the logits softcapping.
625
+ """
626
+ config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
627
+
628
+ # if using the default path -> swap sdpa by eager
629
+ if not hard_check_only and config._attn_implementation == "sdpa":
630
+ config._attn_implementation = "eager"
631
+
632
+ return config
633
+
634
 
635
  _CONFIG_FOR_DOC = "MERaLiONTextConfig"
636
 
 
676
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
677
 
678
  Two formats are allowed:
679
+ - a [`~cache_utils.Cache`] instance, see our
680
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
681
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
682
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
683
  cache format.
 
749
  input_ids: torch.LongTensor = None,
750
  attention_mask: Optional[torch.Tensor] = None,
751
  position_ids: Optional[torch.LongTensor] = None,
752
+ past_key_values: Optional[HybridCache] = None,
753
  inputs_embeds: Optional[torch.FloatTensor] = None,
754
  use_cache: Optional[bool] = None,
755
  output_attentions: Optional[bool] = None,
 
765
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
766
 
767
  if (input_ids is None) ^ (inputs_embeds is not None):
768
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
769
 
770
  if self.gradient_checkpointing and self.training and use_cache:
771
  logger.warning_once(
 
776
  if inputs_embeds is None:
777
  inputs_embeds = self.embed_tokens(input_ids)
778
 
779
+ if use_cache and past_key_values is None and not self.training:
780
+ batch_size, seq_len, _ = inputs_embeds.shape
781
+ past_key_values = HybridCache(
782
+ self.config,
783
+ batch_size=batch_size,
784
+ max_cache_len=seq_len,
785
+ device=self.device,
786
+ dtype=inputs_embeds.dtype,
787
+ )
788
+
789
  if cache_position is None:
790
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
791
+ cache_position = torch.arange(
792
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
793
+ )
794
 
795
  if position_ids is None:
796
  position_ids = cache_position.unsqueeze(0)
 
808
  normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
809
  hidden_states = hidden_states * normalizer
810
 
811
+ # decoder layers
812
  all_hidden_states = () if output_hidden_states else None
813
  all_self_attns = () if output_attentions else None
814
 
 
845
 
846
  hidden_states = self.norm(hidden_states)
847
 
 
848
  if output_hidden_states:
849
  all_hidden_states += (hidden_states,)
850
 
 
864
  attention_mask: torch.Tensor,
865
  input_tensor: torch.Tensor,
866
  cache_position: torch.Tensor,
867
+ past_key_values: HybridCache,
868
  output_attentions: bool,
869
  ):
870
  # Flash Attention currently doesn't support static cache but MERaLiONText work only with static cache.
 
875
  return attention_mask
876
 
877
  dtype, device = input_tensor.dtype, input_tensor.device
 
878
  sequence_length = input_tensor.shape[1]
879
  if isinstance(past_key_values, HybridCache):
880
+ target_length = past_key_values.get_max_cache_shape()
881
  else:
882
  target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
883
 
884
  # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
885
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
886
  attention_mask,
887
  sequence_length=sequence_length,
888
  target_length=target_length,
889
  dtype=dtype,
890
  device=device,
 
891
  cache_position=cache_position,
892
  batch_size=input_tensor.shape[0],
893
  )
894
  return causal_mask
895
 
896
+ @staticmethod
897
+ def _prepare_4d_causal_attention_mask_with_cache_position(
898
+ attention_mask: torch.Tensor,
899
+ sequence_length: int,
900
+ target_length: int,
901
+ dtype: torch.dtype,
902
+ device: torch.device,
903
+ cache_position: torch.Tensor,
904
+ batch_size: int,
905
+ **kwargs,
906
+ ):
907
+ """
908
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
909
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
910
+
911
+ Args:
912
+ attention_mask (`torch.Tensor`):
913
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
914
+ `(batch_size, 1, query_length, key_value_length)`.
915
+ sequence_length (`int`):
916
+ The sequence length being processed.
917
+ target_length (`int`):
918
+ The target length: when generating with static cache, the mask should be as long as the static cache,
919
+ to account for the 0 padding, the part of the cache that is not filled yet.
920
+ dtype (`torch.dtype`):
921
+ The dtype to use for the 4D attention mask.
922
+ device (`torch.device`):
923
+ The device to plcae the 4D attention mask on.
924
+ cache_position (`torch.Tensor`):
925
+ Indices depicting the position of the input sequence tokens in the sequence.
926
+ batch_size (`torch.Tensor`):
927
+ Batch size.
928
+ """
929
+ if attention_mask is not None and attention_mask.dim() == 4:
930
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
931
+ causal_mask = attention_mask
932
+ else:
933
+ min_dtype = torch.finfo(dtype).min
934
+ causal_mask = torch.full(
935
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
936
+ )
937
+ if sequence_length != 1:
938
+ causal_mask = torch.triu(causal_mask, diagonal=1)
939
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
940
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
941
+ if attention_mask is not None:
942
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
943
+ mask_length = attention_mask.shape[-1]
944
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
945
+ padding_mask = padding_mask == 0
946
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
947
+ padding_mask, min_dtype
948
+ )
949
+
950
+ return causal_mask
951
+
952
 
953
+ class MERaLiONTextForCausalLM(MERaLiONTextPreTrainedModel, GenerationMixin):
954
  _tied_weights_keys = ["lm_head.weight"]
955
 
956
  def __init__(self, config):
 
987
  input_ids: torch.LongTensor = None,
988
  attention_mask: Optional[torch.Tensor] = None,
989
  position_ids: Optional[torch.LongTensor] = None,
990
+ past_key_values: Optional[HybridCache] = None,
991
  inputs_embeds: Optional[torch.FloatTensor] = None,
992
  labels: Optional[torch.LongTensor] = None,
993
  use_cache: Optional[bool] = None,
 
995
  output_hidden_states: Optional[bool] = None,
996
  return_dict: Optional[bool] = None,
997
  cache_position: Optional[torch.LongTensor] = None,
998
+ num_logits_to_keep: int = 0,
999
+ **loss_kwargs,
1000
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1001
  r"""
1002
  Args:
 
1005
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1006
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1007
 
1008
+ num_logits_to_keep (`int`, *optional*):
1009
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1010
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1011
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1012
 
1013
+ Returns:
1014
+ """
 
 
 
 
 
 
1015
 
 
 
 
 
 
1016
  if self.training and self.config._attn_implementation != "eager":
1017
  logger.warning_once(
1018
  "It is strongly recommended to train MERaLiONText models with the `eager` attention implementation "
 
1024
  )
1025
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1026
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 
1027
  outputs = self.model(
1028
  input_ids=input_ids,
1029
  attention_mask=attention_mask,
 
1038
  )
1039
 
1040
  hidden_states = outputs[0]
1041
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1042
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1043
  if self.config.final_logit_softcapping is not None:
1044
  logits = logits / self.config.final_logit_softcapping
1045
  logits = torch.tanh(logits)
1046
  logits = logits * self.config.final_logit_softcapping
1047
 
 
1048
  loss = None
1049
  if labels is not None:
1050
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
 
 
 
 
 
 
 
 
 
1051
 
1052
  if not return_dict:
1053
  output = (logits,) + outputs[1:]
 
1070
  cache_position=None,
1071
  position_ids=None,
1072
  use_cache=True,
1073
+ num_logits_to_keep=None,
1074
  **kwargs,
1075
  ):
1076
+ # Overwritten: has a special cache type, `HybridCache`
1077
+
1078
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1079
  # Exception 1: when passing input_embeds, input_ids may be missing entries
1080
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
 
1114
  else:
1115
  batch_size, sequence_length = model_inputs["input_ids"].shape
1116
  device = model_inputs["input_ids"].device
1117
+
1118
+ attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
 
1119
  attention_mask,
1120
  sequence_length=sequence_length,
1121
+ target_length=past_key_values.get_max_cache_shape(),
1122
+ dtype=self.lm_head.weight.dtype,
1123
  device=device,
 
1124
  cache_position=cache_position,
1125
  batch_size=batch_size,
1126
  )
1127
+
1128
+ if num_logits_to_keep is not None:
1129
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
1130
+
1131
  model_inputs.update(
1132
  {
1133
  "position_ids": position_ids,
 
1138
  }
1139
  )
1140
  return model_inputs
1141
+
1142
+
1143
+ @add_start_docstrings(
1144
+ """
1145
+ The MERaLiONText Model transformer with a sequence classification head on top (linear layer).
1146
+
1147
+ [`MERaLiONTextForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1148
+ (e.g. GPT-2) do.
1149
+
1150
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1151
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1152
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1153
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1154
+ each row of the batch).
1155
+ """,
1156
+ MERALION_TEXT_START_DOCSTRING,
1157
+ )
1158
+ class MERaLiONTextForSequenceClassification(MERaLiONTextPreTrainedModel):
1159
+ def __init__(self, config):
1160
+ super().__init__(config)
1161
+ self.num_labels = config.num_labels
1162
+ self.model = MERaLiONTextModel(config)
1163
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1164
+
1165
+ # Initialize weights and apply final processing
1166
+ self.post_init()
1167
+
1168
+ def get_input_embeddings(self):
1169
+ return self.model.embed_tokens
1170
+
1171
+ def set_input_embeddings(self, value):
1172
+ self.model.embed_tokens = value
1173
+
1174
+ @add_start_docstrings_to_model_forward(MERALION_TEXT_INPUTS_DOCSTRING)
1175
+ def forward(
1176
+ self,
1177
+ input_ids: Optional[torch.LongTensor] = None,
1178
+ attention_mask: Optional[torch.Tensor] = None,
1179
+ position_ids: Optional[torch.LongTensor] = None,
1180
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1181
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1182
+ labels: Optional[torch.LongTensor] = None,
1183
+ use_cache: Optional[bool] = None,
1184
+ output_attentions: Optional[bool] = None,
1185
+ output_hidden_states: Optional[bool] = None,
1186
+ return_dict: Optional[bool] = None,
1187
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1188
+ r"""
1189
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1190
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1191
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1192
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1193
+ """
1194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1195
+
1196
+ transformer_outputs = self.model(
1197
+ input_ids,
1198
+ attention_mask=attention_mask,
1199
+ position_ids=position_ids,
1200
+ past_key_values=past_key_values,
1201
+ inputs_embeds=inputs_embeds,
1202
+ use_cache=use_cache,
1203
+ output_attentions=output_attentions,
1204
+ output_hidden_states=output_hidden_states,
1205
+ return_dict=return_dict,
1206
+ )
1207
+ hidden_states = transformer_outputs[0]
1208
+ logits = self.score(hidden_states)
1209
+
1210
+ if input_ids is not None:
1211
+ batch_size = input_ids.shape[0]
1212
+ else:
1213
+ batch_size = inputs_embeds.shape[0]
1214
+
1215
+ if self.config.pad_token_id is None and batch_size != 1:
1216
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1217
+ if self.config.pad_token_id is None:
1218
+ sequence_lengths = -1
1219
+ else:
1220
+ if input_ids is not None:
1221
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1222
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1223
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1224
+ sequence_lengths = sequence_lengths.to(logits.device)
1225
+ else:
1226
+ sequence_lengths = -1
1227
+
1228
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1229
+
1230
+ loss = None
1231
+ if labels is not None:
1232
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1233
+
1234
+ if not return_dict:
1235
+ output = (pooled_logits,) + transformer_outputs[1:]
1236
+ return ((loss,) + output) if loss is not None else output
1237
+
1238
+ return SequenceClassifierOutputWithPast(
1239
+ loss=loss,
1240
+ logits=pooled_logits,
1241
+ past_key_values=transformer_outputs.past_key_values,
1242
+ hidden_states=transformer_outputs.hidden_states,
1243
+ attentions=transformer_outputs.attentions,
1244
+ )
1245
+
1246
+
1247
+ @add_start_docstrings(
1248
+ """
1249
+ The MERaLiONText Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1250
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1251
+ """,
1252
+ MERALION_TEXT_START_DOCSTRING,
1253
+ )
1254
+ class MERaLiONTextForTokenClassification(MERaLiONTextPreTrainedModel):
1255
+ def __init__(self, config):
1256
+ super().__init__(config)
1257
+ self.num_labels = config.num_labels
1258
+ self.model = MERaLiONTextModel(config)
1259
+ if getattr(config, "classifier_dropout", None) is not None:
1260
+ classifier_dropout = config.classifier_dropout
1261
+ elif getattr(config, "hidden_dropout", None) is not None:
1262
+ classifier_dropout = config.hidden_dropout
1263
+ else:
1264
+ classifier_dropout = 0.1
1265
+ self.dropout = nn.Dropout(classifier_dropout)
1266
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1267
+
1268
+ # Initialize weights and apply final processing
1269
+ self.post_init()
1270
+
1271
+ def get_input_embeddings(self):
1272
+ return self.model.embed_tokens
1273
+
1274
+ def set_input_embeddings(self, value):
1275
+ self.model.embed_tokens = value
1276
+
1277
+ @add_start_docstrings_to_model_forward(MERALION_TEXT_INPUTS_DOCSTRING)
1278
+ @add_code_sample_docstrings(
1279
+ checkpoint=_CHECKPOINT_FOR_DOC,
1280
+ output_type=TokenClassifierOutput,
1281
+ config_class=_CONFIG_FOR_DOC,
1282
+ )
1283
+ def forward(
1284
+ self,
1285
+ input_ids: Optional[torch.LongTensor] = None,
1286
+ attention_mask: Optional[torch.Tensor] = None,
1287
+ position_ids: Optional[torch.LongTensor] = None,
1288
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1289
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1290
+ labels: Optional[torch.LongTensor] = None,
1291
+ use_cache: Optional[bool] = None,
1292
+ output_attentions: Optional[bool] = None,
1293
+ output_hidden_states: Optional[bool] = None,
1294
+ return_dict: Optional[bool] = None,
1295
+ ) -> Union[Tuple, TokenClassifierOutput]:
1296
+ r"""
1297
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1298
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1299
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1300
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1301
+ """
1302
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1303
+
1304
+ outputs = self.model(
1305
+ input_ids,
1306
+ attention_mask=attention_mask,
1307
+ position_ids=position_ids,
1308
+ past_key_values=past_key_values,
1309
+ inputs_embeds=inputs_embeds,
1310
+ use_cache=use_cache,
1311
+ output_attentions=output_attentions,
1312
+ output_hidden_states=output_hidden_states,
1313
+ return_dict=return_dict,
1314
+ )
1315
+ sequence_output = outputs[0]
1316
+ sequence_output = self.dropout(sequence_output)
1317
+ logits = self.score(sequence_output)
1318
+
1319
+ loss = None
1320
+ if labels is not None:
1321
+ loss = self.loss_function(logits, labels, self.config)
1322
+
1323
+ if not return_dict:
1324
+ output = (logits,) + outputs[2:]
1325
+ return ((loss,) + output) if loss is not None else output
1326
+
1327
+ return TokenClassifierOutput(
1328
+ loss=loss,
1329
+ logits=logits,
1330
+ hidden_states=outputs.hidden_states,
1331
+ attentions=outputs.attentions,
1332
+ )