diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 99d99bdf254b10..05288708c4ce19 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -438,10 +438,28 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): + """ + Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like + this {(40,): [2883], (2883,): [2712], (2712,): [4346]}. + + Args: + ngram_size (`int`): + The number sequential tokens taken as a group which may only occur once before being banned. + prev_input_ids (`torch.Tensor`): + Generated token ids for the current hypothesis. + num_hypos (`int`): + The number of hypotheses for which n-grams need to be generated. + + Returns: + generated_ngrams (`dict`): + Dictionary of generated ngrams. + """ + # Initialize an empty list of dictionaries, one for each hypothesis (index) in the range of num_hypos generated_ngrams = [{} for _ in range(num_hypos)] for idx in range(num_hypos): gen_tokens = prev_input_ids[idx].tolist() generated_ngram = generated_ngrams[idx] + # Loop through each n-gram of size ngram_size in the list of tokens (gen_tokens) for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): prev_ngram_tuple = tuple(ngram[:-1]) generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] @@ -449,6 +467,22 @@ def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): + """ + Determines the banned tokens for the current hypothesis based on previously generated n-grams. + + Args: + banned_ngrams (`dict`): + A dictionary containing previously generated n-grams for each hypothesis. + prev_input_ids (`torch.Tensor`): + Generated token ids for the current hypothesis. + ngram_size (`int`): + The number sequential tokens taken as a group which may only occur once before being banned. + cur_len (`int`): + The current length of the token sequences for which the n-grams are being checked. + + Returns: + List of tokens that are banned. + """ # Before decoding the next token, prevent decoding of ngrams that have already appeared start_idx = cur_len + 1 - ngram_size ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) @@ -462,9 +496,7 @@ def _calc_banned_ngram_tokens( if cur_len + 1 < ngram_size: # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet return [[] for _ in range(num_hypos)] - generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) - banned_tokens = [ _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) for hypo_idx in range(num_hypos) @@ -474,12 +506,43 @@ def _calc_banned_ngram_tokens( class NoRepeatNGramLogitsProcessor(LogitsProcessor): r""" - [`LogitsProcessor`] that enforces no repetition of n-grams. See + N-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the + sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation, + avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no + repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens + from consideration when further processing the scores. [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). + + + Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York + might lead to undesirable outcomes where the city's name appears only once in the entire text. + [Reference](https://huggingface.co/blog/how-to-generate) + + + Args: ngram_size (`int`): All ngrams of size `ngram_size` can only occur once. + + Examples: + + ```py + >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> inputs = tokenizer(["I enjoy watching football"], return_tensors="pt") + + >>> output = model.generate(**inputs, max_length=50) + >>> print(tokenizer.decode(output[0], skip_special_tokens=True)) + "I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of the night. I'm not a big fan of the idea of playing in the middle of the night. I'm not a big" + + >>> # Now let's add ngram size using in model.generate. This should stop the repetitions in the output. + >>> output = model.generate(**inputs, max_length=50, no_repeat_ngram_size=2) + >>> print(tokenizer.decode(output[0], skip_special_tokens=True)) + I enjoy playing football on the weekends, but I'm not a big fan of the idea of playing in the middle of a game. I think it's a bit of an overreaction to the fact that we're playing a team that's playing" + ``` """ def __init__(self, ngram_size: int): @@ -491,7 +554,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to num_batch_hypotheses = scores.shape[0] cur_len = input_ids.shape[-1] banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) - for i, banned_tokens in enumerate(banned_batch_tokens): scores[i, banned_tokens] = -float("inf")