diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py
index 99d99bdf254b10..05288708c4ce19 100644
--- a/src/transformers/generation/logits_process.py
+++ b/src/transformers/generation/logits_process.py
@@ -438,10 +438,28 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
+ """
+ Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
+ this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.
+
+ Args:
+ ngram_size (`int`):
+ The number sequential tokens taken as a group which may only occur once before being banned.
+ prev_input_ids (`torch.Tensor`):
+ Generated token ids for the current hypothesis.
+ num_hypos (`int`):
+ The number of hypotheses for which n-grams need to be generated.
+
+ Returns:
+ generated_ngrams (`dict`):
+ Dictionary of generated ngrams.
+ """
+ # Initialize an empty list of dictionaries, one for each hypothesis (index) in the range of num_hypos
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
+ # Loop through each n-gram of size ngram_size in the list of tokens (gen_tokens)
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
@@ -449,6 +467,22 @@ def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
+ """
+ Determines the banned tokens for the current hypothesis based on previously generated n-grams.
+
+ Args:
+ banned_ngrams (`dict`):
+ A dictionary containing previously generated n-grams for each hypothesis.
+ prev_input_ids (`torch.Tensor`):
+ Generated token ids for the current hypothesis.
+ ngram_size (`int`):
+ The number sequential tokens taken as a group which may only occur once before being banned.
+ cur_len (`int`):
+ The current length of the token sequences for which the n-grams are being checked.
+
+ Returns:
+ List of tokens that are banned.
+ """
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
@@ -462,9 +496,7 @@ def _calc_banned_ngram_tokens(
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
-
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
-
banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
@@ -474,12 +506,43 @@ def _calc_banned_ngram_tokens(
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
- [`LogitsProcessor`] that enforces no repetition of n-grams. See
+ N-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
+ sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation,
+ avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
+ repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
+ from consideration when further processing the scores.
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
+
+
+ Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York
+ might lead to undesirable outcomes where the city's name appears only once in the entire text.
+ [Reference](https://huggingface.co/blog/how-to-generate)
+
+
+
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
+
+ Examples:
+
+ ```py
+ >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM
+
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+ >>> inputs = tokenizer(["I enjoy watching football"], return_tensors="pt")
+
+ >>> output = model.generate(**inputs, max_length=50)
+ >>> print(tokenizer.decode(output[0], skip_special_tokens=True))
+ "I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of the night. I'm not a big fan of the idea of playing in the middle of the night. I'm not a big"
+
+ >>> # Now let's add ngram size using in model.generate. This should stop the repetitions in the output.
+ >>> output = model.generate(**inputs, max_length=50, no_repeat_ngram_size=2)
+ >>> print(tokenizer.decode(output[0], skip_special_tokens=True))
+ I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of a game. I think it's a bit of an overreaction to the fact that we're playing a team that's playing"
+ ```
"""
def __init__(self, ngram_size: int):
@@ -491,7 +554,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
-
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")