Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Support TIR kernel for PagedKVCache #16374

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 77 additions & 39 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
PackedFunc f_transpose_append_;
PackedFunc f_attention_prefill_;
PackedFunc f_attention_decode_;
PackedFunc f_attention_prefill_ragged_;
PackedFunc f_attention_prefill_ragged_begin_forward_;
PackedFunc f_attention_prefill_ragged_end_forward_;
PackedFunc f_attention_prefill_begin_forward_;
PackedFunc f_attention_prefill_end_forward_;
PackedFunc f_attention_decode_begin_forward_;
PackedFunc f_attention_decode_end_forward_;
Optional<PackedFunc> f_attention_prefill_ragged_;
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward_;
Optional<PackedFunc> f_attention_prefill_ragged_end_forward_;
Optional<PackedFunc> f_attention_prefill_begin_forward_;
Optional<PackedFunc> f_attention_prefill_end_forward_;
Optional<PackedFunc> f_attention_decode_begin_forward_;
Optional<PackedFunc> f_attention_decode_end_forward_;
PackedFunc f_rotary_;
PackedFunc f_merge_inplace_;
Optional<PackedFunc> f_merge_inplace_;
Optional<PackedFunc> f_debug_get_kv_;

/*! \brief Number of fork depth in the current round of forward. */
Expand All @@ -297,19 +297,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {

public:
/*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */
explicit PagedAttentionKVCacheObj(
int64_t page_size, //
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim, //
int64_t reserved_num_seqs, int64_t num_total_pages, //
double rotary_scale, double rotary_theta, //
DLDataType dtype, DLDevice device, PackedFunc f_transpose_append,
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_ragged_begin_forward,
PackedFunc f_attention_prefill_ragged_end_forward,
PackedFunc f_attention_prefill_begin_forward, PackedFunc f_attention_prefill_end_forward,
PackedFunc f_attention_decode_begin_forward, PackedFunc f_attention_decode_end_forward,
PackedFunc f_rotary, PackedFunc f_merge_inplace, Optional<PackedFunc> f_debug_get_kv)
explicit PagedAttentionKVCacheObj(int64_t page_size, //
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim, //
int64_t reserved_num_seqs, int64_t num_total_pages, //
double rotary_scale, double rotary_theta, //
DLDataType dtype, DLDevice device,
PackedFunc f_transpose_append, PackedFunc f_attention_prefill,
PackedFunc f_attention_decode,
Optional<PackedFunc> f_attention_prefill_ragged,
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
Optional<PackedFunc> f_attention_prefill_begin_forward,
Optional<PackedFunc> f_attention_prefill_end_forward,
Optional<PackedFunc> f_attention_decode_begin_forward,
Optional<PackedFunc> f_attention_decode_end_forward,
PackedFunc f_rotary, Optional<PackedFunc> f_merge_inplace,
Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
num_qo_heads_(num_qo_heads),
Expand Down Expand Up @@ -418,6 +422,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
<< "The child sequence \"" << child_seq_id << "\" is already in the KV cache.";
CHECK(f_merge_inplace_.defined() && f_attention_prefill_ragged_.defined())
<< "Attention merge-score function not available. ForkSequence is thereby not supported.";

int32_t parent_block_idx = parent_it->second.last_block_idx;
// Create a child block with the parent block pointer.
Expand Down Expand Up @@ -558,13 +564,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
}

void EndForward() final {
if (!f_attention_prefill_end_forward_.defined() || !f_attention_decode_end_forward_.defined() ||
!f_attention_prefill_ragged_end_forward_.defined()) {
return;
}
// Mark the dirty flag as true, so that BeginForward is required
// to be invoked before the next round of model forward.
dirty_aux_data_device_ = true;
f_attention_prefill_ragged_end_forward_();
f_attention_prefill_ragged_end_forward_.value()();
for (int d = 0; d < num_depths_; ++d) {
f_attention_prefill_end_forward_(d);
f_attention_decode_end_forward_(d);
f_attention_prefill_end_forward_.value()(d);
f_attention_decode_end_forward_.value()(d);
}
}

Expand Down Expand Up @@ -845,30 +855,36 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {

/*! \brief Invoke the "begin forward" functions of underlying kernels. */
void KernelBeginForward() {
if (!f_attention_prefill_begin_forward_.defined() ||
!f_attention_decode_begin_forward_.defined() ||
!f_attention_prefill_ragged_begin_forward_.defined()) {
return;
}

if (num_depths_ == 1) {
if (use_decode_kernel_[0]) {
f_attention_decode_begin_forward_(
f_attention_decode_begin_forward_.value()(
/*depth=*/0, page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0],
num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/true);
} else {
f_attention_prefill_begin_forward_(/*depth=*/0, qo_indptr_on_depths_view_[0],
cur_batch_size_, num_qo_heads_, num_kv_heads_);
f_attention_prefill_begin_forward_.value()(/*depth=*/0, qo_indptr_on_depths_view_[0],
cur_batch_size_, num_qo_heads_, num_kv_heads_);
}
} else {
f_attention_prefill_ragged_begin_forward_(cur_append_length_indptr_view_, cur_batch_size_,
num_qo_heads_, num_kv_heads_);
f_attention_prefill_ragged_begin_forward_.value()(
cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, num_kv_heads_);
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
}
if (use_decode_kernel_[d]) {
f_attention_decode_begin_forward_(
f_attention_decode_begin_forward_.value()(
d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_,
num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/false);
} else {
f_attention_prefill_begin_forward_(/*depth=*/d, qo_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d]->shape[0],
num_qo_heads_, num_kv_heads_);
f_attention_prefill_begin_forward_.value()(/*depth=*/d, qo_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d]->shape[0],
num_qo_heads_, num_kv_heads_);
}
}
}
Expand Down Expand Up @@ -896,10 +912,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
}
} else {
// Compute appended text self-attention
f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data,
cur_append_length_indptr_view_, output, merged_attn_scores_view_,
/*causal=*/1,
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);
f_attention_prefill_ragged_.value()(q_data, cur_append_length_indptr_view_, k_data, v_data,
cur_append_length_indptr_view_, output,
merged_attn_scores_view_,
/*causal=*/1,
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);

for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
Expand All @@ -920,8 +937,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
/*causal=*/0,
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);
}
f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_,
temp_attn_scores_view_);
f_merge_inplace_.value()(output, merged_attn_scores_view_, temp_attn_output_view_,
temp_attn_scores_view_);
}
}
}
Expand Down Expand Up @@ -1068,6 +1085,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
return PagedAttentionKVCache(std::move(n));
});

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
.set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim, double rotary_scale,
double rotary_theta, NDArray init, PackedFunc f_transpose_append,
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
PackedFunc f_rotary, Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 3);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
int64_t page_size = cache_config[2];
int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size;
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append), std::move(f_attention_prefill),
std::move(f_attention_decode), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
std::move(f_rotary), NullOpt, std::move(f_debug_get_kv));
return PagedAttentionKVCache(std::move(n));
});

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::Clear);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,20 @@ def kv_cache_transpose_append(
for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim):
with T.block("k_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
T.writes(
pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf]
)
position: T.int64 = T.Cast("int64", position_map[vgpos])
pages[
T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf
] = k_data[vgpos, vh, vf]
with T.block("v_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
T.writes(
pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf]
)
position: T.int64 = T.Cast("int64", position_map[vgpos])
pages[
T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf
Expand Down Expand Up @@ -115,6 +123,11 @@ def copy_cache(
for p, h, d in T.grid(seqlen, num_kv_heads, head_dim):
with T.block("copy0"):
vp, vh, vd = T.axis.remap("SSS", [p, h, d])
T.reads(
position_map[vp],
pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd],
)
T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
position: T.int64 = T.Cast("int64", position_map[vp])
k_data[layer_id, vp, vh, vd] = pages[
T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd
Expand Down Expand Up @@ -457,7 +470,7 @@ def test_paged_attention_kv_cache_popn(kv_cache):
if pop_length != 0:
cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v)
verify_cached_kv(kv_cache, seq_ids=list(range(5)), expected_k=cached_k, expected_v=cached_v)


if __name__ == "__main__":
Expand Down
Loading
Loading