From 38cd8b2c3e70bd2edc689232baf4eba15188cfee Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 8 Dec 2023 11:50:17 +0100 Subject: [PATCH 1/2] Update attn.py --- awq/modules/fused/attn.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 44e63061..5eca30a5 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -6,12 +6,22 @@ from awq.modules.fused.cache import WindowedCache from awq.utils.fused_utils import get_attention_shapes + try: import ft_inference_engine FT_INSTALLED = True except: FT_INSTALLED = False +HF_NEW_CACHE_FORMAT = False + +import transformers +# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format +HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils") +if HF_NEW_CACHE_FORMAT: + from transformers.cache_utils import DynamicCache + + class RoPE(nn.Module): def __init__(self, hidden_size, n_heads, max_seq_len, device): super(RoPE, self).__init__() @@ -222,5 +232,11 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar # past_key_value is replaced with cache_v, cache_k, returning empty data # we pass a dummy past kv cache for transformers to be able to retrieve the correct info # about past key length - past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] + past_key_value = [torch.zeros(1, 1, self.start_pos, 1), torch.zeros(1, 1, self.start_pos, 1)] + + if HF_NEW_CACHE_FORMAT: + new_cache = DynamicCache() + new_cache.update(past_key_value[0], past_key_value[1], layer_idx=0) + past_key_value = new_cache + return attn_output, attention_weight, past_key_value From 7001b2a46166815259446d13005dd63358f7d32d Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 8 Dec 2023 11:52:41 +0100 Subject: [PATCH 2/2] Update attn.py --- awq/modules/fused/attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 5eca30a5..48845a2d 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -232,11 +232,11 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar # past_key_value is replaced with cache_v, cache_k, returning empty data # we pass a dummy past kv cache for transformers to be able to retrieve the correct info # about past key length - past_key_value = [torch.zeros(1, 1, self.start_pos, 1), torch.zeros(1, 1, self.start_pos, 1)] + past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] - if HF_NEW_CACHE_FORMAT: + if HF_NEW_CACHE_FORMAT and self.is_hf_transformers: new_cache = DynamicCache() - new_cache.update(past_key_value[0], past_key_value[1], layer_idx=0) + new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0) past_key_value = new_cache return attn_output, attention_weight, past_key_value