ajimeno commited on
Commit
a056c48
·
1 Parent(s): a55d38b

Patched error

Browse files
Files changed (1) hide show
  1. logits_ngrams.py +5 -2
logits_ngrams.py CHANGED
@@ -24,7 +24,10 @@ def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_
24
  # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
25
  banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
26
  for batch_idx in range(batch_size):
27
- logits[batch_idx, [token for token in banned_tokens[batch_idx] if skip_tokens is not None and int(token) not in skip_tokens]] = -float("inf")
 
 
 
28
 
29
  return logits
30
 
@@ -35,7 +38,7 @@ def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len
35
  return [[] for _ in range(num_hypos)]
36
  generated_ngrams = [{} for _ in range(num_hypos)]
37
  for idx in range(num_hypos):
38
- gen_tokens = prev_input_ids[idx] # .tolist()
39
  generated_ngram = generated_ngrams[idx]
40
  for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
41
 
 
24
  # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
25
  banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
26
  for batch_idx in range(batch_size):
27
+ if skip_tokens is not None:
28
+ logits[batch_idx, [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens]] = -float("inf")
29
+ else:
30
+ logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
31
 
32
  return logits
33
 
 
38
  return [[] for _ in range(num_hypos)]
39
  generated_ngrams = [{} for _ in range(num_hypos)]
40
  for idx in range(num_hypos):
41
+ gen_tokens = prev_input_ids[idx].tolist()
42
  generated_ngram = generated_ngrams[idx]
43
  for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
44