Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: new Cache format in decoder-only models #31421

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
183cd66
draft bart with new cache
zucchini-nlp Jun 14, 2024
4578bca
add cache for decoder-only models
zucchini-nlp Jun 14, 2024
9505ca4
revert utils
zucchini-nlp Jun 14, 2024
2ab28f3
modify docstring
zucchini-nlp Jun 14, 2024
5fe4e9e
revert bart
zucchini-nlp Jun 14, 2024
09413c3
minor fixes
zucchini-nlp Jun 14, 2024
3c27604
fix copies (not related)
zucchini-nlp Jun 14, 2024
350acc5
revert tests
zucchini-nlp Jun 14, 2024
c0adf10
remove enc-dec related code
zucchini-nlp Jun 17, 2024
c18b177
remove bloom
zucchini-nlp Jun 17, 2024
582f289
remove opt (enc-dec)
zucchini-nlp Jun 17, 2024
3141a71
Merge remote-tracking branch 'upstream/main' into dynamic_cache_decod…
zucchini-nlp Jun 17, 2024
33d54b4
update docstring
zucchini-nlp Jun 18, 2024
dd05e6b
git, codegen, gpt_neo, gpt_neox, gpj
zucchini-nlp Jun 18, 2024
cb878d5
clean up
zucchini-nlp Jun 19, 2024
0588791
copied from statements
zucchini-nlp Jun 19, 2024
a27b47c
revert
zucchini-nlp Jun 19, 2024
1abcf30
tmp
zucchini-nlp Jun 19, 2024
00ed88c
update warning msg
zucchini-nlp Jun 20, 2024
6c3b3aa
forgot git
zucchini-nlp Jun 20, 2024
fd5eeab
add more flags
zucchini-nlp Jun 21, 2024
e233f29
run-slow git,codegen,gpt_neo,gpt_neox,gpj
zucchini-nlp Jun 21, 2024
356d578
add cache flag to VLMs
zucchini-nlp Jul 9, 2024
c906670
remove files
zucchini-nlp Jul 9, 2024
08d9e6f
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Jul 9, 2024
56c05b2
style
zucchini-nlp Jul 9, 2024
8510810
video LLMs also need a flag
zucchini-nlp Jul 9, 2024
cebb55d
style
zucchini-nlp Jul 9, 2024
8fd9dd1
llava will go in another PR
zucchini-nlp Jul 26, 2024
4b9ced1
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Jul 26, 2024
aea219b
style
zucchini-nlp Jul 26, 2024
4991863
[run-slow] codegen, falcon, git, gpt_neo, gpt_neox, gptj, idefics
zucchini-nlp Jul 26, 2024
ec306a2
Update src/transformers/models/gpt_neo/modeling_gpt_neo.py
zucchini-nlp Jul 30, 2024
cf793b7
copy from
zucchini-nlp Jul 30, 2024
c92409c
deprecate until v4.45 and warn if not training
zucchini-nlp Jul 30, 2024
c2b97e4
nit
zucchini-nlp Jul 30, 2024
35b60de
fix test
zucchini-nlp Jul 30, 2024
d2fca9a
test static cache
zucchini-nlp Aug 2, 2024
0933350
Merge branch 'main' into dynamic_cache_decoder_only
zucchini-nlp Aug 2, 2024
42349d4
add more tests and fix models
zucchini-nlp Aug 2, 2024
45c3a1b
fix copies
zucchini-nlp Aug 2, 2024
5f22616
return sliding window mask
zucchini-nlp Aug 2, 2024
f5af6a2
run slow tests & fix + codestyle
zucchini-nlp Aug 6, 2024
21b45c5
one more falcon fix for alibi
zucchini-nlp Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,9 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:

self.dtype = dtype if dtype is not None else torch.float32
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)

self.key_cache: List[torch.Tensor] = []
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
# models. May cause trobles with non-text modalities.
cache_dtype = self.lm_head.weight.dtype
cache_dtype = self.get_output_embeddings().weight.dtype

cache_kwargs = {
"config": self.config,
Expand Down
411 changes: 286 additions & 125 deletions src/transformers/models/codegen/modeling_codegen.py

Large diffs are not rendered by default.

453 changes: 312 additions & 141 deletions src/transformers/models/falcon/modeling_falcon.py
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

145 changes: 84 additions & 61 deletions src/transformers/models/git/modeling_git.py
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

394 changes: 286 additions & 108 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py

Large diffs are not rendered by default.

442 changes: 315 additions & 127 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py

Large diffs are not rendered by default.

447 changes: 303 additions & 144 deletions src/transformers/models/gptj/modeling_gptj.py

Large diffs are not rendered by default.

261 changes: 214 additions & 47 deletions src/transformers/models/idefics/modeling_idefics.py

Large diffs are not rendered by default.

49 changes: 48 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
Expand Down Expand Up @@ -1769,6 +1769,53 @@ def test_new_cache_format(self, num_beams, do_sample):
)
)

def test_generate_with_static_cache(self):
"""
Tests if StaticCache works if we set attn_implementation=static when generation.
This doesn't test if generation quality is good, but tests that models with
self._supports_static_cache don't throw an error when generating and return
a StaticCache object at the end.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")

config, input_ids, attention_mask = self._get_input_ids_and_config()
if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")

config.use_cache = True
config.is_decoder = True
batch_size, seq_length = input_ids.shape
max_new_tokens = 20

model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_length": None,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}

max_cache_len = seq_length + max_new_tokens
head_dim = (
model.config.head_dim
if hasattr(model.config, "head_dim")
else model.config.hidden_size // model.config.num_attention_heads
)
num_key_value_heads = (
model.config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else model.config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)

cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(results.past_key_values, StaticCache))
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)

@require_quanto
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
Expand Down
3 changes: 2 additions & 1 deletion tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Testing suite for the PyTorch Phi-3 model."""

import unittest
from typing import List

from parameterized import parameterized

Expand Down Expand Up @@ -69,7 +70,7 @@ def forward(
).logits

@staticmethod
def generate(model: Phi3ForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]:
def generate(model: Phi3ForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> List[int]:
model = Phi3MiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1])

response_tokens = []
Expand Down
38 changes: 38 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4587,6 +4587,44 @@ def test_custom_4d_attention_mask(self):
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)

def test_static_cache_matches_dynamic(self):
"""
Tests that generating with static cache give almost same results as with dynamic cache.
This test does not compile the model and check only logits similarity for numerical precision
errors.
"""
if len(self.all_generative_model_classes) == 0:
self.skipTest(
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
)

for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} does not support static cache")

if not model_class._supports_cache_class:
self.skipTest(f"{model_class.__name__} does not support cache class")

config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")

model = model_class(config).to(device=torch_device, dtype=torch.float32)
model.eval()

dynamic_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
)
static_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
cache_implementation="static",
output_logits=True,
return_dict_in_generate=True,
)
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))

# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu
Expand Down
Loading