sarinam's picture
Added IMSToucan as normal directory instead of submodule
ef12a74
raw
history blame
9.85 kB
"""
Taken from ESPNet, modified by Florian Lux
"""
import os
from abc import ABC
import torch
def cumsum_durations(durations):
out = [0]
for duration in durations:
out.append(duration + out[-1])
centers = list()
for index, _ in enumerate(out):
if index + 1 < len(out):
centers.append((out[index] + out[index + 1]) / 2)
return out, centers
def delete_old_checkpoints(checkpoint_dir, keep=5):
checkpoint_list = list()
for el in os.listdir(checkpoint_dir):
if el.endswith(".pt") and el != "best.pt":
checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
if len(checkpoint_list) <= keep:
return
else:
checkpoint_list.sort(reverse=False)
checkpoints_to_delete = [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:-keep]]
for old_checkpoint in checkpoints_to_delete:
os.remove(os.path.join(old_checkpoint))
def get_most_recent_checkpoint(checkpoint_dir, verbose=True):
checkpoint_list = list()
for el in os.listdir(checkpoint_dir):
if el.endswith(".pt") and el != "best.pt":
checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
if len(checkpoint_list) == 0:
print("No previous checkpoints found, cannot reload.")
return None
checkpoint_list.sort(reverse=True)
if verbose:
print("Reloading checkpoint_{}.pt".format(checkpoint_list[0]))
return os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(checkpoint_list[0]))
def make_pad_mask(lengths, xs=None, length_dim=-1, device=None):
"""
Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
if device is not None:
seq_range = torch.arange(0, maxlen, dtype=torch.int64, device=device)
else:
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1, device=None):
"""
Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
"""
return ~make_pad_mask(lengths, xs, length_dim, device=device)
def initialize(model, init):
"""
Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Args:
model: Target.
init: Method of initialization.
"""
# weight init
for p in model.parameters():
if p.dim() > 1:
if init == "xavier_uniform":
torch.nn.init.xavier_uniform_(p.data)
elif init == "xavier_normal":
torch.nn.init.xavier_normal_(p.data)
elif init == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
elif init == "kaiming_normal":
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
else:
raise ValueError("Unknown initialization: " + init)
# bias init
for p in model.parameters():
if p.dim() == 1:
p.data.zero_()
# reset some modules with default init
for m in model.modules():
if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)):
m.reset_parameters()
def pad_list(xs, pad_value):
"""
Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
"""
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def subsequent_mask(size, device="cpu", dtype=torch.bool):
"""
Create mask for subsequent steps (size, size).
:param int size: size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype
"""
ret = torch.ones(size, size, device=device, dtype=dtype)
return torch.tril(ret, out=ret)
class ScorerInterface:
"""
Scorer interface for beam search.
The scorer performs scoring of the all tokens in vocabulary.
Examples:
* Search heuristics
* :class:`espnet.nets.scorers.length_bonus.LengthBonus`
* Decoder networks of the sequence-to-sequence models
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
* Neural language models
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
"""
def init_state(self, x):
"""
Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return None
def select_state(self, state, i, new_id=None):
"""
Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label index to select a state if necessary
Returns:
state: pruned state
"""
return None if state is None else state[i]
def score(self, y, state, x):
"""
Score new token (required).
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
scores for next token that has a shape of `(n_vocab)`
and next state for ys
"""
raise NotImplementedError
def final_score(self, state):
"""
Score eos (optional).
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return 0.0
class BatchScorerInterface(ScorerInterface, ABC):
def batch_init_state(self, x):
"""
Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return self.init_state(x)
def batch_score(self, ys, states, xs):
"""
Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
scores = list()
outstates = list()
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
score, outstate = self.score(y, state, x)
outstates.append(outstate)
scores.append(score)
scores = torch.cat(scores, 0).view(ys.shape[0], -1)
return scores, outstates
def to_device(m, x):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
if isinstance(m, torch.nn.Module):
device = next(m.parameters()).device
elif isinstance(m, torch.Tensor):
device = m.device
else:
raise TypeError(
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
)
return x.to(device)