fix learning rate scheduler's warnings (#1135) [skip ci]
Browse files* fix schedulers warnings
* chore: lint
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
@@ -170,24 +170,30 @@ class AxolotlTrainer(Trainer):
|
|
170 |
num_training_steps (int): The number of training steps to do.
|
171 |
optimizer (torch.optim.Optimizer): The training optimizer
|
172 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
# fmt: off
|
175 |
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
176 |
# fmt: on
|
177 |
-
if
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
182 |
optimizer,
|
183 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
184 |
num_training_steps=num_training_steps,
|
185 |
)
|
186 |
-
elif self.args.
|
187 |
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
188 |
-
if self.args.deepspeed:
|
189 |
-
LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
|
190 |
-
in the deepspeed JSON")
|
191 |
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
192 |
optimizer,
|
193 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
@@ -196,6 +202,13 @@ class AxolotlTrainer(Trainer):
|
|
196 |
)
|
197 |
else:
|
198 |
return super().create_scheduler(num_training_steps, optimizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
return self.lr_scheduler
|
200 |
|
201 |
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
|
|
170 |
num_training_steps (int): The number of training steps to do.
|
171 |
optimizer (torch.optim.Optimizer): The training optimizer
|
172 |
"""
|
173 |
+
use_cosine_quadratic = (
|
174 |
+
self.args.lr_scheduler_type == "cosine"
|
175 |
+
and self.args.lr_quadratic_warmup is True
|
176 |
+
)
|
177 |
+
|
178 |
+
use_cosine_min_lr = (
|
179 |
+
self.args.lr_scheduler_type == "cosine"
|
180 |
+
and self.args.cosine_min_lr_ratio is not None
|
181 |
+
)
|
182 |
|
183 |
# fmt: off
|
184 |
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
185 |
# fmt: on
|
186 |
+
if use_cosine_quadratic:
|
187 |
+
if use_cosine_min_lr:
|
188 |
+
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
189 |
+
|
190 |
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
191 |
optimizer,
|
192 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
193 |
num_training_steps=num_training_steps,
|
194 |
)
|
195 |
+
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
196 |
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
|
|
|
|
|
197 |
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
198 |
optimizer,
|
199 |
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
|
202 |
)
|
203 |
else:
|
204 |
return super().create_scheduler(num_training_steps, optimizer)
|
205 |
+
else:
|
206 |
+
if use_cosine_quadratic:
|
207 |
+
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
208 |
+
|
209 |
+
if use_cosine_min_lr:
|
210 |
+
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
211 |
+
|
212 |
return self.lr_scheduler
|
213 |
|
214 |
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|