Skip to content

Commit

Permalink
[Gen] Fix FT kernel smem size, CG when batch size changed
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Apr 20, 2023
1 parent 96d10f6 commit 01c3eb1
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 15 deletions.
15 changes: 7 additions & 8 deletions csrc/ft_attention/decoder_masked_multihead_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@

#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
DO_CROSS_ATTENTION><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)

////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
19 changes: 15 additions & 4 deletions flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,15 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql
else:
assert inference_params.fused_ft_kernel
assert ft_attention is not None
batch_start = inference_params.batch_size_offset
batch_end = batch_start + qkv.shape[0]
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim,
# neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
Expand Down Expand Up @@ -605,10 +610,16 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
else:
assert inference_params.fused_ft_kernel
assert ft_attention is not None
batch_start = inference_params.batch_size_offset
batch_end = batch_start + qkv.shape[0]
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim, inference_params.sequence_len_offset,
self.rotary_emb_dim,
# neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
Expand Down
7 changes: 4 additions & 3 deletions flash_attn/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,16 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
if s_type not in cache.callables:
if (batch_size, s_type) not in cache.callables:
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
cache.callables[s_type] = capture_graph(
cache.callables[batch_size, s_type] = capture_graph(
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
n_warmups=n_warmups
)

def dispatch(input_ids, position_ids, seqlen):
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
batch_size = input_ids.shape[0]
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)

cache.run = dispatch
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
Expand Down
80 changes: 80 additions & 0 deletions tests/models/test_gpt_generation_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import re
import time

import torch
import pytest

from einops import rearrange

from transformers import GPT2Config

from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.utils.generation import update_graph_cache


def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=True,
teacher_outputs=teacher_outputs, return_dict_in_generate=True,
output_scores=True, timing=True, **kwargs)
return torch.stack(out.scores, dim=1)


@pytest.mark.parametrize('seqlen,maxlen', [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize('rotary', [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.
"""
dtype = torch.float16
device = 'cuda'
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
config.n_positions = 16 * 1024
assert seqlen <= maxlen <= config.n_positions
if rotary is not None:
config.n_positions = 0
config.rotary_emb_dim = 32
config.rotary_emb_interleaved = rotary == "interleaved"
config.residual_in_fp32 = True
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True

model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.eval()

torch.manual_seed(0)
batch_size = 1
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)

logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)

# Try increasing batch size and seqlen, then decrease them to see if it's still correct
batch_size = 3
maxlen += 30
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)

batch_size = 2
maxlen -= 35
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long,
device=device)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)

0 comments on commit 01c3eb1

Please sign in to comment.