pere commited on
Commit
3a83068
·
1 Parent(s): 3f31643
README.md CHANGED
@@ -1,3 +1,27 @@
 
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
1
+
2
  ---
3
+ language:
4
+ - multilingual
5
+ - en
6
+ - fo
7
+ - is
8
+ - nn
9
+ - nb
10
+ - no
11
+ - da
12
+ - sv
13
+ license: cc-by-4.0
14
+ tags:
15
+ - norwegian
16
+ - bert
17
+ pipeline_tag: fill-mask
18
+ widget:
19
+ - text: På biblioteket kan du <mask> en bok.
20
+ - text: Dette er et <mask> eksempel.
21
+ - text: Av og til kan en språkmodell gi et <mask> resultat.
22
+ - text: Som ansat får du <mask> for at bidrage til borgernes adgang til dansk kulturarv, til forskning og til samfundets demokratiske udvikling.
23
  ---
24
+
25
+ # Scandinavian XLM-RoBERTa (base-sized model)
26
+
27
+ This model is currently being created. Do not use yet.
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "XLMRobertaForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "xlm-roberta",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "output_past": true,
21
+ "pad_token_id": 1,
22
+ "position_embedding_type": "absolute",
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.23.1",
25
+ "type_vocab_size": 1,
26
+ "use_cache": true,
27
+ "vocab_size": 250002
28
+ }
copy_tokenizer_config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code used for copying tokenizer and config
2
+
3
+ from transformers import XLMRobertaTokenizerFast, XLMRobertaConfig
4
+
5
+ tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base")
6
+ config = XLMRobertaConfig.from_pretrained("xlm-roberta-base")
7
+
8
+ tokenizer.save_pretrained("./")
9
+ config.save_pretrained("./")
10
+
11
+
generate_pt_model.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers import XLMRobertaForMaskedLM, XLMRobertaConfig
2
+ config = XLMRobertaConfig.from_pretrained("./")
3
+ model = XLMRobertaForMaskedLM.from_pretrained("./",config=config,from_flax=True)
4
+ model.save_pretrained("./")
5
+
6
+
run.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_mlm_flax_stream.py \
2
+ --output_dir="../roberta-base-exp-32" \
3
+ --model_name_or_path="xlm-roberta-base" \
4
+ --config_name="./" \
5
+ --tokenizer_name="./" \
6
+ --dataset_name="NbAiLab/scandinavian" \
7
+ --max_seq_length="512" \
8
+ --weight_decay="0.01" \
9
+ --per_device_train_batch_size="62" \
10
+ --per_device_eval_batch_size="62" \
11
+ --learning_rate="1e-4" \
12
+ --warmup_steps="10000" \
13
+ --overwrite_output_dir \
14
+ --num_train_steps="1000000" \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --logging_steps="5000" \
18
+ --save_steps="25000" \
19
+ --eval_steps="25000" \
20
+ --dtype="bfloat16" \
21
+ --push_to_hub
run_mlm_flax_stream.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from collections import defaultdict
28
+ from dataclasses import dataclass, field
29
+
30
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
31
+ from pathlib import Path
32
+ from typing import Dict, List, Optional, Tuple
33
+
34
+ import datasets
35
+ import numpy as np
36
+ from datasets import load_dataset
37
+ from tqdm import tqdm
38
+
39
+ import flax
40
+ import jax
41
+ import jax.numpy as jnp
42
+ import optax
43
+ from flax import jax_utils, traverse_util
44
+ from flax.training import train_state
45
+ from flax.training.common_utils import get_metrics, onehot, shard
46
+ from transformers import (
47
+ CONFIG_MAPPING,
48
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
49
+ AutoConfig,
50
+ AutoTokenizer,
51
+ FlaxAutoModelForMaskedLM,
52
+ HfArgumentParser,
53
+ PreTrainedTokenizerBase,
54
+ TensorType,
55
+ TrainingArguments,
56
+ is_tensorboard_available,
57
+ set_seed,
58
+ )
59
+
60
+ #from jax_smi import initialise_tracking
61
+ #initialise_tracking()
62
+
63
+
64
+ if datasets.__version__ <= "1.8.0":
65
+ raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
66
+
67
+
68
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
69
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
70
+
71
+
72
+ @dataclass
73
+ class ModelArguments:
74
+ """
75
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
76
+ """
77
+
78
+ model_name_or_path: Optional[str] = field(
79
+ default=None,
80
+ metadata={
81
+ "help": (
82
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
83
+ )
84
+ },
85
+ )
86
+ model_type: Optional[str] = field(
87
+ default=None,
88
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
89
+ )
90
+ config_name: Optional[str] = field(
91
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
92
+ )
93
+ tokenizer_name: Optional[str] = field(
94
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
95
+ )
96
+ cache_dir: Optional[str] = field(
97
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
98
+ )
99
+ use_fast_tokenizer: bool = field(
100
+ default=True,
101
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
102
+ )
103
+ dtype: Optional[str] = field(
104
+ default="float32",
105
+ metadata={
106
+ "help": (
107
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
108
+ " `[float32, float16, bfloat16]`."
109
+ )
110
+ },
111
+ )
112
+
113
+
114
+ @dataclass
115
+ class DataTrainingArguments:
116
+ """
117
+ Arguments pertaining to what data we are going to input our model for training and eval.
118
+ """
119
+
120
+ dataset_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
122
+ )
123
+ dataset_config_name: Optional[str] = field(
124
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
125
+ )
126
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
127
+ validation_file: Optional[str] = field(
128
+ default=None,
129
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
130
+ )
131
+ train_ref_file: Optional[str] = field(
132
+ default=None,
133
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
134
+ )
135
+ validation_ref_file: Optional[str] = field(
136
+ default=None,
137
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
138
+ )
139
+ overwrite_cache: bool = field(
140
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
141
+ )
142
+ validation_split_percentage: Optional[int] = field(
143
+ default=5,
144
+ metadata={
145
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
146
+ },
147
+ )
148
+ max_seq_length: Optional[int] = field(
149
+ default=None,
150
+ metadata={
151
+ "help": (
152
+ "The maximum total input sequence length after tokenization. Sequences longer "
153
+ "than this will be truncated. Default to the max input length of the model."
154
+ )
155
+ },
156
+ )
157
+ preprocessing_num_workers: Optional[int] = field(
158
+ default=None,
159
+ metadata={"help": "The number of processes to use for the preprocessing."},
160
+ )
161
+ mlm_probability: float = field(
162
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
163
+ )
164
+ pad_to_max_length: bool = field(
165
+ default=False,
166
+ metadata={
167
+ "help": (
168
+ "Whether to pad all samples to `max_seq_length`. "
169
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
170
+ )
171
+ },
172
+ )
173
+ line_by_line: bool = field(
174
+ default=False,
175
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
176
+ )
177
+ text_column_name: str = field(
178
+ default="text", metadata={"help": "The name of the column to retrieve the training text."}
179
+ )
180
+ shuffle_buffer_size: int = field(
181
+ default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
182
+ )
183
+ num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
184
+ num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
185
+
186
+ def __post_init__(self):
187
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
188
+ raise ValueError("Need either a dataset name or a training/validation file.")
189
+ else:
190
+ if self.train_file is not None:
191
+ extension = self.train_file.split(".")[-1]
192
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
193
+ if self.validation_file is not None:
194
+ extension = self.validation_file.split(".")[-1]
195
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
196
+
197
+
198
+ @flax.struct.dataclass
199
+ class FlaxDataCollatorForLanguageModeling:
200
+ """
201
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
202
+ are not all of the same length.
203
+
204
+ Args:
205
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
206
+ The tokenizer used for encoding the data.
207
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
208
+ The probability with which to (randomly) mask tokens in the input.
209
+
210
+ .. note::
211
+
212
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
213
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
214
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
215
+ argument :obj:`return_special_tokens_mask=True`.
216
+ """
217
+
218
+ tokenizer: PreTrainedTokenizerBase
219
+ mlm_probability: float = 0.15
220
+
221
+ def __post_init__(self):
222
+ if self.tokenizer.mask_token is None:
223
+ raise ValueError(
224
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
225
+ "You should pass `mlm=False` to train on causal language modeling instead."
226
+ )
227
+
228
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
229
+ # Handle dict or lists with proper padding and conversion to tensor.
230
+ batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
231
+
232
+ # If special token mask has been preprocessed, pop it from the dict.
233
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
234
+
235
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
236
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
237
+ )
238
+ return batch
239
+
240
+ def mask_tokens(
241
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
242
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
243
+ """
244
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
245
+ """
246
+ labels = inputs.copy()
247
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
248
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
249
+ special_tokens_mask = special_tokens_mask.astype("bool")
250
+
251
+ probability_matrix[special_tokens_mask] = 0.0
252
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
253
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
254
+
255
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
256
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
257
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
258
+
259
+ # 10% of the time, we replace masked input tokens with random word
260
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
261
+ indices_random &= masked_indices & ~indices_replaced
262
+
263
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
264
+ inputs[indices_random] = random_words[indices_random]
265
+
266
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
267
+ return inputs, labels
268
+
269
+
270
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
271
+ num_samples = len(samples_idx)
272
+ samples_to_remove = num_samples % batch_size
273
+
274
+ if samples_to_remove != 0:
275
+ samples_idx = samples_idx[:-samples_to_remove]
276
+ sections_split = num_samples // batch_size
277
+ batch_idx = np.split(samples_idx, sections_split)
278
+ return batch_idx
279
+
280
+
281
+ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
282
+ """
283
+ The training iterator is advanced so that after groupifying the samples,
284
+ `num_samples` of length `max_seq_length` are returned.
285
+ """
286
+ num_total_tokens = max_seq_length * num_samples
287
+ samples = defaultdict(list)
288
+
289
+ i = 0
290
+ while i < num_total_tokens:
291
+ tokenized_samples = next(train_iterator)
292
+ i += len(tokenized_samples["input_ids"])
293
+
294
+ # concatenate tokenized samples to list (excluding "id" and "text")
295
+ samples = {
296
+ k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
297
+ }
298
+
299
+ # Concatenated tokens are split to lists of length `max_seq_length`.
300
+ # Note that remainedr of % max_seq_length are thrown away.
301
+ def group_texts(examples):
302
+ result = {
303
+ k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
304
+ for k, t in examples.items()
305
+ }
306
+ return result
307
+
308
+ grouped_samples = group_texts(samples)
309
+ return grouped_samples
310
+
311
+
312
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
313
+ summary_writer.scalar("train_time", train_time, step)
314
+
315
+ train_metrics = get_metrics(train_metrics)
316
+ for key, vals in train_metrics.items():
317
+ tag = f"train_{key}"
318
+ for i, val in enumerate(vals):
319
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
320
+
321
+
322
+ def write_eval_metric(summary_writer, eval_metrics, step):
323
+ for metric_name, value in eval_metrics.items():
324
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
325
+
326
+
327
+ if __name__ == "__main__":
328
+ # See all possible arguments in src/transformers/training_args.py
329
+ # or by passing the --help flag to this script.
330
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
331
+
332
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
333
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
334
+ # If we pass only one argument to the script and it's the path to a json file,
335
+ # let's parse it to get our arguments.
336
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
337
+ else:
338
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
339
+
340
+ if (
341
+ os.path.exists(training_args.output_dir)
342
+ and os.listdir(training_args.output_dir)
343
+ and training_args.do_train
344
+ and not training_args.overwrite_output_dir
345
+ ):
346
+ raise ValueError(
347
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
348
+ "Use --overwrite_output_dir to overcome."
349
+ )
350
+
351
+ # Setup logging
352
+ logging.basicConfig(
353
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
354
+ level="INFO",
355
+ datefmt="[%X]",
356
+ )
357
+
358
+ # Log on each process the small summary:
359
+ logger = logging.getLogger(__name__)
360
+ logger.warning(
361
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
362
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
363
+ )
364
+
365
+ # Set the verbosity to info of the Transformers logger (on main process only):
366
+ logger.info(f"Training/evaluation parameters {training_args}")
367
+
368
+ # Set seed before initializing model.
369
+ set_seed(training_args.seed)
370
+
371
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
372
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
373
+ # (the dataset will be downloaded automatically from the datasets Hub).
374
+ #
375
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
376
+ # 'text' is found. You can easily tweak this behavior (see below).
377
+ if data_args.dataset_name is not None:
378
+ # Downloading and loading a dataset from the hub.
379
+ dataset = load_dataset(
380
+ data_args.dataset_name,
381
+ data_args.dataset_config_name,
382
+ cache_dir=model_args.cache_dir,
383
+ streaming=True,
384
+ use_auth_token=True,
385
+ split="train",
386
+ )
387
+
388
+ if model_args.config_name:
389
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
390
+ elif model_args.model_name_or_path:
391
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
392
+ else:
393
+ config = CONFIG_MAPPING[model_args.model_type]()
394
+ logger.warning("You are instantiating a new config instance from scratch.")
395
+
396
+ if model_args.tokenizer_name:
397
+ tokenizer = AutoTokenizer.from_pretrained(
398
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
399
+ )
400
+ elif model_args.model_name_or_path:
401
+ tokenizer = AutoTokenizer.from_pretrained(
402
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
403
+ )
404
+ else:
405
+ raise ValueError(
406
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
407
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
408
+ )
409
+
410
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
411
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
412
+ # efficient when it receives the `special_tokens_mask`.
413
+ def tokenize_function(examples):
414
+ return tokenizer(examples[data_args.text_column_name], return_special_tokens_mask=True, truncation=True)
415
+
416
+ tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=list(dataset.features.keys()))
417
+
418
+ shuffle_seed = training_args.seed
419
+ tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
420
+
421
+ has_tensorboard = is_tensorboard_available()
422
+ if has_tensorboard and jax.process_index() == 0:
423
+ try:
424
+ from flax.metrics.tensorboard import SummaryWriter
425
+ except ImportError as ie:
426
+ has_tensorboard = False
427
+ logger.warning(
428
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
429
+ )
430
+
431
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
432
+
433
+ # Data collator
434
+ # This one will take care of randomly masking the tokens.
435
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
436
+
437
+ # Initialize our training
438
+ rng = jax.random.PRNGKey(training_args.seed)
439
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
440
+
441
+ if model_args.model_name_or_path:
442
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
443
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
444
+ )
445
+ else:
446
+ model = FlaxAutoModelForMaskedLM.from_config(
447
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
448
+ )
449
+
450
+ # Store some constant
451
+ num_epochs = int(training_args.num_train_epochs)
452
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
453
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
454
+
455
+ # define number steps per stream epoch
456
+ num_train_steps = data_args.num_train_steps
457
+
458
+ # Create learning rate schedule
459
+ warmup_fn = optax.linear_schedule(
460
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
461
+ )
462
+ decay_fn = optax.linear_schedule(
463
+ init_value=training_args.learning_rate,
464
+ end_value=0,
465
+ transition_steps=num_train_steps - training_args.warmup_steps,
466
+ )
467
+ linear_decay_lr_schedule_fn = optax.join_schedules(
468
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
469
+ )
470
+
471
+ # We use Optax's "masking" functionality to not apply weight decay
472
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
473
+ # mask boolean with the same structure as the parameters.
474
+ # The mask is True for parameters that should be decayed.
475
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
476
+ # For other models, one should correct the layer norm parameter naming
477
+ # accordingly.
478
+ def decay_mask_fn(params):
479
+ flat_params = traverse_util.flatten_dict(params)
480
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
481
+ return traverse_util.unflatten_dict(flat_mask)
482
+
483
+ # create adam optimizer
484
+ adamw = optax.adamw(
485
+ learning_rate=linear_decay_lr_schedule_fn,
486
+ b1=training_args.adam_beta1,
487
+ b2=training_args.adam_beta2,
488
+ eps=training_args.adam_epsilon,
489
+ weight_decay=training_args.weight_decay,
490
+ mask=decay_mask_fn,
491
+ )
492
+
493
+ # Setup train state
494
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
495
+
496
+ # Define gradient update step fn
497
+ def train_step(state, batch, dropout_rng):
498
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
499
+
500
+ def loss_fn(params):
501
+ labels = batch.pop("labels")
502
+
503
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
504
+
505
+ # compute loss, ignore padded input tokens
506
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
507
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
508
+
509
+ # take average
510
+ loss = loss.sum() / label_mask.sum()
511
+
512
+ return loss
513
+
514
+ grad_fn = jax.value_and_grad(loss_fn)
515
+ loss, grad = grad_fn(state.params)
516
+ grad = jax.lax.pmean(grad, "batch")
517
+ new_state = state.apply_gradients(grads=grad)
518
+
519
+ metrics = jax.lax.pmean(
520
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
521
+ )
522
+
523
+ return new_state, metrics, new_dropout_rng
524
+
525
+ # Create parallel version of the train step
526
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
527
+
528
+ # Define eval fn
529
+ def eval_step(params, batch):
530
+ labels = batch.pop("labels")
531
+
532
+ logits = model(**batch, params=params, train=False)[0]
533
+
534
+ # compute loss, ignore padded input tokens
535
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
536
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
537
+
538
+ # compute accuracy
539
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
540
+
541
+ # summarize metrics
542
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
543
+ metrics = jax.lax.psum(metrics, axis_name="batch")
544
+
545
+ return metrics
546
+
547
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
548
+
549
+ # Replicate the train state on each device
550
+ state = jax_utils.replicate(state)
551
+
552
+ train_time = 0
553
+ train_start = time.time()
554
+ train_metrics = []
555
+ eval_metrics = []
556
+
557
+ training_iter = iter(tokenized_datasets)
558
+
559
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
560
+ eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
561
+
562
+ steps = tqdm(range(num_train_steps), desc="Training...", position=0)
563
+ for step in range(num_train_steps):
564
+ # ======================== Training ================================
565
+ try:
566
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
567
+ except StopIteration:
568
+ # Once the end of the dataset stream is reached, the training iterator
569
+ # is reinitialized and reshuffled and a new eval dataset is randomely chosen.
570
+ shuffle_seed += 1
571
+ tokenized_datasets.set_epoch(shuffle_seed)
572
+
573
+ training_iter = iter(tokenized_datasets)
574
+
575
+ eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
576
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
577
+
578
+ # process input samples
579
+ model_inputs = data_collator(samples)
580
+
581
+ # Model forward
582
+ model_inputs = shard(model_inputs.data)
583
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
584
+
585
+ train_metrics.append(train_metric)
586
+
587
+ if step % training_args.logging_steps == 0 and step > 0:
588
+ steps.write(
589
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
590
+ f" {train_metric['learning_rate'].mean()})"
591
+ )
592
+ train_time += time.time() - train_start
593
+ if has_tensorboard and jax.process_index() == 0:
594
+ write_train_metric(summary_writer, train_metrics, train_time, step)
595
+ train_metrics = []
596
+
597
+ # ======================== Evaluating ==============================
598
+ if step % training_args.eval_steps == 0 and step > 0:
599
+ # Avoid using jax.numpy here in case of TPU training
600
+ eval_samples_idx = np.arange(data_args.num_eval_samples)
601
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
602
+
603
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
604
+ # process input samples
605
+ batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
606
+ model_inputs = data_collator(batch_eval_samples)
607
+
608
+ # Model forward
609
+ model_inputs = shard(model_inputs.data)
610
+ metrics = p_eval_step(state.params, model_inputs)
611
+ eval_metrics.append(metrics)
612
+
613
+ # normalize eval metrics
614
+ eval_metrics = get_metrics(eval_metrics)
615
+ eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
616
+ eval_normalizer = eval_metrics.pop("normalizer")
617
+ eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
618
+
619
+ # Update progress bar
620
+ steps.desc = (
621
+ f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc:"
622
+ f" {eval_metrics['accuracy']})"
623
+ )
624
+
625
+ if has_tensorboard and jax.process_index() == 0:
626
+ write_eval_metric(summary_writer, eval_metrics, step)
627
+ eval_metrics = []
628
+
629
+ # save checkpoint after each epoch and push checkpoint to the hub
630
+ if jax.process_index() == 0:
631
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
632
+ print("*** Printing debug info")
633
+ print(training_args.output_dir)
634
+ print(training_args.push_to_hub)
635
+ try:
636
+ model.save_pretrained(
637
+ training_args.output_dir,
638
+ params=params,
639
+ push_to_hub=training_args.push_to_hub,
640
+ commit_message=f"Saving weights and logs of step {step+1}",
641
+ )
642
+ except:
643
+
644
+ model.save_pretrained(
645
+ training_args.output_dir,
646
+ params=params
647
+ )
648
+ print("Problems pushing this to the hub. The bug should be fixed.")
649
+
650
+ # update tqdm bar
651
+ steps.update(1)
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
+ size 5069051
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff