Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: Add new decoding strategy "DoLa" in .generate() #29619

Merged
merged 33 commits into from
Jul 9, 2024

Conversation

voidism
Copy link
Contributor

@voidism voidism commented Mar 12, 2024

What does this PR do?

Fixes #29524

We add the support for a new decoding strategy proposed in a recent paper of ICLR 2024.
The main revisions are in src/transformers/generation/utils.py and src/transformers/generation/configuration_utils.py

We also update the documentation and add the test code. Run the test by:

CUDA_VISIBLE_DEVICES=0 python examples/pytorch/text-generation/run_generation_dola.py --model_name_or_path huggyllama/llama-7b --model_type llama --dola_layers 'low'

Before submitting

Who can review?

@gante is the main contributor of the part of .generate() function, which this PR focuses on.

@gante gante self-requested a review March 13, 2024 09:44
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@voidism thank you for this cool PR! 🔥

In addition to the interface and user experience comments left below, there is one task missing: tests. We should add two tests:

  1. A very small mixin test, to ensure the interface works on all models as expected. See here for an example.
  2. One (or more) heavy integration test(s), to ensure the method retains its correctness as we add other changes. See here for an example. You can add them on any model you believe it's appropriate.

examples/pytorch/text-generation/run_generation_dola.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Comment on lines 2047 to 2048
mask = final_logits[0] < -1e3
base_logits[0][mask] = -1e3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment about -1e3, for future reference? Why not any other number? It is okay if it is simply a number with which you got good results empirically 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line 2047 is removed, as I can directly get the mask from the _relative_top_filter function.
The -1e3 in line 2048 is simply a number tested work empirically. Any the number that is not -float("Inf") should be working as well. I have cleaned up the code and made them all in _relative_top_filter() function. The -1e3 is assigned as the base_filter_value variable.

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@voidism
Copy link
Contributor Author

voidism commented Mar 19, 2024

Hi @gante !

Thanks so much for your suggestions! I spent some time to add the code for test cases, and fixed the issues you mentioned.
All the CI checks were passed as well. Can you take a look at my latest commits of the code?

Please let me know if you have any other concerns or suggestions for me to fix! I would be happy to address any of the issues you may have! 🤗

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 💛

It almost could be merged as is -- the tests need to be reworked slightly. I've added a few suggestions to further improve the PR while we wait for the green light from a core maintainer 🤗

docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
tests/models/gemma/test_modeling_gemma.py Show resolved Hide resolved
tests/models/mixtral/test_modeling_mixtral.py Outdated Show resolved Hide resolved
@gante gante requested a review from amyeroberts March 20, 2024 19:08
@voidism
Copy link
Contributor Author

voidism commented Mar 20, 2024

Hi @gante !

Thanks so much for your great suggestions! I have fixed all the issues you mentioned. Just let me know if you have any other concerns or suggestions!
Thanks for requesting a review from the core maintainer! 🤗

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy with the PR 🙌

@voidism
Copy link
Contributor Author

voidism commented Mar 21, 2024

Hi @gante !

While waiting for the core maintainer's approval, I found that the validation of the parameter ranges in the generation config mainly happens in tsrc/transformers/generation/configuration_utils.py instead of src/transformers/generation/utils.py. Thus, I simply moved the warning of repetition penalty of dola generation to configuration_utils.py, and the warning will also only occur once!

However, after I committed the new code. A test case of XLM model failed, and it seems to have nothing to do with my commit. The failed case seems related to #29297

I tried syncing with the upstream but it didn't solve the issue. I wonder if you know what's the reason for this failed test case. Sorry for bothering you again!

Some tests failed!

============================= FAILURES SHORT STACK =============================
____________________ XLMModelTest.test_batching_equivalence ____________________

tests/test_modeling_common.py:745: in recursive_check
    self.assertTrue(
E   AssertionError: tensor(False) is not true : Batched and Single row outputs are not equal in XLMForQuestionAnswering for key=end_top_index. Difference=1.


FAILED tests/models/xlm/test_modeling_xlm.py::XLMModelTest::test_batching_equivalence - AssertionError: tensor(False) is not true : Batched and Single row outputs are not equal in XLMForQuestionAnswering for key=end_top_index. Difference=1.

Exited with code exit status 255

@voidism
Copy link
Contributor Author

voidism commented Mar 22, 2024

The failed test case was solved after syncing with the upstream! Please ignore my previous comment.
It's ready to merge now!

@voidism
Copy link
Contributor Author

voidism commented Mar 25, 2024

Hi @amyeroberts !

This PR is ready to merge after some iterations! Would you be able to review it and give me any suggestions you have?
Thanks a lot for the help! 🤗

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @voidism, thanks for working on adding this!

A few small comments. The main one being that the dola sampling method at the moment is way too large and needs to be broken down into smaller chunks

input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("Answer here: ", text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Answer here: ", text)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

input_ids, max_new_tokens=20, temperature=0, dola_layers="low", repetition_penalty=1.2
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("Answer here: ", text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Answer here: ", text)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

@@ -788,3 +789,25 @@ def test_model_7b_4bit(self):
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)

self.assertEqual(output_text, EXPECTED_TEXTS)

def test_model_2b_bf16_dola(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather we didn't add an integration test for each of these models for this new generation method, as it's expensive to run. Doing this for each new generation approach isn't scalable.

Rather, it's better to just have one integration test for specific generation methods, which checks the output for a select model cc @gante

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts I'd rather have it tested in a few key models, as we've been doing in the past for other generation methods -- generation tests are prone to false positives (due to argmax/sampling) and false negatives (due to a problem in the model used in a test).

But I understand our testing limitations, leaving the final call to you 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temperally removed the test for gemma! Let me know if you want me to add it back! 🤗

Comment on lines 1246 to 1250
for model_name in [
"wav2vec",
"clvp",
"bark",
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - one line

Suggested change
for model_name in [
"wav2vec",
"clvp",
"bark",
]
for model_name in ["wav2vec", "clvp", "bark"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

"bark",
]
):
self.skipTest("Skip speech models")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

@voidism voidism Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I previously skipped these speech models because they don't have the regular output_embeddings to perform early exit. And the early exit is required for dola decoding. However, it's actually not just because they are speech models, we should simply check the output_embeddings to decide whether to skip!

Thus, I changed this part to

if model.get_output_embeddings() is None:
    self.skipTest("DoLa is not supported for models that don't have output embeddings")

streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tests it says that this isn't supperted by encoder_decoder models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this part of the code that has if self.config.is_encoder_decoder:!

}
generation_kwargs.update({"dola_layers": "low"})
output_dola = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to do a test which does a single forward pass and checks that the expected logits are selected i.e. the dola method should be decoupled from generate itself and we test passing logits to the dola method and then the logit outputs. I believe this is a more general issue with the generation testing however.

Specifically, this test doesn't really convince me that the implementation is correct (not do the integration tests, unless they've been generated from the official dola implementation), but that they functionally work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to do a test which does a single forward pass and checks that the expected logits are selected i.e. the dola method should be decoupled from generate itself and we test passing logits to the dola method and then the logit outputs. I believe this is a more general issue with the generation testing however.

100% Agreed. However, this is not an issue with the DoLA method, but with the structure of generate. At the moment, each decoding function is a monolith where we can't isolate an iteration of the loop. Me and @zucchini-nlp are working to fix this problem, so we can breakdown (and test) each piece of the core functionality. For instance, you've recently reviewed a PR where the stopping condition of the generation loop was moved into a shared function, which works towards this goal 🤗

What this pattern of (legacy) tests does is to catch flagrant API issues and/or model incompatibilities, not to detect whether the decoding method matches its original implementation. And that's the extent of what we can do in unit tests, until we rework things :)

@amyeroberts What I mean with this comment is that it shouldn't be @voidism's responsibility to break down the _dola_decoding function nor to rework tests, @voidism is simply following the existing pattern. It is our (mine and @zucchini-nlp's) responsibility to ensure what you wrote becomes true -- in fact, it is easier for us to refactor things if they keep the same imperfect pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the correctness of DoLa. I am the first author of DoLa paper and I have kept tracking whether the new code in this PR can reproduce the old numbers in my paper.

image

The left-hand side is the new numbers I tested using the current version of code.
The right-hand side is the screenshot of my paper, where the numbers are from the official implementation and the experiments I did last year.

The original implementation was based on v4.28.1. The numbers changed a little bit (also for the greedy decoding baseline), which I think it's because of the version changes as well as the different machines and gpus I used. But the same level of improvement can be achieved by the new code in this PR, e.g. ~4% on StrQA with llama-7b.

I can also provide more tests to validate the consistency between this PR and my official dola implementation if you think it's needed!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@voidism Thanks for providing these numbers! I think these are good enough to have a reasonable degree of certainty in the application in the absence of being able to fully test at the moment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have checked that my latest commit today (based on v4.41.0) can also reproduce the scores here!

# DoLa decoding with contrasting lower part of layers (layers 0,2,...,14)
>>> dola_low_output = model.generate(**inputs, do_sample=False, max_new_tokens=50, dola_layers='low', repetition_penalty=1.2)
>>> tokenizer.batch_decode(dola_low_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
['\nThe Declaration of Independence was signed on July 4, 1776.\nWhat was the date of the signing of the Declaration of Independence?\nThe Declaration of Independence was signed on July 4,']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it - the outputs are the same?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, otherwise users won't feel compelled into using the technique

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed here! I switched back to show the output example of dola_layers='high' as suggested by @gante last time, and removed the low outputs here. In this case, the high output is different from the vanilla decoding outputs and it makes more sense to the readers.

- If the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function.
- Set the `dola_layers` to a list of integers for layer indices to contrast manually specified layers. For example, setting `dola_layers=[28,30]` will contrast the final layer (32-th layer) with the 28-th and 30-th layers.

The paper suggested that contrasting `'high'` layers to improve short-answer tasks like TruthfulQA, and contrasting `'low'` layers to improve all the other long-answer reasoning tasks, such as GSM8K, StrategyQA, FACTOR, and VicunaQA. Applying DoLa to smaller models like GPT-2 is not recommended, as the results shown in the Appendix N of the paper.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be good to use a better demo for low here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched back to show a demo of high! I can also try to find prompt cases that make vanilla and low and high all very different, if you think it's needed!

Comment on lines 446 to 538
- For `N`-layer models with `N <= 40` layers, the layers of `range(0, N // 2, 2)` and `range(N // 2, N, 2)` are used for `'low'` and `'high'` layers, respectively.
- For models with `N > 40` layers, the layers of `range(0, 20, 2)` and `range(N - 20, N, 2)` are used for `'low'` and `'high'` layers, respectively.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm - is this from the paper? It seems pretty arbitratry

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the layer selection logic is in the Appendix F of my paper. For llama-7b we use [0, 16) and [16, 32). For llama-13b/33b/65b we use [0, 20) and [N-20, N), where N = 40/60/80 for 13b/33b/65b. They are selected based on the validation set results. In this PR, I renamed this layer selection as low or high for simplicity.

@voidism
Copy link
Contributor Author

voidism commented Mar 25, 2024

Hi @amyeroberts !

Thanks so much for all of your great suggestions! They are very helpful and they improved my code and the test cases!
I have tried my best to fix all the issues you mentioned above. Let me know if there are still concerns or suggestions so I can address them further! 🤗

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this! Just a few small suggestions - otherwise looking great!

}
generation_kwargs.update({"dola_layers": "low"})
output_dola = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@voidism Thanks for providing these numbers! I think these are good enough to have a reasonable degree of certainty in the application in the absence of being able to fully test at the moment

The method is based on the paper "DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models" (https://arxiv.org/abs/2309.03883) in ICLR 2024.

Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight mismatch between docstring and method signature e.g. do_sample missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the mismatch. Now the docstring and method signature are consistent!

>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model._dola_decoding(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should show calling a private method in an example. My understanding from recent refactors is that this is now taken from the generation config @gante Is this right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We can remove this example :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the example!

# using final layer as the mature layer
mature_layer = self.config.num_hidden_layers
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer, as the early exit from word embeddings will become identity function
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the mature layer and the 0-th layer if it's the mature layer. Notice that DoLa is not helping much to shallow models.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultra nit

Possibly mature -> final to clarify? I'm not sure what a mature layer is i.e. above when it says using the final layer as the mature layer.

Suggested change
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the mature layer and the 0-th layer if it's the mature layer. Notice that DoLa is not helping much to shallow models.
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the mature layer and the 0-th layer otherwise. Notice that DoLa does not help shallow models much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed all the mature layer into final layer!

Comment on lines +2043 to +2485
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (final_layer_next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)

if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for future @gante - this looks like something we can abstract out for this and other generation methods

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed 👍

Comment on lines 5256 to 5295
else:
# 1. Stacking all premature_layers into a new dimension
stacked_premature_layers = torch.stack(
[candidate_premature_logits[i] for i in candidate_premature_layers], dim=0
)

# 2. Calculate the softmax values for mature_layer and all premature_layers
softmax_mature_layer = F.softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size)
softmax_premature_layers = F.softmax(
stacked_premature_layers, dim=-1
) # shape: (num_premature_layers, batch_size, vocab_size)

# 3. Calculate M, the average distribution
M = 0.5 * (
softmax_mature_layer[None, :, :] + softmax_premature_layers
) # shape: (num_premature_layers, batch_size, vocab_size)

# 4. Calculate log-softmax for the KL divergence
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size)
log_softmax_premature_layers = F.log_softmax(
stacked_premature_layers, dim=-1
) # shape: (num_premature_layers, batch_size, vocab_size)

# 5. Calculate the KL divergences and then the JS divergences
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean(
-1
) # shape: (num_premature_layers, batch_size)
kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean(
-1
) # shape: (num_premature_layers, batch_size)
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size)

# 6. Reduce the batchmean
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]

base_logits = candidate_premature_logits[premature_layer]
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
logits = final_logits - base_logits
return logits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • We can just do an early return here which avoids having the main block of code indented
  • Comments above the line of code to avoid unnecessary spliting
Suggested change
else:
# 1. Stacking all premature_layers into a new dimension
stacked_premature_layers = torch.stack(
[candidate_premature_logits[i] for i in candidate_premature_layers], dim=0
)
# 2. Calculate the softmax values for mature_layer and all premature_layers
softmax_mature_layer = F.softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size)
softmax_premature_layers = F.softmax(
stacked_premature_layers, dim=-1
) # shape: (num_premature_layers, batch_size, vocab_size)
# 3. Calculate M, the average distribution
M = 0.5 * (
softmax_mature_layer[None, :, :] + softmax_premature_layers
) # shape: (num_premature_layers, batch_size, vocab_size)
# 4. Calculate log-softmax for the KL divergence
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size)
log_softmax_premature_layers = F.log_softmax(
stacked_premature_layers, dim=-1
) # shape: (num_premature_layers, batch_size, vocab_size)
# 5. Calculate the KL divergences and then the JS divergences
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean(
-1
) # shape: (num_premature_layers, batch_size)
kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean(
-1
) # shape: (num_premature_layers, batch_size)
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size)
# 6. Reduce the batchmean
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]
base_logits = candidate_premature_logits[premature_layer]
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
logits = final_logits - base_logits
return logits
return logits
# 1. Stacking all premature_layers into a new dimension
stacked_premature_layers = torch.stack(
[candidate_premature_logits[i] for i in candidate_premature_layers], dim=0
)
# 2. Calculate the softmax values for mature_layer and all premature_layers
# shape: (batch_size, vocab_size)
softmax_mature_layer = F.softmax(final_logits, dim=-1)
# shape: (num_premature_layers, batch_size, vocab_size)
softmax_premature_layers = F.softmax(stacked_premature_layers, dim=-1)
# 3. Calculate M, the average distribution
# shape: (num_premature_layers, batch_size, vocab_size)
M = 0.5 * (softmax_mature_layer[None, :, :] + softmax_premature_layers)
# 4. Calculate log-softmax for the KL divergence
# shape: (batch_size, vocab_size)
log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1)
# shape: (num_premature_layers, batch_size, vocab_size)
log_softmax_premature_layers = F.log_softmax(stacked_premature_layers, dim=-1)
# 5. Calculate the KL divergences and then the JS divergences
# shape: (num_premature_layers, batch_size)
kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean(-1)
# shape: (num_premature_layers, batch_size)
kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean(-1)
js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size)
# 6. Reduce the batchmean
js_divs = js_divs.mean(-1) # shape: (num_premature_layers,)
premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]
base_logits = candidate_premature_logits[premature_layer]
final_logits, base_logits = _relative_top_filter(final_logits, base_logits)
logits = final_logits - base_logits
return logits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

Comment on lines 5268 to 5269
# 3. Calculate M, the average distribution
M = 0.5 * (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a rule, no single letter vars should be used - let's use something more descriptive e.g. avg_dist

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to avg_dist!

return logits


def _relative_top_filter(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definition of objects should go above the lines they're first used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed the order!

base_filter_value=-1e-3,
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""Reference: https://github.com/XiangLi1999/ContrastiveDecoding/blob/170e9142e92159c1237d731e240f5eb14aabf428/transformers/src/transformers/generation_logits_process.py#L235"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link is great! We should add a short sentence saying what this function does too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a description!

Comment on lines +1251 to +1286
if not hasattr(config, "use_cache"):
config.use_cache = False
else:
config.use_cache = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on https://github.com/huggingface/transformers/pull/29619/files#r1538243054

Suggested change
if not hasattr(config, "use_cache"):
config.use_cache = False
else:
config.use_cache = True
# Some models don't support the cache and returning past_key_values
if not hasattr(config, "use_cache"):
config.use_cache = False
else:
config.use_cache = True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the comment!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante
Copy link
Member

gante commented Apr 22, 2024

@voidism are you intending to continue the PR? 🤗 Or do you need a hand?

@voidism
Copy link
Contributor Author

voidism commented Apr 22, 2024

Hi @gante

Sorry that I was busy with my midterm for the past few weeks 😔 so I forgot to fix this for a while... I will continue fixing the PR this or next week!
Thanks for the reminder and sorry for the delay!

@gante
Copy link
Member

gante commented Apr 23, 2024

@voidism no worries, focus on your midterms 💪 we'll be here when you're ready to continue 🙌

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@voidism
Copy link
Contributor Author

voidism commented May 19, 2024

Hi @gante and @amyeroberts

I am back and fixed all the suggestions from @amyeroberts last time!

Sorry that I was busy with midterm exams and paper deadlines last month 😔, so I stopped fixing this PR for a while. 🥲
And last week I just traveled to ICLR 2024 to present the DoLa paper! Now I finally get some free time to fix this.

It's my fault that you guys might need to spend more time recalling our discussions from almost two months ago. I am really sorry about that! 🥲

In addition to fixing all the suggestions from last time, I have synced this PR with the latest transformers v4.41.0, and it passed all the CI tests. I found that the new version of generation/utils.py becomes more concise and cleaner than before. Thanks so much for your efforts in making it better!

Let me know if you have any other concerns or suggestions. I recently have more free time so I can assure you guys that I will fix any of your new suggestions as soon as I can! No more procrastination I promise!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding and iterating!

All looks good to me. As it's been open for a while, I'd like a quick re-review from @gante to confirm this is still in-line with the current generate patterns

@voidism
Copy link
Contributor Author

voidism commented May 29, 2024

Thanks @amyeroberts so much for approving the changes! 🙌

Hi @gante Just let me know if the current version looks good or not. I will be happy to fix any suggestions or concerns you have! Thanks! 🤗

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante
Copy link
Member

gante commented Jun 22, 2024

@voidism my turn to apologise for the delay, I'm catching up with issues :)

I've re-checked the PR and I'm happy with it! I'm going to merge this Monday (to avoid breaking our CI on a weekend 😉 )

@voidism
Copy link
Contributor Author

voidism commented Jun 22, 2024

Hi @gante

No problem! Thanks so much for your help!! 🤗

@gante
Copy link
Member

gante commented Jul 9, 2024

rebased yet again (previous main had unrelated issues that was making CI red), fixing resulting issues in next commits

@gante
Copy link
Member

gante commented Jul 9, 2024

Ran the following slow tests locally (with the expected results):

  1. RUN_SLOW=1 py.test tests/models/ -k dola -vv
  2. RUN_SLOW=1 py.test -vv tests/models/llama/test_modeling_llama.py
  3. RUN_SLOW=1 py.test -vv tests/generation/test_utils.py
  4. RUN_SLOW=1 py.test -vv tests/utils/test_cache_utils.py

@gante gante merged commit d094d8d into huggingface:main Jul 9, 2024
23 checks passed
@gante
Copy link
Member

gante commented Jul 9, 2024

@voidism finally all CI issues were sorted -- thank you for bearing with us 🤗 I will communicate about this feature tomorrow! 💪

@voidism
Copy link
Contributor Author

voidism commented Jul 9, 2024

Hi @gante

Thanks a lot for your help! Handling these CI tests isn't easy (I learned a lot from it 😂). I really appreciate your effort. So happy that we finally made it! 🤗

@gante
Copy link
Member

gante commented Jul 10, 2024

Handling these CI tests isn't easy

@voidism hehe it looks annoying, but it is essential to ensure all our features are playing nicely with each other 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding new decoding strategy "DoLa" into the model.generate() function
4 participants