Skip to content

Commit

Permalink
Update logits_process.py docstrings to clarify penalty and reward cas…
Browse files Browse the repository at this point in the history
…es (attempt #2) (#26784)

* Update logits_process.py docstrings + match arg fields to __init__'s

* Ran `make style`
  • Loading branch information
larekrow authored Oct 17, 2023
1 parent 85e9d64 commit 0b8604d
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,14 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
selected. The formula can be seen in the original [paper](https://arxiv.org/pdf/1909.05858.pdf). According to the
paper a penalty of around 1.2 yields a good balance between truthful generation and lack of repetition.
This technique can also be used to reward and thus encourage repetition in a similar manner. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
Examples:
Expand Down Expand Up @@ -313,7 +318,7 @@ def __init__(self, penalty: float):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)

scores.scatter_(1, input_ids, score)
Expand All @@ -322,11 +327,18 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
[`LogitsProcessor`] that avoids hallucination by boosting the probabilities of tokens found within the original
input.
This technique can also be used to reward and thus encourage hallucination (or creativity) in a similar manner. To
penalize and reduce hallucination, use `penalty` values above 1.0, where a higher value penalizes more strongly. To
reward and encourage hallucination, use `penalty` values between 0.0 and 1.0, where a lower value rewards more
strongly.
Args:
hallucination_penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty.
penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty. Above 1.0 penalizes hallucination. Between
0.0 and 1.0 rewards hallucination.
encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should be repeated within the decoder ids.
"""
Expand All @@ -342,7 +354,7 @@ def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids)

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
# if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)

scores.scatter_(1, self.encoder_input_ids, score)
Expand Down

0 comments on commit 0b8604d

Please sign in to comment.