Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: use batch_size instead of max_batch_size #32657

Merged
merged 3 commits into from
Aug 16, 2024

Conversation

gante
Copy link
Member

@gante gante commented Aug 13, 2024

What does this PR do?

Renames the input argument max_batch_size, present in static-shaped caches, to batch_size. max_batch_size is imprecise: the cache needs the EXACT batch size being used.

The imprecise variable name and description was a source of issues, e.g. here.

NOTE: while it is technically feasible to accept smaller batch sizes in static-shaped caches, we would be slicing the cache at each layer. Slicing is an expensive operation, and the point of using static-shaped caches is to be fast. In other words, we would be enabling a silent incorrect usage of the class 🤗


✅ all chances are backwards compatible, and the user will only see a warning if passing the batch size through the deprecated keyword argument

from transformers import AutoConfig, StaticCache

config = AutoConfig.from_pretrained("gpt2")

# No warnings
StaticCache(config, 8, 100, "cpu")
StaticCache(config=config, batch_size=8, max_cache_len=100, device="cpu")

# Warnings
StaticCache(config=config, max_batch_size=8, max_cache_len=100, device="cpu")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for making it clearer!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating, we could keep the name and make sure the description is better as well!

Comment on lines +1020 to +1025
if max_batch_size is not None:
logger.warning_once(
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.46. Use the more precisely named 'batch_size' argument instead."
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code on the hub will complain but yes it makes sense

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
@gante gante merged commit cf32ee1 into huggingface:main Aug 16, 2024
25 checks passed
@gante gante deleted the cache_var_name branch August 16, 2024 10:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants