Skip to content

Commit

Permalink
Adaptive batch sizing (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Nov 11, 2023
1 parent df909e8 commit c5581b2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is

- Fused modules are activated when you use `fuse_layers=True`.
- A custom cache is implemented. It preallocates based on batch size and sequence length.
- You cannot change the sequence length or batch size after you have created your model.
- You cannot change the sequence length after you have created your model.
- Reference: `AutoAWQForCausalLM.from_quantized(max_new_tokens=seq_len, batch_size=batch_size)`
- The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux.
- The `past_key_values` from `model.generate()` are only dummy values, so they cannot be used after generation.
Expand Down
11 changes: 7 additions & 4 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,14 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape

# Reallocate cache if batch size changes
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
if bsz > self.cache_batch_size:
self.cache.increase_batch_size(bsz)
self.cache_batch_size = bsz
elif bsz < self.cache_batch_size:
self.cache.decrease_batch_size(bsz)
self.cache_batch_size = bsz

xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Expand Down
11 changes: 10 additions & 1 deletion awq/modules/fused/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,13 @@ def roll_kv_n_steps(self, start_pos, n=100):
def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)


def increase_batch_size(self, to_bsz):
"""Dynamically allocate new kv when batch size changes."""
self.v = torch.zeros(to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device)
self.k = torch.zeros(to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device)

def decrease_batch_size(self, to_bsz):
"""Dynamically remove part of cache if batch size changes."""
self.v = self.v[:to_bsz, :, :, :]
self.k = self.k[:to_bsz, :, :, :, :]

0 comments on commit c5581b2

Please sign in to comment.