From 0927ccab3c46367c7c77842c4a3a1629319dfe8e Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 6 May 2024 15:09:26 +0000 Subject: [PATCH 01/17] support paged attention --- include/custom_op/custom_op_lite.h | 9 +- include/custom_op/kernel_context.h | 2 + include/ort_c_to_cpp.h | 3 + operators/cuda/cuda_ops.cc | 7 +- operators/cuda/paged_attention.h | 111 +++++++++++++++++++ operators/cuda/paged_attention_impl.cu | 145 +++++++++++++++++++++++++ operators/cuda/paged_attention_impl.h | 113 +++++++++++++++++++ operators/cuda/utils.cuh | 5 + 8 files changed, 393 insertions(+), 2 deletions(-) create mode 100644 operators/cuda/paged_attention.h create mode 100644 operators/cuda/paged_attention_impl.cu create mode 100644 operators/cuda/paged_attention_impl.h diff --git a/include/custom_op/custom_op_lite.h b/include/custom_op/custom_op_lite.h index d6a47af84..77951b093 100644 --- a/include/custom_op/custom_op_lite.h +++ b/include/custom_op/custom_op_lite.h @@ -454,7 +454,7 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext { public: static const int cuda_resource_ver = 1; - OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api), kernel_context_(ctx) { api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); if (!cuda_stream_) { ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); @@ -521,9 +521,16 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext { int GetCudaDeviceId() const override { return device_id_; } + + void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* mem_info, size_t count_or_bytes) override { + void* ret = nullptr; + api_.KernelContext_GetScratchBuffer(&kernel_context_, mem_info, count_or_bytes, &ret); + return ret; + } private: const OrtApi& api_; + const OrtKernelContext& kernel_context_; OrtAllocator* cpu_allocator_; OrtAllocator* cuda_allocator_; void* cuda_stream_ = {}; diff --git a/include/custom_op/kernel_context.h b/include/custom_op/kernel_context.h index 039cf3bb7..036a300e0 100644 --- a/include/custom_op/kernel_context.h +++ b/include/custom_op/kernel_context.h @@ -2,6 +2,7 @@ #include #include #include +#include "onnxruntime_c_api.h" namespace Ort { namespace Custom { @@ -26,6 +27,7 @@ class CUDAKernelContext : public KernelContext { virtual void* GetCudaStream() const = 0; virtual void* GetCublasHandle() const = 0; virtual int GetCudaDeviceId() const = 0; + virtual void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* , size_t ) { return nullptr; } }; #endif diff --git a/include/ort_c_to_cpp.h b/include/ort_c_to_cpp.h index 92c2fb01d..152aa6633 100644 --- a/include/ort_c_to_cpp.h +++ b/include/ort_c_to_cpp.h @@ -81,6 +81,9 @@ class API { return instance()->KernelContext_GetAllocator(context, mem_info, out); } #endif + static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) { + return instance()->ReleaseMemoryInfo(mem_info); + } private: const OrtApi* operator->() const { return &api_; diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 39cc02f85..ef29dc936 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -5,6 +5,9 @@ #ifdef USE_CUDA #include "cuda/fast_gelu.h" +#if ORT_API_VERSION >= 18 +#include "cuda/paged_attention.h" +#endif #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { @@ -13,8 +16,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { #ifdef USE_CUDA , CustomCudaStructV2("FastGelu", contrib::FastGelu), +#if ORT_API_VERSION >= 18 + CustomCudaStructV2("PagedAttention", contrib::PagedAttention), +#endif #if ORT_API_VERSION >= 16 - CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu) #endif diff --git a/operators/cuda/paged_attention.h b/operators/cuda/paged_attention.h new file mode 100644 index 000000000..704c170c0 --- /dev/null +++ b/operators/cuda/paged_attention.h @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "paged_attention_impl.h" + +template +struct PagedAttention { + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + int64_t num_heads = 0, head_size = 0; + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "num_heads", &num_heads)); + assert(num_heads > 0); + num_heads_ = static_cast(num_heads); + num_kv_heads_ = static_cast(OrtW::GetOpAttributeOrDefault(info, "num_kv_heads", num_heads)); + + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "head_size", &head_size)); + assert(head_size > 0); + head_size_ = static_cast(head_size); + + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_float(&info, "scale", &scale_)); + assert(scale_ > 0); + + num_queries_per_kv_ = num_heads_ / num_kv_heads_; + std::vector head_mapping_host(num_heads_); + for (int i = 0; i < num_kv_heads_; i++) { + for (int j = 0; j < num_queries_per_kv_; j++) { + head_mapping_host[i * num_queries_per_kv_ + j] = i; + } + } + + OrtAllocator* allocator = nullptr; + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator)); + allocator_ = UniquePtrWithDeletor{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; + head_mapping_ = GetScratchBuffer(allocator_->Alloc(allocator_.get(), num_heads_), allocator_.get()); + InitializeHeadMapping(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size()); + } + + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, + const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, + const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, + std::optional*> context_lens, + std::optional*> positions + std::optional*> cos_sin_cache, ortc::Tensor& attn_out) const { + InputMetadata input_metadata; + ORTX_RETURN_IF_ERROR(CheckInputs(ctx.GetCudaStream(), allocator_.get(), query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, input_metadata)); + const std::vector& query_shape = query.Shape(); + T* output_data = attn_out.Allocate(query_shape); + + if (cos_sin_cache.has_value()) { + int64_t rot_dim = (*cos_sin_cache)->Shape()[1]; + assert(rot_dim == head_size_); + rotary_embedding_neox(reinterpret_cast(ctx.GetCudaStream()), (*positions)->Data(), query.DataRaw(), key.DataRaw(), head_size_, + (*cos_sin_cache)->DataRaw(), input_metadata.num_valid_tokens, rot_dim, num_heads_, num_kv_heads_, 1); + } + + const std::vector& key_cache_shape = key_cache.Shape(); + if (input_metadata.num_valid_tokens > 0 && key_cache_shape.size() > 3) { + int64_t key_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; + int64_t value_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; + int block_size = gsl::narrow(key_cache_shape[3]); + reshape_and_cache(reinterpret_cast(ctx.GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(), + key_shape_r, value_shape_r, block_size, key_cache_shape[4], 1); + } + + using TT = typename CudaT::MappedType; + if (input_metadata.num_prompt_tokens > 0) { + //TODO(leca): flash attention for prompt > 0 case + return nullptr; // Don't handle prompt with decoding case for now + } + + if (input_metadata.num_generation_tokens > 0) { + constexpr int PARTITION_SIZE = 512; + int max_num_partitions = (input_metadata.max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE; + int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_}; + if (use_v1) { + paged_attention_v1(reinterpret_cast(ctx.GetCudaStream()), reinterpret_cast(output_data), query.DataRaw(), + key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, + block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, + value_cache.Shape()[3], input_metadata.max_context_len, nullptr, + input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); + } else { + OrtMemoryInfo* mem_info = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx.device_id, OrtMemTypeDefault, &mem_info)); + void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); + UniquePtrWithDeletor tmp_output = GetScratchBuffer(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT + void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); + UniquePtrWithDeletor exp_sums = GetScratchBuffer(exp_sums_raw, allocator_.get()); + void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); + UniquePtrWithDeletor max_logits = GetScratchBuffer(max_logits_raw, allocator_.get()); + paged_attention_v2(reinterpret_cast(ctx.GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast(output_data), query.DataRaw(), + key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, + block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, + value_cache.Shape()[3], input_metadata.max_context_len, nullptr, + input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); + + OrtW::API::ReleaseMemoryInfo(mem_info); + } + } + return nullptr; + } + +private: + int32_t num_heads_; // number of attention heads + int32_t num_kv_heads_; // number of attention kv_heads + int32_t head_size_; // number of attention heads + float scale_; // sqrt(head_size_) + UniquePtrWithDeletor head_mapping_; + int32_t num_queries_per_kv_; + UniquePtrWithDeletor allocator_; +}; \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.cu b/operators/cuda/paged_attention_impl.cu new file mode 100644 index 000000000..4d5c286bb --- /dev/null +++ b/operators/cuda/paged_attention_impl.cu @@ -0,0 +1,145 @@ +#include "paged_attention_impl.h" +#include + +namespace cuda { + +inline OrtStatusPtr CudaCall(cudaError_t cuda_error) { + if (cuda_error == cudaSuccess) return nullptr; + return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str()); +} + +void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count) { + cudaMemcpy(dest_data, src_data, count, cudaMemcpyHostToDevice); +} + +template +OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, + const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, + const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, + std::optional*> context_lens, + std::optional*> positions, InputMetadata& input_metadata) { + const std::vector& query_shape = query.Shape(); + if (query_shape.size() < 2 || query_shape.size() > 3) { + return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); + } + if (query_shape.back() != num_heads_ * head_size_) { + return OrtW::CreateStatus(MakesString("query shape should equal to num_heads_ * head_size_")); + } + + // TODO(leca): Cpu input or CUDA input? + int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0]; + if (positions.has_value()) { + std::vector positions_host((*positions)->Shape().size()); + ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); + while (positions_host.back() == 0) { + positions_host.pop_back(); + seq_len--; + } + + input_metadata.max_num_blocks_per_seq = 0; + // in prompt mode + if (positions_host.size() > 1 || positions_host.back() == 0) { + input_metadata.num_prompt_tokens = seq_len; + input_metadata.num_generation_tokens = 0; + } else { + input_metadata.num_prompt_tokens = 0; + input_metadata.num_generation_tokens = seq_len; + input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty? + + int32_t block_size = gsl::narrow(key_cache.Shape()[3]); + for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; + } + } else { + // TODO(leca): context_lens is nullptr? + std::vector context_len_host((*context_lens)->SizeInBytes()); + ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), *(context_lens)->DataRaw(), *(context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); + std::vector position_ids; + for (size_t i = 0; i < context_len_host.size(); i++) { + if (context_len_host[i] == 0) continue; + std::vector position_id(context_len_host[i]); + std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1} + position_ids.insert(position_ids.end(), position_id.begin(), position_id.end()); + } + input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, cnt), allocator); + ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); + } + input_metadata.num_valid_tokens = seq_len; + + return nullptr; +} + +void paged_attention_v1( + const cudaStream_t stream, + void* out, // [num_seqs, num_heads, head_size] + const void* query, // [num_seqs, num_heads, head_size] + const void* key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const void* value_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* head_mapping, // [num_heads] + float scale, + const int* block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* context_lens, // [num_seqs] + int block_size, + int max_context_len, + const float* __restrict__ alibi_slopes, + const int max_num_blocks_per_seq, + const int64_t* query_shapes, + int num_queries_per_kv, + int dtype) { + +} + +template +void paged_attention_v2( + const cudaStream_t stream, + void* out, // [num_seqs, num_heads, head_size] + void* exp_sums, // [num_seqs, num_heads, max_num_partitions] + void* max_logits, // [num_seqs, num_heads, max_num_partitions] + void* tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const void* query, // [num_seqs, num_heads, head_size] + const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const void* value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* head_mapping, // [num_heads] + float scale, + const int* block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* context_lens, // [num_seqs] + int block_size, + int max_context_len, + const float* alibi_slopes, + const int max_num_blocks_per_seq, + const int64_t* query_shapes, + int num_queries_per_kv, + int dtype) { + +} + +void rotary_embedding_neox( + const cudaStream_t stream, + const int64_t* positions, // [num_tokens] + void* query, // [num_tokens, num_heads * head_size] + void* key, // [num_tokens, num_kv_heads * head_size] + int head_size, + const void* cos_sin_cache, // [max_position, rot_dim] + int num_tokens, + int rot_dim, + int num_heads, + int num_kv_heads, + int dtype) { + +} + +void reshape_and_cache( + const cudaStream_t stream, + const void* key, // [num_tokens, num_heads, head_size] + const void* value, // [num_tokens, num_heads, head_size] + const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const void* value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* slot_mapping, // [num_tokens] + const int64_t* key_shapes, + const int64_t* value_shapes, + const int64_t block_size, + const int vec_x, + int dtype) { + +} + +} // namespace cuda \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.h b/operators/cuda/paged_attention_impl.h new file mode 100644 index 000000000..499e29a09 --- /dev/null +++ b/operators/cuda/paged_attention_impl.h @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include + +template +using UniquePtrWithDeletor = std::unique_ptr>; + +template +inline UniquePtrWithDeletor GetScratchBuffer(void* p, OrtAllocator* allocator) { + return UniquePtrWithDeletor{static_cast(p), [allocator = std::move(allocator)](T* p) { + allocator->Free(allocator, p); + }}; +} + +namespace cuda { +struct InputMetadata { + //int64_t schedule_type; // 0: vllm. 1:sarathi, 2:custom, 3:self-build + //int64_t block_tables; + int64_t max_num_blocks_per_seq; + //int64_t context_lens; + int64_t max_context_len = 0; + int64_t num_prompt_tokens = 0; + int64_t num_valid_tokens = 0; + //int64_t slot_mapping; + int64_t num_generation_tokens = 0; + + UniquePtrWithDeletor position_ids; +}; + +void InitializeHeadMapping(); + +// TODO(leca): remove unnecessary parameters +template +OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, + const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, + const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, + std::optional*> context_lens, + std::optional*> positions, InputMetadata& input_metadata); + +void paged_attention_v1( + const cudaStream_t stream, + void* out, // [num_seqs, num_heads, head_size] + const void* query, // [num_seqs, num_heads, head_size] + const void* key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const void* value_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* head_mapping, // [num_heads] + float scale, + const int* block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* context_lens, // [num_seqs] + int block_size, + int max_context_len, + const float* __restrict__ alibi_slopes, + const int max_num_blocks_per_seq, + const int64_t* query_shapes, + int num_queries_per_kv, + int dtype); +// const void* kv_quant_params_cache = nullptr, // [num_blocks, 2, num_kv_heads, head_size / kv_quant_chunk_size, block_size] +// int kv_quant_chunk_size = 0, +// int kv_quant_param_dtype = 0); + +void paged_attention_v2( + const cudaStream_t stream, + void* out, // [num_seqs, num_heads, head_size] + void* exp_sums, // [num_seqs, num_heads, max_num_partitions] + void* max_logits, // [num_seqs, num_heads, max_num_partitions] + void* tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const void* query, // [num_seqs, num_heads, head_size] + const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const void* value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* head_mapping, // [num_heads] + float scale, + const int* block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* context_lens, // [num_seqs] + int block_size, + int max_context_len, + const float* alibi_slopes, + const int max_num_blocks_per_seq, + const int64_t* query_shapes, + int num_queries_per_kv, + int dtype); + +void reshape_and_cache( + const cudaStream_t stream, + const void* key, // [num_tokens, num_heads, head_size] + const void* value, // [num_tokens, num_heads, head_size] + const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const void* value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* slot_mapping, // [num_tokens] + const int64_t* key_shapes, + const int64_t* value_shapes, + const int64_t block_size, + const int vec_x, + int dtype); +// void* kv_quant_param = nullptr, // [num_blocks, 2, num_heads, head_size / kv_quant_chunk_size, block_size] +// const int kv_quant_chunk_size = 0, +// const int kv_quant_param_dtype = 1); + +void rotary_embedding_neox( + const cudaStream_t stream, + const int64_t* positions, // [num_tokens] + void* query, // [num_tokens, num_heads * head_size] + void* key, // [num_tokens, num_kv_heads * head_size] + int head_size, + const void* cos_sin_cache, // [max_position, rot_dim] + int num_tokens, + int rot_dim, + int num_heads, + int num_kv_heads, + int dtype); +} // namespace cuda \ No newline at end of file diff --git a/operators/cuda/utils.cuh b/operators/cuda/utils.cuh index fe3d27daa..322e0c44f 100644 --- a/operators/cuda/utils.cuh +++ b/operators/cuda/utils.cuh @@ -191,3 +191,8 @@ __device__ __inline__ half2 _Tanh(half2 a) { template <> __device__ __inline__ ortc::BFloat16 _Tanh(ortc::BFloat16 a) { return tanhf(static_cast(a)); } + +inline OrtStatusPtr CudaCall(cudaError_t cuda_error) { + if (cuda_error == cudaSuccess) return nullptr; + return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str()); +} \ No newline at end of file From 6931a5211ec48604c82b91bcdbf33a77c76323b8 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 7 May 2024 17:38:07 +0000 Subject: [PATCH 02/17] add kernel functions and build successfully --- operators/cuda/cuda_ops.cc | 2 +- operators/cuda/paged_attention.h | 45 +- operators/cuda/paged_attention_impl.cu | 1101 ++++++++++++++++++++++-- operators/cuda/paged_attention_impl.h | 46 +- operators/cuda/paged_dtype_float16.cuh | 469 ++++++++++ operators/cuda/paged_dtype_float32.cuh | 274 ++++++ operators/cuda/paged_generic.cuh | 65 ++ operators/cuda/paged_utils.cuh | 59 ++ 8 files changed, 1962 insertions(+), 99 deletions(-) create mode 100644 operators/cuda/paged_dtype_float16.cuh create mode 100644 operators/cuda/paged_dtype_float32.cuh create mode 100644 operators/cuda/paged_generic.cuh create mode 100644 operators/cuda/paged_utils.cuh diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index ef29dc936..8770bb42a 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -17,7 +17,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { , CustomCudaStructV2("FastGelu", contrib::FastGelu), #if ORT_API_VERSION >= 18 - CustomCudaStructV2("PagedAttention", contrib::PagedAttention), + CustomCudaStructV2("PagedAttention", PagedAttention), #endif #if ORT_API_VERSION >= 16 CustomCudaStructV2("FastGelu", contrib::FastGelu), diff --git a/operators/cuda/paged_attention.h b/operators/cuda/paged_attention.h index 704c170c0..528a8f323 100644 --- a/operators/cuda/paged_attention.h +++ b/operators/cuda/paged_attention.h @@ -2,8 +2,36 @@ // Licensed under the MIT License. #pragma once +#include "ocos.h" +#include "cuda_type.h" #include "paged_attention_impl.h" +void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count); + +template +using UniquePtrWithDeletor = std::unique_ptr>; + +template +inline UniquePtrWithDeletor GetScratchBuffer(void* p, OrtAllocator* allocator) { + return UniquePtrWithDeletor{static_cast(p), [allocator = std::move(allocator)](T* p) { + allocator->Free(allocator, p); + }}; +} + +struct InputMetadata { + //int64_t schedule_type; // 0: vllm. 1:sarathi, 2:custom, 3:self-build + //int64_t block_tables; + int64_t max_num_blocks_per_seq; + //int64_t context_lens; + int64_t max_context_len = 0; + int64_t num_prompt_tokens = 0; + int64_t num_valid_tokens = 0; + //int64_t slot_mapping; + int64_t num_generation_tokens = 0; + + UniquePtrWithDeletor position_ids; +}; + template struct PagedAttention { OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { @@ -33,23 +61,24 @@ struct PagedAttention { allocator_ = UniquePtrWithDeletor{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; head_mapping_ = GetScratchBuffer(allocator_->Alloc(allocator_.get(), num_heads_), allocator_.get()); InitializeHeadMapping(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size()); + return nullptr; } OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, std::optional*> context_lens, - std::optional*> positions + std::optional*> positions, std::optional*> cos_sin_cache, ortc::Tensor& attn_out) const { InputMetadata input_metadata; - ORTX_RETURN_IF_ERROR(CheckInputs(ctx.GetCudaStream(), allocator_.get(), query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, input_metadata)); +// ORTX_RETURN_IF_ERROR(CheckInputs(ctx->GetCudaStream(), allocator_.get(), query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, input_metadata, num_heads_, head_size_)); const std::vector& query_shape = query.Shape(); T* output_data = attn_out.Allocate(query_shape); if (cos_sin_cache.has_value()) { int64_t rot_dim = (*cos_sin_cache)->Shape()[1]; assert(rot_dim == head_size_); - rotary_embedding_neox(reinterpret_cast(ctx.GetCudaStream()), (*positions)->Data(), query.DataRaw(), key.DataRaw(), head_size_, + cuda::rotary_embedding_neox(reinterpret_cast(ctx->GetCudaStream()), (*positions)->Data(), const_cast(query.DataRaw()), const_cast(key.DataRaw()), head_size_, (*cos_sin_cache)->DataRaw(), input_metadata.num_valid_tokens, rot_dim, num_heads_, num_kv_heads_, 1); } @@ -58,11 +87,11 @@ struct PagedAttention { int64_t key_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; int64_t value_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; int block_size = gsl::narrow(key_cache_shape[3]); - reshape_and_cache(reinterpret_cast(ctx.GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(), + cuda::reshape_and_cache(reinterpret_cast(ctx->GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(), key_shape_r, value_shape_r, block_size, key_cache_shape[4], 1); } - using TT = typename CudaT::MappedType; + using TT = typename contrib::CudaT::MappedType; if (input_metadata.num_prompt_tokens > 0) { //TODO(leca): flash attention for prompt > 0 case return nullptr; // Don't handle prompt with decoding case for now @@ -74,21 +103,21 @@ struct PagedAttention { bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE; int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_}; if (use_v1) { - paged_attention_v1(reinterpret_cast(ctx.GetCudaStream()), reinterpret_cast(output_data), query.DataRaw(), + cuda::paged_attention_v1(reinterpret_cast(ctx->GetCudaStream()), reinterpret_cast(output_data), query.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, value_cache.Shape()[3], input_metadata.max_context_len, nullptr, input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); } else { OrtMemoryInfo* mem_info = nullptr; - ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx.device_id, OrtMemTypeDefault, &mem_info)); + ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx->GetCudaDeviceId(), OrtMemTypeDefault, &mem_info)); void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); UniquePtrWithDeletor tmp_output = GetScratchBuffer(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); UniquePtrWithDeletor exp_sums = GetScratchBuffer(exp_sums_raw, allocator_.get()); void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); UniquePtrWithDeletor max_logits = GetScratchBuffer(max_logits_raw, allocator_.get()); - paged_attention_v2(reinterpret_cast(ctx.GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast(output_data), query.DataRaw(), + cuda::paged_attention_v2(reinterpret_cast(ctx->GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast(output_data), query.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, value_cache.Shape()[3], input_metadata.max_context_len, nullptr, diff --git a/operators/cuda/paged_attention_impl.cu b/operators/cuda/paged_attention_impl.cu index 4d5c286bb..e3318d985 100644 --- a/operators/cuda/paged_attention_impl.cu +++ b/operators/cuda/paged_attention_impl.cu @@ -1,73 +1,822 @@ #include "paged_attention_impl.h" +#include "utils.cuh" + +#include "paged_generic.cuh" +#include "paged_dtype_float16.cuh" +#include "paged_dtype_float32.cuh" +#include "paged_utils.cuh" #include +void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count) { + cudaMemcpy(dest_data, src_data, count, cudaMemcpyHostToDevice); +} + namespace cuda { -inline OrtStatusPtr CudaCall(cudaError_t cuda_error) { - if (cuda_error == cudaSuccess) return nullptr; - return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str()); -} +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b)) -void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count) { - cudaMemcpy(dest_data, src_data, count, cudaMemcpyHostToDevice); +namespace vllm { +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); } -template -OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, - const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, - const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, - std::optional*> context_lens, - std::optional*> positions, InputMetadata& input_metadata) { - const std::vector& query_shape = query.Shape(); - if (query_shape.size() < 2 || query_shape.size() > 3) { - return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template < + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE = 0> // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + + const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + if (block_idx == num_context_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the context, + // we should explicitly zero out the values since they may contain NaNs. + // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } } - if (query_shape.back() != num_heads_ * head_size_) { - return OrtW::CreateStatus(MakesString("query shape should equal to num_heads_ * head_size_")); + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); - // TODO(leca): Cpu input or CUDA input? - int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0]; - if (positions.has_value()) { - std::vector positions_host((*positions)->Shape().size()); - ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); - while (positions_host.back() == 0) { - positions_host.pop_back(); - seq_len--; + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } } + } + __syncthreads(); - input_metadata.max_num_blocks_per_seq = 0; - // in prompt mode - if (positions_host.size() > 1 || positions_host.back() == 0) { - input_metadata.num_prompt_tokens = seq_len; - input_metadata.num_generation_tokens = 0; - } else { - input_metadata.num_prompt_tokens = 0; - input_metadata.num_generation_tokens = seq_len; - input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty? - - int32_t block_size = gsl::narrow(key_cache.Shape()[3]); - for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } } - } else { - // TODO(leca): context_lens is nullptr? - std::vector context_len_host((*context_lens)->SizeInBytes()); - ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), *(context_lens)->DataRaw(), *(context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); - std::vector position_ids; - for (size_t i = 0; i < context_len_host.size(); i++) { - if (context_len_host[i] == 0) continue; - std::vector position_id(context_len_host[i]); - std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1} - position_ids.insert(position_ids.end(), position_id.begin(), position_id.end()); + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); } - input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, cnt), allocator); - ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); } - input_metadata.num_valid_tokens = seq_len; - - return nullptr; + } } +// Grid: (num_heads, num_seqs, 1). +template < + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, + out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template < + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs). +template < + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size, + const int x) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int src_key_idx = token_idx * key_stride + i; + const int src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset; + const int tgt_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset; + //{ + // if (key_cache[tgt_key_idx] - key[src_key_idx] > half(0.1)) { + // printf("key error find, %d,%d ", tgt_key_idx, src_key_idx); + // } + // if (value_cache[tgt_value_idx] - value[src_value_idx] > half(0.1)) { + // printf("key error find, %d %d", tgt_value_idx, src_value_idx); + // } + //} + key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); + value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + } +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int query_stride, + const int key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } +} +} // namespace vllm + +//template +//OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, +// const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, +// const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, +// std::optional*> context_lens, +// std::optional*> positions, InputMetadata& input_metadata, int32_t num_heads, int32_t head_size) { +// const std::vector& query_shape = query.Shape(); +// if (query_shape.size() < 2 || query_shape.size() > 3) { +// return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); +// } +// if (query_shape.back() != num_heads * head_size) { +// return OrtW::CreateStatus(MakeString("query shape should equal to num_heads_ * head_size_"), ORT_INVALID_ARGUMENT); +// } +// +// // TODO(leca): Cpu input or CUDA input? +// int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0]; +// if (positions.has_value()) { +// std::vector positions_host((*positions)->Shape().size()); +// ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); +// while (positions_host.back() == 0) { +// positions_host.pop_back(); +// seq_len--; +// } +// +// input_metadata.max_num_blocks_per_seq = 0; +// // in prompt mode +// if (positions_host.size() > 1 || positions_host.back() == 0) { +// input_metadata.num_prompt_tokens = seq_len; +// input_metadata.num_generation_tokens = 0; +// } else { +// input_metadata.num_prompt_tokens = 0; +// input_metadata.num_generation_tokens = seq_len; +// input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty? +// +// int32_t block_size = gsl::narrow(key_cache.Shape()[3]); +// for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; +// } +// } else { +// // TODO(leca): context_lens is nullptr? +// std::vector context_len_host((*context_lens)->SizeInBytes()); +// ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), (*context_lens)->DataRaw(), (*context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); +// std::vector position_ids; +// for (size_t i = 0; i < context_len_host.size(); i++) { +// if (context_len_host[i] == 0) continue; +// std::vector position_id(context_len_host[i]); +// std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1} +// position_ids.insert(position_ids.end(), position_id.begin(), position_id.end()); +// } +// input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, position_ids.size()), allocator); +// ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); +// } +// input_metadata.num_valid_tokens = seq_len; +// +// return nullptr; +//} + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + cudaFuncSetAttribute( \ + vllm::paged_attention_v1_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + +// TODO(woosuk): Tune NUM_THREADS. +template < + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void paged_attention_v1_launcher( + const cudaStream_t stream, + T* out, + const T* query, + const T* key_cache, + const T* value_cache, + const int* head_mapping, + float scale, + const int* block_tables, + const int* context_lens, + int max_context_len, + const float* alibi_slopes, + const int64_t max_num_blocks_per_seq, + const int64_t* query_shapes, + const int64_t num_queries_per_kv) { + int num_seqs = query_shapes[0]; + int num_heads = query_shapes[1]; + int head_size = query_shapes[2]; + //int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = num_heads * head_size; // query.stride(0); + int kv_block_stride = q_stride / num_queries_per_kv * BLOCK_SIZE; // key_cache.stride(0); + int kv_head_stride = head_size * BLOCK_SIZE;//key_cache.stride(1); + +#ifndef NDEBUG + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); +#endif + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? alibi_slopes : nullptr; + + T* out_ptr = reinterpret_cast(out); + const T* query_ptr = reinterpret_cast(query); + const T* key_cache_ptr = reinterpret_cast(key_cache); + const T* value_cache_ptr = reinterpret_cast(value_cache); + const int* head_mapping_ptr = reinterpret_cast(head_mapping); + const int* block_tables_ptr = block_tables; + const int* context_lens_ptr = context_lens; + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + // TORCH_CHECK(false, "Unsupported head size: ", head_size); + abort(); + break; + } +} + +#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_launcher( \ + stream, \ + (T*)out, \ + (const T*)query, \ + (const T*)key_cache, \ + (const T*)value_cache, \ + (const int*)head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes, \ + max_num_blocks_per_seq, \ + query_shapes, \ + num_queries_per_kv); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, 32); \ + break; \ + default: \ + abort(); \ + break; \ + } + void paged_attention_v1( const cudaStream_t stream, void* out, // [num_seqs, num_heads, head_size] @@ -85,10 +834,174 @@ void paged_attention_v1( const int64_t* query_shapes, int num_queries_per_kv, int dtype) { + if (dtype == 0) { // Float + CALL_V1_LAUNCHER_BLOCK_SIZE(float); + } else if (dtype == 1) { // Half + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (dtype == 2) { // BFloat16 + // CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + // TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + context_lens_ptr, \ + max_num_partitions); + +template < + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> +void paged_attention_v2_launcher( + const cudaStream_t stream, + T* out, + void* exp_sums, + void* max_logits, + void* tmp_out, + const T* query, + const T* key_cache, + const T* value_cache, + const int* head_mapping, + float scale, + const int* block_tables, + const int* context_lens, + int max_context_len, + const float* alibi_slopes, + const int64_t max_num_blocks_per_seq, + const int64_t* query_shapes, + const int64_t num_queries_per_kv) { + int num_seqs = query_shapes[0]; + int num_heads = query_shapes[1]; + int head_size = query_shapes[2]; + // int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = num_heads * head_size; // query.stride(0); + int kv_block_stride = q_stride / num_queries_per_kv * BLOCK_SIZE; // key_cache.stride(0); + int kv_head_stride = head_size * BLOCK_SIZE; // key_cache.stride(1); + +#ifndef NDEBUG + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); +#endif + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? alibi_slopes : nullptr; + + T* out_ptr = reinterpret_cast(out); + float* exp_sums_ptr = reinterpret_cast(exp_sums); + float* max_logits_ptr = reinterpret_cast(max_logits); + T* tmp_out_ptr = reinterpret_cast(tmp_out); + + const T* query_ptr = reinterpret_cast(query); + const T* key_cache_ptr = reinterpret_cast(key_cache); + const T* value_cache_ptr = reinterpret_cast(value_cache); + const int* head_mapping_ptr = reinterpret_cast(head_mapping); + const int* block_tables_ptr = block_tables; + const int* context_lens_ptr = context_lens; + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + abort(); + break; + } } -template +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_launcher( \ + stream, \ + (T*)out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + (const T*)query, \ + (const T*)key_cache, \ + (const T*)value_cache, \ + (const int*)head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes, \ + max_num_blocks_per_seq, \ + query_shapes, \ + num_queries_per_kv); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, 32); \ + break; \ + default: \ + abort(); \ + break; \ + } + void paged_attention_v2( const cudaStream_t stream, void* out, // [num_seqs, num_heads, head_size] @@ -109,7 +1022,15 @@ void paged_attention_v2( const int64_t* query_shapes, int num_queries_per_kv, int dtype) { - + if (dtype == 0) { // Float + CALL_V2_LAUNCHER_BLOCK_SIZE(float); + } else if (dtype == 1) { // Half + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (dtype == 2) { // BFloat16 + // CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + //TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } void rotary_embedding_neox( @@ -123,8 +1044,51 @@ void rotary_embedding_neox( int rot_dim, int num_heads, int num_kv_heads, - int dtype) { + int dtype) { // TODO(leca): only implemented dtype==1 + const bool is_neox = true; + int query_stride = num_heads * head_size; + int key_stride = num_kv_heads * head_size; + // TORCH_CHECK(stride == key.stride(0)); + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + + if (dtype == 0) { + // float + // CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + // } else if constexpr (std::is_same_v) { + } else if (dtype == 1) { + // half + using scalar_t = half; + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions, + static_cast(query), + static_cast(key), + static_cast(cos_sin_cache), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::rotary_embedding_kernel<<>>( + positions, + static_cast(query), + static_cast(key), + static_cast(cos_sin_cache), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } + //} else if constexpr (std::is_same_v) { + } else if (dtype == 2) { + // CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } } void reshape_and_cache( @@ -139,7 +1103,34 @@ void reshape_and_cache( const int64_t block_size, const int vec_x, int dtype) { + int num_tokens = key_shapes[0]; + int num_heads = key_shapes[1]; + int head_size = key_shapes[2]; + // int block_size = key_cache.size(3); + int x = vec_x; + + int key_stride = key_shapes[1] * key_shapes[2]; + int value_stride = value_shapes[1] * value_shapes[2]; + + // static_assert(std::is_same_v, "Unsupported data type: "); + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + // if constexpr (std::is_same_v) { + if (dtype == 1) { + vllm::reshape_and_cache_kernel<<>>( + (const half*)key, + (const half*)value, + (half*)key_cache, + (half*)value_cache, + slot_mapping, + key_stride, + value_stride, + num_heads, + head_size, + block_size, + x); + } } } // namespace cuda \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.h b/operators/cuda/paged_attention_impl.h index 499e29a09..2726381a7 100644 --- a/operators/cuda/paged_attention_impl.h +++ b/operators/cuda/paged_attention_impl.h @@ -2,44 +2,20 @@ // Licensed under the MIT License. #pragma once -#include "ocos.h" +#include "ortx_common.h" +#include "gsl/narrow" #include -template -using UniquePtrWithDeletor = std::unique_ptr>; - -template -inline UniquePtrWithDeletor GetScratchBuffer(void* p, OrtAllocator* allocator) { - return UniquePtrWithDeletor{static_cast(p), [allocator = std::move(allocator)](T* p) { - allocator->Free(allocator, p); - }}; -} - namespace cuda { -struct InputMetadata { - //int64_t schedule_type; // 0: vllm. 1:sarathi, 2:custom, 3:self-build - //int64_t block_tables; - int64_t max_num_blocks_per_seq; - //int64_t context_lens; - int64_t max_context_len = 0; - int64_t num_prompt_tokens = 0; - int64_t num_valid_tokens = 0; - //int64_t slot_mapping; - int64_t num_generation_tokens = 0; - - UniquePtrWithDeletor position_ids; -}; - -void InitializeHeadMapping(); - -// TODO(leca): remove unnecessary parameters -template -OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, - const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, - const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, - std::optional*> context_lens, - std::optional*> positions, InputMetadata& input_metadata); - +// +//// TODO(leca): move the implementation to paged_attention.h and remove unnecessary parameters +//template +//OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, +// const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, +// const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, +// std::optional*> context_lens, +// std::optional*> positions, InputMetadata& input_metadata, int32_t num_heads, int32_t head_size); +// void paged_attention_v1( const cudaStream_t stream, void* out, // [num_seqs, num_heads, head_size] diff --git a/operators/cuda/paged_dtype_float16.cuh b/operators/cuda/paged_dtype_float16.cuh new file mode 100644 index 000000000..132c8154c --- /dev/null +++ b/operators/cuda/paged_dtype_float16.cuh @@ -0,0 +1,469 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "paged_generic.cuh" +#include "paged_dtype_float32.cuh" + +#include +namespace cuda { +namespace vllm { + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" + : "=r"(b) + : "h"(a)); + return b; +} + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" + : "=f"(f) + : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" + : "=h"(lo), "=h"(hi) + : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" + : "=h"(tmp.u16[0]) + : "f"(f)); + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" + : "=h"(tmp.u16[0]) + : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" + : "=h"(tmp.u16[1]) + : "f"(f.y)); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" + : "=h"(c) + : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" + : "=r"(c) + : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" + : "=h"(c) + : "h"(a), "h"(b)); + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" + : "=r"(c) + : "r"(a), "r"(b)); + return c; +} + +template <> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template <> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template <> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template <> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { + return half_to_float(u); +} + +inline __device__ float2 to_float(uint32_t u) { + return half2_to_float2(u); +} + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a vector. +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} + +} // namespace vllm +} // namespace cuda diff --git a/operators/cuda/paged_dtype_float32.cuh b/operators/cuda/paged_dtype_float32.cuh new file mode 100644 index 000000000..93f2c1f5c --- /dev/null +++ b/operators/cuda/paged_dtype_float32.cuh @@ -0,0 +1,274 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "paged_generic.cuh" + +#include +namespace cuda { +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { + return a + b; +} + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template <> +inline __device__ float sum(float v) { + return v; +} + +template <> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template <> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template <> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template <> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { + return a * b; +} + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { + dst = src; +} + +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} + +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} + +// From float to float. +inline __device__ float to_float(float u) { + return u; +} + +inline __device__ float2 to_float(float2 u) { + return u; +} + +inline __device__ float4 to_float(float4 u) { + return u; +} + +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} + +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} + +// Zero-out a variable. +inline __device__ void zero(float& dst) { + dst = 0.f; +} + +} // namespace vllm +} // namespace cuda diff --git a/operators/cuda/paged_generic.cuh b/operators/cuda/paged_generic.cuh new file mode 100644 index 000000000..b35500f54 --- /dev/null +++ b/operators/cuda/paged_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +namespace cuda { +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm +} // namespace cuda diff --git a/operators/cuda/paged_utils.cuh b/operators/cuda/paged_utils.cuh new file mode 100644 index 000000000..3f01ece8c --- /dev/null +++ b/operators/cuda/paged_utils.cuh @@ -0,0 +1,59 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "paged_generic.cuh" +#include "paged_dtype_float16.cuh" +#include "paged_dtype_float32.cuh" + +#include +#include +namespace cuda { + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm +} // namespace cuda From 77ae3b0b69c357d6346c771137665e8da36d26f8 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 8 May 2024 14:46:50 +0000 Subject: [PATCH 03/17] move checkInput() to paged_attention.h --- operators/cuda/paged_attention.h | 92 +++++++++++++-- operators/cuda/paged_attention_impl.cu | 149 +++++++++++++++---------- operators/cuda/paged_attention_impl.h | 67 +++++++++-- 3 files changed, 230 insertions(+), 78 deletions(-) diff --git a/operators/cuda/paged_attention.h b/operators/cuda/paged_attention.h index 528a8f323..12ff6be31 100644 --- a/operators/cuda/paged_attention.h +++ b/operators/cuda/paged_attention.h @@ -6,8 +6,6 @@ #include "cuda_type.h" #include "paged_attention_impl.h" -void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count); - template using UniquePtrWithDeletor = std::unique_ptr>; @@ -18,6 +16,16 @@ inline UniquePtrWithDeletor GetScratchBuffer(void* p, OrtAllocator* allocator }}; } +struct AttnBias { + typedef struct { + int64_t seqstart; + int64_t max_seqlen; + int64_t seqstart_py; + } block_tables; + block_tables q_seqinfo; + int64_t batchsize; +}; + struct InputMetadata { //int64_t schedule_type; // 0: vllm. 1:sarathi, 2:custom, 3:self-build //int64_t block_tables; @@ -28,12 +36,73 @@ struct InputMetadata { int64_t num_valid_tokens = 0; //int64_t slot_mapping; int64_t num_generation_tokens = 0; - + AttnBias attn_bias; UniquePtrWithDeletor position_ids; }; +//// TODO(leca): remove unnecessary parameters, move all cuda call to .cu file and check return value by calling CudaCall(). +template +OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, + const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, + const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, + std::optional*> context_lens, + std::optional*> positions, int32_t num_heads, int32_t head_size, InputMetadata& input_metadata) { + const std::vector& query_shape = query.Shape(); + if (query_shape.size() < 2 || query_shape.size() > 3) { + return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); + } + if (query_shape.back() != num_heads * head_size) { + return OrtW::CreateStatus(MakeString("query shape should equal to num_heads_ * head_size_"), ORT_INVALID_ARGUMENT); + } + + // TODO(leca): Cpu input or CUDA input? + int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0]; + if (positions.has_value()) { + std::vector positions_host((*positions)->Shape().size()); + //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); + cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost); + while (positions_host.back() == 0) { + positions_host.pop_back(); + seq_len--; + } + + input_metadata.max_num_blocks_per_seq = 0; + // in prompt mode + if (positions_host.size() > 1 || positions_host.back() == 0) { + input_metadata.num_prompt_tokens = seq_len; + input_metadata.num_generation_tokens = 0; + } else { + input_metadata.num_prompt_tokens = 0; + input_metadata.num_generation_tokens = seq_len; + input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty? + + int32_t block_size = gsl::narrow(key_cache.Shape()[3]); + for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; + } + } else { + // TODO(leca): context_lens is nullptr? + std::vector context_len_host((*context_lens)->SizeInBytes()); + //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), (*context_lens)->DataRaw(), (*context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); + cudaMemcpy(context_len_host.data(), (*context_lens)->DataRaw(), (*context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost); + std::vector position_ids; + for (size_t i = 0; i < context_len_host.size(); i++) { + if (context_len_host[i] == 0) continue; + std::vector position_id(context_len_host[i]); + std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1} + position_ids.insert(position_ids.end(), position_id.begin(), position_id.end()); + } + input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, position_ids.size()), allocator); + //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); + cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream); + } + input_metadata.num_valid_tokens = seq_len; + + return nullptr; +} + template struct PagedAttention { + using TT = typename contrib::CudaT::MappedType; OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { int64_t num_heads = 0, head_size = 0; ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "num_heads", &num_heads)); @@ -60,10 +129,17 @@ struct PagedAttention { ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator)); allocator_ = UniquePtrWithDeletor{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; head_mapping_ = GetScratchBuffer(allocator_->Alloc(allocator_.get(), num_heads_), allocator_.get()); - InitializeHeadMapping(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size()); + cudaMemcpy(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size(), cudaMemcpyHostToDevice); return nullptr; } + OrtStatusPtr RunMultiHeadAttention(Ort::Custom::CUDAKernelContext* ctx, PackedAttentionParameters& parameters) const { + PackedMultiHeadAttentionData data; + + return cuda::QkvToContext(reinterpret_cast(ctx->GetCudaStream()), parameters, data); +// return nullptr; + } + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, @@ -71,7 +147,9 @@ struct PagedAttention { std::optional*> positions, std::optional*> cos_sin_cache, ortc::Tensor& attn_out) const { InputMetadata input_metadata; -// ORTX_RETURN_IF_ERROR(CheckInputs(ctx->GetCudaStream(), allocator_.get(), query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, input_metadata, num_heads_, head_size_)); + PackedAttentionParameters parameters; + ORTX_RETURN_IF_ERROR(CheckInputs(reinterpret_cast(ctx->GetCudaStream()), allocator_.get(), query, key, value, + key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, num_heads_, head_size_, input_metadata)); const std::vector& query_shape = query.Shape(); T* output_data = attn_out.Allocate(query_shape); @@ -91,10 +169,8 @@ struct PagedAttention { key_shape_r, value_shape_r, block_size, key_cache_shape[4], 1); } - using TT = typename contrib::CudaT::MappedType; if (input_metadata.num_prompt_tokens > 0) { - //TODO(leca): flash attention for prompt > 0 case - return nullptr; // Don't handle prompt with decoding case for now + return RunMultiHeadAttention(ctx, parameters); // Don't handle prompt with decoding case for now } if (input_metadata.num_generation_tokens > 0) { diff --git a/operators/cuda/paged_attention_impl.cu b/operators/cuda/paged_attention_impl.cu index e3318d985..38dc93b4b 100644 --- a/operators/cuda/paged_attention_impl.cu +++ b/operators/cuda/paged_attention_impl.cu @@ -1,16 +1,12 @@ #include "paged_attention_impl.h" #include "utils.cuh" - +#include "device_prop.cuh" #include "paged_generic.cuh" #include "paged_dtype_float16.cuh" #include "paged_dtype_float32.cuh" #include "paged_utils.cuh" #include -void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count) { - cudaMemcpy(dest_data, src_data, count, cudaMemcpyHostToDevice); -} - namespace cuda { #define WARP_SIZE 32 @@ -623,62 +619,6 @@ __global__ void rotary_embedding_kernel( } } // namespace vllm -//template -//OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, -// const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, -// const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, -// std::optional*> context_lens, -// std::optional*> positions, InputMetadata& input_metadata, int32_t num_heads, int32_t head_size) { -// const std::vector& query_shape = query.Shape(); -// if (query_shape.size() < 2 || query_shape.size() > 3) { -// return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); -// } -// if (query_shape.back() != num_heads * head_size) { -// return OrtW::CreateStatus(MakeString("query shape should equal to num_heads_ * head_size_"), ORT_INVALID_ARGUMENT); -// } -// -// // TODO(leca): Cpu input or CUDA input? -// int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0]; -// if (positions.has_value()) { -// std::vector positions_host((*positions)->Shape().size()); -// ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); -// while (positions_host.back() == 0) { -// positions_host.pop_back(); -// seq_len--; -// } -// -// input_metadata.max_num_blocks_per_seq = 0; -// // in prompt mode -// if (positions_host.size() > 1 || positions_host.back() == 0) { -// input_metadata.num_prompt_tokens = seq_len; -// input_metadata.num_generation_tokens = 0; -// } else { -// input_metadata.num_prompt_tokens = 0; -// input_metadata.num_generation_tokens = seq_len; -// input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty? -// -// int32_t block_size = gsl::narrow(key_cache.Shape()[3]); -// for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; -// } -// } else { -// // TODO(leca): context_lens is nullptr? -// std::vector context_len_host((*context_lens)->SizeInBytes()); -// ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), (*context_lens)->DataRaw(), (*context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); -// std::vector position_ids; -// for (size_t i = 0; i < context_len_host.size(); i++) { -// if (context_len_host[i] == 0) continue; -// std::vector position_id(context_len_host[i]); -// std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1} -// position_ids.insert(position_ids.end(), position_id.begin(), position_id.end()); -// } -// input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, position_ids.size()), allocator); -// ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); -// } -// input_metadata.num_valid_tokens = seq_len; -// -// return nullptr; -//} - #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ vllm::paged_attention_v1_kernel, \ @@ -1133,4 +1073,91 @@ void reshape_and_cache( } } +#if USE_FLASH_ATTENTION +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int num_kv_heads = parameters.num_kv_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // Q, K and V pointers + const int model_dimension_qk = num_heads * qk_head_size; + const int model_dimension_v = num_kv_heads * v_head_size; + const size_t elements_qk = static_cast(parameters.token_count) * static_cast(model_dimension_qk); + const size_t elements_v = static_cast(parameters.token_count) * static_cast(model_dimension_v); + + // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH + if (!data.no_qkv_workspace) { + LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); + } + + float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; + int32_t* cu_seqlens_q = const_cast(data.cumulative_sequence_length); + int32_t* cu_seqlens_k = const_cast(data.cumulative_sequence_length); + const void* query = data.no_qkv_workspace ? data.query : data.workspace; + const void* key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); + const void* value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); + void* softmax_lse_buffer = data.no_qkv_workspace + ? data.workspace + : (data.workspace + elements_qk + elements_v + elements_v); + + ORT_RETURN_IF_ERROR( + onnxruntime::flash::mha_varlen_fwd( + device_prop, + stream, + const_cast(query), + const_cast(key), + const_cast(value), + data.output, + cu_seqlens_q, + cu_seqlens_k, + softmax_lse_buffer, + batch_size, + num_heads, + num_kv_heads, // num_heads_k + qk_head_size, + sequence_length, + sequence_length, + scale, + parameters.causal // is causal + )); + + return nullptr; +} +#endif + +template +OrtStatusPtr QkvToContext( + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data) { + const cudaDeviceProp& device_prop = DeviceProp::GetCudaDeviceProp(); +#if OCOS_USE_FLASH_ATTENTION + return FlashAttention(device_prop, stream, parameters, data); +#endif + return nullptr; +} + +//template OrtStatusPtr QkvToContext( +// cudaStream_t stream, +// PackedAttentionParameters& parameters, +// PackedMultiHeadAttentionData& data); + +template OrtStatusPtr QkvToContext( + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data); + } // namespace cuda \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.h b/operators/cuda/paged_attention_impl.h index 2726381a7..6c8149764 100644 --- a/operators/cuda/paged_attention_impl.h +++ b/operators/cuda/paged_attention_impl.h @@ -5,17 +5,60 @@ #include "ortx_common.h" #include "gsl/narrow" #include +#include + +enum AttentionQkvFormat { + UNKNOWN, // enum value not set, or depends on qkv projection implementation details + Q_K_V_BNSH, // for non-packed qkv, permuted + Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed +}; + +struct PackedAttentionParameters { + int batch_size; + int sequence_length; + int input_hidden_size; // hidden size of input + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V + int num_heads; + int num_kv_heads; + float scale; + int token_count; + int valid_token_count; + bool has_relative_position_bias; + bool broadcast_res_pos_bias; + bool causal; +}; + +template +struct PackedMultiHeadAttentionData { + const T* query; + const T* key; + const T* value; + const T* bias; + const T* relative_position_bias; + const int32_t* token_offset; + const int32_t* cumulative_sequence_length; + + AttentionQkvFormat source_qkv_format; + + bool no_qkv_workspace; + T* workspace; + T* output; + + void* fused_runner; + + bool use_flash_attention; + bool use_memory_efficient_attention; +}; namespace cuda { -// -//// TODO(leca): move the implementation to paged_attention.h and remove unnecessary parameters -//template -//OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, -// const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, -// const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, -// std::optional*> context_lens, -// std::optional*> positions, InputMetadata& input_metadata, int32_t num_heads, int32_t head_size); -// void paged_attention_v1( const cudaStream_t stream, void* out, // [num_seqs, num_heads, head_size] @@ -86,4 +129,10 @@ void rotary_embedding_neox( int num_heads, int num_kv_heads, int dtype); + +template +OrtStatusPtr QkvToContext( + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data); } // namespace cuda \ No newline at end of file From 9ac5f18e7ead3ce88e2840b56bba1a50c0a8044f Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 8 May 2024 18:46:33 +0000 Subject: [PATCH 04/17] call flash attention code for prompt mode --- cmake/ext_cuda.cmake | 4 +- operators/cuda/paged_attention.h | 90 ++++++++++++++++++++++---- operators/cuda/paged_attention_impl.cu | 88 +++++++++++++++++++++---- operators/cuda/paged_attention_impl.h | 13 ++++ 4 files changed, 168 insertions(+), 27 deletions(-) diff --git a/cmake/ext_cuda.cmake b/cmake/ext_cuda.cmake index aa7d3282c..4be088896 100644 --- a/cmake/ext_cuda.cmake +++ b/cmake/ext_cuda.cmake @@ -30,8 +30,8 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no add_compile_definitions(USE_CUDA) -set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use -set(OCOS_USE_FLASH_ATTENTION OFF) +#set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use +#set(OCOS_USE_FLASH_ATTENTION OFF) if (OCOS_USE_FLASH_ATTENTION) message(STATUS "Enable flash attention") add_compile_definitions(OCOS_USE_FLASH_ATTENTION) diff --git a/operators/cuda/paged_attention.h b/operators/cuda/paged_attention.h index 12ff6be31..cff4c1fff 100644 --- a/operators/cuda/paged_attention.h +++ b/operators/cuda/paged_attention.h @@ -38,6 +38,7 @@ struct InputMetadata { int64_t num_generation_tokens = 0; AttnBias attn_bias; UniquePtrWithDeletor position_ids; + UniquePtrWithDeletor seqinfo; }; //// TODO(leca): remove unnecessary parameters, move all cuda call to .cu file and check return value by calling CudaCall(). @@ -46,7 +47,7 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, std::optional*> context_lens, - std::optional*> positions, int32_t num_heads, int32_t head_size, InputMetadata& input_metadata) { + std::optional*> positions, int32_t num_heads, int32_t head_size, InputMetadata& input_metadata, PackedAttentionParameters& parameters) { const std::vector& query_shape = query.Shape(); if (query_shape.size() < 2 || query_shape.size() > 3) { return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); @@ -71,6 +72,14 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con if (positions_host.size() > 1 || positions_host.back() == 0) { input_metadata.num_prompt_tokens = seq_len; input_metadata.num_generation_tokens = 0; + + std::vector seqstart(2, 0); + seqstart[1] = input_metadata.num_prompt_tokens; + input_metadata.seqinfo = GetScratchBuffer(allocator->Alloc(allocator, seqstart.size() * sizeof(int32_t)), allocator); + cudaMemcpy(input_metadata.seqinfo.get(), seqstart.data(), seqstart.size() * sizeof(int32_t), cudaMemcpyHostToDevice); + input_metadata.attn_bias.q_seqinfo.seqstart = reinterpret_cast(input_metadata.seqinfo.get()); + input_metadata.attn_bias.q_seqinfo.max_seqlen = input_metadata.num_prompt_tokens; + input_metadata.attn_bias.batchsize = 1; } else { input_metadata.num_prompt_tokens = 0; input_metadata.num_generation_tokens = seq_len; @@ -91,12 +100,20 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1} position_ids.insert(position_ids.end(), position_id.begin(), position_id.end()); } - input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, position_ids.size()), allocator); + input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, position_ids.size()), allocator); // TODO(leca): position_ids.size() or position_ids.size() * sizeof(int64_t)? //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); - cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream); + cudaMemcpy(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice); } - input_metadata.num_valid_tokens = seq_len; - + input_metadata.num_valid_tokens = seq_len; + + parameters.batch_size = input_metadata.attn_bias.batchsize; + parameters.sequence_length = static_cast(input_metadata.attn_bias.q_seqinfo.max_seqlen); + parameters.input_hidden_size = -1; + parameters.token_count = static_cast(input_metadata.num_prompt_tokens); + parameters.valid_token_count = static_cast(input_metadata.num_valid_tokens); + parameters.has_relative_position_bias = false; + parameters.broadcast_res_pos_bias = false; + parameters.causal = true; return nullptr; } @@ -133,11 +150,49 @@ struct PagedAttention { return nullptr; } - OrtStatusPtr RunMultiHeadAttention(Ort::Custom::CUDAKernelContext* ctx, PackedAttentionParameters& parameters) const { + OrtStatusPtr RunMultiHeadAttention(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, const ortc::Tensor& value, + T* output, OrtMemoryInfo* mem_info, PackedAttentionParameters& parameters, InputMetadata& input_metadata) const { PackedMultiHeadAttentionData data; - + data.use_flash_attention = false; + data.use_memory_efficient_attention = false; +#if OCOS_USE_FLASH_ATTENTION + data.use_flash_attention = true; +#endif +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + data.use_memory_efficient_attention = true; +#endif + data.query = reinterpret_cast(query.DataRaw()); + data.key = reinterpret_cast(key.DataRaw()); + data.value = reinterpret_cast(value.DataRaw()); + + // TODO(leca): +// // broadcast key,value for GQA +// TensorShape key_shape({parameters.valid_token_count, parameters.num_kv_heads, parameters.head_size}); +// size_t kv_repeat_space = key_shape.Size() * (num_queries_per_kv_ > 0 ? num_queries_per_kv_ : 0); +// IAllocatorUniquePtr key_out = GetScratchBuffer(kv_repeat_space, context->GetComputeStream()); +// IAllocatorUniquePtr value_out = GetScratchBuffer(kv_repeat_space, context->GetComputeStream()); +// if (num_queries_per_kv_ > 1 && !ParseEnvironmentVariableWithDefault("repeat_kv_tile", false)) { +// // repeat key and value +// LaunchRepeatKeyValue(Stream(context), key_out.get(), value_out.get(), +// data.key, data.value, key_shape.GetDims().data(), num_queries_per_kv_); +// CHECK_CUDA_ERROR(); +// data.key = key_out.get(); +// data.value = value_out.get(); +// parameters.num_kv_heads = parameters.num_heads; +// DumpTensor(Stream(context), data.key, "repeat_key", kv_repeat_space * sizeof(CudaT)); +// } + + size_t workSpaceSize = cuda::GetAttentionWorkspaceSize(sizeof(T), parameters.batch_size, parameters.num_heads, parameters.head_size, parameters.v_head_size, + parameters.sequence_length, nullptr, data.use_flash_attention, data.use_memory_efficient_attention, true); + void* workspace_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, workSpaceSize); + UniquePtrWithDeletor workspace_unique = GetScratchBuffer(workspace_raw, allocator_.get()); + data.workspace = reinterpret_cast(workspace_unique.get()); + data.cumulative_sequence_length = reinterpret_cast(input_metadata.attn_bias.q_seqinfo.seqstart); + data.output = reinterpret_cast(output); + data.fused_runner = nullptr; + data.no_qkv_workspace = data.fused_runner == nullptr || data.use_flash_attention || data.use_memory_efficient_attention; + data.source_qkv_format = data.key == nullptr ? AttentionQkvFormat::QKV_TN3H : AttentionQkvFormat::Q_K_V_TNH; return cuda::QkvToContext(reinterpret_cast(ctx->GetCudaStream()), parameters, data); -// return nullptr; } OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, @@ -148,8 +203,18 @@ struct PagedAttention { std::optional*> cos_sin_cache, ortc::Tensor& attn_out) const { InputMetadata input_metadata; PackedAttentionParameters parameters; + OrtMemoryInfo* mem_info = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx->GetCudaDeviceId(), OrtMemTypeDefault, &mem_info)); ORTX_RETURN_IF_ERROR(CheckInputs(reinterpret_cast(ctx->GetCudaStream()), allocator_.get(), query, key, value, - key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, num_heads_, head_size_, input_metadata)); + key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, num_heads_, head_size_, input_metadata, parameters)); + parameters.head_size = head_size_; + parameters.num_heads = num_heads_; + parameters.num_kv_heads = num_kv_heads_; + parameters.scale = scale_; + parameters.hidden_size = static_cast(head_size_ * num_heads_); + parameters.v_hidden_size = static_cast(head_size_ * num_kv_heads_); + parameters.v_head_size = static_cast(parameters.head_size); + const std::vector& query_shape = query.Shape(); T* output_data = attn_out.Allocate(query_shape); @@ -170,7 +235,8 @@ struct PagedAttention { } if (input_metadata.num_prompt_tokens > 0) { - return RunMultiHeadAttention(ctx, parameters); // Don't handle prompt with decoding case for now + // TODO(leca): deallocate mem_info + return RunMultiHeadAttention(ctx, query, key, value, output_data, mem_info, parameters, input_metadata); // Don't handle prompt with decoding case for now } if (input_metadata.num_generation_tokens > 0) { @@ -185,8 +251,6 @@ struct PagedAttention { value_cache.Shape()[3], input_metadata.max_context_len, nullptr, input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); } else { - OrtMemoryInfo* mem_info = nullptr; - ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx->GetCudaDeviceId(), OrtMemTypeDefault, &mem_info)); void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); UniquePtrWithDeletor tmp_output = GetScratchBuffer(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); @@ -199,9 +263,9 @@ struct PagedAttention { value_cache.Shape()[3], input_metadata.max_context_len, nullptr, input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); - OrtW::API::ReleaseMemoryInfo(mem_info); } } + OrtW::API::ReleaseMemoryInfo(mem_info); return nullptr; } diff --git a/operators/cuda/paged_attention_impl.cu b/operators/cuda/paged_attention_impl.cu index 38dc93b4b..49f745a7f 100644 --- a/operators/cuda/paged_attention_impl.cu +++ b/operators/cuda/paged_attention_impl.cu @@ -5,6 +5,12 @@ #include "paged_dtype_float16.cuh" #include "paged_dtype_float32.cuh" #include "paged_utils.cuh" +#ifdef OCOS_USE_FLASH_ATTENTION +#include "attention_lib/flash_attention/flash_api.h" +#endif +#ifdef OCOS_USE_MEMORY_EFFICIENT_ATTENTION +#include "attention_lib/cutlass_fmha/memory_efficient_attention.h" +#endif #include namespace cuda { @@ -1073,9 +1079,9 @@ void reshape_and_cache( } } -#if USE_FLASH_ATTENTION +#if OCOS_USE_FLASH_ATTENTION template -Status FlashAttention( +OrtStatusPtr FlashAttention( const cudaDeviceProp& device_prop, cudaStream_t stream, PackedAttentionParameters& parameters, @@ -1094,13 +1100,14 @@ Status FlashAttention( const size_t elements_v = static_cast(parameters.token_count) * static_cast(model_dimension_v); // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH - if (!data.no_qkv_workspace) { - LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, - batch_size, sequence_length, - num_heads, qk_head_size, v_head_size, - data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, - data.token_offset, parameters.token_count, stream); - } + // TODO(leca): +// if (!data.no_qkv_workspace) { +// LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, +// batch_size, sequence_length, +// num_heads, qk_head_size, v_head_size, +// data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, +// data.token_offset, parameters.token_count, stream); +// } float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; @@ -1113,8 +1120,8 @@ Status FlashAttention( ? data.workspace : (data.workspace + elements_qk + elements_v + elements_v); - ORT_RETURN_IF_ERROR( - onnxruntime::flash::mha_varlen_fwd( + ORTX_RETURN_IF_ERROR( + flash::mha_varlen_fwd( device_prop, stream, const_cast(query), @@ -1131,7 +1138,8 @@ Status FlashAttention( sequence_length, sequence_length, scale, - parameters.causal // is causal + parameters.causal, // is causal + false // is_bf16 TODO(leca) )); return nullptr; @@ -1146,6 +1154,10 @@ OrtStatusPtr QkvToContext( const cudaDeviceProp& device_prop = DeviceProp::GetCudaDeviceProp(); #if OCOS_USE_FLASH_ATTENTION return FlashAttention(device_prop, stream, parameters, data); +#endif +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + // TODO(leca): + //return FusedAttentionCutlass(device_prop, stream, parameters, data); #endif return nullptr; } @@ -1160,4 +1172,56 @@ template OrtStatusPtr QkvToContext( PackedAttentionParameters& parameters, PackedMultiHeadAttentionData& data); +constexpr size_t kCUDAMemoryAlignment = 256; + +size_t GetAttentionScratchSize( + size_t element_size, + size_t batch_size, + size_t num_heads, + size_t sequence_length) { + const size_t bytes = element_size * batch_size * num_heads * sequence_length * sequence_length; + return ((bytes + kCUDAMemoryAlignment - 1) / kCUDAMemoryAlignment) * kCUDAMemoryAlignment; +} + +size_t GetAttentionWorkspaceSize( + size_t element_size, + size_t batch_size, + size_t num_heads, + size_t qk_head_size, + size_t v_head_size, + size_t sequence_length, + void* fused_runner, + bool use_flash_attention, + bool use_memory_efficient_attention, + bool no_qkv_workspace) { + // Note that q, k and v might need alignment for fused attention kernels. + const size_t qkv_bytes = no_qkv_workspace ? 0 : (element_size * batch_size * num_heads * sequence_length * (qk_head_size + qk_head_size + v_head_size)); + +#if USE_FLASH_ATTENTION + // Use portion of workspace for softmax buffer. + if (use_flash_attention) { + size_t flash_buffer_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + return qkv_bytes + flash_buffer_bytes; + } +#endif + + if (fused_runner != nullptr) { + return qkv_bytes; + } + +//#if USE_MEMORY_EFFICIENT_ATTENTION +// if (use_memory_efficient_attention) { +// size_t fmha_buffer_bytes = 0; +// if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) { +// fmha_buffer_bytes = batch_size * sequence_length * num_heads * v_head_size * sizeof(float); +// } +// return qkv_bytes + fmha_buffer_bytes; +// } +//#else +// ORT_UNUSED_PARAMETER(use_memory_efficient_attention); +//#endif + + return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length); +} + } // namespace cuda \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.h b/operators/cuda/paged_attention_impl.h index 6c8149764..ec7f4b2c3 100644 --- a/operators/cuda/paged_attention_impl.h +++ b/operators/cuda/paged_attention_impl.h @@ -135,4 +135,17 @@ OrtStatusPtr QkvToContext( cudaStream_t stream, PackedAttentionParameters& parameters, PackedMultiHeadAttentionData& data); + +size_t GetAttentionWorkspaceSize( + size_t element_size, + size_t batch_size, + size_t num_heads, + size_t qk_head_size, + size_t v_head_size, + size_t sequence_length, + void* fused_runner, + bool use_flash_attention, + bool use_memory_efficient_attention, + bool no_qkv_workspace); + } // namespace cuda \ No newline at end of file From 90ee6a6ed16b967633ec05cc2f7b870b6b957572 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Mon, 20 May 2024 14:57:37 +0000 Subject: [PATCH 05/17] runtime error in RunMultiHeadAttention --- operators/cuda/paged_attention.h | 90 ++++++++++++++------------ operators/cuda/paged_attention_impl.cu | 48 ++++++-------- operators/cuda/paged_attention_impl.h | 6 +- test/cuda/test_cudaops.py | 55 ++++++++++++++++ 4 files changed, 122 insertions(+), 77 deletions(-) diff --git a/operators/cuda/paged_attention.h b/operators/cuda/paged_attention.h index cff4c1fff..3c64fb493 100644 --- a/operators/cuda/paged_attention.h +++ b/operators/cuda/paged_attention.h @@ -45,8 +45,7 @@ struct InputMetadata { template OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, - const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, - std::optional*> context_lens, + const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, const ortc::Tensor& context_lens, std::optional*> positions, int32_t num_heads, int32_t head_size, InputMetadata& input_metadata, PackedAttentionParameters& parameters) { const std::vector& query_shape = query.Shape(); if (query_shape.size() < 2 || query_shape.size() > 3) { @@ -56,8 +55,7 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con return OrtW::CreateStatus(MakeString("query shape should equal to num_heads_ * head_size_"), ORT_INVALID_ARGUMENT); } - // TODO(leca): Cpu input or CUDA input? - int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0]; + int seq_len = query_shape[0]; if (positions.has_value()) { std::vector positions_host((*positions)->Shape().size()); //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); @@ -90,9 +88,9 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con } } else { // TODO(leca): context_lens is nullptr? - std::vector context_len_host((*context_lens)->SizeInBytes()); + std::vector context_len_host(context_lens.SizeInBytes()); //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), (*context_lens)->DataRaw(), (*context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); - cudaMemcpy(context_len_host.data(), (*context_lens)->DataRaw(), (*context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost); + cudaMemcpy(context_len_host.data(), context_lens.DataRaw(), context_lens.SizeInBytes(), cudaMemcpyDeviceToHost); std::vector position_ids; for (size_t i = 0; i < context_len_host.size(); i++) { if (context_len_host[i] == 0) continue; @@ -102,7 +100,7 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con } input_metadata.position_ids = GetScratchBuffer(allocator->Alloc(allocator, position_ids.size()), allocator); // TODO(leca): position_ids.size() or position_ids.size() * sizeof(int64_t)? //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); - cudaMemcpy(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice); + cudaMemcpy(input_metadata.position_ids.get(), position_ids.data(), position_ids.size() * sizeof(int64_t), cudaMemcpyHostToDevice); } input_metadata.num_valid_tokens = seq_len; @@ -119,6 +117,11 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con template struct PagedAttention { + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 8) return OrtMemType::OrtMemTypeCPUInput; // make is_prompt CPU input + return OrtMemType::OrtMemTypeDefault; + } + using TT = typename contrib::CudaT::MappedType; OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { int64_t num_heads = 0, head_size = 0; @@ -146,7 +149,7 @@ struct PagedAttention { ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator)); allocator_ = UniquePtrWithDeletor{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; head_mapping_ = GetScratchBuffer(allocator_->Alloc(allocator_.get(), num_heads_), allocator_.get()); - cudaMemcpy(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size(), cudaMemcpyHostToDevice); + cudaMemcpy(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size() * sizeof(int32_t), cudaMemcpyHostToDevice); return nullptr; } @@ -198,7 +201,7 @@ struct PagedAttention { OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, - std::optional*> context_lens, + const ortc::Tensor& context_lens, const ortc::Tensor& is_prompt, std::optional*> positions, std::optional*> cos_sin_cache, ortc::Tensor& attn_out) const { InputMetadata input_metadata; @@ -226,46 +229,47 @@ struct PagedAttention { } const std::vector& key_cache_shape = key_cache.Shape(); - if (input_metadata.num_valid_tokens > 0 && key_cache_shape.size() > 3) { + if (input_metadata.num_valid_tokens > 0) { int64_t key_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; int64_t value_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; - int block_size = gsl::narrow(key_cache_shape[3]); + int block_size = key_cache_shape[1] / (num_kv_heads_ * head_size_); + // TODO(leca): or we just pass num_valid_tokens, num_kv_head, head_size and block_size as parameter? cuda::reshape_and_cache(reinterpret_cast(ctx->GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(), - key_shape_r, value_shape_r, block_size, key_cache_shape[4], 1); + key_shape_r, value_shape_r, block_size); } - if (input_metadata.num_prompt_tokens > 0) { + if (*(is_prompt.Data())) { // TODO(leca): deallocate mem_info return RunMultiHeadAttention(ctx, query, key, value, output_data, mem_info, parameters, input_metadata); // Don't handle prompt with decoding case for now } - - if (input_metadata.num_generation_tokens > 0) { - constexpr int PARTITION_SIZE = 512; - int max_num_partitions = (input_metadata.max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE; - int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_}; - if (use_v1) { - cuda::paged_attention_v1(reinterpret_cast(ctx->GetCudaStream()), reinterpret_cast(output_data), query.DataRaw(), - key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, - block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, - value_cache.Shape()[3], input_metadata.max_context_len, nullptr, - input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); - } else { - void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); - UniquePtrWithDeletor tmp_output = GetScratchBuffer(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT - void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); - UniquePtrWithDeletor exp_sums = GetScratchBuffer(exp_sums_raw, allocator_.get()); - void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); - UniquePtrWithDeletor max_logits = GetScratchBuffer(max_logits_raw, allocator_.get()); - cuda::paged_attention_v2(reinterpret_cast(ctx->GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast(output_data), query.DataRaw(), - key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, - block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, - value_cache.Shape()[3], input_metadata.max_context_len, nullptr, - input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); - - } - } - OrtW::API::ReleaseMemoryInfo(mem_info); +// +// if (input_metadata.num_generation_tokens > 0) { +// constexpr int PARTITION_SIZE = 512; +// int max_num_partitions = (input_metadata.max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; +// bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE; +// int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_}; +// if (use_v1) { +// cuda::paged_attention_v1(reinterpret_cast(ctx->GetCudaStream()), reinterpret_cast(output_data), query.DataRaw(), +// key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, +// block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, +// value_cache.Shape()[3], input_metadata.max_context_len, nullptr, +// input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); +// } else { +// void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); +// UniquePtrWithDeletor tmp_output = GetScratchBuffer(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT +// void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); +// UniquePtrWithDeletor exp_sums = GetScratchBuffer(exp_sums_raw, allocator_.get()); +// void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); +// UniquePtrWithDeletor max_logits = GetScratchBuffer(max_logits_raw, allocator_.get()); +// cuda::paged_attention_v2(reinterpret_cast(ctx->GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast(output_data), query.DataRaw(), +// key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, +// block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, +// value_cache.Shape()[3], input_metadata.max_context_len, nullptr, +// input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); +// +// } +// } +// OrtW::API::ReleaseMemoryInfo(mem_info); return nullptr; } @@ -274,7 +278,7 @@ struct PagedAttention { int32_t num_kv_heads_; // number of attention kv_heads int32_t head_size_; // number of attention heads float scale_; // sqrt(head_size_) - UniquePtrWithDeletor head_mapping_; int32_t num_queries_per_kv_; - UniquePtrWithDeletor allocator_; + UniquePtrWithDeletor allocator_; // make allocator_ declared first in order to release it last + UniquePtrWithDeletor head_mapping_; }; \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.cu b/operators/cuda/paged_attention_impl.cu index 49f745a7f..a629124d9 100644 --- a/operators/cuda/paged_attention_impl.cu +++ b/operators/cuda/paged_attention_impl.cu @@ -12,6 +12,7 @@ #include "attention_lib/cutlass_fmha/memory_efficient_attention.h" #endif #include +#include namespace cuda { @@ -512,21 +513,16 @@ template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size, block_size] scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, const int head_size, - const int block_size, - const int x) { + const int block_size) { const int token_idx = blockIdx.x; const int slot_idx = slot_mapping[token_idx]; - if (slot_idx < 0) { - // Padding token that should be ignored. - return; - } const int block_idx = slot_idx / block_size; const int block_offset = slot_idx % block_size; @@ -537,11 +533,9 @@ __global__ void reshape_and_cache_kernel( const int head_idx = i / head_size; const int head_offset = i % head_size; - const int x_idx = head_offset / x; - const int x_offset = head_offset % x; - const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + block_offset * x + x_offset; const int tgt_value_idx = block_idx * num_heads * head_size * block_size + head_idx * head_size * block_size + head_offset * block_size + block_offset; + const int tgt_key_idx = tgt_value_idx; //{ // if (key_cache[tgt_key_idx] - key[src_key_idx] > half(0.1)) { // printf("key error find, %d,%d ", tgt_key_idx, src_key_idx); @@ -1041,19 +1035,16 @@ void reshape_and_cache( const cudaStream_t stream, const void* key, // [num_tokens, num_heads, head_size] const void* value, // [num_tokens, num_heads, head_size] - const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const void* key_cache, // [num_blocks, num_heads, head_size, block_size] const void* value_cache, // [num_blocks, num_heads, head_size, block_size] const int* slot_mapping, // [num_tokens] const int64_t* key_shapes, const int64_t* value_shapes, - const int64_t block_size, - const int vec_x, - int dtype) { + const int64_t block_size) { int num_tokens = key_shapes[0]; int num_heads = key_shapes[1]; int head_size = key_shapes[2]; // int block_size = key_cache.size(3); - int x = vec_x; int key_stride = key_shapes[1] * key_shapes[2]; int value_stride = value_shapes[1] * value_shapes[2]; @@ -1062,21 +1053,18 @@ void reshape_and_cache( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); - // if constexpr (std::is_same_v) { - if (dtype == 1) { - vllm::reshape_and_cache_kernel<<>>( - (const half*)key, - (const half*)value, - (half*)key_cache, - (half*)value_cache, - slot_mapping, - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - } + + vllm::reshape_and_cache_kernel<<>>( + (const half*)key, + (const half*)value, + (half*)key_cache, + (half*)value_cache, + slot_mapping, + key_stride, + value_stride, + num_heads, + head_size, + block_size); } #if OCOS_USE_FLASH_ATTENTION diff --git a/operators/cuda/paged_attention_impl.h b/operators/cuda/paged_attention_impl.h index ec7f4b2c3..c23017720 100644 --- a/operators/cuda/paged_attention_impl.h +++ b/operators/cuda/paged_attention_impl.h @@ -105,14 +105,12 @@ void reshape_and_cache( const cudaStream_t stream, const void* key, // [num_tokens, num_heads, head_size] const void* value, // [num_tokens, num_heads, head_size] - const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const void* key_cache, // [num_blocks, num_heads, head_size, block_size] const void* value_cache, // [num_blocks, num_heads, head_size, block_size] const int* slot_mapping, // [num_tokens] const int64_t* key_shapes, const int64_t* value_shapes, - const int64_t block_size, - const int vec_x, - int dtype); + const int64_t block_size); // void* kv_quant_param = nullptr, // [num_blocks, 2, num_heads, head_size / kv_quant_chunk_size, block_size] // const int kv_quant_chunk_size = 0, // const int kv_quant_param_dtype = 1); diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index d868fe675..758648be3 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -116,6 +116,61 @@ def test_cuda_fastgelu_f16(self): else: print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.') + @staticmethod + def _create_pagedattention_test_model(domain='ai.onnx.contrib'): + nodes = [ + helper.make_node('PagedAttention', + ['query', 'key', 'value', 'key_cache', 'value_cache', 'block_tables', 'slot_mappings', 'context_lens', 'is_prompt'], + ['attn_out'], + domain=domain, num_heads=32, num_kv_heads=32, head_size=16, scale=1.0) + ] + query = helper.make_tensor_value_info( + 'query', onnx_proto.TensorProto.FLOAT16, [87,512]) + key = helper.make_tensor_value_info( + 'key', onnx_proto.TensorProto.FLOAT16, [87,512]) + value = helper.make_tensor_value_info( + 'value', onnx_proto.TensorProto.FLOAT16, [87,512]) + key_cache = helper.make_tensor_value_info( + 'key_cache', onnx_proto.TensorProto.FLOAT16, [32,8192]) + value_cache = helper.make_tensor_value_info( + 'value_cache', onnx_proto.TensorProto.FLOAT16, [32,8192]) + block_tables = helper.make_tensor_value_info( + 'block_tables', onnx_proto.TensorProto.INT32, [5,3]) + slot_mappings = helper.make_tensor_value_info( + 'slot_mappings', onnx_proto.TensorProto.INT32, [87]) + context_lens = helper.make_tensor_value_info( + 'context_lens', onnx_proto.TensorProto.INT32, [5]) + is_prompt = helper.make_tensor_value_info( + 'is_prompt', onnx_proto.TensorProto.INT32, [1]) + attn_out = helper.make_tensor_value_info( + 'attn_out', onnx_proto.TensorProto.FLOAT16, [87,512]) + graph = helper.make_graph(nodes, 'test_paged_attention', + [query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, is_prompt], + [attn_out]) + model = make_onnx_model(graph) + return model + + def test_cuda_paged_attention(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_pagedattention_test_model() + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + query = np.random.randn(87,512).astype(np.float16) # 87 is the token num of all the sequences (5+12+16+20+34) + key = np.random.randn(87,512).astype(np.float16) + value = np.random.randn(87,512).astype(np.float16) + key_cache = np.zeros([32,8192]).astype(np.float16) + value_cache = np.zeros([32,8192]).astype(np.float16) + block_tables = np.array([[0,-1,-1],[1,-1,-1],[2,-1,-1],[3,4,-1],[5,6,7]]).astype(np.int32) + slot1 = np.arange(0, 5, dtype=np.int32) + slot2 = np.arange(16, 28, dtype=np.int32) + slot3 = np.arange(32, 68, dtype=np.int32) + slot4 = np.arange(80, 114, dtype=np.int32) + slot_mappings = np.concatenate((slot1, slot2, slot3, slot4)) + context_lens = np.array([5, 12, 16, 20, 34]).astype(np.int32) + is_prompt = np.array([1]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'key_cache':key_cache, 'value_cache':value_cache, 'block_tables':block_tables, 'slot_mappings':slot_mappings, 'context_lens':context_lens, 'is_prompt':is_prompt}) if __name__ == "__main__": unittest.main() From 17c2c4138e7faddf61b9087617bb35372f0467a8 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 21 May 2024 01:08:55 +0000 Subject: [PATCH 06/17] UT can run now --- .pyproject/cmdclass.py | 7 +- operators/cuda/paged_attention.h | 98 ++++++++++++------------- operators/cuda/paged_attention_impl.cu | 26 +------ operators/cuda/paged_attention_impl.h | 8 +- test/cuda/key.npy | Bin 0 -> 89216 bytes test/cuda/query.npy | Bin 0 -> 89216 bytes test/cuda/test_cudaops.py | 11 ++- test/cuda/value.npy | Bin 0 -> 89216 bytes 8 files changed, 67 insertions(+), 83 deletions(-) create mode 100644 test/cuda/key.npy create mode 100644 test/cuda/query.npy create mode 100644 test/cuda/value.npy diff --git a/.pyproject/cmdclass.py b/.pyproject/cmdclass.py index 16855a6b3..1016f77f2 100644 --- a/.pyproject/cmdclass.py +++ b/.pyproject/cmdclass.py @@ -148,6 +148,7 @@ def initialize_options(self): self.no_opencv = None self.cc_debug = None self.cuda_archs = None + self.ort_pkg_dir = None def _parse_options(self, options): for segment in options.split(','): @@ -189,7 +190,8 @@ def build_cmake(self, extension): ext_fullpath = pathlib.Path( self.get_ext_fullpath(extension.name)).absolute() - config = 'RelWithDebInfo' if self.debug else 'Release' +# config = 'RelWithDebInfo' if self.debug else 'Release' + config = 'Debug' if self.debug else 'Release' cmake_args = [ '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(ext_fullpath.parent.absolute()), @@ -199,6 +201,9 @@ def build_cmake(self, extension): '-DCMAKE_BUILD_TYPE=' + config ] + if self.ort_pkg_dir: + cmake_args += ['-DONNXRUNTIME_PKG_DIR=' + self.ort_pkg_dir] + if self.no_opencv: # Disabling openCV can drastically reduce the build time. cmake_args += [ diff --git a/operators/cuda/paged_attention.h b/operators/cuda/paged_attention.h index 3c64fb493..29ce603f1 100644 --- a/operators/cuda/paged_attention.h +++ b/operators/cuda/paged_attention.h @@ -46,7 +46,7 @@ template OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& key, const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, const ortc::Tensor& context_lens, - std::optional*> positions, int32_t num_heads, int32_t head_size, InputMetadata& input_metadata, PackedAttentionParameters& parameters) { + std::optional*> positions, int32_t num_heads, int32_t head_size, bool prompt_mode, InputMetadata& input_metadata, PackedAttentionParameters& parameters) { const std::vector& query_shape = query.Shape(); if (query_shape.size() < 2 || query_shape.size() > 3) { return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); @@ -56,18 +56,8 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con } int seq_len = query_shape[0]; - if (positions.has_value()) { - std::vector positions_host((*positions)->Shape().size()); - //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); - cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost); - while (positions_host.back() == 0) { - positions_host.pop_back(); - seq_len--; - } - - input_metadata.max_num_blocks_per_seq = 0; - // in prompt mode - if (positions_host.size() > 1 || positions_host.back() == 0) { + input_metadata.max_num_blocks_per_seq = 0; + if (prompt_mode) { input_metadata.num_prompt_tokens = seq_len; input_metadata.num_generation_tokens = 0; @@ -78,17 +68,18 @@ OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, con input_metadata.attn_bias.q_seqinfo.seqstart = reinterpret_cast(input_metadata.seqinfo.get()); input_metadata.attn_bias.q_seqinfo.max_seqlen = input_metadata.num_prompt_tokens; input_metadata.attn_bias.batchsize = 1; - } else { + } else { + std::vector positions_host((*positions)->Shape().size()); // TODO(leca): input_metadata.num_prompt_tokens = 0; input_metadata.num_generation_tokens = seq_len; - input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty? + input_metadata.max_context_len = positions_host.back() + 1; int32_t block_size = gsl::narrow(key_cache.Shape()[3]); - for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; - } - } else { - // TODO(leca): context_lens is nullptr? - std::vector context_len_host(context_lens.SizeInBytes()); + for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; + } + + if (!positions.has_value()) { // TODO(leca): only generate position when cos_sin_cache is provided? As position and cos_sin_cache are only used for rotary embeding + std::vector context_len_host(context_lens.SizeInBytes() / sizeof(int32_t)); //ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), (*context_lens)->DataRaw(), (*context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); cudaMemcpy(context_len_host.data(), context_lens.DataRaw(), context_lens.SizeInBytes(), cudaMemcpyDeviceToHost); std::vector position_ids; @@ -204,12 +195,11 @@ struct PagedAttention { const ortc::Tensor& context_lens, const ortc::Tensor& is_prompt, std::optional*> positions, std::optional*> cos_sin_cache, ortc::Tensor& attn_out) const { + bool prompt_mode = *(is_prompt.Data()) == 1; InputMetadata input_metadata; PackedAttentionParameters parameters; - OrtMemoryInfo* mem_info = nullptr; - ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx->GetCudaDeviceId(), OrtMemTypeDefault, &mem_info)); ORTX_RETURN_IF_ERROR(CheckInputs(reinterpret_cast(ctx->GetCudaStream()), allocator_.get(), query, key, value, - key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, num_heads_, head_size_, input_metadata, parameters)); + key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, num_heads_, head_size_, prompt_mode, input_metadata, parameters)); parameters.head_size = head_size_; parameters.num_heads = num_heads_; parameters.num_kv_heads = num_kv_heads_; @@ -229,47 +219,49 @@ struct PagedAttention { } const std::vector& key_cache_shape = key_cache.Shape(); + int block_size = key_cache_shape[1] / (num_kv_heads_ * head_size_); if (input_metadata.num_valid_tokens > 0) { int64_t key_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; int64_t value_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; - int block_size = key_cache_shape[1] / (num_kv_heads_ * head_size_); // TODO(leca): or we just pass num_valid_tokens, num_kv_head, head_size and block_size as parameter? cuda::reshape_and_cache(reinterpret_cast(ctx->GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(), key_shape_r, value_shape_r, block_size); } - if (*(is_prompt.Data())) { + OrtMemoryInfo* mem_info = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx->GetCudaDeviceId(), OrtMemTypeDefault, &mem_info)); + if (prompt_mode) { // TODO(leca): deallocate mem_info return RunMultiHeadAttention(ctx, query, key, value, output_data, mem_info, parameters, input_metadata); // Don't handle prompt with decoding case for now } -// -// if (input_metadata.num_generation_tokens > 0) { -// constexpr int PARTITION_SIZE = 512; -// int max_num_partitions = (input_metadata.max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; -// bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE; -// int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_}; -// if (use_v1) { -// cuda::paged_attention_v1(reinterpret_cast(ctx->GetCudaStream()), reinterpret_cast(output_data), query.DataRaw(), -// key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, -// block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, -// value_cache.Shape()[3], input_metadata.max_context_len, nullptr, -// input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); -// } else { -// void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); -// UniquePtrWithDeletor tmp_output = GetScratchBuffer(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT -// void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); -// UniquePtrWithDeletor exp_sums = GetScratchBuffer(exp_sums_raw, allocator_.get()); -// void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); -// UniquePtrWithDeletor max_logits = GetScratchBuffer(max_logits_raw, allocator_.get()); -// cuda::paged_attention_v2(reinterpret_cast(ctx->GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast(output_data), query.DataRaw(), -// key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, -// block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, -// value_cache.Shape()[3], input_metadata.max_context_len, nullptr, -// input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); -// -// } -// } -// OrtW::API::ReleaseMemoryInfo(mem_info); + + if (input_metadata.num_generation_tokens > 0) { + constexpr int PARTITION_SIZE = 512; + int max_num_partitions = (input_metadata.max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE; + int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_}; + if (use_v1) { + cuda::paged_attention_v1(reinterpret_cast(ctx->GetCudaStream()), reinterpret_cast(output_data), query.DataRaw(), + key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, + block_tables.Data(), context_lens.Data(), + block_size, input_metadata.max_context_len, nullptr, + block_tables.Shape()[1], generation_qeury_shape, num_queries_per_kv_); // TODO(leca): block_tables.Shape()[1] replacing input_metadata.max_num_blocks_per_seq + } else { + void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); + UniquePtrWithDeletor tmp_output = GetScratchBuffer(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT + void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); + UniquePtrWithDeletor exp_sums = GetScratchBuffer(exp_sums_raw, allocator_.get()); + void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); + UniquePtrWithDeletor max_logits = GetScratchBuffer(max_logits_raw, allocator_.get()); + cuda::paged_attention_v2(reinterpret_cast(ctx->GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast(output_data), query.DataRaw(), + key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, + block_tables.Data(), context_lens.Data(), + block_size, input_metadata.max_context_len, nullptr, + block_tables.Shape()[1], generation_qeury_shape, num_queries_per_kv_); + + } + } + OrtW::API::ReleaseMemoryInfo(mem_info); return nullptr; } diff --git a/operators/cuda/paged_attention_impl.cu b/operators/cuda/paged_attention_impl.cu index a629124d9..f2c8c0c9c 100644 --- a/operators/cuda/paged_attention_impl.cu +++ b/operators/cuda/paged_attention_impl.cu @@ -772,17 +772,8 @@ void paged_attention_v1( const float* __restrict__ alibi_slopes, const int max_num_blocks_per_seq, const int64_t* query_shapes, - int num_queries_per_kv, - int dtype) { - if (dtype == 0) { // Float - CALL_V1_LAUNCHER_BLOCK_SIZE(float); - } else if (dtype == 1) { // Half - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (dtype == 2) { // BFloat16 - // CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - // TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } + int num_queries_per_kv) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); } #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ @@ -960,17 +951,8 @@ void paged_attention_v2( const float* alibi_slopes, const int max_num_blocks_per_seq, const int64_t* query_shapes, - int num_queries_per_kv, - int dtype) { - if (dtype == 0) { // Float - CALL_V2_LAUNCHER_BLOCK_SIZE(float); - } else if (dtype == 1) { // Half - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (dtype == 2) { // BFloat16 - // CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - //TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } + int num_queries_per_kv) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); } void rotary_embedding_neox( diff --git a/operators/cuda/paged_attention_impl.h b/operators/cuda/paged_attention_impl.h index c23017720..b7697dc0d 100644 --- a/operators/cuda/paged_attention_impl.h +++ b/operators/cuda/paged_attention_impl.h @@ -74,8 +74,8 @@ void paged_attention_v1( const float* __restrict__ alibi_slopes, const int max_num_blocks_per_seq, const int64_t* query_shapes, - int num_queries_per_kv, - int dtype); + int num_queries_per_kv); +// int dtype); // const void* kv_quant_params_cache = nullptr, // [num_blocks, 2, num_kv_heads, head_size / kv_quant_chunk_size, block_size] // int kv_quant_chunk_size = 0, // int kv_quant_param_dtype = 0); @@ -98,8 +98,8 @@ void paged_attention_v2( const float* alibi_slopes, const int max_num_blocks_per_seq, const int64_t* query_shapes, - int num_queries_per_kv, - int dtype); + int num_queries_per_kv); +// int dtype); void reshape_and_cache( const cudaStream_t stream, diff --git a/test/cuda/key.npy b/test/cuda/key.npy new file mode 100644 index 0000000000000000000000000000000000000000..a9ec93b04abd90a19413687499949a2a8180177f GIT binary patch literal 89216 zcmbT7WnV`NFTtO@K z$9ktYr&|yfaNXUn!Um&+CjIL(Aq=wY+>ilUtqwuiP6 zq~!doeFyd6K6E^m!P!J*_+QY;*Vc*v4IKA2M$OUV;B$Bn{|YaZd=PZyB5kDA(RYkF zm0g+cuJu!w(#?q0L52;ejf`=0I-JdSS7ES?MwtUKiTA+q{B~bWlZP7qLG(8N#$!;u zgyqISkK`D@?qRCHfbo!hW302hkq5;r*bloM$N33j4LsfG;DlUh3k3VHL&-mIZ>1ra z$VcH6SV>-^1?m>h`;5974A+J%VBE^{>;qsZYGIW(Ml(|zGgeC+iOS(iatWxB-}tiX zZQ-=B&+Npv1%HC(YIV?Qe^Txgbc{O1k4BGU1~t0J{^m+pU*QsEY+eVo1FG+jO8Tf> z5=U4QQiF!*uNZqMwioE7tTEs4SxI{WpNT@!PjVYb*TcBoQX_U6c@`byGcB6Q2!7=& z(5)lpkV#}kV>Ek+xZ!>dR|(C)TK%UdHP{@sxBmoRxoV6AM;`7WjAVvVQt30=!&3khNbZ19wGAf%0e*EQ$7OndBJkCla_0>`zAz7~yJ`P@wkE{myialWmajj7@|GqFNp!jENubrpVFf#**@VE30hBKS>Wx*xK9so70BtqX?z1Vn*T*k z;GfWo;X1ks-QH?$^&>6jCHK-iYR^|bX`kc~WCL^>C&}N*_tIcM9 zK1w|l4xWPEz+Nl1%ztDHJ zKnS?XdjFsXQU}#IbUV-h9VEjn!n6_^(dD&K=#jSA*az>qV!1WU40A1U4SSifiM_>& zuoQe0@7K>+oM!TCnU&@c!i5KCy?_mqCgv0B0Bt`SsLyiK-~oTf>lR3W`%U@^`KXw}DD?|7M)QnM}N=ESbdAamQx=!=>8#*lR-qH}Nf!ef&<8%zb7PtN@A; zTe?T6>GD0Muh&}g_7 z6+vFK9?!P2k9^}#l9kk0u0|j{Z%$5EW+q=fci!HQs55NL^vhX~N)e~k%i3sMM;wAK z!7O!t@K0)A!~0Z;sYk|r}Rr4ME)PtO1_1~5&)XaCpW6i1T`9^As~ znE-BTd{u8qZS@ISfygQCktW3HI=6v2BzIw+gEr=*Mfu9e5W)AcGaPIY2%Bsz#6NDLa zu(`kyJeh371p8Xw`Q`E(4Z$PM_BmZ(O{z4%4Q_>t8rG$M1%A5_E|8Z{6HL~sObn(9 z>y_%U${Er=Xh*Zl&1tU(`?P(WFMj*+CV=;MymJuh2|%5mc5~{+QX1d z4!+H?S>udthEq-BW|$?3?3^THm5`>7Q5_&pA82)z%J@>$?x14UcJlCgk)&W0+NP za}Apna>B`9xE#7IR;SMgA9`-^T)N2X1;XQtVvcNsweFmN(Xnu_@f$Nl!{+$`Ct`GGatUbdIIIR3suM&Sr z-_UI9i+TX=1dY|hOm}gWFbmOG%wA3v79Ru`@l)IzatQ6anyp?i>iZjzvGAKT!|Lsh zWzi8TI}Hv*@tOA%Pby91jln4963XF6L?$v6vmzLXUz^j59_P)^8JQR+&4ZT%arRWb z8|bKy=F8xxkr6;=k^S`qG$H$`kuHq5* z*(iCc&^GFCHAP(?70wPP9L9b975j=SshnUusEcMs*m5Q*5`QPTv$D^52G!FJv+)+` zSs;7CR;Hm)QhlcGB~~4pjWoPlkM~Sjl|4ZwudV^K;=J$4>uQ z*vLi3weWn`e$&l@!{lzDMPRg52VU?D@dLd*xQdJF3)wBo-ki&HDa6K<()K#eGTT6< z;4XF^$RY-)E2Hn~e}fA2zj`}e7I!i~0cQ5AB)TQPVg@mc@c?EWX=4{*tKbTL0&&); z!yWM#5a)&N;18~{0QIBx@8(7I9Qba!4rDmad%F;{=(2{`3**C}62XWKwC=uK?|UT{ z#2S~}b^O<*FxwBNBY5R$#neZ~`G!Qg_K|9UC*^OHOPSq_@meJ322t9ajECZ8pW85q z|K)5Z&r0)s6ys*jz-h1)*?@RV)um2SUzrP_uydBwM4qkHa*yC@WDhd><}7j^2pogI z9UR%rx`*Y^YU^t64WII-Ts|o!jS2mW38)epstr=k=0q5!s6xRW#w@a=z9Qqad0Z|f zK9ibiFVg-JZlNeRLtn&y#{Y<2jZF2dHQYL;*FijR%Z1@DEjF;7FR4E^7`RzzqCS9F z8_y3Cbs-Pb0k5cck)PR_@>MF@oXP(nONo6PH`50xeL)gTPa zS}tcczfoDsH;|7TT@x1yw}n&5AW^YTA>QK~iTktHfI9@sA0xf^v3C>Y$h%;=lP-9B znEUy;Np1bS)YP%wzg$V=NAYuoRD4oeqHPS+6OW;0C^o1Vo$)?Tcd0Qu4vt2()e@c# z0##_sqSVyEs{BfD_IrUmF!7#tgtNcyeos8 zhF^jXTtIB5K2WYH8`Za-Z=%Vk_+Nll5nlN?COnHl`{>m~%_Sfk05xxc*X( zM#a|;)!uqwPpBeIWwW_tc#!`sE#(zyBQum9sx1w(5e<#hoIG=2#wg*5r^f!~QnY$L zYa3dN9>I>oXbuT=Qbj6vic?O zy1R&Ul^JeK_qBzW9Qkm7V+;z{^HFnLg5S*_)IZ5n)y~{$*f!`edYDVeFQg$=Fs?`$ z+B4WS_a|Ewj*~0M?bPDnWo#}kuAe8TX!$;u;iK(LOzxTdyu=5Y^^Fnw1yGSWZ}h>> zw7qh5oFLXFRQ7w~0NqV5g)Q{egv(aMo9c20TXFB{G2AIbgb&iY7_1%(*5&keHRcZ6 zSE~!KNVvq7d}Y)WEaeyRky`&yCM_fNL(_6kxvOi7xCJJX+j2AoaB29=Z{+l_C;7r~ zwp0`yBcG}-aS29{S2FooGnhS67++W|Pgms_&P4?1PHM6!hI>fHcrL$HO|HIiHu(%U zP75QtI@72U*$4FiJK1<{o2i~7gXSNO-Si*S&)^`nvl+v5fa8?r;{V8jXgSdmF?B-pLbnTn#Ev{gjhr3ffJ>5|QwJEHn zNuiFKppJ6Q5~u6Y>TrFIW3|x9eD7+;bu%iWT87PeRlK3kM?*aO)SKo|j_}MUW6jCd zD3bgU(0+ff~^7SV&*;@aX*1JKE8qBU9zC@tmk++*Kh zWrk>Xo^f?hP9yH%g~&E&x^|xvm7>VhCV<&?PZpo>OfLZ~;xsYQz6qVp-vaKV1!(Yr zf6Yt6H8W0Z!3+kmCRXdIf5DsdQF|rpqj5pq&#{Is&l0OzRrnd8nOqSq)XKOmeVpsQ z-r1~gB;$EPvgJo%W_czz`!!LSzeF`cnXr!<4M_a}o8mo%+WANMW`P2ApYsTxWl%Cf zUv#x}Omqyf7O|A_(tg3L!Zjl5!#Vzj^Z@5K)q|6bo!m}RM6=9(%s{mz*Hu{)XEU4Y z{{vUdRyZNnVha)%!gbtVa6g=?U-I|hir3!_`l92RzmW-*NcQrjzYooR=ggH6T&Zk1L8*7W~NV~-5)3bD}x z_@2f&Y|vJdZmGwUzo%;XeW!>ZY%9 zSOh_vU=`Kbo#>B{-s{cS9om0te`!DRsiWxG&S^1}mZzCV`NFP3DR3?BoHvfzrk930wSms}${b=Z8kI3KzK6Pr zTfkp}``||1qe=8X`VqA}UsYWMikJrD^p}B8@jmKr-)>Yq>o`{m@ASpuamqt&gj(9F z$zM@t_-gnU8|!k4!-cHr{4Aw{k+vlJ@}w)F)xkIlyLzkB%#Q3}!Vda!t(|QW-urbt zL4RyN1Aoe6`JHK&k(>2eDvCdl54F01&!H#V0TlCZB+Ig=m1($-?)2}_314KOverbp zYIlHx#!w-bQ(_M@k6=B%p?j?s#?+Nwnh|-CMQ1ccLgIM61=-P?ii#Mdn(5zeTOprEExeadk}up`iZYam<}b-^wlofs zkJVf>PfH3MR~+v8R-WStm%uIL-Jv|w9#3-|biAZH>s9b`K*RDv3;lz>+O$Uv5c~3z zot>VW1MI0kiD$)GORm@`|=lTZ>+~|8|$<2GeKU z(~R20x-8^9CM(v5F(@{0g za$FA{FzzZlsX^k1I1CJ_nLZrQyhx-)%yQon(hseZI(SDzIjUaT|Ck(RCz}9I!;a=# zuvG6!9Vc__FJw1v%6?P)WC4wQMdh0|g`4JG>UWETJ)Knj>nx9c+A%k$1=BqDaCSX* z8fs=7g_jf~H51P9Tu5Dv-lM{KJD85;-ZQu=T9|FP4$>37XS2r})%5x@1Iv(anab`Y4vAj~BBnQU z+C1UvZ^r34)Dev_E~@1drn4i>ddxBN6AlR&)>4ps=nMCeE$S*{Rmp4(M&O>omrM~r%kRNSBHFATJZ0n)&AE;K z9K9wu%I)=Dmp(gQv4>H>^q`>HTUpQc7USjZ;t;E&euJEe(&%Av0|Zvh)@FIHpnF;+ z^@UZbWz>%DfLz{I+9$15C=HYkrqEYFFm{lAvNukh9_j;2;cH%2KEd@|yfRW)~#|oR@xBfarmW z>R07Vay*)3tpT&)bCjn3*6*Px?k?KHH^BQ{+u%R=fZWMZ54BBN3Q2Bx>>N2vZKibO zwx(C3c%h$J|41Fzeai#FKo>)XkI^94S>m;vPe-%msjj#pc4!Br;yG@d1Pi<}*GY&5 z2LlCE`KW)0-XSU#cC3&4s?t(qPgx^f)DSPyN{CajDDBz>B2dlFg zg?vV=Hg-xy6hp8FX2@TaHmJbp1pmYfxhMgkk&>6z)NtH|7$YUidjDAb$H!% z1RJ=BqZDy8{V?AqxQ?kuRrF5}eCA7`83w1$*18GXVSVav<2L1_ed2kw_`cynSFMV> ztWqmDf@`e(2XCUr(iFa~x=KCA?q)^>jxb@2(_29Igf+;~+$uhi&UIcz$JM*!ACOW< zve)DTIUnFLvTQJeGm;AhxAKy82$cvwoRdkg3N)tOucg1IDM3qTUPkQqXsMstM@n5DMG?&AA7W+p{2WxYLw6WRfk!JR<= z5(Vz6`c|@Y+yOqpQAuxati|okBHVc2(p_`uM}pIz=ZIc0IYe3GT^GzTmOba3R}g*X z84%&wA}lkzaP75sK?z*~@zQBw51x{i2Zv_;EzP6v5qt3}{+GW1$J8tB{R3qtb@57O zk+h{o4R|8_JUt?Gci}=$&oL>dnSa@rS$S#V!UqqX_WUts~&J%pF${Xy`-}`3~*9^b-VFKD}*du8?*!S`7&t##kSuVSRxX@gy2XR%=U~cC*gr4SDe}}wJoFDEm>jbarcIq%{rc%Ni;)jta z<_EgdjS}0UYobefK>udm7zOq$Zi<#b&T3H8cZ7+QPvM2m-pMBtZH{!csq7C;RwA`o zN{zrj$`U;ljYS({4_fE43W@jh6XH;-n%LV(mHu_TYj1&eBS&Mv?Sr<7xxXPOyYzH@>c~ohl0arUK zi`#m1u)3Xl#E*Bk)Bg?hS2Fl&TnoyXQ=O^r85-i^yB+0-IE~dW)7Pm^S~cQYVln=J zInePIc*wueZnuZoNStD-;t#H=!cFo4w^As}Ek>WTw$1_UX!@-o>ethzI2-tuC691V zmMZG&V7jdlc|c7P!?}FvqZA_uxP*C0>!qRKIq9KqxcpfC>G zsJnaUu6XppT>}EP9i9fpq9ynV_WJL5Yltn>Io`W`1G6x$Esmn%sL!5O;(HL6`AfI zYQhex8tA6=jLG-ZKuzU2DAV=QsBDyWpGBkSMqoK3I#Vo-eeHj&%&JtEhdTiaI1dUCH$MVjCNu8r^o!Vqd|(58q}c z@@2SJX`_i;bPf%n=1~2>2S#CH@GqvZc-;6meJ>jWOK~ibKFV8j za_O5l27ObCu_MKNp&lFWZpIeC(+1ExgGIq?bqHMwuL#aZg{6x|Xhow&a2Ws4;N^b< zuVqB+!_DRQw3Bbh9YuHew{jxoBCA`X;b%?-N1AjjS4CSy%1u z7xPZnI|j};UT8PiaL@4U@@T^D?qmt=3Fk1bqZ#xmwxT?rN(qtbDZJfS3!^>VotMSp zxj0nO`gyL~m#eGUnfiiIC923LpsQ+Sc0_QWGD8}OIrTX69M)Ces?{@&NR5e~ey=q( zezw|${x?_||D;#5``n1hH;15z?6-ktAWp94%<*rAwV;i=LTyi~LLK9ODz8aW_yHah zEj_Ca7uN@~7rBPs2xcALj0E%q(ZVdK;Hvy}ZNI*ne5{uv&W5OWAGNjo1(9qi=eh`= z=C`8#dGKs*d3@VbA@4q&3!W2h>P2V@|3oX<@E<%GBk+EUTopiHTH3iGPDAI4Joz0=S(kRm##g0Xr6j&YpgFBb!aXc6_j^ z$6r;;5%2Z3UQG*dLlX`LnzE;jIPn0;X5#dR@F|+6PBSd@QDtI#>AswqOetF>_Yk3U zaDf({EO}ll$8)bs4=9TtA0Prl_TQtLWzU5#osD5*wzuQ9|3Cdyu%xH6FP47{%Af$d zPz|&4sKsOqIUSEtP7?3wLi&h=(cBAI!S+#VDLzQt?%d37R7cw9L4VR?y={JTcJ7h2 zoDdffFT)LtOJF=&X!cnDs zCwmDsX3y~Fx$@5TXek#bZg)F8?X;JCaiPQh;zYc?sh(>5HCR`PpQkTp1KspBqFfGXhC#CfR<8loMin#o6m+GbV$D45{UmBy;T zltYd6M|w~0G{XhM)yKXX)E08SKFfG&o{os6(}M0BRwX`Xz+qvr@h`L+zz!d+y}O2KF}5)I1G!ZrW8T+k-m#3%>Cw3vXuE* zED+bj^+aj7+1X5~h>AMS(|=Q;Bjat|HDy{2Q)9V#(kgK3-n~ zt{Oa@t8CX7qq;CUZzR>5`b6|IN*PP}ww_n^OZ;07G23ngZ<$%$2@JWb9PR_ zk-2y(%u*`zKj3Yp8B^DpW-gO!8}+m%Oc{4OykOTC{|8W#+hiU)QZAV4EI}*_)vsS6 zb>^+PQtyGD2J3?wXpN^on34F@Yq%?^GuYO8U85@B-_J7b_@A(zqdu3fXCHhZM3N8b z5_lRz1iLevJ(BObD!a=tXJQneDACfn#iY62skZwUKBe zSfFg7-_bF#HQg$$pyB*m_MXxf3}t$H>VesgcG3@Gij;#Od>TZge(BI|t~CD_eT}(pw^L&_w6Xm#XQ_L%7bq3p;TD2E>dL4ed@FM^**MR~ zCgnxc>`HxN#|Lxu8@M{!uO`_5Z~7&pS*RaP9oePOQd+A-TMW@y4}dRxL<`zyRaZy_#Ihnv7SuQtn^fR=z2d=_(s zE=OK5IvJ7VFRB#!jPLlisWXJiaUN?{{(Nx;Y|A7t&58d>8?7>^ztMrN!?#Ij=C7-^ zCQF;`*+zlF+NHcAd=GaCs)ZSqu$SMf4bWmd6!RZEqK{Q8NLAc6KGn)|UBGAU&kT+2 z!f)1ESut>k^B@=NZ6zv1e(^AX?4n>7V%* zFcz$_Bw>s@UTw&9&l}Fx^j?A6qR!bo;urR{Zv^-mG9^YEqdX4-O|)~)2O&c#OI#;R zH>Z$8i7R|5&Ln$+8EQ4^COqMMA-;AtM3an6I!)z5&SCwu?b;h-D5&dAX*^5)?S8E0 zXkRcCL}M;KhWl&baId_YbOck)jWB_pMp)iIZLQ_R1*-E%NvwKEok z<0pJY;+k|;;L#7)Eni{zPpSp8!@e@OL~i1lab$tqnW?DeDpUC9*$>eS^gp$=IwlbA zX(qPe?C^|zC$rDj*qZ`f^m;kNzD#+<-$YsV$>eEopI{dT@tsA=*3-L^LXiWrH}V-| zVfQMcBRdDxux5l1_A2tIHy5Yq`(%gfI%=!;Qh+C0k@q!9OxD7UYvftCqQ;O*;B{xM z5PfZ{zH~}#$c+o0pze@^&9TfLqFi<<^F46c{k~;JxG_gdSF5QdBX{c#dUnDSavV`f zKPe7lOKDwk6|FtoPK0P)&OR=cU2G?*k92|ggL%teB%8|Pa}4_;Z4K3+(KJ1axeub% z4#G6HEf?i@Wj+Iuq5PTWb(*c2(eB5b8e9lZurKZXNyg_z)#~-i-o!=&%i1W0c%xV~ zSin7Whr80;DdJ+-iu(<#CeK2H%x8(~LK$EfxRpGbeacr&uo<5NofF?15$=2g)#C_>Y(josw45FA*I)0<&6ME;aQ>k!^wsUD#?QoL3<+g&ib2a_**^RyDtr@vwrQmvR@+G(z6-Zhk9b(gBavV1>oIw#_HMkl?RuoG4Ao))Gg zj*44o+L>kcE?K?Qk_lyHgIS`L#9taV547fo`%1HI@l3eEcxH<7H2w%|DR0+f@&+Hc zC;mv^4=cM*@?L`{>MOt02hj($uSRjsCk<8Ck?WD2CxnjPHA-#kZ@RBummZIti65CZ z+$%FzUQe}vy-+>+{(+2`Cuos}LuIpON52g&cbDSpIXeXZ#@l2a zMJ;FIbFmsGc1KlUX)-$f5Sy#YhR1PEz8|Q{9^%r;{$2r;4bk@Yj`Bh~ZWwWl-)L{l z6n5_e&(z&&VMnAjK>w~>qfd}F^n+34zqv(5C8JPm88iaUW5d*K?ut!DYn`p_s2d&t ze-Yba#@ikSi!fExK8~j>t#x-S)?b4<;saCQPw@dx%mIOEL=I>MUgC#vd&rBK&%{fQ z)e_me$yum~@{xFiD)E#28}SFnDt$~iwiT>FsmEa#?DU`U+Zm@8e}z7AD%d{z?iC8cAC66%3uxrKSs^#$yKOTG}(12y5oV3J%553+_hU+;Uw-clFS+l)8F zf#51>ownCLL(Gr_aw@JX#x`(iWv!CZB8pVM=KOEhP$Dg%qqCS)NNXVOP^+N`ZgO%p zdJ7r|On!|qgJ?%y(DxZ#g!SYuxw*5pj3k^hohQMSL%Wy<$}>D#Z;O+~4)o!Wue?y` z?oSUCbywnF2p!Fdybn48TZI>DBX)93Y~Ul`o|g*qnYct=qY;SS!!P(PD zVxqx3vXu^~ew?KpYtlkXfwS4~o_q3vyvB)L&7GbiP+_C<&g+ZJ+S;o|PWrw0(k#dF zT1~`yYO3{3mon7WGPq2*7HCDR6iVo4g$!aD*B6#&%(&Zpx_SasW|yO_%s14YNXBKj zVe%Z}8$jr~JWbbZQG6hwbcpq5fPdAUcuZ1Nst47d`ycx`rzpVEUU55nJaa7B2L4b_ z(~pvVkbRs3wXsoYB#RF#{n1YKge#Fxb`~>kz%D{CcoLQJNT|M=B$S3};&4^qQ}WVr zF=>ao613#X1j1r^;SA$OYR5eb^2SYLtjLD!$lBa^X1wD7szsPkq%=ncbm{-v zUo+FC`_fb;R?mYNFXh^!e(V<|L^s$2d`iFti?~yb%WMZBs0HJj@rNDW+;`cnj8RHa zHU&R}G($6+{S9z;eYiS`FEG|oE9?h}m3VUQ6I+s!O1IHl1xMtzrA~7r@Fahg;9383 zb$#qPqqX9KbDSMfIKRYI22{~XunVc~$O0pcU&LpygS(eiIqNbQOgscV^)aqYqc`jb zGc%&ag}y#}xAWJWk6^abNPEdffG$coPg3m{4v}60vVy0ynr3@pDjGx{RTiV-?tMp^ zQP9nxrAKN8Zfb?No6HcYjkj0sZS!5AhBO~uz}2kIV*AE5GA1WX+5Ddz0o4r7y$I*S zFQ_N?OKYfJG=IZM=0a(d@pjL}v?6FNIAr8SrO?RSi$>7X1C>$VWDk2@vk5h=UDP$E z5n8DZa*pC*$f6jQd6ak|KVfUhYKX;cBp#NNmYAmxcNw>Lxps3~LB;uKyGn?WR04~o+|8*@(-vjTWRovHzI6TWkfS1H= za{`kpKKGu0M?;k-4ZTK|`doXd>tu#rJ9Cc}@ElEsp1FKYb{~93)<(HjGcca*>E0c@ z3m%wbU3={H&8tj5RD+fHVN|=2chbkRKG?vFL)Fzr!30mXj@5bADspgyDy;LN^s7nL z(d0T`!8z>`*+5m4;lPF)q2`Q0^Y(MWEx0Z>l{7phwIv2JQgY%!M)(n?bD+3+h2+fL z{C`4KPlPXw;Xs6XCfG)l85nwVi@HwZ-#n`|k1gT(g#Tla;|?`~ zE-6(}-dfeQ#q24MteupzL0wxU{Vl1#@y^VMPci<78%QO^p7=lfOx!EAgJa|txI4LD z|Lxt$bd<6}-uotQJF`NKz}xfz&O-WTJPOaSlEk5|8ZZi#GMX~Oi22svcnTaZ9j@CD zyR?4UeZ~LqY_n;wmR6J~%dV1!d-n%H`V$2*Jo|$iPbI*7MkWmMmh5(BfyyD4L>iT$N` zKT}XKTEM>2dZ<(KE}Q$rBB(t>r(Zi*QC%9j87)^o>Xo(d+%(V;CNU?2w>(?12KsT2 zUAc*c)F`*=tmAv*mjkckkkAU>HCyXhWX<3b_<`$=et<>hIr%m|4K64A78l^7c#vlu zkkEHfR`2EZ;YLwuR$rq#)Vv9J5mShGqqy)A+C{H(ncklCCc(x3a+L8jN7MNb(W3LH zufeZyHWiWhCsW?R3M5;Rb5&@fL+Nu~q$gGJ4{ zL<-+D7%*$lHSDMS5olnLt*X<#W;xaR| zX2Fr53NZ8cq|=Z%kRDv3ESVU2@`XY9BBL9p$R33+yN4X?~g= zD?L>g%hhaa*<>PGUnw?b;*y=_A$pZ}lxtOVytTnINSo+=WrQnZ*`Io7H_SfF8y%{k z!_+9dO?a#3k-for_HwY6?hHJ0CDG;VYqe_Vf7s~(p8SOK#B5(X?Y@Ziw`7)%}&-a%$7x)|kKUr{|%2j_+9&ZSgK^d{$`J`Cln#X~%@y!!=nKd>aSp#gAQ zR44r-`PB6q%_Kv|Pw1AgM@Xb->! zMjOstSC$hCvfOeZN0*p)yhA%}y2;@h${y{vj4$plnDn#^z9Sa6VsPPL6~kpd)Dfgr z3ApiK;V9Cyc6=d!U#*T7Z4AS^5_9=sQITMi`!e4UC838-DiG=E1j0lYtbj@oGv!`p zHFbiy6BRZxz%VqCxv5s>yOF2&r@$$&xUekmY-sYG#=KFc!B)a|!o{b9;s%i4X&anp zxR%mPeP!-iq>~R-uXmtj6LxFa^dYT9Mm{@8s6mfYTbu3mn5<`Tx3_ds^T=9kr6xPF zdZu0V{KWCF4*2UPRsE} z+(J?EE?TelaFyb@%$bh%Xuth$2X>0> z7Jo}MLg&VYF~4=Y)n1AuZtGd@nOuFXj*^xYDU5_9I0!z=#pA1){nTx^L5PXA@YFHW zGFo}k&_Ks2Ej)T!UIkDBCMUPlJFxwwHG$P38@pymNT1>835Ude)=AyUUZF#7qxLqd zCH=#nY_#<3#5S+O420pz1|yz{6b2AAd=$E&-jg@TtF%(F8b}dd^7WO-oD_90bJNu( zbY4LgjEoojPz%B3j2D&-Idu*A8ekavHYyUGR_)X~wHlPIFm`p|kBsB&0coxjA^!+a zRHjhGxRBa8xvvL?_WVN5En9l9Zk_~AxyOf2elEy8Ep9YwI;;>stDrX2*0F=c{mdJh zg5T6qAmCY`<_PgHmj8xQq`u~UTrRn!{6fGu$y*`ByFwgA`v=ikcjpo91wb0=XE z{6qh$1GA)Kp?3z1h3#y06CP7r(P;B8@uGc${$7d+7B%+ZU)ls}5L$%V^ZoOZefO1* zYUv=4UaC9vGMSf6HLm^J;-K<9-68GjK=og?6 zpPjon>$H2EYhlRho|ajIi{#gPXY;r0>s%GQQ>i3kp%Je!S}&y`a}C}X&IUff(ezQJ zRrUo1siT;7au0G^csL@=?Kn23m_J&`LKCg;@_OTM^%8Li{NQGiweZ5-f4~aNdwwL_ zO32uIRv!va*}J0?%t!z|+oc8>wR{Vxl7K*~j6F%+z!A9zD3W(sO>@mO^Qgc0f$lHh zomox4D3;Wh!PBwV^~Th!*hh{Gsdf5R^g+KR`ou|+K@U)S@e`Sap;-|}s-O{@3Kppo zBtN>A57d8*3+P1X)V3#Yo0#D)i|*P_uS2eT-?}Os}Ij#oN9#&-gut{7BXSx5A?sC^?CN}rb1osH}KebC@ z3iB{`m@|!wO3dd!knQOK(3NzaPLa2(vtgX(${(qhVv2|(^uDl=l*c~+6hB{XCqB=4 z#E^PFhRy5aYzHgrqIeS~==+X*MZ2j&;wvdjj`FNyii3=t4Jh0D!FM8ga%lEdGdK~w z@~lk#YCJ@f13hv(Gt0H3xmWo_G(kQ=ZR4L31)+qO2yH|yV!U%dNHHHt#XU*d7WC2> zCOi=bYKhc=K#u=GL{&J$PdOe5Cy2xLSA2aG2YQMtVI*;fsSRx0kYG(=SXMh^hkKfg z0e7N?QM2=YC;bbidYWZMhYOj_;9SbR1)2ckoW#>&Rx?Tb^=>AVsdv6|?-00P1|5 zWi@8B{GCqaAIixKF@KFSzy{7w0o1?xv`kQost9pLIW81jwvr;IR zNwG~Ed+JiD`T#aL=aBwfW9c$rEw+a8>TTfi#8zM!+$8f5leZ z9t>8By|L2#u*{wM@+N<%NJ~TQyK!IH9vo&*&_l!`KE^v1<(hYx z%eJ!o-W`{iVcEc$Ltil!%w%mNnV@|IPWV_^=1N9#@R790Jzji{1nyGAVE(4%-{3F$ zndiCu6qMyYL5(ZQKE&o3kBuer2Es~|1dQhdoF{H%mLz1t@oGP7ujEtYHfL4qPzz}c z*9zIsV4j*T-+&**8rXhMcowc)Q4xyJE*U-WiOz7P4YwIx5K^UN)X1aRkL2e#`y_nL zZXwnswQqfF#x#@A%-7Sv~dF`Kn3^cZx7U`z%9NBv^u z;AV3#-ekXMj3JJqPrjDO=l@rIrZl6^(8-P(OcC&ce!^2&Qt&=#OUKZeuCqB^h)L`} zS`UMuw(>mRfbWCX220g*>=UmtVRv+vrZFz5nZxBt6XVoU%onkL%vKc5p^2WaECgE$88i=#A{p|Y_7e6Wl9;*lDc=BtHY|=@uqhC))>1}m4Keel zz*N)Uyg_3V)|z?n>+uzp&sq^!YsU&`Mc?pyNzY_E-^gChQdbVK`}Da`Hs90#A~l|{ zj#L^uS0C)d^YSO360E4?z|z5%aShal{t4P&zB67=Y_9*-mdj)Ork<^wnJ)v{y z$K>6>zkVNF#{uP#)}1~_Z6LKkq;XcN#T=n$8WX`R5DoY84*r0$Mv4Zd!5yhG^%Jk- zu1>E_SSzlM>V{YffEx5D{x~Y<pLL+ul=%ETPj zi6)aZts|scrXzBdnCN#AT+pdE%gdEEMYo?#h&V1J;CISv`yNqC`3CACyVf%snMv2CI+tZ52z_oAE|u7@b&bZ$i4-# zn8)@u;vf{4KNQTxyv$*79ovNa22NAA&Gd_dUV$5E3T)+g67fLYE)hEKDWgr%ba>Jf z0;_qyQWM;j{PpDrMg!UaW};pF>ffzFy|%6?6NB@Wjl3r*>hMqSOfQBt(CfMfdvK+U z7$VNq-U}(jL$$t{uboDnP;2HFr)=Q3vBmH*>yZ`zsomg1jhXiUiQy`vcIF%L>HgA? zXIhzh5D$t(O|q;NBI%M`1!)J|neVbwmT_Vk_DjJ@FxQw7JDR&U^3)40VH~0{}U*Lo~o84(0>-8}$6AUINrlR^kaXEX|y+&WfM{uv=TMG|RJ7ukw z$Gf$F+B$iHu#l)p{N?VTVbp#t8~F4)=BayG2nlv}w%zwto@gClzCQ?9KHb5#PGqz% z1tk^Ih$h4#3~d^b?pxz$MXZ8raR2l6@^ z$yUaQbry%Ixv9n(qbfB;Dwkg$6{i*3ApUH^W^Zw73fdvxOWI}J&ONP{Rt{jVxEd(I zu_=m@Yock<9&&Z|V+=DhKtG5W+~T?z7_CT+LyS?8`=a8w>y7}}ffXgE7o-Ob@Lc|% zF*NW(P+=W<5futrB$pz~Qlr?mdKx{0nv#Tzzis<1uMy2h(W&uUou$|}dWh1%t-=j- zHdoVO?k2{Vy!}Zbd<5zRp1Zyq9!b{!h3io^9W90`gX8wn-I0YU3M|R-j*WN}vyDGj zP$gK5cn^N^8MI4}lm6rONNp_5s2J}{<~Gc*W?7=t+Qx0UH#-a~6+EC!V8?6Um?~^4 zR>G113;B2IEOsY$Rry=B+0p%}(#&h-U*%sdQ&u^| z&(tR4bD>LKE-#Q}V>>0VQ?idX8^#J!Q6gEbzgRt(-GCjmcX_|@r_2X77Hg-E2Rb%P{SP+tI33CO3HqCECqj7-u~qq|WE#sr zG`_6PP@{~tI53x;CUTi|#0JjY%1it9%o}nSeP_%AP{A32x`T@K+d5~^L(mF#e16@q zYcNySnLVc0>n@y2o{D~I1l3<4K`)9`ma2m;1(l@*;1(U=Hbq?^Iq+Y`IdKFtR>?L> z^MBM^;!Gh7g!mPDCGEyz@m_Ga(o#%CISfji8eg0Jk^iqJ2EPtkkoUQV!SPgq`Zsw5 zs;^`!Q^}EHMnY|4FlX{2=$*v9f~Q6)-5Y2u&dRB)r%^e9UNA{mL5~4erKw3oS_xVz z%_Sd|rb*gtW;^f*^{HX}Q~HcoR~orjiHnpN)Xv_CIG&dvraF45VfH)r&Gb?G1&C9o za>VszPhdFJU*EtCat}qko(~_hXN3Vw8ke3k-*`jM@cku#9&g*7bSI~=vVu!-x5tC1 z6h0YD;XA03zYN*HG7HXDW8td6WFY`9DP^SsZ$nQ8HAd`xV6?v>Gcj?o&>59d-RuO| zmpx4lcl`yefS`@kpQ?3(rrEeXb3WsGg-S+4iI7#9^A1?vSOZRL>ogo$-oZ8BHo101>kv^;c?%Nz!i7i#wlBKu<+G@EczJs5wt-}7x z=}PBRaXkeD1_8Rd@Ph^HZU`7gsUI-&d~iPUyNRb2eDl33QPn?!6YTb*%oY; z5)(#Q2D6h<1!1k6tIpK^5oWNvl)lP5eHmSah(u|oelc7gN{rUJWYkEn#{T0gsogON zQ{~YeZH7L>*~QGXZM0+L5@opcl=PV$Pw8rN&7d#T5m8o)C-c)3>b_%{$O4-|KMf87@9_2npjZix(%iUe=$aSI> z{(Im8xIi7E?hD76>%vD+5&-~ z?y;rZH4~mXTdHyTIM~By6$XPQUR<6d8pI=(HMs{4laCcW3F;semmEC+9fVbqOa4-uB|7PP-ZeKPj#J`(3YyqOhsF6C4d3MllJi z+|RYy!4cX*t~L88#;Is}Y0zJ_vm)x6ou-l218f013GYu%0sP=? zu6x1djy>{3V>5gUk231P!5YUFJM>FiM319hf;i_lZx?VsX^Z8P(2U&TA3#^~eFh(~ zu7P$^5#JG%51N5A|HlSHU@`SKx+0un0O|-w`Jd85y$6i|a&_@3h+iNR;Q- zD6y(cPY4_ozN$I)X#WID>%18--8xk-a0Ab2uoqzHpy&JAQmjC*uj0%*;}Zj1eU3j#aadP(NWNt zx&VKfd^nc-3lfr!iRR6Wy@1YImeI?NW)7fsWEXlHIV5%wJz3q05BAl9@5RxsGHfZ| zV0(Yd3;B>|lu1JyuCJ5g=&k(KjK_Ru^`*2bygD_=R1r^vO?9qLWd{Spm&MO zh;^ihE6K}2C@Al0sC4p{=NqGFydz_CPJz$)2kPL0F3NQn#I_d9r}n_gXdsmhk`sQY zO}yRpQ;u?c16193PUXP|k$r-%m6gLL%ciW zIf`4|rS#Sgg64n;WGJbe({zeOstpAwH& zO|Uihe$sOTGl>kV11tk9HjS608C1>e-;9r62)yu{_dd;%9k?#;8Zzl0q-UzdRotH= zcJiE}hVuq?RQw4V2&=htZ349b?313+l?k1Z*fN z+DehFQ4`RS43o|)bkGX-VXc4&Snqp$P2xTCT)p_$VQb+d zvAVKC9jR_}7-SdzfYh6MYLeD=#N6_4Q9Iek>!Z;{?g40PS!5r|gavnn`n1pDI*O$& zzI_OyM)vO7PB6wh(b%Kk@mw|&&(+`<>O#}q*cdBiHH5BORimt#*LDZGN>_uIVZF3e z(}$Ykc*~8{W+jz{b3s$=J;sZ@_-gthtfH_g@LMQrJ)049oJQZL-a6vi zP~Duy8}ti4#+MZExQgMsq|)%t!C^#MDgynh_ZKQ+E4b%A4W<$wv=Yn_Asn@5?@(pI zpI`zQp>%}pJ^!(*K^LwjSXG$B&(7}VUP#xoBmAS%3xueD6IwWHP`yzW|J>h9nG^Yu zYhj;Q+D8`GolEPf`3}Fg;wE8hy=s z2mk9$qLw5P`hMEw`6F~z=KFS8rs#X|WojNZ(B$BDM0;U^wFy((lF8JtTAj^ci9i_2 z!hV{HI4c}>Pq9w+-csw*A#B&=Z>ny52brJ=RmXV=&ci!!gK0vk>^+c|td!Q=+yxq# ze78k>N$Mh6X4|2)BX+}PMmYUhf5Y!GQh}RFrl)#Z#hcj*y12B1JFYZ~wZqE98L5I> zRI`;9;1c%>ejzJ)2C*Mdl;7=wYKVEq9;FloE%_4KmV?zyx9|kDIQ3I$$!B>-X<6_* zsF8e#in4?#wDJ?#P=VAzSSI}pug_MfQ-Ya7+_M~47rcY1{-IT;?5~0tqwe3-eC75xJ3Ob0kodk5AFO@JXiRDMY#J>}^cW^!tnF37(zT>4+oFfQiZFz;wiygj-kgepJu zXhjm5TUvuhq|1LisT*-u>Oh9l{l#KnySiSwk7zO(4Wlav7N+BY=;Re#12c8;&QxTl z>b%x7$ZH{7sP%~dv1OD!#oB^ifj(H@;^)9jcgwrP|J#@AvJ@7o$ZsrW{l@m=o-Ul4w-) zue4R;eo~)^_hu@~Lp(7m2KM*|S|@U}_7`jdjs35|dF~Wg!(G{@(-r7Qy*1O^W(RLc zM0ZkJp~`?(AK3;vcS@|cg73W0-CL+k)SFg@OkehnqqSIzDWUzq(tQ;jMHt=8K>Sh| zZdt=CmKLrGcoBML?krDhqCPixUo8+CSvKh>;gkk?F0{-S$?n&ZZ$AmnVK8UVb=|CNz_?BXOu><@)y0i+6RrxRI zQ0qhUdw+{N#%z!7?C(NPC8k+0I7Y1shVqmx2}UTGemf^#7)@?~xq&gBsQe?jKdFDP zP8!rIIQoM$CC%9A-^sV&vV5DUF5m?y5(qa8xIt+k7pI$(6Kr*ui9(K@HpaWEv#qcz zibtEKFThWFPJjmqQS>0>|0z>MULHNRro?MH15PU}yEM-T>rkOXP~` zG^H8&+wxpoVtCO!-vAhHdRJQuCG>cqEBDpbC-BFzpL&fZ=|9;D+Q7WQNvXnRDl;;c z^m=b|j`+Re3jug5NZ+ZO@-+@AH~km2wx|MEuV9jUycDbW@O!}R9>jKY{v}>mHnNG_ zfABN-&-9Enr7yCJfT~_6t{HPgm7Yf(G(STzHU;z`-Fz+WGgdpJyl+439mwLwt7TCv zS3YrG;JbaBR#S|m%Hd_jW7cBqjmCHLor2M7UE{3Uo-M)~ z!9@gT`kKS&B;CTrfPRi{C`h)^gTAfIFuFx_4tY(!D7=HSj2YNK&@lKTsyTB6d&*Gg zZC*a;lT!829y@ePOXl5 zU!6y}ZGBNQX0de7=wmuFiY6D4-n+({&dh0sM=r{?HNR&?c|@5GRz92jF8yLUa)QaF zN~J#QAEh3FUqVYmSKAcyr>4mc{HhkduWDcjyq} z1wNxiN(90!x2zwNZVQR}3!zVNtzK5l$4j|3@^y8>a)W6rG^ZLmkNSR#d+@VZiR8q< zPNgc{v%skw&xr^=it&OX!gP0`5am836eBx|!{G%p(boh{WLv@%!X-0ULm(SRToFw1 zU=|Iu)%pFnHR^Xc$<(C|i6evmaZgo8AptVAk;XQvld#?X9AD((cpk|UPGl5-LILx@uH>Mag0ltaeUN8bi$abT+x{_(M_nz`ldek&ZA1cg^Ijnai z*2hkvJMy(aj$Ma~9KQo6Oje(y%vW!t5dDz&-8&zR#EyD;(@-rJxM~q&8ydd?6t+$r z2=7tls4)F3b$|&&RoFV#;q9 z>J4gF(8ew>(~L!h^>lJ(ef(zJ7xsd#(F;K-Avs9fj<7cpp9pIppPZ)EaPA8F@Nrc8 zzC?XgzG8mb7tiY)?5oBoR07!IC7K5yFo^Hj3(#+q)d)kJC@ zm+UB&|6RPDU{?mBD0qVkCA&foQ5rU~CgTT^qUbBYE!*XBdR44t(5}Z*s~L@+%Jx$C zGDVEALq{y_?Q!fw{Rh9nGh0vB_u)a6O67@eii(Er4@Fa0<671IE+8+!@r(hu& zBG(EQs@4#Uhs9}0#BOI*;a|E0OoQRm>-q` z`UyH?&fu-VhS6WqakeuW$yN1MV{S9!VV1BO+d$pdBOF%iVw23@m}w>bcFfT1+5}{E zM3Ws7FVicrxwzt&Z4S||r?TTQS;@uQQAa{4SGsGNaF6W>Ur?)A77XSaQ9iaFAHmke z`k)KyRWjY{+?{ z(vUOEL{Pr-aE&5jim85E5K=~=3CdG}4usL9KGstfnmh<ZK5?ADl$Z6|pB+9L{P9E023zT*P{tQvdRypA{3dU`x zuQ6XV55LdZo_lU{Xp4og03j%7Zp=<la zJ+T$?ZmtOTm)>-!I9o#IXNtfhX zsI4W*2jMRMhxHN}#~9$7T8AE=XpVx=bH`DtgwPJHG3*o#OBI?lY4f{1%HEYPqt#4I zR_>W+&sg7CM|T^O-@;VZ%aMJB4Ro&9pJ|TPiaQ-6=wBvj<~RI`)iE5R&+-Q3fHzuI zoC!9z)ONM@h0+mNCF-15UaZ82CU!%Z_LNMAKXRYRou#fGT>qB$K)#-9x;lJaul+q7?PxEy05-p2ty9E+h$}RH;N~xq6wCV@Lz0> z<#}}U&c2Q;J!Q{C^viBHeI5&JWz*xNIr=kaK*Tje{;A)gUEEpqFq4Oe(C?{R#1;p~ zouB|u(^=L`&_k>dOv`O;O!Hj97yEDU$JiwAMQyb?Crjl#}jLPD5jiPXVe z=&D2XkUA*3l)~potI%yv9er=X9Y<58t}tEsE!K4KHsE7%etF&QB&v% zQM}_Hc)_dqgVLJxXslA^bP+L^8yf5j24Z8oyuRqcp9(~g+2 zFh;o(OhH>vjx$HcsAEc3+?RMi#;@x1zlnG9Pin0ahKOI0hQEO``OZu-xj5eD;_&&zmh>6t0O>9$$JG`FnaW5aRs|KaI*~#*323%gmKK8B-pZW@mV#2? zecVM?745Lp#TgqoC?BOCP_6lEOewVu@jR9=iM>08ZMH(a7LJ3WoL}tdSi$^adt)mB z%h=F1aFv_uJ|G-cJCdx;uHO+WCQKua>HdNn=s(R|x}ZJYoh(kZl!L@@ahlcxZBe?w zCmAJ#5A0|50<(@wLf4(I-P=%>6jHaf^2f|v9mX1HcbMX~e5MCVfP=vgsfN@oVJo%M zQOI{=<`N(ILG&fqiK^@QX02*p3S}QF72|NW64wHSp|$2jz)r8m`r&S+l|0!zSqhzH z_;vVBB~6J|25~pE;YzgCpZHB(n;RtO=#1r|+yk2mN^@>og93_t01Uhw+KP&zj#xc( z)Ke_}2DUA@4G_UCfp9X-Hd80eo!2RH_%B-}P+QE>tH(E2mu1>{JKCr>D%=IesVWM3 z4;kwMtx&Fij$UX;$<489N$rd!1qdEXDnqrmp4S^=3*ZuSoBeG%l1>X=y1Ly(a*X7R z67$dmwsO=|uAbvMp3L~+5;O~D^PeDRXe60Pm|a3c!FBRAb!p!T?JPc(X+tbVkJL)A z16Bk3$^SvJR7xoUvY`#D$=5^`=m4IN1(ch77wm#EHaG$*@?B6`9mY)xm%(GbN8DYx zS$;HXVNXG0&>RzNzo%4~Oesz9d%pTaWfY}0Co9@l(ifzX)L!p5 zXjnfooopimjBgJ0N$(}zWHLR;80()1-zig6QuxPlpIYaj*ktUtuaGO}IPHv479}pE z`kI8@rf4?s(bobY9s%w`v2sH&1}tM+nG>^!z<8#&aGI_dC}D}FCcq8MF8QH0&r;G` zomvQqaZ3OX30BN9l~(ym4_+1{bx*%zJv7ZM8R+Rqzgm-*o3L zL;30)^|p+$6Px~(L^MT=@HMw@aCFDFg%`(T#go?0%pep55onA1qsUNGoHjVh*(0eq z?5&S8%&skQBhed`H|m0~@F}Y0t!JbO8YtixbBv0C^xM&gYEcS_4a`-CQrVWsm`P*2WVVm9RH9y@Y`Nu-RR^Tk5-po!Tr=$ zz@o4K>>(aFcP6|fAG1#8r?M?FvM?flrE?_kIR@!b)GriWc$exc>`%(@wJ}#sE3>(& z9u`L()8}AyZSAmmSY*@L)Mfc-;G#Jf9mf}urs)IBx#b`(T>X=M&A)(s=fA2YqoQCA z`^|yC?aFye{xLFe6Pq>~>!hNVMwT3%n-XaDh zoCV*viQpT5L7c~QqxP`5NnU!jcEV^2BKgH=lbO4?K(%EzJLej>@zKc{ZLoFzW(67C z5Um~7NmukgdW@QC-kUyiuet}QM$IZ5KNFHWuIDyKjDeeM>v<0=e#s^!c_8pgJSwLF;;!Y zlqc(g=X4Hut0kc?%3on`^4UNe#XRMN^7Q^-3GAa|wszmI(Yw(gI-5FXrlgL`K|Qwq zp8S`3(cmqvC+Hqr22|A6KZj32IkpB)-gp$)Ds}^H`Es5-g)%2U0=I+d$%ImGm{ZDH z`nvH1FJd`L?>92kwZt`3CprZqG$%60WAYjN7X8(3aypaV=C3q!-WITiUSd)(*K&8j z1nseMR=%oFWd|}_+2mkj;a|%trGPH)aDyz=+N9^y<@dmjbQajkKg_jK2h~}Q@6>Qg zCFbg58?8Y7qy_A6lO~icS4+4}y|ARS59s6cK9f5*%HJ8kLb=4Zcn$EK>;hld5~w-+ z4irQ;^v={aA{hixHFr8FVAr5u;Gm;=e7Ia0t?dUnegRq=}89VLIOiw@@N|J71#TlMo zO*UlqIj1w5KoVI&*{82Y*C;#u!%Rq=gp%>&s3cW_Y7f@iyXAfW)ujpS(6D56SKLJC zQJ27poMthEaM4L$^)yGH?Pp?l`k&$_xgPpR@|NUv#06FUiS0SQz%kLWlkVdw#5Xh7 z^~LHn;;tB_GQJ+h65Do+rXJfPY<8sYlztiTVM*^GLC1HrWx_&f$DOD${y2wJAy03Hr1L*pJN&cm%6MPmp<%u(IfKhS1 znFZ<=p;E>&q{UD7b>Pm@<+<|m7SzwoOJ>I-$1|Z6ILeJACyB-20KNqy_{v(3M%sV{ z8U@v1q(mQRY5EHypbLMT{Y&g*eAI?xtA*;YiC)6_25t2pM(^!S_$zt}=2EJfw1y>3 zOY$ecRG~~LMwKyJ214Xz^0`=Zw#n8-&A|_J+g$w-)K2n_T7w%#u2HHu7r{LUcYKh4 zYK!Tm)--TCS~4k%`!$cXdf=AYN%yPiNhRp%mJnBF{0KeFDCNBg4`6Qde$tp($GnZx zl+NH~gd&XqpP5r~F4IxD?rFil!AF46j@{Ir;BjrT^`IlI@Retxu$Uf1KY;rTGl@?> zrN$pSK{;S;MINEn zOI_4ibOWx6=`wjh-_>Iq7Anwe16gUlEb0iVKFhm_8HQlIKXRq7lYb#KJ+TDULAwHe zW?#qpI3I{9v`wtV&zFnB<<|0+^=j+fa@q>BQ?5APk`awHS~2{cIzcam{91P=A#Rc3 zOyjh2&H>)qW{1>5%Vz2iPlGSgP3|gBP+zH2)=%UPut8qnO|cZwfO0{ZCl!-}pcgjM zs2BLl-sSs>o3RyQmf(U6)k?m_+*jU)WB6lcv}KBACN@dysjv175f>^0*eeI4s9q+= zXDRKHc%f8af;taN@zkWAdLA-2T=)BR@udb8m*NOHdz~&aRP);0oyC` zGQk7JBG6|Yd4bUQ65Kfu3;yWOENPTUSvJi4MQ|xvsXBSPr0HFP5&Al_g6*Nyk!p(f zQ^-I!Fpef1B~5kyh}orb92J#sYBMC9IIaz*e#g9##DE{%4)5$dM2qzO@R`I@fl+It z9>y(vl@uObRh8KkSisiD=Ieh==3-5znKl~z2*lvI!BY52@5+E{FQz?CJcYZd9o9Df zbbJiTl#@&@S|mygYFZC2+g_8H=YOG`Ht!ru!&(tu=DS`-IRL}eAK>}HQ1g}ZuM)}a zgKhbS`f1x5u3f=A?sdYIn9u4d+3(*eq?uhkJxs>7UE}cX`U~DBKG50{=s(`D142??VB&4M)RqO1{4?^_UzCR)Y-U2n1vYpWy<_`ZMs3Tbujb z??N5Sz4--Y(e4G3oU@gd(8=GX7Dx>ewSuko4-BW400*VU#y$#oZpgdDM7l_RnP7x+ zQ-nr@+#cL>y;rBk?h(>F`}j1^bW{&iXBO&R$#c+;Uh_7lim{#Sf}1P^cCA*ApUEjm z(e?*Z0#!(>zDb-4=U^h9$=Q8%%>KlDmBYil8_0H&pv9AU+-hT#vffO&N9o6iSTF{k zOiVIe$)Q@G`0M!*%t;zg_*OeJZ;RH-`0CqScm&Lo&Rdq6v`k8NQ{|Mg8CUa?upB&# zTh7PmeS_6ebU`xv6g6b`7EHuPdD`olmWPzan2;q?SB>R)=8^tZn(X{W-j_GZm(3o< zHv9eDXML7q3b_jWW4;)#nl9Bny30~hIV^QnJ0(p=d2pUE*85qANGO&0*`Oq31jSA4 zVvffbdaYd1z;JFW(bQf>Xm70#cBQtZeqsNEH`r8R7x&K9UTenfY1%@p2RHjVN-N|4 z-QP;ygF;xzbnH`(?(}Z3m+GJwDr>39@Qm6wsyls4m~Qqm zpCRvO_vr!OV)?e2`>dD#NZANYHxe(=2{CgD8{=c_9689D%Ern)EjP^`ydow&F^(}E zaqcAF7gq)DATWCs(LtD|Rt?r+D&T$KN6iIZag=i2+kkMwRCtDcXKjx*p+r||LOK?v z9pp2GInIgD7I>y^AkH#KAiyek(j+L&Mucq=Uy*kDD#%`#5V&By&-M~n?{IdX^4$Mh zt{zygl)#$^HBc1FCc(uPq!IMVbh zZB(v64xD!%OSCF?$*tm&0Pd`=%oNvQgNf_5j-U}DovDT6lxL1k)H1jfZ3)IP3&4fI z$^2}!c)?|>`;F#;UV+PIfK}fH>B{hKV3WNh%w%w6F8F_l1x#Ey)KJ?PNFQB;QN6Y zHssjPfY{T%>EsBwi48b{N-O*%-$NZ~x`VoN^;F7xA33bm?TzTSAi*HfaAmJjnJUIj zF(YNwU>4EYJ>>joZNY z(T65%W*hovm?0ij2Z;Y2A!^Qoh$G4+JG?#Li{nP_D^a zh$VegSn_NsU893fxgxe6CNDe`w8duWbDSd_{VfBzJ!oWLxiHQ<5J2Xbp;OWPaqJ_X zV&=iJ5{tk@lipm%ch+POCfkRjv(AT}V_Jc%5VweOV4YOKaM?O0nwBicxQ7~ zWIW<=ZX|WnZhDxIOpY<>O|m^YH%vdS-=`l4K0pWCF|MTdusB%Cs?M*;<#GmD|MRbv zmeD0Nugzk+;&|&=>gg!|;_h%~$c2`ybSzk$CaJV!J2dEzA{u_4Z{gd3Xs<3b$$-{OCCa}5>B1*t`qO++! zXX;q3JaV|Zqv)g*ZXz4z7{!Nj2YovgyFDy<5Gin9eYtE6y*3(ywS%#ypC=S7Cu=8f zQ4XQiYBtrw)SOR7KL9@WlH1PYAgV)I@nopn`Lda8uc2vT(i#)eQ5!6@8oi zDICUU#x0`O0#18mDX*U(ajBYp2|Walr0(UHO{mMf)RFd6c!e6^1h^LSjQt>H+b&@{ zgptm{=raG1S1Ad+RWeNSm0cg7lj7{>yyV^OpJJ@`-9}g8Ro`}Hb$lNpT8iT`N z=MUk>@Qt)-Xp;>(7bfK3>jPu>T&hFUENEA^DIK_3Oh0~pb^@B8Gev1i%tiP3qWpb$ zo}Ixq%)A%a4P)SKJ>b2R)D;Ymx!t6sumHUkPKZ(ZYx22x6ILZlo211v!hGrzTFF*K zW%R#s!-8uatM!s#sDDRN6