diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 740bb4b0719c61..93f2c96d2d9df0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -99,6 +99,8 @@ sections: - local: generation_strategies title: Customize the generation strategy + - local: kv_cache + title: Best Practices for Generation with Cache title: Generation - isExpanded: false sections: diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index 1f4df78b9a6616..3a9392ddd07d9b 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -174,117 +174,6 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te ``` -## KV Cache Quantization - -The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value -cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models. -Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed. - -KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache] -(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper. - -To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`. -Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class. -One has to indicate which quantization backend to use in the [`QuantizedCacheConfig`], the default is `quanto`. - - - -Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization. - - - - -```python ->>> import torch ->>> from transformers import AutoTokenizer, AutoModelForCausalLM - ->>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") ->>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0") ->>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device) - ->>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"}) ->>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) -I like rock music because it's loud and energetic. It's a great way to express myself and rel - ->>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20) ->>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) -I like rock music because it's loud and energetic. I like to listen to it when I'm feeling -``` - -## KV Cache Offloading - -Similarly to KV cache quantization, this strategy aims to reduce GPU VRAM usage. -It does so by moving the KV cache for most layers to the CPU. -As the model's `forward()` method iterates over the layers, this strategy maintains the current layer cache on the GPU. -At the same time it asynchronously prefetches the next layer cache as well as sending the previous layer cache back to the CPU. -Unlike KV cache quantization, this strategy always produces the same result as the default KV cache implementation. -Thus, it can serve as a drop-in replacement or a fallback for it. - -Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.) -you may notice a small degradation in generation throughput compared to the default KV cache implementation. - -To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config`. - -```python ->>> import torch ->>> from transformers import AutoTokenizer, AutoModelForCausalLM ->>> ckpt = "microsoft/Phi-3-mini-4k-instruct" - ->>> tokenizer = AutoTokenizer.from_pretrained(ckpt) ->>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0") ->>> inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device) - ->>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded") ->>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) -Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896. - ->>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23) ->>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) -Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896. -``` - - - -Cache offloading requires a GPU and can be slower than the default KV cache. Use it if you are getting CUDA out of memory errors. - - - -The example below shows how KV cache offloading can be used as a fallback strategy. -```python ->>> import torch ->>> from transformers import AutoTokenizer, AutoModelForCausalLM ->>> def resilient_generate(model, *args, **kwargs): -... oom = False -... try: -... return model.generate(*args, **kwargs) -... except torch.cuda.OutOfMemoryError as e: -... print(e) -... print("retrying with cache_implementation='offloaded'") -... oom = True -... if oom: -... torch.cuda.empty_cache() -... kwargs["cache_implementation"] = "offloaded" -... return model.generate(*args, **kwargs) -... -... ->>> ckpt = "microsoft/Phi-3-mini-4k-instruct" ->>> tokenizer = AutoTokenizer.from_pretrained(ckpt) ->>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0") ->>> prompt = ["okay "*1000 + "Fun fact: The most"] ->>> inputs = tokenizer(prompt, return_tensors="pt").to(model.device) ->>> beams = { "num_beams": 40, "num_beam_groups": 40, "num_return_sequences": 40, "diversity_penalty": 1.0, "max_new_tokens": 23, "early_stopping": True, } ->>> out = resilient_generate(model, **inputs, **beams) ->>> responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True) -``` - -On a GPU with 50 GB of RAM, running this code will print -``` -CUDA out of memory. Tried to allocate 4.83 GiB. GPU -retrying with cache_implementation='offloaded' -``` -before successfully generating 40 beams. - - ## Watermarking The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index da7ea25e54b6b0..1172e32fd0cc5a 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -386,11 +386,24 @@ A [`Constraint`] can be used to force the generation to include specific tokens - get_seq_length - reorder_cache +[[autodoc]] OffloadedCache + - update + - prefetch_layer + - evict_previous_layer + [[autodoc]] StaticCache - update - get_seq_length - reset +[[autodoc]] HybridCache + - update + - reset + +[[autodoc]] SlidingWindowCache + - update + - reset + [[autodoc]] EncoderDecoderCache - get_seq_length - to_legacy_cache @@ -398,6 +411,11 @@ A [`Constraint`] can be used to force the generation to include specific tokens - reset - reorder_cache +[[autodoc]] MambaCache + - update_conv_state + - update_ssm_state + - reset + ## Watermark Utils [[autodoc]] WatermarkDetector diff --git a/docs/source/en/kv_cache.md b/docs/source/en/kv_cache.md new file mode 100644 index 00000000000000..c0ccc49d41e683 --- /dev/null +++ b/docs/source/en/kv_cache.md @@ -0,0 +1,346 @@ + + +# Best Practices for Generation with Cache + +Efficient caching is crucial for optimizing the performance of models in various generative tasks, +including text generation, translation, summarization and other transformer-based applications. +Effective caching helps reduce computation time and improve response rates, especially in real-time or resource-intensive applications. + +Transformers support various caching methods, leveraging "Cache" classes to abstract and manage the caching logic. +This document outlines best practices for using these classes to maximize performance and efficiency. +Check out all the available `Cache` classes in the [API documentation](./internal/generation_utils.md). + +## What is Cache and why we should care? + +Imagine youโ€™re having a conversation with someone, and instead of remembering what was said previously, you have to start from scratch every time you respond. This would be slow and inefficient, right? In the world of Transformer models, a similar concept applies, and that's where Caching keys and values come into play. From now on, I'll refer to the concept as KV Cache. + +KV cache is needed to optimize the generation in autoregressive models, where the model predicts text token by token. This process can be slow since the model can generate only one token at a time, and each new prediction is dependent on the previous context. That means, to predict token number 1000 in the generation, you need information from the previous 999 tokens, which comes in the form of some matrix multiplications across the representations of those tokens. But to predict token number 1001, you also need the same information from the first 999 tokens, plus additional information from token number 1000. That is where key-value cache is used to optimize the sequential generation process by storing previous calculations to reuse in subsequent tokens, so they don't need to be computed again. + +More concretely, key-value cache acts as a memory bank for these generative models, where the model stores key-value pairs derived from self-attention layers for previously processed tokens. By storing this information, the model can avoid redundant computations and instead retrieve keys and values of previous tokens from the cache. + +
+ For the Curious Minds Who Like to Dive Deep + + ### Under the Hood: How Cache Object Works in Attention Mechanism + + When utilizing a cache object in the input, the Attention module performs several critical steps to integrate past and present information seamlessly. + + The Attention module concatenates the current key-values with the past key-values stored in the cache. This results in attention weights of shape `(new_tokens_length, past_kv_length + new_tokens_length)`. Essentially, the past and current key-values are combined to compute attention scores, ensuring that the model considers both previous context and new input. The concatenated key-values are used to compute the attention scores resulting in attention weights of shape `(new_tokens_length, past_kv_length + new_tokens_length)`. + + Therefore, when iteratively calling `forward()` instead of the `generate()` method, itโ€™s crucial to ensure that the attention mask shape matches the combined length of past and current key-values. The attention mask should have the shape `(batch_size, past_kv_length + new_tokens_length)`. This is usually handled internally when you call `generate()` method. If you want to implement your own generation loop with Cache classes, take this into consideration and prepare the attention mask to hold values to current and past tokens. + + + + One important concept you need to know when writing your own generation loop, is `cache_position`. In case you want to reuse an already filled Cache object by calling `forward()`, you have to pass in a valid `cache_position` which will indicate the positions of inputs in the sequence. Note that `cache_position` is not affected by padding, and always adds one more position for each token. For example, if key/value cache contains 10 tokens (no matter how many of it is a pad token), the cache position for the next token should be `torch.tensor([10])`. + + + + + See an example below for how to implement your own generation loop. + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model_id = "meta-llama/Llama-2-7b-chat-hf" + >>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0") + >>> tokenizer = AutoTokenizer.from_pretrained(model_id) + + >>> past_key_values = DynamicCache() + >>> messages = [{"role": "user", "content": "Hello, what's your name."}] + >>> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0") + + >>> generated_ids = inputs.input_ids + >>> cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0") + >>> max_new_tokens = 10 + + >>> for _ in range(max_new_tokens): + ... outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True) + ... # Greedily sample one next token + ... next_token_ids = outputs.logits[:, -1:].argmax(-1) + ... generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) + ... + ... # Prepare inputs for the next generation step by leaaving unprocessed tokens, in our case we have only one new token + ... # and expanding attn mask for the new token, as explained above + ... attention_mask = inputs["attention_mask"] + ... attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + ... inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask} + ... cache_position = cache_position[-1:] + 1 # add one more position for the next token + + >>> print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) + "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," + ``` + +
+ + + +## Generate with Cache + +In ๐Ÿค— Transformers, we support various Cache types to optimize the performance across different models and tasks. By default, all models generate with caching, +with the [`~DynamicCache`] class being the default cache for most models. It allows us to dynamically grow cache size, by saving more and more keys and values as we generate. If for some reason you don't want to use caches, you can pass `use_cache=False` into the `generate()` method. + +Refer to the table below to see the difference between cache types and choose the one that suits best for your use-case. + +| Cache Type | Memory Efficient | Supports torch.compile() | Initialization Recommended | Latency | Long Context Generation | +|---------------------|------------------|--------------------------|----------------------------|----------|--------------------------| +| Dynamic Cache | No | No | No | Mid | No | +| Static Cache | No | Yes | Yes | High | No | +| Quantized Cache | Yes | No | No | Low | Yes | +| Offloaded Cache | Yes | No | No | Low | No | +| Sliding Window Cache| No | Yes | Yes | High | No | +| Sink Cache | Yes | No | Yes | Mid | Yes | + + +These cache classes can be set with a `cache_implementation` argument when generating. To learn about the available options for the cache_implementation flag, please refer to the [API Documentation](./main_classes/text_generation.md#transformers.GenerationConfig). Now, let's explore each cache type in detail and see how to use them. Note that the below examples are for decoder-only Tranformer-based models. We also support ["Model-Specific Cache"] classes for models such as Mamba or Jamba, keep reading for more details. + +### Quantized Cache + +The key and value cache can occupy a large portion of memory, becoming a [bottleneck for long-context generation](https://huggingface.co/blog/llama31#inference-memory-requirements), especially for Large Language Models. +Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed. + +KV Cache quantization in `transformers` is largely inspired by the paper ["KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache"](https://arxiv.org/abs/2402.02750) and currently supports [`~QuantoQuantizedCache`] and [`~HQQQuantizedCache`] classes. For more information on the inner workings see the paper. + +To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`. +Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`~QuantizedCacheConfig`] class. +One has to indicate which quantization backend to use in the [`~QuantizedCacheConfig`], the default is `quanto`. + + + +Cache quantization can be detrimental in terms of latency if the context length is short and there is enough GPU VRAM available to run without cache quantization. It is recommended to seek balance between memory efficiency and latency. + + + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0") +>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device) + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"}) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +I like rock music because it's loud and energetic. It's a great way to express myself and rel + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +I like rock music because it's loud and energetic. I like to listen to it when I'm feeling +``` + +## OffloadedCache + +Similarly to KV cache quantization, [`~OffloadedCache`] strategy aims to reduce GPU VRAM usage. +It does so by moving the KV cache for most layers to the CPU. +As the model's `forward()` method iterates over the layers, this strategy maintains the current layer cache on the GPU. +At the same time it asynchronously prefetches the next layer cache as well as sending the previous layer cache back to the CPU. +Unlike KV cache quantization, this strategy always produces the same result as the default KV cache implementation. +Thus, it can serve as a drop-in replacement or a fallback for it. + +Depending on your model and the characteristics of your generation task (size of context, number of generated tokens, number of beams, etc.) +you may notice a small degradation in generation throughput compared to the default KV cache implementation. + +To enable KV cache offloading, pass `cache_implementation="offloaded"` in the `generation_config` or directky to the `generate()` call. + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM +>>> ckpt = "microsoft/Phi-3-mini-4k-instruct" + +>>> tokenizer = AutoTokenizer.from_pretrained(ckpt) +>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0") +>>> inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device) + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded") +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896. + +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23) +>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0]) +Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896. +``` + + + +Cache offloading requires a GPU and can be slower than dynamic KV cache. Use it if you are getting CUDA out of memory errors. + + + +The example below shows how KV cache offloading can be used as a fallback strategy. +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM +>>> def resilient_generate(model, *args, **kwargs): +... oom = False +... try: +... return model.generate(*args, **kwargs) +... except torch.cuda.OutOfMemoryError as e: +... print(e) +... print("retrying with cache_implementation='offloaded'") +... oom = True +... if oom: +... torch.cuda.empty_cache() +... kwargs["cache_implementation"] = "offloaded" +... return model.generate(*args, **kwargs) +... +... +>>> ckpt = "microsoft/Phi-3-mini-4k-instruct" +>>> tokenizer = AutoTokenizer.from_pretrained(ckpt) +>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0") +>>> prompt = ["okay "*1000 + "Fun fact: The most"] +>>> inputs = tokenizer(prompt, return_tensors="pt").to(model.device) +>>> beams = { "num_beams": 40, "num_beam_groups": 40, "num_return_sequences": 40, "diversity_penalty": 1.0, "max_new_tokens": 23, "early_stopping": True, } +>>> out = resilient_generate(model, **inputs, **beams) +>>> responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True) +``` + +On a GPU with 50 GB of RAM, running this code will print +``` +CUDA out of memory. Tried to allocate 4.83 GiB. GPU +retrying with cache_implementation='offloaded' +``` +before successfully generating 40 beams. + + + +### Static Cache + +Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates +a specific maximum size for the keys and values, allowing you to generate up to the maximum length without having to modify cache size. Check the below usage example. + +For more examples with Static Cache and JIT compilation, take a look at [StaticCache & torchcompile](./llm_optims.md#static-kv-cache-and-torchcompile) + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto") +>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device) + +>>> # simply pass the cache implementation="static" +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="static") +>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] +"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of" +``` + +### Sliding Window Cache + +As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache. + +Note that you can use this cache only for models that support sliding window, e.g. Mistral models. + + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + +>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") +>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16).to("cuda:0") +>>> inputs = tokenizer("Yesterday I was on a rock concert and.", return_tensors="pt").to(model.device) + +>>> # can be used by passing in cache implementation +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=30, cache_implementation="sliding_window") +>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] +"Yesterday I was on a rock concert and. I was so excited to see my favorite band. I was so excited that I was jumping up and down and screaming. I was so excited that I" +``` + +### Sink Cache + +Sink Cache was introduced in ["Efficient Streaming Language Models with Attention Sinks"](https://arxiv.org/abs/2309.17453). It allows you to generate long sequences of text ("infinite length" according to the paper) without any fine-tuning. That is achieved by smart handling of previous keys and values, specifically it retains a few initial tokens from the sequence, called "sink tokens". This is based on the observation that these initial tokens attract a significant portion of attention scores during the generation process. Tokens that come after "sink tokens" are discarded on a sliding windowed basis, keeping only the latest `window_size` tokens. By keeping these initial tokens as "attention sinks," the model maintains stable performance even when dealing with very long texts, thus discarding most of the previous knowledge. + +Unlike other cache classes, this one can't be used directly by indicating a `cache_implementation`. You have to initialize the Cache before calling on `generate()` as follows. + +```python +>>> import torch +>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + +>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0") +>>> inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device) + +>>> # get our cache, specify number of sink tokens and window size +>>> # Note that window size already includes sink tokens, so has to be larger +>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) +>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=30, past_key_values=past_key_values) +>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] +"This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily" +``` + +### Encoder-Decoder Cache + +The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper.md) models but we will be adding more models soon. + +In terms of usage, there is nothing special to be done and calling `generate()` or `forward()` will handle everything for you. + + +### Model-specific Cache Classes + +Some models require storing previous keys, values, or states in a specific way, and the above cache classes cannot be used. For such cases, we have several specialized cache classes that are designed for specific models. These models only accept their own dedicated cache classes and do not support using any other cache types. Some examples include [`~HybridCache`] for [Gemma2](./model_doc/gemma2.md) series models or [`~MambaCache`] for [Mamba](./model_doc/mamba.md) architecture models. + + +## Iterative Generation with Cache + +We have seen how to use each of the cache types when generating. What if you want to use cache in iterative generation setting, for example in applications like chatbots, where interactions involve multiple turns and continuous back-and-forth exchanges. Iterative generation with cache allows these systems to handle ongoing conversations effectively without reprocessing the entire context at each step. But there are some tips that you should know before you start implementing: + +The general format when doing iterative generation is as below. First you have to initialize an empty cache of the type you want, and you can start feeding in new prompts iteratively. Keeping track of dialogues history and formatting can be done with chat templates, read more on that in [chat_templating](./chat_templating.md) + +In case you are using Sink Cache, you have to crop your inputs to that maximum length because Sink Cache can generate text longer than its maximum window size, but it expects the first input to not exceed the maximum cache length. + + +```python +>>> import torch +>>> from transformers import AutoTokenizer,AutoModelForCausalLM +>>> from transformers.cache_utils import ( +>>> DynamicCache, +>>> SinkCache, +>>> StaticCache, +>>> SlidingWindowCache, +>>> QuantoQuantizedCache, +>>> QuantizedCacheConfig, +>>> ) + +>>> model_id = "meta-llama/Llama-2-7b-chat-hf" +>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='auto') +>>> tokenizer = AutoTokenizer.from_pretrained(model_id) + +>>> user_prompts = ["Hello, what's your name?", "Btw, yesterday I was on a rock concert."] + +>>> past_key_values = DynamicCache() +>>> max_cache_length = past_key_values.get_max_length() + +>>> messages = [] +>>> for prompt in user_prompts: +... messages.append({"role": "user", "content": prompt}) +... inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device) +... if isinstance(past_key_values, SinkCache): +... inputs = {k: v[:, -max_cache_length:] for k, v in inputs.items()} +... +... input_length = inputs["input_ids"].shape[1] +... +... outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values) +... completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True) +... messages.append({"role": "assistant", "content": completion}) + +print(messages) +[{'role': 'user', 'content': "Hello, what's your name?"}, {'role': 'assistant', 'content': " Hello! My name is LLaMA, I'm a large language model trained by a team of researcher at Meta AI. ๐Ÿ˜Š"}, {'role': 'user', 'content': 'Btw, yesterday I was on a rock concert.'}, {'role': 'assistant', 'content': ' Oh, cool! That sounds like a lot of fun! ๐ŸŽ‰ Did you enjoy the concert? What was the band like? ๐Ÿค”'}] +``` + + +## Re-use Cache to continue generation + +Sometimes you would want to fist fill-in cache object with key/values for certain prefix prompt and re-use it several times to generate different sequences from it. We are working hard on adding this feature to ๐Ÿค— Transformers and will update this section soon. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4c953bab6be4b0..f971b5ffe4917e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1226,10 +1226,14 @@ "DynamicCache", "EncoderDecoderCache", "HQQQuantizedCache", + "HybridCache", + "MambaCache", + "OffloadedCache", "QuantizedCache", "QuantizedCacheConfig", "QuantoQuantizedCache", "SinkCache", + "SlidingWindowCache", "StaticCache", ] _import_structure["data.datasets"] = [ @@ -5948,10 +5952,14 @@ DynamicCache, EncoderDecoderCache, HQQQuantizedCache, + HybridCache, + MambaCache, + OffloadedCache, QuantizedCache, QuantizedCacheConfig, QuantoQuantizedCache, SinkCache, + SlidingWindowCache, StaticCache, ) from .data.datasets import ( diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d9a3a3a5a50d6c..141964c1add6e3 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -299,6 +299,22 @@ class DynamicCache(Cache): It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` """ def __init__(self) -> None: @@ -657,6 +673,24 @@ class QuantoQuantizedCache(QuantizedCache): Parameters: cache_config (`QuantizedCacheConfig`): A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install quanto first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4) + >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` """ def __init__(self, cache_config: CacheConfig) -> None: @@ -698,6 +732,24 @@ class HQQQuantizedCache(QuantizedCache): Parameters: cache_config (`QuantizedCacheConfig`): A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. + + Example: + + ```python + >>> # Run pip install hqq first if you don't have it yet + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) + >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` """ def __init__(self, cache_config: CacheConfig) -> None: @@ -748,6 +800,22 @@ class SinkCache(Cache): The length of the context window. num_sink_tokens (`int`): The number of sink tokens. See the original paper for more information. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` """ def __init__(self, window_length: int, num_sink_tokens: int) -> None: @@ -917,6 +985,24 @@ class StaticCache(Cache): The device on which the cache should be initialized. Should be the same as the layer. dtype (*optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: @@ -1047,6 +1133,24 @@ class SlidingWindowCache(StaticCache): The device on which the cache should be initialized. Should be the same as the layer. dtype (*optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` """ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: @@ -1125,6 +1229,25 @@ class EncoderDecoderCache(Cache): """ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and cross-attention caches. + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache + + >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") + >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") + + >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") + + >>> # Prepare cache classes for encoder and decoder and pass it to model's forward + >>> self_attention_cache = DynamicCache() + >>> cross_attention_cache = DynamicCache() + >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): @@ -1271,6 +1394,42 @@ def batch_select_indices(self, indices: torch.Tensor): class HybridCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`, *optional*, defaults to `"cpu"`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: super().__init__() if not hasattr(config, "sliding_window") or config.sliding_window is None: @@ -1398,18 +1557,44 @@ class MambaCache: Cache for mamba model which does not have attention mechanism and key value states. Arguments: - config: MambaConfig - max_batch_size: int - dtype: torch.dtype - device: torch.device + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + dtype (*optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. Attributes: - dtype: torch.dtype - intermediate_size: int - ssm_state_size: int - conv_kernel_size: int - conv_states: torch.Tensor [layer_idx, batch_size, intermediate_size, conv_kernel_size] - ssm_states: torch.Tensor [layer_idx, batch_size, intermediate_size, ssm_state_size] + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + intermediate_size: (`int`): + Model's intermediate_size taken from config. + ssm_state_size: (`int`): + Model's state_size taken from config. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config + conv_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv = outputs.past_key_values + ``` """ def __init__( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index de739c6e70044a..258cc5191e136b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -51,6 +51,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class HybridCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MambaCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class OffloadedCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class QuantizedCache(metaclass=DummyObject): _backends = ["torch"] @@ -79,6 +100,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SlidingWindowCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class StaticCache(metaclass=DummyObject): _backends = ["torch"]