Skip to content

Commit

Permalink
Merge pull request #75 from casper-hansen/fix_runtime
Browse files Browse the repository at this point in the history
Fix KV cache shapes error
  • Loading branch information
casper-hansen authored Sep 27, 2023
2 parents 8eb26eb + cba9a28 commit c57da6b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
64 changes: 37 additions & 27 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,32 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() )
self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() )

if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True

def _get_attention_shapes(self, attention_shapes, max_seq_len):
if attention_shapes is not None:
self.attention_shapes = attention_shapes
attention_shapes = attention_shapes

elif self.n_kv_heads == 0:
self.attention_shapes = {
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
Expand All @@ -104,7 +124,7 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max
}

else:
self.attention_shapes = {
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
Expand All @@ -121,40 +141,30 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}

self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)

self.cache_k = (
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)

if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
return attention_shapes

def forward(
self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)

if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
# Roll cache to the left
roll_len = self.start_pos
self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
# Zero out the new part
self.cache_v[:, :, -roll_len:, :] = 0
self.cache_k[:, :, :, -roll_len:, :] = 0
self.start_pos = 0

xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])

Expand All @@ -179,7 +189,7 @@ def forward(
.permute(0, 2, 3, 1, 4)
.contiguous()
)

self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

Expand Down
7 changes: 6 additions & 1 deletion awq/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,9 @@ def clear_memory(weight=None):
if weight is not None:
del weight
gc.collect()
torch.cuda.empty_cache()
torch.cuda.empty_cache()

def compute_memory_used_pct(device):
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
return memory_pct

0 comments on commit c57da6b

Please sign in to comment.