Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: SinkCache can handle iterative prompts #27907

Merged
merged 9 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")

def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length


class DynamicCache(Cache):
"""
Expand Down Expand Up @@ -120,6 +135,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return 0
return self.key_cache[layer_idx].shape[-2]

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
Expand Down Expand Up @@ -209,8 +228,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
cache_length = self.key_cache[layer_idx].shape[-2]
return min(cache_length, self.window_length - 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was a bit of a hack, get_max_length makes us no longer need the hack :)

get_seq_length now always does what the fn name and the docstring say it does.

return self.key_cache[layer_idx].shape[-2]

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.window_length

def update(
self,
Expand Down Expand Up @@ -267,7 +289,9 @@ def update(

# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin)
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, cos[: self.window_length], sin[: self.window_length]
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
Expand Down
19 changes: 12 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down Expand Up @@ -503,7 +503,7 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
Expand Down Expand Up @@ -910,7 +910,7 @@ def forward(
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -1127,8 +1127,10 @@ def prepare_inputs_for_generation(
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
Expand All @@ -1142,10 +1144,13 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay better! thanks

if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand Down
21 changes: 12 additions & 9 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slice_length or window_length might be better? but a nit feel free to ignore

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down Expand Up @@ -363,7 +363,7 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
Expand Down Expand Up @@ -850,15 +850,13 @@ def forward(
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

seq_length_with_past = seq_length
past_key_values_length = 0

if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -1092,8 +1090,10 @@ def prepare_inputs_for_generation(
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
Expand All @@ -1107,10 +1107,13 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand Down
17 changes: 11 additions & 6 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# Partial rotary embedding
Expand Down Expand Up @@ -612,7 +612,7 @@ def forward(
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
past_key_values_length = past_key_values.get_usable_length(seq_length)
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
Expand Down Expand Up @@ -831,8 +831,10 @@ def prepare_inputs_for_generation(
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
Expand All @@ -846,10 +848,13 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand Down
21 changes: 12 additions & 9 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# Partial rotary embedding
Expand Down Expand Up @@ -444,7 +444,7 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# Partial rotary embedding
Expand Down Expand Up @@ -855,15 +855,13 @@ def forward(
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

seq_length_with_past = seq_length
past_key_values_length = 0

if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -1085,8 +1083,10 @@ def prepare_inputs_for_generation(
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
Expand All @@ -1100,10 +1100,13 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,45 @@ def test_sink_cache_hard(self):
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))

def test_sink_cache_iterative_prompts(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test would fail on main

"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
)
prompt = (
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)

# Prepare generation settings
cache = SinkCache(window_length=256, num_sink_tokens=4)
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
for _ in range(3):
# Tokenize the prompt with the correct chat template
chat = [{"role": "user", "content": prompt}]
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
)
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)

# Perform the generation
gen_out = model.generate(
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
)
input_ids = gen_out

# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)

# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
last_output = (
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
"was visiting the historic district of Honolulu. Here,"
)
self.assertTrue(decoded[0].endswith(last_output))
Loading