Added function to skip table tokens from no repeat ngrams
Browse files- app.py +8 -7
- logits_ngrams.py +60 -0
app.py
CHANGED
@@ -8,17 +8,17 @@ from PIL import Image
|
|
8 |
from io import BytesIO
|
9 |
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
|
10 |
|
|
|
|
|
11 |
def run_prediction(sample, model, processor, mode):
|
12 |
|
|
|
|
|
|
|
13 |
if mode == "OCR":
|
14 |
prompt = "<s><s_pretraining>"
|
15 |
-
no_repeat_ngram_size = 10
|
16 |
-
elif mode == "Table":
|
17 |
-
prompt = "<s><s_hierarchical>"
|
18 |
-
no_repeat_ngram_size = 0
|
19 |
else:
|
20 |
prompt = "<s><s_hierarchical>"
|
21 |
-
no_repeat_ngram_size = 10
|
22 |
|
23 |
|
24 |
print("prompt:", prompt)
|
@@ -33,10 +33,11 @@ def run_prediction(sample, model, processor, mode):
|
|
33 |
outputs = model.generate(
|
34 |
pixel_values.to(device),
|
35 |
decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
|
|
|
36 |
do_sample=True,
|
37 |
top_p=0.92, #.92,
|
38 |
top_k=5,
|
39 |
-
no_repeat_ngram_size=
|
40 |
num_beams=3,
|
41 |
output_attentions=False,
|
42 |
output_hidden_states=False,
|
@@ -70,7 +71,7 @@ with st.sidebar:
|
|
70 |
image_bytes_data = uploaded_file.getvalue()
|
71 |
image_upload = Image.open(BytesIO(image_bytes_data))
|
72 |
|
73 |
-
mode = st.selectbox('Mode', ('OCR', '
|
74 |
|
75 |
if image_upload:
|
76 |
image = image_upload
|
|
|
8 |
from io import BytesIO
|
9 |
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
|
10 |
|
11 |
+
from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids
|
12 |
+
|
13 |
def run_prediction(sample, model, processor, mode):
|
14 |
|
15 |
+
skip_tokens = get_table_token_ids(processor)
|
16 |
+
no_repeat_ngram_size = 10
|
17 |
+
|
18 |
if mode == "OCR":
|
19 |
prompt = "<s><s_pretraining>"
|
|
|
|
|
|
|
|
|
20 |
else:
|
21 |
prompt = "<s><s_hierarchical>"
|
|
|
22 |
|
23 |
|
24 |
print("prompt:", prompt)
|
|
|
33 |
outputs = model.generate(
|
34 |
pixel_values.to(device),
|
35 |
decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
|
36 |
+
logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
|
37 |
do_sample=True,
|
38 |
top_p=0.92, #.92,
|
39 |
top_k=5,
|
40 |
+
no_repeat_ngram_size=0,
|
41 |
num_beams=3,
|
42 |
output_attentions=False,
|
43 |
output_hidden_states=False,
|
|
|
71 |
image_bytes_data = uploaded_file.getvalue()
|
72 |
image_upload = Image.open(BytesIO(image_bytes_data))
|
73 |
|
74 |
+
mode = st.selectbox('Mode', ('OCR', 'Element annotation'), index=1)
|
75 |
|
76 |
if image_upload:
|
77 |
image = image_upload
|
logits_ngrams.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
from transformers import DonutProcessor
|
4 |
+
from transformers.utils import add_start_docstrings
|
5 |
+
from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING
|
6 |
+
|
7 |
+
# Inspired on https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706
|
8 |
+
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
9 |
+
def __init__(self, ngram_size: int, skip_tokens = None):
|
10 |
+
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
11 |
+
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
12 |
+
self.ngram_size = ngram_size
|
13 |
+
self.skip_tokens = skip_tokens
|
14 |
+
|
15 |
+
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
16 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
17 |
+
num_batch_hypotheses = scores.shape[0]
|
18 |
+
cur_len = input_ids.shape[-1]
|
19 |
+
return _no_repeat_ngram_logits(input_ids, cur_len, scores, batch_size = num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, skip_tokens = self.skip_tokens)
|
20 |
+
|
21 |
+
def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_ngram_size=0, skip_tokens=None):
|
22 |
+
if no_repeat_ngram_size > 0:
|
23 |
+
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
24 |
+
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
25 |
+
banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
26 |
+
for batch_idx in range(batch_size):
|
27 |
+
logits[batch_idx, [token for token in banned_tokens[batch_idx] if skip_tokens is not None and int(token) not in skip_tokens]] = -float("inf")
|
28 |
+
|
29 |
+
return logits
|
30 |
+
|
31 |
+
def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
|
32 |
+
# Copied from fairseq for no_repeat_ngram in beam_search"""
|
33 |
+
if cur_len + 1 < no_repeat_ngram_size:
|
34 |
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
35 |
+
return [[] for _ in range(num_hypos)]
|
36 |
+
generated_ngrams = [{} for _ in range(num_hypos)]
|
37 |
+
for idx in range(num_hypos):
|
38 |
+
gen_tokens = prev_input_ids[idx] # .tolist()
|
39 |
+
generated_ngram = generated_ngrams[idx]
|
40 |
+
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
41 |
+
|
42 |
+
prev_ngram_tuple = tuple(ngram[:-1])
|
43 |
+
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [
|
44 |
+
ngram[-1]
|
45 |
+
]
|
46 |
+
|
47 |
+
def _get_generated_ngrams(hypo_idx):
|
48 |
+
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
49 |
+
start_idx = cur_len + 1 - no_repeat_ngram_size
|
50 |
+
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
51 |
+
|
52 |
+
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
53 |
+
|
54 |
+
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
55 |
+
return banned_tokens
|
56 |
+
|
57 |
+
|
58 |
+
def get_table_token_ids(processor):
|
59 |
+
skip_tokens = {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if re.search('<t.*>', token)}
|
60 |
+
|