|
"""Module containing the Trainer class and related functions""" |
|
import math |
|
import os |
|
import random |
|
from contextlib import contextmanager |
|
from functools import partial |
|
from typing import List, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.cuda |
|
from accelerate.logging import get_logger |
|
from datasets import set_caching_enabled |
|
from torch.utils.data import DataLoader, RandomSampler |
|
from transformers.utils import is_torch_bf16_gpu_available |
|
|
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder |
|
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first |
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths |
|
|
|
LOG = get_logger("axolotl") |
|
|
|
|
|
@torch.jit.script |
|
def weighted_cross_entropy( |
|
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor |
|
): |
|
|
|
logits = logits.view( |
|
-1, logits.size(-1) |
|
) |
|
labels = labels.view(-1) |
|
weights = weights.view(-1) |
|
|
|
|
|
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none") |
|
|
|
|
|
return (weights * losses).sum() |
|
|
|
|
|
@torch.jit.script |
|
def create_weighted_mask(labels: torch.Tensor): |
|
|
|
if len(labels.shape) == 1: |
|
labels = labels.unsqueeze(0) |
|
|
|
weights = torch.zeros_like(labels).float() |
|
for i in range(labels.shape[0]): |
|
mask = labels[i] != -100 |
|
|
|
|
|
group_ids = torch.zeros_like(labels[i]).int() |
|
curr_group_id = 0 |
|
|
|
for j in range(1, len(labels[i])): |
|
if mask[j] and not mask[j - 1]: |
|
curr_group_id += 1 |
|
group_ids[j] = ( |
|
curr_group_id if mask[j] else 0 |
|
) |
|
|
|
|
|
group_counts = torch.bincount(group_ids[mask]) |
|
|
|
mask_weights = torch.zeros_like(labels[i]).float() |
|
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]] |
|
|
|
weights[i] = mask_weights |
|
|
|
return weights.squeeze() |
|
|
|
|
|
def trainer_weighted_loss(model_output, labels, shift_labels=True): |
|
logits = ( |
|
model_output["logits"] if isinstance(model_output, dict) else model_output[0] |
|
) |
|
if shift_labels: |
|
logits = logits[..., :-1, :].contiguous() |
|
labels = labels[..., 1:].contiguous() |
|
|
|
weights = create_weighted_mask(labels) |
|
return weighted_cross_entropy(logits, labels, weights) |
|
|
|
|
|
@contextmanager |
|
def disable_datasets_caching(): |
|
try: |
|
set_caching_enabled(False) |
|
yield |
|
finally: |
|
set_caching_enabled(True) |
|
|
|
|
|
def add_position_ids(sample): |
|
sample_len = len(sample["input_ids"]) |
|
sample["position_ids"] = torch.arange(len(sample["input_ids"])) |
|
sample["length"] = sample_len |
|
return sample |
|
|
|
|
|
def add_pose_position_ids( |
|
sample, |
|
max_context_len=32768, |
|
split_on_token_ids: Optional[List[int]] = None, |
|
chunks: int = 2, |
|
): |
|
""" |
|
use the PoSE technique to extend the context length by randomly skipping |
|
positions in the context. We only want to skip right before tokens in |
|
the split_on_token_ids list. We should attempt to randomly distribute |
|
the skips, but we don't need the final position_ids to be the full |
|
context_len. There may be multiple turns in the context, so we want to |
|
make sure we take into account the maximum possible number of skips |
|
remaining in each sample. |
|
""" |
|
|
|
input_ids = sample["input_ids"] |
|
sample_len = len(input_ids) |
|
max_skips = max_context_len - sample_len |
|
|
|
if split_on_token_ids is None: |
|
split_on_token_ids = [] |
|
|
|
if split_on_token_ids: |
|
split_indices = [ |
|
i for i, token_id in enumerate(input_ids) if token_id in split_on_token_ids |
|
] |
|
else: |
|
chunk_len = sample_len // chunks |
|
split_indices = [i * chunk_len for i in range(1, chunks)] |
|
split_indices.append(len(input_ids)) |
|
if split_indices[0] < 2: |
|
|
|
split_indices = split_indices[1:] |
|
|
|
position_ids = [] |
|
prev_index = 0 |
|
total_skips = 0 |
|
|
|
for split_index in split_indices: |
|
num_skips = ( |
|
random.randint(0, max_skips) |
|
if prev_index != 0 and max_skips |
|
else 0 |
|
) |
|
max_skips -= num_skips |
|
total_skips += num_skips |
|
|
|
segment_position_ids = list( |
|
range(prev_index + total_skips, split_index + total_skips) |
|
) |
|
|
|
position_ids.extend(segment_position_ids) |
|
prev_index = split_index |
|
|
|
sample["sequence_len"] = position_ids[-1] |
|
position_ids = torch.tensor(position_ids) |
|
|
|
sample["position_ids"] = position_ids |
|
sample["length"] = len(position_ids) |
|
assert len(position_ids) == len(input_ids) |
|
|
|
return sample |
|
|
|
|
|
def add_length(sample): |
|
sample["length"] = len(sample["input_ids"]) |
|
return sample |
|
|
|
|
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): |
|
return ( |
|
len(sample["input_ids"]) <= sequence_len |
|
and len(sample["input_ids"]) >= min_sequence_len |
|
) |
|
|
|
|
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): |
|
drop_long = partial( |
|
drop_long_seq, |
|
sequence_len=cfg.sequence_len, |
|
min_sequence_len=cfg.min_sample_len or 2, |
|
) |
|
with zero_first(is_main_process()): |
|
if cfg.is_preprocess: |
|
min_input_len = np.min(get_dataset_lengths(train_dataset)) |
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) |
|
max_input_len = np.max(get_dataset_lengths(train_dataset)) |
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) |
|
|
|
if ( |
|
cfg.is_mistral_derived_model and cfg.flash_attention |
|
) or cfg.model_config_type == "mamba": |
|
LOG.info("dropping attention_mask column") |
|
train_dataset = train_dataset.remove_columns("attention_mask") |
|
if eval_dataset: |
|
eval_dataset = eval_dataset.remove_columns("attention_mask") |
|
|
|
if cfg.model_config_type == "falcon": |
|
LOG.info("dropping token_type_ids column if it exists") |
|
if "token_type_ids" in train_dataset.column_names: |
|
train_dataset = train_dataset.remove_columns("token_type_ids") |
|
if eval_dataset and "token_type_ids" in eval_dataset.column_names: |
|
eval_dataset = eval_dataset.remove_columns("token_type_ids") |
|
|
|
train_dataset = train_dataset.filter( |
|
drop_long, |
|
num_proc=cfg.dataset_processes, |
|
load_from_cache_file=not cfg.is_preprocess, |
|
desc="Dropping Long Sequences", |
|
) |
|
if eval_dataset: |
|
eval_dataset = eval_dataset.filter( |
|
drop_long, |
|
num_proc=cfg.dataset_processes, |
|
load_from_cache_file=not cfg.is_preprocess, |
|
desc="Dropping Long Sequences", |
|
) |
|
|
|
if cfg.group_by_length: |
|
train_dataset = train_dataset.map( |
|
add_length, |
|
num_proc=cfg.dataset_processes, |
|
load_from_cache_file=not cfg.is_preprocess, |
|
desc="Group By Length", |
|
) |
|
|
|
if cfg.use_pose: |
|
pose_kwargs = {} |
|
if cfg.pose_num_chunks is not None: |
|
pose_kwargs["chunks"] = cfg.pose_num_chunks |
|
pose_fn = partial( |
|
add_pose_position_ids, |
|
max_context_len=cfg.pose_max_context_len, |
|
split_on_token_ids=cfg.pose_split_on_token_ids, |
|
**pose_kwargs, |
|
) |
|
train_dataset = train_dataset.map( |
|
pose_fn, |
|
num_proc=cfg.dataset_processes, |
|
load_from_cache_file=not cfg.is_preprocess, |
|
desc="Add position_id column (PoSE)", |
|
) |
|
train_dataset = train_dataset.sort("sequence_len") |
|
if cfg.eval_sample_packing is not False: |
|
if eval_dataset: |
|
eval_dataset = eval_dataset.map( |
|
pose_fn, |
|
num_proc=cfg.dataset_processes, |
|
load_from_cache_file=not cfg.is_preprocess, |
|
desc="Add position_id column (PoSE)", |
|
) |
|
elif cfg.sample_packing: |
|
train_dataset = train_dataset.map( |
|
add_position_ids, |
|
num_proc=cfg.dataset_processes, |
|
load_from_cache_file=not cfg.is_preprocess, |
|
desc="Add position_id column (Sample Packing)", |
|
) |
|
if cfg.eval_sample_packing is not False: |
|
if eval_dataset: |
|
eval_dataset = eval_dataset.map( |
|
add_position_ids, |
|
num_proc=cfg.dataset_processes, |
|
load_from_cache_file=not cfg.is_preprocess, |
|
desc="Add position_id column (Sample Packing)", |
|
) |
|
|
|
return train_dataset, eval_dataset |
|
|
|
|
|
def process_pretraining_datasets_for_packing( |
|
train_dataset, sequence_len, skip_position_ids=True |
|
): |
|
drop_long = partial(drop_long_seq, sequence_len=sequence_len) |
|
|
|
train_dataset = train_dataset.filter( |
|
drop_long, |
|
desc="Dropping Long Sequences", |
|
) |
|
if skip_position_ids: |
|
train_dataset = train_dataset.map( |
|
add_position_ids, |
|
desc="Add position_id column (Pretraining Sample Packing)", |
|
) |
|
|
|
return train_dataset |
|
|
|
|
|
def calculate_total_num_steps(cfg, train_dataset, update=True): |
|
if not cfg.total_num_tokens: |
|
total_num_tokens = np.sum( |
|
train_dataset.data.column("input_ids") |
|
.to_pandas() |
|
.apply(lambda x: len(x)) |
|
.values |
|
) |
|
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True) |
|
if update: |
|
cfg.total_num_tokens = total_num_tokens |
|
|
|
skip_estimates = cfg.model_config_type == "mamba" |
|
|
|
if not skip_estimates and not cfg.total_supervised_tokens: |
|
total_supervised_tokens = ( |
|
train_dataset.data.column("labels") |
|
.to_pandas() |
|
.apply(lambda x: np.sum(np.array(x) != -100)) |
|
.sum() |
|
) |
|
LOG.debug( |
|
f"`total_supervised_tokens: {total_supervised_tokens:_}`", |
|
main_process_only=True, |
|
) |
|
if update: |
|
cfg.total_supervised_tokens = total_supervised_tokens |
|
|
|
if not skip_estimates and cfg.sample_packing: |
|
|
|
|
|
|
|
if cfg.sample_packing_eff_est: |
|
total_num_steps = ( |
|
|
|
( |
|
math.floor( |
|
0.99 |
|
* cfg.total_num_tokens |
|
/ cfg.sample_packing_eff_est |
|
/ cfg.sequence_len |
|
// cfg.batch_size |
|
) |
|
- 1 |
|
) |
|
* cfg.num_epochs |
|
) |
|
LOG.debug( |
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", |
|
main_process_only=True, |
|
) |
|
else: |
|
if cfg.flash_attention: |
|
sampler_batch_size = 1 |
|
batch_max_len = cfg.micro_batch_size * cfg.sequence_len |
|
else: |
|
sampler_batch_size = cfg.micro_batch_size |
|
batch_max_len = cfg.sequence_len |
|
sampler = MultipackBatchSampler( |
|
sampler=RandomSampler(train_dataset), |
|
lengths=get_dataset_lengths(train_dataset), |
|
batch_size=sampler_batch_size, |
|
batch_max_len=batch_max_len, |
|
group_size=cfg.sample_packing_group_size, |
|
bin_size=cfg.sample_packing_bin_size, |
|
drop_last=True, |
|
) |
|
|
|
data_loader = DataLoader( |
|
train_dataset.remove_columns(["length"]), |
|
batch_sampler=sampler, |
|
) |
|
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size |
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) |
|
|
|
|
|
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) |
|
|
|
def calc_sample_packing_eff_est(estimates: List[float]): |
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") |
|
return max(estimates) |
|
|
|
sample_packing_actual_eff_all = reduce_and_broadcast( |
|
lambda: sampler.efficiency(), |
|
calc_sample_packing_eff_est, |
|
) |
|
sample_packing_eff_est = ( |
|
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 |
|
) |
|
if update: |
|
cfg.sample_packing_eff_est = sample_packing_eff_est |
|
LOG.debug( |
|
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}", |
|
main_process_only=True, |
|
) |
|
else: |
|
total_num_steps = int( |
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) |
|
) |
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) |
|
return total_num_steps |
|
|
|
|
|
def setup_fsdp_envs(cfg): |
|
os.environ["ACCELERATE_USE_FSDP"] = "true" |
|
if cfg.fsdp_config.fsdp_activation_checkpointing: |
|
os.environ["FSDP_ACTIVATION_CHECKPOINTING"] = "true" |
|
if cfg.fsdp_config.fsdp_offload_params: |
|
os.environ["FSDP_OFFLOAD_PARAMS"] = "true" |
|
if cfg.fsdp_config.fsdp_sync_module_states: |
|
os.environ["FSDP_SYNC_MODULE_STATES"] = "true" |
|
if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: |
|
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true" |
|
if cfg.fsdp_config.fsdp_use_orig_params: |
|
os.environ["FSDP_USE_ORIG_PARAMS"] = "true" |
|
if cfg.fsdp_config.fsdp_state_dict_type: |
|
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type |
|
if cfg.fsdp_config.fsdp_auto_wrap_policy: |
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = cfg.fsdp_config.fsdp_auto_wrap_policy |
|
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap: |
|
os.environ[ |
|
"FSDP_TRANSFORMER_CLS_TO_WRAP" |
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap |
|
|
|
|
|
def prepare_optim_env(cfg): |
|
if cfg.fsdp: |
|
setup_fsdp_envs(cfg) |
|
elif cfg.deepspeed: |
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" |
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed |
|
|
|
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True: |
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16" |
|
elif cfg.fp16: |
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16" |
|
|
|
|
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): |
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]: |
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) |
|
trainer_builder.model_ref = model[1] |
|
trainer_builder.peft_config = model[2] |
|
else: |
|
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) |
|
|
|
trainer_builder.train_dataset = train_dataset |
|
trainer_builder.eval_dataset = eval_dataset |
|
|
|
return trainer_builder.build(total_num_steps) |
|
|