ajimeno commited on
Commit
4792343
·
1 Parent(s): 5d1e573

Added function to skip table tokens from no repeat ngrams

Browse files
Files changed (2) hide show
  1. app.py +8 -7
  2. logits_ngrams.py +60 -0
app.py CHANGED
@@ -8,17 +8,17 @@ from PIL import Image
8
  from io import BytesIO
9
  from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
10
 
 
 
11
  def run_prediction(sample, model, processor, mode):
12
 
 
 
 
13
  if mode == "OCR":
14
  prompt = "<s><s_pretraining>"
15
- no_repeat_ngram_size = 10
16
- elif mode == "Table":
17
- prompt = "<s><s_hierarchical>"
18
- no_repeat_ngram_size = 0
19
  else:
20
  prompt = "<s><s_hierarchical>"
21
- no_repeat_ngram_size = 10
22
 
23
 
24
  print("prompt:", prompt)
@@ -33,10 +33,11 @@ def run_prediction(sample, model, processor, mode):
33
  outputs = model.generate(
34
  pixel_values.to(device),
35
  decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
 
36
  do_sample=True,
37
  top_p=0.92, #.92,
38
  top_k=5,
39
- no_repeat_ngram_size=no_repeat_ngram_size,
40
  num_beams=3,
41
  output_attentions=False,
42
  output_hidden_states=False,
@@ -70,7 +71,7 @@ with st.sidebar:
70
  image_bytes_data = uploaded_file.getvalue()
71
  image_upload = Image.open(BytesIO(image_bytes_data))
72
 
73
- mode = st.selectbox('Mode', ('OCR', 'Table', 'Element annotation'), index=2)
74
 
75
  if image_upload:
76
  image = image_upload
 
8
  from io import BytesIO
9
  from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, DonutProcessor, DonutImageProcessor, AutoTokenizer
10
 
11
+ from logits_ngrams import NoRepeatNGramLogitsProcessor, get_table_token_ids
12
+
13
  def run_prediction(sample, model, processor, mode):
14
 
15
+ skip_tokens = get_table_token_ids(processor)
16
+ no_repeat_ngram_size = 10
17
+
18
  if mode == "OCR":
19
  prompt = "<s><s_pretraining>"
 
 
 
 
20
  else:
21
  prompt = "<s><s_hierarchical>"
 
22
 
23
 
24
  print("prompt:", prompt)
 
33
  outputs = model.generate(
34
  pixel_values.to(device),
35
  decoder_input_ids=processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.to(device),
36
+ logits_processor=[NoRepeatNGramLogitsProcessor(no_repeat_ngram_size, skip_tokens)],
37
  do_sample=True,
38
  top_p=0.92, #.92,
39
  top_k=5,
40
+ no_repeat_ngram_size=0,
41
  num_beams=3,
42
  output_attentions=False,
43
  output_hidden_states=False,
 
71
  image_bytes_data = uploaded_file.getvalue()
72
  image_upload = Image.open(BytesIO(image_bytes_data))
73
 
74
+ mode = st.selectbox('Mode', ('OCR', 'Element annotation'), index=1)
75
 
76
  if image_upload:
77
  image = image_upload
logits_ngrams.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from transformers import DonutProcessor
4
+ from transformers.utils import add_start_docstrings
5
+ from transformers.generation.logits_process import LogitsProcessor, LOGITS_PROCESSOR_INPUTS_DOCSTRING
6
+
7
+ # Inspired on https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706
8
+ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
9
+ def __init__(self, ngram_size: int, skip_tokens = None):
10
+ if not isinstance(ngram_size, int) or ngram_size <= 0:
11
+ raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
12
+ self.ngram_size = ngram_size
13
+ self.skip_tokens = skip_tokens
14
+
15
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
16
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
17
+ num_batch_hypotheses = scores.shape[0]
18
+ cur_len = input_ids.shape[-1]
19
+ return _no_repeat_ngram_logits(input_ids, cur_len, scores, batch_size = num_batch_hypotheses, no_repeat_ngram_size=self.ngram_size, skip_tokens = self.skip_tokens)
20
+
21
+ def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_ngram_size=0, skip_tokens=None):
22
+ if no_repeat_ngram_size > 0:
23
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
24
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
25
+ banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
26
+ for batch_idx in range(batch_size):
27
+ logits[batch_idx, [token for token in banned_tokens[batch_idx] if skip_tokens is not None and int(token) not in skip_tokens]] = -float("inf")
28
+
29
+ return logits
30
+
31
+ def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
32
+ # Copied from fairseq for no_repeat_ngram in beam_search"""
33
+ if cur_len + 1 < no_repeat_ngram_size:
34
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
35
+ return [[] for _ in range(num_hypos)]
36
+ generated_ngrams = [{} for _ in range(num_hypos)]
37
+ for idx in range(num_hypos):
38
+ gen_tokens = prev_input_ids[idx] # .tolist()
39
+ generated_ngram = generated_ngrams[idx]
40
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
41
+
42
+ prev_ngram_tuple = tuple(ngram[:-1])
43
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [
44
+ ngram[-1]
45
+ ]
46
+
47
+ def _get_generated_ngrams(hypo_idx):
48
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
49
+ start_idx = cur_len + 1 - no_repeat_ngram_size
50
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
51
+
52
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
53
+
54
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
55
+ return banned_tokens
56
+
57
+
58
+ def get_table_token_ids(processor):
59
+ skip_tokens = {token_id for token, token_id in processor.tokenizer.get_added_vocab().items() if re.search('<t.*>', token)}
60
+