ReactSeq / onmt /models /model.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
15.5 kB
""" Onmt NMT Model base class definition """
import torch
import torch.nn as nn
import glob
class BaseModel(nn.Module):
"""Core trainable object in OpenNMT. Implements a trainable interface
for a simple, generic encoder / decoder or decoder only model.
Args:
encoder (onmt.encoders.EncoderBase): an encoder object
decoder (onmt.decoders.DecoderBase): a decoder object"""
def __init__(self, encoder, decoder):
super(BaseModel, self).__init__()
def forward(self, src, tgt, src_len, bptt=False, with_align=False):
"""Forward propagate a `src` and `tgt` pair for training.
Args:
src (Tensor): A source sequence passed to encoder.
Typically for input this will be a padded `LongTensor`
of size ``(batch, len, features)``. However, may be an
image or other generic input depending on encoder.
tgt (LongTensor): A target sequence passed to decoder.
Size ``(batch, tgt_len, features)``.
src_len(LongTensor): The src lengths, pre-padding ``(batch,)``.
bptt (Boolean): A flag indicating if truncated bptt is set.
If bptt is false then init decoder state.
with_align (Boolean): A flag indicating whether output alignment,
Only valid for transformer decoder.
Returns:
(FloatTensor, dict[str, FloatTensor]):
* decoder output ``(batch, tgt_len, hidden)``
* dictionary of attention weights ``(batch, tgt_len, src_len)``"""
raise NotImplementedError
def update_dropout(self, dropout, attention_dropout):
raise NotImplementedError
def count_parameters(self, log=print):
raise NotImplementedError
def load_state_dict(
self,
checkpoint,
precision=torch.float32,
device=torch.device("cpu"),
strict=True,
offset=0,
):
"""Custom state_dict loading to enable moving module on device as they are loaded
Args:
checkpoint: Pytorch serialized checkpoint
precision: precision to move each module to
device: device to move each module to
strict: if True checks model keys wrt state_dict (both ways)
"""
# bitsandbytes quantize weights when .cuda() is called
# for huge models we need to save Ram
# so we load the weights module by module and transfer them to GPU for quantization
if device == torch.device("cpu"):
offset = 0
buf_list = []
for name, module in self.named_modules():
for buf_name, buf in module.named_buffers():
buf_list.append(buf_name)
if len(buf_name.split(".")) == 1: # only last key
if precision == torch.int8:
torch.quantization.quantize_dynamic(module, inplace=True)
else:
module.to(precision)
module.to(device)
for param_name, param in module.named_parameters():
if len(param_name.split(".")) == 1: # only last key
if name + "." + param_name in checkpoint["model"].keys():
ckpt_t = checkpoint["model"][name + "." + param_name]
if name.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
col_slice_start = param.data.size(0) * offset
col_slice_end = param.data.size(0) * (offset + 1)
else:
col_slice_start = 0
col_slice_end = param.data.size(0)
if param.data.dim() == 2:
if name.split(".")[-1] in ["final_linear", "w_2"]:
row_slice_start = param.data.size(1) * offset
row_slice_end = param.data.size(1) * (offset + 1)
else:
row_slice_start = 0
row_slice_end = param.data.size(1)
assert (
param.data.size()
== ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].size()
), "An error in model's partition and checkpoint's slice was detected"
param.data = ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
]
else:
assert (
param.data.size()
== ckpt_t[col_slice_start:col_slice_end].size()
), "An error in model's partition and checkpoint's slice was detected"
param.data = ckpt_t[col_slice_start:col_slice_end]
del checkpoint["model"][name + "." + param_name]
elif (
"generator" in checkpoint.keys()
and name == "generator"
and checkpoint["generator"] is not None
and param_name in checkpoint["generator"].keys()
):
param.data = checkpoint["generator"][param_name]
del checkpoint["generator"][param_name]
elif strict and "lora" not in param_name:
raise ValueError(
"Missing key in checkpoint: %s" % name + "." + param_name
)
if precision == torch.int8:
torch.quantization.quantize_dynamic(module, inplace=True)
else:
module.to(precision)
module.to(device)
for key in checkpoint[
"model"
].keys(): # if some keys are left in checkpoint after deletion
if key not in buf_list:
raise ValueError(
"Extra keys in model state_dict do not match the model config %s"
% checkpoint["model"].keys()
)
if checkpoint["generator"]:
for key in checkpoint["generator"].keys():
if key not in buf_list:
raise ValueError(
"Extra keys in generator state_dict do not match the model config %s"
% checkpoint["generator"].keys()
)
def load_safe_state_dict(
self,
model_path,
precision=torch.float32,
device=torch.device("cpu"),
strict=True,
offset=0,
):
"""Custom state_dict loading to enable moving module on device as they are loaded
Args:
model_path: Model path
precision: same as above
device: same as above
strict: same as above
"""
# bitsandbytes quantize weights when .cuda() is called
# for huge models we need to save Ram
# so we load the weights module by module and transfer them to GPU for quantization
try:
import safetensors
except ImportError:
raise ImportError("run: pip install safetensors, to use safetensors")
keyfound = {}
shards = glob.glob(model_path + ".*.safetensors")
if len(shards) == 0:
raise ValueError("No safetensors file found")
f = []
keys_shard = {}
for i, shard in enumerate(shards):
f.append(safetensors.safe_open(shard, framework="pt", device="cpu"))
for key in f[i].keys():
keys_shard[key] = i
buf_list = []
for name, module in self.named_modules():
for buf_name, buf in module.named_buffers():
buf_list.append(buf_name)
if len(buf_name.split(".")) == 1: # only last key
if precision == torch.int8:
torch.quantization.quantize_dynamic(module, inplace=True)
else:
module.to(precision)
module.to(device)
for param_name, param in module.named_parameters():
if len(param_name.split(".")) == 1: # only last key
if name + "." + param_name in keys_shard.keys():
ckpt_t = f[keys_shard[name + "." + param_name]].get_tensor(
name + "." + param_name
)
if name.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
col_slice_start = param.data.size(0) * offset
col_slice_end = param.data.size(0) * (offset + 1)
else:
col_slice_start = 0
col_slice_end = param.data.size(0)
if param.data.dim() == 2:
if name.split(".")[-1] in ["final_linear", "w_2"]:
row_slice_start = param.data.size(1) * offset
row_slice_end = param.data.size(1) * (offset + 1)
else:
row_slice_start = 0
row_slice_end = param.data.size(1)
assert (
param.data.size()
== ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].size()
), "An error in model's partition and checkpoint's slice was detected"
param.data = ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
]
else:
assert (
param.data.size()
== ckpt_t[col_slice_start:col_slice_end].size()
), "An error in model's partition and checkpoint's slice was detected"
param.data = ckpt_t[col_slice_start:col_slice_end]
keyfound[name + "." + param_name] = True
elif strict and "lora" not in param_name:
raise ValueError(
"Missing key in safetensors checkpoint: %s" % name
+ "."
+ param_name
)
if precision == torch.int8:
torch.quantization.quantize_dynamic(module, inplace=True)
else:
module.to(precision)
module.to(device)
for key in keys_shard.keys():
if key not in keyfound.keys() and key not in buf_list:
raise ValueError(
"Extra keys in model state_dict do not match the model config %s"
% key
)
class NMTModel(BaseModel):
"""NMTModel Class
See :class:`~onmt.models.BaseModel` for options."""
def __init__(self, encoder, decoder):
super(NMTModel, self).__init__(encoder, decoder)
self.encoder = encoder
self.decoder = decoder
def forward(self, src, tgt, src_len, bptt=False, with_align=False):
"""An NMTModel forward the src side to the encoder.
Then the output of encoder ``enc_out`` is forwarded to the
decoder along with the target excluding the last token.
The decoder state is initiliazed with:
* enc_final_hs in the case of RNNs
* enc_out + enc_final_hs in the case of CNNs
* src in the case of Transformer"""
dec_in = tgt[:, :-1, :]
enc_out, enc_final_hs, src_len = self.encoder(src, src_len)
if not bptt:
self.decoder.init_state(src, enc_out, enc_final_hs)
dec_out, attns = self.decoder(
dec_in, enc_out, src_len=src_len, with_align=with_align
)
return dec_out, attns
def update_dropout(self, dropout, attention_dropout):
self.encoder.update_dropout(dropout, attention_dropout)
self.decoder.update_dropout(dropout, attention_dropout)
def count_parameters(self, log=print):
"""Count number of parameters in model (& print with `log` callback).
Returns:
(int, int):
* encoder side parameter count
* decoder side parameter count"""
enc, dec = 0, 0
for name, param in self.named_parameters():
if "encoder" in name:
enc += param.nelement()
else:
dec += param.nelement()
if callable(log):
log("encoder: {}".format(enc))
log("decoder: {}".format(dec))
log("* number of parameters: {}".format(enc + dec))
return enc, dec
class LanguageModel(BaseModel):
"""NMTModel Class
Currently TransformerLMDecoder is the only LM decoder implemented
Args:
decoder (onmt.decoders.TransformerLMDecoder): a transformer decoder"""
def __init__(self, encoder=None, decoder=None):
super(LanguageModel, self).__init__(encoder, decoder)
if encoder is not None:
raise ValueError("LanguageModel should not be used" "with an encoder")
self.decoder = decoder
def forward(self, src, tgt, src_len, bptt=False, with_align=False):
"""A LanguageModel forward the src side to the decoder along
with the source lengths vector. It is a decoder only LM (cf GPT-2)"""
if not bptt:
self.decoder.init_state()
dec_out, attns = self.decoder(
src, enc_out=None, src_len=src_len, with_align=with_align
)
return dec_out, attns
def update_dropout(self, dropout, attention_dropout):
self.decoder.update_dropout(dropout, attention_dropout)
def count_parameters(self, log=print):
"""Count number of parameters in model (& print with `log` callback).
Returns: (int, int)
encoder side parameter count
decoder side parameter count"""
enc, dec = 0, 0
for name, param in self.named_parameters():
if "decoder" in name:
dec += param.nelement()
if callable(log):
# No encoder in LM, seq2seq count formatting kept
log("encoder: {}".format(enc))
log("decoder: {}".format(dec))
log("* number of parameters: {}".format(enc + dec))
return enc, dec