|
"""Module for custom LRScheduler class""" |
|
|
|
from torch.optim.lr_scheduler import LRScheduler |
|
|
|
|
|
class InterpolatingLogScheduler(LRScheduler): |
|
""" |
|
A scheduler that interpolates learning rates in a logarithmic fashion |
|
""" |
|
|
|
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): |
|
"""A scheduler that interpolates learning rates in a logarithmic fashion |
|
|
|
Args: |
|
- optimizer: pytorch optimizer |
|
- num_steps: int, the number of steps over which to increase from the min_lr to the max_lr |
|
- min_lr: float, the minimum learning rate |
|
- max_lr: float, the maximum learning rate |
|
|
|
Usage: |
|
fc = nn.Linear(1,1) |
|
optimizer = optim.Adam(fc.parameters()) |
|
lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4) |
|
""" |
|
self.num_steps = num_steps |
|
self.min_lr = min_lr |
|
self.max_lr = max_lr |
|
self.q = (max_lr / min_lr) ** ( |
|
1 / (num_steps - 1) |
|
) |
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
if self.last_epoch <= 0: |
|
lrs = [self.min_lr for base_lr in self.base_lrs] |
|
elif self.last_epoch < self.num_steps: |
|
lrs = [ |
|
self.min_lr * (self.q ** (self.last_epoch - 1)) |
|
for base_lr in self.base_lrs |
|
] |
|
else: |
|
lrs = [self.max_lr for base_lr in self.base_lrs] |
|
|
|
return lrs |
|
|