Add customized static cache implementation #4385
Closed
+165
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is based on the prototype done in huggingface/transformers#31706. We want to export HF models to ExecuTorch. However, we cannot do that right now due to its Cache is not a
torch.nn.Module
.This PR tries to solve this problem by implementing a customized
StaticCache
which is not only aCache
but also atorch.nn.Module
. Most of its implementation is copied from transformers with the following modification:torch.nn.Module
register_buffer
call, copied from [Demo][ExecuTorch] Lower and run native Gemma e2e in ExecuTorch huggingface/transformers#31706StaticCache
implementation: 1.get_seq_length
should return a number instead of tensor, 2.update
should only returned filled cache slots instead of the whole static cache.Test Plan:
Make sure the following commands generate the exact same output but the later one, with kv cache enabled, is faster: