File size: 7,964 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import functools
import json
import logging
import math
from pathlib import Path
from typing import Callable, Union
import safetensors
import torch
import torch.distributed.fsdp.wrap as torch_wrap
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from model.args import ModelArgs, MoeArgs
from model.transformer import Transformer, TransformerBlock
from .args import LoraArgs
from .checkpointing import Checkpointer
from .distributed import (
get_rank,
get_world_size,
)
logger = logging.getLogger(__name__)
def main_logger_info(message: str) -> None:
if get_rank() == 0:
logger.info(message)
def get_fsdp_policy(is_lora: bool) -> Callable[[torch.nn.Module], bool]:
"""
This function instantiates the FSDP wrap policy.
- Each Transformers block becomes it's own FSDP group so that only a single Transformer block is sharded at a time
- If LoRA is enabled, we additionally create seperate FSDP sub-groups for every trainable and non-trainable parameter group
since this is a requirement for mixed requires_grad=True/False training. See: https://pytorch.org/docs/stable/fsdp.html
"""
# Each transformer block becomes a FSDP group, each being sharded seperately
transformer_block_wrap_policy = functools.partial(
torch_wrap.transformer_auto_wrap_policy,
transformer_layer_cls=(TransformerBlock,),
)
if not is_lora:
return transformer_block_wrap_policy
def fsdp_lora_policy_fn(module):
return all(p.requires_grad for p in module.parameters())
# For LoRA training, trainable and non-trainable parameters need to be put into
# different FSDP groups
fsdp_lora_policy = functools.partial(
torch_wrap.lambda_auto_wrap_policy, lambda_fn=fsdp_lora_policy_fn
)
policies = [fsdp_lora_policy, transformer_block_wrap_policy]
return functools.partial(torch_wrap._or_policy, policies=policies)
def log_train_params(model: Union[torch.nn.Module, FullyShardedDataParallel]):
world_size = get_world_size()
num_params = world_size * sum(p.numel() for p in model.parameters())
num_train_params = world_size * sum(
p.numel() for p in model.parameters() if p.requires_grad
)
main_logger_info(
f"{num_train_params:,.0f} out of {num_params:,.0f} parameter are finetuned ({num_train_params / num_params * 100:.2f}%)."
)
def initialize_lora_parameters(model: torch.nn.Module, param_dtype: torch.dtype):
"""
Initialize LoRA layers with Kaiming uniform and zeros.
See original paper for more info: https://arxiv.org/abs/2106.09685 and
original github repo: https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L122
"""
for m_name, module in model.named_modules():
if all(p.is_meta for p in module.parameters()):
for p_name, param in module.named_parameters():
module._parameters[p_name] = torch.nn.Parameter(
torch.empty_like(param, device="cpu", dtype=param_dtype)
)
param = module._parameters[p_name]
if m_name.split(".")[-1] == "lora_A":
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
elif m_name.split(".")[-1] == "lora_B":
torch.nn.init.zeros_(param)
else:
raise ValueError(
"Only Lora layers should be randomely initialized."
)
def load_model(
folder: Path,
lora: LoraArgs,
checkpoint: bool,
param_dtype: torch.dtype,
) -> FullyShardedDataParallel:
with open(folder / "params.json", "r") as f:
args = json.loads(f.read())
model_args = ModelArgs(
lora=lora,
dim=args["dim"],
n_layers=args["n_layers"],
head_dim=args["head_dim"],
hidden_dim=args["hidden_dim"],
n_heads=args["n_heads"],
n_kv_heads=args["n_kv_heads"],
norm_eps=args["norm_eps"],
vocab_size=args["vocab_size"],
)
if model_args.vocab_size == 32000:
raise ValueError(
f"Fine-tuning is not supported for older model versions with vocab_size 32000. Make sure to extend your model to vocab_size=32768 using `python -m utils.extend_model_vocab --original_model_ckpt {folder} --extended_model_ckpt {folder}_extended`."
)
assert (
model_args.vocab_size >= 32768
), "Make sure to use a model with a vocab size of at least 32768"
if args.get("rope_theta") is not None:
model_args.rope_theta = args["rope_theta"]
if args.get("moe") is not None:
model_args.moe = MoeArgs(**args["moe"])
with torch.device("meta"):
model = Transformer(args=model_args, checkpoint=checkpoint)
if get_rank() == 0:
state_dict = load_state_dict(folder, dtype=param_dtype)
model.load_state_dict(state_dict, assign=True) # type: ignore
logger.info("Loaded model on cpu!")
if lora.enable:
logger.info("Initializing lora layers ...")
# initialize LoRA layers
initialize_lora_parameters(model, param_dtype)
assert not any(
p.is_meta for p in model.parameters()
), "All parameters should be intialized by now"
assert all(
p.dtype == param_dtype for p in model.parameters()
), f"All parameters should be on {param_dtype}"
logger.info("Finished initialization!")
param_init_fn = None
else:
def param_init_fn(m):
m.to_empty(device=torch.cuda.current_device(), recurse=False)
m.to(param_dtype)
assert all(
p.is_meta for p in model.parameters()
), "All parameters should be on meta"
torch.distributed.barrier()
# only finetune LoRA parameters and freeze before wrapping
if lora.enable:
for name, param in model.named_parameters():
if "lora" in name:
param.requires_grad = True
else:
param.requires_grad = False
auto_wrap_policy = get_fsdp_policy(model_args.lora.enable)
main_logger_info(f"Sharding model over {get_world_size()} GPUs ...")
wrapped_model = FullyShardedDataParallel(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
device_id=torch.cuda.current_device(),
sync_module_states=True,
param_init_fn=param_init_fn,
)
main_logger_info("Model sharded!")
log_train_params(wrapped_model)
return wrapped_model
@torch.no_grad()
def load_state_dict(path: Path, dtype: torch.dtype):
assert path.is_dir(), path
this_safetensors_path = Checkpointer.consolidated_path(path, use_safetensors=True)
this_torch_path = Checkpointer.consolidated_path(path, use_safetensors=False)
assert (
this_safetensors_path.exists() or this_torch_path.exists()
), f"Either {this_safetensors_path} or {this_torch_path} must exist."
assert not (
this_safetensors_path.exists() and this_torch_path.exists()
), f"Only one of {this_safetensors_path} or {this_torch_path} should exist."
if this_safetensors_path.exists():
logger.info(f"Reloading model from {this_safetensors_path} ...")
model_state_dict = safetensors.torch.load_file(this_safetensors_path)
else:
logger.info(f"Reloading model from {this_torch_path} ...")
model_state_dict = torch.load(this_torch_path)
logger.info(f"Converting model to dtype {dtype} ...")
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(dtype)
return model_state_dict
|