Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: Add new decoding strategy "DoLa" in .generate() #29619

Merged
merged 33 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5c35893
add support for DoLa decoding
voidism Mar 5, 2024
bb66df5
add docs; remove deprecated function
voidism Mar 5, 2024
9769d6d
add test code for DoLa decoding
voidism Mar 12, 2024
4dea208
update docs and paper link
voidism Mar 12, 2024
1bcdf79
solved the issues that made tests failed on CircleCI
voidism Mar 12, 2024
c31d664
ruff reformatted
voidism Mar 12, 2024
8ea188b
update DoLa decoding; test cases for llama/mistral/mixtral/gemma; docs
voidism Mar 18, 2024
76708c8
fix formatting; fix failed test cases
voidism Mar 18, 2024
419be60
fix other test cases
voidism Mar 18, 2024
69a11e6
ruff reformatted; mamba cache issue unsolved
voidism Mar 18, 2024
87fc406
remove keyword argument 'model_inputs' to match upstream changes
voidism Mar 18, 2024
efe3e34
improve documentation
voidism Mar 19, 2024
8654787
fixed suggestions from @gante
voidism Mar 20, 2024
7fcf990
ruff reformated
voidism Mar 20, 2024
21a646a
moved config warning of dola generation from `utils.py` to `configura…
voidism Mar 21, 2024
2b81d73
fixed suggestions from @amyeroberts
voidism Mar 25, 2024
cc067cc
fixed format issue; removed print; added explanation
voidism Mar 25, 2024
679afea
remove trailing whitespace
voidism Mar 25, 2024
7e9367b
ruff reformat to pass test
voidism Mar 25, 2024
4bd54c8
fixed suggestions from @amyeroberts on Mar 28
voidism May 19, 2024
92de5cd
fix failed CI tests
voidism May 19, 2024
cff5661
ruff reformatted; fixed missing argument generation_config
voidism May 19, 2024
ba61bf4
make `dola_layers` not optional
voidism May 19, 2024
87ea8d8
fix divergence w main
gante Jun 27, 2024
57c89af
fix dola test on mamba
gante Jun 27, 2024
1c44900
rwkv test (wont fix
gante Jun 27, 2024
9d9f894
slow tests running in fp16
gante Jun 28, 2024
a8993ef
make fixup
gante Jul 9, 2024
aaf560f
remove redundant fn
gante Jul 9, 2024
dc2192c
final rebase divergences
gante Jul 9, 2024
520202d
this one was missing
gante Jul 9, 2024
ce64a5f
a few more nits
gante Jul 9, 2024
8b6653c
skip stateful models
gante Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te

The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.

KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.
Expand Down Expand Up @@ -213,11 +213,11 @@ I like rock music because it's loud and energetic. I like to listen to it when I

## Watermarking

The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
When generating the "green" will have a small 'bias' value added to their logits, thus having a higher chance to be generated.
The watermarked text can be detected by calculating the proportion of "green" tokens in the text and estimating how likely it is
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
statistically to obtain that amount of "green" tokens for human-generated text. This watermarking strategy was proposed in the paper
["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634). For more information on
the inner functioning of watermarking, it is recommended to refer to the paper.

The watermarking can be used with any generative model in `tranformers` and does not require an extra classification model
Expand Down Expand Up @@ -484,3 +484,59 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t

Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
### DoLa Decoding

**D**ecoding by C**o**ntrasting **La**yers (DoLa) is a contrastive decoding strategy to improve the factuality and reduce the
hallucinations of LLMs, as described in this paper of ICLR 2024 [DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models](https://arxiv.org/abs/2309.03883).

DoLa is achieved by contrasting the differences in logits obtained from final
layers versus earlier layers, thus amplify the factual knowledge localized to particular part of transformer layers.

Do the following two steps to activate DoLa decoding when calling the `model.generate` function:
1. Set the `dola_layers` argument, which can be either a string or a list of integers.
- If set to a string, it can be one of `low`, `high`.
- If set to a list of integers, it should be a list of layer indices between 0 and the total number of layers in the model. The 0-th layer is word embedding, and the 1st layer is the first transformer layer, and so on.
2. Set `repetition_penalty = 1.2` is suggested to reduce repetition in DoLa decoding.

See the following examples for DoLa decoding with the 32-layer LLaMA-7B model.

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
>>> import torch

>>> tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
>>> model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", torch_dtype=torch.float16)
>>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
>>> model.to(device)
>>> set_seed(42)

>>> text = "On what date was the Declaration of Independence officially signed?"
>>> inputs = tokenizer(text, return_tensors="pt").to(device)

# Vanilla greddy decoding
>>> vanilla_output = model.generate(**inputs, do_sample=False, max_new_tokens=50)
>>> tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,']

# DoLa decoding with contrasting higher part of layers (layers 16,18,...,30)
>>> dola_high_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='high')
>>> tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nJuly 4, 1776, when the Continental Congress voted to separate from Great Britain. The 56 delegates to the Continental Congress signed the Declaration on August 2, 1776.']

# DoLa decoding with contrasting specific layers (layers 28 and 30)
>>> dola_custom_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers=[28,30], repetition_penalty=1.2)
>>> tokenizer.batch_decode(dola_custom_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nIt was officially signed on 2 August 1776, when 56 members of the Second Continental Congress, representing the original 13 American colonies, voted unanimously for the resolution for independence. The 2']
```

#### Understanding the `dola_layers` argument

`dola_layers` stands for the candidate layers in premature layer selection, as described in the DoLa paper. The selected premature layer will be contrasted with the final layer.

Setting `dola_layers` to `'low'` or `'high'` will select the lower or higher part of the layers to contrast, respectively.
- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively.
- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively.
- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function.
- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers.

The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper.
38 changes: 38 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class GenerationMode(ExplicitEnum):
GREEDY_SEARCH = "greedy_search"
SAMPLE = "sample"
ASSISTED_GENERATION = "assisted_generation"
DOLA_GENERATION = "dola_generation"
# Beam methods
BEAM_SEARCH = "beam_search"
BEAM_SAMPLE = "beam_sample"
Expand All @@ -81,6 +82,7 @@ class GenerationConfig(PushToHubMixin):
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
- *dola decoding* if `dola_layers` is passed to `.generate()`

To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).

Expand Down Expand Up @@ -305,6 +307,18 @@ class GenerationConfig(PushToHubMixin):
max_matching_ngram_size (`int`, *optional*, default to `None`):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.

> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)

dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.

> Parameters specific to the caching mechanism:

cache_implementation (`str`, *optional*, default to `None`):
Expand Down Expand Up @@ -397,6 +411,9 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# DoLa generation
self.dola_layers = kwargs.pop("dola_layers", None)

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
Expand Down Expand Up @@ -495,6 +512,16 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
)

# DoLa generation may extend some generation modes
if self.dola_layers is not None:
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.DOLA_GENERATION
else:
raise ValueError(
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
"is only supported with Greedy Search and Sample."
)
return generation_mode

def validate(self, is_init=False):
Expand Down Expand Up @@ -700,6 +727,17 @@ def validate(self, is_init=False):
"`generate()` (or a pipeline) directly."
)

# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
dola_decoding_wrong_parameter_msg = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
)
warnings.warn(
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
UserWarning,
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
Loading
Loading