diff --git a/.pyproject/cmdclass.py b/.pyproject/cmdclass.py index 16855a6b3..1016f77f2 100644 --- a/.pyproject/cmdclass.py +++ b/.pyproject/cmdclass.py @@ -148,6 +148,7 @@ def initialize_options(self): self.no_opencv = None self.cc_debug = None self.cuda_archs = None + self.ort_pkg_dir = None def _parse_options(self, options): for segment in options.split(','): @@ -189,7 +190,8 @@ def build_cmake(self, extension): ext_fullpath = pathlib.Path( self.get_ext_fullpath(extension.name)).absolute() - config = 'RelWithDebInfo' if self.debug else 'Release' +# config = 'RelWithDebInfo' if self.debug else 'Release' + config = 'Debug' if self.debug else 'Release' cmake_args = [ '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(ext_fullpath.parent.absolute()), @@ -199,6 +201,9 @@ def build_cmake(self, extension): '-DCMAKE_BUILD_TYPE=' + config ] + if self.ort_pkg_dir: + cmake_args += ['-DONNXRUNTIME_PKG_DIR=' + self.ort_pkg_dir] + if self.no_opencv: # Disabling openCV can drastically reduce the build time. cmake_args += [ diff --git a/cmake/ext_cuda.cmake b/cmake/ext_cuda.cmake index aa7d3282c..4be088896 100644 --- a/cmake/ext_cuda.cmake +++ b/cmake/ext_cuda.cmake @@ -30,8 +30,8 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no add_compile_definitions(USE_CUDA) -set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use -set(OCOS_USE_FLASH_ATTENTION OFF) +#set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use +#set(OCOS_USE_FLASH_ATTENTION OFF) if (OCOS_USE_FLASH_ATTENTION) message(STATUS "Enable flash attention") add_compile_definitions(OCOS_USE_FLASH_ATTENTION) diff --git a/docs/How_to_write_custom_op.md b/docs/How_to_write_custom_op.md new file mode 100644 index 000000000..d04e04d16 --- /dev/null +++ b/docs/How_to_write_custom_op.md @@ -0,0 +1,51 @@ +# How to write custom ops + +Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order. + +## Basic scenario + +You have 2 ways to write a custom op: by writing a function, or by writing a structure. + +### Custom op in the form of function + +If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like: + +```C++ +const Ort::Custom::Tensor& +// or +const Ort::Custom::Tensor* +``` + +For the inputs that are optional, their type would be like: + +```C++ +std::optional*> +``` + +The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU. + +The function will return the type **OrtStatusPtr** + +Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types. + +### Custom op in the form of structure + +If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions: + +```C++ +OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op + +OrtStatusPtr Compute(...) const // This function computes the customized kernel. +``` + +The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function) + +## Advanced scenario + +In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as: + +```C++ +// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs, +// you can provide your own implementation to specify the ith input is in CPU or GPU. +static OrtMemType GetInputMemoryType(size_t input_index) +``` \ No newline at end of file diff --git a/include/custom_op/custom_op_lite.h b/include/custom_op/custom_op_lite.h index d6a47af84..77951b093 100644 --- a/include/custom_op/custom_op_lite.h +++ b/include/custom_op/custom_op_lite.h @@ -454,7 +454,7 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext { public: static const int cuda_resource_ver = 1; - OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) { + OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api), kernel_context_(ctx) { api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); if (!cuda_stream_) { ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); @@ -521,9 +521,16 @@ class OrtGraphCudaKernelContext : public CUDAKernelContext { int GetCudaDeviceId() const override { return device_id_; } + + void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* mem_info, size_t count_or_bytes) override { + void* ret = nullptr; + api_.KernelContext_GetScratchBuffer(&kernel_context_, mem_info, count_or_bytes, &ret); + return ret; + } private: const OrtApi& api_; + const OrtKernelContext& kernel_context_; OrtAllocator* cpu_allocator_; OrtAllocator* cuda_allocator_; void* cuda_stream_ = {}; diff --git a/include/custom_op/kernel_context.h b/include/custom_op/kernel_context.h index 039cf3bb7..036a300e0 100644 --- a/include/custom_op/kernel_context.h +++ b/include/custom_op/kernel_context.h @@ -2,6 +2,7 @@ #include #include #include +#include "onnxruntime_c_api.h" namespace Ort { namespace Custom { @@ -26,6 +27,7 @@ class CUDAKernelContext : public KernelContext { virtual void* GetCudaStream() const = 0; virtual void* GetCublasHandle() const = 0; virtual int GetCudaDeviceId() const = 0; + virtual void* GetScratchBufferUnderMultiStream(const OrtMemoryInfo* , size_t ) { return nullptr; } }; #endif diff --git a/include/ort_c_to_cpp.h b/include/ort_c_to_cpp.h index 92c2fb01d..152aa6633 100644 --- a/include/ort_c_to_cpp.h +++ b/include/ort_c_to_cpp.h @@ -81,6 +81,9 @@ class API { return instance()->KernelContext_GetAllocator(context, mem_info, out); } #endif + static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) { + return instance()->ReleaseMemoryInfo(mem_info); + } private: const OrtApi* operator->() const { return &api_; diff --git a/operators/cuda/attention_lib/flash_attention/flash.h b/operators/cuda/attention_lib/flash_attention/flash.h index 603a6e068..5f5be4078 100644 --- a/operators/cuda/attention_lib/flash_attention/flash.h +++ b/operators/cuda/attention_lib/flash_attention/flash.h @@ -87,6 +87,13 @@ struct Flash_fwd_params : public Qkv_params { // The indices to index into the KV cache. int* __restrict__ cache_batch_idx = nullptr; + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + float rp_dropout; + // Local window size int window_size_left = -1; int window_size_right = -1; @@ -102,6 +109,9 @@ struct Flash_fwd_params : public Qkv_params { int num_splits = 0; // For split-KV version + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + const cudaDeviceProp* dprops = nullptr; }; diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.cc b/operators/cuda/attention_lib/flash_attention/flash_api.cc index 46812b560..586a7a471 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_api.cc +++ b/operators/cuda/attention_lib/flash_attention/flash_api.cc @@ -32,7 +32,9 @@ void set_params_fprop(Flash_fwd_params& params, bool is_bf16, bool kv_bsnh = true, int window_size_left = -1, - int window_size_right = -1) { + int window_size_right = -1, + bool paged_KV = false, + int page_block_size = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -64,8 +66,8 @@ void set_params_fprop(Flash_fwd_params& params, if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) - params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) - params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.k_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0) + params.v_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0) params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) } else { params.q_batch_stride = 0; @@ -99,6 +101,10 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + params.rp_dropout = 1.f; + params.alibi_slopes_ptr = nullptr; + params.alibi_slopes_batch_stride = 0; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates // local and causal, meaning when we have local window size params.is_causal = is_causal; @@ -349,8 +355,8 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size - void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size @@ -374,7 +380,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded int local_window_size, bool is_rotary_interleaved, - bool is_packed_qkv) { + bool is_packed_qkv, + int32_t* block_table, // batch_size x max_num_blocks_per_seq + int32_t max_num_blocks_per_seq, + int32_t page_block_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -398,7 +407,9 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, is_bf16, past_bsnh, local_window_size, - is_causal ? 0 : -1); + is_causal ? 0 : -1, + block_table != nullptr, + page_block_size); params.dprops = &dprops; if (k_new != nullptr && v_new != nullptr) { @@ -454,6 +465,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, params.oaccum_ptr = nullptr; } + params.block_table = block_table; + params.block_table_batch_stride = max_num_blocks_per_seq; + params.page_block_size = page_block_size; + // Only split kernel supports appending to KV cache run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); diff --git a/operators/cuda/attention_lib/flash_attention/flash_api.h b/operators/cuda/attention_lib/flash_attention/flash_api.h index 4ad1b76e1..07640d4c8 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_api.h +++ b/operators/cuda/attention_lib/flash_attention/flash_api.h @@ -53,8 +53,8 @@ OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table void* k, // batch_size x seqlen_k_new x num_heads_k x head_size void* v, // batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size @@ -78,7 +78,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded int local_window_size = -1, bool is_rotary_interleaved = false, - bool is_packed_qkv = false); + bool is_packed_qkv = false, + int32_t* block_table = nullptr, // batch_size x max_num_blocks_per_seq + int32_t max_num_blocks_per_seq = -1, + int32_t page_block_size = 1); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h index c44a470f6..47263d411 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_kernel.h @@ -28,1027 +28,1006 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, - Tensor2& acc_o, float softmax_scale_log2) { - if (Is_first) { - flash::template reduce_max(scores, scores_max); - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - flash::reduce_sum(scores, scores_sum); - } else { - cute::Tensor scores_max_prev = make_fragment_like(scores_max); - cute::copy(scores_max, scores_max_prev); - flash::template reduce_max(scores, scores_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); -#pragma unroll - for (int mi = 0; mi < cute::size(scores_max); ++mi) { - float scores_max_cur = !Check_inf - ? scores_max(mi) - : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scores_sum(mi) *= scores_scale; -#pragma unroll - for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scores_scale; - } - } - flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); - flash::reduce_sum(scores, scores_sum_cur); -#pragma unroll - for (int mi = 0; mi < cute::size(scores_sum); ++mi) { - scores_sum(mi) += scores_sum_cur(mi); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void write_softmax_to_gmem( - cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { - // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) - cute::Layout l = tOrP.layout(); - cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); - CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); - CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); -#pragma unroll - for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { - cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; - - const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); - int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal || Is_local) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - // We exit early and write 0 to gO and gLSE. + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= n_block_min) { - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - Tensor tOrO = make_tensor(shape(tOgO)); - clear(tOrO); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); -#pragma unroll - for (int m = 0; m < size<1>(tOgO); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSE(row) = INFINITY; + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } } - } - return; + return; } - } - - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - cute::Shape, cute::Int>{}, - make_stride(params.q_row_stride, _1{})); - cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - cute::Shape, cute::Int>{}, - make_stride(params.k_row_stride, _1{})); - cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - cute::Shape, cute::Int>{}, - make_stride(params.v_row_stride, _1{})); - cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), - cute::Shape, cute::Int>{}, - make_stride(params.seqlen_k_rounded, _1{})); - - cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; - cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), - typename Kernel_traits::SmemLayoutKV{}); - cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); - cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; - auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); - - cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); - - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - - cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - - auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // TODO: this might need to change if we change the mma instruction in SM70 - cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); - cute::Tensor scores_sum = make_fragment_like(scores_max); - - // - // PREDICATES - // - - // Construct identity layout for sQ and sK - cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } - // Repeat the partitioning with identity layouts - cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - // Allocate predicate tensors for k - cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); - cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - // Set predicates for k bounds - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < cute::size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } -#pragma unroll - for (int k = 0; k < cute::size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } - } - - // Prologue - - cute::Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); - if (Kernel_traits::Is_Q_in_regs) { - cute::cp_async_fence(); - } - - if (Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<0>(); - __syncthreads(); - cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - __syncthreads(); - } - - int n_block = n_block_max - 1; - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); - cute::cp_async_fence(); - - if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { - flash::cp_async_wait<1>(); - __syncthreads(); - cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); - CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M - cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); - } - clear(acc_o); + // Prologue - // For performance reason, we separate out two kinds of iterations: - // those that need masking on S, and those that don't. - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = (!Is_causal && !Is_local) - ? 1 - : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); -#pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - - // Advance gV - if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - } else { - // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); } - cute::cp_async_fence(); - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); - // if (cute::thread0()) { print(acc_s); } - - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { - flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); - } - } else { - // I can't get the stride from idx_row - flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } - // TODO: when we have key_padding_mask we'll need to Check_inf - masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - - // Convert scores from fp32 to fp16/bf16 - cute::Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - // if (Return_softmax) { - // cute::Tensor tOrP_copy = make_fragment_like(tOrP); - // copy(tOrP, tOrP_copy); - // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - // tPgP.data() = tPgP.data() + (-kBlockN); - // } - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= n_block_min) { - --n_block; - break; + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); +// if (Return_softmax) { +// Tensor rP_drop = make_fragment_like(rP); +// cute::copy(rP, rP_drop); +// dropout.template apply_dropout( +// rP_drop, block_row_idx, block_col_idx, kNWarps +// ); +// cute::copy(rP_drop, tSgS); +// tSgS.data() = tSgS.data() + (-kBlockN); +// } +// if (Is_dropout) { +// dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); +// } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } } - } - - // These are the iterations where we don't need masking on S - for (; n_block >= n_block_min; --n_block) { - cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - cute::cp_async_fence(); - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); +// if (Return_softmax) { +// Tensor rP_drop = make_fragment_like(rP); +// cute::copy(rP, rP_drop); +// dropout.template apply_dropout( +// rP_drop, block_row_idx, block_col_idx, kNWarps +// ); +// cute::copy(rP_drop, tSgS); +// tSgS.data() = tSgS.data() + (-kBlockN); +// } +// if (Is_dropout) { +// dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); +// } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - - cute::Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - // if (Return_softmax) { - // cute::Tensor tOrP_copy = make_fragment_like(tOrP); - // copy(tOrP, tOrP_copy); - // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); - // tPgP.data() = tPgP.data() + (-kBlockN); - // } + // Epilogue - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - } + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); - // Epilogue + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - cute::Tensor lse = make_fragment_like(scores_sum); -#pragma unroll - for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; -#pragma unroll - for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; - } - } + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } - // Convert acc_o from fp32 to fp16/bf16 - cute::Tensor rO = flash::convert_type(acc_o); - cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); - cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // sO has the same size as sQ, so we don't need to sync here. - if (Kernel_traits::Share_Q_K_smem) { - __syncthreads(); - } + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - cute::Shape, cute::Int>{}, - make_stride(params.o_row_stride, _1{})); - cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - cute::Shape>{}, cute::Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) - cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - - __syncthreads(); + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + __syncthreads(); - cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(cute::size<0>(taccOcO))::value == 4); - // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { -#pragma unroll - for (int mi = 0; mi < cute::size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM) { - gLSE(row) = lse(mi); - } + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } } - } - // Construct identity layout for sO - cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < cute::size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNWarps = Kernel_traits::kNWarps; - - using GmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::GmemTiledCopyOaccum, - typename Kernel_traits::GmemTiledCopyO>; - using ElementO = std::conditional_t; - - const BlockInfo binfo(params, bidb); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } - // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = !Is_local - ? n_split_idx * n_blocks_per_split - : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); - int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal || Is_local) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); - } - if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 - // We exit early and write 0 to gOaccum and -inf to gLSEaccum. - // Otherwise we might read OOB elements from gK and gV, - // or get wrong results when we combine gOaccum from different blocks. - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), - Shape>{}, Stride<_1>{}); - - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrOaccum); - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; - } + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum + >; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); -#pragma unroll - for (int m = 0; m < size<1>(tOgOaccum); ++m) { - const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { - gLSEaccum(row) = Split ? -INFINITY : INFINITY; - } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + } + return; } - return; - } - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQ{}); - Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); - Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - - typename Kernel_traits::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(tidx); - Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) - - Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); - - auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); - auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_K.partition_S(sK); - - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - - // - // PREDICATES - // - - // // Allocate predicate tensors for m and n - // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); - // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); - - // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - // Allocate predicate tensors for k - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - // Set predicates for k bounds - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { - tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; - } -#pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { - tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } - } - // Prologue - // Copy from Knew to K, optionally apply rotary embedding. - typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; - auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; - auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); - if constexpr (Append_KV) { - // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to - // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. - // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); - Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(params.rotary_dim / 2, _1{})); - Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); - // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } - // if (cute::thread(8, 0)) { print_tensor(gCos); } - // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) - - const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); - for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { - flash::copy_w_min_idx( - tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - if (params.rotary_dim == 0) { - flash::copy_w_min_idx( - tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } else { - if (params.is_rotary_interleaved) { - // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_interleaved( - tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, - binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); - tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); - tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); - } else { - // Don't clear OOB_K because we're writing to global memory - flash::copy_rotary_contiguous( - tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, - binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); - tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); - tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } } - } - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; } - // Need this before we can read in K again, so that we'll see the updated K values. - __syncthreads(); - if (n_block_max > n_block_copy_min) { - tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; - tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; - } - } - // Read Q from gmem to smem, optionally apply rotary embedding. - Tensor tQrQ = make_fragment_like(tQgQ); - if (!Append_KV || params.rotary_dim == 0) { - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); - } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); - // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. - // We do this by setting the row stride of gCos / gSin to 0. - Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), - Shape, Int>{}, - make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); - if (params.is_rotary_interleaved) { - flash::copy_rotary_interleaved( - tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, - 0, params.d, params.rotary_dim); - } else { - flash::copy_rotary_contiguous( - tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, - 0, params.d, params.rotary_dim); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } } - } - - int n_block = n_block_max - 1; - // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); - cute::cp_async_fence(); - - // flash::cp_async_wait<0>(); - // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } - // __syncthreads(); - - clear(acc_o); - - // For performance reason, we separate out two kinds of iterations: - // those that need masking on S, and those that don't. - // We need masking on S for the very last block when K and V has length not multiple of kBlockN. - // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. - // We will have at least 1 "masking" iteration. - - // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to - // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = (!Is_causal && !Is_local) - ? 1 - : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); -#pragma unroll - for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - } else { - // Clear the smem tiles to account for predicated off loads - flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); - } + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); - // if (cute::thread0()) { print(acc_s); } - - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { - flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); - } - } else { - flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); - } - - flash::cp_async_wait<0>(); - __syncthreads(); - // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } // __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - // We have key_padding_mask so we'll need to Check_inf - masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - - // Convert scores from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // if (cute::thread0()) { print(scores); } - - // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= n_block_min) { - --n_block; - break; - } - } + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - // These are the iterations where we don't need masking on S - for (; n_block >= n_block_min; --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); - flash::cp_async_wait<0>(); - __syncthreads(); - // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - cute::cp_async_fence(); - - flash::gemm( - acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - flash::cp_async_wait<0>(); - __syncthreads(); - if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right); + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } - flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - // Epilogue + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); - // if (cute::thread0()) { print(acc_o_rowcol); } - Tensor lse = make_fragment_like(scores_sum); -#pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; -#pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } - } - // if (cute::thread0()) { print(lse); } - // if (cute::thread0()) { print(acc_o_rowcol); } - - Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) - // Partition sO to match the accumulator partitioning - using SmemTiledCopyO = std::conditional_t< - !Split, - typename Kernel_traits::SmemCopyAtomO, - typename Kernel_traits::SmemCopyAtomOaccum>; - auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); - auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor rO = flash::convert_type(acc_o); - Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - // sOaccum is larger than sQ, so we need to syncthreads here - // TODO: allocate enough smem for sOaccum - if constexpr (Split) { - __syncthreads(); - } - cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); - - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), - Shape, Int>{}, - make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), - Shape>{}, Stride<_1>{}); - // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + // if (cute::thread0()) { print(lse); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - GmemTiledCopyO gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } - __syncthreads(); + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + __syncthreads(); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) - static_assert(decltype(size<0>(taccOcO))::value == 4); - // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. - Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - if (get<1>(taccOcO_row(0)) == 0) { -#pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccOcO_row(mi)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM) { - gLSEaccum(row) = lse(mi); - } + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } } - } - // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { -#pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); - // __syncthreads(); - // if (cute::thread0()) { print(tOgOaccum); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1064,12 +1043,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1078,7 +1057,7 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h index e2f2505a7..750305fd4 100644 --- a/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h +++ b/operators/cuda/attention_lib/flash_attention/flash_fwd_launch_template.h @@ -9,20 +9,20 @@ namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); + flash::compute_attn(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else (void)params; #endif @@ -38,7 +38,7 @@ __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { #endif } -template +template void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { constexpr size_t smem_size = Kernel_traits::kSmemSize; @@ -53,23 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); }); }); }); @@ -90,16 +92,18 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); }); }); }); @@ -143,7 +147,7 @@ template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); }); } @@ -154,7 +158,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); }); @@ -168,12 +172,12 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -192,12 +196,12 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. if (is_sm8x) { if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -220,12 +224,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { // and 128 x 64 with 8 warps is the fastest for non-causal. if (is_sm8x) { if constexpr (!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -241,7 +245,7 @@ template void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { constexpr int Headdim = 192; BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd>(params, stream); @@ -257,9 +261,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // run_flash_fwd, Is_causal>(params, stream); // run_flash_fwd, Is_causal>(params, stream); @@ -280,9 +284,9 @@ void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, false /*Is_dropout*/, Is_causal>(params, stream); } // 64 KB // run_flash_fwd, Is_causal>(params, stream); diff --git a/operators/cuda/attention_lib/flash_attention/softmax.h b/operators/cuda/attention_lib/flash_attention/softmax.h index 9c31336c9..a70406aed 100644 --- a/operators/cuda/attention_lib/flash_attention/softmax.h +++ b/operators/cuda/attention_lib/flash_attention/softmax.h @@ -54,10 +54,10 @@ __device__ inline void reduce_max(Tensor const& tensor, Tensor reduce_(tensor, max, max_op); } -template -__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { - SumOp sum_op; - reduce_(tensor, sum, sum_op); +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. @@ -212,4 +212,168 @@ inline __device__ void apply_mask_causal_w_idx( } } +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +template +struct Mask { + + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope=0.f) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + +}; + } // namespace flash diff --git a/operators/cuda/attention_lib/flash_attention/utils.h b/operators/cuda/attention_lib/flash_attention/utils.h index cd10bd534..f638a232a 100644 --- a/operators/cuda/attention_lib/flash_attention/utils.h +++ b/operators/cuda/attention_lib/flash_attention/utils.h @@ -198,6 +198,28 @@ inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { @@ -212,6 +234,25 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. template diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 39cc02f85..8770bb42a 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -5,6 +5,9 @@ #ifdef USE_CUDA #include "cuda/fast_gelu.h" +#if ORT_API_VERSION >= 18 +#include "cuda/paged_attention.h" +#endif #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { @@ -13,8 +16,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { #ifdef USE_CUDA , CustomCudaStructV2("FastGelu", contrib::FastGelu), +#if ORT_API_VERSION >= 18 + CustomCudaStructV2("PagedAttention", PagedAttention), +#endif #if ORT_API_VERSION >= 16 - CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu) #endif diff --git a/operators/cuda/paged_attention.h b/operators/cuda/paged_attention.h new file mode 100644 index 000000000..1dee204d1 --- /dev/null +++ b/operators/cuda/paged_attention.h @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "cuda_type.h" +#include "paged_attention_impl.h" +#include "device_prop.cuh" +#ifdef OCOS_USE_FLASH_ATTENTION +#include "attention_lib/flash_attention/flash_api.h" +#endif + +template +using UniquePtrWithDeletor = std::unique_ptr>; + +template +inline UniquePtrWithDeletor GetScratchBuffer(void* p, OrtAllocator* allocator) { + return UniquePtrWithDeletor{static_cast(p), [allocator = std::move(allocator)](T* p) { + allocator->Free(allocator, p); + }}; +} + +template +OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor& query, const ortc::Tensor& context_lens, + int32_t num_heads, int32_t num_kv_heads, int32_t head_size, float scale, bool prompt_mode, PackedAttentionParameters& parameters) { + const std::vector& query_shape = query.Shape(); + if (query_shape.size() < 2 || query_shape.size() > 3) { + return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); + } + if (query_shape.back() != num_heads * head_size) { + return OrtW::CreateStatus(MakeString("Hidden size should equal to num_heads_ * head_size_"), ORT_INVALID_ARGUMENT); + } + + parameters.batch_size = context_lens.NumberOfElement(); + parameters.sequence_length = 1; + parameters.token_count = 0; + parameters.valid_token_count = query_shape[0]; + parameters.causal = true; + parameters.head_size = head_size; + parameters.num_heads = num_heads; + parameters.num_kv_heads = num_kv_heads; + parameters.scale = scale; + parameters.hidden_size = static_cast(head_size * num_heads); + parameters.v_hidden_size = static_cast(head_size * num_kv_heads); + parameters.v_head_size = static_cast(parameters.head_size); + return nullptr; +} + +template +struct PagedAttention { + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 7 || input_index == 8) return OrtMemType::OrtMemTypeCPUInput; // make context_lens and is_prompt CPU input + return OrtMemType::OrtMemTypeDefault; + } + + using TT = typename contrib::CudaT::MappedType; + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + int64_t num_heads = 0, head_size = 0; + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "num_heads", &num_heads)); + assert(num_heads > 0); + num_heads_ = static_cast(num_heads); + num_kv_heads_ = static_cast(OrtW::GetOpAttributeOrDefault(info, "num_kv_heads", num_heads)); + + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "head_size", &head_size)); + assert(head_size > 0); + head_size_ = static_cast(head_size); + + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_float(&info, "scale", &scale_)); + assert(scale_ >= 0); + + num_queries_per_kv_ = num_heads_ / num_kv_heads_; + OrtAllocator* allocator = nullptr; + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator)); + allocator_ = UniquePtrWithDeletor{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; + return nullptr; + } + + OrtStatusPtr RunMultiHeadAttention(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, const ortc::Tensor& value, + T* output, PackedAttentionParameters& parameters, const int32_t* seqinfo) const { + PackedMultiHeadAttentionData data; + data.use_flash_attention = false; + data.use_memory_efficient_attention = false; +#if OCOS_USE_FLASH_ATTENTION + data.use_flash_attention = true; +#endif +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + data.use_memory_efficient_attention = true; +#endif + data.query = reinterpret_cast(query.DataRaw()); + data.key = reinterpret_cast(key.DataRaw()); + data.value = reinterpret_cast(value.DataRaw()); + + // TODO(leca): +// // broadcast key,value for GQA +// TensorShape key_shape({parameters.valid_token_count, parameters.num_kv_heads, parameters.head_size}); +// size_t kv_repeat_space = key_shape.Size() * (num_queries_per_kv_ > 0 ? num_queries_per_kv_ : 0); +// IAllocatorUniquePtr key_out = GetScratchBuffer(kv_repeat_space, context->GetComputeStream()); +// IAllocatorUniquePtr value_out = GetScratchBuffer(kv_repeat_space, context->GetComputeStream()); +// if (num_queries_per_kv_ > 1 && !ParseEnvironmentVariableWithDefault("repeat_kv_tile", false)) { +// // repeat key and value +// LaunchRepeatKeyValue(Stream(context), key_out.get(), value_out.get(), +// data.key, data.value, key_shape.GetDims().data(), num_queries_per_kv_); +// CHECK_CUDA_ERROR(); +// data.key = key_out.get(); +// data.value = value_out.get(); +// parameters.num_kv_heads = parameters.num_heads; +// DumpTensor(Stream(context), data.key, "repeat_key", kv_repeat_space * sizeof(CudaT)); +// } + + size_t workSpaceSize = cuda::GetAttentionWorkspaceSize(sizeof(T), parameters.batch_size, parameters.num_heads, parameters.head_size, parameters.v_head_size, + parameters.sequence_length, nullptr, data.use_flash_attention, data.use_memory_efficient_attention, true); + UniquePtrWithDeletor workspace_unique = GetScratchBuffer(allocator_->Alloc(allocator_.get(), workSpaceSize), allocator_.get()); + data.workspace = reinterpret_cast(workspace_unique.get()); + data.cumulative_sequence_length = seqinfo; + data.output = reinterpret_cast(output); + data.fused_runner = nullptr; + data.no_qkv_workspace = data.fused_runner == nullptr || data.use_flash_attention || data.use_memory_efficient_attention; + data.source_qkv_format = data.key == nullptr ? AttentionQkvFormat::QKV_TN3H : AttentionQkvFormat::Q_K_V_TNH; + return cuda::QkvToContext(reinterpret_cast(ctx->GetCudaStream()), parameters, data); + } + + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& query, const ortc::Tensor& key, + const ortc::Tensor& value, const ortc::Tensor& key_cache, const ortc::Tensor& value_cache, + const ortc::Tensor& block_tables, const ortc::Tensor& slot_mappings, + const ortc::Tensor& context_lens, const ortc::Tensor& is_prompt, + std::optional*> cos_sin_cache, + std::optional*> positions, ortc::Tensor& attn_out) const { + bool prompt_mode = *(is_prompt.Data()) == 1; + PackedAttentionParameters parameters; + ORTX_RETURN_IF_ERROR(CheckInputs(reinterpret_cast(ctx->GetCudaStream()), allocator_.get(), query, + context_lens, num_heads_, num_kv_heads_, head_size_, scale_, prompt_mode, parameters)); + + UniquePtrWithDeletor seqinfo; + UniquePtrWithDeletor position_ids; + if (prompt_mode) { + parameters.token_count = parameters.valid_token_count; + + std::vector seqstart(context_lens.NumberOfElement() + 1, 0); + for (int64_t i = 0; i < context_lens.NumberOfElement(); i++) { + int32_t seqlen_i = *(context_lens.Data()+i); + if (seqlen_i > parameters.sequence_length) parameters.sequence_length = seqlen_i; + seqstart[i+1] = seqstart[i] + seqlen_i; + } + seqinfo = GetScratchBuffer(allocator_.get()->Alloc(allocator_.get(), seqstart.size() * sizeof(int32_t)), allocator_.get()); + cudaMemcpy(seqinfo.get(), seqstart.data(), seqstart.size() * sizeof(int32_t), cudaMemcpyHostToDevice); + } else { + seqinfo = GetScratchBuffer(allocator_.get()->Alloc(allocator_.get(), context_lens.SizeInBytes()), allocator_.get()); + cudaMemcpy(seqinfo.get(), context_lens.DataRaw(), context_lens.SizeInBytes(), cudaMemcpyHostToDevice); + } + + if (cos_sin_cache.has_value() && !positions.has_value()) { + std::vector position_ids_host; + if (prompt_mode) { + for (int64_t i = 0; i < context_lens.NumberOfElement(); i++) { + int32_t seqlen_i = *(context_lens.Data()+i); + if (seqlen_i == 0) continue; + std::vector position_id(seqlen_i); + std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with [0, 1, 2, ...seqlen_i) + position_ids_host.insert(position_ids_host.end(), position_id.begin(), position_id.end()); + } + } else position_ids_host.assign(parameters.batch_size, 0); // TODO(leca): Does decoding case support seqlen_knew > 1? + + position_ids = GetScratchBuffer(allocator_.get()->Alloc(allocator_.get(), position_ids_host.size() * sizeof(int32_t)), allocator_.get()); + cudaMemcpy(position_ids.get(), position_ids_host.data(), position_ids_host.size() * sizeof(int32_t), cudaMemcpyHostToDevice); + } + + const std::vector& query_shape = query.Shape(); + T* output_data = attn_out.Allocate(query_shape); + + if (cos_sin_cache.has_value()) { + int64_t rot_dim = (*cos_sin_cache)->Shape()[1]; + assert(rot_dim == head_size_); + cuda::rotary_embedding_neox(reinterpret_cast(ctx->GetCudaStream()), positions.has_value() ? (*positions)->Data() : position_ids.get(), + const_cast(query.DataRaw()), const_cast(key.DataRaw()), head_size_, (*cos_sin_cache)->DataRaw(), parameters.valid_token_count, rot_dim, num_heads_, num_kv_heads_); + } + + const std::vector& key_cache_shape = key_cache.Shape(); + int block_size = key_cache_shape[1] / (num_kv_heads_ * head_size_); + if (parameters.valid_token_count > 0) { + int32_t key_shape_r[3] = {parameters.valid_token_count, num_kv_heads_, head_size_}; + int32_t value_shape_r[3] = {parameters.valid_token_count, num_kv_heads_, head_size_}; + // TODO(leca): or we just pass num_valid_tokens, num_kv_head, head_size and block_size as parameter? + cuda::reshape_and_cache(reinterpret_cast(ctx->GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(), + key_shape_r, value_shape_r, block_size); + } + + if (prompt_mode) { + return RunMultiHeadAttention(ctx, query, key, value, output_data, parameters, seqinfo.get()); // Don't handle prompt with decoding case for now + } + +#ifdef OCOS_USE_FLASH_ATTENTION + int seqlen_knew = 1; // TODO(leca): Decoding case, the sequence of k will always be 1? + int max_num_blocks_per_seq = block_tables.Shape()[1]; + int seqlen_k = max_num_blocks_per_seq * block_size; + parameters.causal = false; // flash code: if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + size_t workSpaceSize = cuda::GetAttentionWorkspaceSize(sizeof(T), parameters.batch_size, parameters.num_heads, parameters.head_size, parameters.v_head_size, + seqlen_knew, nullptr, true/*data.use_flash_attention*/, false/*data.use_memory_efficient_attention*/, true); + UniquePtrWithDeletor workspace_unique = GetScratchBuffer(allocator_->Alloc(allocator_.get(), workSpaceSize), allocator_.get()); // for softmax_lse + const cudaDeviceProp& device_prop = DeviceProp::GetCudaDeviceProp(); + return flash::mha_fwd_kvcache(device_prop, reinterpret_cast(ctx->GetCudaStream()), const_cast(query.DataRaw()), const_cast(key_cache.DataRaw()), + const_cast(value_cache.DataRaw()), const_cast(key.DataRaw()), const_cast(value.DataRaw()), output_data, + workspace_unique.get(), seqinfo.get(), + nullptr, nullptr, // rotary_sin and rotary_cos. TODO(leca): Do we still split the input cos_sin_cache as there is a seperate step to do rotary embedding + query_shape[0], num_heads_, num_kv_heads_, head_size_, 1 /*seqlen_q*/, seqlen_k, seqlen_knew, 1.0f/sqrt(head_size_), parameters.causal, false, true, + 1 /*num_splits*/, nullptr, nullptr, -1 /*local_window_size*/, false, false, const_cast(block_tables.Data()), max_num_blocks_per_seq, block_size); +#endif + } + +private: + int32_t num_heads_; // number of attention heads + int32_t num_kv_heads_; // number of attention kv_heads + int32_t head_size_; // number of attention heads + float scale_; // sqrt(head_size_) + int32_t num_queries_per_kv_; + UniquePtrWithDeletor allocator_; // make allocator_ declared first in order to release it last +}; \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.cu b/operators/cuda/paged_attention_impl.cu new file mode 100644 index 000000000..192382f92 --- /dev/null +++ b/operators/cuda/paged_attention_impl.cu @@ -0,0 +1,354 @@ +#include "paged_attention_impl.h" +#include "utils.cuh" +#include "device_prop.cuh" +#ifdef OCOS_USE_FLASH_ATTENTION +#include "attention_lib/flash_attention/flash_api.h" +#endif +#ifdef OCOS_USE_MEMORY_EFFICIENT_ATTENTION +#include "attention_lib/cutlass_fmha/memory_efficient_attention.h" +#endif +#include +#include + +namespace cuda { + +namespace vllm { + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, head_size] + const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int token_idx = blockIdx.x; + const int slot_idx = slot_mapping[token_idx]; + const int block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int src_key_idx = token_idx * key_stride + i; + const int src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + const int tgt_value_idx = block_idx * num_heads * head_size * block_size + block_offset * num_heads * head_size + head_idx * head_size + head_offset; + const int tgt_key_idx = tgt_value_idx; + //{ + // if (key_cache[tgt_key_idx] - key[src_key_idx] > half(0.1)) { + // printf("key error find, %d,%d ", tgt_key_idx, src_key_idx); + // } + // if (value_cache[tgt_value_idx] - value[src_value_idx] > half(0.1)) { + // printf("key error find, %d %d", tgt_value_idx, src_value_idx); + // } + //} + key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); + value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + } +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +__global__ void rotary_embedding_kernel( + const int32_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int query_stride, + const int key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int32_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } +} +} // namespace vllm + +void rotary_embedding_neox( + const cudaStream_t stream, + const int32_t* positions, // [num_tokens] + void* query, // [num_tokens, num_heads * head_size] + void* key, // [num_tokens, num_kv_heads * head_size] + int head_size, + const void* cos_sin_cache, // [max_position, rot_dim] + int num_tokens, + int rot_dim, + int num_heads, + int num_kv_heads) { + const bool is_neox = true; + int query_stride = num_heads * head_size; + int key_stride = num_kv_heads * head_size; + // TORCH_CHECK(stride == key.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + + // half + using scalar_t = half; + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions, + static_cast(query), + static_cast(key), + static_cast(cos_sin_cache), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::rotary_embedding_kernel<<>>( + positions, + static_cast(query), + static_cast(key), + static_cast(cos_sin_cache), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } +} + +void reshape_and_cache( + const cudaStream_t stream, + const void* key, // [num_tokens, num_heads, head_size] + const void* value, // [num_tokens, num_heads, head_size] + const void* key_cache, // [num_blocks, block_size, num_heads, head_size] + const void* value_cache, // [num_blocks, block_size, num_heads, head_size] + const int* slot_mapping, // [num_tokens] + const int32_t* key_shapes, + const int32_t* value_shapes, + const int64_t block_size) { + int num_tokens = key_shapes[0]; + int num_heads = key_shapes[1]; + int head_size = key_shapes[2]; + // int block_size = key_cache.size(3); + + int key_stride = key_shapes[1] * key_shapes[2]; + int value_stride = value_shapes[1] * value_shapes[2]; + + // static_assert(std::is_same_v, "Unsupported data type: "); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + + vllm::reshape_and_cache_kernel<<>>( + (const half*)key, + (const half*)value, + (half*)key_cache, + (half*)value_cache, + slot_mapping, + key_stride, + value_stride, + num_heads, + head_size, + block_size); +} + +#if OCOS_USE_FLASH_ATTENTION +template +OrtStatusPtr FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int num_kv_heads = parameters.num_kv_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // Q, K and V pointers + const int model_dimension_qk = num_heads * qk_head_size; + const int model_dimension_v = num_kv_heads * v_head_size; + const size_t elements_qk = static_cast(parameters.token_count) * static_cast(model_dimension_qk); + const size_t elements_v = static_cast(parameters.token_count) * static_cast(model_dimension_v); + + // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH + // TODO(leca): +// if (!data.no_qkv_workspace) { +// LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, +// batch_size, sequence_length, +// num_heads, qk_head_size, v_head_size, +// data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, +// data.token_offset, parameters.token_count, stream); +// } + + float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; + int32_t* cu_seqlens_q = const_cast(data.cumulative_sequence_length); + int32_t* cu_seqlens_k = const_cast(data.cumulative_sequence_length); + const void* query = data.no_qkv_workspace ? data.query : data.workspace; + const void* key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); + const void* value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); + void* softmax_lse_buffer = data.no_qkv_workspace + ? data.workspace + : (data.workspace + elements_qk + elements_v + elements_v); + + ORTX_RETURN_IF_ERROR( + flash::mha_varlen_fwd( + device_prop, + stream, + const_cast(query), + const_cast(key), + const_cast(value), + data.output, + cu_seqlens_q, + cu_seqlens_k, + softmax_lse_buffer, + batch_size, + num_heads, + num_kv_heads, // num_heads_k + qk_head_size, + sequence_length, + sequence_length, + scale, + parameters.causal, // is causal + false // is_bf16 TODO(leca) + )); + + return nullptr; +} +#endif + +template +OrtStatusPtr QkvToContext( + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data) { + const cudaDeviceProp& device_prop = DeviceProp::GetCudaDeviceProp(); +#if OCOS_USE_FLASH_ATTENTION + return FlashAttention(device_prop, stream, parameters, data); +#endif +#if OCOS_USE_MEMORY_EFFICIENT_ATTENTION + // TODO(leca): + //return FusedAttentionCutlass(device_prop, stream, parameters, data); +#endif + return nullptr; +} + +//template OrtStatusPtr QkvToContext( +// cudaStream_t stream, +// PackedAttentionParameters& parameters, +// PackedMultiHeadAttentionData& data); + +template OrtStatusPtr QkvToContext( + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data); + +constexpr size_t kCUDAMemoryAlignment = 256; + +size_t GetAttentionScratchSize( + size_t element_size, + size_t batch_size, + size_t num_heads, + size_t sequence_length) { + const size_t bytes = element_size * batch_size * num_heads * sequence_length * sequence_length; + return ((bytes + kCUDAMemoryAlignment - 1) / kCUDAMemoryAlignment) * kCUDAMemoryAlignment; +} + +size_t GetAttentionWorkspaceSize( + size_t element_size, + size_t batch_size, + size_t num_heads, + size_t qk_head_size, + size_t v_head_size, + size_t sequence_length, + void* fused_runner, + bool use_flash_attention, + bool use_memory_efficient_attention, + bool no_qkv_workspace) { + // Note that q, k and v might need alignment for fused attention kernels. + const size_t qkv_bytes = no_qkv_workspace ? 0 : (element_size * batch_size * num_heads * sequence_length * (qk_head_size + qk_head_size + v_head_size)); + + // Use portion of workspace for softmax buffer. + if (use_flash_attention) { + size_t flash_buffer_bytes = flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + return qkv_bytes + flash_buffer_bytes; + } + + if (fused_runner != nullptr) { + return qkv_bytes; + } + +//#if USE_MEMORY_EFFICIENT_ATTENTION +// if (use_memory_efficient_attention) { +// size_t fmha_buffer_bytes = 0; +// if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) { +// fmha_buffer_bytes = batch_size * sequence_length * num_heads * v_head_size * sizeof(float); +// } +// return qkv_bytes + fmha_buffer_bytes; +// } +//#else +// ORT_UNUSED_PARAMETER(use_memory_efficient_attention); +//#endif + + return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length); +} + +} // namespace cuda \ No newline at end of file diff --git a/operators/cuda/paged_attention_impl.h b/operators/cuda/paged_attention_impl.h new file mode 100644 index 000000000..d27e304ee --- /dev/null +++ b/operators/cuda/paged_attention_impl.h @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ortx_common.h" +#include +#include + +enum AttentionQkvFormat { + UNKNOWN, // enum value not set, or depends on qkv projection implementation details + Q_K_V_BNSH, // for non-packed qkv, permuted + Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed +}; + +struct PackedAttentionParameters { + int batch_size; + int sequence_length; + int input_hidden_size; // hidden size of input + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V + int num_heads; + int num_kv_heads; + float scale; + int token_count; + int valid_token_count; + bool has_relative_position_bias; + bool broadcast_res_pos_bias; + bool causal; +}; + +template +struct PackedMultiHeadAttentionData { + const T* query; + const T* key; + const T* value; + const T* bias; + const T* relative_position_bias; + const int32_t* token_offset; + const int32_t* cumulative_sequence_length; + + AttentionQkvFormat source_qkv_format; + + bool no_qkv_workspace; + T* workspace; + T* output; + + void* fused_runner; + + bool use_flash_attention; + bool use_memory_efficient_attention; +}; + +namespace cuda { +void reshape_and_cache( + const cudaStream_t stream, + const void* key, // [num_tokens, num_heads, head_size] + const void* value, // [num_tokens, num_heads, head_size] + const void* key_cache, // [num_blocks, block_size, num_heads, head_size] + const void* value_cache, // [num_blocks, block_size, num_heads, head_size] + const int* slot_mapping, // [num_tokens] + const int32_t* key_shapes, + const int32_t* value_shapes, + const int64_t block_size); +// void* kv_quant_param = nullptr, // [num_blocks, 2, num_heads, head_size / kv_quant_chunk_size, block_size] +// const int kv_quant_chunk_size = 0, +// const int kv_quant_param_dtype = 1); + +void rotary_embedding_neox( + const cudaStream_t stream, + const int32_t* positions, // [num_tokens] + void* query, // [num_tokens, num_heads * head_size] + void* key, // [num_tokens, num_kv_heads * head_size] + int head_size, + const void* cos_sin_cache, // [max_position, rot_dim] + int num_tokens, + int rot_dim, + int num_heads, + int num_kv_heads); + +template +OrtStatusPtr QkvToContext( + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data); + +size_t GetAttentionWorkspaceSize( + size_t element_size, + size_t batch_size, + size_t num_heads, + size_t qk_head_size, + size_t v_head_size, + size_t sequence_length, + void* fused_runner, + bool use_flash_attention, + bool use_memory_efficient_attention, + bool no_qkv_workspace); + +} // namespace cuda \ No newline at end of file diff --git a/operators/cuda/paged_dtype_float16.cuh b/operators/cuda/paged_dtype_float16.cuh new file mode 100644 index 000000000..132c8154c --- /dev/null +++ b/operators/cuda/paged_dtype_float16.cuh @@ -0,0 +1,469 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "paged_generic.cuh" +#include "paged_dtype_float32.cuh" + +#include +namespace cuda { +namespace vllm { + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" + : "=r"(b) + : "h"(a)); + return b; +} + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" + : "=f"(f) + : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" + : "=h"(lo), "=h"(hi) + : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" + : "=h"(tmp.u16[0]) + : "f"(f)); + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" + : "=h"(tmp.u16[0]) + : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" + : "=h"(tmp.u16[1]) + : "f"(f.y)); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" + : "=h"(c) + : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" + : "=r"(c) + : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" + : "=h"(c) + : "h"(a), "h"(b)); + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" + : "=r"(c) + : "r"(a), "r"(b)); + return c; +} + +template <> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template <> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template <> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template <> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { + return half_to_float(u); +} + +inline __device__ float2 to_float(uint32_t u) { + return half2_to_float2(u); +} + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a vector. +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} + +} // namespace vllm +} // namespace cuda diff --git a/operators/cuda/paged_dtype_float32.cuh b/operators/cuda/paged_dtype_float32.cuh new file mode 100644 index 000000000..93f2c1f5c --- /dev/null +++ b/operators/cuda/paged_dtype_float32.cuh @@ -0,0 +1,274 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "paged_generic.cuh" + +#include +namespace cuda { +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { + return a + b; +} + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template <> +inline __device__ float sum(float v) { + return v; +} + +template <> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template <> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template <> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template <> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { + return a * b; +} + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { + dst = src; +} + +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} + +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} + +// From float to float. +inline __device__ float to_float(float u) { + return u; +} + +inline __device__ float2 to_float(float2 u) { + return u; +} + +inline __device__ float4 to_float(float4 u) { + return u; +} + +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} + +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} + +// Zero-out a variable. +inline __device__ void zero(float& dst) { + dst = 0.f; +} + +} // namespace vllm +} // namespace cuda diff --git a/operators/cuda/paged_generic.cuh b/operators/cuda/paged_generic.cuh new file mode 100644 index 000000000..b35500f54 --- /dev/null +++ b/operators/cuda/paged_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +namespace cuda { +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm +} // namespace cuda diff --git a/operators/cuda/paged_utils.cuh b/operators/cuda/paged_utils.cuh new file mode 100644 index 000000000..3f01ece8c --- /dev/null +++ b/operators/cuda/paged_utils.cuh @@ -0,0 +1,59 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "paged_generic.cuh" +#include "paged_dtype_float16.cuh" +#include "paged_dtype_float32.cuh" + +#include +#include +namespace cuda { + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm +} // namespace cuda diff --git a/operators/cuda/utils.cuh b/operators/cuda/utils.cuh index fe3d27daa..322e0c44f 100644 --- a/operators/cuda/utils.cuh +++ b/operators/cuda/utils.cuh @@ -191,3 +191,8 @@ __device__ __inline__ half2 _Tanh(half2 a) { template <> __device__ __inline__ ortc::BFloat16 _Tanh(ortc::BFloat16 a) { return tanhf(static_cast(a)); } + +inline OrtStatusPtr CudaCall(cudaError_t cuda_error) { + if (cuda_error == cudaSuccess) return nullptr; + return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str()); +} \ No newline at end of file diff --git a/test/cuda/block_table_2x3.npy b/test/cuda/block_table_2x3.npy new file mode 100644 index 000000000..b0163b5d2 Binary files /dev/null and b/test/cuda/block_table_2x3.npy differ diff --git a/test/cuda/cache_seqlens_2.npy b/test/cuda/cache_seqlens_2.npy new file mode 100644 index 000000000..fd2ad0b6a Binary files /dev/null and b/test/cuda/cache_seqlens_2.npy differ diff --git a/test/cuda/k_2x1x6x16.npy b/test/cuda/k_2x1x6x16.npy new file mode 100644 index 000000000..9eb34c3b9 Binary files /dev/null and b/test/cuda/k_2x1x6x16.npy differ diff --git a/test/cuda/k_cache_6x256x6x16.npy b/test/cuda/k_cache_6x256x6x16.npy new file mode 100644 index 000000000..49cebd4e5 Binary files /dev/null and b/test/cuda/k_cache_6x256x6x16.npy differ diff --git a/test/cuda/key.npy b/test/cuda/key.npy new file mode 100644 index 000000000..a9ec93b04 Binary files /dev/null and b/test/cuda/key.npy differ diff --git a/test/cuda/key_381x512_float16.npy b/test/cuda/key_381x512_float16.npy new file mode 100644 index 000000000..a218f76ff Binary files /dev/null and b/test/cuda/key_381x512_float16.npy differ diff --git a/test/cuda/o381x512_ortx.npy b/test/cuda/o381x512_ortx.npy new file mode 100644 index 000000000..3f6630c25 Binary files /dev/null and b/test/cuda/o381x512_ortx.npy differ diff --git a/test/cuda/o_2x96.npy b/test/cuda/o_2x96.npy new file mode 100644 index 000000000..07c441c7a Binary files /dev/null and b/test/cuda/o_2x96.npy differ diff --git a/test/cuda/q_2x1x6x16.npy b/test/cuda/q_2x1x6x16.npy new file mode 100644 index 000000000..11d718525 Binary files /dev/null and b/test/cuda/q_2x1x6x16.npy differ diff --git a/test/cuda/query.npy b/test/cuda/query.npy new file mode 100644 index 000000000..274a1851c Binary files /dev/null and b/test/cuda/query.npy differ diff --git a/test/cuda/query_381x512_float16.npy b/test/cuda/query_381x512_float16.npy new file mode 100644 index 000000000..53b842b33 Binary files /dev/null and b/test/cuda/query_381x512_float16.npy differ diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index d868fe675..ec9f8127d 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -6,7 +6,133 @@ from onnxruntime_extensions import get_library_path as _get_library_path import onnxruntime as _ort +import torch +from einops import rearrange, repeat +import math +import pdb +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + +def generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): + num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + block_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + # pytorch 1.12 doesn't have indexing with int32 + k_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks class TestCudaOps(unittest.TestCase): @staticmethod @@ -116,6 +242,309 @@ def test_cuda_fastgelu_f16(self): else: print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.') + @staticmethod + def _create_pagedattention_test_model(batch_size, total_seqlen, hidden_size, slot_cnt_per_block, block_cnt_per_layer, block_cnt_needed_by_longest_seq, num_heads=32, num_kv_heads=32, head_size=16, scale=0.0, domain='ai.onnx.contrib'): + nodes = [ + helper.make_node('PagedAttention', + ['query', 'key', 'value', 'key_cache', 'value_cache', 'block_tables', 'slot_mappings', 'context_lens', 'is_prompt'], + ['attn_out'], + domain=domain, num_heads=num_heads, num_kv_heads=num_kv_heads, head_size=head_size, scale=scale) + ] + query = helper.make_tensor_value_info( + 'query', onnx_proto.TensorProto.FLOAT16, [None, hidden_size]) + key = helper.make_tensor_value_info( + 'key', onnx_proto.TensorProto.FLOAT16, [None, hidden_size]) + value = helper.make_tensor_value_info( + 'value', onnx_proto.TensorProto.FLOAT16, [None, hidden_size]) + key_cache = helper.make_tensor_value_info( + 'key_cache', onnx_proto.TensorProto.FLOAT16, [block_cnt_per_layer, hidden_size * slot_cnt_per_block]) + value_cache = helper.make_tensor_value_info( + 'value_cache', onnx_proto.TensorProto.FLOAT16, [block_cnt_per_layer, hidden_size * slot_cnt_per_block]) + block_tables = helper.make_tensor_value_info( + 'block_tables', onnx_proto.TensorProto.INT32, [batch_size, block_cnt_needed_by_longest_seq]) + slot_mappings = helper.make_tensor_value_info( + 'slot_mappings', onnx_proto.TensorProto.INT32, [None]) + context_lens = helper.make_tensor_value_info( + 'context_lens', onnx_proto.TensorProto.INT32, [batch_size]) + is_prompt = helper.make_tensor_value_info( + 'is_prompt', onnx_proto.TensorProto.INT32, [1]) + attn_out = helper.make_tensor_value_info( + 'attn_out', onnx_proto.TensorProto.FLOAT16, [None, hidden_size]) + graph = helper.make_graph(nodes, 'test_paged_attention', + [query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, is_prompt], + [attn_out]) + model = make_onnx_model(graph) + return model + + def test_cuda_paged_attention(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_pagedattention_test_model(5, 87, 512, 16, 32, 3) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + #query = np.random.randn(87,512).astype(np.float16) # 87 is the token num of all the sequences (5+12+16+20+34) + #key = np.random.randn(87,512).astype(np.float16) + #value = np.random.randn(87,512).astype(np.float16) + query = np.load('query.npy') + key = np.load('key.npy') + value = np.load('value.npy') + key_cache = np.zeros([32,8192]).astype(np.float16) + value_cache = np.zeros([32,8192]).astype(np.float16) + block_tables = np.array([[0,-1,-1],[1,-1,-1],[2,-1,-1],[3,4,-1],[5,6,7]]).astype(np.int32) + slot1 = np.arange(0, 5, dtype=np.int32) + slot2 = np.arange(16, 28, dtype=np.int32) + slot3 = np.arange(32, 68, dtype=np.int32) + slot4 = np.arange(80, 114, dtype=np.int32) + slot_mappings = np.concatenate((slot1, slot2, slot3, slot4)) + context_lens = np.array([5, 12, 16, 20, 34]).astype(np.int32) + is_prompt = np.array([1]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'key_cache':key_cache, 'value_cache':value_cache, 'block_tables':block_tables, 'slot_mappings':slot_mappings, 'context_lens':context_lens, 'is_prompt':is_prompt}) + print('Y=') + print(y) + + def test_cuda_paged_attention2(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_pagedattention_test_model(3, 381, 512, 16, 32, 8) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + #query = np.random.randn(381,512).astype(np.float16) # 381 is the token num of all the sequences (127, 127, 127) + #key = np.random.randn(381,512).astype(np.float16) + #value = np.random.randn(381,512).astype(np.float16) + query = np.load('query_381x512_float16.npy') + key = np.load('key_381x512_float16.npy') + value = np.load('value_381x512_float16.npy') + key_cache = np.zeros([32,8192]).astype(np.float16) + value_cache = np.zeros([32,8192]).astype(np.float16) + block_tables = np.array([[0,1,2,3,4,5,6,7],[8,9,10,11,12,13,14,15],[16,17,18,19,20,21,22,23]]).astype(np.int32) # each sequence occupies 8 blocks (127/16) + slot1 = np.arange(0, 127, dtype=np.int32) + slot2 = np.arange(128, 255, dtype=np.int32) + slot3 = np.arange(256, 383, dtype=np.int32) + slot_mappings = np.concatenate((slot1, slot2, slot3)) + context_lens = np.array([127, 127, 127]).astype(np.int32) + is_prompt = np.array([1]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'key_cache':key_cache, 'value_cache':value_cache, 'block_tables':block_tables, 'slot_mappings':slot_mappings, 'context_lens':context_lens, 'is_prompt':is_prompt}) + #pdb.set_trace() + print('Y=') + print(y) + q_pt = torch.from_numpy(query.reshape(3, 127, 32, 16)) + k_pt = torch.from_numpy(key.reshape(3, 127, 32, 16)) + v_pt = torch.from_numpy(value.reshape(3, 127, 32, 16)) + out, attention = attention_ref(q_pt, k_pt, v_pt, causal=True, window_size=[-1, 0]) + y_np = np.array(y).reshape(381, 512) + out_np = out.reshape(381, 512).numpy() + assert np.allclose(y_np, out_np, rtol=1e-3, atol=1e-3, equal_nan=True) + + def test_cuda_paged_attention3(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_pagedattention_test_model(3, 381, 512, 16, 32, 8) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + + query = np.random.randn(381,512).astype(np.float16) # 381 is the token num of all the sequences (127, 127, 127) + key = np.random.randn(381,512).astype(np.float16) + value = np.random.randn(381,512).astype(np.float16) + key_cache = np.zeros([32,8192]).astype(np.float16) + value_cache = np.zeros([32,8192]).astype(np.float16) + block_tables = np.array([[0,1,2,3,4,5,6,7],[8,9,10,11,12,13,14,15],[16,17,18,19,20,21,22,23]]).astype(np.int32) # each sequence occupies 8 blocks (127/16) + slot1 = np.arange(0, 127, dtype=np.int32) + slot2 = np.arange(128, 255, dtype=np.int32) + slot3 = np.arange(256, 383, dtype=np.int32) + slot_mappings = np.concatenate((slot1, slot2, slot3)) + context_lens = np.array([127, 127, 127]).astype(np.int32) + is_prompt = np.array([1]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'key_cache':key_cache, 'value_cache':value_cache, 'block_tables':block_tables, 'slot_mappings':slot_mappings, 'context_lens':context_lens, 'is_prompt':is_prompt}) + q_pt = torch.from_numpy(query.reshape(3, 127, 32, 16)) + k_pt = torch.from_numpy(key.reshape(3, 127, 32, 16)) + v_pt = torch.from_numpy(value.reshape(3, 127, 32, 16)) + out, attention = attention_ref(q_pt, k_pt, v_pt, causal=True, window_size=[-1, 0]) + y_np = np.array(y).reshape(381, 512) + out_np = out.reshape(381, 512).numpy() + assert np.allclose(y_np, out_np, rtol=1e-3, atol=1e-3, equal_nan=True) + + def test_cuda_paged_attention_prompt_decoding(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_pagedattention_test_model(3, 381, 512, 16, 32, 8) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + + query = np.random.randn(381,512).astype(np.float16) # 381 is the token num of all the sequences (127, 127, 127) + key = np.random.randn(381,512).astype(np.float16) + value = np.random.randn(381,512).astype(np.float16) + key_cache = np.zeros([32,8192]).astype(np.float16) + value_cache = np.zeros([32,8192]).astype(np.float16) + block_tables = np.array([[0,1,2,3,4,5,6,7],[8,9,10,11,12,13,14,15],[16,17,18,19,20,21,22,23]]).astype(np.int32) # each sequence occupies 8 blocks (127/16) + slot1 = np.arange(0, 127, dtype=np.int32) + slot2 = np.arange(128, 255, dtype=np.int32) + slot3 = np.arange(256, 383, dtype=np.int32) + slot_mappings = np.concatenate((slot1, slot2, slot3)) + context_lens = np.array([127, 127, 127]).astype(np.int32) + is_prompt = np.array([1]).astype(np.int32) + + key_cache_ort = _ort.OrtValue.ortvalue_from_numpy(key_cache, "cuda") + value_cache_ort = _ort.OrtValue.ortvalue_from_numpy(value_cache, "cuda") + block_tables_ort = _ort.OrtValue.ortvalue_from_numpy(block_tables, "cuda") + slot_mappings_ort = _ort.OrtValue.ortvalue_from_numpy(slot_mappings, "cuda") + context_lens_ort = _ort.OrtValue.ortvalue_from_numpy(context_lens) + is_prompt_ort = _ort.OrtValue.ortvalue_from_numpy(is_prompt) + + # prompt case + io_binding = sess.io_binding() + io_binding.bind_cpu_input("query", query) + io_binding.bind_cpu_input("key", key) + io_binding.bind_cpu_input("value", value) + io_binding.bind_ortvalue_input("key_cache", key_cache_ort) + io_binding.bind_ortvalue_input("value_cache", value_cache_ort) + io_binding.bind_ortvalue_input("block_tables", block_tables_ort) + io_binding.bind_ortvalue_input("slot_mappings", slot_mappings_ort) + io_binding.bind_ortvalue_input("context_lens", context_lens_ort) + io_binding.bind_ortvalue_input("is_prompt", is_prompt_ort) + io_binding.bind_output("attn_out") + sess.run_with_iobinding(io_binding) + + # decoding case + query2 = np.random.randn(3, 512).astype(np.float16) + key2 = np.random.randn(3, 512).astype(np.float16) + value2 = np.random.randn(3, 512).astype(np.float16) + slot = np.array([127, 255, 383]).astype(np.int32) + io_binding.bind_cpu_input("query", query2) + io_binding.bind_cpu_input("key", key2) + io_binding.bind_cpu_input("value", value2) + io_binding.bind_cpu_input("slot_mappings", slot) + context_lens_ort.update_inplace(np.array([1,1,1]).astype(np.int32)) + is_prompt_ort.update_inplace(np.array([0]).astype(np.int32)) + sess.run_with_iobinding(io_binding) + + def test_cuda_paged_attention_decoding(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_pagedattention_test_model(batch_size=2, total_seqlen=0, hidden_size=96, slot_cnt_per_block=256, + block_cnt_per_layer=6, block_cnt_needed_by_longest_seq=3, num_heads=6, num_kv_heads=6, head_size=16) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + +# query = np.random.randn(2,96).astype(np.float16) +# key = np.random.randn(2,96).astype(np.float16) +# value = np.random.randn(2,96).astype(np.float16) +# key_cache = np.zeros([6,24576]).astype(np.float16) # 24576 = 256x6x16 +# value_cache = np.zeros([6,24576]).astype(np.float16) +# block_tables = np.array([[0,1,2],[3,4,5]]).astype(np.int32) +# context_lens = np.array([83, 65]).astype(np.int32) + #pdb.set_trace() + query_2x1x6x16 = np.load('q_2x1x6x16.npy') + key_2x1x6x16 = np.load('k_2x1x6x16.npy') + value_2x1x6x16 = np.load('v_2x1x6x16.npy') + key_cache_6x256x6x16 = np.load('k_cache_6x256x6x16.npy') + value_cache_6x256x6x16 = np.load('v_cache_6x256x6x16.npy') + block_tables = np.load('block_table_2x3.npy') # [[2,4,1], [5,3,0]] + context_lens = np.load('cache_seqlens_2.npy') # [83, 65] + query = np.reshape(query_2x1x6x16, (2, 96)) + key = np.reshape(key_2x1x6x16, (2, 96)) + value = np.reshape(value_2x1x6x16, (2, 96)) + key_cache = np.reshape(key_cache_6x256x6x16, (6, 24576)) + value_cache = np.reshape(value_cache_6x256x6x16, (6, 24576)) + + slot_mappings = np.array([250, 500]).astype(np.int32) + is_prompt = np.array([0]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'key_cache':key_cache, 'value_cache':value_cache, 'block_tables':block_tables, 'slot_mappings':slot_mappings, 'context_lens':context_lens, 'is_prompt':is_prompt}) + print('Y=') + print(y) + #y_np = np.array(y).reshape(2, 96) + #np.save('o_2x96', y_np) + + def test_cuda_paged_attention_decoding2(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_pagedattention_test_model(batch_size=2, total_seqlen=0, hidden_size=96, slot_cnt_per_block=256, + block_cnt_per_layer=6, block_cnt_needed_by_longest_seq=3, num_heads=6, num_kv_heads=6, head_size=16) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + + torch.random.manual_seed(0) + seqlen_k = 127 + batch_size = 2 + nheads = 6 + d = 16 + paged_kv_block_size = 256 + #pdb.set_trace() + + query = np.random.randn(batch_size, nheads*d).astype(np.float16) + key = np.random.randn(batch_size, nheads*d).astype(np.float16) + value = np.random.randn(batch_size, nheads*d).astype(np.float16) + key_cache_6x256x6x16 = np.random.randn(6, paged_kv_block_size, nheads, d).astype(np.float16) + value_cache_6x256x6x16 = np.random.randn(6, paged_kv_block_size, nheads, d).astype(np.float16) + key_cache = key_cache_6x256x6x16.reshape(6, paged_kv_block_size * nheads * d) + value_cache = value_cache_6x256x6x16.reshape(6, paged_kv_block_size * nheads * d) +# block_tables = np.random.permutation(6).astype(np.int32).reshape(2,3) +# context_lens = np.random.randint(1, seqlen_k, size=(batch_size,)).astype(np.int32) +# (k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks) = generate_block_kvcache( +# seqlen_k, paged_kv_block_size, batch_size, nheads, d, 'cuda', torch.float16) +# key_cache = k_cache_paged.cpu().numpy().reshape(6, 24576) +# value_cache = v_cache_paged.cpu().numpy().reshape(6, 24576) +# block_tables = block_table.cpu().numpy() +# cache_seqlens = torch.randint(1, seqlen_k, (batch_size,), dtype=torch.int32) +# context_lens = cache_seqlens.numpy() + +# query_2x1x6x16 = np.load('q_2x1x6x16.npy') +# key_2x1x6x16 = np.load('k_2x1x6x16.npy') +# value_2x1x6x16 = np.load('v_2x1x6x16.npy') +# key_cache_6x256x6x16 = np.load('k_cache_6x256x6x16.npy') +# value_cache_6x256x6x16 = np.load('v_cache_6x256x6x16.npy') + block_tables = np.load('block_table_2x3.npy') # [[2,4,1], [5,3,0]] + context_lens = np.load('cache_seqlens_2.npy') # [83, 65] +# query = np.reshape(query_2x1x6x16, (2, 96)) +# key = np.reshape(key_2x1x6x16, (2, 96)) +# value = np.reshape(value_2x1x6x16, (2, 96)) + key_cache = np.reshape(key_cache_6x256x6x16, (6, 24576)) + value_cache = np.reshape(value_cache_6x256x6x16, (6, 24576)) + + slot_mappings = np.array([250, 500]).astype(np.int32) + is_prompt = np.array([0]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'key_cache':key_cache, 'value_cache':value_cache, 'block_tables':block_tables, 'slot_mappings':slot_mappings, 'context_lens':context_lens, 'is_prompt':is_prompt}) + #print('Y=') + #print(y) + + cache_seqlens = torch.from_numpy(context_lens) + block_tables_pt = torch.from_numpy(block_tables) + key_cache_pt = torch.from_numpy(key_cache_6x256x6x16) + value_cache_pt = torch.from_numpy(value_cache_6x256x6x16) + k_cache_cpu = rearrange(key_cache_pt[block_tables_pt.flatten()], '(b nblocks) block_size ... -> b (nblocks block_size) ...', b = batch_size)[:, :seqlen_k] + v_cache_cpu = rearrange(value_cache_pt[block_tables_pt.flatten()], '(b nblocks) block_size ... -> b (nblocks block_size) ...', b = batch_size)[:, :seqlen_k] + + q = torch.from_numpy(query.reshape(batch_size, 1, nheads, d)) + k = torch.from_numpy(key.reshape(batch_size, 1, nheads, d)) + v = torch.from_numpy(value.reshape(batch_size, 1, nheads, d)) + + arange = rearrange(torch.arange(seqlen_k), 's->1 s') + cache_seqlens_expand = rearrange(cache_seqlens, 'b->b 1') + key_padding_mask = arange < cache_seqlens_expand + 1 + update_mask = torch.logical_and(cache_seqlens_expand <= arange, arange < cache_seqlens_expand + 1) +# k_cache_cpu = k_cache.cpu() +# v_cache_cpu = v_cache.cpu() + k_cache_cpu[update_mask] = rearrange(k, 'b s ... -> (b s) ...') + v_cache_cpu[update_mask] = rearrange(v, 'b s ... -> (b s) ...') + out_ref, _ = attention_ref(q, k_cache_cpu, v_cache_cpu, None, key_padding_mask, 0.0, None, causal=True) + #out_ref2, _ = attention_ref(q, k, v, causal=True) + #print('out_ref=') + #print(out_ref) + y_np = np.array(y).reshape(2, 96) + out_np = out_ref.reshape(2, 96).numpy() + #out_np2 = out_ref2.reshape(2, 96).numpy() + #print(y_np-out_np) + print(np.max(np.absolute(y_np - out_np))) + #print(query) + #print(out_np) + assert np.allclose(y_np, out_np, rtol=1e-3, atol=1e-3, equal_nan=True) + #assert np.allclose(y_np, out_np2, rtol=1e-3, atol=1e-3, equal_nan=True) + if __name__ == "__main__": unittest.main() diff --git a/test/cuda/v_2x1x6x16.npy b/test/cuda/v_2x1x6x16.npy new file mode 100644 index 000000000..093417fa1 Binary files /dev/null and b/test/cuda/v_2x1x6x16.npy differ diff --git a/test/cuda/v_cache_6x256x6x16.npy b/test/cuda/v_cache_6x256x6x16.npy new file mode 100644 index 000000000..40fff33dc Binary files /dev/null and b/test/cuda/v_cache_6x256x6x16.npy differ diff --git a/test/cuda/value.npy b/test/cuda/value.npy new file mode 100644 index 000000000..ffedf1bc2 Binary files /dev/null and b/test/cuda/value.npy differ diff --git a/test/cuda/value_381x512_float16.npy b/test/cuda/value_381x512_float16.npy new file mode 100644 index 000000000..f1ca878d3 Binary files /dev/null and b/test/cuda/value_381x512_float16.npy differ