Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: use batch_size instead of max_batch_size #32657

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16

past_key_values = StaticCache(
config=model.config,
max_batch_size=1,
batch_size=1,
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
device=model.device,
Expand Down Expand Up @@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
Expand Down
127 changes: 90 additions & 37 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,14 @@ class StaticCache(Cache):
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
gante marked this conversation as resolved.
Show resolved Hide resolved
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
Expand All @@ -999,22 +1000,37 @@ class StaticCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)

Comment on lines +1020 to +1025
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code on the hub will complain but yes it makes sense

self.batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype if dtype is not None else torch.float32
self.dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
Expand All @@ -1024,7 +1040,7 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
Expand Down Expand Up @@ -1130,13 +1146,14 @@ class SlidingWindowCache(StaticCache):
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
Expand All @@ -1152,13 +1169,22 @@ class SlidingWindowCache(StaticCache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
Expand All @@ -1168,7 +1194,12 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
)
max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__(
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
config=config,
batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
max_batch_size=max_batch_size,
)

def update(
Expand Down Expand Up @@ -1407,13 +1438,14 @@ class HybridCache(Cache):
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`, *optional*, defaults to `"cpu"`):
device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
Example:
Expand All @@ -1429,28 +1461,42 @@ class HybridCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
```
"""

def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__()
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
self.batch_size = batch_size or max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype if dtype is not None else torch.float32
self.dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
Expand All @@ -1459,9 +1505,9 @@ def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, devi
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
max_batch_size,
self.batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
self.head_dim,
Expand Down Expand Up @@ -1564,11 +1610,12 @@ class MambaCache:
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
dtype (*optional*, defaults to `torch.float16`):
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device`, *optional*):
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. Should be the same as the layer.
Attributes:
Expand Down Expand Up @@ -1596,37 +1643,43 @@ class MambaCache:
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> past_kv = outputs.past_key_values
```
"""

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
max_batch_size: int,
batch_size: int = None,
dtype: torch.dtype = torch.float16,
device: Optional[str] = None,
**kwargs,
device: Optional[Union[torch.device, str]] = None,
max_batch_size: Optional[int] = None,
):
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)
self.dtype = dtype
self.max_batch_size = max_batch_size
self.batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel

self.conv_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.batch_size,
self.intermediate_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
self.ssm_states: torch.Tensor = torch.zeros(
config.num_hidden_layers,
self.max_batch_size,
self.batch_size,
self.intermediate_size,
self.ssm_state_size,
device=device,
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,7 +1430,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
return model_kwargs

def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
Expand All @@ -1452,7 +1452,7 @@ def _get_cache(
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.batch_size != batch_size
)
if cache_implementation != "mamba":
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
Expand All @@ -1477,7 +1477,7 @@ def _get_cache(

cache_kwargs = {
"config": self.config,
"max_batch_size": max_batch_size,
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
Expand Down Expand Up @@ -1816,7 +1816,7 @@ def generate(
)
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
device=device,
model_kwargs=model_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def forward(
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
Expand Down
4 changes: 2 additions & 2 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ def test_stacked_causal_mask_static_cache(self):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def test_partial_stacked_causal_mask_static_cache(self):
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
past_key_values = StaticCache(
config=self.model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=max_cache_len,
device=torch_device,
dtype=self.model.dtype,
Expand Down
4 changes: 2 additions & 2 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
end_of_text_token = 32000

class Phi3MiniWithStaticCache(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
batch_size=batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,
Expand Down
2 changes: 1 addition & 1 deletion tests/quantization/aqlm_integration/test_aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
# Setup static KV cache for generation
past_key_values = StaticCache(
config=self.quantized_model.config,
max_batch_size=1,
batch_size=1,
max_cache_len=seq_length + self.max_new_tokens + 1,
device=torch_device,
dtype=self.quantized_model.config._pre_quantization_dtype,
Expand Down
Loading
Loading