lolcats / src /finetune.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
"""
Finetuning functions to do post-distillation
"""
from os.path import join
from omegaconf import OmegaConf
import torch
from torch.nn import Module
from src.utils.setup import update_config_from_args
from src.dataloaders import load_data
from src.trainer import get_trainer, get_optimizer, get_scheduler
def prepare_finetune_configs(args, model_config: dict,
finetune_config_name: str = None,
finetune_checkpoint_name: str = None,
config_dir='./configs/experiment'):
"""
Prepare finetuning configs
"""
# Load finetuning config
finetune_config = (finetune_config_name if finetune_config_name is not None else
finetune_checkpoint_name.split('-f=')[-1].split('-')[0])
finetune_config_path = join(config_dir, f'{finetune_config}.yaml')
finetune_config = OmegaConf.load(finetune_config_path)
finetune_config = update_config_from_args(finetune_config, args,
ignore_args=['lr', 'weight_decay'])
# Update data tokenizer to match model
if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None:
for k in ['pretrained_model_name_or_path', 'cache_dir']:
finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k]
# Set finetuning args
for arg, argv in finetune_config.trainer.items():
if arg != 'name':
setattr(args, arg, argv)
for _config in ['dataloader', 'optimizer', 'lr_scheduler']:
setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config)))
return finetune_config, args
def get_finetuner(model: Module, finetune_config: dict, device: torch.device,
args: any, wandb: any, initial_eval: bool = False):
"""
Initialize finetuning trainer
"""
model.to(device) # if using a fused optimizer
model.train()
# Initialize optimizer and scheduler
optimizer = get_optimizer(model=model, **finetune_config.optimizer)
scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler)
dataloaders = load_data(finetune_config.dataset, finetune_config.dataloader)
train_loader = dataloaders[finetune_config.trainer.train_split]
eval_loader = dataloaders[finetune_config.trainer.val_split]
OurTrainer = get_trainer(finetune_config.trainer.name)
trainer = OurTrainer(model=model,
args=args,
train_loader=train_loader,
eval_loader=eval_loader,
optimizer_and_scheduler=(optimizer, scheduler),
device=device,
wandb=wandb,
checkpoint_suffix='_ft',
**finetune_config.trainer)
return trainer