Patched error
Browse files- logits_ngrams.py +5 -2
logits_ngrams.py
CHANGED
@@ -24,7 +24,10 @@ def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_
|
|
24 |
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
25 |
banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
26 |
for batch_idx in range(batch_size):
|
27 |
-
|
|
|
|
|
|
|
28 |
|
29 |
return logits
|
30 |
|
@@ -35,7 +38,7 @@ def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len
|
|
35 |
return [[] for _ in range(num_hypos)]
|
36 |
generated_ngrams = [{} for _ in range(num_hypos)]
|
37 |
for idx in range(num_hypos):
|
38 |
-
gen_tokens = prev_input_ids[idx]
|
39 |
generated_ngram = generated_ngrams[idx]
|
40 |
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
41 |
|
|
|
24 |
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
25 |
banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
26 |
for batch_idx in range(batch_size):
|
27 |
+
if skip_tokens is not None:
|
28 |
+
logits[batch_idx, [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens]] = -float("inf")
|
29 |
+
else:
|
30 |
+
logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
31 |
|
32 |
return logits
|
33 |
|
|
|
38 |
return [[] for _ in range(num_hypos)]
|
39 |
generated_ngrams = [{} for _ in range(num_hypos)]
|
40 |
for idx in range(num_hypos):
|
41 |
+
gen_tokens = prev_input_ids[idx].tolist()
|
42 |
generated_ngram = generated_ngrams[idx]
|
43 |
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
44 |
|