File size: 3,408 Bytes
4792343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a056c48
 
 
 
4792343
 
 
 
 
 
 
 
 
 
a056c48
4792343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec9e91a
4792343
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import re
import torch
from transformers import DonutProcessor
from transformers.utils import add_start_docstrings
from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING

# Inspired on https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
        def __init__(self, ngram_size: int, skip_tokens = None):
            if not isinstance(ngram_size, int) or ngram_size <= 0:
                raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
            self.ngram_size = ngram_size
            self.skip_tokens = skip_tokens

        @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            num_batch_hypotheses = scores.shape[0]
            cur_len = input_ids.shape[-1]
            return _no_repeat_ngram_logits(input_ids, cur_len, scores, batch_size = num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, skip_tokens = self.skip_tokens)

def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_ngram_size=0, skip_tokens=None):
    if no_repeat_ngram_size > 0:
        # calculate a list of banned tokens to prevent repetitively generating the same ngrams
        # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
        banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
        for batch_idx in range(batch_size):
            if skip_tokens is not None:
                logits[batch_idx, [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens]] = -float("inf")
            else:
                logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

    return logits

def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
    # Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):

            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [
                ngram[-1]
            ]

    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())

        return generated_ngrams[hypo_idx].get(ngram_idx, [])

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


def get_table_token_ids(processor):
    return {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if token.startswith("<t") or token.startswith("</t") }