Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: New Cache abstraction and Attention Sinks support #26681

Merged
merged 35 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d200a73
Draft version of new KV Caching
tomaarsen Oct 9, 2023
ffd7ba4
Address numerous PR suggestions
tomaarsen Oct 12, 2023
be0f917
Implement the SinkCache through backward+forward rotations
tomaarsen Oct 16, 2023
e9ffd60
Integrate (Sink)Cache with Llama FA2
tomaarsen Oct 16, 2023
c0e327c
Set use_legacy_cache=True as default, allows for test passes
tomaarsen Oct 17, 2023
8565e9d
Move from/to_legacy_cache to ...Model class
tomaarsen Oct 19, 2023
4c3bf9a
Undo unnecessary newline change
tomaarsen Oct 19, 2023
1b9ec3d
Remove copy utility from deprecated OpenLlama
tomaarsen Oct 24, 2023
c3c4d5a
Match import style
tomaarsen Oct 24, 2023
a40037d
manual rebase with main
gante Nov 21, 2023
2e66bc6
Cache class working with generate (#1)
gante Nov 21, 2023
ac766e3
move import
gante Nov 21, 2023
3490e1e
add default to model_kwargs.get('use_legacy_cache')
gante Nov 22, 2023
3520c47
correct failing test
gante Nov 22, 2023
3534699
Apply suggestions from code review
gante Nov 22, 2023
f4ced8a
apply PR suggestions
gante Nov 22, 2023
f6e7d2e
fix failing test
gante Nov 22, 2023
1510746
Apply suggestions from code review
gante Nov 30, 2023
00f373b
PR comments
gante Dec 4, 2023
89ffc8d
tmp commit
gante Dec 4, 2023
5aa4573
add docstrings
gante Dec 4, 2023
6675c20
more tests, more docstrings, add to docs
gante Dec 4, 2023
7bf1fe0
derp
gante Dec 4, 2023
2cd20a4
tmp commit
gante Dec 4, 2023
4d87439
tmp dbg
gante Dec 4, 2023
7f0fc57
more dbg
gante Dec 4, 2023
7389b6b
fix beam search bug
gante Dec 4, 2023
69085bf
cache can be a list of tuples in some models
gante Dec 4, 2023
ebd223b
fix group beam search
gante Dec 4, 2023
03fa241
all but sinkcache integration tests
gante Dec 6, 2023
a9fe510
fix sink cache and add hard integration test
gante Dec 6, 2023
e370d33
now also compatible with input_embeds input
gante Dec 6, 2023
e7a6df7
PR comments
gante Dec 7, 2023
4bff583
add Cache support to Phi+FA2
gante Dec 7, 2023
ee60b1c
make fixup
gante Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] TextStreamer

[[autodoc]] TextIteratorStreamer

## Caches

[[autodoc]] Cache
- update

[[autodoc]] DynamicCache
- update
- get_seq_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] SinkCache
- update
- get_seq_length
- reorder_cache
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -5945,6 +5946,7 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
298 changes: 298 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
from typing import Any, Dict, List, Optional, Tuple

import torch


class Cache:
tomaarsen marked this conversation as resolved.
Show resolved Hide resolved
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.

Return:
A tuple containing the updated key and value states.
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
"""

def __init__(self) -> None:
gante marked this conversation as resolved.
Show resolved Hide resolved
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache


class SinkCache(Cache):
gante marked this conversation as resolved.
Show resolved Hide resolved
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.

Parameters:
window_length (`int`):
The length of the context window.
num_sink_tokens (`int`):
The number of sink tokens. See the original paper for more information.
"""

def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states

def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)

# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin

self.cos_sin_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_cache[key_states.shape[-2]]

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
cache_length = self.key_cache[layer_idx].shape[-2]
return min(cache_length, self.window_length - 1)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted.

Return:
A tuple containing the updated key and value states.
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None

# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]

# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache.append(key_states)
self.value_cache.append(value_states)

elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

else:
# Shifting cache
keys_to_keep = self.key_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
]

# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin)
gante marked this conversation as resolved.
Show resolved Hide resolved
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
keys_to_keep[..., partial_rotation_size:],
)
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
if partial_rotation_size is not None:
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)

# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)

sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
values_to_keep = self.value_cache[layer_idx][
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
]
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
Loading
Loading