Skip to content

Commit

Permalink
Generate: unify LogitsWarper and LogitsProcessor (#32626)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Aug 16, 2024
1 parent 5fd7ca7 commit 70d5df6
Show file tree
Hide file tree
Showing 20 changed files with 186 additions and 623 deletions.
4 changes: 0 additions & 4 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ generation.
[[autodoc]] LogitsProcessorList
- __call__

[[autodoc]] LogitsWarper
- __call__

[[autodoc]] MinLengthLogitsProcessor
- __call__

Expand Down Expand Up @@ -421,4 +418,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] WatermarkDetector
- __call__

3 changes: 0 additions & 3 deletions docs/source/ja/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,6 @@ generation_output[:2]
[[autodoc]] LogitsProcessorList
- __call__

[[autodoc]] LogitsWarper
- __call__

[[autodoc]] MinLengthLogitsProcessor
- __call__

Expand Down
3 changes: 0 additions & 3 deletions docs/source/zh/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,6 @@ generation_output[:2]
[[autodoc]] LogitsProcessorList
- __call__

[[autodoc]] LogitsWarper
- __call__

[[autodoc]] MinLengthLogitsProcessor
- __call__

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ class GenerationConfig(PushToHubMixin):
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
Whether to renormalize the logits after applying all the logits processors (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors or warpers break the normalization.
are normalized but some logit processors break the normalization.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use of
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
Expand Down
54 changes: 30 additions & 24 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

def __init__(self):
logger.warning_once(
"`LogitsWarper` is deprecated and will be removed in v4.48. Your class should inherit `LogitsProcessor` "
"instead, which has the same properties and interface."
)

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
raise NotImplementedError(
Expand All @@ -64,9 +70,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

class LogitsProcessorList(list):
"""
This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the
inputs.
"""

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Expand Down Expand Up @@ -233,9 +239,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class TemperatureLogitsWarper(LogitsWarper):
class TemperatureLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means
that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
[`TopKLogitsWarper`].
Expand Down Expand Up @@ -408,10 +414,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class TopPLogitsWarper(LogitsWarper):
class TopPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
[`LogitsProcessor`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
Args:
top_p (`float`):
Expand Down Expand Up @@ -475,10 +481,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class TopKLogitsWarper(LogitsWarper):
class TopKLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used
together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
Args:
top_k (`int`):
Expand Down Expand Up @@ -528,9 +534,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class MinPLogitsWarper(LogitsWarper):
class MinPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
[`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter becomes more agressive in the presence of
high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
Expand Down Expand Up @@ -605,11 +611,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class TypicalLogitsWarper(LogitsWarper):
class TypicalLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
log probability is close to the entropy of the token probability distribution. This means that the most likely
tokens may be discarded in the process.
[`LogitsProcessor`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens
whose log probability is close to the entropy of the token probability distribution. This means that the most
likely tokens may be discarded in the process.
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
Expand Down Expand Up @@ -693,9 +699,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class EpsilonLogitsWarper(LogitsWarper):
class EpsilonLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
[`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
Expand Down Expand Up @@ -762,15 +768,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class EtaLogitsWarper(LogitsWarper):
class EtaLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
[`LogitsProcessor`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
must be set to `True` for this `LogitsWarper` to work.
must be set to `True` for this `LogitsProcessor` to work.
Args:
Expand Down Expand Up @@ -1708,9 +1714,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores_processed


class LogitNormalization(LogitsProcessor, LogitsWarper):
class LogitNormalization(LogitsProcessor):
r"""
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
[`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
Expand Down
Loading

0 comments on commit 70d5df6

Please sign in to comment.