From e9c5b17776fd0c4bbb20882b4a0347fd55534504 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 13:44:27 +0000 Subject: [PATCH 1/9] add test --- src/transformers/cache_utils.py | 19 +++++++--- .../models/llama/modeling_llama.py | 12 +++++++ .../models/mistral/modeling_mistral.py | 12 +++++++ .../models/persimmon/modeling_persimmon.py | 6 ++++ src/transformers/models/phi/modeling_phi.py | 12 +++++++ tests/test_cache_utils.py | 36 +++++++++++++++++++ 6 files changed, 93 insertions(+), 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a6258182fb0035..6f2f58544ecb77 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -38,6 +38,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states, if there is any.""" + raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") + class DynamicCache(Cache): """ @@ -120,6 +124,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return None + def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): @@ -209,8 +217,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length if len(self.key_cache) <= layer_idx: return 0 - cache_length = self.key_cache[layer_idx].shape[-2] - return min(cache_length, self.window_length - 1) + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.window_length def update( self, @@ -239,8 +250,8 @@ def update( """ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models # with partially rotated position embeddings, like Phi or Persimmon. - sin = cache_kwargs.get("sin") - cos = cache_kwargs.get("cos") + sin = cache_kwargs.get("sin")[: self.window_length] + cos = cache_kwargs.get("cos")[: self.window_length] partial_rotation_size = cache_kwargs.get("partial_rotation_size") using_rope = cos is not None and sin is not None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9791c99a1ffdb8..e2b7d8d84a399d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -405,6 +405,12 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the + # attention mask accordingly. `.update()` handles the cache cropping internally if needed. + kv_max_length = past_key_value.get_max_length() + if kv_max_length is not None and kv_seq_len > kv_max_length: + kv_seq_len = kv_max_length + attention_mask = attention_mask[:, :, :, -kv_seq_len:] key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -511,6 +517,12 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the + # attention mask accordingly. `.update()` handles the cache cropping internally if needed. + kv_max_length = past_key_value.get_max_length() + if kv_max_length is not None and kv_seq_len > kv_max_length: + kv_seq_len = kv_max_length + attention_mask = attention_mask[:, :, :, -kv_seq_len:] # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 34a8f0bfa8128e..b032eb27b81558 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -275,6 +275,12 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the + # attention mask accordingly. `.update()` handles the cache cropping internally if needed. + kv_max_length = past_key_value.get_max_length() + if kv_max_length is not None and kv_seq_len > kv_max_length: + kv_seq_len = kv_max_length + attention_mask = attention_mask[:, :, :, -kv_seq_len:] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -408,6 +414,12 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the + # attention mask accordingly. `.update()` handles the cache cropping internally if needed. + kv_max_length = past_key_value.get_max_length() + if kv_max_length is not None and kv_seq_len > kv_max_length: + kv_seq_len = kv_max_length + attention_mask = attention_mask[:, :, :, -kv_seq_len:] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index eacccb1564c8d7..d60cb5d055c8cd 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -318,6 +318,12 @@ def forward( # Specific to RoPE models with partial rotation cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the + # attention mask accordingly. `.update()` handles the cache cropping internally if needed. + kv_max_length = past_key_value.get_max_length() + if kv_max_length is not None and kv_seq_len > kv_max_length: + kv_seq_len = kv_max_length + attention_mask = attention_mask[:, :, :, -kv_seq_len:] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4431c10eebcec4..0997162b43b326 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -357,6 +357,12 @@ def forward( # Specific to RoPE models with partial rotation cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the + # attention mask accordingly. `.update()` handles the cache cropping internally if needed. + kv_max_length = past_key_value.get_max_length() + if kv_max_length is not None and kv_seq_len > kv_max_length: + kv_seq_len = kv_max_length + attention_mask = attention_mask[:, :, :, -kv_seq_len:] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -466,6 +472,12 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the + # attention mask accordingly. `.update()` handles the cache cropping internally if needed. + kv_max_length = past_key_value.get_max_length() + if kv_max_length is not None and kv_seq_len > kv_max_length: + kv_seq_len = kv_max_length + attention_mask = attention_mask[:, :, :, -kv_seq_len:] tgt_len = key_states.shape[2] diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 1ffc4962787905..f6d316b9f0b4a3 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -15,6 +15,8 @@ import unittest +from datasets import load_dataset + from transformers import set_seed from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow @@ -187,3 +189,37 @@ def test_sink_cache_hard(self): gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) + + @require_auto_gptq + def test_sink_cache_iterative_prompts(self): + """Tests that SinkCache supports more than one new token at once, when shifting the cache""" + tokenizer = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ") + model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ", device_map="auto") + + # Loading the prompts to simulate user interactions + prompt_dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") + prompts = [prompt for prompts in prompt_dataset["prompt"] for prompt in prompts] + + # Prepare generation settings + cache = SinkCache(window_length=512, num_sink_tokens=4) + input_ids = torch.tensor([], device=model.device, dtype=torch.int) + for prompt in sorted(prompts)[:8]: + # Tokenize the prompt with the correct chat template + chat = [{"role": "user", "content": prompt}] + tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( + model.device + ) + input_ids = torch.cat((input_ids, tokenized_chat), dim=1) + + # Perform the generation + gen_out = model.generate( + input_ids, do_sample=False, max_new_tokens=10, past_key_values=cache, use_cache=True + ) + input_ids = gen_out + + # We went well beyond the cache length + self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) + + # And it still produces a coherent output + decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) + self.assertTrue(decoded[0].endswith("<|assistant|>\n<|> interface is uninspiringly")) From f1f12c1def6b951fb660519c581ac07efc89a223 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 14:01:45 +0000 Subject: [PATCH 2/9] better test --- tests/test_cache_utils.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index f6d316b9f0b4a3..f0e580134d4d99 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -15,8 +15,6 @@ import unittest -from datasets import load_dataset - from transformers import set_seed from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow @@ -195,15 +193,15 @@ def test_sink_cache_iterative_prompts(self): """Tests that SinkCache supports more than one new token at once, when shifting the cache""" tokenizer = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ") model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ", device_map="auto") - - # Loading the prompts to simulate user interactions - prompt_dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train") - prompts = [prompt for prompts in prompt_dataset["prompt"] for prompt in prompts] + prompt = ( + "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " + "and must-see attractions." + ) # Prepare generation settings - cache = SinkCache(window_length=512, num_sink_tokens=4) + cache = SinkCache(window_length=256, num_sink_tokens=4) input_ids = torch.tensor([], device=model.device, dtype=torch.int) - for prompt in sorted(prompts)[:8]: + for _ in range(3): # Tokenize the prompt with the correct chat template chat = [{"role": "user", "content": prompt}] tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( @@ -213,7 +211,7 @@ def test_sink_cache_iterative_prompts(self): # Perform the generation gen_out = model.generate( - input_ids, do_sample=False, max_new_tokens=10, past_key_values=cache, use_cache=True + input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True ) input_ids = gen_out @@ -222,4 +220,11 @@ def test_sink_cache_iterative_prompts(self): # And it still produces a coherent output decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) - self.assertTrue(decoded[0].endswith("<|assistant|>\n<|> interface is uninspiringly")) + last_output = ( + "<|assistant|>\nHawaii, the Aloha State post for a travel destination you've taken. Your post's and " + "must-see landmarks. Use a descriptive and engaging writing style, incorporating personal anecdotes and " + "recommendations for fellow travelers. Your post should be at least 800 words and include high-quality " + "images to enhance the reader's experience. Be sure to cover a variety of experiences, from cultural " + "immersion to outdoor adventures, and provide practical" + ) + self.assertTrue(decoded[0].endswith(last_output)) From d50b85e9a0ff72484420b73c8821beec89f19288 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 14:03:16 +0000 Subject: [PATCH 3/9] . --- tests/test_cache_utils.py | 41 --------------------------------------- 1 file changed, 41 deletions(-) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index f0e580134d4d99..1ffc4962787905 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -187,44 +187,3 @@ def test_sink_cache_hard(self): gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) - - @require_auto_gptq - def test_sink_cache_iterative_prompts(self): - """Tests that SinkCache supports more than one new token at once, when shifting the cache""" - tokenizer = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ") - model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ", device_map="auto") - prompt = ( - "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " - "and must-see attractions." - ) - - # Prepare generation settings - cache = SinkCache(window_length=256, num_sink_tokens=4) - input_ids = torch.tensor([], device=model.device, dtype=torch.int) - for _ in range(3): - # Tokenize the prompt with the correct chat template - chat = [{"role": "user", "content": prompt}] - tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( - model.device - ) - input_ids = torch.cat((input_ids, tokenized_chat), dim=1) - - # Perform the generation - gen_out = model.generate( - input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True - ) - input_ids = gen_out - - # We went well beyond the cache length - self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) - - # And it still produces a coherent output - decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) - last_output = ( - "<|assistant|>\nHawaii, the Aloha State post for a travel destination you've taken. Your post's and " - "must-see landmarks. Use a descriptive and engaging writing style, incorporating personal anecdotes and " - "recommendations for fellow travelers. Your post should be at least 800 words and include high-quality " - "images to enhance the reader's experience. Be sure to cover a variety of experiences, from cultural " - "immersion to outdoor adventures, and provide practical" - ) - self.assertTrue(decoded[0].endswith(last_output)) From 49e4d40249c613eaf4d3e02595b61da96fb336ba Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 14:03:25 +0000 Subject: [PATCH 4/9] . --- tests/test_cache_utils.py | 41 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 1ffc4962787905..f0e580134d4d99 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -187,3 +187,44 @@ def test_sink_cache_hard(self): gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) + + @require_auto_gptq + def test_sink_cache_iterative_prompts(self): + """Tests that SinkCache supports more than one new token at once, when shifting the cache""" + tokenizer = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ") + model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ", device_map="auto") + prompt = ( + "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " + "and must-see attractions." + ) + + # Prepare generation settings + cache = SinkCache(window_length=256, num_sink_tokens=4) + input_ids = torch.tensor([], device=model.device, dtype=torch.int) + for _ in range(3): + # Tokenize the prompt with the correct chat template + chat = [{"role": "user", "content": prompt}] + tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( + model.device + ) + input_ids = torch.cat((input_ids, tokenized_chat), dim=1) + + # Perform the generation + gen_out = model.generate( + input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True + ) + input_ids = gen_out + + # We went well beyond the cache length + self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) + + # And it still produces a coherent output + decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) + last_output = ( + "<|assistant|>\nHawaii, the Aloha State post for a travel destination you've taken. Your post's and " + "must-see landmarks. Use a descriptive and engaging writing style, incorporating personal anecdotes and " + "recommendations for fellow travelers. Your post should be at least 800 words and include high-quality " + "images to enhance the reader's experience. Be sure to cover a variety of experiences, from cultural " + "immersion to outdoor adventures, and provide practical" + ) + self.assertTrue(decoded[0].endswith(last_output)) From e319719774db4d37e89d540b05effc3f3aa930f1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 14:05:49 +0000 Subject: [PATCH 5/9] add test comment --- tests/test_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index f0e580134d4d99..441b365c35f758 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -218,7 +218,7 @@ def test_sink_cache_iterative_prompts(self): # We went well beyond the cache length self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) - # And it still produces a coherent output + # And it still produces a coherent english (the repetition is due to the prompt being repeated 3 times) decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) last_output = ( "<|assistant|>\nHawaii, the Aloha State post for a travel destination you've taken. Your post's and " From 029fac9dd4ac39384388ef2f920cbdf29b8186e2 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 18:30:12 +0000 Subject: [PATCH 6/9] rework as per PR comments --- src/transformers/cache_utils.py | 21 ++++++++++-- .../models/llama/modeling_llama.py | 31 +++++++---------- .../models/mistral/modeling_mistral.py | 33 +++++++------------ .../models/persimmon/modeling_persimmon.py | 23 +++++++------ src/transformers/models/phi/modeling_phi.py | 33 +++++++------------ tests/test_cache_utils.py | 11 +++---- 6 files changed, 70 insertions(+), 82 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6f2f58544ecb77..ce941a333836f4 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -42,6 +42,19 @@ def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states, if there is any.""" raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None: + length_excess = previous_seq_length + new_seq_length - max_length + if length_excess > 0: + return max_length - length_excess + return previous_seq_length + class DynamicCache(Cache): """ @@ -250,8 +263,8 @@ def update( """ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models # with partially rotated position embeddings, like Phi or Persimmon. - sin = cache_kwargs.get("sin")[: self.window_length] - cos = cache_kwargs.get("cos")[: self.window_length] + sin = cache_kwargs.get("sin") + cos = cache_kwargs.get("cos") partial_rotation_size = cache_kwargs.get("partial_rotation_size") using_rope = cos is not None and sin is not None @@ -278,7 +291,9 @@ def update( # On RoPE models, we need to recompute the Key rotation as the tokens are shifted if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin) + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, cos[: self.window_length], sin[: self.window_length] + ) if partial_rotation_size is not None: keys_to_keep, keys_pass = ( keys_to_keep[..., :partial_rotation_size], diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e2b7d8d84a399d..f4234e9a775499 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -398,19 +398,13 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the - # attention mask accordingly. `.update()` handles the cache cropping internally if needed. - kv_max_length = past_key_value.get_max_length() - if kv_max_length is not None and kv_seq_len > kv_max_length: - kv_seq_len = kv_max_length - attention_mask = attention_mask[:, :, :, -kv_seq_len:] key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -509,7 +503,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -517,12 +511,6 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the - # attention mask accordingly. `.update()` handles the cache cropping internally if needed. - kv_max_length = past_key_value.get_max_length() - if kv_max_length is not None and kv_seq_len > kv_max_length: - kv_seq_len = kv_max_length - attention_mask = attention_mask[:, :, :, -kv_seq_len:] # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -922,7 +910,7 @@ def forward( use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1139,8 +1127,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1154,10 +1144,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b032eb27b81558..ef65d8a0894b12 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -268,19 +268,13 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the - # attention mask accordingly. `.update()` handles the cache cropping internally if needed. - kv_max_length = past_key_value.get_max_length() - if kv_max_length is not None and kv_seq_len > kv_max_length: - kv_seq_len = kv_max_length - attention_mask = attention_mask[:, :, :, -kv_seq_len:] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -369,7 +363,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 @@ -414,12 +408,6 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the - # attention mask accordingly. `.update()` handles the cache cropping internally if needed. - kv_max_length = past_key_value.get_max_length() - if kv_max_length is not None and kv_seq_len > kv_max_length: - kv_seq_len = kv_max_length - attention_mask = attention_mask[:, :, :, -kv_seq_len:] # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -862,15 +850,13 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1104,8 +1090,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1119,10 +1107,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index d60cb5d055c8cd..17163dcd8edf9b 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -295,7 +295,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -318,12 +318,6 @@ def forward( # Specific to RoPE models with partial rotation cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the - # attention mask accordingly. `.update()` handles the cache cropping internally if needed. - kv_max_length = past_key_value.get_max_length() - if kv_max_length is not None and kv_seq_len > kv_max_length: - kv_seq_len = kv_max_length - attention_mask = attention_mask[:, :, :, -kv_seq_len:] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -618,7 +612,7 @@ def forward( use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -837,8 +831,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -852,10 +848,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0997162b43b326..ca32193d535893 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -334,7 +334,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -357,12 +357,6 @@ def forward( # Specific to RoPE models with partial rotation cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the - # attention mask accordingly. `.update()` handles the cache cropping internally if needed. - kv_max_length = past_key_value.get_max_length() - if kv_max_length is not None and kv_seq_len > kv_max_length: - kv_seq_len = kv_max_length - attention_mask = attention_mask[:, :, :, -kv_seq_len:] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) @@ -450,7 +444,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -472,12 +466,6 @@ def forward( if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # If the cache has a fixed length and we were about to go beyond it, update the key/value length and the - # attention mask accordingly. `.update()` handles the cache cropping internally if needed. - kv_max_length = past_key_value.get_max_length() - if kv_max_length is not None and kv_seq_len > kv_max_length: - kv_seq_len = kv_max_length - attention_mask = attention_mask[:, :, :, -kv_seq_len:] tgt_len = key_states.shape[2] @@ -867,15 +855,13 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1097,8 +1083,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1112,10 +1100,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 441b365c35f758..dd6cda6fcf7a77 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -218,13 +218,12 @@ def test_sink_cache_iterative_prompts(self): # We went well beyond the cache length self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) - # And it still produces a coherent english (the repetition is due to the prompt being repeated 3 times) + # And it still produces a coherent english decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) last_output = ( - "<|assistant|>\nHawaii, the Aloha State post for a travel destination you've taken. Your post's and " - "must-see landmarks. Use a descriptive and engaging writing style, incorporating personal anecdotes and " - "recommendations for fellow travelers. Your post should be at least 800 words and include high-quality " - "images to enhance the reader's experience. Be sure to cover a variety of experiences, from cultural " - "immersion to outdoor adventures, and provide practical" + "<|assistant|>\nHawaii, the Aloha State, is a paradise on earth. From its stunning beaches to its lush " + "greenery, Hawaii is a destination that will leave you in awe. I recently had the privilege of visiting " + "this tropical paradise, and I'm excited to share my experiences with you.\n\nFirstly, let's talk about " + "the culture. Hawaii has a rich and unique culture that is deeply rooted in its history. One of the best" ) self.assertTrue(decoded[0].endswith(last_output)) From fa8bf6277064b57c826302d5ccdcc2f9319c38c5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 18:53:52 +0000 Subject: [PATCH 7/9] gptq not determinitic, bad for tests --- tests/test_cache_utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index dd6cda6fcf7a77..72d055c8806afd 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -188,11 +188,12 @@ def test_sink_cache_hard(self): decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) - @require_auto_gptq def test_sink_cache_iterative_prompts(self): """Tests that SinkCache supports more than one new token at once, when shifting the cache""" - tokenizer = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ") - model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ", device_map="auto") + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + model = AutoModelForCausalLM.from_pretrained( + "HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16 + ) prompt = ( "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " "and must-see attractions." @@ -221,9 +222,10 @@ def test_sink_cache_iterative_prompts(self): # And it still produces a coherent english decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) last_output = ( - "<|assistant|>\nHawaii, the Aloha State, is a paradise on earth. From its stunning beaches to its lush " - "greenery, Hawaii is a destination that will leave you in awe. I recently had the privilege of visiting " - "this tropical paradise, and I'm excited to share my experiences with you.\n\nFirstly, let's talk about " - "the culture. Hawaii has a rich and unique culture that is deeply rooted in its history. One of the best" + "<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of " + "Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the " + "beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences " + "and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip " + "was visiting the historic district of Honolulu. Here," ) self.assertTrue(decoded[0].endswith(last_output)) From dd3d069bdd270239a09cf71a384b473d2ea13708 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 19:08:23 +0000 Subject: [PATCH 8/9] fix logic --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ce941a333836f4..860555f04ff7c1 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -52,7 +52,7 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - if max_length is not None: length_excess = previous_seq_length + new_seq_length - max_length if length_excess > 0: - return max_length - length_excess + return max_length - new_seq_length return previous_seq_length From e629c215460aa5a6e1bb7c4dfcd4802855bbe18a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 8 Dec 2023 19:14:14 +0000 Subject: [PATCH 9/9] simpler logic --- src/transformers/cache_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 860555f04ff7c1..b298a7bdd0f5d6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -49,10 +49,8 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - # length, we will need to evict part of the cache (and thus not all cache is usable) max_length = self.get_max_length() previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None: - length_excess = previous_seq_length + new_seq_length - max_length - if length_excess > 0: - return max_length - new_seq_length + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length return previous_seq_length