|
""" |
|
This file is for models creation, which consults options |
|
and creates each encoder and decoder accordingly. |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.init import xavier_uniform_, zeros_, uniform_ |
|
import onmt.modules |
|
from onmt.encoders import str2enc |
|
from onmt.decoders import str2dec |
|
from onmt.inputters.inputter import dict_to_vocabs |
|
from onmt.modules import Embeddings, CopyGenerator |
|
from onmt.utils.misc import use_gpu |
|
from onmt.utils.logging import logger |
|
from onmt.utils.parse import ArgumentParser |
|
from onmt.models.model_saver import load_checkpoint |
|
from onmt.constants import DefaultTokens, ModelTask |
|
from onmt.modules.lora import ( |
|
replace_lora_linear, |
|
replace_lora_embedding, |
|
mark_only_lora_as_trainable, |
|
) |
|
|
|
|
|
def build_embeddings(opt, vocabs, for_encoder=True): |
|
""" |
|
Args: |
|
opt: the option in current environment. |
|
vocab. |
|
for_encoder(bool): build Embeddings for encoder or decoder? |
|
""" |
|
feat_pad_indices = [] |
|
num_feat_embeddings = [] |
|
if for_encoder: |
|
emb_dim = opt.src_word_vec_size |
|
word_padding_idx = vocabs["src"][DefaultTokens.PAD] |
|
num_word_embeddings = len(vocabs["src"]) |
|
if "src_feats" in vocabs: |
|
feat_pad_indices = [fv[DefaultTokens.PAD] for fv in vocabs["src_feats"]] |
|
num_feat_embeddings = [len(fv) for fv in vocabs["src_feats"]] |
|
freeze_word_vecs = opt.freeze_word_vecs_enc |
|
else: |
|
emb_dim = opt.tgt_word_vec_size |
|
word_padding_idx = vocabs["tgt"][DefaultTokens.PAD] |
|
num_word_embeddings = len(vocabs["tgt"]) |
|
freeze_word_vecs = opt.freeze_word_vecs_dec |
|
|
|
emb = Embeddings( |
|
word_vec_size=emb_dim, |
|
position_encoding=opt.position_encoding, |
|
position_encoding_type=opt.position_encoding_type, |
|
feat_merge=opt.feat_merge, |
|
feat_vec_exponent=opt.feat_vec_exponent, |
|
feat_vec_size=opt.feat_vec_size, |
|
dropout=opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
|
word_padding_idx=word_padding_idx, |
|
feat_padding_idx=feat_pad_indices, |
|
word_vocab_size=num_word_embeddings, |
|
feat_vocab_sizes=num_feat_embeddings, |
|
sparse=opt.optim == "sparseadam", |
|
freeze_word_vecs=freeze_word_vecs, |
|
) |
|
return emb |
|
|
|
|
|
def build_encoder(opt, embeddings): |
|
""" |
|
Various encoder dispatcher function. |
|
Args: |
|
opt: the option in current environment. |
|
embeddings (Embeddings): vocab embeddings for this encoder. |
|
""" |
|
enc_type = opt.encoder_type if opt.model_type == "text" else opt.model_type |
|
return str2enc[enc_type].from_opt(opt, embeddings) |
|
|
|
|
|
def build_decoder(opt, embeddings): |
|
""" |
|
Various decoder dispatcher function. |
|
Args: |
|
opt: the option in current environment. |
|
embeddings (Embeddings): vocab embeddings for this decoder. |
|
""" |
|
dec_type = ( |
|
"ifrnn" if opt.decoder_type == "rnn" and opt.input_feed else opt.decoder_type |
|
) |
|
return str2dec[dec_type].from_opt(opt, embeddings) |
|
|
|
|
|
def load_test_model(opt, device_id=0, model_path=None): |
|
if model_path is None: |
|
model_path = opt.models[0] |
|
checkpoint = load_checkpoint(model_path) |
|
|
|
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) |
|
|
|
model_opt.quant_layers = opt.quant_layers |
|
model_opt.quant_type = opt.quant_type |
|
|
|
if opt.world_size > 1 and opt.parallel_mode == "tensor_parallel": |
|
model_opt.world_size = opt.world_size |
|
model_opt.parallel_mode = opt.parallel_mode |
|
model_opt.gpu_ranks = opt.gpu_ranks |
|
device = torch.device("cuda", device_id) |
|
offset = device_id |
|
else: |
|
if use_gpu(opt): |
|
if len(opt.gpu_ranks) > 0: |
|
device_id = opt.gpu_ranks[0] |
|
elif opt.gpu > -1: |
|
device_id = opt.gpu |
|
device = torch.device("cuda", device_id) |
|
else: |
|
device = torch.device("cpu") |
|
offset = 0 |
|
|
|
ArgumentParser.update_model_opts(model_opt) |
|
ArgumentParser.validate_model_opts(model_opt) |
|
vocabs = dict_to_vocabs(checkpoint["vocab"]) |
|
|
|
|
|
model_opt.update_vocab = False |
|
|
|
model = build_base_model(model_opt, vocabs) |
|
|
|
precision = torch.float32 |
|
|
|
if opt.precision == "fp16": |
|
precision = torch.float16 |
|
elif opt.precision == "int8": |
|
if opt.gpu >= 0: |
|
raise ValueError("Dynamic 8-bit quantization is not supported on GPU") |
|
else: |
|
precision = torch.int8 |
|
|
|
logger.info("Loading data into the model") |
|
|
|
if "model" in checkpoint.keys(): |
|
|
|
model.load_state_dict( |
|
checkpoint, |
|
precision=precision, |
|
device=device, |
|
strict=True, |
|
offset=offset, |
|
) |
|
else: |
|
|
|
base_name = model_path[:-3] if model_path[-3:] == ".pt" else model_path |
|
model.load_safe_state_dict( |
|
base_name, |
|
precision=precision, |
|
device=device, |
|
strict=True, |
|
offset=offset, |
|
) |
|
del checkpoint |
|
|
|
model.eval() |
|
model.generator.eval() |
|
return vocabs, model, model_opt |
|
|
|
|
|
def build_src_emb(model_opt, vocabs): |
|
|
|
if model_opt.model_type == "text": |
|
src_emb = build_embeddings(model_opt, vocabs) |
|
else: |
|
src_emb = None |
|
return src_emb |
|
|
|
|
|
def build_encoder_with_embeddings(model_opt, vocabs): |
|
|
|
src_emb = build_src_emb(model_opt, vocabs) |
|
encoder = build_encoder(model_opt, src_emb) |
|
return encoder, src_emb |
|
|
|
|
|
def build_decoder_with_embeddings( |
|
model_opt, vocabs, share_embeddings=False, src_emb=None |
|
): |
|
|
|
tgt_emb = build_embeddings(model_opt, vocabs, for_encoder=False) |
|
|
|
if share_embeddings: |
|
tgt_emb.word_lut.weight = src_emb.word_lut.weight |
|
|
|
|
|
decoder = build_decoder(model_opt, tgt_emb) |
|
return decoder, tgt_emb |
|
|
|
|
|
def build_task_specific_model(model_opt, vocabs): |
|
|
|
if model_opt.share_embeddings: |
|
|
|
assert ( |
|
vocabs["src"] == vocabs["tgt"] |
|
), "preprocess with -share_vocab if you use share_embeddings" |
|
|
|
if model_opt.model_task == ModelTask.SEQ2SEQ: |
|
encoder, src_emb = build_encoder_with_embeddings(model_opt, vocabs) |
|
decoder, _ = build_decoder_with_embeddings( |
|
model_opt, |
|
vocabs, |
|
share_embeddings=model_opt.share_embeddings, |
|
src_emb=src_emb, |
|
) |
|
return onmt.models.NMTModel(encoder=encoder, decoder=decoder) |
|
elif model_opt.model_task == ModelTask.LANGUAGE_MODEL: |
|
src_emb = build_src_emb(model_opt, vocabs) |
|
decoder, _ = build_decoder_with_embeddings( |
|
model_opt, vocabs, share_embeddings=True, src_emb=src_emb |
|
) |
|
return onmt.models.LanguageModel(decoder=decoder) |
|
else: |
|
raise ValueError(f"No model defined for {model_opt.model_task} task") |
|
|
|
|
|
def use_embeddings_from_checkpoint(vocabs, model, checkpoint): |
|
|
|
logger.info("Updating vocabulary embeddings with checkpoint embeddings") |
|
|
|
enc_emb_name = "encoder.embeddings.make_embedding.emb_luts.0.weight" |
|
dec_emb_name = "decoder.embeddings.make_embedding.emb_luts.0.weight" |
|
model_dict = {k: v for k, v in model.state_dict().items() if "generator" not in k} |
|
generator_dict = model.generator.state_dict() |
|
for side, emb_name in [("src", enc_emb_name), ("tgt", dec_emb_name)]: |
|
if emb_name not in checkpoint["model"]: |
|
continue |
|
new_tokens = [] |
|
ckp_vocabs = dict_to_vocabs(checkpoint["vocab"]) |
|
for i, tok in enumerate(vocabs[side].ids_to_tokens): |
|
if tok in ckp_vocabs[side]: |
|
old_i = ckp_vocabs[side].lookup_token(tok) |
|
model_dict[emb_name][i] = checkpoint["model"][emb_name][old_i] |
|
if side == "tgt": |
|
generator_dict["weight"][i] = checkpoint["generator"]["weight"][ |
|
old_i |
|
] |
|
generator_dict["bias"][i] = checkpoint["generator"]["bias"][old_i] |
|
else: |
|
|
|
new_tokens.append(tok) |
|
logger.info("%s: %d new tokens" % (side, len(new_tokens))) |
|
|
|
|
|
del checkpoint["model"][emb_name] |
|
del checkpoint["generator"]["weight"], checkpoint["generator"]["bias"] |
|
fake_ckpt = {"model": model_dict, "generator": generator_dict} |
|
model.load_state_dict(fake_ckpt) |
|
|
|
|
|
def build_base_model(model_opt, vocabs): |
|
"""Build a model from opts. |
|
|
|
Args: |
|
model_opt: the option loaded from checkpoint. It's important that |
|
the opts have been updated and validated. See |
|
:class:`onmt.utils.parse.ArgumentParser`. |
|
vocabs (dict[str, Vocab]): |
|
`Field` objects for the model. |
|
|
|
Returns: |
|
the NMTModel. |
|
""" |
|
|
|
|
|
try: |
|
model_opt.attention_dropout |
|
except AttributeError: |
|
model_opt.attention_dropout = model_opt.dropout |
|
|
|
|
|
model = build_task_specific_model(model_opt, vocabs) |
|
|
|
nonlora_to_quant = [ |
|
layer |
|
for layer in getattr(model_opt, "quant_layers", []) |
|
if layer not in getattr(model_opt, "lora_layers", []) |
|
] |
|
|
|
if hasattr(model_opt, "quant_layers") and len(nonlora_to_quant) > 0: |
|
if model_opt.quant_type in ["bnb_8bit", "bnb_FP4", "bnb_NF4"]: |
|
logger.info( |
|
"%s compression of layer %s" % (model_opt.quant_type, nonlora_to_quant) |
|
) |
|
try: |
|
from onmt.modules.bnb_linear import replace_bnb_linear |
|
except ImportError: |
|
raise ImportError("Install bitsandbytes to use 4/8bit compression") |
|
model = replace_bnb_linear( |
|
model, module_to_convert=nonlora_to_quant, q_type=model_opt.quant_type |
|
) |
|
else: |
|
logger.info("compression type %s not supported." % model_opt.quant_type) |
|
|
|
mark_lora = False |
|
if hasattr(model_opt, "lora_layers") and len(model_opt.lora_layers) > 0: |
|
if model_opt.freeze_encoder or model_opt.freeze_decoder: |
|
raise ValueError("Cannot use LoRa with Enc/Dec-oder freezing") |
|
for layer in model_opt.lora_layers: |
|
if hasattr(model_opt, "quant_layers") and layer in model_opt.quant_layers: |
|
quant_type = model_opt.quant_type |
|
else: |
|
quant_type = None |
|
logger.info("Adding LoRa layers for %s quant %s" % (layer, quant_type)) |
|
model = replace_lora_linear( |
|
model, |
|
r=model_opt.lora_rank, |
|
lora_alpha=model_opt.lora_alpha, |
|
lora_dropout=model_opt.lora_dropout, |
|
layer=layer, |
|
quant_type=quant_type, |
|
use_ckpting=model_opt.use_ckpting, |
|
) |
|
mark_lora = True |
|
if hasattr(model_opt, "lora_embedding") and model_opt.lora_embedding: |
|
if model_opt.freeze_encoder or model_opt.freeze_decoder: |
|
raise ValueError("Cannot use LoRa with Enc/Dec-oder freezing") |
|
logger.info("Adding LoRa Embeddings") |
|
model = replace_lora_embedding( |
|
model, r=model_opt.lora_rank, lora_alpha=model_opt.lora_alpha |
|
) |
|
mark_lora = True |
|
|
|
if mark_lora: |
|
mark_only_lora_as_trainable(model, bias="lora_only") |
|
|
|
|
|
if not model_opt.copy_attn: |
|
generator = nn.Linear(model_opt.dec_hid_size, len(vocabs["tgt"])) |
|
if model_opt.share_decoder_embeddings: |
|
generator.weight = model.decoder.embeddings.word_lut.weight |
|
else: |
|
vocab_size = len(vocabs["tgt"]) |
|
pad_idx = vocabs["tgt"][DefaultTokens.PAD] |
|
generator = CopyGenerator(model_opt.dec_hid_size, vocab_size, pad_idx) |
|
if model_opt.share_decoder_embeddings: |
|
generator.linear.weight = model.decoder.embeddings.word_lut.weight |
|
|
|
model.generator = generator |
|
|
|
return model |
|
|
|
|
|
def build_model(model_opt, opt, vocabs, checkpoint, device_id): |
|
logger.info("Building model...") |
|
|
|
model = build_base_model(model_opt, vocabs) |
|
|
|
|
|
|
|
if checkpoint is None or model_opt.update_vocab: |
|
if model_opt.param_init != 0.0: |
|
for param in model.parameters(): |
|
uniform_(param, -model_opt.param_init, model_opt.param_init) |
|
elif model_opt.param_init_glorot: |
|
for name, module in model.named_modules(): |
|
for param_name, param in module.named_parameters(): |
|
if param_name == "weight" and param.dim() > 1: |
|
xavier_uniform_(param) |
|
elif param_name == "bias": |
|
zeros_(param) |
|
else: |
|
raise ValueError("You need either param_init != 0 OR init_glorot True") |
|
|
|
if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"): |
|
model.encoder.embeddings.load_pretrained_vectors( |
|
model_opt.pre_word_vecs_enc |
|
) |
|
if hasattr(model.decoder, "embeddings"): |
|
model.decoder.embeddings.load_pretrained_vectors( |
|
model_opt.pre_word_vecs_dec |
|
) |
|
|
|
|
|
if ( |
|
model_opt.model_dtype == "fp16" |
|
and model_opt.apex_opt_level not in ["O0", "O1", "O2", "O3"] |
|
and model_opt.optim == "fusedadam" |
|
): |
|
precision = torch.float16 |
|
logger.info("Switching model to half() for FusedAdam legacy") |
|
logger.info("Non quantized layer compute is %s", model_opt.model_dtype) |
|
else: |
|
precision = torch.float32 |
|
logger.info("Switching model to float32 for amp/apex_amp") |
|
logger.info("Non quantized layer compute is %s", model_opt.model_dtype) |
|
|
|
if opt.world_size > 1 and opt.parallel_mode == "tensor_parallel": |
|
device = torch.device("cuda") |
|
offset = device_id |
|
else: |
|
if use_gpu(opt): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
offset = 0 |
|
|
|
if checkpoint is not None: |
|
if model_opt.update_vocab: |
|
if "model" in checkpoint.keys(): |
|
|
|
|
|
use_embeddings_from_checkpoint(vocabs, model, checkpoint) |
|
|
|
else: |
|
raise ValueError( |
|
"Update Vocab is not compatible with safetensors mode (yet" |
|
) |
|
|
|
|
|
|
|
strict = not model_opt.update_vocab |
|
|
|
if "model" in checkpoint.keys(): |
|
|
|
model.load_state_dict( |
|
checkpoint, |
|
precision=precision, |
|
device=device, |
|
strict=strict, |
|
offset=offset, |
|
) |
|
else: |
|
|
|
model_path = ( |
|
opt.train_from[:-3] if opt.train_from[-3:] == ".pt" else opt.train_from |
|
) |
|
model.load_safe_state_dict( |
|
model_path, |
|
precision=precision, |
|
device=device, |
|
strict=strict, |
|
offset=offset, |
|
) |
|
else: |
|
model.to(precision) |
|
model.to(device) |
|
|
|
if model_opt.freeze_encoder: |
|
model.encoder.requires_grad_(False) |
|
model.encoder.embeddings.requires_grad_() |
|
|
|
if model_opt.freeze_decoder: |
|
model.decoder.requires_grad_(False) |
|
model.decoder.embeddings.requires_grad_() |
|
|
|
logger.info(model) |
|
return model |
|
|