Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCS] Example for LogitsProcessor class #24848

Merged
merged 11 commits into from
Jul 20, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,49 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
[`LogitsProcessor`] prevents the repetition of previous tokens through an exponential penalty.
shauray8 marked this conversation as resolved.
Show resolved Hide resolved
This technique invloves adjusting the scores assigned to previously generated tokens, discouraging repetition.
shauray8 marked this conversation as resolved.
Show resolved Hide resolved
This technique shares some similarities with coverage mechanisms and other aimed at reducing repetition.
During the text generation process, the probability distribution for the next token is determined using a formula that incorporates
token scores based on their occurrence in the generated sequence. Tokens with higher scores are less likely to be selected.
shauray8 marked this conversation as resolved.
Show resolved Hide resolved
The formula takes into account a fair bunch of variables which can be further seen in the [paper](https://arxiv.org/pdf/1909.05858.pdf).
shauray8 marked this conversation as resolved.
Show resolved Hide resolved

<Tip>

It is worth noting that the effectiveness of penalized sampling relies on the model's ability to learn a reliable distribution of
tokens during its training.

</Tip>
shauray8 marked this conversation as resolved.
Show resolved Hide resolved

Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
The parameter for repetition penalty. 1.0 means no penalty. According to the paper a penalty of around 1.2
yields a good balance between truthful generation and lack of repetition.
See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
shauray8 marked this conversation as resolved.
Show resolved Hide resolved

Examples:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> # Initializing the model and tokenizer for it
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # add_prefix_space - Adds an initial space to the input. This allows to treat the leading word just as any other word
>>> #(Not important for this demonstration)
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2", add_prefix_space = True)
shauray8 marked this conversation as resolved.
Show resolved Hide resolved
>>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")

>>> # This shows a normal generate without any specific parameters
>>> summary_ids = model.generate(inputs["input_ids"], max_length=20)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
I'm not going to lie, I'm not going to lie. I'm not going to lie

>>> # This generates a penalty for repeated tokens
>>> penalized_ids = model.generate(inputs["input_ids"], max_length=20, repetition_penalty=1.2)
>>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
I'm not going to lie, I was really excited about this. It's a great game

```
"""

def __init__(self, penalty: float):
Expand Down