Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: fix end to end compilation #32465

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,19 +1024,22 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
key_cache = getattr(self, f"key_cache_{idx}")
value_cache = getattr(self, f"value_cache_{idx}")
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(value_cache)
self.key_cache.append(key_cache)
self.value_cache.append(value_cache)
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
Comment on lines +1035 to +1036
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.register_buffer can't be compiled, but is needed for torch.export

new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def update(
self,
Expand Down
17 changes: 9 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,9 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = cache_position
return model_kwargs

def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
def _get_cache(
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache or uses a different batch size.
Expand Down Expand Up @@ -1477,7 +1479,7 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l
"config": self.config,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": self.device,
"device": device,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dropped this line in the end-to-end PR, in a rebase 😢 . self.device can't be called at compilation time

"dtype": cache_dtype,
}
self._cache = cache_cls(**cache_kwargs)
Expand Down Expand Up @@ -1813,12 +1815,11 @@ def generate(
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1)
* getattr(generation_config, "num_return_sequences", 1)
* batch_size,
generation_config.max_length,
model_kwargs,
cache_implementation=generation_config.cache_implementation,
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
max_cache_len=generation_config.max_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
Expand Down
Loading