|
|
|
|
|
import torch |
|
from itertools import accumulate |
|
from onmt.constants import SubwordMarker |
|
|
|
|
|
def make_batch_align_matrix(index_tensor, size=None, normalize=False): |
|
""" |
|
Convert a sparse index_tensor into a batch of alignment matrix, |
|
with row normalize to the sum of 1 if set normalize. |
|
|
|
Args: |
|
index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id] |
|
size (List[int]): Size of the sparse tensor. |
|
normalize (bool): if normalize the 2nd dim of resulting tensor. |
|
""" |
|
n_fill, device = index_tensor.size(0), index_tensor.device |
|
value_tensor = torch.ones([n_fill], dtype=torch.float) |
|
dense_tensor = torch.sparse_coo_tensor( |
|
index_tensor.t(), value_tensor, size=size, device=device |
|
).to_dense() |
|
if normalize: |
|
row_sum = dense_tensor.sum(-1, keepdim=True) |
|
|
|
torch.nn.functional.threshold(row_sum, 1, 1, inplace=True) |
|
dense_tensor.div_(row_sum) |
|
return dense_tensor |
|
|
|
|
|
def extract_alignment(align_matrix, tgt_mask, src_len, n_best): |
|
""" |
|
Extract a batched align_matrix into its src indice alignment lists, |
|
with tgt_mask to filter out invalid tgt position as EOS/PAD. |
|
BOS already excluded from tgt_mask in order to match prediction. |
|
|
|
Args: |
|
align_matrix (Tensor): ``(B, tgt_len, src_len)``, |
|
attention head normalized by Softmax(dim=-1) |
|
tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD. |
|
src_len (LongTensor): ``(B,)``, containing valid src lengths |
|
n_best (int): a value indicating number of parallel translation. |
|
* B: denote flattened batch as B = batch_size * n_best. |
|
|
|
Returns: |
|
alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``, |
|
containing valid alignment matrix (or None if blank prediction) |
|
for each translation. |
|
""" |
|
batch_size_n_best = align_matrix.size(0) |
|
assert batch_size_n_best % n_best == 0 |
|
|
|
alignments = [[] for _ in range(batch_size_n_best // n_best)] |
|
|
|
|
|
for i, (am_b, tgt_mask_b, src_len) in enumerate( |
|
zip(align_matrix, tgt_mask, src_len) |
|
): |
|
valid_tgt = ~tgt_mask_b |
|
valid_tgt_len = valid_tgt.sum() |
|
if valid_tgt_len == 0: |
|
|
|
valid_alignment = None |
|
else: |
|
|
|
am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)).view( |
|
valid_tgt_len, -1 |
|
) |
|
valid_alignment = am_valid_tgt[:, :src_len] |
|
alignments[i // n_best].append(valid_alignment) |
|
|
|
return alignments |
|
|
|
|
|
def build_align_pharaoh(valid_alignment): |
|
"""Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs, |
|
or empty list if it's None. |
|
""" |
|
align_pairs = [] |
|
align_scores = [] |
|
if isinstance(valid_alignment, torch.Tensor): |
|
tgt_align_src_id = valid_alignment.argmax(dim=-1) |
|
align_scores = torch.divide( |
|
valid_alignment.max(dim=-1).values, valid_alignment.sum(dim=-1) |
|
) |
|
for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()): |
|
align_pairs.append(str(src_id) + "-" + str(tgt_id)) |
|
align_scores = [ |
|
"{0}-{1:.5f}".format(i, s) for i, s in enumerate(align_scores.tolist()) |
|
] |
|
align_pairs.sort(key=lambda x: int(x.split("-")[-1])) |
|
align_pairs.sort(key=lambda x: int(x.split("-")[0])) |
|
print(align_scores) |
|
return align_pairs, align_scores |
|
|
|
|
|
def to_word_align( |
|
src, tgt, subword_align, subword_align_scores, m_src="joiner", m_tgt="joiner" |
|
): |
|
"""Convert subword alignment to word alignment. |
|
|
|
Args: |
|
src (string): tokenized sentence in source language. |
|
tgt (string): tokenized sentence in target language. |
|
subword_align (string): align_pharaoh correspond to src-tgt. |
|
m_src (string): tokenization mode used in src, |
|
can be ["joiner", "spacer"]. |
|
m_tgt (string): tokenization mode used in tgt, |
|
can be ["joiner", "spacer"]. |
|
|
|
Returns: |
|
word_align (string): converted alignments correspand to |
|
detokenized src-tgt. |
|
""" |
|
assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!" |
|
assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!" |
|
|
|
src, tgt = src.strip().split(), tgt.strip().split() |
|
subword_align = { |
|
(int(a), int(b)) for a, b in (x.split("-") for x in subword_align.split()) |
|
} |
|
|
|
subword_align_scores = dict( |
|
(int(a), float(b)) |
|
for a, b in (x.split("-") for x in subword_align_scores.split()) |
|
) |
|
|
|
src_map = ( |
|
subword_map_by_spacer(src) if m_src == "spacer" else subword_map_by_joiner(src) |
|
) |
|
|
|
tgt_map = ( |
|
subword_map_by_spacer(tgt) if m_tgt == "spacer" else subword_map_by_joiner(tgt) |
|
) |
|
|
|
word_align = list( |
|
{"{}-{}".format(src_map[a], tgt_map[b]) for a, b in subword_align} |
|
) |
|
|
|
word_align_scores = list( |
|
{ |
|
"{}-{}".format(tgt_map[a], subword_align_scores[a]) |
|
for a in subword_align_scores.keys() |
|
} |
|
) |
|
|
|
word_align.sort(key=lambda x: int(x.split("-")[-1])) |
|
word_align.sort(key=lambda x: int(x.split("-")[0])) |
|
|
|
word_align_scores.sort(key=lambda x: int(x.split("-")[0])) |
|
|
|
return " ".join(word_align), " ".join(word_align_scores) |
|
|
|
|
|
|
|
def begin_uppercase(token): |
|
return token == SubwordMarker.BEGIN_UPPERCASE |
|
|
|
|
|
def end_uppercase(token): |
|
return token == SubwordMarker.END_UPPERCASE |
|
|
|
|
|
def begin_case(token): |
|
return token == SubwordMarker.BEGIN_CASED |
|
|
|
|
|
def case_markup(token): |
|
return begin_uppercase(token) or end_uppercase(token) or begin_case(token) |
|
|
|
|
|
def subword_map_by_joiner( |
|
subwords, original_subwords=None, marker=SubwordMarker.JOINER |
|
): |
|
"""Return word id for each subword token (annotate by joiner).""" |
|
|
|
flags = [1] * len(subwords) |
|
j = 0 |
|
finished = True |
|
for i, tok in enumerate(subwords): |
|
previous_tok = subwords[i - 1] if i else "" |
|
previous_tok_2 = subwords[i - 2] if i > 1 else "" |
|
|
|
|
|
current_original_subword = ( |
|
"" |
|
if not original_subwords |
|
else original_subwords[j] |
|
if j < len(original_subwords) |
|
else "" |
|
) |
|
|
|
if tok.startswith(marker) and tok != current_original_subword: |
|
flags[i] = 0 |
|
elif ( |
|
previous_tok.endswith(marker) |
|
or begin_case(previous_tok) |
|
or begin_uppercase(previous_tok) |
|
) and not finished: |
|
flags[i] = 0 |
|
elif ( |
|
previous_tok_2.endswith(marker) |
|
and case_markup(previous_tok) |
|
and not finished |
|
): |
|
flags[i] = 0 |
|
elif end_uppercase(tok) and tok != current_original_subword: |
|
flags[i] = 0 |
|
else: |
|
finished = False |
|
if tok == current_original_subword: |
|
finished = True |
|
j += 1 |
|
|
|
flags[0] = 0 |
|
word_group = list(accumulate(flags)) |
|
|
|
if original_subwords: |
|
assert max(word_group) < len(original_subwords) |
|
return word_group |
|
|
|
|
|
def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER): |
|
"""Return word id for each subword token (annotate by spacer).""" |
|
flags = [0] * len(subwords) |
|
for i, tok in enumerate(subwords): |
|
if marker in tok: |
|
if case_markup(tok.replace(marker, "")): |
|
if i < len(subwords) - 1: |
|
flags[i] = 1 |
|
else: |
|
if i > 0: |
|
previous = subwords[i - 1].replace(marker, "") |
|
if not case_markup(previous): |
|
flags[i] = 1 |
|
|
|
|
|
for i in range(1, len(subwords) - 1): |
|
if case_markup(subwords[-i]): |
|
flags[-i] = 0 |
|
elif subwords[-i] == marker: |
|
flags[-i] = 0 |
|
break |
|
|
|
word_group = list(accumulate(flags)) |
|
if word_group[0] == 1: |
|
word_group = [item - 1 for item in word_group] |
|
return word_group |
|
|