Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offloaded KV Cache #31325

Merged
merged 8 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,111 @@ def batch_select_indices(self, indices: torch.Tensor):
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]


class OffloadedCache(DynamicCache):
"""
n17s marked this conversation as resolved.
Show resolved Hide resolved
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
Useful for generating from models with very long context.

When layer k is executing it moves the cache of layer k-1 to the CPU and prefetches the KV cache of layer k+1.
"""
n17s marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self) -> None:
assert torch.cuda.is_available(), "OffloadedCache can only be used with a GPU"
n17s marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.beam_idx = None # used to delay beam search operations

def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
n17s marked this conversation as resolved.
Show resolved Hide resolved
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)

def evict_previous_layer(self, layer_idx: int):
"Moves the previous layer cache to the CPU"
if len(self) > 2:
prev_layer_idx = (layer_idx - 1) % len(self)
self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
n17s marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
self.prefetch_stream.synchronize()
key_tensor = self.key_cache[layer_idx]
value_tensor = self.value_cache[layer_idx]
# Now deal with beam search ops which were delayed
if self.beam_idx is not None:
self.beam_idx = self.beam_idx.to(original_device)
key_tensor = key_tensor.index_select(0, self.beam_idx)
value_tensor = value_tensor.index_select(0, self.beam_idx)
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
return (key_tensor, value_tensor)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Saves the beam indices and reorders the cache when the tensor is back to its device."""
# We delay this operation until the tensors are back to their original
# device because performing torch.index_select on the CPU is very slow
del self.beam_idx
self.beam_idx = beam_idx.clone()

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self.original_device.append(key_states.device)
self.evict_previous_layer(layer_idx)
else:
key_tensor, value_tensor = self[layer_idx]
self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

# According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
# if a method is not supposed to be supported in a subclass we should set it to None
from_legacy_cache = None

to_legacy_cache = None


class QuantizedCache(DynamicCache):
"""
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
HQQQuantizedCache,
HybridCache,
MambaCache,
OffloadedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
Expand Down Expand Up @@ -1814,6 +1815,8 @@ def generate(
)

model_kwargs[cache_name] = cache_class(cache_config)
elif generation_config.cache_implementation == "offloaded":
model_kwargs[cache_name] = OffloadedCache()
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
Expand Down
52 changes: 52 additions & 0 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
GenerationConfig,
GPT2LMHeadModel,
LlamaConfig,
SinkCache,
Expand Down Expand Up @@ -455,3 +456,54 @@ def test_static_cache_extra_left_padding(self):
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
def test_static_cache_beam_search(self):
pass

@require_torch_gpu
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
"num_beams": 4,
"num_beam_groups": 2,
"num_return_sequences": 4,
"diversity_penalty": 1.0,
"max_new_tokens": 20,
"early_stopping": True,
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
original_outputs = model.generate(generation_config=original, **inputs)
offloaded_outputs = model.generate(generation_config=offloaded, **inputs)
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
assert torch.all(original_output == offloaded_output).item()

@require_torch_gpu
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device
input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
"num_beams": 4,
"num_beam_groups": 2,
"num_return_sequences": 4,
"diversity_penalty": 1.0,
"max_new_tokens": 20,
"early_stopping": True,
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
torch.cuda.reset_peak_memory_stats(device)
model.generate(generation_config=original, **inputs)
original_peak_memory = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
assert offloaded_peak_memory < original_peak_memory
Loading