winglian commited on
Commit
7dc580b
·
1 Parent(s): 93dacba

add axolotl trainer and quadratic warmup

Browse files
src/axolotl/utils/schedulers.py CHANGED
@@ -1,6 +1,9 @@
1
  """Module for custom LRScheduler class"""
 
 
2
 
3
- from torch.optim.lr_scheduler import LRScheduler
 
4
 
5
 
6
  class InterpolatingLogScheduler(LRScheduler):
@@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
42
  lrs = [self.max_lr for base_lr in self.base_lrs]
43
 
44
  return lrs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module for custom LRScheduler class"""
2
+ import math
3
+ from functools import partial
4
 
5
+ from torch.optim import Optimizer
6
+ from torch.optim.lr_scheduler import LambdaLR, LRScheduler
7
 
8
 
9
  class InterpolatingLogScheduler(LRScheduler):
 
45
  lrs = [self.max_lr for base_lr in self.base_lrs]
46
 
47
  return lrs
48
+
49
+
50
+ def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
51
+ current_step: int,
52
+ *,
53
+ num_warmup_steps: int,
54
+ num_training_steps: int,
55
+ num_cycles: float
56
+ ):
57
+ if current_step < num_warmup_steps:
58
+ return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
59
+ progress = float(current_step - num_warmup_steps) / float(
60
+ max(1, num_training_steps - num_warmup_steps)
61
+ )
62
+ return max(
63
+ 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
64
+ )
65
+
66
+
67
+ def get_cosine_schedule_with_quadratic_warmup(
68
+ optimizer: Optimizer,
69
+ num_warmup_steps: int,
70
+ num_training_steps: int,
71
+ num_cycles: float = 0.5,
72
+ last_epoch: int = -1,
73
+ ):
74
+ """
75
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
76
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
77
+ initial lr set in the optimizer.
78
+
79
+ Args:
80
+ optimizer ([`~torch.optim.Optimizer`]):
81
+ The optimizer for which to schedule the learning rate.
82
+ num_warmup_steps (`int`):
83
+ The number of steps for the warmup phase.
84
+ num_training_steps (`int`):
85
+ The total number of training steps.
86
+ num_cycles (`float`, *optional*, defaults to 0.5):
87
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
88
+ following a half-cosine).
89
+ last_epoch (`int`, *optional*, defaults to -1):
90
+ The index of the last epoch when resuming training.
91
+
92
+ Return:
93
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
94
+ """
95
+
96
+ lr_lambda = partial(
97
+ _get_cosine_schedule_with_quadratic_warmup_lr_lambda,
98
+ num_warmup_steps=num_warmup_steps,
99
+ num_training_steps=num_training_steps,
100
+ num_cycles=num_cycles,
101
+ )
102
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
src/axolotl/utils/trainer.py CHANGED
@@ -17,10 +17,42 @@ from transformers import EarlyStoppingCallback, Trainer
17
  from transformers.trainer_pt_utils import get_parameter_names
18
 
19
  from axolotl.utils.callbacks import SavePeftModelCallback
20
- from axolotl.utils.schedulers import InterpolatingLogScheduler
 
 
 
21
 
22
 
23
- class OneCycleLRSchedulerTrainer(Trainer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  """
25
  Trainer subclass that uses the OneCycleLR scheduler
26
  """
@@ -259,7 +291,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
259
  trainer_cls = (
260
  OneCycleLRSchedulerTrainer
261
  if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
262
- else transformers.Trainer
263
  )
264
  trainer = trainer_cls(
265
  model=model,
 
17
  from transformers.trainer_pt_utils import get_parameter_names
18
 
19
  from axolotl.utils.callbacks import SavePeftModelCallback
20
+ from axolotl.utils.schedulers import (
21
+ InterpolatingLogScheduler,
22
+ get_cosine_schedule_with_quadratic_warmup,
23
+ )
24
 
25
 
26
+ class AxolotlTrainer(Trainer):
27
+ """
28
+ Extend the base Trainer for axolotl helpers
29
+ """
30
+
31
+ def create_scheduler(
32
+ self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
33
+ ):
34
+ """
35
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
36
+ passed as an argument.
37
+
38
+ Args:
39
+ num_training_steps (int): The number of training steps to do.
40
+ """
41
+
42
+ if self.lr_scheduler is None: # pylint: disable=access-member-before-definition
43
+ """# type: ignore"""
44
+ if self.args.lr_scheduler_type == "cosine_with_quadratic":
45
+ self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
46
+ optimizer,
47
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
48
+ num_training_steps=num_training_steps,
49
+ )
50
+ else:
51
+ return super().create_scheduler(num_training_steps, optimizer)
52
+ return self.lr_scheduler
53
+
54
+
55
+ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
56
  """
57
  Trainer subclass that uses the OneCycleLR scheduler
58
  """
 
291
  trainer_cls = (
292
  OneCycleLRSchedulerTrainer
293
  if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
294
+ else AxolotlTrainer
295
  )
296
  trainer = trainer_cls(
297
  model=model,