huseinzol05 commited on
Commit
0710ee7
·
verified ·
1 Parent(s): f67f43f

Upload MistralBiForMNTP

Browse files
Files changed (3) hide show
  1. bidirectional_mistral.py +281 -0
  2. config.json +4 -1
  3. model.safetensors +1 -1
bidirectional_mistral.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+
4
+ from transformers import (
5
+ MistralModel,
6
+ MistralPreTrainedModel,
7
+ MistralForCausalLM,
8
+ MistralConfig,
9
+ )
10
+ from transformers.modeling_outputs import BaseModelOutputWithPast
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.models.mistral.modeling_mistral import (
13
+ MistralDecoderLayer,
14
+ MistralRMSNorm,
15
+ MistralAttention,
16
+ MistralFlashAttention2,
17
+ MistralSdpaAttention,
18
+ MistralMLP,
19
+ )
20
+ from torch import nn
21
+ from transformers.utils import logging
22
+ from attn_mask_utils import (
23
+ _prepare_4d_causal_attention_mask,
24
+ _prepare_4d_causal_attention_mask_for_sdpa,
25
+ )
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class ModifiedMistralAttention(MistralAttention):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.is_causal = False
34
+
35
+
36
+ class ModifiedMistralFlashAttention2(MistralFlashAttention2):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self.is_causal = False
40
+
41
+
42
+ class ModifiedMistralSdpaAttention(MistralSdpaAttention):
43
+ def __init__(self, *args, **kwargs):
44
+ super().__init__(*args, **kwargs)
45
+ self.is_causal = False
46
+
47
+
48
+ MISTRAL_ATTENTION_CLASSES = {
49
+ "eager": ModifiedMistralAttention,
50
+ "flash_attention_2": ModifiedMistralFlashAttention2,
51
+ "sdpa": ModifiedMistralSdpaAttention,
52
+ }
53
+
54
+
55
+ class ModifiedMistralDecoderLayer(MistralDecoderLayer):
56
+ def __init__(self, config: MistralConfig, layer_idx: int):
57
+ nn.Module.__init__(self)
58
+ self.hidden_size = config.hidden_size
59
+
60
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](
61
+ config, layer_idx
62
+ )
63
+
64
+ self.mlp = MistralMLP(config)
65
+ self.input_layernorm = MistralRMSNorm(
66
+ config.hidden_size, eps=config.rms_norm_eps
67
+ )
68
+ self.post_attention_layernorm = MistralRMSNorm(
69
+ config.hidden_size, eps=config.rms_norm_eps
70
+ )
71
+
72
+
73
+ class MistralBiModel(MistralModel):
74
+ def __init__(self, config: MistralConfig):
75
+ MistralPreTrainedModel.__init__(self, config)
76
+ self.padding_idx = config.pad_token_id
77
+ self.vocab_size = config.vocab_size
78
+
79
+ self.embed_tokens = nn.Embedding(
80
+ config.vocab_size, config.hidden_size, self.padding_idx
81
+ )
82
+ self.layers = nn.ModuleList(
83
+ [
84
+ ModifiedMistralDecoderLayer(config, layer_idx)
85
+ for layer_idx in range(config.num_hidden_layers)
86
+ ]
87
+ )
88
+ self._attn_implementation = config._attn_implementation
89
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
90
+
91
+ self.gradient_checkpointing = False
92
+ # Initialize weights and apply final processing
93
+ self.post_init()
94
+
95
+ # Copied from forward() in transformers.models.mistral.modeling_mistral.MistralModel
96
+ def forward(
97
+ self,
98
+ input_ids: torch.LongTensor = None,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ position_ids: Optional[torch.LongTensor] = None,
101
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
102
+ inputs_embeds: Optional[torch.FloatTensor] = None,
103
+ use_cache: Optional[bool] = None,
104
+ output_attentions: Optional[bool] = None,
105
+ output_hidden_states: Optional[bool] = None,
106
+ return_dict: Optional[bool] = None,
107
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
108
+ output_attentions = (
109
+ output_attentions
110
+ if output_attentions is not None
111
+ else self.config.output_attentions
112
+ )
113
+ output_hidden_states = (
114
+ output_hidden_states
115
+ if output_hidden_states is not None
116
+ else self.config.output_hidden_states
117
+ )
118
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
119
+
120
+ return_dict = (
121
+ return_dict if return_dict is not None else self.config.use_return_dict
122
+ )
123
+
124
+ # retrieve input_ids and inputs_embeds
125
+ if input_ids is not None and inputs_embeds is not None:
126
+ raise ValueError(
127
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
128
+ )
129
+ elif input_ids is not None:
130
+ batch_size, seq_length = input_ids.shape
131
+ elif inputs_embeds is not None:
132
+ batch_size, seq_length, _ = inputs_embeds.shape
133
+ else:
134
+ raise ValueError(
135
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
136
+ )
137
+
138
+ if self.gradient_checkpointing and self.training:
139
+ if use_cache:
140
+ logger.warning_once(
141
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
142
+ )
143
+ use_cache = False
144
+
145
+ past_key_values_length = 0
146
+
147
+ if use_cache:
148
+ use_legacy_cache = not isinstance(past_key_values, Cache)
149
+ if use_legacy_cache:
150
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
151
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
152
+
153
+ if position_ids is None:
154
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
155
+ position_ids = torch.arange(
156
+ past_key_values_length,
157
+ seq_length + past_key_values_length,
158
+ dtype=torch.long,
159
+ device=device,
160
+ )
161
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
162
+ else:
163
+ position_ids = position_ids.view(-1, seq_length).long()
164
+
165
+ if inputs_embeds is None:
166
+ inputs_embeds = self.embed_tokens(input_ids)
167
+
168
+ if (
169
+ attention_mask is not None
170
+ and self._attn_implementation == "flash_attention_2"
171
+ and use_cache
172
+ ):
173
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
174
+ if is_padding_right:
175
+ raise ValueError(
176
+ "You are attempting to perform batched generation with padding_side='right'"
177
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
178
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
179
+ )
180
+
181
+ if self._attn_implementation == "flash_attention_2":
182
+ # 2d mask is passed through the layers
183
+ attention_mask = (
184
+ attention_mask
185
+ if (attention_mask is not None and 0 in attention_mask)
186
+ else None
187
+ )
188
+ elif self._attn_implementation == "sdpa" and not output_attentions:
189
+ # The original implementation is by-passed, see attn_mask_utils.py
190
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
191
+ attention_mask,
192
+ (batch_size, seq_length),
193
+ inputs_embeds,
194
+ past_key_values_length,
195
+ )
196
+ else:
197
+ # 4d mask is passed through the layers
198
+ attention_mask = _prepare_4d_causal_attention_mask(
199
+ attention_mask,
200
+ (batch_size, seq_length),
201
+ inputs_embeds,
202
+ past_key_values_length,
203
+ sliding_window=self.config.sliding_window,
204
+ )
205
+
206
+ hidden_states = inputs_embeds
207
+
208
+ # decoder layers
209
+ all_hidden_states = () if output_hidden_states else None
210
+ all_self_attns = () if output_attentions else None
211
+ next_decoder_cache = None
212
+
213
+ for decoder_layer in self.layers:
214
+ if output_hidden_states:
215
+ all_hidden_states += (hidden_states,)
216
+
217
+ if self.gradient_checkpointing and self.training:
218
+ layer_outputs = self._gradient_checkpointing_func(
219
+ decoder_layer.__call__,
220
+ hidden_states,
221
+ attention_mask,
222
+ position_ids,
223
+ past_key_values,
224
+ output_attentions,
225
+ use_cache,
226
+ )
227
+ else:
228
+ layer_outputs = decoder_layer(
229
+ hidden_states,
230
+ attention_mask=attention_mask,
231
+ position_ids=position_ids,
232
+ past_key_value=past_key_values,
233
+ output_attentions=output_attentions,
234
+ use_cache=use_cache,
235
+ )
236
+
237
+ hidden_states = layer_outputs[0]
238
+
239
+ if use_cache:
240
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
241
+
242
+ if output_attentions:
243
+ all_self_attns += (layer_outputs[1],)
244
+
245
+ hidden_states = self.norm(hidden_states)
246
+
247
+ # add hidden states from the last decoder layer
248
+ if output_hidden_states:
249
+ all_hidden_states += (hidden_states,)
250
+
251
+ next_cache = None
252
+ if use_cache:
253
+ next_cache = (
254
+ next_decoder_cache.to_legacy_cache()
255
+ if use_legacy_cache
256
+ else next_decoder_cache
257
+ )
258
+
259
+ if not return_dict:
260
+ return tuple(
261
+ v
262
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
263
+ if v is not None
264
+ )
265
+ return BaseModelOutputWithPast(
266
+ last_hidden_state=hidden_states,
267
+ past_key_values=next_cache,
268
+ hidden_states=all_hidden_states,
269
+ attentions=all_self_attns,
270
+ )
271
+
272
+
273
+ class MistralBiForMNTP(MistralForCausalLM):
274
+ def __init__(self, config):
275
+ MistralPreTrainedModel.__init__(self, config)
276
+ self.model = MistralBiModel(config)
277
+ self.vocab_size = config.vocab_size
278
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
279
+
280
+ # Initialize weights and apply final processing
281
+ self.post_init()
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
- "_name_or_path": "mistral-64M-mlm/checkpoint-60000",
3
  "architectures": [
4
  "MistralBiForMNTP"
5
  ],
6
  "attention_dropout": 0.0,
 
 
 
7
  "bos_token_id": 1,
8
  "eos_token_id": 2,
9
  "hidden_act": "silu",
 
1
  {
2
+ "_name_or_path": "mistral-64M-mlm/checkpoint-64000",
3
  "architectures": [
4
  "MistralBiForMNTP"
5
  ],
6
  "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModel": "bidirectional_mistral.MistralBiForMNTP"
9
+ },
10
  "bos_token_id": 1,
11
  "eos_token_id": 2,
12
  "hidden_act": "silu",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c46275cd03ee2867e64076edc9ad8f0b9698357eba8061aef34928a50be335d
3
  size 256944240
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9ab5196eb6b2b102a9cda8ab8b5244d95de18d6aa7f4fdf7f13aeabfffcd72d
3
  size 256944240