winglian commited on
Commit
6c81c61
·
unverified ·
1 Parent(s): 9b43e7e

refactor setup trainer so we can add more hooks (#773)

Browse files

* refactor setup trainer so we can add more hooks

* Remove stray comma

src/axolotl/core/__init__.py ADDED
File without changes
src/axolotl/core/trainer_builder.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Builder for the training args and trainer
3
+ """
4
+
5
+ import abc
6
+ import importlib
7
+ import logging
8
+ import math
9
+ import os
10
+ import sys
11
+ from abc import abstractmethod
12
+ from dataclasses import dataclass, field
13
+ from functools import partial
14
+ from pathlib import Path
15
+ from typing import Optional, Union
16
+
17
+ import torch
18
+ import transformers
19
+ from datasets import Dataset
20
+ from torch.optim.lr_scheduler import OneCycleLR
21
+ from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
22
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
23
+ from transformers.trainer_pt_utils import SequentialDistributedSampler
24
+
25
+ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
26
+ from axolotl.utils.callbacks import (
27
+ EvalFirstStepCallback,
28
+ GPUStatsCallback,
29
+ SaveAxolotlConfigtoWandBCallback,
30
+ SaveBetterTransformerModelCallback,
31
+ bench_eval_callback_factory,
32
+ log_prediction_callback_factory,
33
+ )
34
+ from axolotl.utils.collators import DataCollatorForSeq2Seq
35
+ from axolotl.utils.dataloader import MultipackDistributedDataloader
36
+ from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
37
+
38
+ try:
39
+ import torch._dynamo # pylint: disable=ungrouped-imports
40
+ except ImportError:
41
+ pass
42
+
43
+ LOG = logging.getLogger("axolotl.core.trainer_builder")
44
+
45
+
46
+ @dataclass
47
+ class AxolotlTrainingArguments(TrainingArguments):
48
+ """
49
+ Extend the base TrainingArguments for axolotl helpers
50
+ """
51
+
52
+ lr_quadratic_warmup: bool = field(
53
+ default=False,
54
+ metadata={"help": "Use quadratic warmup for cosine scheduling."},
55
+ )
56
+ sample_packing: bool = field(
57
+ default=False,
58
+ metadata={"help": "Use sample packing for efficient training."},
59
+ )
60
+ eval_sample_packing: Optional[bool] = field(
61
+ default=None,
62
+ metadata={"help": "Use sample packing for efficient evals."},
63
+ )
64
+ sample_packing_efficiency: float = field(
65
+ default=1.0,
66
+ metadata={"help": "Sample packing efficiency for calculating batch length."},
67
+ )
68
+ max_seq_length: int = field(
69
+ default=2048,
70
+ metadata={"help": "The maximum sequence length the model can handle"},
71
+ )
72
+ sample_packing_seq_len_multiplier: int = field(
73
+ default=1,
74
+ metadata={"help": "the multiplier for the max len for packed sequences"},
75
+ )
76
+ relora_steps: Optional[int] = field(
77
+ default=None,
78
+ metadata={"help": "how often to reset for ReLoRA"},
79
+ )
80
+ relora_warmup_steps: Optional[int] = field(
81
+ default=None,
82
+ metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
83
+ )
84
+ bench_split: Optional[str] = field(
85
+ default="eval", metadata={"help": "The benchmark split to run on"}
86
+ )
87
+ bench_dataset: Optional[str] = field(
88
+ default="pharaouk/dharma-1/dharma_1_mini.json",
89
+ metadata={
90
+ "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
91
+ },
92
+ )
93
+ do_bench_eval: Optional[bool] = field(
94
+ default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
95
+ )
96
+ max_bench_samples: Optional[int] = field(
97
+ default=None,
98
+ metadata={
99
+ "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
100
+ },
101
+ )
102
+ bench_source_max_len: int = field(
103
+ default=2048, metadata={"help": "Maximum source sequence length for bench."}
104
+ )
105
+
106
+
107
+ class AxolotlTrainer(Trainer):
108
+ """
109
+ Extend the base Trainer for axolotl helpers
110
+ """
111
+
112
+ args = None # type: AxolotlTrainingArguments
113
+
114
+ def __init__(self, *args, bench_data_collator=None, **kwargs):
115
+ self.bench_data_collator = bench_data_collator
116
+ super().__init__(*args, **kwargs)
117
+
118
+ def create_scheduler(
119
+ self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
120
+ ):
121
+ """
122
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
123
+ passed as an argument.
124
+
125
+ Args:
126
+ num_training_steps (int): The number of training steps to do.
127
+ optimizer (torch.optim.Optimizer): The training optimizer
128
+ """
129
+
130
+ # fmt: off
131
+ if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
132
+ # fmt: on
133
+ if (
134
+ self.args.lr_scheduler_type == "cosine"
135
+ and self.args.lr_quadratic_warmup is True
136
+ ):
137
+ self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
138
+ optimizer,
139
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
140
+ num_training_steps=num_training_steps,
141
+ )
142
+ else:
143
+ return super().create_scheduler(num_training_steps, optimizer)
144
+ return self.lr_scheduler
145
+
146
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
147
+ if self.args.world_size > 1 and self.args.sample_packing:
148
+ return DistributedSampler(
149
+ self.train_dataset,
150
+ num_replicas=self.args.world_size,
151
+ rank=self.args.process_index,
152
+ seed=self.args.seed,
153
+ )
154
+ return super()._get_train_sampler()
155
+
156
+ def _get_eval_sampler(
157
+ self, eval_dataset: Dataset
158
+ ) -> Optional[torch.utils.data.Sampler]:
159
+ if (
160
+ self.args.world_size > 1
161
+ and self.args.sample_packing
162
+ and self.args.eval_sample_packing is not False
163
+ ):
164
+ return SequentialDistributedSampler(
165
+ eval_dataset,
166
+ num_replicas=self.args.world_size,
167
+ rank=self.args.process_index,
168
+ batch_size=self.args.per_device_eval_batch_size,
169
+ )
170
+ return super()._get_eval_sampler(eval_dataset)
171
+
172
+ def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
173
+ if self.args.sample_packing:
174
+ train_sampler = self._get_train_sampler()
175
+ return self.accelerator.prepare(
176
+ MultipackDistributedDataloader(
177
+ self.train_dataset,
178
+ batch_size=self._train_batch_size,
179
+ seq_max_length=self.args.max_seq_length,
180
+ collate_fn=self.data_collator,
181
+ sampler=train_sampler,
182
+ packing_efficiency_estimate=self.args.sample_packing_efficiency,
183
+ sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
184
+ device_count=int(os.environ.get("WORLD_SIZE", 1)),
185
+ )
186
+ )
187
+ return super().get_train_dataloader()
188
+
189
+ def get_eval_dataloader(
190
+ self, eval_dataset: Optional[Dataset] = None
191
+ ) -> Union[DataLoader, MultipackDistributedDataloader]:
192
+ if self.args.sample_packing and self.args.eval_sample_packing is not False:
193
+ eval_dataset = (
194
+ eval_dataset if eval_dataset is not None else self.eval_dataset
195
+ )
196
+
197
+ eval_sampler = self._get_eval_sampler(eval_dataset)
198
+ return self.accelerator.prepare(
199
+ MultipackDistributedDataloader(
200
+ eval_dataset,
201
+ batch_size=self.args.eval_batch_size,
202
+ seq_max_length=self.args.max_seq_length,
203
+ collate_fn=self.data_collator,
204
+ sampler=eval_sampler,
205
+ packing_efficiency_estimate=self.args.sample_packing_efficiency,
206
+ sample_packing_seq_len_multiplier=self.args.eval_batch_size,
207
+ device_count=int(os.environ.get("WORLD_SIZE", 1)),
208
+ )
209
+ )
210
+ return super().get_eval_dataloader(eval_dataset)
211
+
212
+ def _get_bench_sampler(
213
+ self, bench_dataset: Dataset
214
+ ) -> Optional[torch.utils.data.Sampler]:
215
+ if self.args.world_size <= 1:
216
+ return SequentialSampler(bench_dataset)
217
+ return None
218
+
219
+ def get_bench_dataloader(
220
+ self,
221
+ bench_dataset: Dataset,
222
+ ) -> Union[DataLoader, MultipackDistributedDataloader]:
223
+ dataloader_params = {
224
+ "batch_size": self.args.eval_batch_size,
225
+ "collate_fn": self.bench_data_collator,
226
+ "num_workers": self.args.dataloader_num_workers,
227
+ "pin_memory": self.args.dataloader_pin_memory,
228
+ }
229
+
230
+ if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
231
+ dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
232
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
233
+
234
+ return DataLoader(bench_dataset, **dataloader_params)
235
+ # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
236
+
237
+ def compute_loss(self, model, inputs, return_outputs=False):
238
+ # use one's weighted cross entropy loss calc
239
+ # if self.args.sample_packing:
240
+ # labels = inputs.pop("labels")
241
+ # outputs = model(**inputs)
242
+ # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
243
+ # return (loss, outputs) if return_outputs else loss
244
+ return super().compute_loss(model, inputs, return_outputs=return_outputs)
245
+
246
+
247
+ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
248
+ """
249
+ Trainer subclass that uses the OneCycleLR scheduler
250
+ """
251
+
252
+ def __init__(self, *args, **kwargs):
253
+ super().__init__(*args, **kwargs)
254
+ self.lr_scheduler = None
255
+
256
+ def create_scheduler(
257
+ self,
258
+ num_training_steps: int,
259
+ optimizer: Optional[torch.optim.Optimizer] = None,
260
+ ):
261
+ optimizer = self.optimizer if optimizer is None else optimizer
262
+ num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
263
+ pct_start = num_warmup_steps / num_training_steps
264
+
265
+ self.lr_scheduler = OneCycleLR(
266
+ optimizer,
267
+ max_lr=self.args.learning_rate,
268
+ total_steps=num_training_steps,
269
+ pct_start=pct_start,
270
+ div_factor=6,
271
+ )
272
+
273
+ return self.lr_scheduler
274
+
275
+
276
+ class ReLoRATrainer(AxolotlTrainer):
277
+ """
278
+ Trainer subclass that uses the OneCycleLR scheduler
279
+ """
280
+
281
+ def __init__(self, *args, **kwargs):
282
+ super().__init__(*args, **kwargs)
283
+ self.lr_scheduler = None
284
+
285
+ def create_scheduler(
286
+ self,
287
+ num_training_steps: int,
288
+ optimizer: Optional[torch.optim.Optimizer] = None,
289
+ ):
290
+ optimizer = self.optimizer if optimizer is None else optimizer
291
+ lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
292
+
293
+ if self.args.relora_steps:
294
+ warmup_steps = (
295
+ self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
296
+ )
297
+ self.lr_scheduler = ReLoRAScheduler(
298
+ optimizer,
299
+ lr_scheduler,
300
+ self.args.relora_steps,
301
+ warmup_steps,
302
+ )
303
+ else:
304
+ self.lr_scheduler = lr_scheduler
305
+
306
+ return self.lr_scheduler
307
+
308
+
309
+ class TrainerBuilderBase(abc.ABC):
310
+ """
311
+ Base class for trainer builder
312
+ """
313
+
314
+ _train_dataset = None
315
+ _eval_dataset = None
316
+
317
+ def __init__(self, cfg, model, tokenizer):
318
+ self.cfg = cfg
319
+ self.model = model
320
+ self.tokenizer = tokenizer
321
+
322
+ @property
323
+ def train_dataset(self):
324
+ return self._train_dataset
325
+
326
+ @train_dataset.setter
327
+ def train_dataset(self, dataset):
328
+ self._train_dataset = dataset
329
+
330
+ @property
331
+ def eval_dataset(self):
332
+ return self._eval_dataset
333
+
334
+ @eval_dataset.setter
335
+ def eval_dataset(self, dataset):
336
+ self._eval_dataset = dataset
337
+
338
+ @abstractmethod
339
+ def build(self, total_num_steps):
340
+ pass
341
+
342
+ @abstractmethod
343
+ def get_callbacks(self):
344
+ pass
345
+
346
+ @abstractmethod
347
+ def get_post_trainer_create_callbacks(self, trainer):
348
+ """
349
+ Callbacks added after the trainer is created, usually b/c these need access to the trainer
350
+ """
351
+
352
+
353
+ class HFCausalTrainerBuilder(TrainerBuilderBase):
354
+ """
355
+ Build the HuggingFace training args/trainer for Causal models
356
+ """
357
+
358
+ def hook_pre_create_training_args(self, training_arguments_kwargs):
359
+ # TODO
360
+ return training_arguments_kwargs
361
+
362
+ def hook_post_create_training_args(self, training_arguments):
363
+ # TODO
364
+ return training_arguments
365
+
366
+ def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
367
+ # TODO
368
+ return trainer_kwargs, trainer_cls
369
+
370
+ def hook_post_create_trainer(self, trainer):
371
+ # TODO
372
+ return trainer
373
+
374
+ def get_callbacks(self):
375
+ callbacks = []
376
+ callbacks.append(GPUStatsCallback(self.cfg))
377
+ callbacks.append(EvalFirstStepCallback)
378
+
379
+ if self.cfg.relora_steps:
380
+ callbacks.append(ReLoRACallback(self.cfg))
381
+
382
+ if (
383
+ hasattr(self.model, "use_bettertransformer")
384
+ and self.model.use_bettertransformer is True
385
+ ):
386
+ callbacks.append(SaveBetterTransformerModelCallback)
387
+
388
+ if self.cfg.use_wandb:
389
+ callbacks.append(
390
+ SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
391
+ )
392
+
393
+ return callbacks
394
+
395
+ def get_post_trainer_create_callbacks(self, trainer):
396
+ callbacks = []
397
+ if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
398
+ LogPredictionCallback = log_prediction_callback_factory(
399
+ trainer, self.tokenizer
400
+ )
401
+ callbacks.append(LogPredictionCallback(self.cfg))
402
+
403
+ if self.cfg.do_bench_eval:
404
+ callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
405
+
406
+ if self.cfg.early_stopping_patience:
407
+ early_stop_cb = EarlyStoppingCallback(
408
+ self.cfg.early_stopping_patience,
409
+ )
410
+ callbacks.append(early_stop_cb)
411
+
412
+ return callbacks
413
+
414
+ def _get_trainer_cls(self):
415
+ if self.cfg.lr_scheduler == "one_cycle" and (
416
+ self.cfg.fsdp or self.cfg.adapter == "qlora"
417
+ ):
418
+ return OneCycleLRSchedulerTrainer
419
+ if self.cfg.relora_steps:
420
+ return ReLoRATrainer
421
+ return AxolotlTrainer
422
+
423
+ def build(self, total_num_steps):
424
+ warmup_steps = (
425
+ self.cfg.warmup_steps
426
+ if self.cfg.warmup_steps is not None
427
+ else min(int(0.03 * total_num_steps), 100)
428
+ )
429
+ logging_steps = (
430
+ self.cfg.logging_steps
431
+ if self.cfg.logging_steps is not None
432
+ else max(min(int(0.005 * total_num_steps), 10), 1)
433
+ )
434
+
435
+ training_arguments_kwargs = {}
436
+ if self.cfg.bf16 == "full":
437
+ training_arguments_kwargs["bf16_full_eval"] = True
438
+ else:
439
+ training_arguments_kwargs["bf16"] = self.cfg.bf16
440
+ training_arguments_kwargs["fp16"] = (
441
+ self.cfg.fp16 and not self.cfg.bf16
442
+ ) or False
443
+ training_arguments_kwargs["tf32"] = self.cfg.tf32
444
+ training_arguments_kwargs["warmup_steps"] = warmup_steps
445
+ training_arguments_kwargs["logging_steps"] = logging_steps
446
+
447
+ if self.cfg.seed:
448
+ training_arguments_kwargs["seed"] = self.cfg.seed
449
+
450
+ if self.cfg.gradient_checkpointing:
451
+ training_arguments_kwargs[
452
+ "gradient_checkpointing"
453
+ ] = self.cfg.gradient_checkpointing
454
+ if self.cfg.fsdp:
455
+ training_arguments_kwargs["fsdp"] = self.cfg.fsdp
456
+ if self.cfg.fsdp_config:
457
+ training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
458
+
459
+ # deepspeed
460
+ if self.cfg.deepspeed:
461
+ training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
462
+
463
+ if self.cfg.lr_quadratic_warmup is not None:
464
+ training_arguments_kwargs[
465
+ "lr_quadratic_warmup"
466
+ ] = self.cfg.lr_quadratic_warmup
467
+
468
+ if self.cfg.adam_beta1:
469
+ training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
470
+ if self.cfg.adam_beta2:
471
+ training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
472
+ if self.cfg.adam_epsilon:
473
+ training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
474
+ if self.cfg.max_grad_norm:
475
+ training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
476
+
477
+ if self.cfg.hub_model_id:
478
+ training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
479
+ training_arguments_kwargs["push_to_hub"] = True
480
+ training_arguments_kwargs["hub_private_repo"] = True
481
+
482
+ if self.cfg.hub_strategy:
483
+ training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
484
+
485
+ if self.cfg.save_safetensors:
486
+ training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
487
+
488
+ if self.cfg.sample_packing_eff_est:
489
+ training_arguments_kwargs[
490
+ "sample_packing_efficiency"
491
+ ] = self.cfg.sample_packing_eff_est
492
+
493
+ if self.cfg.eval_steps:
494
+ training_arguments_kwargs["evaluation_strategy"] = "steps"
495
+ training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
496
+ elif self.cfg.evaluation_strategy:
497
+ training_arguments_kwargs[
498
+ "evaluation_strategy"
499
+ ] = self.cfg.evaluation_strategy
500
+ elif self.cfg.val_set_size == 0:
501
+ # no eval set, so don't eval
502
+ training_arguments_kwargs["evaluation_strategy"] = "no"
503
+ else:
504
+ # we have an eval set, but no steps defined, default to use epoch
505
+ training_arguments_kwargs["evaluation_strategy"] = "epoch"
506
+
507
+ if self.cfg.save_steps:
508
+ training_arguments_kwargs["save_strategy"] = "steps"
509
+ training_arguments_kwargs["save_steps"] = self.cfg.save_steps
510
+ elif self.cfg.save_strategy:
511
+ training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
512
+ else:
513
+ # default to saving each epoch if not defined
514
+ training_arguments_kwargs["save_strategy"] = "epoch"
515
+
516
+ if self.cfg.do_bench_eval:
517
+ training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
518
+ if self.cfg.bench_dataset:
519
+ training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
520
+ if self.cfg.metric_for_best_model:
521
+ training_arguments_kwargs[
522
+ "metric_for_best_model"
523
+ ] = self.cfg.metric_for_best_model
524
+ if self.cfg.greater_is_better:
525
+ training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
526
+
527
+ if self.cfg.torch_compile:
528
+ if torch.__version__ < "2.1.0": # pylint: disable=protected-access
529
+ LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
530
+ elif torch._dynamo: # pylint: disable=protected-access
531
+ torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
532
+ True
533
+ )
534
+ training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
535
+ if self.cfg.torch_compile_backend:
536
+ training_arguments_kwargs[
537
+ "torch_compile_backend"
538
+ ] = self.cfg.torch_compile_backend
539
+
540
+ # DDP Config
541
+ if self.cfg.ddp_timeout:
542
+ training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
543
+ # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
544
+ if self.cfg.ddp_bucket_cap_mb:
545
+ training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
546
+ if self.cfg.ddp_broadcast_buffers is not None:
547
+ training_arguments_kwargs[
548
+ "ddp_broadcast_buffers"
549
+ ] = self.cfg.ddp_broadcast_buffers
550
+
551
+ # these are all the "standard" kwargs that are def used
552
+ training_arguments_kwargs["max_steps"] = (
553
+ total_num_steps if self.cfg.max_steps else -1
554
+ )
555
+ training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
556
+ training_arguments_kwargs[
557
+ "per_device_train_batch_size"
558
+ ] = self.cfg.micro_batch_size
559
+ training_arguments_kwargs[
560
+ "per_device_eval_batch_size"
561
+ ] = self.cfg.eval_batch_size
562
+ training_arguments_kwargs[
563
+ "gradient_accumulation_steps"
564
+ ] = self.cfg.gradient_accumulation_steps
565
+ training_arguments_kwargs[
566
+ "eval_accumulation_steps"
567
+ ] = self.cfg.gradient_accumulation_steps
568
+ training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
569
+ training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
570
+ training_arguments_kwargs["output_dir"] = self.cfg.output_dir
571
+ training_arguments_kwargs["save_total_limit"] = (
572
+ self.cfg.save_total_limit if self.cfg.save_total_limit else 4
573
+ )
574
+ training_arguments_kwargs["load_best_model_at_end"] = (
575
+ (
576
+ self.cfg.load_best_model_at_end is not False
577
+ or self.cfg.early_stopping_patience
578
+ )
579
+ and self.cfg.val_set_size > 0
580
+ and self.cfg.save_steps
581
+ and self.cfg.eval_steps
582
+ and self.cfg.save_steps % self.cfg.eval_steps == 0
583
+ ) or False
584
+ training_arguments_kwargs["ddp_find_unused_parameters"] = (
585
+ False if self.cfg.ddp else None
586
+ )
587
+ training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
588
+ training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
589
+ training_arguments_kwargs["run_name"] = (
590
+ self.cfg.wandb_run_id if self.cfg.use_wandb else None
591
+ )
592
+ training_arguments_kwargs["optim"] = (
593
+ self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
594
+ )
595
+ training_arguments_kwargs["lr_scheduler_type"] = (
596
+ self.cfg.lr_scheduler
597
+ if self.cfg.lr_scheduler
598
+ and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
599
+ else "cosine"
600
+ )
601
+ training_arguments_kwargs["weight_decay"] = (
602
+ self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
603
+ )
604
+ training_arguments_kwargs["sample_packing"] = (
605
+ self.cfg.sample_packing if self.cfg.sample_packing else False
606
+ )
607
+ training_arguments_kwargs["eval_sample_packing"] = (
608
+ self.cfg.sample_packing if self.cfg.sample_packing else False
609
+ )
610
+ training_arguments_kwargs[
611
+ "sample_packing_seq_len_multiplier"
612
+ ] = self.cfg.micro_batch_size
613
+ training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
614
+ training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
615
+ training_arguments_kwargs = self.hook_pre_create_training_args(
616
+ training_arguments_kwargs
617
+ )
618
+ training_args = (
619
+ AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
620
+ **training_arguments_kwargs,
621
+ )
622
+ )
623
+ training_args = self.hook_post_create_training_args(training_args)
624
+ trainer_kwargs = {}
625
+
626
+ if self.cfg.optimizer == "adamw_anyprecision":
627
+ if Path(self.cfg.torchdistx_path).exists():
628
+ sys.path.append(self.cfg.torchdistx_path)
629
+ importlib.import_module("torchdistx")
630
+
631
+ data_collator_kwargs = {
632
+ "padding": True, # True/"longest" is the default
633
+ }
634
+ if self.cfg.pad_to_sequence_len:
635
+ data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
636
+ self.cfg.sequence_len / 64
637
+ )
638
+ else:
639
+ # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
640
+ # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
641
+ data_collator_kwargs["pad_to_multiple_of"] = 64
642
+
643
+ if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
644
+ from axolotl.monkeypatch.llama_landmark_attn import (
645
+ add_mem_tokens,
646
+ get_mem_id,
647
+ set_model_mem_id,
648
+ )
649
+
650
+ set_model_mem_id(self.model, self.tokenizer)
651
+
652
+ LOG.info("Adding landmark attention tokens to dataset")
653
+
654
+ for dataset in [self.train_dataset, self.eval_dataset]:
655
+ dataset = dataset.map(
656
+ partial(
657
+ add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
658
+ ),
659
+ batched=False,
660
+ num_proc=32,
661
+ )
662
+
663
+ trainer_cls = self._get_trainer_cls()
664
+ trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
665
+ trainer_kwargs, trainer_cls
666
+ )
667
+ trainer = trainer_cls(
668
+ model=self.model,
669
+ train_dataset=self.train_dataset,
670
+ eval_dataset=self.eval_dataset,
671
+ args=training_args,
672
+ data_collator=DataCollatorForSeq2Seq(
673
+ self.tokenizer,
674
+ return_tensors="pt",
675
+ **data_collator_kwargs,
676
+ ),
677
+ bench_data_collator=transformers.DataCollatorForSeq2Seq(
678
+ self.tokenizer,
679
+ return_tensors="pt",
680
+ **data_collator_kwargs,
681
+ ),
682
+ callbacks=self.get_callbacks(),
683
+ **trainer_kwargs,
684
+ )
685
+ trainer = self.hook_post_create_trainer(trainer)
686
+ for callback in self.get_post_trainer_create_callbacks(trainer):
687
+ trainer.add_callback(callback)
688
+
689
+ return trainer
src/axolotl/utils/callbacks.py CHANGED
@@ -37,7 +37,7 @@ from axolotl.utils.distributed import (
37
  )
38
 
39
  if TYPE_CHECKING:
40
- from axolotl.utils.trainer import AxolotlTrainingArguments
41
 
42
  LOG = logging.getLogger("axolotl.callbacks")
43
  IGNORE_INDEX = -100
 
37
  )
38
 
39
  if TYPE_CHECKING:
40
+ from axolotl.core.trainer_builder import AxolotlTrainingArguments
41
 
42
  LOG = logging.getLogger("axolotl.callbacks")
43
  IGNORE_INDEX = -100
src/axolotl/utils/trainer.py CHANGED
@@ -1,40 +1,19 @@
1
  """Module containing the Trainer class and related functions"""
2
- import importlib
3
  import logging
4
  import math
5
  import os
6
- import sys
7
  from contextlib import contextmanager
8
- from dataclasses import dataclass, field
9
  from functools import partial
10
- from pathlib import Path
11
- from typing import List, Optional, Union
12
 
13
  import numpy as np
14
  import torch
15
  import torch.cuda
16
  import torch.distributed as dist
17
- import transformers
18
- from datasets import Dataset, set_caching_enabled
19
- from torch.optim.lr_scheduler import OneCycleLR
20
- from torch.utils.data import (
21
- DataLoader,
22
- DistributedSampler,
23
- RandomSampler,
24
- SequentialSampler,
25
- )
26
- from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
27
- from transformers.trainer_pt_utils import SequentialDistributedSampler
28
-
29
- from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
30
- from axolotl.utils.callbacks import (
31
- EvalFirstStepCallback,
32
- GPUStatsCallback,
33
- SaveAxolotlConfigtoWandBCallback,
34
- SaveBetterTransformerModelCallback,
35
- bench_eval_callback_factory,
36
- log_prediction_callback_factory,
37
- )
38
  from axolotl.utils.collators import DataCollatorForSeq2Seq
39
  from axolotl.utils.dataloader import MultipackDistributedDataloader
40
  from axolotl.utils.distributed import (
@@ -43,7 +22,6 @@ from axolotl.utils.distributed import (
43
  reduce_and_broadcast,
44
  zero_first,
45
  )
46
- from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
47
 
48
  LOG = logging.getLogger("axolotl")
49
 
@@ -110,269 +88,6 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
110
  return weighted_cross_entropy(logits, labels, weights)
111
 
112
 
113
- @dataclass
114
- class AxolotlTrainingArguments(TrainingArguments):
115
- """
116
- Extend the base TrainingArguments for axolotl helpers
117
- """
118
-
119
- lr_quadratic_warmup: bool = field(
120
- default=False,
121
- metadata={"help": "Use quadratic warmup for cosine scheduling."},
122
- )
123
- sample_packing: bool = field(
124
- default=False,
125
- metadata={"help": "Use sample packing for efficient training."},
126
- )
127
- eval_sample_packing: Optional[bool] = field(
128
- default=None,
129
- metadata={"help": "Use sample packing for efficient evals."},
130
- )
131
- sample_packing_efficiency: float = field(
132
- default=1.0,
133
- metadata={"help": "Sample packing efficiency for calculating batch length."},
134
- )
135
- max_seq_length: int = field(
136
- default=2048,
137
- metadata={"help": "The maximum sequence length the model can handle"},
138
- )
139
- sample_packing_seq_len_multiplier: int = field(
140
- default=1,
141
- metadata={"help": "the multiplier for the max len for packed sequences"},
142
- )
143
- relora_steps: Optional[int] = field(
144
- default=None,
145
- metadata={"help": "how often to reset for ReLoRA"},
146
- )
147
- relora_warmup_steps: Optional[int] = field(
148
- default=None,
149
- metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
150
- )
151
- bench_split: Optional[str] = field(
152
- default="eval", metadata={"help": "The benchmark split to run on"}
153
- )
154
- bench_dataset: Optional[str] = field(
155
- default="pharaouk/dharma-1/dharma_1_mini.json",
156
- metadata={
157
- "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
158
- },
159
- )
160
- do_bench_eval: Optional[bool] = field(
161
- default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
162
- )
163
- max_bench_samples: Optional[int] = field(
164
- default=None,
165
- metadata={
166
- "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
167
- },
168
- )
169
- bench_source_max_len: int = field(
170
- default=2048, metadata={"help": "Maximum source sequence length for bench."}
171
- )
172
-
173
-
174
- class AxolotlTrainer(Trainer):
175
- """
176
- Extend the base Trainer for axolotl helpers
177
- """
178
-
179
- args = None # type: AxolotlTrainingArguments
180
-
181
- def __init__(self, *args, bench_data_collator=None, **kwargs):
182
- self.bench_data_collator = bench_data_collator
183
- super().__init__(*args, **kwargs)
184
-
185
- def create_scheduler(
186
- self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
187
- ):
188
- """
189
- Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
190
- passed as an argument.
191
-
192
- Args:
193
- num_training_steps (int): The number of training steps to do.
194
- optimizer (torch.optim.Optimizer): The training optimizer
195
- """
196
-
197
- # fmt: off
198
- if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
199
- # fmt: on
200
- if (
201
- self.args.lr_scheduler_type == "cosine"
202
- and self.args.lr_quadratic_warmup is True
203
- ):
204
- self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
205
- optimizer,
206
- num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
207
- num_training_steps=num_training_steps,
208
- )
209
- else:
210
- return super().create_scheduler(num_training_steps, optimizer)
211
- return self.lr_scheduler
212
-
213
- def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
214
- if self.args.world_size > 1 and self.args.sample_packing:
215
- return DistributedSampler(
216
- self.train_dataset,
217
- num_replicas=self.args.world_size,
218
- rank=self.args.process_index,
219
- seed=self.args.seed,
220
- )
221
- return super()._get_train_sampler()
222
-
223
- def _get_eval_sampler(
224
- self, eval_dataset: Dataset
225
- ) -> Optional[torch.utils.data.Sampler]:
226
- if (
227
- self.args.world_size > 1
228
- and self.args.sample_packing
229
- and self.args.eval_sample_packing is not False
230
- ):
231
- return SequentialDistributedSampler(
232
- eval_dataset,
233
- num_replicas=self.args.world_size,
234
- rank=self.args.process_index,
235
- batch_size=self.args.per_device_eval_batch_size,
236
- )
237
- return super()._get_eval_sampler(eval_dataset)
238
-
239
- def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
240
- if self.args.sample_packing:
241
- train_sampler = self._get_train_sampler()
242
- return self.accelerator.prepare(
243
- MultipackDistributedDataloader(
244
- self.train_dataset,
245
- batch_size=self._train_batch_size,
246
- seq_max_length=self.args.max_seq_length,
247
- collate_fn=self.data_collator,
248
- sampler=train_sampler,
249
- packing_efficiency_estimate=self.args.sample_packing_efficiency,
250
- sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
251
- device_count=int(os.environ.get("WORLD_SIZE", 1)),
252
- )
253
- )
254
- return super().get_train_dataloader()
255
-
256
- def get_eval_dataloader(
257
- self, eval_dataset: Optional[Dataset] = None
258
- ) -> Union[DataLoader, MultipackDistributedDataloader]:
259
- if self.args.sample_packing and self.args.eval_sample_packing is not False:
260
- eval_dataset = (
261
- eval_dataset if eval_dataset is not None else self.eval_dataset
262
- )
263
-
264
- eval_sampler = self._get_eval_sampler(eval_dataset)
265
- return self.accelerator.prepare(
266
- MultipackDistributedDataloader(
267
- eval_dataset,
268
- batch_size=self.args.eval_batch_size,
269
- seq_max_length=self.args.max_seq_length,
270
- collate_fn=self.data_collator,
271
- sampler=eval_sampler,
272
- packing_efficiency_estimate=self.args.sample_packing_efficiency,
273
- sample_packing_seq_len_multiplier=self.args.eval_batch_size,
274
- device_count=int(os.environ.get("WORLD_SIZE", 1)),
275
- )
276
- )
277
- return super().get_eval_dataloader(eval_dataset)
278
-
279
- def _get_bench_sampler(
280
- self, bench_dataset: Dataset
281
- ) -> Optional[torch.utils.data.Sampler]:
282
- if self.args.world_size <= 1:
283
- return SequentialSampler(bench_dataset)
284
- return None
285
-
286
- def get_bench_dataloader(
287
- self,
288
- bench_dataset: Dataset,
289
- ) -> Union[DataLoader, MultipackDistributedDataloader]:
290
- dataloader_params = {
291
- "batch_size": self.args.eval_batch_size,
292
- "collate_fn": self.bench_data_collator,
293
- "num_workers": self.args.dataloader_num_workers,
294
- "pin_memory": self.args.dataloader_pin_memory,
295
- }
296
-
297
- if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
298
- dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
299
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
300
-
301
- return DataLoader(bench_dataset, **dataloader_params)
302
- # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
303
-
304
- def compute_loss(self, model, inputs, return_outputs=False):
305
- # use one's weighted cross entropy loss calc
306
- # if self.args.sample_packing:
307
- # labels = inputs.pop("labels")
308
- # outputs = model(**inputs)
309
- # loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
310
- # return (loss, outputs) if return_outputs else loss
311
- return super().compute_loss(model, inputs, return_outputs=return_outputs)
312
-
313
-
314
- class OneCycleLRSchedulerTrainer(AxolotlTrainer):
315
- """
316
- Trainer subclass that uses the OneCycleLR scheduler
317
- """
318
-
319
- def __init__(self, *args, **kwargs):
320
- super().__init__(*args, **kwargs)
321
- self.lr_scheduler = None
322
-
323
- def create_scheduler(
324
- self,
325
- num_training_steps: int,
326
- optimizer: Optional[torch.optim.Optimizer] = None,
327
- ):
328
- optimizer = self.optimizer if optimizer is None else optimizer
329
- num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
330
- pct_start = num_warmup_steps / num_training_steps
331
-
332
- self.lr_scheduler = OneCycleLR(
333
- optimizer,
334
- max_lr=self.args.learning_rate,
335
- total_steps=num_training_steps,
336
- pct_start=pct_start,
337
- div_factor=6,
338
- )
339
-
340
- return self.lr_scheduler
341
-
342
-
343
- class ReLoRATrainer(AxolotlTrainer):
344
- """
345
- Trainer subclass that uses the OneCycleLR scheduler
346
- """
347
-
348
- def __init__(self, *args, **kwargs):
349
- super().__init__(*args, **kwargs)
350
- self.lr_scheduler = None
351
-
352
- def create_scheduler(
353
- self,
354
- num_training_steps: int,
355
- optimizer: Optional[torch.optim.Optimizer] = None,
356
- ):
357
- optimizer = self.optimizer if optimizer is None else optimizer
358
- lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
359
-
360
- if self.args.relora_steps:
361
- warmup_steps = (
362
- self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
363
- )
364
- self.lr_scheduler = ReLoRAScheduler(
365
- optimizer,
366
- lr_scheduler,
367
- self.args.relora_steps,
368
- warmup_steps,
369
- )
370
- else:
371
- self.lr_scheduler = lr_scheduler
372
-
373
- return self.lr_scheduler
374
-
375
-
376
  def add_position_ids(sample):
377
  sample_len = len(sample["input_ids"])
378
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
@@ -550,245 +265,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
550
  elif cfg.deepspeed:
551
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
552
 
553
- warmup_steps = (
554
- cfg.warmup_steps
555
- if cfg.warmup_steps is not None
556
- else min(int(0.03 * total_num_steps), 100)
557
- )
558
- logging_steps = (
559
- cfg.logging_steps
560
- if cfg.logging_steps is not None
561
- else max(min(int(0.005 * total_num_steps), 10), 1)
562
- )
563
-
564
- training_arguments_kwargs = {}
565
- if cfg.bf16 == "full":
566
- training_arguments_kwargs["bf16_full_eval"] = True
567
- else:
568
- training_arguments_kwargs["bf16"] = cfg.bf16
569
- training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
570
- training_arguments_kwargs["tf32"] = cfg.tf32
571
- training_arguments_kwargs["warmup_steps"] = warmup_steps
572
- training_arguments_kwargs["logging_steps"] = logging_steps
573
-
574
- if cfg.seed:
575
- training_arguments_kwargs["seed"] = cfg.seed
576
-
577
- if cfg.gradient_checkpointing:
578
- training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
579
- if cfg.fsdp:
580
- training_arguments_kwargs["fsdp"] = cfg.fsdp
581
- if cfg.fsdp_config:
582
- training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
583
-
584
- # deepspeed
585
- if cfg.deepspeed:
586
- training_arguments_kwargs["deepspeed"] = cfg.deepspeed
587
-
588
- if cfg.lr_quadratic_warmup is not None:
589
- training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
590
-
591
- if cfg.adam_beta1:
592
- training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
593
- if cfg.adam_beta2:
594
- training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
595
- if cfg.adam_epsilon:
596
- training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
597
- if cfg.max_grad_norm:
598
- training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
599
-
600
- if cfg.hub_model_id:
601
- training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
602
- training_arguments_kwargs["push_to_hub"] = True
603
- training_arguments_kwargs["hub_private_repo"] = True
604
-
605
- if cfg.hub_strategy:
606
- training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
607
-
608
- if cfg.save_safetensors:
609
- training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
610
-
611
- if cfg.sample_packing_eff_est:
612
- training_arguments_kwargs[
613
- "sample_packing_efficiency"
614
- ] = cfg.sample_packing_eff_est
615
-
616
- if cfg.eval_steps:
617
- training_arguments_kwargs["evaluation_strategy"] = "steps"
618
- training_arguments_kwargs["eval_steps"] = cfg.eval_steps
619
- elif cfg.evaluation_strategy:
620
- training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
621
- elif cfg.val_set_size == 0:
622
- # no eval set, so don't eval
623
- training_arguments_kwargs["evaluation_strategy"] = "no"
624
- else:
625
- # we have an eval set, but no steps defined, default to use epoch
626
- training_arguments_kwargs["evaluation_strategy"] = "epoch"
627
-
628
- if cfg.save_steps:
629
- training_arguments_kwargs["save_strategy"] = "steps"
630
- training_arguments_kwargs["save_steps"] = cfg.save_steps
631
- elif cfg.save_strategy:
632
- training_arguments_kwargs["save_strategy"] = cfg.save_strategy
633
- else:
634
- # default to saving each epoch if not defined
635
- training_arguments_kwargs["save_strategy"] = "epoch"
636
-
637
- if cfg.do_bench_eval:
638
- training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
639
- if cfg.bench_dataset:
640
- training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
641
- if cfg.metric_for_best_model:
642
- training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
643
- if cfg.greater_is_better:
644
- training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
645
-
646
- if cfg.torch_compile:
647
- if torch.__version__ < "2.1.0": # pylint: disable=protected-access
648
- LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
649
- else:
650
- import torch._dynamo # pylint: disable=redefined-outer-name
651
-
652
- torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
653
- True
654
- )
655
- training_arguments_kwargs["torch_compile"] = cfg.torch_compile
656
- if cfg.torch_compile_backend:
657
- training_arguments_kwargs[
658
- "torch_compile_backend"
659
- ] = cfg.torch_compile_backend
660
-
661
- # DDP Config
662
- if cfg.ddp_timeout:
663
- training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
664
- # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
665
- if cfg.ddp_bucket_cap_mb:
666
- training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
667
- if cfg.ddp_broadcast_buffers is not None:
668
- training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
669
-
670
- training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
671
- max_steps=total_num_steps if cfg.max_steps else -1,
672
- max_seq_length=cfg.sequence_len,
673
- per_device_train_batch_size=cfg.micro_batch_size,
674
- per_device_eval_batch_size=cfg.eval_batch_size,
675
- gradient_accumulation_steps=cfg.gradient_accumulation_steps,
676
- eval_accumulation_steps=cfg.gradient_accumulation_steps,
677
- num_train_epochs=cfg.num_epochs,
678
- learning_rate=cfg.learning_rate,
679
- output_dir=cfg.output_dir,
680
- save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
681
- load_best_model_at_end=(
682
- (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
683
- and cfg.val_set_size > 0
684
- and cfg.save_steps
685
- and cfg.eval_steps
686
- and cfg.save_steps % cfg.eval_steps == 0
687
- )
688
- or False,
689
- ddp_find_unused_parameters=False if cfg.ddp else None,
690
- group_by_length=cfg.group_by_length,
691
- report_to="wandb" if cfg.use_wandb else None,
692
- run_name=cfg.wandb_run_id if cfg.use_wandb else None,
693
- optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
694
- lr_scheduler_type=cfg.lr_scheduler
695
- if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
696
- else "cosine",
697
- weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
698
- sample_packing=cfg.sample_packing if cfg.sample_packing else False,
699
- eval_sample_packing=cfg.eval_sample_packing,
700
- sample_packing_seq_len_multiplier=cfg.micro_batch_size,
701
- relora_steps=cfg.relora_steps,
702
- relora_warmup_steps=cfg.relora_warmup_steps,
703
- **training_arguments_kwargs,
704
- )
705
-
706
- trainer_kwargs = {}
707
-
708
- if cfg.optimizer == "adamw_anyprecision":
709
- if Path(cfg.torchdistx_path).exists():
710
- sys.path.append(cfg.torchdistx_path)
711
- importlib.import_module("torchdistx")
712
-
713
- callbacks = []
714
- callbacks.append(GPUStatsCallback(cfg))
715
- callbacks.append(EvalFirstStepCallback)
716
-
717
- if cfg.relora_steps:
718
- callbacks.append(ReLoRACallback(cfg))
719
-
720
- if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
721
- callbacks.append(SaveBetterTransformerModelCallback)
722
-
723
- data_collator_kwargs = {
724
- "padding": True, # True/"longest" is the default
725
- }
726
- if cfg.pad_to_sequence_len:
727
- data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
728
- cfg.sequence_len / 64
729
- )
730
- else:
731
- # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
732
- # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
733
- data_collator_kwargs["pad_to_multiple_of"] = 64
734
-
735
- if cfg.is_llama_derived_model and cfg.landmark_attention:
736
- from axolotl.monkeypatch.llama_landmark_attn import (
737
- add_mem_tokens,
738
- get_mem_id,
739
- set_model_mem_id,
740
- )
741
-
742
- set_model_mem_id(model, tokenizer)
743
-
744
- LOG.info("Adding landmark attention tokens to dataset")
745
-
746
- for dataset in [train_dataset, eval_dataset]:
747
- dataset = dataset.map(
748
- partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
749
- batched=False,
750
- num_proc=32,
751
- )
752
-
753
- trainer_cls = AxolotlTrainer
754
- if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
755
- trainer_cls = OneCycleLRSchedulerTrainer
756
- elif cfg.relora_steps:
757
- trainer_cls = ReLoRATrainer
758
- trainer = trainer_cls(
759
- model=model,
760
- train_dataset=train_dataset,
761
- eval_dataset=eval_dataset,
762
- args=training_args,
763
- data_collator=DataCollatorForSeq2Seq(
764
- tokenizer,
765
- return_tensors="pt",
766
- **data_collator_kwargs,
767
- ),
768
- bench_data_collator=transformers.DataCollatorForSeq2Seq(
769
- tokenizer,
770
- return_tensors="pt",
771
- **data_collator_kwargs,
772
- ),
773
- callbacks=callbacks,
774
- **trainer_kwargs,
775
- )
776
-
777
- if cfg.use_wandb and cfg.eval_table_size > 0:
778
- LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
779
- trainer.add_callback(LogPredictionCallback(cfg))
780
-
781
- if cfg.use_wandb:
782
- trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
783
-
784
- if cfg.do_bench_eval:
785
- trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
786
-
787
- # TODO on_save callback to sync checkpoints to GCP/AWS in background
788
- if cfg.early_stopping_patience:
789
- early_stop_cb = EarlyStoppingCallback(
790
- cfg.early_stopping_patience,
791
- )
792
- trainer.add_callback(early_stop_cb)
793
 
794
- return trainer
 
1
  """Module containing the Trainer class and related functions"""
 
2
  import logging
3
  import math
4
  import os
 
5
  from contextlib import contextmanager
 
6
  from functools import partial
7
+ from typing import List
 
8
 
9
  import numpy as np
10
  import torch
11
  import torch.cuda
12
  import torch.distributed as dist
13
+ from datasets import set_caching_enabled
14
+ from torch.utils.data import DistributedSampler, RandomSampler
15
+
16
+ from axolotl.core.trainer_builder import HFCausalTrainerBuilder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from axolotl.utils.collators import DataCollatorForSeq2Seq
18
  from axolotl.utils.dataloader import MultipackDistributedDataloader
19
  from axolotl.utils.distributed import (
 
22
  reduce_and_broadcast,
23
  zero_first,
24
  )
 
25
 
26
  LOG = logging.getLogger("axolotl")
27
 
 
88
  return weighted_cross_entropy(logits, labels, weights)
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def add_position_ids(sample):
92
  sample_len = len(sample["input_ids"])
93
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
 
265
  elif cfg.deepspeed:
266
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
267
 
268
+ trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
269
+ trainer_builder.train_dataset = train_dataset
270
+ trainer_builder.eval_dataset = eval_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ return trainer_builder.build(total_num_steps)