From e31ff367db9e99f9482040f3be899c805937bf02 Mon Sep 17 00:00:00 2001 From: helunwencser Date: Thu, 25 Jul 2024 12:29:31 -0700 Subject: [PATCH] Add customized static cache implementation --- examples/models/phi-3-mini/eager.py | 18 ++- extension/llm/transformers/__init__.py | 0 extension/llm/transformers/static_cache.py | 148 +++++++++++++++++++++ 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 extension/llm/transformers/__init__.py create mode 100644 extension/llm/transformers/static_cache.py diff --git a/examples/models/phi-3-mini/eager.py b/examples/models/phi-3-mini/eager.py index a2aeb8415c..71f15e09a4 100644 --- a/examples/models/phi-3-mini/eager.py +++ b/examples/models/phi-3-mini/eager.py @@ -12,6 +12,8 @@ import torch +from extension.llm.transformers.static_cache import ETStaticCache + from transformers import AutoTokenizer, Phi3ForCausalLM end_of_text_token = 32000 @@ -40,7 +42,18 @@ def _generate_token(args, model, prompt_tokens): def _generate_token_with_kv_cache(args, model, prompt_tokens): print("Generating tokens:", end="", flush=True) - result = model.forward(input_ids=prompt_tokens, use_cache=True, return_dict=True) + result = model.forward( + input_ids=prompt_tokens, + use_cache=True, + return_dict=True, + past_key_values=ETStaticCache( + model.config, + prompt_tokens.shape[0], + args.seq_len + prompt_tokens.shape[-1], + device=model.device, + dtype=model.dtype, + ), + ) current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item() current_key_value = result.past_key_values @@ -55,6 +68,9 @@ def _generate_token_with_kv_cache(args, model, prompt_tokens): use_cache=True, return_dict=True, past_key_values=current_key_value, + cache_position=torch.arange( + 0, prompt_tokens.shape[-1] + len(generated_tokens), device=model.device + ), ) current_token = torch.argmax(result.logits[:, -1, :], dim=-1).item() current_key_value = result.past_key_values diff --git a/extension/llm/transformers/__init__.py b/extension/llm/transformers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/extension/llm/transformers/static_cache.py b/extension/llm/transformers/static_cache.py new file mode 100644 index 0000000000..73a2423a70 --- /dev/null +++ b/extension/llm/transformers/static_cache.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict, List, Optional, Tuple + +import torch +from transformers import PretrainedConfig, StaticCache + + +class ETStaticCache(torch.nn.Module, StaticCache): + """ + Static Cache class to be used with `torch.compile(model)`. + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + """ + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: int, + device, + dtype=torch.float32, + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = ( + config.max_position_embeddings if max_cache_len is None else max_cache_len + ) + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // config.num_attention_heads + ) + self.dtype = dtype if dtype is not None else torch.float32 + self.num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + cache_shape = ( + max_batch_size, + self.num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + for idx in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing cuda graph + # breaks when updating the cache. + self.register_buffer( + f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) + ) + self.register_buffer( + f"val_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device) + ) + key_cache = getattr(self, f"key_cache_{idx}") + val_cache = getattr(self, f"val_cache_{idx}") + torch._dynamo.mark_static_address(key_cache) + torch._dynamo.mark_static_address(val_cache) + self.key_cache.append(key_cache) + self.value_cache.append(val_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + Return: + A tuple containing the updated key and value states. + """ + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + seq_len = self.get_seq_length(layer_idx) + return ( + k_out[:, :, torch.arange(0, seq_len, device=k_out.device), :], + v_out[:, :, torch.arange(0, seq_len, device=v_out.device), :], + ) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item() + + def get_usable_length( + self, new_seq_length: int, layer_idx: Optional[int] = 0 + ) -> int: + return self.get_seq_length(layer_idx) + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + def from_legacy_cache( + self, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + ): + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + self.update(key_states, value_states, layer_idx, cache_kwargs) + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return id(self) == id(other)