ricdomolm winglian commited on
Commit
b4ac96a
·
unverified ·
1 Parent(s): 98b4762

fix learning rate scheduler's warnings (#1135) [skip ci]

Browse files

* fix schedulers warnings

* chore: lint

---------

Co-authored-by: Wing Lian <[email protected]>

Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +21 -8
src/axolotl/core/trainer_builder.py CHANGED
@@ -170,24 +170,30 @@ class AxolotlTrainer(Trainer):
170
  num_training_steps (int): The number of training steps to do.
171
  optimizer (torch.optim.Optimizer): The training optimizer
172
  """
 
 
 
 
 
 
 
 
 
173
 
174
  # fmt: off
175
  if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
176
  # fmt: on
177
- if (
178
- self.args.lr_scheduler_type == "cosine"
179
- and self.args.lr_quadratic_warmup is True
180
- ):
181
  self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
182
  optimizer,
183
  num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
184
  num_training_steps=num_training_steps,
185
  )
186
- elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None:
187
  assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
188
- if self.args.deepspeed:
189
- LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
190
- in the deepspeed JSON")
191
  self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
192
  optimizer,
193
  num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
@@ -196,6 +202,13 @@ class AxolotlTrainer(Trainer):
196
  )
197
  else:
198
  return super().create_scheduler(num_training_steps, optimizer)
 
 
 
 
 
 
 
199
  return self.lr_scheduler
200
 
201
  def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
 
170
  num_training_steps (int): The number of training steps to do.
171
  optimizer (torch.optim.Optimizer): The training optimizer
172
  """
173
+ use_cosine_quadratic = (
174
+ self.args.lr_scheduler_type == "cosine"
175
+ and self.args.lr_quadratic_warmup is True
176
+ )
177
+
178
+ use_cosine_min_lr = (
179
+ self.args.lr_scheduler_type == "cosine"
180
+ and self.args.cosine_min_lr_ratio is not None
181
+ )
182
 
183
  # fmt: off
184
  if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
185
  # fmt: on
186
+ if use_cosine_quadratic:
187
+ if use_cosine_min_lr:
188
+ LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
189
+
190
  self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
191
  optimizer,
192
  num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
193
  num_training_steps=num_training_steps,
194
  )
195
+ elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
196
  assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
 
 
 
197
  self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
198
  optimizer,
199
  num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
 
202
  )
203
  else:
204
  return super().create_scheduler(num_training_steps, optimizer)
205
+ else:
206
+ if use_cosine_quadratic:
207
+ LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
208
+
209
+ if use_cosine_min_lr:
210
+ LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
211
+
212
  return self.lr_scheduler
213
 
214
  def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: