# -*- coding: utf-8 -*- import torch from itertools import accumulate from onmt.constants import SubwordMarker def make_batch_align_matrix(index_tensor, size=None, normalize=False): """ Convert a sparse index_tensor into a batch of alignment matrix, with row normalize to the sum of 1 if set normalize. Args: index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id] size (List[int]): Size of the sparse tensor. normalize (bool): if normalize the 2nd dim of resulting tensor. """ n_fill, device = index_tensor.size(0), index_tensor.device value_tensor = torch.ones([n_fill], dtype=torch.float) dense_tensor = torch.sparse_coo_tensor( index_tensor.t(), value_tensor, size=size, device=device ).to_dense() if normalize: row_sum = dense_tensor.sum(-1, keepdim=True) # sum by row(tgt) # threshold on 1 to avoid div by 0 torch.nn.functional.threshold(row_sum, 1, 1, inplace=True) dense_tensor.div_(row_sum) return dense_tensor def extract_alignment(align_matrix, tgt_mask, src_len, n_best): """ Extract a batched align_matrix into its src indice alignment lists, with tgt_mask to filter out invalid tgt position as EOS/PAD. BOS already excluded from tgt_mask in order to match prediction. Args: align_matrix (Tensor): ``(B, tgt_len, src_len)``, attention head normalized by Softmax(dim=-1) tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD. src_len (LongTensor): ``(B,)``, containing valid src lengths n_best (int): a value indicating number of parallel translation. * B: denote flattened batch as B = batch_size * n_best. Returns: alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``, containing valid alignment matrix (or None if blank prediction) for each translation. """ batch_size_n_best = align_matrix.size(0) assert batch_size_n_best % n_best == 0 alignments = [[] for _ in range(batch_size_n_best // n_best)] # treat alignment matrix one by one as each have different lengths for i, (am_b, tgt_mask_b, src_len) in enumerate( zip(align_matrix, tgt_mask, src_len) ): valid_tgt = ~tgt_mask_b valid_tgt_len = valid_tgt.sum() if valid_tgt_len == 0: # No alignment if not exist valid tgt token valid_alignment = None else: # get valid alignment (sub-matrix from full paded aligment matrix) am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)).view( valid_tgt_len, -1 ) valid_alignment = am_valid_tgt[:, :src_len] # only keep valid src alignments[i // n_best].append(valid_alignment) return alignments def build_align_pharaoh(valid_alignment): """Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs, or empty list if it's None. """ align_pairs = [] align_scores = [] if isinstance(valid_alignment, torch.Tensor): tgt_align_src_id = valid_alignment.argmax(dim=-1) align_scores = torch.divide( valid_alignment.max(dim=-1).values, valid_alignment.sum(dim=-1) ) for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()): align_pairs.append(str(src_id) + "-" + str(tgt_id)) align_scores = [ "{0}-{1:.5f}".format(i, s) for i, s in enumerate(align_scores.tolist()) ] align_pairs.sort(key=lambda x: int(x.split("-")[-1])) # sort by tgt_id align_pairs.sort(key=lambda x: int(x.split("-")[0])) # sort by src_id print(align_scores) return align_pairs, align_scores def to_word_align( src, tgt, subword_align, subword_align_scores, m_src="joiner", m_tgt="joiner" ): """Convert subword alignment to word alignment. Args: src (string): tokenized sentence in source language. tgt (string): tokenized sentence in target language. subword_align (string): align_pharaoh correspond to src-tgt. m_src (string): tokenization mode used in src, can be ["joiner", "spacer"]. m_tgt (string): tokenization mode used in tgt, can be ["joiner", "spacer"]. Returns: word_align (string): converted alignments correspand to detokenized src-tgt. """ assert m_src in ["joiner", "spacer"], "Invalid value for argument m_src!" assert m_tgt in ["joiner", "spacer"], "Invalid value for argument m_tgt!" src, tgt = src.strip().split(), tgt.strip().split() subword_align = { (int(a), int(b)) for a, b in (x.split("-") for x in subword_align.split()) } subword_align_scores = dict( (int(a), float(b)) for a, b in (x.split("-") for x in subword_align_scores.split()) ) src_map = ( subword_map_by_spacer(src) if m_src == "spacer" else subword_map_by_joiner(src) ) tgt_map = ( subword_map_by_spacer(tgt) if m_tgt == "spacer" else subword_map_by_joiner(tgt) ) word_align = list( {"{}-{}".format(src_map[a], tgt_map[b]) for a, b in subword_align} ) word_align_scores = list( { "{}-{}".format(tgt_map[a], subword_align_scores[a]) for a in subword_align_scores.keys() } ) word_align.sort(key=lambda x: int(x.split("-")[-1])) # sort by tgt_id word_align.sort(key=lambda x: int(x.split("-")[0])) # sort by src_id word_align_scores.sort(key=lambda x: int(x.split("-")[0])) return " ".join(word_align), " ".join(word_align_scores) # Helper functions def begin_uppercase(token): return token == SubwordMarker.BEGIN_UPPERCASE def end_uppercase(token): return token == SubwordMarker.END_UPPERCASE def begin_case(token): return token == SubwordMarker.BEGIN_CASED def case_markup(token): return begin_uppercase(token) or end_uppercase(token) or begin_case(token) def subword_map_by_joiner( subwords, original_subwords=None, marker=SubwordMarker.JOINER ): """Return word id for each subword token (annotate by joiner).""" flags = [1] * len(subwords) j = 0 finished = True for i, tok in enumerate(subwords): previous_tok = subwords[i - 1] if i else "" # Previous N-1 token previous_tok_2 = subwords[i - 2] if i > 1 else "" # Previous N-2 token # Keeps track of the original words/subwords # ('prior_tokenization' option) current_original_subword = ( "" if not original_subwords else original_subwords[j] if j < len(original_subwords) else "" ) if tok.startswith(marker) and tok != current_original_subword: flags[i] = 0 elif ( previous_tok.endswith(marker) or begin_case(previous_tok) or begin_uppercase(previous_tok) ) and not finished: flags[i] = 0 elif ( previous_tok_2.endswith(marker) and case_markup(previous_tok) and not finished ): flags[i] = 0 elif end_uppercase(tok) and tok != current_original_subword: flags[i] = 0 else: finished = False if tok == current_original_subword: finished = True j += 1 flags[0] = 0 word_group = list(accumulate(flags)) if original_subwords: assert max(word_group) < len(original_subwords) return word_group def subword_map_by_spacer(subwords, marker=SubwordMarker.SPACER): """Return word id for each subword token (annotate by spacer).""" flags = [0] * len(subwords) for i, tok in enumerate(subwords): if marker in tok: if case_markup(tok.replace(marker, "")): if i < len(subwords) - 1: flags[i] = 1 else: if i > 0: previous = subwords[i - 1].replace(marker, "") if not case_markup(previous): flags[i] = 1 # In case there is a final case_markup when new_spacer is on for i in range(1, len(subwords) - 1): if case_markup(subwords[-i]): flags[-i] = 0 elif subwords[-i] == marker: flags[-i] = 0 break word_group = list(accumulate(flags)) if word_group[0] == 1: # when dummy prefix is set word_group = [item - 1 for item in word_group] return word_group