|
import re |
|
import torch |
|
from transformers import DonutProcessor |
|
from transformers.utils import add_start_docstrings |
|
from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING |
|
|
|
|
|
class NoRepeatNGramLogitsProcessor(LogitsProcessor): |
|
def __init__(self, ngram_size: int, skip_tokens = None): |
|
if not isinstance(ngram_size, int) or ngram_size <= 0: |
|
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") |
|
self.ngram_size = ngram_size |
|
self.skip_tokens = skip_tokens |
|
|
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
num_batch_hypotheses = scores.shape[0] |
|
cur_len = input_ids.shape[-1] |
|
return _no_repeat_ngram_logits(input_ids, cur_len, scores, batch_size = num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, skip_tokens = self.skip_tokens) |
|
|
|
def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_ngram_size=0, skip_tokens=None): |
|
if no_repeat_ngram_size > 0: |
|
|
|
|
|
banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) |
|
for batch_idx in range(batch_size): |
|
logits[batch_idx, [token for token in banned_tokens[batch_idx] if skip_tokens is not None and int(token) not in skip_tokens]] = -float("inf") |
|
|
|
return logits |
|
|
|
def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): |
|
|
|
if cur_len + 1 < no_repeat_ngram_size: |
|
|
|
return [[] for _ in range(num_hypos)] |
|
generated_ngrams = [{} for _ in range(num_hypos)] |
|
for idx in range(num_hypos): |
|
gen_tokens = prev_input_ids[idx] |
|
generated_ngram = generated_ngrams[idx] |
|
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): |
|
|
|
prev_ngram_tuple = tuple(ngram[:-1]) |
|
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ |
|
ngram[-1] |
|
] |
|
|
|
def _get_generated_ngrams(hypo_idx): |
|
|
|
start_idx = cur_len + 1 - no_repeat_ngram_size |
|
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) |
|
|
|
return generated_ngrams[hypo_idx].get(ngram_idx, []) |
|
|
|
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] |
|
return banned_tokens |
|
|
|
|
|
def get_table_token_ids(processor): |
|
skip_tokens = {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if re.search('<t.*>', token)} |
|
|
|
|