From 8b7b72c2267e069b02f91b336eb734ebff080a9c Mon Sep 17 00:00:00 2001 From: xh Date: Sun, 31 Jul 2022 19:30:41 +0800 Subject: [PATCH 01/71] add support for attn mask --- .gitignore | 125 ++++++++++++++++++ benchmarks/test_example.py | 54 ++++++++ csrc/flash_attn/fmha_api.cpp | 29 +++- csrc/flash_attn/src/fmha.h | 6 + csrc/flash_attn/src/fmha/gmem_tile.h | 35 +++++ csrc/flash_attn/src/fmha/kernel_traits.h | 3 + .../src/fmha_fprop_fp16_kernel.sm80.cu | 24 ++++ csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 41 +++++- flash_attn/Attention.py | 56 ++++++++ setup.py | 12 +- 10 files changed, 374 insertions(+), 11 deletions(-) create mode 100644 .gitignore create mode 100644 benchmarks/test_example.py create mode 100644 flash_attn/Attention.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..826a7efe9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,125 @@ +*.pt +*.tfevents.* +# JetBrains PyCharm IDE +.idea/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# macOS dir files +.DS_Store + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.args +*.egg + +# Checkpoints +checkpoints + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mypy +.mypy_cache/ + +# VSCODE +.vscode/ftp-sync.json +.vscode/settings.json + +# too big to git +*.lmdb +*.sto +*.pt +*.pkl + +# pytest +.pytest_cache +test/.pytest_cache +/local* +/_* \ No newline at end of file diff --git a/benchmarks/test_example.py b/benchmarks/test_example.py new file mode 100644 index 000000000..1711ac865 --- /dev/null +++ b/benchmarks/test_example.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np + +from time import perf_counter_ns + +from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func + +# bs = 1 +# seq = 20 +# head = 32 +# c_dim = 16 + +torch.manual_seed(0) +# v2 +bs = 1 +seq = 128 +head = 16 +c_dim = 32 + +orig_tensor = torch.stack( + [ (i+1) * 0.1 * torch.ones((bs, seq, head, c_dim)) for i in range(seq) ] + ,dim = 1 +).cuda().to(torch.bfloat16) + +print ("origin shape: ", orig_tensor.shape) + +batch_size = bs * seq +seqlen = seq +max_s = seqlen +cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=orig_tensor.device) + +print (cu_seqlens) +tensor_2d_pad = orig_tensor.reshape(-1, head, c_dim) + +output3 = flash_attn_unpadded_func( + tensor_2d_pad, + tensor_2d_pad, + tensor_2d_pad, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + dropout_p = 0., + softmax_scale = 1., # q has been scaled already +) + +print (output3.shape, output3.shape) +output3 = output3.reshape((bs, seq, seq, head, c_dim)) \ No newline at end of file diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index dd9ba20a1..067331eae 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -53,7 +53,10 @@ void set_params_fprop(FMHA_fprop_params ¶ms, void *softmax_lse_d, float p_dropout, float softmax_scale, - bool is_causal) { + bool is_causal, + void *attn_mask, + void *attn_bias + ) { Data_type acc_type = DATA_TYPE_FP32; Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16; @@ -95,6 +98,10 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.seqlen_k = seqlen_k; params.d = d; + // attn mask & bias + params.attn_mask_ptr = attn_mask; + params.attn_bias_ptr = attn_bias; + // Set the different scale values. // const float scale_bmm1 = 1.f / sqrtf(d); const float scale_bmm1 = softmax_scale; @@ -152,7 +159,9 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, softmax_lse_d, p_dropout, softmax_scale, - is_causal); + is_causal, + nullptr, + nullptr); // Set the pointers and strides. params.dq_ptr = dq.data_ptr(); @@ -183,7 +192,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const bool zero_tensors, const bool is_causal, const bool return_softmax, - c10::optional gen_) { + c10::optional gen_, + const c10::optional &attn_mask, // attn_mask + const c10::optional &attn_bias // attn bias + ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -239,6 +251,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; + // loop over blocks more than once ? auto opts = q.options(); @@ -277,8 +290,12 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q softmax_lse.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + attn_mask ? attn_mask->data_ptr() : nullptr, + attn_bias ? attn_bias->data_ptr() : nullptr + ); + printf ("debug fmha api start test\n"); run_fmha_fp16_sm80(launch_params, /*configure=*/ true); // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -550,7 +567,9 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t softmax_lse.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + nullptr, + nullptr); launch_params.params.blockmask = static_cast(blockmask.data_ptr()); run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true); diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 7452dbe24..04fcba0f4 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -73,6 +73,12 @@ struct Qkv_params { struct FMHA_fprop_params : public Qkv_params { + // The attn mask matrix + void * __restrict__ attn_mask_ptr; + + // The attn bias matrix + void * __restrict__ attn_bias_ptr; + // The O matrix (output). void * __restrict__ o_ptr; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index e903c33ef..db21a0eb3 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -117,6 +117,20 @@ struct Gmem_tile_qkv { } } + // print data. + inline __device__ void print() { + int row_ = tidx_ / THREADS_PER_ROW; + printf("print LDGS %d\n", LDGS); + #pragma unroll + for( int ii = 0; ii < LDGS; ++ii ) { + char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { + printf("%f\n", *(ptr_ + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes)); + printf("%f\n", (ptr_ + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes)); + } + } + } + // Store data to memory. inline __device__ void store(const uint4 (&data)[LDGS]) { int row_ = tidx_ / THREADS_PER_ROW; @@ -404,6 +418,27 @@ struct Gmem_tile_mma_s : public Base { } }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// attn mask struct like s, maybe later can reuse the above declaration +template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > +struct Gmem_tile_mma_mask : public Base { + + // The number of mmas in the vertical dimension. + static constexpr int M = Base::MMAS_M; + // The number of mmas in the horizontal dimension. + static constexpr int N = Base::MMAS_N; + // The type of the vectors stored by each STG. + using Type = typename Base::Type; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_mask(const Params ¶ms, const Block_info& binfo, const int tidx) + : Base(params.attn_mask_ptr, params, binfo.bidb, binfo.bidh, tidx) { + } + // TODO load data impl +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// template< diff --git a/csrc/flash_attn/src/fmha/kernel_traits.h b/csrc/flash_attn/src/fmha/kernel_traits.h index 7d7218c6b..809f19c52 100644 --- a/csrc/flash_attn/src/fmha/kernel_traits.h +++ b/csrc/flash_attn/src/fmha/kernel_traits.h @@ -71,6 +71,9 @@ struct FMHA_kernel_traits { // The global memory tile to load/store S. using Gmem_tile_s = fmha::Gmem_tile_mma_s; + // Gmem_tile_mma_mask + using Gmem_tile_mask = fmha::Gmem_tile_mma_mask; + // The shared memory tile to transpose S. using Smem_tile_st = fmha::Smem_tile_mma_transposed; diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 67be427ad..193b77c3d 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -51,9 +51,20 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; launch_params.elts_per_thread = elts_per_head; + + printf("configuration----"); + printf("blocksize_c %d\n", blocksize_c); + printf("loop steps %d\n", loop_steps); + printf("Steps %zu\n", STEPS); + printf("Cta tile M %d\n", M); + printf("Cta tile MMAS_M %zu\n", MMAS_M); + printf("Cta tile MMAS_N %zu\n", MMAS_N); + printf("elts_per_thread %zu\n", elts_per_head); + printf("configuration----"); return; } + printf("deub kernel sm80\n"); constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; // Don't need smem_size_softmax_lse if we're not looping const int smem_size = fmha::get_dynamic_smem_size() @@ -75,6 +86,11 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 grid(launch_params.params.b, launch_params.params.h); + // grid: b, h (batch_size = len(cu_seq_q), 16), + // block: 32 + // kernel<<>> + // xh: static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA; + // static constexpr int THREADS_PER_WARP = 32; kernel<<>>( launch_params.params); FMHA_CHECK_CUDA(cudaPeekAtLastError()); @@ -89,6 +105,14 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, if (launch_params.params.d == 16) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; + // xh: template using Cta_tile_p = fmha::Cta_tile_extd; + // -> using Cta_tile_extd = Cta_tile_; + // -> static constexpr int M = M_(STEP 16), N = N_(s 128), K = K_ (D 16); + // -> // The number of warps. + // static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_; + // WARPS_M is 1, WARPS_N is 4 + // STEP = 16 ??? run_fmha_fp16_sm80_loop_(launch_params, configure); } else if( launch_params.params.seqlen_k == 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 91ef08e96..5d69f64fd 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -255,13 +255,26 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; Gemm1 gemm_q_k(smem_, tidx); + // Allocate the global memory tile loader for Q. Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); + // Gmem_tile_q is fmha::Gmem_tile_qkv; + // Gmem_tile_q is fmha::Gmem_tile_qkv, 取数逻辑在gmem_tile.h + // 指针q_ptr引入数据,构造 csrc/flash_attn/src/fmha/gmem_tile.h:67 + // row_stide is dim0, head_strid is dim1 + // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); + + + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + // Gmem_tile_mask gmem_mask(params, binfo, tidx); + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); // Wind gmem tiles to the correct position. @@ -275,11 +288,15 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_o_tmp.move(begin); if (Return_softmax) { gmem_s.move(begin); } gmem_softmax_lse.move(begin); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("begin = %d, steps = %d\n", begin, steps); - // } + + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("begin = %d, steps = %d\n", begin, steps); + } + fmha::Mask mask(binfo, tidx, loop_step_idx); + // mask in used pad actually, actual_seqlen_k is params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb] + // why works for the difference output ? // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); @@ -302,6 +319,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Trigger the loads for K. gmem_k.load(); + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("Gemm load k...\n"); + gmem_k.print(); + } // Trigger the loads for Q. gmem_q.load(); // Trigger the loads for V. @@ -392,7 +413,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - // Apply the mask. + // Apply the mask. + // this impl is more like padding + // TODO apply + mask + bias softmax.apply_mask(mask); if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { @@ -651,6 +674,15 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int tidx = threadIdx.x; const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; + // tidx_global = (blockIdx.x * params.h + blockIdx.y) * blockDim.x * 2 + threadIdx.x; + // what is mean of 2? + if (tidx == 0) { + printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); + } + if (tidx == 1) { + printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); + } + auto seeds = at::cuda::philox::unpack(params.philox_args); Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); @@ -658,6 +690,7 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int STEPS = (params.seqlen_q + M - 1) / M; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + // Tc, loop over k in algo2 line 6, blocksize_c in line 4 if (params.seqlen_k == blocksize_c) { fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); } else { diff --git a/flash_attn/Attention.py b/flash_attn/Attention.py new file mode 100644 index 000000000..0ff736f13 --- /dev/null +++ b/flash_attn/Attention.py @@ -0,0 +1,56 @@ +import torch + +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + + +def flash_attn(q, k, v): + # import pdb; pdb.set_trace() + batch_dims = q.shape[:-3] + no_heads, n, c = q.shape[-3:] + dtype = q.dtype + + # [*, B, N, H, C] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + # [B_flat, N, H, C] + q = q.reshape(-1, *q.shape[-3:]) + k = k.reshape(-1, *k.shape[-3:]) + v = v.reshape(-1, *v.shape[-3:]) + + # Flattened batch size + batch_size = q.shape[0] + + # [B_flat * N, H, C] + q = q.reshape(-1, *q.shape[-2:]) + k = k.reshape(-1, *k.shape[-2:]) + v = v.reshape(-1, *v.shape[-2:]) + + q_max_s = n + q_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device + ) + + k_max_s = n + k_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device + ) + + out = flash_attn_unpadded_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + q_max_s, + k_max_s, + dropout_p = 0., + softmax_scale = 1., # q has been scaled already + ) + # [*, B, N, H, C] + out = out.reshape(*batch_dims, n, no_heads, c) + + out = out.to(dtype=dtype) + + return out \ No newline at end of file diff --git a/setup.py b/setup.py index eabcf0630..2f48a6fda 100644 --- a/setup.py +++ b/setup.py @@ -125,10 +125,12 @@ def append_nvcc_threads(nvcc_extra_args): "csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu", ], extra_compile_args={ - "cxx": ["-O3"] + generator_flag, + # "cxx": ["-O3"] + generator_flag, + "cxx": ["-g"] + generator_flag, "nvcc": append_nvcc_threads( [ - "-O3", + #"-O3", + "-g", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "--expt-relaxed-constexpr", @@ -146,6 +148,12 @@ def append_nvcc_threads(nvcc_extra_args): Path(this_dir) / 'csrc' / 'flash_attn' / 'src', Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', ], + # add depends for modification of header file + depends = [ + Path(this_dir) / 'csrc' / 'flash_attn', + Path(this_dir) / 'csrc' / 'flash_attn' / 'src', + Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', + ], ) ) From b8acf761539c5207ec520263a7141dc82b5c174a Mon Sep 17 00:00:00 2001 From: xh Date: Mon, 1 Aug 2022 17:07:08 +0800 Subject: [PATCH 02/71] add mask operation --- csrc/flash_attn/src/fmha/gmem_tile.h | 32 +++++++++--- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 58 ++++++++++++++++----- 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index db21a0eb3..ef89338a9 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -114,20 +114,23 @@ struct Gmem_tile_qkv { #pragma unroll for( int ii = 0; ii < LDGS; ++ii ) { fct.load(ii, preds[ii]); + // // The fetch registers. fetch_ + // -> inline __device__ void load(int ii, bool p) + // -> ldg(fetch_[ii], ptrs_[ii]); + // -> inline __device__ void ldg(uint4 &dst, const void *ptr) { + // dst = *reinterpret_cast(ptr); + // } } } // print data. + template inline __device__ void print() { - int row_ = tidx_ / THREADS_PER_ROW; + // int row_ = tidx_ / THREADS_PER_ROW; printf("print LDGS %d\n", LDGS); - #pragma unroll for( int ii = 0; ii < LDGS; ++ii ) { - char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; - if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { - printf("%f\n", *(ptr_ + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes)); - printf("%f\n", (ptr_ + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes)); - } + // char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + printf("data: %f\n", *(elem_type *)(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes)); } } @@ -140,6 +143,8 @@ struct Gmem_tile_qkv { char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { fmha::stg(ptr_, data[ii]); + // stg function, inline __device__ void stg(void *ptr, uint2 val) + // *reinterpret_cast(ptr) = val; } } } @@ -437,6 +442,19 @@ struct Gmem_tile_mma_mask : public Base { : Base(params.attn_mask_ptr, params, binfo.bidb, binfo.bidh, tidx) { } // TODO load data impl + + // Load from global memory. + template + inline __device__ void load(Fragment (&frag)[N][M]) { + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + frag[mi][ni] = make_uint4(0, 0, 0, 0); + Base::load(frag[mi][ni], mi, ni); + } + } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 5d69f64fd..c9f831c7a 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -52,6 +52,7 @@ struct Gemm_Q_K_base { using Mma_tile_p = fmha::Hmma_tile; static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; + // ? __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) : smem_q(smem_ptr_q, tidx) @@ -90,6 +91,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base { static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE; static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE; static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE); + // ? offset of smem // Q | K / V // | O | SOFTMAX @@ -262,6 +264,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Gmem_tile_q is fmha::Gmem_tile_qkv, 取数逻辑在gmem_tile.h // 指针q_ptr引入数据,构造 csrc/flash_attn/src/fmha/gmem_tile.h:67 // row_stide is dim0, head_strid is dim1 + // constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; + // using A_type = uint16_t; // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); @@ -269,17 +273,18 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - // Allocate the global memory tile loader for mask. using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; // conctructor - // Gmem_tile_mask gmem_mask(params, binfo, tidx); + Gmem_tile_mask gmem_mask(params, binfo, tidx); + // TODO: load fun as s Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); // Wind gmem tiles to the correct position. static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); const int begin_og = begin; + // begin is 0 begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; const int steps_og = steps; steps -= begin - begin_og; @@ -287,12 +292,14 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_o.move(begin); gmem_o_tmp.move(begin); if (Return_softmax) { gmem_s.move(begin); } + // TODO: mask move + gmem_mask.move(begin); + gmem_softmax_lse.move(begin); - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("begin = %d, steps = %d\n", begin, steps); - } - + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("begin = %d, steps = %d\n", begin, steps); + // } fmha::Mask mask(binfo, tidx, loop_step_idx); // mask in used pad actually, actual_seqlen_k is params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb] @@ -304,6 +311,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false); // The base pointer of smem_v; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; + // smem_ is continous memory, each part is v, o // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! Smem_tile_v smem_v(smem_v_, tidx); @@ -315,14 +323,12 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } + // TODO: mask move as s + gmem_mask.move(loop_step_idx * steps_og); } // Trigger the loads for K. gmem_k.load(); - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("Gemm load k...\n"); - gmem_k.print(); - } // Trigger the loads for Q. gmem_q.load(); // Trigger the loads for V. @@ -352,6 +358,12 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i __syncthreads(); + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("Gemm load k...\n"); + // https://stackoverflow.com/questions/7397934/calling-template-function-within-template-class + gmem_k.template print(); + } + // Load the fragments for Q. gemm_q_k.load_q(); @@ -391,7 +403,23 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i fmha::Clear_accumulator::apply(acc_p); // Do this part of P = Q * K^T. + // ? M / N 参数? gemm_q_k(acc_p); + // TODO acc_p += mask, index like gmem_s.store(frag_p, mask); + // move(1) + + using Frag_mask = fmha::Fragment_a; + // struct Fragment_a : public Fragment { + // struct Fragment_accumulator : public Fragment { + Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + gmem_mask.load(frag_mask); + acc_p.add(frag_mask); + gmem_mask.move(); + + if (Return_softmax) { + gmem_s.store(frag_p, mask); + gmem_s.move(); + } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); @@ -415,7 +443,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Apply the mask. // this impl is more like padding - // TODO apply + mask + bias softmax.apply_mask(mask); if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { @@ -503,6 +530,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); softmax.template pack(frag_p); + // ? pack if (Return_softmax) { gmem_s.store(frag_p, mask); gmem_s.move(); @@ -679,9 +707,9 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { if (tidx == 0) { printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); } - if (tidx == 1) { - printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); - } + // if (tidx == 1) { + // printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); + // } auto seeds = at::cuda::philox::unpack(params.philox_args); Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); @@ -692,6 +720,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; // Tc, loop over k in algo2 line 6, blocksize_c in line 4 if (params.seqlen_k == blocksize_c) { + // inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); } else { const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; From abc409d3e5094011d0651699e586749d2bedb5cc Mon Sep 17 00:00:00 2001 From: xh Date: Mon, 1 Aug 2022 17:35:43 +0800 Subject: [PATCH 03/71] add mask operation --- csrc/flash_attn/src/fmha/gmem_tile.h | 10 +++++++--- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 17 ++++++----------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index ef89338a9..518d6cd59 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -444,14 +444,18 @@ struct Gmem_tile_mma_mask : public Base { // TODO load data impl // Load from global memory. - template + template inline __device__ void load(Fragment (&frag)[N][M]) { #pragma unroll for( int mi = 0; mi < M; mi++ ) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - frag[mi][ni] = make_uint4(0, 0, 0, 0); - Base::load(frag[mi][ni], mi, ni); + uint4 dst; + Base::load(dst mi, ni); + frag[ni][mi].reg(0) = dst.x; + frag[ni][mi].reg(2) = dst.y; + frag[ni][mi].reg(1) = dst.z; + frag[ni][mi].reg(3) = dst.w; } } } diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index c9f831c7a..defb1c18e 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -358,11 +358,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i __syncthreads(); - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("Gemm load k...\n"); - // https://stackoverflow.com/questions/7397934/calling-template-function-within-template-class - gmem_k.template print(); - } + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Gemm load k...\n"); + // // https://stackoverflow.com/questions/7397934/calling-template-function-within-template-class + // gmem_k.template print(); + // } // Load the fragments for Q. gemm_q_k.load_q(); @@ -413,14 +413,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // struct Fragment_accumulator : public Fragment { Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; gmem_mask.load(frag_mask); - acc_p.add(frag_mask); + acc_p.template add(frag_mask); gmem_mask.move(); - if (Return_softmax) { - gmem_s.store(frag_p, mask); - gmem_s.move(); - } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); // } From 98b290af526766afd609c4fa1a1fd1193c174682 Mon Sep 17 00:00:00 2001 From: xh Date: Mon, 1 Aug 2022 17:55:22 +0800 Subject: [PATCH 04/71] add mask operation --- csrc/flash_attn/src/fmha/gmem_tile.h | 2 +- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 518d6cd59..c7e62091e 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -451,7 +451,7 @@ struct Gmem_tile_mma_mask : public Base { #pragma unroll for( int ni = 0; ni < N; ni++ ) { uint4 dst; - Base::load(dst mi, ni); + Base::load(dst, mi, ni); frag[ni][mi].reg(0) = dst.x; frag[ni][mi].reg(2) = dst.y; frag[ni][mi].reg(1) = dst.z; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index defb1c18e..099e0a848 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -412,8 +412,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // struct Fragment_a : public Fragment { // struct Fragment_accumulator : public Fragment { Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + gmem_mask.load(frag_mask); - acc_p.template add(frag_mask); + // acc_p.template add(frag_mask); + acc_p.add(frag_mask); gmem_mask.move(); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { From ce0aff1e1782d327939ad38b10d2c4176d103ecc Mon Sep 17 00:00:00 2001 From: xh Date: Mon, 1 Aug 2022 19:51:11 +0800 Subject: [PATCH 05/71] add interface --- csrc/flash_attn/src/fmha/gmem_tile.h | 2 +- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 8 +++++- flash_attn/flash_attn_interface.py | 28 +++++++++++++-------- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index c7e62091e..8900ee8b2 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -443,7 +443,7 @@ struct Gmem_tile_mma_mask : public Base { } // TODO load data impl - // Load from global memory. + // Load from global memory to Fragment. template inline __device__ void load(Fragment (&frag)[N][M]) { #pragma unroll diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 099e0a848..f4ba3ef32 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -415,7 +415,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_mask.load(frag_mask); // acc_p.template add(frag_mask); - acc_p.add(frag_mask); + // acc_p.add(frag_mask); + // mask tranpose or not + for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + acc_p[mi][ni].add(frag_mask[ni][mi]); + } + } gmem_mask.move(); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1076648c5..67cab2f1b 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -14,12 +14,18 @@ def _get_block_size(device, head_dim, is_dropout): return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128 -def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, +def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, dropout_p, softmax_scale, causal, return_softmax): - out, softmax_lse, *rest = flash_attn_cuda.fwd( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, - False, causal, return_softmax, None - ) + if attn_mask is None: + out, softmax_lse, *rest = flash_attn_cuda.fwd( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, + False, causal, return_softmax, None, None, None + ) + else: + out, softmax_lse, *rest = flash_attn_cuda.fwd( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, + False, causal, return_softmax, None, attn_mask, None + ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() S_dmask = rest[0] if return_softmax else None @@ -114,14 +120,14 @@ def backward(ctx, dout, *args): class FlashAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax): + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, + dropout_p, softmax_scale, causal, return_softmax): # Save rng_state because the backward pass will regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse, S_dmask = _flash_attn_forward( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax ) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) @@ -210,8 +216,8 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq return_attn_probs) -def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale=None, causal=False, return_attn_probs=False): +def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask=None, + dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. @@ -239,7 +245,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, dropout_p, softmax_scale, causal, return_attn_probs) From 5c023bad78e3ea811ad2a6045e926a41cb8ad4e5 Mon Sep 17 00:00:00 2001 From: xh Date: Mon, 1 Aug 2022 21:08:38 +0800 Subject: [PATCH 06/71] add mask support --- csrc/flash_attn/fmha_api.cpp | 1 + csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 50 ++++++++++++--------- flash_attn/flash_attn_interface.py | 1 + 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 067331eae..8ac9dfdd2 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -102,6 +102,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.attn_mask_ptr = attn_mask; params.attn_bias_ptr = attn_bias; + // Set the different scale values. // const float scale_bmm1 = 1.f / sqrtf(d); const float scale_bmm1 = softmax_scale; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index f4ba3ef32..f646e1821 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -273,11 +273,14 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); + bool has_attn = !(params.attn_mask_ptr == nullptr); // Allocate the global memory tile loader for mask. using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; - // conctructor - Gmem_tile_mask gmem_mask(params, binfo, tidx); - // TODO: load fun as s + if (has_attn) { + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx); + // TODO: load fun as s + } Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -292,8 +295,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_o.move(begin); gmem_o_tmp.move(begin); if (Return_softmax) { gmem_s.move(begin); } - // TODO: mask move - gmem_mask.move(begin); + + if (has_attn) { + // TODO: mask move + gmem_mask.move(begin); + } gmem_softmax_lse.move(begin); @@ -323,8 +329,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } - // TODO: mask move as s - gmem_mask.move(loop_step_idx * steps_og); + if (has_attn) { + // TODO: mask move as s + gmem_mask.move(loop_step_idx * steps_og); + } } // Trigger the loads for K. @@ -408,21 +416,23 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // TODO acc_p += mask, index like gmem_s.store(frag_p, mask); // move(1) - using Frag_mask = fmha::Fragment_a; - // struct Fragment_a : public Fragment { - // struct Fragment_accumulator : public Fragment { - Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - - gmem_mask.load(frag_mask); - // acc_p.template add(frag_mask); - // acc_p.add(frag_mask); - // mask tranpose or not - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - acc_p[mi][ni].add(frag_mask[ni][mi]); + if (has_attn) { + using Frag_mask = fmha::Fragment_a; + // struct Fragment_a : public Fragment { + // struct Fragment_accumulator : public Fragment { + Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + + gmem_mask.load(frag_mask); + // acc_p.template add(frag_mask); + // acc_p.add(frag_mask); + // mask tranpose or not + for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + acc_p[mi][ni].add(frag_mask[ni][mi]); + } } + gmem_mask.move(); } - gmem_mask.move(); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 67cab2f1b..eba15bec6 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -16,6 +16,7 @@ def _get_block_size(device, head_dim, is_dropout): def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, dropout_p, softmax_scale, causal, return_softmax): + # import pdb; pdb.set_trace() if attn_mask is None: out, softmax_lse, *rest = flash_attn_cuda.fwd( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, From bc5aa56a00c865dbca0a2a2a8233fc122adf57b4 Mon Sep 17 00:00:00 2001 From: xh Date: Tue, 2 Aug 2022 17:30:12 +0800 Subject: [PATCH 07/71] add mask supprt --- csrc/flash_attn/src/fmha/gemm.h | 20 +++++ csrc/flash_attn/src/fmha/gmem_tile.h | 2 + csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 92 ++++++++++++++++++--- 3 files changed, 101 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index a142f0bf2..7c64db8c8 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -112,6 +112,12 @@ struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { return reinterpret_cast(&this->regs_[0])[ii]; } + // xh: Immutable access to the elements and cast to elem_type. + template< typename elem_type > + inline __device__ elem_type& elt(int ii) { + return reinterpret_cast(&this->regs_[0])[ii]; + } + // Immutable access to the elements with a cast. template< typename Cast_type > inline __device__ const Cast_type& elt_as(int ii) const { @@ -165,6 +171,12 @@ struct Fragment_b : public Fragment { //////////////////////////////////////////////////////////////////////////////////////////////////// +template< typename Layout > +struct Fragment_c : public Fragment { +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + struct Fragment_accumulator : public Fragment { // The base class. @@ -178,6 +190,14 @@ struct Fragment_accumulator : public Fragment { } } + template< typename elem_type, typename Other_fragment_ > + inline __device__ void add(const Other_fragment_ &other) { + for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { + this->elt(ii) = this->elt(ii) + other.template elt(ii); + + } + } + inline __device__ void mul_(const float other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) *= other; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 8900ee8b2..540f72324 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -402,6 +402,8 @@ struct Gmem_tile_mma_s : public Base { dst.w = frag[ni][mi].reg(3); if( mask.any_valid(mi, ni) ) { Base::store(dst, mi, ni); + // uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; + // fmha::stg(ptr_ + offset, data); } } } diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index f646e1821..f8b4c4636 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -276,11 +276,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i bool has_attn = !(params.attn_mask_ptr == nullptr); // Allocate the global memory tile loader for mask. using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; - if (has_attn) { - // conctructor - Gmem_tile_mask gmem_mask(params, binfo, tidx); - // TODO: load fun as s - } + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx); + // TODO: load fun as s Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -291,6 +289,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; const int steps_og = steps; steps -= begin - begin_og; + // begin - begin_og = 0 + // steps -= 0 gmem_q.move(begin); gmem_o.move(begin); gmem_o_tmp.move(begin); @@ -301,6 +301,14 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_mask.move(begin); } + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("og begin %d\n", begin_og); + printf("begin %d\n", begin); + printf("og step %d\n", steps_og); + printf("begin %d\n", steps); + printf("loop_step_idx %d\n", loop_step_idx); + } + gmem_softmax_lse.move(begin); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { @@ -416,27 +424,68 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // TODO acc_p += mask, index like gmem_s.store(frag_p, mask); // move(1) + // debug: before add mask + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + printf("before mask print acc_p\n"); + // for (int i = 0; i < acc_p[0][0].NUM_ELTS; i ++) { + for (int i = 0; i < 8; i ++) { + printf("i=%d, acc_p=%.6f \n", i, acc_p[0][0].elt(i)); + } + printf("\n"); + printf("end print acc_p\n"); + } + if (has_attn) { + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf(" add attn mask ====\n"); + } using Frag_mask = fmha::Fragment_a; + // template< + // // The type of the elements. + // typename Data_type_, + // // The number of elements. + // int NUM_ELTS_, // struct Fragment_a : public Fragment { // struct Fragment_accumulator : public Fragment { Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - gmem_mask.load(frag_mask); // acc_p.template add(frag_mask); // acc_p.add(frag_mask); // mask tranpose or not + __syncthreads(); + + for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - acc_p[mi][ni].add(frag_mask[ni][mi]); + acc_p[mi][ni].template add(frag_mask[ni][mi]); + } + } + + // debug: + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("print frag_mask %d\n", l); + // for (int i = 0; i < frag_mask[0][0].NUM_ELTS; i ++) { + for (int i = 0; i < 8; i ++) { + printf("i=%d, frag_mask=%.6f, %u\n", i, + frag_mask[0][0].elt(i), + frag_mask[0][0].elt(i)); } + printf("\n"); + printf("end print frag_mask\n"); } gmem_mask.move(); } - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); - // } + // debug: after add mask + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + printf("after mask print acc_p\n"); + // for (int i = 0; i < acc_p[0][0].NUM_ELTS; i ++) { + for (int i = 0; i < 8; i ++) { + printf("i=%d, acc_p=%.6f\n", i, acc_p[0][0].elt(i)); + } + printf("\n"); + printf("end print acc_p\n"); + } uint4 out[Gmem_tile_o::STGS_PER_LOOP]; if (!Is_first) { gmem_o_tmp.load(out, 0); } @@ -454,6 +503,23 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); + // if (is_thread_0()) { + // for (int i_m = 0; i_m < Mma_tile_p::MMAS_M * 2; i_m ++){ + // printf("after_mask: loop_step_idx = %03d, l = %03d, threadIdx.x = %03d, threadIdx.y = %03d, i_m = %01d, accp ele(0) = %f, ele(1) = %f, ele(2) = %f, ele(3) = %f, ele(4) = %f, ele(5) = %f, ele(6) = %f, ele(7) = %f \n", + // loop_step_idx, l, threadIdx.x, threadIdx.y, i_m, + // softmax.elt_[i_m][0], + // softmax.elt_[i_m][1], + // softmax.elt_[i_m][2], + // softmax.elt_[i_m][3], + // softmax.elt_[i_m][4], + // softmax.elt_[i_m][5], + // softmax.elt_[i_m][6], + // softmax.elt_[i_m][7] + // ); + // } + // } + + // Apply the mask. // this impl is more like padding softmax.apply_mask(mask); @@ -717,9 +783,9 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; // tidx_global = (blockIdx.x * params.h + blockIdx.y) * blockDim.x * 2 + threadIdx.x; // what is mean of 2? - if (tidx == 0) { - printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); - } + // if (tidx == 0) { + // printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); + // } // if (tidx == 1) { // printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); // } From a6f232b6416fcbe0a6e7371f628346aa613b42d7 Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 4 Aug 2022 12:37:38 +0800 Subject: [PATCH 08/71] fix up --- csrc/flash_attn/src/fmha/gemm.h | 59 ++++++++++-- .../src/fmha_fprop_fp16_kernel.sm80.cu | 11 ++- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 91 +++++++++++++------ 3 files changed, 120 insertions(+), 41 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 7c64db8c8..e45064a65 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -36,6 +36,9 @@ #include #include +#include +#include + namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -113,10 +116,10 @@ struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { } // xh: Immutable access to the elements and cast to elem_type. - template< typename elem_type > - inline __device__ elem_type& elt(int ii) { - return reinterpret_cast(&this->regs_[0])[ii]; - } + // template< typename elem_type > + // inline __device__ elem_type& elt(int ii) { + // return reinterpret_cast(&this->regs_[0])[ii]; + // } // Immutable access to the elements with a cast. template< typename Cast_type > @@ -171,12 +174,28 @@ struct Fragment_b : public Fragment { //////////////////////////////////////////////////////////////////////////////////////////////////// -template< typename Layout > -struct Fragment_c : public Fragment { +template< typename Layout, typename elem_type > +struct Fragment_c : public Fragment { }; //////////////////////////////////////////////////////////////////////////////////////////////////// +template __device__ +inline float toFloat(T a) { + return (float)a; +} +template<> __device__ +inline float toFloat(half a) { + return __half2float(a); +} +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> __device__ +inline float toFloat(__nv_bfloat16 a) { + return __bfloat162float(a); +} +#endif + + struct Fragment_accumulator : public Fragment { // The base class. @@ -190,14 +209,34 @@ struct Fragment_accumulator : public Fragment { } } - template< typename elem_type, typename Other_fragment_ > - inline __device__ void add(const Other_fragment_ &other) { + template< typename Other_fragment_ > + inline __device__ void addf(const Other_fragment_ &other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { - this->elt(ii) = this->elt(ii) + other.template elt(ii); - + this->elt(ii) = this->elt(ii) + toFloat(other.elt(ii)); } } + // cause invalid redeclaration of member function template + // template + // inline __device__ void addf(const Other_fragment_ &other); + + // // 特化 + // template + // inline __device__ void addf(const Other_fragment_ &other) { + // for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { + // this->elt(ii) = this->elt(ii) + __half2float(other.elt(ii)); + // } + // } + + // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // template + // inline __device__ void addf(const Other_fragment_ &other) { + // for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { + // this->elt(ii) = this->elt(ii) + __bfloat162float(other.elt(ii)); + // } + // } + // #endif + inline __device__ void mul_(const float other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) *= other; diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 193b77c3d..765132460 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -52,15 +52,15 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; launch_params.elts_per_thread = elts_per_head; - printf("configuration----"); + printf("configuration----\n"); printf("blocksize_c %d\n", blocksize_c); printf("loop steps %d\n", loop_steps); printf("Steps %zu\n", STEPS); printf("Cta tile M %d\n", M); - printf("Cta tile MMAS_M %zu\n", MMAS_M); - printf("Cta tile MMAS_N %zu\n", MMAS_N); + printf("Mma tile MMAS_M %zu\n", MMAS_M); + printf("Mma tile MMAS_N %zu\n", MMAS_N); printf("elts_per_thread %zu\n", elts_per_head); - printf("configuration----"); + printf("configuration----\n"); return; } @@ -86,6 +86,9 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 grid(launch_params.params.b, launch_params.params.h); + + printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + printf("block size: %d\n", Kernel_traits::THREADS); // grid: b, h (batch_size = len(cu_seq_q), 16), // block: 32 // kernel<<>> diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index f8b4c4636..d29b70e63 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -304,7 +304,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("og begin %d\n", begin_og); printf("begin %d\n", begin); - printf("og step %d\n", steps_og); + printf("og step %d\n", steps_og); printf("begin %d\n", steps); printf("loop_step_idx %d\n", loop_step_idx); } @@ -431,7 +431,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i for (int i = 0; i < 8; i ++) { printf("i=%d, acc_p=%.6f \n", i, acc_p[0][0].elt(i)); } - printf("\n"); printf("end print acc_p\n"); } @@ -439,7 +438,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf(" add attn mask ====\n"); } - using Frag_mask = fmha::Fragment_a; + // method 1 + using Frag_mask = fmha::Fragment_c; + // method 2 + // using Frag_mask = fmha::Fragment_a; // template< // // The type of the elements. // typename Data_type_, @@ -448,29 +450,57 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // struct Fragment_a : public Fragment { // struct Fragment_accumulator : public Fragment { Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + gmem_mask.load(frag_mask); + // see the 463 // acc_p.template add(frag_mask); + // acc_p.template add(frag_mask); // acc_p.add(frag_mask); // mask tranpose or not __syncthreads(); - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - acc_p[mi][ni].template add(frag_mask[ni][mi]); + // acc_p[mi][ni].template addf(frag_mask[ni][mi]); + // acc_p[mi][ni].add(frag_mask[ni][mi]); + acc_p[mi][ni].addf(frag_mask[ni][mi]); } } // debug: if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("print frag_mask %d\n", l); + printf("print frag_mask: l=%d\n", l); + + float2 tmp_mask1 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[0])); + float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[0][0])); + float2 tmp_mask3 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[0][1])); + float2 tmp_mask4 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[1][0])); + printf("Per warp, threadIdx.x = %d, frag_mask[0] = %.6f, %.6f, frag_mask[0][0] = %.6f, %.6f\n", + threadIdx.x, tmp_mask1.x, tmp_mask1.y, tmp_mask2.x, tmp_mask2.y); + printf("Per warp, threadIdx.x = %d, frag_mask[0][1] = %.6f, %.6f, frag_mask[1][0] = %.6f, %.6f\n", + threadIdx.x, tmp_mask3.x, tmp_mask3.y, tmp_mask4.x, tmp_mask4.y); + + #pragma unroll + for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { + float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[mi][ki])); + printf("Per warp, threadIdx.x = %d, mi=%d, ni=%d, frag_mask[mi][ni]= %.6f, %.6f\n", + threadIdx.x, mi, ki, tmp_mask2.x, tmp_mask2.y); + } + } + + int num_elt = frag_mask[0][0].NUM_ELTS; // for (int i = 0; i < frag_mask[0][0].NUM_ELTS; i ++) { - for (int i = 0; i < 8; i ++) { - printf("i=%d, frag_mask=%.6f, %u\n", i, + // sometime correct, sometime wrong + for (int i = 0; i < num_elt; i ++) { + printf("i=%d, frag_mask=%.6f, hex=%d, %f\n", i, frag_mask[0][0].elt(i), - frag_mask[0][0].elt(i)); + frag_mask[0][0].elt(i), + // toFloat(frag_mask[0][0].elt(i)) + frag_mask[0][0].elt(i) + ); } - printf("\n"); printf("end print frag_mask\n"); } gmem_mask.move(); @@ -503,23 +533,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - // if (is_thread_0()) { - // for (int i_m = 0; i_m < Mma_tile_p::MMAS_M * 2; i_m ++){ - // printf("after_mask: loop_step_idx = %03d, l = %03d, threadIdx.x = %03d, threadIdx.y = %03d, i_m = %01d, accp ele(0) = %f, ele(1) = %f, ele(2) = %f, ele(3) = %f, ele(4) = %f, ele(5) = %f, ele(6) = %f, ele(7) = %f \n", - // loop_step_idx, l, threadIdx.x, threadIdx.y, i_m, - // softmax.elt_[i_m][0], - // softmax.elt_[i_m][1], - // softmax.elt_[i_m][2], - // softmax.elt_[i_m][3], - // softmax.elt_[i_m][4], - // softmax.elt_[i_m][5], - // softmax.elt_[i_m][6], - // softmax.elt_[i_m][7] - // ); - // } - // } - - // Apply the mask. // this impl is more like padding softmax.apply_mask(mask); @@ -610,6 +623,30 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); softmax.template pack(frag_p); // ? pack + + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + #pragma unroll + for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + } + if (Return_softmax) { gmem_s.store(frag_p, mask); gmem_s.move(); From 2735ee9b5cfd5340c47c168f2803bd51b8954f8d Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 4 Aug 2022 13:49:26 +0800 Subject: [PATCH 09/71] add bias --- csrc/flash_attn/src/fmha/gemm.h | 1 + csrc/flash_attn/src/fmha/gmem_tile.h | 37 +++++++++- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 78 ++++++++++++++++++++- flash_attn/flash_attn_interface.py | 12 ++-- 4 files changed, 120 insertions(+), 8 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index e45064a65..2f9a9f275 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -211,6 +211,7 @@ struct Fragment_accumulator : public Fragment { template< typename Other_fragment_ > inline __device__ void addf(const Other_fragment_ &other) { + #pragma unroll for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) = this->elt(ii) + toFloat(other.elt(ii)); } diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 540f72324..3d18ead63 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -443,7 +443,42 @@ struct Gmem_tile_mma_mask : public Base { inline __device__ Gmem_tile_mma_mask(const Params ¶ms, const Block_info& binfo, const int tidx) : Base(params.attn_mask_ptr, params, binfo.bidb, binfo.bidh, tidx) { } - // TODO load data impl + + // Load from global memory to Fragment. + template + inline __device__ void load(Fragment (&frag)[N][M]) { + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + uint4 dst; + Base::load(dst, mi, ni); + frag[ni][mi].reg(0) = dst.x; + frag[ni][mi].reg(2) = dst.y; + frag[ni][mi].reg(1) = dst.z; + frag[ni][mi].reg(3) = dst.w; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// attn bias struct like s, maybe later can reuse the above declaration +template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > +struct Gmem_tile_mma_bias : public Base { + + // The number of mmas in the vertical dimension. + static constexpr int M = Base::MMAS_M; + // The number of mmas in the horizontal dimension. + static constexpr int N = Base::MMAS_N; + // The type of the vectors stored by each STG. + using Type = typename Base::Type; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_bias(const Params ¶ms, const Block_info& binfo, const int tidx) + : Base(params.attn_bias_ptr, params, binfo.bidb, binfo.bidh, tidx) { + } // Load from global memory to Fragment. template diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index d29b70e63..a343318d7 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -280,6 +280,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Gmem_tile_mask gmem_mask(params, binfo, tidx); // TODO: load fun as s + bool has_bias = !(params.attn_bias_ptr == nullptr); + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx); + // TODO: load fun as s + Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); // Wind gmem tiles to the correct position. @@ -301,6 +308,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_mask.move(begin); } + if (has_bias) { + // TODO: bias move + gmem_bias.move(begin); + } + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("og begin %d\n", begin_og); printf("begin %d\n", begin); @@ -341,6 +353,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // TODO: mask move as s gmem_mask.move(loop_step_idx * steps_og); } + if (has_bias) { + gmem_bias.move(loop_step_idx * steps_og); + } } // Trigger the loads for K. @@ -450,7 +465,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // struct Fragment_a : public Fragment { // struct Fragment_accumulator : public Fragment { Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - + gmem_mask.load(frag_mask); // see the 463 // acc_p.template add(frag_mask); @@ -459,7 +474,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // mask tranpose or not __syncthreads(); + #pragma unroll for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + #pragma unroll for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { // acc_p[mi][ni].template addf(frag_mask[ni][mi]); // acc_p[mi][ni].add(frag_mask[ni][mi]); @@ -506,6 +523,65 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_mask.move(); } + + if (has_bias) { + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf(" add attn mask ====\n"); + } + // method 1 + using Frag_bias = fmha::Fragment_c; + + Frag_bias frag_bias[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + gmem_bias.load(frag_bias); + + __syncthreads(); + + #pragma unroll + for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + acc_p[mi][ni].addf(frag_bias[ni][mi]); + } + } + + // debug: + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("print frag_bias: l=%d\n", l); + + float2 tmp_mask1 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[0])); + float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[0][0])); + float2 tmp_mask3 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[0][1])); + float2 tmp_mask4 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[1][0])); + printf("Per warp, threadIdx.x = %d, frag_bias[0] = %.6f, %.6f, frag_bias[0][0] = %.6f, %.6f\n", + threadIdx.x, tmp_mask1.x, tmp_mask1.y, tmp_mask2.x, tmp_mask2.y); + printf("Per warp, threadIdx.x = %d, frag_bias[0][1] = %.6f, %.6f, frag_bias[1][0] = %.6f, %.6f\n", + threadIdx.x, tmp_mask3.x, tmp_mask3.y, tmp_mask4.x, tmp_mask4.y); + + #pragma unroll + for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { + #pragma unroll + for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { + float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[mi][ki])); + printf("Per warp, threadIdx.x = %d, mi=%d, ni=%d, frag_mask[mi][ni]= %.6f, %.6f\n", + threadIdx.x, mi, ki, tmp_mask2.x, tmp_mask2.y); + } + } + + int num_elt = frag_bias[0][0].NUM_ELTS; + for (int i = 0; i < num_elt; i ++) { + printf("i=%d, frag_mask=%.6f, hex=%d, %f\n", i, + frag_bias[0][0].elt(i), + frag_bias[0][0].elt(i), + // toFloat(frag_mask[0][0].elt(i)) + frag_bias[0][0].elt(i) + ); + } + printf("end print frag_bias\n"); + } + gmem_bias.move(); + } + + // debug: after add mask if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { printf("after mask print acc_p\n"); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index eba15bec6..3593fda44 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -14,7 +14,7 @@ def _get_block_size(device, head_dim, is_dropout): return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128 -def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, dropout_p, +def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, attn_bias, dropout_p, softmax_scale, causal, return_softmax): # import pdb; pdb.set_trace() if attn_mask is None: @@ -25,7 +25,7 @@ def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_s else: out, softmax_lse, *rest = flash_attn_cuda.fwd( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, - False, causal, return_softmax, None, attn_mask, None + False, causal, return_softmax, None, attn_mask, attn_bias ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -121,14 +121,14 @@ def backward(ctx, dout, *args): class FlashAttnFunc(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, attn_bias, dropout_p, softmax_scale, causal, return_softmax): # Save rng_state because the backward pass will regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse, S_dmask = _flash_attn_forward( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, attn_bias, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax ) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) @@ -217,7 +217,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq return_attn_probs) -def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask=None, +def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask=None, attn_bias=None, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation Arguments: @@ -246,7 +246,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ - return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, + return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, attn_bias, dropout_p, softmax_scale, causal, return_attn_probs) From 3232d8dd50d5c7be08ae1ecbca6eab3c4c92adfc Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 4 Aug 2022 15:00:27 +0800 Subject: [PATCH 10/71] add template --- .../src/fmha_fprop_fp16_kernel.sm80.cu | 133 ++++++++++++++---- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 29 ++-- 2 files changed, 119 insertions(+), 43 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 765132460..7ccc78b9a 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -32,9 +32,9 @@ #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" -template +template __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { - fmha::device_1xN_loop(params); + fmha::device_1xN_loop(params); } template @@ -70,34 +70,111 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, const int smem_size = fmha::get_dynamic_smem_size() + (loop_steps > 1 ? smem_size_softmax_lse : 0); - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - auto kernel = launch_params.params.is_causal - ? (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + bool has_attn = !(launch_params.params.attn_mask_ptr == nullptr); + bool has_bias = !(launch_params.params.attn_bias_ptr == nullptr); + + if (has_attn) + { + if (has_bias) { + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); } - dim3 grid(launch_params.params.b, launch_params.params.h); + }else{ + if (has_bias) { + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); - printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - printf("block size: %d\n", Kernel_traits::THREADS); - // grid: b, h (batch_size = len(cu_seq_q), 16), - // block: 32 - // kernel<<>> - // xh: static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA; - // static constexpr int THREADS_PER_WARP = 32; - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); + + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + } + } } void run_fmha_fp16_sm80(Launch_params &launch_params, diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index a343318d7..23117232b 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -198,7 +198,7 @@ constexpr size_t get_dynamic_smem_size(){ return Gemm_Q_K::SMEM_BYTES; } -template +template inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -273,14 +273,14 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - bool has_attn = !(params.attn_mask_ptr == nullptr); + // bool has_attn = !(params.attn_mask_ptr == nullptr); // Allocate the global memory tile loader for mask. using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; // conctructor Gmem_tile_mask gmem_mask(params, binfo, tidx); // TODO: load fun as s - bool has_bias = !(params.attn_bias_ptr == nullptr); + // bool has_bias = !(params.attn_bias_ptr == nullptr); // Allocate the global memory tile loader for bias. using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_mask; // conctructor @@ -303,12 +303,12 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_o_tmp.move(begin); if (Return_softmax) { gmem_s.move(begin); } - if (has_attn) { + if constexpr (has_attn) { // TODO: mask move gmem_mask.move(begin); } - if (has_bias) { + if constexpr (has_bias) { // TODO: bias move gmem_bias.move(begin); } @@ -349,11 +349,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } - if (has_attn) { + if constexpr (has_attn) { // TODO: mask move as s gmem_mask.move(loop_step_idx * steps_og); } - if (has_bias) { + if constexpr (has_bias) { gmem_bias.move(loop_step_idx * steps_og); } } @@ -449,7 +449,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i printf("end print acc_p\n"); } - if (has_attn) { + if constexpr (has_attn) { if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf(" add attn mask ====\n"); } @@ -524,7 +524,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } - if (has_bias) { + if constexpr (has_bias) { if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf(" add attn mask ====\n"); } @@ -883,7 +883,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void device_1xN_loop(const Params ¶ms) { // The block index for the batch. @@ -913,15 +913,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { // Tc, loop over k in algo2 line 6, blocksize_c in line 4 if (params.seqlen_k == blocksize_c) { // inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { - - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); } else { const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); } - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1); + fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1); } } From f33f2adce59973a6eb396762bd4526b498d0ccfe Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 4 Aug 2022 15:12:45 +0800 Subject: [PATCH 11/71] add test --- benchmarks/test_example.py | 181 ++++++++++++++++++++++++++++++++++--- 1 file changed, 170 insertions(+), 11 deletions(-) diff --git a/benchmarks/test_example.py b/benchmarks/test_example.py index 1711ac865..086d13463 100644 --- a/benchmarks/test_example.py +++ b/benchmarks/test_example.py @@ -5,29 +5,86 @@ import math from typing import Optional, Callable, List, Tuple, Sequence import numpy as np +import deepspeed from time import perf_counter_ns from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func -# bs = 1 -# seq = 20 -# head = 32 -# c_dim = 16 +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + d = t.dtype + if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): + with torch.cuda.amp.autocast(enabled=False): + s = torch.nn.functional.softmax(t, dim=dim) + else: + s = torch.nn.functional.softmax(t, dim=dim) + + return s + +def _attention(query, key, value, mask=None, biases=None) -> torch.Tensor: + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 0)) + + # import pdb; pdb.set_trace() + # [*, H, Q, K] + a = torch.matmul(query, key) + + print ("q * k: ", a) + + if biases is None: + biases = [] + for b in biases: + a += b + + print ("after bias:", a) + + if mask is not None: + a += mask + + print ("after mask:", a) + + a = softmax_no_cast(a, -1) + print ("softmax :", a) + + # [*, H, Q, C_hidden] + a = torch.matmul(a, value) + print ("p * v: ", a) + + return a + torch.manual_seed(0) # v2 bs = 1 -seq = 128 -head = 16 -c_dim = 32 +seq = 2 +head = 1 +c_dim = 16 + +# import pdb; pdb.set_trace() + +print (10 * "*" + "prepare data" + 10 * "*" ) +# dtype = torch.bfloat16 +dtype = torch.half +device = "cuda" orig_tensor = torch.stack( [ (i+1) * 0.1 * torch.ones((bs, seq, head, c_dim)) for i in range(seq) ] ,dim = 1 -).cuda().to(torch.bfloat16) +).cuda().to(dtype) +print ("tensor: ", orig_tensor) print ("origin shape: ", orig_tensor.shape) +# [bs, seq, seq, head, c_dim] batch_size = bs * seq seqlen = seq @@ -35,9 +92,22 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=orig_tensor.device) -print (cu_seqlens) +print ("cu_seqlens: ", cu_seqlens) + +# [bs, seq, seq, head, c_dim] +orig_tensor = orig_tensor.permute([0, 1, 3, 2, 4]) +# [bs, seq, head, seq, c_dim] +print ("after permute: ", orig_tensor.shape) + +print (10 * "*" + "end prepare data" + 10 * "*" ) + +print (10 * "*" + "normal attn" + 10 * "*" ) +print ("normal attn: ", _attention(orig_tensor, orig_tensor, orig_tensor)) +print (10 * "*" + "end normal attn" + 10 * "*" ) + tensor_2d_pad = orig_tensor.reshape(-1, head, c_dim) +print (10 * "*" + "flash attn without mask" + 10 * "*" ) output3 = flash_attn_unpadded_func( tensor_2d_pad, tensor_2d_pad, @@ -50,5 +120,94 @@ softmax_scale = 1., # q has been scaled already ) -print (output3.shape, output3.shape) -output3 = output3.reshape((bs, seq, seq, head, c_dim)) \ No newline at end of file +print ("output3 shape: ", output3.shape) +output3 = output3.reshape((bs, seq, seq, head, c_dim)) +print ("output3: ", output3.shape) +print (10 * "*" + "end flash attn without mask" + 10 * "*" ) + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + +# mask = gen_attn_mask( +# ( +# # [bs, seq, h, seq, seq_k] +# # [bs, seq, 1, 1, seq_k] +# torch.ones( +# bs, +# seq, +# 1, +# 1, +# seq, +# dtype=dtype, +# device="cuda", +# ) +# > 0.2 +# ).type(dtype), +# -1e5, +# ) +# unicore mask +# torch.rand( +# n_batch, +# n_groups, +# 1, +# 1, +# last_dim, +# dtype=dtype, +# device=test_device, +# ) + +print (10 * "*" + "flash attn with mask" + 10 * "*" ) +mask = torch.randn( + bs, + seq, + 1, + 1, + seq, + dtype=dtype, + device="cuda", + ) + +# [bs, group, 1, 1, seq_k] +seq_q = seq +seq_k = seq +print ("mask: ", mask.shape) +mask_exp = mask.expand([bs, seq_q, head, seq_q, seq_k]) +print ("mask_exp: ", mask_exp.shape) +mask_batch = mask_exp.reshape((bs*seq_q, head, seq_q, seq_k)) +print ("mask_exp: ", mask_batch.shape) + +print ("mask: ", mask_batch) +print ("tensor: ", tensor_2d_pad) +print ("mask maximum number :", mask_batch.abs().max()) + +# bs * seq +# batch_size, num_heads, max_seqlen_q, max_seqlen_k +output4 = flash_attn_unpadded_func(tensor_2d_pad, + tensor_2d_pad, + tensor_2d_pad, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + # None, + attn_mask=mask_batch, + attn_bias=mask_batch, + dropout_p=0.0, + softmax_scale=1.0) + +output4 = output4.reshape((bs, seq, seq, head, c_dim)) + +print ("output4: ", output4.shape) + +print (10 * "*" + "end flash attn with mask" + 10 * "*" ) + +print (10 * "*" + "normal attn with mask" + 10 * "*" ) +print ("normal attn: ", _attention(orig_tensor, orig_tensor, orig_tensor, mask=mask)) +print (10 * "*" + "end normal attn with mask" + 10 * "*" ) + +print ("all close on output3 and output4 max error", (output3 - output4).abs().max()) +print ("all close on output3 and output4 min error", (output3 - output4).abs().min()) +print ("all close on output3 and output4 num less min error", torch.sum( (output3 - output4).abs() <=(output3 - output4).abs().min() )) From 0402aa5e2f356ea7aa9fe57677b78fdb4073b130 Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 4 Aug 2022 15:33:32 +0800 Subject: [PATCH 12/71] clean --- csrc/flash_attn/fmha_api.cpp | 1 - csrc/flash_attn/src/fmha/gemm.h | 29 +-- csrc/flash_attn/src/fmha/gmem_tile.h | 19 -- csrc/flash_attn/src/fmha/kernel_traits.h | 3 + .../src/fmha_fprop_fp16_kernel.sm80.cu | 19 -- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 184 +----------------- setup.py | 6 - 7 files changed, 7 insertions(+), 254 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 8ac9dfdd2..f076cf5f8 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -296,7 +296,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q attn_bias ? attn_bias->data_ptr() : nullptr ); - printf ("debug fmha api start test\n"); run_fmha_fp16_sm80(launch_params, /*configure=*/ true); // number of times random will be generated per thread, to offset philox counter in thc random // state diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 2f9a9f275..9d40713f2 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -115,12 +115,6 @@ struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { return reinterpret_cast(&this->regs_[0])[ii]; } - // xh: Immutable access to the elements and cast to elem_type. - // template< typename elem_type > - // inline __device__ elem_type& elt(int ii) { - // return reinterpret_cast(&this->regs_[0])[ii]; - // } - // Immutable access to the elements with a cast. template< typename Cast_type > inline __device__ const Cast_type& elt_as(int ii) const { @@ -194,6 +188,7 @@ inline float toFloat(__nv_bfloat16 a) { return __bfloat162float(a); } #endif +//////////////////////////////////////////////////////////////////////////////////////////////////// struct Fragment_accumulator : public Fragment { @@ -211,33 +206,13 @@ struct Fragment_accumulator : public Fragment { template< typename Other_fragment_ > inline __device__ void addf(const Other_fragment_ &other) { + // elt or reg? #pragma unroll for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) = this->elt(ii) + toFloat(other.elt(ii)); } } - // cause invalid redeclaration of member function template - // template - // inline __device__ void addf(const Other_fragment_ &other); - - // // 特化 - // template - // inline __device__ void addf(const Other_fragment_ &other) { - // for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { - // this->elt(ii) = this->elt(ii) + __half2float(other.elt(ii)); - // } - // } - - // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - // template - // inline __device__ void addf(const Other_fragment_ &other) { - // for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { - // this->elt(ii) = this->elt(ii) + __bfloat162float(other.elt(ii)); - // } - // } - // #endif - inline __device__ void mul_(const float other) { for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { this->elt(ii) *= other; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 3d18ead63..0b372b7f3 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -114,23 +114,6 @@ struct Gmem_tile_qkv { #pragma unroll for( int ii = 0; ii < LDGS; ++ii ) { fct.load(ii, preds[ii]); - // // The fetch registers. fetch_ - // -> inline __device__ void load(int ii, bool p) - // -> ldg(fetch_[ii], ptrs_[ii]); - // -> inline __device__ void ldg(uint4 &dst, const void *ptr) { - // dst = *reinterpret_cast(ptr); - // } - } - } - - // print data. - template - inline __device__ void print() { - // int row_ = tidx_ / THREADS_PER_ROW; - printf("print LDGS %d\n", LDGS); - for( int ii = 0; ii < LDGS; ++ii ) { - // char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; - printf("data: %f\n", *(elem_type *)(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes)); } } @@ -402,8 +385,6 @@ struct Gmem_tile_mma_s : public Base { dst.w = frag[ni][mi].reg(3); if( mask.any_valid(mi, ni) ) { Base::store(dst, mi, ni); - // uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - // fmha::stg(ptr_ + offset, data); } } } diff --git a/csrc/flash_attn/src/fmha/kernel_traits.h b/csrc/flash_attn/src/fmha/kernel_traits.h index 809f19c52..c5f573dea 100644 --- a/csrc/flash_attn/src/fmha/kernel_traits.h +++ b/csrc/flash_attn/src/fmha/kernel_traits.h @@ -74,6 +74,9 @@ struct FMHA_kernel_traits { // Gmem_tile_mma_mask using Gmem_tile_mask = fmha::Gmem_tile_mma_mask; + // Gmem_tile_mma_bias + using Gmem_tile_bias = fmha::Gmem_tile_mma_bias; + // The shared memory tile to transpose S. using Smem_tile_st = fmha::Smem_tile_mma_transposed; diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 7ccc78b9a..a9e84d3f9 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -51,20 +51,9 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; launch_params.elts_per_thread = elts_per_head; - - printf("configuration----\n"); - printf("blocksize_c %d\n", blocksize_c); - printf("loop steps %d\n", loop_steps); - printf("Steps %zu\n", STEPS); - printf("Cta tile M %d\n", M); - printf("Mma tile MMAS_M %zu\n", MMAS_M); - printf("Mma tile MMAS_N %zu\n", MMAS_N); - printf("elts_per_thread %zu\n", elts_per_head); - printf("configuration----\n"); return; } - printf("deub kernel sm80\n"); constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; // Don't need smem_size_softmax_lse if we're not looping const int smem_size = fmha::get_dynamic_smem_size() @@ -185,14 +174,6 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, if (launch_params.params.d == 16) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; - // xh: template using Cta_tile_p = fmha::Cta_tile_extd; - // -> using Cta_tile_extd = Cta_tile_; - // -> static constexpr int M = M_(STEP 16), N = N_(s 128), K = K_ (D 16); - // -> // The number of warps. - // static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_; - // WARPS_M is 1, WARPS_N is 4 - // STEP = 16 ??? run_fmha_fp16_sm80_loop_(launch_params, configure); } else if( launch_params.params.seqlen_k == 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 23117232b..a342977df 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -260,12 +260,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for Q. Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); - // Gmem_tile_q is fmha::Gmem_tile_qkv; - // Gmem_tile_q is fmha::Gmem_tile_qkv, 取数逻辑在gmem_tile.h - // 指针q_ptr引入数据,构造 csrc/flash_attn/src/fmha/gmem_tile.h:67 - // row_stide is dim0, head_strid is dim1 - // constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; - // using A_type = uint16_t; // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); @@ -282,7 +276,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // bool has_bias = !(params.attn_bias_ptr == nullptr); // Allocate the global memory tile loader for bias. - using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_mask; + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; // conctructor Gmem_tile_bias gmem_bias(params, binfo, tidx); // TODO: load fun as s @@ -292,12 +286,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Wind gmem tiles to the correct position. static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); const int begin_og = begin; - // begin is 0 begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; const int steps_og = steps; steps -= begin - begin_og; - // begin - begin_og = 0 - // steps -= 0 gmem_q.move(begin); gmem_o.move(begin); gmem_o_tmp.move(begin); @@ -312,15 +303,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // TODO: bias move gmem_bias.move(begin); } - - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("og begin %d\n", begin_og); - printf("begin %d\n", begin); - printf("og step %d\n", steps_og); - printf("begin %d\n", steps); - printf("loop_step_idx %d\n", loop_step_idx); - } - gmem_softmax_lse.move(begin); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { @@ -328,8 +310,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // } fmha::Mask mask(binfo, tidx, loop_step_idx); - // mask in used pad actually, actual_seqlen_k is params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb] - // why works for the difference output ? // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); @@ -389,12 +369,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i __syncthreads(); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Gemm load k...\n"); - // // https://stackoverflow.com/questions/7397934/calling-template-function-within-template-class - // gmem_k.template print(); - // } - // Load the fragments for Q. gemm_q_k.load_q(); @@ -434,101 +408,28 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i fmha::Clear_accumulator::apply(acc_p); // Do this part of P = Q * K^T. - // ? M / N 参数? gemm_q_k(acc_p); // TODO acc_p += mask, index like gmem_s.store(frag_p, mask); - // move(1) - - // debug: before add mask - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - printf("before mask print acc_p\n"); - // for (int i = 0; i < acc_p[0][0].NUM_ELTS; i ++) { - for (int i = 0; i < 8; i ++) { - printf("i=%d, acc_p=%.6f \n", i, acc_p[0][0].elt(i)); - } - printf("end print acc_p\n"); - } if constexpr (has_attn) { - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf(" add attn mask ====\n"); - } - // method 1 using Frag_mask = fmha::Fragment_c; - // method 2 - // using Frag_mask = fmha::Fragment_a; - // template< - // // The type of the elements. - // typename Data_type_, - // // The number of elements. - // int NUM_ELTS_, - // struct Fragment_a : public Fragment { - // struct Fragment_accumulator : public Fragment { Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; gmem_mask.load(frag_mask); - // see the 463 - // acc_p.template add(frag_mask); - // acc_p.template add(frag_mask); - // acc_p.add(frag_mask); - // mask tranpose or not + // do we need sync ? __syncthreads(); #pragma unroll for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { #pragma unroll for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - // acc_p[mi][ni].template addf(frag_mask[ni][mi]); - // acc_p[mi][ni].add(frag_mask[ni][mi]); acc_p[mi][ni].addf(frag_mask[ni][mi]); } } - - // debug: - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("print frag_mask: l=%d\n", l); - - float2 tmp_mask1 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[0])); - float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[0][0])); - float2 tmp_mask3 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[0][1])); - float2 tmp_mask4 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[1][0])); - printf("Per warp, threadIdx.x = %d, frag_mask[0] = %.6f, %.6f, frag_mask[0][0] = %.6f, %.6f\n", - threadIdx.x, tmp_mask1.x, tmp_mask1.y, tmp_mask2.x, tmp_mask2.y); - printf("Per warp, threadIdx.x = %d, frag_mask[0][1] = %.6f, %.6f, frag_mask[1][0] = %.6f, %.6f\n", - threadIdx.x, tmp_mask3.x, tmp_mask3.y, tmp_mask4.x, tmp_mask4.y); - - #pragma unroll - for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_mask[mi][ki])); - printf("Per warp, threadIdx.x = %d, mi=%d, ni=%d, frag_mask[mi][ni]= %.6f, %.6f\n", - threadIdx.x, mi, ki, tmp_mask2.x, tmp_mask2.y); - } - } - - int num_elt = frag_mask[0][0].NUM_ELTS; - // for (int i = 0; i < frag_mask[0][0].NUM_ELTS; i ++) { - // sometime correct, sometime wrong - for (int i = 0; i < num_elt; i ++) { - printf("i=%d, frag_mask=%.6f, hex=%d, %f\n", i, - frag_mask[0][0].elt(i), - frag_mask[0][0].elt(i), - // toFloat(frag_mask[0][0].elt(i)) - frag_mask[0][0].elt(i) - ); - } - printf("end print frag_mask\n"); - } gmem_mask.move(); } - if constexpr (has_bias) { - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf(" add attn mask ====\n"); - } - // method 1 using Frag_bias = fmha::Fragment_c; Frag_bias frag_bias[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; @@ -543,55 +444,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i acc_p[mi][ni].addf(frag_bias[ni][mi]); } } - - // debug: - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("print frag_bias: l=%d\n", l); - - float2 tmp_mask1 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[0])); - float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[0][0])); - float2 tmp_mask3 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[0][1])); - float2 tmp_mask4 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[1][0])); - printf("Per warp, threadIdx.x = %d, frag_bias[0] = %.6f, %.6f, frag_bias[0][0] = %.6f, %.6f\n", - threadIdx.x, tmp_mask1.x, tmp_mask1.y, tmp_mask2.x, tmp_mask2.y); - printf("Per warp, threadIdx.x = %d, frag_bias[0][1] = %.6f, %.6f, frag_bias[1][0] = %.6f, %.6f\n", - threadIdx.x, tmp_mask3.x, tmp_mask3.y, tmp_mask4.x, tmp_mask4.y); - - #pragma unroll - for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - float2 tmp_mask2 = __half22float2(reinterpret_cast<__half2 &>(frag_bias[mi][ki])); - printf("Per warp, threadIdx.x = %d, mi=%d, ni=%d, frag_mask[mi][ni]= %.6f, %.6f\n", - threadIdx.x, mi, ki, tmp_mask2.x, tmp_mask2.y); - } - } - - int num_elt = frag_bias[0][0].NUM_ELTS; - for (int i = 0; i < num_elt; i ++) { - printf("i=%d, frag_mask=%.6f, hex=%d, %f\n", i, - frag_bias[0][0].elt(i), - frag_bias[0][0].elt(i), - // toFloat(frag_mask[0][0].elt(i)) - frag_bias[0][0].elt(i) - ); - } - printf("end print frag_bias\n"); - } gmem_bias.move(); } - - - // debug: after add mask - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - printf("after mask print acc_p\n"); - // for (int i = 0; i < acc_p[0][0].NUM_ELTS; i ++) { - for (int i = 0; i < 8; i ++) { - printf("i=%d, acc_p=%.6f\n", i, acc_p[0][0].elt(i)); - } - printf("\n"); - printf("end print acc_p\n"); - } uint4 out[Gmem_tile_o::STGS_PER_LOOP]; if (!Is_first) { gmem_o_tmp.load(out, 0); } @@ -698,30 +552,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); softmax.template pack(frag_p); - // ? pack - - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - } if (Return_softmax) { gmem_s.store(frag_p, mask); @@ -894,14 +724,6 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int tidx = threadIdx.x; const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx; - // tidx_global = (blockIdx.x * params.h + blockIdx.y) * blockDim.x * 2 + threadIdx.x; - // what is mean of 2? - // if (tidx == 0) { - // printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); - // } - // if (tidx == 1) { - // printf("blockIdx.x: %d, blockIdx.y: %d, threadIdx.x: %d, tidx_global: %d\n", bidb, bidh, tidx, tidx_global); - // } auto seeds = at::cuda::philox::unpack(params.philox_args); Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); @@ -910,9 +732,7 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int STEPS = (params.seqlen_q + M - 1) / M; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - // Tc, loop over k in algo2 line 6, blocksize_c in line 4 if (params.seqlen_k == blocksize_c) { - // inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); } else { const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; diff --git a/setup.py b/setup.py index 2f48a6fda..8fad3ed94 100644 --- a/setup.py +++ b/setup.py @@ -148,12 +148,6 @@ def append_nvcc_threads(nvcc_extra_args): Path(this_dir) / 'csrc' / 'flash_attn' / 'src', Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', ], - # add depends for modification of header file - depends = [ - Path(this_dir) / 'csrc' / 'flash_attn', - Path(this_dir) / 'csrc' / 'flash_attn' / 'src', - Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', - ], ) ) From b4793faf4a176b00c3d56b6ab6e4b95318b22ee9 Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 4 Aug 2022 15:37:20 +0800 Subject: [PATCH 13/71] clean code --- csrc/flash_attn/src/fmha/gmem_tile.h | 2 -- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 0b372b7f3..37a7fc53c 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -126,8 +126,6 @@ struct Gmem_tile_qkv { char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { fmha::stg(ptr_, data[ii]); - // stg function, inline __device__ void stg(void *ptr, uint2 val) - // *reinterpret_cast(ptr) = val; } } } diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index a342977df..7d1b40a11 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -305,10 +305,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } gmem_softmax_lse.move(begin); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("begin = %d, steps = %d\n", begin, steps); - // } - fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. @@ -411,6 +407,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gemm_q_k(acc_p); // TODO acc_p += mask, index like gmem_s.store(frag_p, mask); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); + // } + if constexpr (has_attn) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; From 3659eb0dc0d5710557ae000ded1ab863ca95e11d Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 23 Aug 2022 16:29:27 +0800 Subject: [PATCH 14/71] add mask load --- csrc/flash_attn/fmha_api.cpp | 27 +- csrc/flash_attn/src/fmha/gmem_tile.h | 206 ++++++++++- .../src/fmha_fprop_fp16_kernel.sm80.cu | 327 ++++++++++-------- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 200 ++++++++--- tests/build.sh | 49 +++ tests/fmha_api.h | 24 ++ tests/test_forward.cu | 106 ++++++ 7 files changed, 722 insertions(+), 217 deletions(-) create mode 100644 tests/build.sh create mode 100644 tests/fmha_api.h create mode 100644 tests/test_forward.cu diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index f076cf5f8..3b90b7b0a 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -31,6 +31,10 @@ #include "fmha.h" +#ifdef DDEBUG_PRINT +#include "fmha_api.h" +#endif + #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -103,6 +107,26 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.attn_bias_ptr = attn_bias; +#ifdef DEBUG_PRINT + printf("========================================\n"); + printf("params.q_row_stride_in_elts = %d \n", params.q_row_stride_in_elts); + printf("params.k_row_stride_in_elts = %d \n", params.k_row_stride_in_elts); + printf("params.v_row_stride_in_elts = %d \n", params.v_row_stride_in_elts); + printf("params.q_head_stride_in_elts = %d \n", params.q_head_stride_in_elts); + printf("params.k_head_stride_in_elts = %d \n", params.k_head_stride_in_elts); + printf("params.v_head_stride_in_elts = %d \n", params.v_head_stride_in_elts); + printf("params.h = %d \n", params.h); + printf("params.b = %d \n", params.b); + printf("params.seqlen_q (max seq) = %d \n", params.seqlen_q); + printf("params.seqlen_k (max seq) = %d \n", params.seqlen_k); + printf("params.d = %d \n", params.d); + printf("params.o_row_stride_in_elts = %d \n", params.o_row_stride_in_elts); + printf("params.o_head_stride_in_elts = %d \n", params.o_head_stride_in_elts); + printf("params.s_stride_in_bytes = %d \n", params.s_stride_in_bytes); + printf("========================================\n"); +#endif + + // Set the different scale values. // const float scale_bmm1 = 1.f / sqrtf(d); const float scale_bmm1 = softmax_scale; @@ -736,7 +760,7 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size return { dq, dk, dv, softmax_d }; } - +#if !defined(DEBUG_USING_NVCC) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "Fused Multi-head Self-attention"; m.def("fwd", &mha_fwd, "Forward pass"); @@ -744,3 +768,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)"); m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)"); } +#endif \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 37a7fc53c..dc0ef20b3 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -41,11 +41,13 @@ template< // The number of rows of Q, K or V loaded by this tile. int ROWS_, // The number of columns. - int COLS + int COLS_ > struct Gmem_tile_qkv { using Cta_tile = Cta_tile_; + // xh + static constexpr int COLS = COLS_; static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8; // The size of each LDG. @@ -84,9 +86,41 @@ struct Gmem_tile_qkv { // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes); // Add the block index. +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("use_seqlen_q=%d\n", use_seqlen_q); + printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); + printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_LDG=%d, LDGS=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_LDG, LDGS); + printf("\n"); + } + if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("use_seqlen_q=%d\n", use_seqlen_q); + printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); + printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_LDG=%d, LDGS=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_LDG, LDGS); + printf("\n"); + } +#endif // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("use_seqlen_q=%d\n", use_seqlen_q); + printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", + threadIdx.x, threadIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); + printf("\n"); + } + if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("use_seqlen_q=%d\n", use_seqlen_q); + printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); + printf("\n"); + } +#endif // Assemble the final pointer. ptr += row_offset + col * BYTES_PER_LDG; } @@ -213,7 +247,24 @@ struct Gmem_tile_o { row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); // Assemble the final pointer. ptr_ += row_offset + col * BYTES_PER_STG; - +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("print o parameter\n"); + printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); + printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); + printf("\n"); + } + if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("print o parameter\n"); + printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); + printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", + threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); + printf("\n"); + } +#endif // Is that thread active on the last STG? if( HAS_INCOMPLETE_STG ) { is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; @@ -356,6 +407,7 @@ struct Gmem_tile_mma_sd { template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > struct Gmem_tile_mma_s : public Base { + // mma matrix multiply // The number of mmas in the vertical dimension. static constexpr int M = Base::MMAS_M; // The number of mmas in the horizontal dimension. @@ -407,38 +459,154 @@ struct Gmem_tile_mma_s : public Base { //////////////////////////////////////////////////////////////////////////////////////////////////// // attn mask struct like s, maybe later can reuse the above declaration -template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > -struct Gmem_tile_mma_mask : public Base { +// template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > +// struct Gmem_tile_mma_mask : public Base { - // The number of mmas in the vertical dimension. - static constexpr int M = Base::MMAS_M; - // The number of mmas in the horizontal dimension. - static constexpr int N = Base::MMAS_N; +template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> +struct Gmem_tile_mma_mask { + + using Mma_tile = fmha::Hmma_tile; // The type of the vectors stored by each STG. - using Type = typename Base::Type; + using StoreType = uint32_t; + + // static constexpr int LDG_ELEMENTS = 2 + // using Type = typename fmha::Uint_from_size_in_bytes< LDG_ELEMENTS * BYTES_PER_ELEMENT >::Type; + + // The number of MMAs in the M dimension. + static constexpr int M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int N = Mma_tile::MMAS_N; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + static constexpr int COLS = Cta_tile::N; + + // The size of each LDG. + // load two elements of data + static constexpr int BYTES_PER_LDG = 2 * BYTES_PER_ELEMENT; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of LDGS needed to store a chunk of the P matrix in total. + // Tell me if has more efficient way + static constexpr int LDGS_PER_THREAD_PER_WARP = 4; + static constexpr int THREADS_PER_QUAD = 4; + static constexpr int COL_PER_MMA_PER_CTA = Cta_tile::THREADS_PER_WARP / THREADS_PER_QUAD; // Ctor. template< typename Params, typename Block_info > - inline __device__ Gmem_tile_mma_mask(const Params ¶ms, const Block_info& binfo, const int tidx) - : Base(params.attn_mask_ptr, params, binfo.bidb, binfo.bidh, tidx) { + inline __device__ Gmem_tile_mma_mask(const Params ¶ms, + // const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, + const Block_info& binfo, const int tidx) + : ptr_(static_cast(params.attn_mask_ptr)) + // : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , actual_seqlen_k(binfo.actual_seqlen_k) + , tidx_(tidx) + { + row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + + const int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + const int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + // this col is mean the 8x4 tile's cole + + row = warp_m * Mma_tile::M_PER_MMA + quad; + static_assert(Mma_tile::M_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + col = warp_n * Mma_tile::N_PER_MMA + tid; + static_assert(Mma_tile::N_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + // The distance between two blocks (in bytes). + // TODO: mask is [bs, head, seq_q, seq_k] + // The block index. + uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + + uint32_t row_offset = bidx * params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; + // row_offset = (uint32_t)(row * row_stride_in_bytes); + row_offset += (uint32_t)(row * params.seqlen_k * BYTES_PER_ELEMENT); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", + tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); + printf("bidb=%d, bidh=%d, param.h=%d\n", binfo.bidb, binfo.bidh, params.h); + printf("\n"); + } +#endif + // do we need to move col first if seklen_k > cols + ptr_ += row_offset; } // Load from global memory to Fragment. - template - inline __device__ void load(Fragment (&frag)[N][M]) { + template + inline __device__ void load(Fragment (&frag)[M][N]) { + // using Fragment = typename fmha::Fragment; + // like Fragment_a + + const void *ptrs[LDGS_PER_THREAD_PER_WARP]; + uint32_t preds[LDGS_PER_THREAD_PER_WARP]; + #pragma unroll for( int mi = 0; mi < M; mi++ ) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - uint4 dst; - Base::load(dst, mi, ni); - frag[ni][mi].reg(0) = dst.x; - frag[ni][mi].reg(2) = dst.y; - frag[ni][mi].reg(1) = dst.z; - frag[ni][mi].reg(3) = dst.w; + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + int offset = ii * 2 + jj; + const int current_row = mi * ROWS + ii * 8; + const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + // 8 is actually col of half data now, for more general case ? +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d", + mi, ni, ii, jj, offset, current_row, current_col); + printf("\n"); + } +#endif + // the row is already in the right position + ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + + preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) < min(COLS, actual_seqlen_k)); + } + } + + // load data + Ldg_functor fct(frag[mi][ni].regs_, ptrs); + #pragma unroll + for(int kk = 0; kk < LDGS_PER_THREAD_PER_WARP; ++kk ) { + fct.load(kk, preds[kk]); + } } } } + + inline __device__ void move(const int steps = 1) { + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + this->actual_seqlen_q -= ROWS * steps; + } + + int row; + int col; + uint32_t row_stride_in_bytes; + // The pointer. + char *ptr_; + int actual_seqlen_q; + int actual_seqlen_k; + const int tidx_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index a9e84d3f9..d0af0c850 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -62,108 +62,124 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, bool has_attn = !(launch_params.params.attn_mask_ptr == nullptr); bool has_bias = !(launch_params.params.attn_bias_ptr == nullptr); - if (has_attn) - { - if (has_bias) { - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - auto kernel = launch_params.params.is_causal - ? (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(launch_params.params.b, launch_params.params.h); +#ifdef DEBUG_PRINT + printf ("has_attn=%d, has_bias=%d\n", has_attn, has_bias); +#endif - // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // printf("block size: %d\n", Kernel_traits::THREADS); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); - }else{ - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - auto kernel = launch_params.params.is_causal - ? (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(launch_params.params.b, launch_params.params.h); + // attn + bias on + // IsDropoutConst off + auto kernel = &fmha_fprop_fp16_sm80_loop_kernel; + dim3 grid(launch_params.params.b, launch_params.params.h); - // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // printf("block size: %d\n", Kernel_traits::THREADS); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); - } - }else{ - if (has_bias) { - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - auto kernel = launch_params.params.is_causal - ? (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(launch_params.params.b, launch_params.params.h); + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + - // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // printf("block size: %d\n", Kernel_traits::THREADS); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); - }else{ - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - auto kernel = launch_params.params.is_causal - ? (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(launch_params.params.b, launch_params.params.h); + // if (has_attn) + // { + // if (has_bias) { + // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // // https://github.com/kokkos/kokkos-kernels/issues/349 + // // https://github.com/HazyResearch/flash-attention/issues/21 + // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + // auto kernel = launch_params.params.is_causal + // ? (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel) + // : (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel); + // if( smem_size >= 48 * 1024 ) { + // FMHA_CHECK_CUDA(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // } + // dim3 grid(launch_params.params.b, launch_params.params.h); - // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // printf("block size: %d\n", Kernel_traits::THREADS); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); - } - } + // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // // printf("block size: %d\n", Kernel_traits::THREADS); + // kernel<<>>( + // launch_params.params); + // FMHA_CHECK_CUDA(cudaPeekAtLastError()); + // }); + // }else{ + // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // // https://github.com/kokkos/kokkos-kernels/issues/349 + // // https://github.com/HazyResearch/flash-attention/issues/21 + // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + // auto kernel = launch_params.params.is_causal + // ? (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel) + // : (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel); + // if( smem_size >= 48 * 1024 ) { + // FMHA_CHECK_CUDA(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // } + // dim3 grid(launch_params.params.b, launch_params.params.h); + + // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // // printf("block size: %d\n", Kernel_traits::THREADS); + // kernel<<>>( + // launch_params.params); + // FMHA_CHECK_CUDA(cudaPeekAtLastError()); + // }); + // } + // }else{ + // if (has_bias) { + // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // // https://github.com/kokkos/kokkos-kernels/issues/349 + // // https://github.com/HazyResearch/flash-attention/issues/21 + // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + // auto kernel = launch_params.params.is_causal + // ? (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel) + // : (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel); + // if( smem_size >= 48 * 1024 ) { + // FMHA_CHECK_CUDA(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // } + // dim3 grid(launch_params.params.b, launch_params.params.h); + + // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // // printf("block size: %d\n", Kernel_traits::THREADS); + // kernel<<>>( + // launch_params.params); + // FMHA_CHECK_CUDA(cudaPeekAtLastError()); + // }); + // }else{ + // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // // https://github.com/kokkos/kokkos-kernels/issues/349 + // // https://github.com/HazyResearch/flash-attention/issues/21 + // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + // auto kernel = launch_params.params.is_causal + // ? (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel) + // : (launch_params.return_softmax + // ? &fmha_fprop_fp16_sm80_loop_kernel + // : &fmha_fprop_fp16_sm80_loop_kernel); + // if( smem_size >= 48 * 1024 ) { + // FMHA_CHECK_CUDA(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // } + // dim3 grid(launch_params.params.b, launch_params.params.h); + + // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // // printf("block size: %d\n", Kernel_traits::THREADS); + // kernel<<>>( + // launch_params.params); + // FMHA_CHECK_CUDA(cudaPeekAtLastError()); + // }); + // } + // } } void run_fmha_fp16_sm80(Launch_params &launch_params, @@ -173,62 +189,69 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, auto dprops = at::cuda::getCurrentDeviceProperties(); if (launch_params.params.d == 16) { if( launch_params.params.seqlen_k == 128 ) { + // int S, int D, int STEP, int WARPS_M, int WARPS_N, + // D is [hidden_dim] using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if( launch_params.params.seqlen_k == 256 ) { - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - // TD [2022-05-15] 512 gives wrong results rn - // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } else if (launch_params.params.d == 32) { - if( launch_params.params.seqlen_k == 128 ) { - using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if( launch_params.params.seqlen_k == 256 ) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } else if (launch_params.params.d == 64) { - if( launch_params.params.seqlen_k == 128 ) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if( launch_params.params.seqlen_k >= 256 ) { - if (dprops->major == 8 && dprops->minor >= 0) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if (dprops->major == 7 && dprops->minor == 5) { - if (launch_params.is_dropout) { // Need to use the same block size as backward - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } - } - } else if (launch_params.params.d == 128) { - if( launch_params.params.seqlen_k == 128 ) { - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { - // TD [2022-06-05] Keep K in registers to reduce register spilling - // Gives about 6% speedup compared to using block size 128. - using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { // Need to use the same block size as backward - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } - } + } + // else if( launch_params.params.seqlen_k == 256 ) { + // using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // // TD [2022-05-15] 512 gives wrong results rn + // // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; + // using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + } + // debug on comments + // else if (launch_params.params.d == 32) { + // if( launch_params.params.seqlen_k == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if( launch_params.params.seqlen_k == 256 ) { + // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } else if (launch_params.params.d == 64) { + // if( launch_params.params.seqlen_k == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if( launch_params.params.seqlen_k >= 256 ) { + // if (dprops->major == 8 && dprops->minor >= 0) { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if (dprops->major == 7 && dprops->minor == 5) { + // if (launch_params.is_dropout) { // Need to use the same block size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } + // } + // } else if (launch_params.params.d == 128) { + // if( launch_params.params.seqlen_k == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { + // // TD [2022-06-05] Keep K in registers to reduce register spilling + // // Gives about 6% speedup compared to using block size 128. + // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { // Need to use the same block size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } + // } + // debug on comments + // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 7d1b40a11..0870d009a 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -174,6 +174,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base; +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // Cta_tile_p + printf("Cta_tile_p::M = %d, Cta_tile_p::N = %d, Cta_tile_p::K = %d\n", + Cta_tile_p::M, Cta_tile_p::N, Cta_tile_p::K); + printf("Cta_tile_p::WARPS_M = %d, Cta_tile_p::WARPS_N = %d, Cta_tile_p::WARPS_K = %d\n", + Cta_tile_p::WARPS_M, Cta_tile_p::WARPS_N, Cta_tile_p::WARPS_K); + printf("Cta_tile_p::WARPS_PER_CTA = %d, Cta_tile_p::THREADS_PER_WARP = %d, Cta_tile_p::THREADS_PER_CTA = %d\n", + Cta_tile_p::WARPS_PER_CTA, Cta_tile_p::THREADS_PER_WARP, Cta_tile_p::THREADS_PER_CTA); + printf("\n"); + + // Cta_tile_o + printf("Cta_tile_o::M = %d, Cta_tile_o::N = %d, Cta_tile_o::K = %d\n", + Cta_tile_o::M, Cta_tile_o::N, Cta_tile_o::K); + printf("Cta_tile_o::WARPS_M = %d, Cta_tile_o::WARPS_N = %d, Cta_tile_o::WARPS_K = %d\n", + Cta_tile_o::WARPS_M, Cta_tile_o::WARPS_N, Cta_tile_o::WARPS_K); + printf("Cta_tile_o::WARPS_PER_CTA = %d, Cta_tile_o::THREADS_PER_WARP = %d, Cta_tile_o::THREADS_PER_CTA = %d\n", + Cta_tile_o::WARPS_PER_CTA, Cta_tile_o::THREADS_PER_WARP, Cta_tile_o::THREADS_PER_CTA); + printf("\n"); + + // Mma_tile_p + printf("Mma_tile_p::MMAS_M = %d, Mma_tile_p::MMAS_N = %d, Mma_tile_p::MMAS_K = %d\n", + Mma_tile_p::MMAS_M, Mma_tile_p::MMAS_N, Mma_tile_p::MMAS_K); + // The number of elements computed with a single CTA-MMA. + printf("Mma_tile_p::M_PER_MMA_PER_CTA = %d, Mma_tile_p::N_PER_MMA_PER_CTA = %d, Mma_tile_p::K_PER_MMA_PER_CTA = %d\n", + Mma_tile_p::M_PER_MMA_PER_CTA, Mma_tile_p::N_PER_MMA_PER_CTA, Mma_tile_p::K_PER_MMA_PER_CTA); + printf("\n"); + + // Mma_tile_o + printf("Mma_tile_o::MMAS_M = %d, Mma_tile_o::MMAS_N = %d, Mma_tile_o::MMAS_K = %d\n", + Mma_tile_o::MMAS_M, Mma_tile_o::MMAS_N, Mma_tile_o::MMAS_K); + printf("Mma_tile_o::M_PER_MMA_PER_CTA = %d, Mma_tile_o::N_PER_MMA_PER_CTA = %d, Mma_tile_o::K_PER_MMA_PER_CTA = %d\n", + Mma_tile_o::M_PER_MMA_PER_CTA, Mma_tile_o::N_PER_MMA_PER_CTA, Mma_tile_o::K_PER_MMA_PER_CTA); + printf("\n"); + + // Gmem_tile_q + printf("Gmem_tile_q::BYTES_PER_ELEMENT = %d, Gmem_tile_q::ROWS = %d, Gmem_tile_q::COLS = %d, Gmem_tile_q::LDGS = %d, Gmem_tile_q::THREADS_PER_ROW = %d, Gmem_tile_q::ROWS_PER_LDG=%d\n", + Gmem_tile_q::BYTES_PER_ELEMENT, Gmem_tile_q::ROWS, Gmem_tile_q::COLS, Gmem_tile_q::LDGS, + Gmem_tile_q::THREADS_PER_ROW, Gmem_tile_q::ROWS_PER_LDG); + printf("\n"); + + // Gmem_tile_k + printf("Gmem_tile_k::BYTES_PER_ELEMENT = %d, Gmem_tile_k::ROWS = %d, Gmem_tile_k::COLS = %d, Gmem_tile_k::LDGS = %d, Gmem_tile_q::THREADS_PER_ROW = %d, Gmem_tile_q::ROWS_PER_LDG=%d\n", + Gmem_tile_k::BYTES_PER_ELEMENT, Gmem_tile_k::ROWS, Gmem_tile_k::COLS, Gmem_tile_k::LDGS, + Gmem_tile_q::THREADS_PER_ROW, Gmem_tile_q::ROWS_PER_LDG); + printf("\n"); + + // Gmem_tile_v + printf("Gmem_tile_v::BYTES_PER_ELEMENT = %d, Gmem_tile_v::ROWS = %d, Gmem_tile_v::COLS = %d, Gmem_tile_v::LDGS = %d, Gmem_tile_q::THREADS_PER_ROW = %d, Gmem_tile_q::ROWS_PER_LDG=%d\n", + Gmem_tile_v::BYTES_PER_ELEMENT, Gmem_tile_v::ROWS, Gmem_tile_v::COLS, Gmem_tile_v::LDGS, + Gmem_tile_q::THREADS_PER_ROW, Gmem_tile_q::ROWS_PER_LDG); + printf("\n"); + + // Gmem_tile_o + printf("Gmem_tile_o::ROWS = %d, Gmem_tile_o::COLS = %d, Gmem_tile_o::STGS = %d, Gmem_tile_o::STGS_PER_LOOP = %d\n", + Gmem_tile_o::ROWS, Gmem_tile_o::COLS, Gmem_tile_o::STGS, Gmem_tile_o::STGS_PER_LOOP); + printf("\n"); + + // Gmem_tile_s + printf("Gmem_tile_s::M = %d, Gmem_tile_s::N = %d\n", + Gmem_tile_s::M, Gmem_tile_s::N); + printf("\n"); + + // Gmem_softmax_sum + printf("Gmem_softmax_sum::MMAS_M = %d, Gmem_softmax_sum::ROWS = %d\n", + Gmem_softmax_sum::MMAS_M, Gmem_softmax_sum::ROWS); + printf("\n"); + + // Gemm1 + printf("Gemm1::SHARE_SMEM_FOR_K_AND_V = %d, Gemm1::SMEM_OFFSET_O = %d, Gemm1::SMEM_OFFSET_SOFTMAX = %d, Gemm1::SMEM_OFFSET_V = %d, Gemm1::SMEM_OFFSET_V = %d\n", + Gemm1::SHARE_SMEM_FOR_K_AND_V, Gemm1::SMEM_OFFSET_O, Gemm1::SMEM_OFFSET_SOFTMAX, Gemm1::SMEM_OFFSET_V, Gemm1::SMEM_OFFSET_V); + printf("\n"); + + // Softmax + printf("Softmax::WARPS_M = %d, Softmax::WARPS_N = %d, Softmax::MMAS_M = %d, Softmax::MMAS_N = %d\n", + Softmax::WARPS_M, Softmax::WARPS_N, Softmax::MMAS_M, Softmax::MMAS_N); + printf("\n"); + } +#endif + // Shared memory. extern __shared__ char smem_[]; @@ -276,9 +357,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // bool has_bias = !(params.attn_bias_ptr == nullptr); // Allocate the global memory tile loader for bias. - using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; - // conctructor - Gmem_tile_bias gmem_bias(params, binfo, tidx); + // using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + // // conctructor + // Gmem_tile_bias gmem_bias(params, binfo, tidx); // TODO: load fun as s Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -294,15 +375,18 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_o_tmp.move(begin); if (Return_softmax) { gmem_s.move(begin); } - if constexpr (has_attn) { + // if constexpr (has_attn) { + if (!(params.attn_mask_ptr == nullptr)) { // TODO: mask move gmem_mask.move(begin); } - if constexpr (has_bias) { - // TODO: bias move - gmem_bias.move(begin); - } + // // if constexpr (has_bias) { + // if (!(params.attn_bias_ptr == nullptr)) { + // // TODO: bias move + // gmem_bias.move(begin); + // } + gmem_softmax_lse.move(begin); fmha::Mask mask(binfo, tidx, loop_step_idx); @@ -325,13 +409,15 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } - if constexpr (has_attn) { - // TODO: mask move as s - gmem_mask.move(loop_step_idx * steps_og); - } - if constexpr (has_bias) { - gmem_bias.move(loop_step_idx * steps_og); - } + // if constexpr (has_attn) { + // if (!(params.attn_mask_ptr == nullptr)) { + // // TODO: mask move as s, with col move + // gmem_mask.move(loop_step_idx * steps_og); + // } + // // if constexpr (has_bias) { + // if (!(params.attn_bias_ptr == nullptr)) { + // gmem_bias.move(loop_step_idx * steps_og); + // } } // Trigger the loads for K. @@ -341,6 +427,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Trigger the loads for V. gmem_v.load(); + using Frag_mask = fmha::Fragment_c; + Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_mask.template load(frag_mask); + if (!Is_first) { __syncthreads(); } float p_prev_lse[Mma_tile_p::MMAS_M * 2]; @@ -411,41 +501,55 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); // } - if constexpr (has_attn) { - using Frag_mask = fmha::Fragment_c; - Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - - gmem_mask.load(frag_mask); - // do we need sync ? - __syncthreads(); - - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - acc_p[mi][ni].addf(frag_mask[ni][mi]); +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { + for (int ii = 0; ii < Mma_tile_p::MMAS_M; ii ++) { + for (int jj = 0; jj < Mma_tile_p::MMAS_N; jj ++) { + for (int kk = 0; kk < acc_p[ii][jj].NUM_ELTS; kk ++) { + printf("ii=%d, jj=%d, kk=%d, acc_p=%.6f\n", ii, jj, kk, acc_p[ii][jj].elt(kk)); + } } } - gmem_mask.move(); } +#endif - if constexpr (has_bias) { - using Frag_bias = fmha::Fragment_c; + // if constexpr (has_attn) { + // if (!(params.attn_mask_ptr == nullptr)) { + // using Frag_mask = fmha::Fragment_c; + // Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + + // gmem_mask.load(frag_mask); + // // do we need sync ? + // __syncthreads(); + + // #pragma unroll + // for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + // #pragma unroll + // for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + // acc_p[mi][ni].addf(frag_mask[ni][mi]); + // } + // } + // gmem_mask.move(); + // } + + // if constexpr (has_bias) { + // if (!(params.attn_bias_ptr == nullptr)) { + // using Frag_bias = fmha::Fragment_c; - Frag_bias frag_bias[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - gmem_bias.load(frag_bias); + // Frag_bias frag_bias[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; + // gmem_bias.load(frag_bias); - __syncthreads(); + // __syncthreads(); - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - acc_p[mi][ni].addf(frag_bias[ni][mi]); - } - } - gmem_bias.move(); - } + // #pragma unroll + // for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + // #pragma unroll + // for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + // acc_p[mi][ni].addf(frag_bias[ni][mi]); + // } + // } + // gmem_bias.move(); + // } uint4 out[Gmem_tile_o::STGS_PER_LOOP]; if (!Is_first) { gmem_o_tmp.load(out, 0); } @@ -481,7 +585,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if (!Is_first) { smem_softmax_lse.store_pair(p_prev_lse, l % 2); // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; } - for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; } + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; + } } // Trigger the load for the next LSE values. @@ -539,6 +645,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // // Finalize softmax on the accumulators of P^T. // softmax.scale(p_sum); + constexpr bool encode_dropout_in_sign_bit = Return_softmax; if (Is_dropout) { @@ -729,13 +836,16 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); constexpr int M = Kernel_traits::Cta_tile_p::M; - const int STEPS = (params.seqlen_q + M - 1) / M; + const int STEPS = (params.seqlen_q + M - 1) / M; + // iterative over q, stride with M, block size constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + if (params.seqlen_k == blocksize_c) { fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); } else { const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; + // iterative with k fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); diff --git a/tests/build.sh b/tests/build.sh new file mode 100644 index 000000000..2f11ad907 --- /dev/null +++ b/tests/build.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# csrc_path=../csrc/flash_attn +# csrc_path=/workspace/openfold/single_test/flash_attn/flash-attention_v2/csrc/flash_attn +csrc_path=../csrc/flash_attn +src_file= +src_file+=test_forward.cu +src_file+=" ${csrc_path}/fmha_api.cpp" +src_file+=" ${csrc_path}/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu" +src_file+=" ${csrc_path}/src/fmha_block_fprop_fp16_kernel.sm80.cu" +src_file+=" ${csrc_path}/src/fmha_dgrad_fp16_kernel_loop.sm80.cu" +src_file+=" ${csrc_path}/src/fmha_fprop_fp16_kernel.sm80.cu" + +echo ${src_file} + +echo ${csrc_path}/ +echo ${csrc_path}/src +echo ${csrc_path}/cutlass/include + +# nvcc -o test ${src_file} \ +/usr/local/cuda-11.3/bin/nvcc -v -o test ${src_file} \ + --compiler-options='-Wl\,--no-as-needed' \ + -lc10 -ltorch -ltorch_cpu -lcudart -lc10_cuda -ltorch_cuda -ltorch_cuda_cu -ltorch_cuda_cpp \ + -I ./ \ + -I ${csrc_path} \ + -I ${csrc_path}/src \ + -I ${csrc_path}/cutlass/include \ + -I /opt/conda/lib/python3.7/site-packages/torch/include \ + -I /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include \ + -I /opt/conda/lib/python3.7/site-packages/torch/include/TH \ + -I /opt/conda/lib/python3.7/site-packages/torch/include/THC \ + -I /opt/conda/include \ + -I /opt/conda/include/python3.7m \ + -L /opt/conda/lib/python3.7/site-packages/torch/lib/ \ + -L /usr/local/cuda-11.3/lib64/ \ + -L /opt/conda/lib64/ \ + -L /opt/conda/lib/ \ + -g -G \ + -t 4 \ + -D_GLIBCXX_USE_CXX11_ABI=0 \ + -DDEBUG_PRINT \ + -DDEBUG_USING_NVCC \ + -gencode arch=compute_80,code=sm_80 \ + -U__CUDA_NO_HALF_OPERATORS__ \ + -U__CUDA_NO_HALF_CONVERSIONS__ \ + --expt-relaxed-constexpr \ + --expt-extended-lambda \ + --use_fast_math + + diff --git a/tests/fmha_api.h b/tests/fmha_api.h new file mode 100644 index 000000000..6d6ef7410 --- /dev/null +++ b/tests/fmha_api.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#include "fmha.h" + +std::vector +mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + const bool return_softmax, + c10::optional gen_, + const c10::optional &attn_mask, // attn_mask + const c10::optional &attn_bias // attn bias + ); diff --git a/tests/test_forward.cu b/tests/test_forward.cu new file mode 100644 index 000000000..0f95d281b --- /dev/null +++ b/tests/test_forward.cu @@ -0,0 +1,106 @@ +#include +//#include +#include + +void test_fwd() { + int batch_size = 1; + int nheads = 1; + int headdim = 16; + int max_seqlen_q_ = 128; + int max_seqlen_k_ = 128; + + float softmax_scale = 0.1; + + bool zero_tensors = false; + bool is_causal = false; + bool return_softmax = false; + + // q -> [bs * seq, head, head_dim] + // q -> [1 * 128, 1, 16] + // block q -> [128, 16] + + // k -> [bs * seq, head, head_dim] + // k -> [1 * 128, 1, 16] + // block k -> [128, 16] + + // v -> [bs * seq, head, head_dim] + // v -> [1 * 128, 1, 16] + // block k -> [128, 16] + + at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + int cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + q_cpu[i][j][k] = cnt * 0.001; + k_cpu[i][j][k] = cnt * 0.001; + v_cpu[i][j][k] = cnt * 0.001; + cnt ++; + } + } + } + + auto q = q_cpu.cuda(); + auto k = k_cpu.cuda(); + auto v = v_cpu.cuda(); + + at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + + for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { + cu_seqlens_q_cpu[i] = i * max_seqlen_q_; + cu_seqlens_k_cpu[i] = i * max_seqlen_k_; + } + + auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); + auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); + + at::Tensor attn_mask = at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).cuda(); + + cnt = 0; + for (int i = 0; i < batch_size; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < max_seqlen_q_; k ++) { + for (int l = 0; l < max_seqlen_k_; l ++) { + attn_mask[i][j][k][l] = cnt * 0.001; + cnt ++; + } + } + } + } + + c10::optional gen_; + c10::optional attn_mask_op; + + std::vector ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + attn_mask, + attn_mask.clone() + ); + + std::cout << "Ret vec size is " << ret.size(); + for (int i = 0; i < ret.size(); i ++) { + ret[i].cpu(); + std::cout << ret[i] << std::endl; + } +} + +int main(){ + test_fwd(); + return 0; +} From 4368b1bdb289906e19e612976c116d6bc91eea10 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 24 Aug 2022 15:28:43 +0800 Subject: [PATCH 15/71] add mask test --- benchmarks/test/test_forward_with_mask.py | 233 ++++++++++++++++++++ csrc/flash_attn/src/fmha/gemm.h | 16 -- csrc/flash_attn/src/fmha/gmem_tile.h | 7 +- csrc/flash_attn/src/fmha/softmax.h | 23 ++ csrc/flash_attn/src/fmha/utils.h | 17 ++ csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 61 ++++- 6 files changed, 333 insertions(+), 24 deletions(-) create mode 100644 benchmarks/test/test_forward_with_mask.py diff --git a/benchmarks/test/test_forward_with_mask.py b/benchmarks/test/test_forward_with_mask.py new file mode 100644 index 000000000..cf3a8b8e8 --- /dev/null +++ b/benchmarks/test/test_forward_with_mask.py @@ -0,0 +1,233 @@ +import torch +import torch.nn as nn + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np +import deepspeed + +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + d = t.dtype + # if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): + # with torch.cuda.amp.autocast(enabled=False): + # s = torch.nn.functional.softmax(t, dim=dim) + # else: + # s = torch.nn.functional.softmax(t, dim=dim) + s = torch.nn.functional.softmax(t, dim=dim) + return s + + +def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch.Tensor: + # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + # output back to fp16/bf16. + dtype_og = query.dtype + if upcast: + query = query.float() + key = key.float() + value = value.float() + if mask is not None: + mask = mask.float() + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + # print ("q * k: ", a) + + if biases is None: + biases = [] + for b in biases: + a += b + # print ("after bias:", a) + + if mask is not None: + a += mask + # print ("after mask:", a) + + a = softmax_no_cast(a, -1) + # print ("softmax :", a) + + # [*, H, Q, C_hidden] + b = torch.matmul(a, value) + # print ("p * v: ", a) + return b.to(dtype_og), a.to(dtype_og) + + +def _flash_attn(q, k, v, attn_mask=None): + batch_dims = q.shape[:-3] + no_heads, n, c = q.shape[-3:] + dtype = q.dtype + + if attn_mask is not None: + attn_mask = attn_mask.half() + + # [*, B, N, H, C] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + # [B_flat, N, H, C] + q = q.reshape(-1, *q.shape[-3:]) + k = k.reshape(-1, *k.shape[-3:]) + v = v.reshape(-1, *v.shape[-3:]) + + # Flattened batch size + batch_size = q.shape[0] + + # [B_flat * N, H, C] + q = q.reshape(-1, *q.shape[-2:]) + k = k.reshape(-1, *k.shape[-2:]) + v = v.reshape(-1, *v.shape[-2:]) + + q_max_s = n + q_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device + ) + + k_max_s = n + k_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device + ) + + out = flash_attn_unpadded_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + q_max_s, + k_max_s, + attn_mask=attn_mask, + dropout_p = 0., + softmax_scale = 1., # q has been scaled already + ) + + # [*, B, N, H, C] + out = out.reshape(*batch_dims, n, no_heads, c) + return out + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +torch.manual_seed(0) +# v2 +bs = 1 +seq = 128 +head = 1 +c_dim = 16 + +seq_q = seq_k = seq_v = seq + +print (10 * "*" + "prepare data" + 10 * "*" ) +dtype = torch.bfloat16 +# dtype = torch.half +device = "cuda" + +# orig_tensor = torch.stack( +# [ (i+1) * 0.1 * torch.randn((bs, seq, head, c_dim)) for i in range(seq) ] +# ,dim = 1 +# ).to(device).to(dtype) + +orig_tensor = torch.empty((bs, seq, head, seq, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) +orig_tensor.requires_grad = True +# print ("tensor: ", orig_tensor) +print ("origin shape: ", orig_tensor.shape) +# [bs, seq, seq, head, c_dim] + +mask = gen_attn_mask( + ( + torch.rand( + bs, + seq_q, + 1, + 1, + seq_k, + dtype=dtype, + device=device, + ) + > 0.2 + ).type(dtype), + -3e4, +) +print ("mask shape: ", mask.shape) +mask_broadcast = mask.expand([bs, seq_k, head, seq_q, seq_k]) +print ("mask_broadcast shape: ", mask_broadcast.shape) + + +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +normal_attn_v1 = orig_tensor.clone() +output_ref, softmax_output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, mask=mask_broadcast, upcast=True) +# be careful here +output_ref = output_ref.transpose(-2, -3) +print ("attention output shape: ", output_ref.shape) +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +print () + +print (10 * "*" + "normal attn fp16" + 10 * "*" ) +normal_attn_v2 = orig_tensor.clone() +output_pt, softmax_output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, mask=mask_broadcast) +# be careful here +output_pt = output_pt.transpose(-2, -3) +print ("attention output shape: ", output_pt.shape) +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +print () + +print (10 * "*" + "flash attn" + 10 * "*" ) +normal_attn_flash = orig_tensor.clone() +output3 = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, attn_mask=mask_broadcast) +import pdb; pdb.set_trace() +print ("flash attn output shape: ", output3.shape) +print (10 * "*" + "flash attn" + 10 * "*" ) +print () + +# print ("max abs error: ", (output3 - output_ref).abs().max()) +# print ("all close at pre@.2: ", torch.allclose(output3, output_ref, atol=1e-2)) + +print (10 * "*" + "comparing forward" + 10 * "*" ) +print("Output max diff: {0}".format((output3 - output_ref).abs().max().item())) +print("Output mean diff: {0}".format((output3 - output_ref).abs().mean().item())) + +print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) +print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) + +print("Output max diff with Pytorch: {0}".format((output3 - output_pt).abs().max().item())) +print("Output mean diff with Pytorch: {0}".format((output3 - output_pt).abs().mean().item())) + +print ("less than twice error: ", (output3 - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) +print (10 * "*" + "comparing forward" + 10 * "*" ) +print () + + +# test backward + +# g = torch.randn_like(output3) +# dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) +# dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) +# dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) + +# print("dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) +# print("dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) +# print("dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) + +# print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) + diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 9d40713f2..d06cec6e5 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -174,22 +174,6 @@ struct Fragment_c : public Fragment { //////////////////////////////////////////////////////////////////////////////////////////////////// -template __device__ -inline float toFloat(T a) { - return (float)a; -} -template<> __device__ -inline float toFloat(half a) { - return __half2float(a); -} -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -template<> __device__ -inline float toFloat(__nv_bfloat16 a) { - return __bfloat162float(a); -} -#endif -//////////////////////////////////////////////////////////////////////////////////////////////////// - struct Fragment_accumulator : public Fragment { diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index dc0ef20b3..efb594fb8 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -497,12 +497,13 @@ struct Gmem_tile_mma_mask { template< typename Params, typename Block_info > inline __device__ Gmem_tile_mma_mask(const Params ¶ms, // const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, - const Block_info& binfo, const int tidx) + const Block_info& binfo, const int tidx, const int loop_step_idx) : ptr_(static_cast(params.attn_mask_ptr)) // : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) , actual_seqlen_q(binfo.actual_seqlen_q) , actual_seqlen_k(binfo.actual_seqlen_k) , tidx_(tidx) + , loop_step_idx(loop_step_idx) { row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; @@ -551,7 +552,6 @@ struct Gmem_tile_mma_mask { template inline __device__ void load(Fragment (&frag)[M][N]) { // using Fragment = typename fmha::Fragment; - // like Fragment_a const void *ptrs[LDGS_PER_THREAD_PER_WARP]; uint32_t preds[LDGS_PER_THREAD_PER_WARP]; @@ -566,7 +566,7 @@ struct Gmem_tile_mma_mask { for (int jj = 0; jj < 2; ++jj ) { int offset = ii * 2 + jj; const int current_row = mi * ROWS + ii * 8; - const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; // 8 is actually col of half data now, for more general case ? #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { @@ -601,6 +601,7 @@ struct Gmem_tile_mma_mask { int row; int col; + const int loop_step_idx; uint32_t row_stride_in_bytes; // The pointer. char *ptr_; diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index c4783ee50..66a67e6f9 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -30,6 +30,8 @@ #include #include +#include + namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -489,6 +491,27 @@ struct Softmax : public Softmax_base { , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { } + template + inline __device__ void apply_attn_mask(const Fragment (&mask)[MMAS_M][MMAS_N]) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + #pragma unroll + for( int jj = 0; jj < 4; ++jj ) { + // if( abs(float(mask[mi][ni].elt(ii * 4 + jj))) > 0 ) { + // this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; + // } + // this->elt_[2 * mi + ii][4 * ni + jj] += float(mask[mi][ni].elt(ii * 4 + jj)); + this->elt_[2 * mi + ii][4 * ni + jj] += toFloat(mask[mi][ni].elt(ii * 4 + jj)); + } + } + } + } + } + // Pack the data to a fragment for the next GEMM. template inline __device__ void pack(Fragment_a (&dst)[K][M]) const { diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h index ecb8aef7f..5ce679378 100644 --- a/csrc/flash_attn/src/fmha/utils.h +++ b/csrc/flash_attn/src/fmha/utils.h @@ -1212,4 +1212,21 @@ __device__ inline void quad_allreduce(__half2 (&dst)[M], float2 (&src)[M], Opera //////////////////////////////////////////////////////////////////////////////////////////////////// +template __device__ +inline float toFloat(T a) { + return (float)a; +} +template<> __device__ +inline float toFloat(half a) { + return __half2float(a); +} +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> __device__ +inline float toFloat(__nv_bfloat16 a) { + return __bfloat162float(a); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace fmha diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 0870d009a..c8c5721e3 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -352,7 +352,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for mask. using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; // conctructor - Gmem_tile_mask gmem_mask(params, binfo, tidx); + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); // TODO: load fun as s // bool has_bias = !(params.attn_bias_ptr == nullptr); @@ -427,10 +427,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Trigger the loads for V. gmem_v.load(); - using Frag_mask = fmha::Fragment_c; - Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - gmem_mask.template load(frag_mask); - if (!Is_first) { __syncthreads(); } float p_prev_lse[Mma_tile_p::MMAS_M * 2]; @@ -567,6 +563,61 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); + using Frag_mask = fmha::Fragment_c; + Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_mask.template load(frag_mask); + gmem_mask.move(); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_o::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + // Apply the attn mask. + softmax.apply_attn_mask(frag_mask); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_o::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + // Apply the mask. // this impl is more like padding softmax.apply_mask(mask); From 81281c39e627596f1a82ccc66461b93308ed04fc Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 25 Aug 2022 11:20:18 +0800 Subject: [PATCH 16/71] fix forward bugs --- benchmarks/test/test_forward_with_mask.py | 66 ++++++++----- csrc/flash_attn/fmha_api.cpp | 4 +- csrc/flash_attn/src/fmha/gmem_tile.h | 28 +++--- csrc/flash_attn/src/fmha/softmax.h | 8 +- flash_attn/flash_attn_interface.py | 14 +-- tests/test_forward.cu | 114 +++++++++++++++++++++- 6 files changed, 187 insertions(+), 47 deletions(-) diff --git a/benchmarks/test/test_forward_with_mask.py b/benchmarks/test/test_forward_with_mask.py index cf3a8b8e8..865016dd2 100644 --- a/benchmarks/test/test_forward_with_mask.py +++ b/benchmarks/test/test_forward_with_mask.py @@ -56,7 +56,11 @@ def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch # print ("after bias:", a) if mask is not None: - a += mask + # a += mask + # import pdb; pdb.set_trace() + # please do not use add now + a.masked_fill_(mask < 0, float('-inf')) + # print ("after mask:", a) a = softmax_no_cast(a, -1) @@ -104,6 +108,10 @@ def _flash_attn(q, k, v, attn_mask=None): 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device ) + if attn_mask is not None: + # import pdb; pdb.set_trace() + attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]) + out = flash_attn_unpadded_func( q, k, @@ -136,6 +144,12 @@ def gen_attn_mask(mask, neg_inf): head = 1 c_dim = 16 +# mini +# bs = 1 +# seq = 2 +# head = 1 +# c_dim = 16 + seq_q = seq_k = seq_v = seq print (10 * "*" + "prepare data" + 10 * "*" ) @@ -154,25 +168,29 @@ def gen_attn_mask(mask, neg_inf): print ("origin shape: ", orig_tensor.shape) # [bs, seq, seq, head, c_dim] -mask = gen_attn_mask( - ( - torch.rand( - bs, - seq_q, - 1, - 1, - seq_k, - dtype=dtype, - device=device, - ) - > 0.2 - ).type(dtype), +mask_data = torch.rand( + bs, + seq_q, + 1, + 1, + seq_k, + dtype=dtype, + device=device, + ) + +# fake data +# mask_data[:, :, :, :, :] = 0.02 +# mask_data[:, :, :, :, 0] = 0.001 + +mask = gen_attn_mask( + ( mask_data > 0.01 ).type(dtype), -3e4, ) print ("mask shape: ", mask.shape) mask_broadcast = mask.expand([bs, seq_k, head, seq_q, seq_k]) print ("mask_broadcast shape: ", mask_broadcast.shape) +print ("mask broadcast: ", mask_broadcast) print (10 * "*" + "normal attn fp32" + 10 * "*" ) normal_attn_v1 = orig_tensor.clone() @@ -183,6 +201,7 @@ def gen_attn_mask(mask, neg_inf): print (10 * "*" + "normal attn fp32" + 10 * "*" ) print () + print (10 * "*" + "normal attn fp16" + 10 * "*" ) normal_attn_v2 = orig_tensor.clone() output_pt, softmax_output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, mask=mask_broadcast) @@ -192,10 +211,11 @@ def gen_attn_mask(mask, neg_inf): print (10 * "*" + "normal attn fp32" + 10 * "*" ) print () + print (10 * "*" + "flash attn" + 10 * "*" ) normal_attn_flash = orig_tensor.clone() output3 = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, attn_mask=mask_broadcast) -import pdb; pdb.set_trace() +# import pdb; pdb.set_trace() print ("flash attn output shape: ", output3.shape) print (10 * "*" + "flash attn" + 10 * "*" ) print () @@ -220,14 +240,14 @@ def gen_attn_mask(mask, neg_inf): # test backward -# g = torch.randn_like(output3) -# dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) -# dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) -# dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) +g = torch.randn_like(output3) +dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) +dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) +dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) -# print("dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) -# print("dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) -# print("dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) +print("dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) +print("dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) +print("dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) -# print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) +print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 3b90b7b0a..1ff138f80 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -358,7 +358,9 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const float softmax_scale, const bool zero_tensors, const bool is_causal, - c10::optional gen_ + c10::optional gen_, + const c10::optional &attn_mask, // attn_mask + const c10::optional &attn_bias // attn bias ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm75 = dprops->major == 7 && dprops->minor == 5; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index efb594fb8..66e16241a 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -532,9 +532,10 @@ struct Gmem_tile_mma_mask { // The block index. uint32_t bidx = binfo.bidb * params.h + binfo.bidh; - uint32_t row_offset = bidx * params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT; + // the index of bs and head dim + uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; // row_offset = (uint32_t)(row * row_stride_in_bytes); - row_offset += (uint32_t)(row * params.seqlen_k * BYTES_PER_ELEMENT); + row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { @@ -567,20 +568,25 @@ struct Gmem_tile_mma_mask { int offset = ii * 2 + jj; const int current_row = mi * ROWS + ii * 8; const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + // const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; // 8 is actually col of half data now, for more general case ? -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d", - mi, ni, ii, jj, offset, current_row, current_col); - printf("\n"); - } -#endif // the row is already in the right position ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + (uint32_t)current_col * BYTES_PER_ELEMENT; - preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) - && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) < min(COLS, actual_seqlen_k)); + preds[offset] = (current_row <= min(ROWS, actual_seqlen_q)) + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", + mi, ni, ii, jj, offset, current_row, current_col, ptr_, ptrs[offset], preds[offset]); + printf("current_row=%d, current_col=%d, ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d\n", + current_row, current_col, ROWS, actual_seqlen_q, COLS, actual_seqlen_k); + printf("cond 1=%d\n", (current_row <= min(ROWS, actual_seqlen_q))); + printf("cond 2=%d\n", ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); + printf("\n"); + } +#endif } } diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index 66a67e6f9..d3d67fd8a 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -501,11 +501,11 @@ struct Softmax : public Softmax_base { for( int ni = 0; ni < MMAS_N; ++ni ) { #pragma unroll for( int jj = 0; jj < 4; ++jj ) { - // if( abs(float(mask[mi][ni].elt(ii * 4 + jj))) > 0 ) { - // this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; - // } + if( abs(toFloat(mask[mi][ni].elt(ii * 4 + jj))) > 0 ) { + this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; + } // this->elt_[2 * mi + ii][4 * ni + jj] += float(mask[mi][ni].elt(ii * 4 + jj)); - this->elt_[2 * mi + ii][4 * ni + jj] += toFloat(mask[mi][ni].elt(ii * 4 + jj)); + // this->elt_[2 * mi + ii][4 * ni + jj] += toFloat(mask[mi][ni].elt(ii * 4 + jj)); } } } diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 3593fda44..aaffe5463 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -33,11 +33,11 @@ def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_s return out, softmax_lse, S_dmask -def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, +def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, attn_mask, attn_bias, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal): softmax_d = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None) + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, attn_mask, None) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dq, dk, dv, softmax_d @@ -131,7 +131,7 @@ def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, attn_bias, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax ) - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, attn_mask, attn_bias) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k @@ -141,18 +141,20 @@ def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, attn_mask, attn_bias = ctx.saved_tensors if rng_state is not None: cur_rng_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(rng_state) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + import pdb; pdb.set_trace() _flash_attn_backward( - dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, attn_mask, attn_bias, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) - return dq, dk, dv, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + # TODO: the last two is attn_mask, attn_bias, bias need gradient def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, diff --git a/tests/test_forward.cu b/tests/test_forward.cu index 0f95d281b..e63ced238 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -1,3 +1,4 @@ +#include #include //#include #include @@ -61,7 +62,7 @@ void test_fwd() { at::Tensor attn_mask = at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).cuda(); cnt = 0; - for (int i = 0; i < batch_size; i ++) { + for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { for (int j = 0; j < nheads; j ++) { for (int k = 0; k < max_seqlen_q_; k ++) { for (int l = 0; l < max_seqlen_k_; l ++) { @@ -100,7 +101,116 @@ void test_fwd() { } } + +void test_fwd_mini() { + int batch_size = 1; + int nheads = 1; + int headdim = 16; + int max_seqlen_q_ = 2; + int max_seqlen_k_ = 2; + + float softmax_scale = 0.1; + + bool zero_tensors = false; + bool is_causal = false; + bool return_softmax = false; + + // q -> [bs * seq, head, head_dim] + // q -> [1 * 128, 1, 16] + // block q -> [128, 16] + + // k -> [bs * seq, head, head_dim] + // k -> [1 * 128, 1, 16] + // block k -> [128, 16] + + // v -> [bs * seq, head, head_dim] + // v -> [1 * 128, 1, 16] + // block k -> [128, 16] + + at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + int cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + q_cpu[i][j][k] = cnt * 0.001; + k_cpu[i][j][k] = cnt * 0.001; + v_cpu[i][j][k] = cnt * 0.001; + cnt ++; + } + } + } + + auto q = q_cpu.cuda(); + auto k = k_cpu.cuda(); + auto v = v_cpu.cuda(); + + at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + + for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { + cu_seqlens_q_cpu[i] = i * max_seqlen_q_; + cu_seqlens_k_cpu[i] = i * max_seqlen_k_; + } + + auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); + auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); + + at::Tensor attn_mask_cpu = at::zeros({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf); + + cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < max_seqlen_q_; k ++) { + for (int l = 0; l < max_seqlen_k_; l ++) { + // if (l == 0) attn_mask[i][j][k][l] = -INFINITY; + if (l == 0) attn_mask_cpu[i][j][k][l] = -3e4; + else attn_mask_cpu[i][j][k][l] = 0; + + attn_mask_cpu[i][j][k][l] = -3e4; + printf("i=%d, j=%d, k=%d, l=%d attn_mask=%f\n", i, j, k, l, attn_mask_cpu[i][j][k][l]); + } + } + } + } + + auto attn_mask = attn_mask_cpu.cuda(); + + c10::optional gen_; + c10::optional attn_mask_op; + + std::vector ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + attn_mask, + attn_mask.clone() + ); + + // ret: std::vector result = {o, softmax_lse}; + // [bs * seq * seq, head, head_dim] + // [1 * 2 * 2, 1, 16] + std::cout << "Ret vec size is " << ret.size(); + for (int i = 0; i < ret.size(); i ++) { + ret[i].cpu(); + std::cout << ret[i] << std::endl; + } +} + int main(){ - test_fwd(); + // test_fwd(); + test_fwd_mini(); return 0; } From 4bbe6b1a8bdc17e8849c03302540bcb596aa61ec Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 25 Aug 2022 11:24:16 +0800 Subject: [PATCH 17/71] add test --- ...h_mask.py => test_forward_with_mask_v2.py} | 0 .../test/test_forward_without_mask_v2.py | 230 ++++++++++++++++++ 2 files changed, 230 insertions(+) rename benchmarks/test/{test_forward_with_mask.py => test_forward_with_mask_v2.py} (100%) create mode 100644 benchmarks/test/test_forward_without_mask_v2.py diff --git a/benchmarks/test/test_forward_with_mask.py b/benchmarks/test/test_forward_with_mask_v2.py similarity index 100% rename from benchmarks/test/test_forward_with_mask.py rename to benchmarks/test/test_forward_with_mask_v2.py diff --git a/benchmarks/test/test_forward_without_mask_v2.py b/benchmarks/test/test_forward_without_mask_v2.py new file mode 100644 index 000000000..40ddcf899 --- /dev/null +++ b/benchmarks/test/test_forward_without_mask_v2.py @@ -0,0 +1,230 @@ +import torch +import torch.nn as nn + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np +import deepspeed + +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + d = t.dtype + # if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): + # with torch.cuda.amp.autocast(enabled=False): + # s = torch.nn.functional.softmax(t, dim=dim) + # else: + # s = torch.nn.functional.softmax(t, dim=dim) + s = torch.nn.functional.softmax(t, dim=dim) + return s + + +def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch.Tensor: + # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + # output back to fp16/bf16. + dtype_og = query.dtype + if upcast: + query = query.float() + key = key.float() + value = value.float() + if mask is not None: + mask = mask.float() + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + # print ("q * k: ", a) + + if biases is None: + biases = [] + for b in biases: + a += b + # print ("after bias:", a) + + if mask is not None: + # a += mask + # import pdb; pdb.set_trace() + # please do not use add now + a.masked_fill_(mask < 0, float('-inf')) + + # print ("after mask:", a) + + a = softmax_no_cast(a, -1) + # print ("softmax :", a) + + # [*, H, Q, C_hidden] + b = torch.matmul(a, value) + # print ("p * v: ", a) + return b.to(dtype_og), a.to(dtype_og) + + +def _flash_attn(q, k, v, attn_mask=None): + batch_dims = q.shape[:-3] + no_heads, n, c = q.shape[-3:] + dtype = q.dtype + + if attn_mask is not None: + attn_mask = attn_mask.half() + + # [*, B, N, H, C] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + # [B_flat, N, H, C] + q = q.reshape(-1, *q.shape[-3:]) + k = k.reshape(-1, *k.shape[-3:]) + v = v.reshape(-1, *v.shape[-3:]) + + # Flattened batch size + batch_size = q.shape[0] + + # [B_flat * N, H, C] + q = q.reshape(-1, *q.shape[-2:]) + k = k.reshape(-1, *k.shape[-2:]) + v = v.reshape(-1, *v.shape[-2:]) + + q_max_s = n + q_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device + ) + + k_max_s = n + k_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device + ) + + if attn_mask is not None: + # import pdb; pdb.set_trace() + attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]) + + out = flash_attn_unpadded_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + q_max_s, + k_max_s, + attn_mask=attn_mask, + dropout_p = 0., + softmax_scale = 1., # q has been scaled already + ) + + # [*, B, N, H, C] + out = out.reshape(*batch_dims, n, no_heads, c) + return out + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +torch.manual_seed(0) +# v2 +bs = 1 +seq = 128 +head = 1 +c_dim = 16 + +# mini +# bs = 1 +# seq = 2 +# head = 1 +# c_dim = 16 + +seq_q = seq_k = seq_v = seq + +print (10 * "*" + "prepare data" + 10 * "*" ) +dtype = torch.bfloat16 +# dtype = torch.half +device = "cuda" + +# orig_tensor = torch.stack( +# [ (i+1) * 0.1 * torch.randn((bs, seq, head, c_dim)) for i in range(seq) ] +# ,dim = 1 +# ).to(device).to(dtype) + +orig_tensor = torch.empty((bs, seq, head, seq, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) +orig_tensor.requires_grad = True +# print ("tensor: ", orig_tensor) +print ("origin shape: ", orig_tensor.shape) +# [bs, seq, seq, head, c_dim] + + +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +normal_attn_v1 = orig_tensor.clone() +output_ref, softmax_output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, mask=None, upcast=True) +# be careful here +output_ref = output_ref.transpose(-2, -3) +print ("attention output shape: ", output_ref.shape) +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +print () + + +print (10 * "*" + "normal attn fp16" + 10 * "*" ) +normal_attn_v2 = orig_tensor.clone() +output_pt, softmax_output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, mask=None) +# be careful here +output_pt = output_pt.transpose(-2, -3) +print ("attention output shape: ", output_pt.shape) +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +print () + + +print (10 * "*" + "flash attn" + 10 * "*" ) +normal_attn_flash = orig_tensor.clone() +output3 = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, attn_mask=None) +# import pdb; pdb.set_trace() +print ("flash attn output shape: ", output3.shape) +print (10 * "*" + "flash attn" + 10 * "*" ) +print () + +# print ("max abs error: ", (output3 - output_ref).abs().max()) +# print ("all close at pre@.2: ", torch.allclose(output3, output_ref, atol=1e-2)) + +print (10 * "*" + "comparing forward" + 10 * "*" ) +print("Output max diff: {0}".format((output3 - output_ref).abs().max().item())) +print("Output mean diff: {0}".format((output3 - output_ref).abs().mean().item())) + +print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) +print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) + +print("Output max diff with Pytorch: {0}".format((output3 - output_pt).abs().max().item())) +print("Output mean diff with Pytorch: {0}".format((output3 - output_pt).abs().mean().item())) + +print ("less than twice error: ", (output3 - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) +print (10 * "*" + "comparing forward" + 10 * "*" ) +print () + + +# test backward + +g = torch.randn_like(output3) +dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) +dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) +dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) + +print("dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) +print("dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) +print("dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) + +print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) + From d2b883b78c101daf48cc29b0f2c39250c9ef1a6f Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 25 Aug 2022 14:49:17 +0800 Subject: [PATCH 18/71] add mask in backward --- benchmarks/test/test_forward_with_mask_v2.py | 15 ++-- .../test/test_forward_without_mask_v2.py | 14 +++- csrc/flash_attn/fmha_api.cpp | 16 +++-- .../src/fmha_dgrad_kernel_1xN_loop.h | 70 +++++++++++++++++++ csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 43 ++++++++---- 5 files changed, 134 insertions(+), 24 deletions(-) diff --git a/benchmarks/test/test_forward_with_mask_v2.py b/benchmarks/test/test_forward_with_mask_v2.py index 865016dd2..6da081675 100644 --- a/benchmarks/test/test_forward_with_mask_v2.py +++ b/benchmarks/test/test_forward_with_mask_v2.py @@ -245,9 +245,16 @@ def gen_attn_mask(mask, neg_inf): dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) -print("dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) -print("dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) -print("dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) +print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) +print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) +print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) -print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) +print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) +print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) +print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) +print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) +print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) +print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) + +print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file diff --git a/benchmarks/test/test_forward_without_mask_v2.py b/benchmarks/test/test_forward_without_mask_v2.py index 40ddcf899..f784fd346 100644 --- a/benchmarks/test/test_forward_without_mask_v2.py +++ b/benchmarks/test/test_forward_without_mask_v2.py @@ -222,9 +222,17 @@ def gen_attn_mask(mask, neg_inf): dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) -print("dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) -print("dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) -print("dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) +print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) +print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) +print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) + +print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) +print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) +print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) + +print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) +print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) +print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 1ff138f80..a2fbaaff0 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -171,7 +171,9 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, void *dsoftmax_sum_d, float p_dropout, float softmax_scale, - bool is_causal) { + bool is_causal, + void *attn_mask, + void *attn_bias) { set_params_fprop(params, b, seqlen_q, seqlen_k, h, d, @@ -185,8 +187,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, p_dropout, softmax_scale, is_causal, - nullptr, - nullptr); + attn_mask, + attn_bias); // Set the pointers and strides. params.dq_ptr = dq.data_ptr(); @@ -472,7 +474,9 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_d.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + attn_mask ? attn_mask->data_ptr() : nullptr, + attn_bias ? attn_bias->data_ptr() : nullptr); auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -742,7 +746,9 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size softmax_d.data_ptr(), p_dropout, softmax_scale, - is_causal); + is_causal, + nullptr, + nullptr); params.blockmask = static_cast(blockmask.data_ptr()); auto gen = at::get_generator_or_default( diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 8fc079156..0ec4ddc82 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -148,6 +148,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); + // bool has_attn = !(params.attn_mask_ptr == nullptr); + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + // TODO: load fun as s + fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. @@ -197,6 +204,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_softmax_lse.move(begin); gmem_softmax_d.move(begin); + // if constexpr (has_attn) { + if (!(params.attn_mask_ptr == nullptr)) { + // TODO: mask move + gmem_mask.move(begin); + } + if (!Is_first) { gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); @@ -326,6 +339,63 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); + // if constexpr (has_attn) { + if (!(params.attn_mask_ptr == nullptr)) { + using Frag_mask = fmha::Fragment_c; + Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_mask.template load(frag_mask); + gmem_mask.move(); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + // Apply the attn mask. + softmax.apply_attn_mask(frag_mask); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + } // Apply the mask. softmax.apply_mask(mask); // Scale by log-sum-exp of the softmax diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index c8c5721e3..5e9de5ddc 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -563,15 +563,17 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - using Frag_mask = fmha::Fragment_c; - Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - gmem_mask.template load(frag_mask); - gmem_mask.move(); + // if constexpr (has_attn) { + if (!(params.attn_mask_ptr == nullptr)) { + using Frag_mask = fmha::Fragment_c; + Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_mask.template load(frag_mask); + gmem_mask.move(); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_o::MMAS_N; ++ki ) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { // 1st row - 4 elements per row. float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; @@ -591,13 +593,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i printf("\n"); } #endif - // Apply the attn mask. - softmax.apply_attn_mask(frag_mask); + // Apply the attn mask. + softmax.apply_attn_mask(frag_mask); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_o::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_o::MMAS_N; ++ki ) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { // 1st row - 4 elements per row. float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; @@ -617,6 +619,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i printf("\n"); } #endif + } // Apply the mask. // this impl is more like padding @@ -649,7 +652,15 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } softmax.template reduce_max(p_max); - +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + // can we print the tile row? + for (int i = 0; i < Mma_tile_p::MMAS_M * 2; i ++) { + printf("i=%d, p_max=%f\n", i, p_max[i]); + } + printf("\n"); + } +#endif // if ((threadIdx.x == 0) && (l == 38)) { // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); // } @@ -682,7 +693,15 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // softmax.reduce_sum(p_sum); softmax.reduce_sum_before_sync_(p_sum); // softmax.template reduce_sum_before_sync_(p_sum); - +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + // can we print the tile row? + for (int i = 0; i < Mma_tile_p::MMAS_M * 2; i ++) { + printf("i=%d, p_max=%f\n", i, p_sum[i]); + } + printf("\n"); + } +#endif // float p_sum_log[Mma_tile_p::MMAS_M * 2]; // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { // float sum = p_sum[mi]; From 050107cbbf5ad460499163104d1b0ddcfedce7c0 Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 25 Aug 2022 15:15:57 +0800 Subject: [PATCH 19/71] add test case --- benchmarks/test/test_forward_with_mask_v2.py | 15 ++++++++++++++- benchmarks/test/test_forward_without_mask_v2.py | 14 +++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/benchmarks/test/test_forward_with_mask_v2.py b/benchmarks/test/test_forward_with_mask_v2.py index 6da081675..2490e0864 100644 --- a/benchmarks/test/test_forward_with_mask_v2.py +++ b/benchmarks/test/test_forward_with_mask_v2.py @@ -257,4 +257,17 @@ def gen_attn_mask(mask, neg_inf): print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) -print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file + +print("Output dQ mean diff: {0}".format( (dq - dq_ref).abs().mean().item() )) +print("Output dK mean diff: {0}".format( (dk - dk_ref).abs().mean().item() )) +print("Output dV mean diff: {0}".format( (dv - dv_ref).abs().mean().item() )) + +print("Pytorch dQ mean diff: {0}".format( (dq_pt - dq_ref).abs().mean().item() )) +print("Pytorch dK mean diff: {0}".format( (dk_pt - dk_ref).abs().mean().item() )) +print("Pytorch dV mean diff: {0}".format( (dv_pt - dv_ref).abs().mean().item() )) + +print("Output dQ mean diff with Pytorch: {0}".format( (dq - dq_pt).abs().mean().item() )) +print("Output dK mean diff with Pytorch: {0}".format( (dk - dk_pt).abs().mean().item() )) +print("Output dV mean diff with Pytorch: {0}".format( (dv - dv_pt).abs().mean().item() )) + +print ("less than twice in max error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file diff --git a/benchmarks/test/test_forward_without_mask_v2.py b/benchmarks/test/test_forward_without_mask_v2.py index f784fd346..a24a31d6c 100644 --- a/benchmarks/test/test_forward_without_mask_v2.py +++ b/benchmarks/test/test_forward_without_mask_v2.py @@ -234,5 +234,17 @@ def gen_attn_mask(mask, neg_inf): print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) -print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) +print("Output dQ mean diff: {0}".format( (dq - dq_ref).abs().mean().item() )) +print("Output dK mean diff: {0}".format( (dk - dk_ref).abs().mean().item() )) +print("Output dV mean diff: {0}".format( (dv - dv_ref).abs().mean().item() )) + +print("Pytorch dQ mean diff: {0}".format( (dq_pt - dq_ref).abs().mean().item() )) +print("Pytorch dK mean diff: {0}".format( (dk_pt - dk_ref).abs().mean().item() )) +print("Pytorch dV mean diff: {0}".format( (dv_pt - dv_ref).abs().mean().item() )) + +print("Output dQ mean diff with Pytorch: {0}".format( (dq - dq_pt).abs().mean().item() )) +print("Output dK mean diff with Pytorch: {0}".format( (dk - dk_pt).abs().mean().item() )) +print("Output dV mean diff with Pytorch: {0}".format( (dv - dv_pt).abs().mean().item() )) + +print ("less than twice in max error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file From 6f14cb5a0276cd04efd03b9ba90b913d9e780e34 Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 26 Aug 2022 16:06:10 +0800 Subject: [PATCH 20/71] add bias --- benchmarks/test/test_forward_with_bias_v2.py | 249 ++++++++++++++++++ csrc/flash_attn/src/fmha/gmem_tile.h | 182 ++++++++++--- csrc/flash_attn/src/fmha/softmax.h | 19 ++ .../src/fmha_dgrad_kernel_1xN_loop.h | 65 +++++ csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 75 +++++- flash_attn/flash_attn_interface.py | 8 +- tests/test_forward.cu | 245 ++++++++++++++++- 7 files changed, 790 insertions(+), 53 deletions(-) create mode 100644 benchmarks/test/test_forward_with_bias_v2.py diff --git a/benchmarks/test/test_forward_with_bias_v2.py b/benchmarks/test/test_forward_with_bias_v2.py new file mode 100644 index 000000000..272cabf01 --- /dev/null +++ b/benchmarks/test/test_forward_with_bias_v2.py @@ -0,0 +1,249 @@ +import torch +import torch.nn as nn + +from functools import partial +import math +from typing import Optional, Callable, List, Tuple, Sequence +import numpy as np +import deepspeed + +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + d = t.dtype + # if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): + # with torch.cuda.amp.autocast(enabled=False): + # s = torch.nn.functional.softmax(t, dim=dim) + # else: + # s = torch.nn.functional.softmax(t, dim=dim) + s = torch.nn.functional.softmax(t, dim=dim) + return s + + +def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch.Tensor: + # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + # output back to fp16/bf16. + dtype_og = query.dtype + if upcast: + query = query.float() + key = key.float() + value = value.float() + if mask is not None: + mask = mask.float() + if bias is not None: + biases = biases.float() + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + # print ("q * k: ", a) + # import pdb; pdb.set_trace() + + if biases is not None: + a += biases + # print ("after bias:", a) + + if mask is not None: + # a += mask + # import pdb; pdb.set_trace() + # please do not use add now + a.masked_fill_(mask < 0, float('-inf')) + + # print ("after mask:", a) + + a = softmax_no_cast(a, -1) + # print ("softmax :", a) + + # [*, H, Q, C_hidden] + b = torch.matmul(a, value) + # print ("p * v: ", a) + return b.to(dtype_og), a.to(dtype_og) + + +def _flash_attn(q, k, v, attn_mask=None, attn_bias=None): + batch_dims = q.shape[:-3] + no_heads, n, c = q.shape[-3:] + dtype = q.dtype + + # [*, B, N, H, C] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + # [B_flat, N, H, C] + q = q.reshape(-1, *q.shape[-3:]) + k = k.reshape(-1, *k.shape[-3:]) + v = v.reshape(-1, *v.shape[-3:]) + + # Flattened batch size + batch_size = q.shape[0] + + # [B_flat * N, H, C] + q = q.reshape(-1, *q.shape[-2:]) + k = k.reshape(-1, *k.shape[-2:]) + v = v.reshape(-1, *v.shape[-2:]) + + q_max_s = n + q_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device + ) + + k_max_s = n + k_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device + ) + + if attn_mask is not None: + # import pdb; pdb.set_trace() + attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]) + + if attn_bias is not None: + # import pdb; pdb.set_trace() + attn_bias = attn_bias.reshape([bs * n, no_heads, n, n]) + + out = flash_attn_unpadded_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + q_max_s, + k_max_s, + # attn_mask=attn_mask, + attn_bias=attn_bias, + dropout_p = 0., + softmax_scale = 1., # q has been scaled already + ) + + # [*, B, N, H, C] + out = out.reshape(*batch_dims, n, no_heads, c) + return out + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +torch.manual_seed(0) +# v2 +bs = 1 +seq = 128 +head = 1 +c_dim = 16 + +# mini +# bs = 1 +# seq = 2 +# head = 1 +# c_dim = 16 + +seq_q = seq_k = seq_v = seq + +print (10 * "*" + "prepare data" + 10 * "*" ) +dtype = torch.bfloat16 +# dtype = torch.half +device = "cuda" + +# orig_tensor = torch.stack( +# [ (i+1) * 0.1 * torch.randn((bs, seq, head, c_dim)) for i in range(seq) ] +# ,dim = 1 +# ).to(device).to(dtype) + +orig_tensor = torch.empty((bs, seq, head, seq, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) +orig_tensor.requires_grad = True +# print ("tensor: ", orig_tensor) +print ("origin shape: ", orig_tensor.shape) +# [bs, seq, seq, head, c_dim] + +bias = torch.rand( + 1, 1, head, seq_q, seq_k, dtype=dtype, device=device +) * 0 + +print ("bias shape: ", bias.shape) +bias_broadcast = bias.expand([bs, seq_k, head, seq_q, seq_k]) +print ("bias_broadcast shape: ", bias_broadcast.shape) + +# print ("bias_broadcast: ", bias_broadcast) + +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +normal_attn_v1 = orig_tensor.clone() +output_ref, softmax_output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, biases=bias_broadcast, upcast=True) +# be careful here +output_ref = output_ref.transpose(-2, -3) +print ("attention output shape: ", output_ref.shape) +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +print () + + +print (10 * "*" + "normal attn fp16" + 10 * "*" ) +normal_attn_v2 = orig_tensor.clone() +output_pt, softmax_output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, biases=bias_broadcast) +# be careful here +output_pt = output_pt.transpose(-2, -3) +print ("attention output shape: ", output_pt.shape) +print (10 * "*" + "normal attn fp32" + 10 * "*" ) +print () + + +print (10 * "*" + "flash attn" + 10 * "*" ) +normal_attn_flash = orig_tensor.clone() +output3 = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, attn_bias=bias_broadcast) +# import pdb; pdb.set_trace() +print ("flash attn output shape: ", output3.shape) +print (10 * "*" + "flash attn" + 10 * "*" ) +print () + +# print ("max abs error: ", (output3 - output_ref).abs().max()) +# print ("all close at pre@.2: ", torch.allclose(output3, output_ref, atol=1e-2)) + +print (10 * "*" + "comparing forward" + 10 * "*" ) +print("Output max diff: {0}".format((output3 - output_ref).abs().max().item())) +print("Output mean diff: {0}".format((output3 - output_ref).abs().mean().item())) + +print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) +print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) + +print("Output max diff with Pytorch: {0}".format((output3 - output_pt).abs().max().item())) +print("Output mean diff with Pytorch: {0}".format((output3 - output_pt).abs().mean().item())) + +print ("less than twice error: ", (output3 - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) +print (10 * "*" + "comparing forward" + 10 * "*" ) +print () + + +# test backward + +# g = torch.randn_like(output3) +# dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) +# dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) +# dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) + +# print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) +# print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) +# print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) + +# print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) +# print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) +# print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) + +# print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) +# print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) +# print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) + +# print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 66e16241a..5a4f8c995 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -114,12 +114,12 @@ struct Gmem_tile_qkv { threadIdx.x, threadIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); printf("\n"); } - if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("use_seqlen_q=%d\n", use_seqlen_q); - printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); - printf("\n"); - } + // if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("use_seqlen_q=%d\n", use_seqlen_q); + // printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", + // threadIdx.x, blockIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); + // printf("\n"); + // } #endif // Assemble the final pointer. ptr += row_offset + col * BYTES_PER_LDG; @@ -256,14 +256,14 @@ struct Gmem_tile_o { threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); printf("\n"); } - if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("print o parameter\n"); - printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); - printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); - printf("\n"); - } + // if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("print o parameter\n"); + // printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", + // threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); + // printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", + // threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); + // printf("\n"); + // } #endif // Is that thread active on the last STG? if( HAS_INCOMPLETE_STG ) { @@ -618,38 +618,158 @@ struct Gmem_tile_mma_mask { //////////////////////////////////////////////////////////////////////////////////////////////////// // attn bias struct like s, maybe later can reuse the above declaration -template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > -struct Gmem_tile_mma_bias : public Base { +template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> +struct Gmem_tile_mma_bias { - // The number of mmas in the vertical dimension. - static constexpr int M = Base::MMAS_M; - // The number of mmas in the horizontal dimension. - static constexpr int N = Base::MMAS_N; + using Mma_tile = fmha::Hmma_tile; // The type of the vectors stored by each STG. - using Type = typename Base::Type; + using StoreType = uint32_t; + + // static constexpr int LDG_ELEMENTS = 2 + // using Type = typename fmha::Uint_from_size_in_bytes< LDG_ELEMENTS * BYTES_PER_ELEMENT >::Type; + + // The number of MMAs in the M dimension. + static constexpr int M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int N = Mma_tile::MMAS_N; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + static constexpr int COLS = Cta_tile::N; + + // The size of each LDG. + // load two elements of data + static constexpr int BYTES_PER_LDG = 2 * BYTES_PER_ELEMENT; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of LDGS needed to store a chunk of the P matrix in total. + // Tell me if has more efficient way + static constexpr int LDGS_PER_THREAD_PER_WARP = 4; + static constexpr int THREADS_PER_QUAD = 4; + static constexpr int COL_PER_MMA_PER_CTA = Cta_tile::THREADS_PER_WARP / THREADS_PER_QUAD; // Ctor. template< typename Params, typename Block_info > - inline __device__ Gmem_tile_mma_bias(const Params ¶ms, const Block_info& binfo, const int tidx) - : Base(params.attn_bias_ptr, params, binfo.bidb, binfo.bidh, tidx) { + inline __device__ Gmem_tile_mma_bias(const Params ¶ms, + // const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, + const Block_info& binfo, const int tidx, const int loop_step_idx) + : ptr_(static_cast(params.attn_bias_ptr)) + // : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , actual_seqlen_k(binfo.actual_seqlen_k) + , tidx_(tidx) + , loop_step_idx(loop_step_idx) + { + row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + + const int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + const int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + // this col is mean the 8x4 tile's cole + + row = warp_m * Mma_tile::M_PER_MMA + quad; + static_assert(Mma_tile::M_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + col = warp_n * Mma_tile::N_PER_MMA + tid; + static_assert(Mma_tile::N_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + // The distance between two blocks (in bytes). + // TODO: mask is [bs, head, seq_q, seq_k] + // The block index. + uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + + // the index of bs and head dim + uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + // row_offset = (uint32_t)(row * row_stride_in_bytes); + row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", + tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); + printf("bidb=%d, bidh=%d, param.h=%d\n", binfo.bidb, binfo.bidh, params.h); + printf("\n"); + } +#endif + // do we need to move col first if seklen_k > cols + ptr_ += row_offset; } // Load from global memory to Fragment. - template - inline __device__ void load(Fragment (&frag)[N][M]) { + template + inline __device__ void load(Fragment (&frag)[M][N]) { + // using Fragment = typename fmha::Fragment; + + const void *ptrs[LDGS_PER_THREAD_PER_WARP]; + uint32_t preds[LDGS_PER_THREAD_PER_WARP]; + #pragma unroll for( int mi = 0; mi < M; mi++ ) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - uint4 dst; - Base::load(dst, mi, ni); - frag[ni][mi].reg(0) = dst.x; - frag[ni][mi].reg(2) = dst.y; - frag[ni][mi].reg(1) = dst.z; - frag[ni][mi].reg(3) = dst.w; + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + int offset = ii * 2 + jj; + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + // const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + // 8 is actually col of half data now, for more general case ? + // the row is already in the right position + ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + + preds[offset] = (current_row <= min(ROWS, actual_seqlen_q)) + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", + mi, ni, ii, jj, offset, current_row, current_col, ptr_, ptrs[offset], preds[offset]); + printf("current_row=%d, current_col=%d, ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d\n", + current_row, current_col, ROWS, actual_seqlen_q, COLS, actual_seqlen_k); + printf("cond 1=%d\n", (current_row <= min(ROWS, actual_seqlen_q))); + printf("cond 2=%d\n", ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); + printf("\n"); + } +#endif + } + } + + // load data + Ldg_functor fct(frag[mi][ni].regs_, ptrs); + #pragma unroll + for(int kk = 0; kk < LDGS_PER_THREAD_PER_WARP; ++kk ) { + fct.load(kk, preds[kk]); + } } } } + + inline __device__ void move(const int steps = 1) { + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + this->actual_seqlen_q -= ROWS * steps; + } + + int row; + int col; + const int loop_step_idx; + uint32_t row_stride_in_bytes; + // The pointer. + char *ptr_; + int actual_seqlen_q; + int actual_seqlen_k; + const int tidx_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index d3d67fd8a..47e0a5b34 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -512,6 +512,25 @@ struct Softmax : public Softmax_base { } } + + template + inline __device__ void apply_attn_bias(const Fragment (&bias)[MMAS_M][MMAS_N]) { + #pragma unroll + for( int mi = 0; mi < MMAS_M; ++mi ) { + #pragma unroll + for( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for( int ni = 0; ni < MMAS_N; ++ni ) { + #pragma unroll + for( int jj = 0; jj < 4; ++jj ) { + this->elt_[2 * mi + ii][4 * ni + jj] += toFloat(bias[mi][ni].elt(ii * 4 + jj)); + } + } + } + } + } + + // Pack the data to a fragment for the next GEMM. template inline __device__ void pack(Fragment_a (&dst)[K][M]) const { diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 0ec4ddc82..425384013 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -155,6 +155,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); // TODO: load fun as s + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); + // TODO: load fun as s + fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. @@ -396,6 +402,64 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } #endif } + + if (!(params.attn_bias_ptr == nullptr)) { + using Frag_Bias = fmha::Fragment_c; + Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_bias.template load(frag_bias); + gmem_bias.move(); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + // Apply the attn mask. + softmax.apply_attn_bias(frag_bias); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + } + // Apply the mask. softmax.apply_mask(mask); // Scale by log-sum-exp of the softmax @@ -414,6 +478,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.template pack(frag_p); // Store s * dmask to smem for transpose + // ? smem_s.store(frag_p); // Trigger the load for the next Q values. diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 5e9de5ddc..f280c329f 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -357,9 +357,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // bool has_bias = !(params.attn_bias_ptr == nullptr); // Allocate the global memory tile loader for bias. - // using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; - // // conctructor - // Gmem_tile_bias gmem_bias(params, binfo, tidx); + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); // TODO: load fun as s Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -381,11 +381,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_mask.move(begin); } - // // if constexpr (has_bias) { - // if (!(params.attn_bias_ptr == nullptr)) { - // // TODO: bias move - // gmem_bias.move(begin); - // } + // if constexpr (has_bias) { + if (!(params.attn_bias_ptr == nullptr)) { + // TODO: bias move + gmem_bias.move(begin); + } gmem_softmax_lse.move(begin); @@ -621,6 +621,63 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i #endif } + if (!(params.attn_bias_ptr == nullptr)) { + using Frag_Bias = fmha::Fragment_c; + Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + gmem_bias.template load(frag_bias); + gmem_bias.move(); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + // Apply the attn mask. + softmax.apply_attn_bias(frag_bias); + +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + // 1st row - 4 elements per row. + float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; + float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; + float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; + float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; + + // 2nd row - 4 elements per row. + float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; + float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; + float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; + float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; + + printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); + printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); + } + } + printf("\n"); + } +#endif + } + // Apply the mask. // this impl is more like padding softmax.apply_mask(mask); @@ -673,6 +730,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Compute the exponential value. // softmax.apply_exp(p_max); + // Compute: exp(p - p_max) softmax.scale_apply_exp(p_max, params.scale_bmm1f); // if (!Is_first) { @@ -728,6 +786,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); + // frag_p = exp(s{i} - s{max}) softmax.template pack(frag_p); if (Return_softmax) { diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index aaffe5463..6a9a59966 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -17,13 +17,7 @@ def _get_block_size(device, head_dim, is_dropout): def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, attn_mask, attn_bias, dropout_p, softmax_scale, causal, return_softmax): # import pdb; pdb.set_trace() - if attn_mask is None: - out, softmax_lse, *rest = flash_attn_cuda.fwd( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, - False, causal, return_softmax, None, None, None - ) - else: - out, softmax_lse, *rest = flash_attn_cuda.fwd( + out, softmax_lse, *rest = flash_attn_cuda.fwd( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, return_softmax, None, attn_mask, attn_bias ) diff --git a/tests/test_forward.cu b/tests/test_forward.cu index e63ced238..f137e89cd 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -3,7 +3,7 @@ //#include #include -void test_fwd() { +void test_fwd_with_mask() { int batch_size = 1; int nheads = 1; int headdim = 16; @@ -74,7 +74,9 @@ void test_fwd() { } c10::optional gen_; - c10::optional attn_mask_op; + c10::optional attn_bias; + + // std::cout << "attn bias" << attn_bias << std::endl; std::vector ret = mha_fwd( q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i @@ -91,7 +93,7 @@ void test_fwd() { return_softmax, gen_, attn_mask, - attn_mask.clone() + attn_bias ); std::cout << "Ret vec size is " << ret.size(); @@ -102,7 +104,7 @@ void test_fwd() { } -void test_fwd_mini() { +void test_fwd_with_mask_mini() { int batch_size = 1; int nheads = 1; int headdim = 16; @@ -179,7 +181,235 @@ void test_fwd_mini() { auto attn_mask = attn_mask_cpu.cuda(); c10::optional gen_; - c10::optional attn_mask_op; + c10::optional attn_bias; + + // std::cout << "attn bias: " << attn_bias << std::endl; + + std::vector ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + attn_mask, + attn_bias + ); + + // ret: std::vector result = {o, softmax_lse}; + // [bs * seq * seq, head, head_dim] + // [1 * 2 * 2, 1, 16] + std::cout << "Ret vec size is " << ret.size(); + for (int i = 0; i < ret.size(); i ++) { + ret[i].cpu(); + std::cout << ret[i] << std::endl; + } +} + + +void test_fwd_with_bias_mini() { + int batch_size = 1; + int nheads = 1; + int headdim = 16; + int max_seqlen_q_ = 2; + int max_seqlen_k_ = 2; + + float softmax_scale = 0.1; + + bool zero_tensors = false; + bool is_causal = false; + bool return_softmax = false; + + // q -> [bs * seq, head, head_dim] + // q -> [1 * 128, 1, 16] + // block q -> [128, 16] + + // k -> [bs * seq, head, head_dim] + // k -> [1 * 128, 1, 16] + // block k -> [128, 16] + + // v -> [bs * seq, head, head_dim] + // v -> [1 * 128, 1, 16] + // block k -> [128, 16] + + at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + int cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + q_cpu[i][j][k] = cnt * 0.001; + k_cpu[i][j][k] = cnt * 0.001; + v_cpu[i][j][k] = cnt * 0.001; + cnt ++; + } + } + } + + auto q = q_cpu.cuda(); + auto k = k_cpu.cuda(); + auto v = v_cpu.cuda(); + + at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + + for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { + cu_seqlens_q_cpu[i] = i * max_seqlen_q_; + cu_seqlens_k_cpu[i] = i * max_seqlen_k_; + } + + auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); + auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); + + at::Tensor attn_bias_cpu = at::zeros({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf); + + cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < max_seqlen_q_; k ++) { + for (int l = 0; l < max_seqlen_k_; l ++) { + // if (l == 0) attn_mask[i][j][k][l] = -INFINITY; + if (l == 0) attn_bias_cpu[i][j][k][l] = -3e4; + else attn_bias_cpu[i][j][k][l] = 0; + + attn_bias_cpu[i][j][k][l] = 100; + printf("i=%d, j=%d, k=%d, l=%d attn_bias=%f\n", i, j, k, l, attn_bias_cpu[i][j][k][l]); + // std::cout << "i=" << i << ", j=" << j << ", k=" << k << ", l" + // << l << << ", attn_bias=" << attn_bias_cpu[i][j][k][l] << std::endl; + } + } + } + } + + auto attn_bias = attn_bias_cpu.cuda(); + + c10::optional gen_; + c10::optional attn_mask; + + // std::cout << attn_mask << std::endl; + + std::vector ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + attn_mask, + attn_bias + ); + + // ret: std::vector result = {o, softmax_lse}; + // [bs * seq * seq, head, head_dim] + // [1 * 2 * 2, 1, 16] + std::cout << "Ret vec size is " << ret.size(); + for (int i = 0; i < ret.size(); i ++) { + ret[i].cpu(); + std::cout << ret[i] << std::endl; + } +} + + +void test_fwd_with_bias() { + int batch_size = 1; + int nheads = 1; + int headdim = 16; + int max_seqlen_q_ = 128; + int max_seqlen_k_ = 128; + + float softmax_scale = 0.1; + + bool zero_tensors = false; + bool is_causal = false; + bool return_softmax = false; + + // q -> [bs * seq, head, head_dim] + // q -> [1 * 128, 1, 16] + // block q -> [128, 16] + + // k -> [bs * seq, head, head_dim] + // k -> [1 * 128, 1, 16] + // block k -> [128, 16] + + // v -> [bs * seq, head, head_dim] + // v -> [1 * 128, 1, 16] + // block k -> [128, 16] + + at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + int cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + q_cpu[i][j][k] = cnt * 0.001; + k_cpu[i][j][k] = cnt * 0.001; + v_cpu[i][j][k] = cnt * 0.001; + cnt ++; + } + } + } + + auto q = q_cpu.cuda(); + auto k = k_cpu.cuda(); + auto v = v_cpu.cuda(); + + at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); + + for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { + cu_seqlens_q_cpu[i] = i * max_seqlen_q_; + cu_seqlens_k_cpu[i] = i * max_seqlen_k_; + } + + auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); + auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); + + at::Tensor attn_bias_cpu = at::zeros({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf); + + cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < max_seqlen_q_; k ++) { + for (int l = 0; l < max_seqlen_k_; l ++) { + // if (l == 0) attn_mask[i][j][k][l] = -INFINITY; + // if (l == 0) attn_bias_cpu[i][j][k][l] = -3e4; + // else attn_bias_cpu[i][j][k][l] = 0; + + attn_bias_cpu[i][j][k][l] = 0; + // attn_bias_cpu[i][j][k][l] = cnt * 0.001; + cnt ++; + // printf("i=%d, j=%d, k=%d, l=%d attn_bias=%f\n", i, j, k, l, attn_bias_cpu[i][j][k][l]); + // std::cout << "i=" << i << ", j=" << j << ", k=" << k << ", l" + // << l << << ", attn_bias=" << attn_bias_cpu[i][j][k][l] << std::endl; + } + } + } + } + + auto attn_bias = attn_bias_cpu.cuda(); + + c10::optional gen_; + c10::optional attn_mask; + + // std::cout << attn_mask << std::endl; std::vector ret = mha_fwd( q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i @@ -196,7 +426,7 @@ void test_fwd_mini() { return_softmax, gen_, attn_mask, - attn_mask.clone() + attn_bias ); // ret: std::vector result = {o, softmax_lse}; @@ -211,6 +441,7 @@ void test_fwd_mini() { int main(){ // test_fwd(); - test_fwd_mini(); + // test_fwd_with_bias_mini(); + test_fwd_with_bias(); return 0; } From dfcd4a75fd35e7025e34041dd19394c2abb6441a Mon Sep 17 00:00:00 2001 From: robotcator Date: Sat, 27 Aug 2022 17:26:02 +0800 Subject: [PATCH 21/71] add mask --- benchmarks/test/test_forward_with_bias_v2.py | 53 ++++++++++++++------ csrc/flash_attn/fmha_api.cpp | 12 +++++ csrc/flash_attn/src/fmha/gmem_tile.h | 2 +- csrc/flash_attn/src/fmha/softmax.h | 11 +++- 4 files changed, 59 insertions(+), 19 deletions(-) diff --git a/benchmarks/test/test_forward_with_bias_v2.py b/benchmarks/test/test_forward_with_bias_v2.py index 272cabf01..3b17ae64d 100644 --- a/benchmarks/test/test_forward_with_bias_v2.py +++ b/benchmarks/test/test_forward_with_bias_v2.py @@ -53,6 +53,7 @@ def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch # import pdb; pdb.set_trace() if biases is not None: + print ("attn_shape = {}, bias_shape = {}".format(a.shape, biases.shape)) a += biases # print ("after bias:", a) @@ -108,11 +109,19 @@ def _flash_attn(q, k, v, attn_mask=None, attn_bias=None): if attn_mask is not None: # import pdb; pdb.set_trace() - attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]) + attn_mask = attn_mask.reshape([batch_size , no_heads, n, n]).contiguous() if attn_bias is not None: # import pdb; pdb.set_trace() - attn_bias = attn_bias.reshape([bs * n, no_heads, n, n]) + if attn_bias.is_contiguous: + print ("attn_bias it not contiguous") + attn_bias = attn_bias.reshape([batch_size , no_heads, n, n]).contiguous() + # attn_bias = attn_bias.reshape([batch_size , no_heads, n, n]) + + print ("check shapes q_shape = {} k_shape = {} v_shape = {}".format(q.shape, k.shape, v.shape)) + print ("check shapes q_cu_shape = {} k_cu_shape = {}".format(q_cu_seqlens.shape, k_cu_seqlens.shape)) + if attn_bias is not None: + print ("attn_bias shape = {}".format(attn_bias.shape)) out = flash_attn_unpadded_func( q, @@ -122,7 +131,7 @@ def _flash_attn(q, k, v, attn_mask=None, attn_bias=None): k_cu_seqlens, q_max_s, k_max_s, - # attn_mask=attn_mask, + attn_mask=None, attn_bias=attn_bias, dropout_p = 0., softmax_scale = 1., # q has been scaled already @@ -142,22 +151,22 @@ def gen_attn_mask(mask, neg_inf): torch.manual_seed(0) # v2 +# bs = 1 +# seq = 128 +# head = 1 +# c_dim = 16 + +# mini bs = 1 seq = 128 head = 1 c_dim = 16 -# mini -# bs = 1 -# seq = 2 -# head = 1 -# c_dim = 16 - -seq_q = seq_k = seq_v = seq +seq_q = seq_k = seq_v = 128 print (10 * "*" + "prepare data" + 10 * "*" ) -dtype = torch.bfloat16 -# dtype = torch.half +# dtype = torch.bfloat16 +dtype = torch.half device = "cuda" # orig_tensor = torch.stack( @@ -165,18 +174,18 @@ def gen_attn_mask(mask, neg_inf): # ,dim = 1 # ).to(device).to(dtype) -orig_tensor = torch.empty((bs, seq, head, seq, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) +orig_tensor = torch.empty((bs, seq, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) orig_tensor.requires_grad = True # print ("tensor: ", orig_tensor) print ("origin shape: ", orig_tensor.shape) # [bs, seq, seq, head, c_dim] -bias = torch.rand( +bias = torch.randn( 1, 1, head, seq_q, seq_k, dtype=dtype, device=device -) * 0 +) * 1 print ("bias shape: ", bias.shape) -bias_broadcast = bias.expand([bs, seq_k, head, seq_q, seq_k]) +bias_broadcast = bias.expand([bs, seq, head, seq_q, seq_k]) print ("bias_broadcast shape: ", bias_broadcast.shape) # print ("bias_broadcast: ", bias_broadcast) @@ -216,6 +225,9 @@ def gen_attn_mask(mask, neg_inf): print("Output max diff: {0}".format((output3 - output_ref).abs().max().item())) print("Output mean diff: {0}".format((output3 - output_ref).abs().mean().item())) +# print("Output max diff: {0}".format((output3[:,0,:,:,:] - output_ref[:,0,:,:,:]).abs().max().item())) +# print("Output max diff: {0}".format((output3[:,3,:,:,:] - output_ref[:,3,:,:,:]).abs().max().item())) + print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) @@ -226,6 +238,15 @@ def gen_attn_mask(mask, neg_inf): print (10 * "*" + "comparing forward" + 10 * "*" ) print () +# max_diff = (output3 - output_ref).abs().max().item() +# relative_diff = (output_pt - output_ref).abs().max().item() + +# for i in range(bs): +# for j in range(seq_q): +# for k in range(seq_k): +# if (output3[i, j, k, :, :] - output_ref[i, j, k, :, :]).abs().max().item() >= 2 * (relative_diff): +# print ("i={}, j={}, k={} output3={}".format(i, j, k, output3[i, j, k, :, :].data)) +# print ("i={}, j={}, k={} output_pt={}".format(i, j, k, output_ref[i, j, k, :, :].data)) # test backward diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index a2fbaaff0..11ac0e071 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -268,6 +268,18 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (attn_bias.has_value()) { + TORCH_CHECK(attn_bias.value().is_cuda()); + TORCH_CHECK(attn_bias.value().dtype() == q_dtype); + TORCH_CHECK(attn_bias.value().is_contiguous()); + } + + if (attn_mask.has_value()) { + TORCH_CHECK(attn_mask.value().is_cuda()); + TORCH_CHECK(attn_mask.value().dtype() == q_dtype); + TORCH_CHECK(attn_mask.value().is_contiguous()); + } + int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; // Need to round max_seqlen_k to multiples of blocksize_c int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 5a4f8c995..ad074e3e2 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -528,7 +528,7 @@ struct Gmem_tile_mma_mask { "only support sm80 m16n8k16 tensor core"); // The distance between two blocks (in bytes). - // TODO: mask is [bs, head, seq_q, seq_k] + // TODO: mask is [bs * seq, head, seq_q, seq_k] // The block index. uint32_t bidx = binfo.bidb * params.h + binfo.bidh; diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index 47e0a5b34..aaa38786b 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -514,7 +514,7 @@ struct Softmax : public Softmax_base { template - inline __device__ void apply_attn_bias(const Fragment (&bias)[MMAS_M][MMAS_N]) { + inline __device__ void apply_attn_bias(const Fragment (&bias)[MMAS_M][MMAS_N], int l) { #pragma unroll for( int mi = 0; mi < MMAS_M; ++mi ) { #pragma unroll @@ -523,7 +523,14 @@ struct Softmax : public Softmax_base { for( int ni = 0; ni < MMAS_N; ++ni ) { #pragma unroll for( int jj = 0; jj < 4; ++jj ) { - this->elt_[2 * mi + ii][4 * ni + jj] += toFloat(bias[mi][ni].elt(ii * 4 + jj)); + float value = toFloat(bias[mi][ni].elt(ii * 4 + jj)); + this->elt_[2 * mi + ii][4 * ni + jj] += value; +#ifdef DEBUG_PRINT + if ((blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("AttnBias: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, ldx = %d, blockIdx.x = %d\n", + threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), l, blockIdx.x); + } +#endif } } } From ffd07e85d425d8d0136cd95a3d6118a27646b339 Mon Sep 17 00:00:00 2001 From: xh Date: Tue, 30 Aug 2022 11:38:43 +0800 Subject: [PATCH 22/71] add bias test --- benchmarks/test/test_forward_with_bias_v2.py | 33 ++-- .../src/fmha_dgrad_kernel_1xN_loop.h | 33 +++- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 3 +- setup.py | 7 +- tests/fmha_api.h | 24 +++ tests/test_forward.cu | 152 ++++++++++++++---- tests/tools/check_output.py | 115 +++++++++++++ tests/tools/rebuild_mat.py | 94 +++++++++++ 8 files changed, 405 insertions(+), 56 deletions(-) create mode 100644 tests/tools/check_output.py create mode 100644 tests/tools/rebuild_mat.py diff --git a/benchmarks/test/test_forward_with_bias_v2.py b/benchmarks/test/test_forward_with_bias_v2.py index 3b17ae64d..4cac9f524 100644 --- a/benchmarks/test/test_forward_with_bias_v2.py +++ b/benchmarks/test/test_forward_with_bias_v2.py @@ -114,9 +114,10 @@ def _flash_attn(q, k, v, attn_mask=None, attn_bias=None): if attn_bias is not None: # import pdb; pdb.set_trace() if attn_bias.is_contiguous: - print ("attn_bias it not contiguous") + print ("attn_bias it not contiguous, stride is", attn_bias.stride()) attn_bias = attn_bias.reshape([batch_size , no_heads, n, n]).contiguous() # attn_bias = attn_bias.reshape([batch_size , no_heads, n, n]) + print ("attn_bias stride is", attn_bias.stride()) print ("check shapes q_shape = {} k_shape = {} v_shape = {}".format(q.shape, k.shape, v.shape)) print ("check shapes q_cu_shape = {} k_cu_shape = {}".format(q_cu_seqlens.shape, k_cu_seqlens.shape)) @@ -180,7 +181,7 @@ def gen_attn_mask(mask, neg_inf): print ("origin shape: ", orig_tensor.shape) # [bs, seq, seq, head, c_dim] -bias = torch.randn( +bias = torch.ones( 1, 1, head, seq_q, seq_k, dtype=dtype, device=device ) * 1 @@ -250,21 +251,21 @@ def gen_attn_mask(mask, neg_inf): # test backward -# g = torch.randn_like(output3) -# dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) -# dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) -# dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) +g = torch.randn_like(output3) +dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) +dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) +dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) -# print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) -# print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) -# print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) +print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) +print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) +print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) -# print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) -# print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) -# print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) +print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) +print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) +print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) -# print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) -# print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) -# print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) +print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) +print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) +print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) -# print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file +print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 425384013..9f8130e37 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -216,6 +216,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_mask.move(begin); } + // if constexpr (has_attn) { + if (!(params.attn_bias_ptr == nullptr)) { + // TODO: mask move + gmem_bias.move(begin); + } + if (!Is_first) { gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); @@ -240,6 +246,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Commit the data for Q, dO, and V to shared memory. gmem_q.commit(gemm_q_k.smem_q); gmem_do.commit(smem_do); + // D_sum if (Is_first) { dot_do_o( gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx @@ -345,7 +352,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - // if constexpr (has_attn) { + // if constexpr (has_attn) { if (!(params.attn_mask_ptr == nullptr)) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; @@ -433,7 +440,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } #endif // Apply the attn mask. - softmax.apply_attn_bias(frag_bias); + softmax.apply_attn_bias(frag_bias, l); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { @@ -464,7 +471,25 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.apply_mask(mask); // Scale by log-sum-exp of the softmax // softmax.apply_exp(p_lse); + // exp (x - (max+log(sum))) = exp(x - max) / sum softmax.template scale_apply_exp(p_lse, params.scale_bmm1f); +#ifdef DEBUG_PRINT + if ((blockIdx.x == 0) && (blockIdx.y == 0)) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + for (int ii = 0; ii < 2; ii ++) { + for (int jj = 0; jj < 4; jj ++) { + int st_row = 2 * mi + ii; + int st_col = 4 * ki + jj; + printf("bwd softmax: threadIdx=%d, l=%d, mi=%d, ki=%d, ii=%d, jj=%d, elt=%d", + threadIdx.x, mi, ki, ii, jj, softmax.elt_[st_row][st_col]); + } + } + } + } + printf("\n"); + } +#endif if (Is_dropout) { // softmax.apply_dropout(ph, params.p_dropout_in_uint); // softmax.template apply_dropout(ph, params.p_dropout_in_uint); @@ -478,7 +503,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.template pack(frag_p); // Store s * dmask to smem for transpose - // ? + // how to test smem_s.store(frag_p); // Trigger the load for the next Q values. @@ -493,6 +518,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // __syncthreads(); // } + // what's meaning? fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; #pragma unroll for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { @@ -771,6 +797,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than // the total amount of shared mem? // Epilogue swizzle for dV + // data flow: fragment -> smem_dv -> global Smem_tile_dv smem_dv(&smem_[0], tidx); smem_dv.template store(acc_dv); diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index f280c329f..934f575db 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -624,6 +624,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if (!(params.attn_bias_ptr == nullptr)) { using Frag_Bias = fmha::Fragment_c; Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::clear(frag_bias); gmem_bias.template load(frag_bias); gmem_bias.move(); @@ -651,7 +652,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } #endif // Apply the attn mask. - softmax.apply_attn_bias(frag_bias); + softmax.apply_attn_bias(frag_bias, l); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { diff --git a/setup.py b/setup.py index 8fad3ed94..2467be80e 100644 --- a/setup.py +++ b/setup.py @@ -126,11 +126,12 @@ def append_nvcc_threads(nvcc_extra_args): ], extra_compile_args={ # "cxx": ["-O3"] + generator_flag, - "cxx": ["-g"] + generator_flag, + "cxx": ["-O3", "-DDEBUG_PRINT"] + generator_flag, "nvcc": append_nvcc_threads( [ - #"-O3", - "-g", + "-O3", + # "-g", + "-DDEBUG_PRINT", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "--expt-relaxed-constexpr", diff --git a/tests/fmha_api.h b/tests/fmha_api.h index 6d6ef7410..a1ff42204 100644 --- a/tests/fmha_api.h +++ b/tests/fmha_api.h @@ -22,3 +22,27 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const c10::optional &attn_mask, // attn_mask const c10::optional &attn_bias // attn bias ); + + +std::vector +mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size + const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &out, // total_q x num_heads x head_size + const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp + at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &cu_seqlens_q, // b+1 + const at::Tensor &cu_seqlens_k, // b+1 + const int max_seqlen_q_, + const int max_seqlen_k_, // max sequence length to choose the kernel + const float p_dropout, // probability to drop + const float softmax_scale, + const bool zero_tensors, + const bool is_causal, + c10::optional gen_, + const c10::optional &attn_mask, // attn_mask + const c10::optional &attn_bias // attn bias +); \ No newline at end of file diff --git a/tests/test_forward.cu b/tests/test_forward.cu index f137e89cd..ee4f6d3c6 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -2,6 +2,9 @@ #include //#include #include +#include +#include + void test_fwd_with_mask() { int batch_size = 1; @@ -326,14 +329,31 @@ void test_fwd_with_bias_mini() { } -void test_fwd_with_bias() { +void dump_tensor(const std::string &tensor_name, at::Tensor &tensor) { + std::string file_name = tensor_name + ".data"; + std::ofstream file(file_name.c_str()); + // file << tensor_name << std::endl; + // file << tensor << std::endl; + std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; + auto flatten_tensor = tensor.flatten(); + auto size = flatten_tensor.numel(); + + for (int i = 0; i < size; i ++) { + file << flatten_tensor[i].item() << " "; + } + file << std::endl; +} + + +void test_fwd_with_bias(bool has_bias) { int batch_size = 1; int nheads = 1; int headdim = 16; - int max_seqlen_q_ = 128; - int max_seqlen_k_ = 128; + int seq = 8; + int max_seqlen_q_ = seq; + int max_seqlen_k_ = seq; - float softmax_scale = 0.1; + float softmax_scale = 1; bool zero_tensors = false; bool is_causal = false; @@ -393,8 +413,8 @@ void test_fwd_with_bias() { // if (l == 0) attn_bias_cpu[i][j][k][l] = -3e4; // else attn_bias_cpu[i][j][k][l] = 0; - attn_bias_cpu[i][j][k][l] = 0; - // attn_bias_cpu[i][j][k][l] = cnt * 0.001; + // attn_bias_cpu[i][j][k][l] = 0; + attn_bias_cpu[i][j][k][l] = cnt * 0.001; cnt ++; // printf("i=%d, j=%d, k=%d, l=%d attn_bias=%f\n", i, j, k, l, attn_bias_cpu[i][j][k][l]); // std::cout << "i=" << i << ", j=" << j << ", k=" << k << ", l" @@ -408,40 +428,106 @@ void test_fwd_with_bias() { c10::optional gen_; c10::optional attn_mask; - - // std::cout << attn_mask << std::endl; - - std::vector ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - attn_mask, - attn_bias - ); + std::vector ret ; + + if (has_bias) { + ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + attn_mask, + attn_bias + ); + }else{ + ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + attn_mask, + attn_mask + // no bias + ); + } // ret: std::vector result = {o, softmax_lse}; // [bs * seq * seq, head, head_dim] // [1 * 2 * 2, 1, 16] - std::cout << "Ret vec size is " << ret.size(); - for (int i = 0; i < ret.size(); i ++) { - ret[i].cpu(); - std::cout << ret[i] << std::endl; - } + std::cout << "fwd Ret vec size is " << ret.size(); + // for (int i = 0; i < ret.size(); i ++) { + // ret[i].cpu(); + // std::cout << ret[i] << std::endl; + // } + dump_tensor("attn_output", ret[0]); + dump_tensor("attn_lse", ret[1]); + + // at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + // at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + // at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + // at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + // auto dout = dout_cpu.cuda(); + // auto dq = dq_cpu.cuda(); + // auto dk = dk_cpu.cuda(); + // auto dv = dv_cpu.cuda(); + + // std::vector bwd_ret = mha_bwd( + // dout, + // q, + // k, + // v, + // ret[0], + // ret[1], + // dq, + // dk, + // dv, + // cu_seqlens_q, // b + 1 + // cu_seqlens_k, // b + 1 + // max_seqlen_q_, + // max_seqlen_k_, + // 0.0, + // softmax_scale, + // zero_tensors, + // is_causal, + // gen_, + // attn_mask, + // attn_bias + // ); + + // std::cout << "bwd Ret vec size is " << ret.size(); + // for (int i = 0; i < bwd_ret.size(); i ++) { + // bwd_ret[i].cpu(); + // std::cout << bwd_ret[i] << std::endl; + // } } -int main(){ +int main(int argc, char** argv){ // test_fwd(); // test_fwd_with_bias_mini(); - test_fwd_with_bias(); + bool has_bias = false; + if( argc == 2 ) { + std::cout << "argv: " << argv[1] << std::endl; + has_bias = true; + } + test_fwd_with_bias(has_bias); return 0; } diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py new file mode 100644 index 000000000..743ad17eb --- /dev/null +++ b/tests/tools/check_output.py @@ -0,0 +1,115 @@ +import numpy as np + +batch_size = 1 +nheads = 1 +headdim = 16 +seq = 8 +max_seqlen_q_ = seq +max_seqlen_k_ = seq + + +q_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=np.float16) +k_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=np.float16) +v_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=np.float16) + +cnt = 0 +for i in range(batch_size * max_seqlen_k_ * max_seqlen_k_): + for j in range(nheads): + for k in range(headdim): + q_cpu[i][j][k] = cnt * 0.001 + k_cpu[i][j][k] = cnt * 0.001 + v_cpu[i][j][k] = cnt * 0.001 + cnt += 1 + +bias_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float32) +cnt = 0 + +for i in range(batch_size * max_seqlen_k_): + for j in range(nheads): + for k in range(max_seqlen_q_): + for l in range(max_seqlen_k_): + bias_ref[i][j][k][l] = cnt * 0.001 + cnt += 1 + + +def softmax(logit): + max_value_over_last_dim = np.max(logit, axis=-1, keepdims=True) + logit_sub_max_value = logit - max_value_over_last_dim + + exp_x = np.exp(logit_sub_max_value) + + probs = exp_x / np.sum(exp_x, axis=-1, keepdims=True) + return probs + + +def fwd(q, k, v, max_seqlen_q, bias=None): + + batch_size = int(q.shape[0] / max_seqlen_q) + head_num = q.shape[1] + head_dim = q.shape[2] + + q = q.reshape(batch_size, max_seqlen_q, head_num, head_dim) + k = k.reshape(batch_size, max_seqlen_q, head_num, head_dim) + v = v.reshape(batch_size, max_seqlen_q, head_num, head_dim) + + q = q.transpose(0,2,1,3) + k = k.transpose(0,2,1,3) + v = v.transpose(0,2,1,3) + + print ("data q block 0 = {}".format(q[0, 0, :, :])) + + s = np.matmul(q, k.transpose(0,1,3,2)) + + if bias is not None: + s = s + bias + + p = softmax(s) + + o = np.matmul(p, v) + + o = o.transpose(0,2,1,3).reshape(batch_size * max_seqlen_q, head_num, head_dim) + + return s, p, o + + +def compute_lse(s): + max_value_over_last_dim = np.max(s, axis=-1, keepdims=True) + logit_sub_max_value = s - max_value_over_last_dim + + exp_x = np.exp(logit_sub_max_value) + + softmax_lse = np.max(s, axis=-1, keepdims=True) + np.log(np.sum(exp_x, axis=-1, keepdims=True)) + return softmax_lse + + + +if __name__ == '__main__': + + has_bias = True + + if has_bias: + s, p, o = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) + else: + s, p, o = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_) + # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) + + # attn_output = np.loadtxt("attn_output.data", delimiter=" ") + attn_output = np.genfromtxt("attn_output.data", delimiter=" ", dtype=np.float16) + attn_output = attn_output.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) + # attn_output = attn_output.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + print ("output max error: ", np.abs(o - attn_output).max()) + + attn_lse = np.genfromtxt("attn_lse.data", delimiter=" ", dtype=np.float32) + max_seqlen_q_pad = ((max_seqlen_q_ + 16 - 1) // 16) * 16 + attn_lse = attn_lse.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_pad) + print ("attn lse: ", attn_lse) + attn_lse = attn_lse[:,:,:max_seqlen_q_] + + lse_ref = compute_lse(s) + lse_ref = lse_ref.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_) + print ("ref lse: ", lse_ref) + + print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) + print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) + + diff --git a/tests/tools/rebuild_mat.py b/tests/tools/rebuild_mat.py new file mode 100644 index 000000000..1fdf98e0a --- /dev/null +++ b/tests/tools/rebuild_mat.py @@ -0,0 +1,94 @@ +from parse import parse +import sys +import numpy as np + +filename = "./output.log" +if len(sys.argv) > 1: + filename = sys.argv[1] + +# AttnBias: threadIdx.x = 0, threadIdx.y = 0, mi = 0, ni = 0, ii = 0, jj = 0, value = 0.000000 +format_string = 'AttnBias: threadIdx.x = {}, threadIdx.y = {}, mi = {}, ni = {}, ii = {}, jj = {}, value = {}, ldx = {}, blockIdx.x = {}' +batch_size = 1 +nheads = 1 +headdim = 16 +seq = 8 +seq_q = 8 +max_seqlen_q_ = seq_q +max_seqlen_k_ = seq_q + + +mask_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) +cnt = 0 + +for i in range(batch_size * max_seqlen_k_): + for j in range(nheads): + for k in range(max_seqlen_q_): + for l in range(max_seqlen_k_): + mask_ref[i][j][k][l] = cnt * 0.001 + cnt += 1 + +# mask = np.zeros([1, 1, max_seqlen_q_, max_seqlen_k_], dtype=np.float32) +# batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_ +mask = np.zeros([batch_size * max_seqlen_k_, nheads, 16, 128], dtype=np.float16) + + +def parse_bias_load(filename): + with open(filename, "r") as f: + for line in f.readlines(): + # print (line) + if line.startswith("AttnBias:"): + # print (line.strip()) + + result = parse(format_string, line.strip()) + print (result) + # import pdb; pdb.set_trace() + # if result[0] == 0: + # print (result[0], result[1], result[2], result[3], result[4], result[5], result[6]) + tidx_ = int(result[0]) + mi = int(result[2]) + ni = int(result[3]) + ii = int(result[4]) + jj = int(result[5]) + value = float(result[6]) + block_idx = int(result[8]) + + warp = tidx_ // 32 + lane = tidx_ % 32 + + warp_n = (warp // 1) + warp_m = (warp % 1) + + quad = lane // 4 + tid = (lane % 4) * 2 + + row = warp_m * 16 + quad + col = warp_n * 16 + tid + + current_row = mi * 16 + ii * 8 + row + # current_col = ni * 64 + jj * 8 + col + current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col + + # if (current_row < 8 and current_col < 8): + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + # mask[0, 0, current_row, current_col] = value + print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_col={}, value={}".format( + warp, lane, quad, tid, current_row, current_col, value)) + mask[block_idx, 0, current_row, current_col] = value + + +def check(mask, mask_ref, block_idx=0): + flag = True + bs, nheads, max_seqlen_q_, max_seqlen_k_ = mask_ref.shape + for i in range(max_seqlen_q_): + for j in range(max_seqlen_k_): + if (abs(mask[0, 0, i, j] - mask_ref[block_idx, 0, i, j]) > 1e-3): + print ("False in block_idx = {}, i = {}, j = {}, mask = {}, mask_ref = {}".format(block_idx, + i, j, mask[0, 0, i, j] - mask_ref[block_idx, 0, i, j])) + flag = False + return flag + +parse_bias_load(filename) + +# block_idx = 1 +# print (check(mask, mask_ref, block_idx)) \ No newline at end of file From b824852d28253a07a372c184515e5cd32a3da030 Mon Sep 17 00:00:00 2001 From: xh Date: Tue, 30 Aug 2022 14:29:28 +0800 Subject: [PATCH 23/71] fix test case --- benchmarks/test/test_forward_with_mask_v2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/benchmarks/test/test_forward_with_mask_v2.py b/benchmarks/test/test_forward_with_mask_v2.py index 2490e0864..3906165f8 100644 --- a/benchmarks/test/test_forward_with_mask_v2.py +++ b/benchmarks/test/test_forward_with_mask_v2.py @@ -77,9 +77,6 @@ def _flash_attn(q, k, v, attn_mask=None): no_heads, n, c = q.shape[-3:] dtype = q.dtype - if attn_mask is not None: - attn_mask = attn_mask.half() - # [*, B, N, H, C] q = q.transpose(-2, -3) k = k.transpose(-2, -3) @@ -111,6 +108,7 @@ def _flash_attn(q, k, v, attn_mask=None): if attn_mask is not None: # import pdb; pdb.set_trace() attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]) + attn_mask = attn_mask.contiguous() out = flash_attn_unpadded_func( q, From 319b94037645245d6313643326a8195dd65f92a2 Mon Sep 17 00:00:00 2001 From: xh Date: Tue, 30 Aug 2022 14:41:43 +0800 Subject: [PATCH 24/71] add without mask test --- ...rd_without_mask_v2.py => test_forward_without_bias_mask.py} | 3 --- 1 file changed, 3 deletions(-) rename benchmarks/test/{test_forward_without_mask_v2.py => test_forward_without_bias_mask.py} (99%) diff --git a/benchmarks/test/test_forward_without_mask_v2.py b/benchmarks/test/test_forward_without_bias_mask.py similarity index 99% rename from benchmarks/test/test_forward_without_mask_v2.py rename to benchmarks/test/test_forward_without_bias_mask.py index a24a31d6c..3306aa885 100644 --- a/benchmarks/test/test_forward_without_mask_v2.py +++ b/benchmarks/test/test_forward_without_bias_mask.py @@ -77,9 +77,6 @@ def _flash_attn(q, k, v, attn_mask=None): no_heads, n, c = q.shape[-3:] dtype = q.dtype - if attn_mask is not None: - attn_mask = attn_mask.half() - # [*, B, N, H, C] q = q.transpose(-2, -3) k = k.transpose(-2, -3) From 1be4a39f3accf5836135f3835fa4eb1b6c1affe1 Mon Sep 17 00:00:00 2001 From: xh Date: Tue, 30 Aug 2022 19:53:47 +0800 Subject: [PATCH 25/71] add kernel test --- build.sh | 10 ++ tests/test_forward.cu | 98 ++++++++---- tests/tools/check_output.py | 292 +++++++++++++++++++++++++++++++++--- 3 files changed, 348 insertions(+), 52 deletions(-) create mode 100644 build.sh diff --git a/build.sh b/build.sh new file mode 100644 index 000000000..f8309df33 --- /dev/null +++ b/build.sh @@ -0,0 +1,10 @@ + +#rm -rf build flash_attn_cuda.cpython-37m-x86_64-linux-gnu.so + +start=`date +%s` +#CXX="/usr/lib/ccache/c++" +python setup.py build -j 8 develop 2>&1 | tee build.log +end=`date +%s` + +runtime=$((end-start)) +echo ${runtime} diff --git a/tests/test_forward.cu b/tests/test_forward.cu index ee4f6d3c6..a12393991 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -414,7 +414,7 @@ void test_fwd_with_bias(bool has_bias) { // else attn_bias_cpu[i][j][k][l] = 0; // attn_bias_cpu[i][j][k][l] = 0; - attn_bias_cpu[i][j][k][l] = cnt * 0.001; + attn_bias_cpu[i][j][k][l] = cnt * 0.1; cnt ++; // printf("i=%d, j=%d, k=%d, l=%d attn_bias=%f\n", i, j, k, l, attn_bias_cpu[i][j][k][l]); // std::cout << "i=" << i << ", j=" << j << ", k=" << k << ", l" @@ -480,38 +480,70 @@ void test_fwd_with_bias(bool has_bias) { dump_tensor("attn_output", ret[0]); dump_tensor("attn_lse", ret[1]); - // at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - // at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - // at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - // at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - // auto dout = dout_cpu.cuda(); - // auto dq = dq_cpu.cuda(); - // auto dk = dk_cpu.cuda(); - // auto dv = dv_cpu.cuda(); - - // std::vector bwd_ret = mha_bwd( - // dout, - // q, - // k, - // v, - // ret[0], - // ret[1], - // dq, - // dk, - // dv, - // cu_seqlens_q, // b + 1 - // cu_seqlens_k, // b + 1 - // max_seqlen_q_, - // max_seqlen_k_, - // 0.0, - // softmax_scale, - // zero_tensors, - // is_causal, - // gen_, - // attn_mask, - // attn_bias - // ); + at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + auto dout = dout_cpu.cuda(); + auto dq = dq_cpu.cuda(); + auto dk = dk_cpu.cuda(); + auto dv = dv_cpu.cuda(); + std::vector bwd_ret; + + if (has_bias) { + bwd_ret = mha_bwd( + dout, + q, + k, + v, + ret[0], + ret[1], + dq, + dk, + dv, + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + gen_, + attn_mask, + attn_bias + ); + }else{ + bwd_ret = mha_bwd( + dout, + q, + k, + v, + ret[0], + ret[1], + dq, + dk, + dv, + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + gen_, + attn_mask, + attn_mask + // placeholder + ); + } + + dump_tensor("attn_dq", dq); + dump_tensor("attn_dk", dk); + dump_tensor("attn_dv", dv); // std::cout << "bwd Ret vec size is " << ret.size(); // for (int i = 0; i < bwd_ret.size(); i ++) { diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py index 743ad17eb..de4675da5 100644 --- a/tests/tools/check_output.py +++ b/tests/tools/check_output.py @@ -1,4 +1,8 @@ +from audioop import bias +from operator import truediv +from socket import NI_NAMEREQD import numpy as np +import torch batch_size = 1 nheads = 1 @@ -7,10 +11,12 @@ max_seqlen_q_ = seq max_seqlen_k_ = seq +dtypes = np.float16 -q_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=np.float16) -k_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=np.float16) -v_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=np.float16) + +q_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) +k_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) +v_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) cnt = 0 for i in range(batch_size * max_seqlen_k_ * max_seqlen_k_): @@ -21,17 +27,20 @@ v_cpu[i][j][k] = cnt * 0.001 cnt += 1 -bias_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float32) +bias_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) cnt = 0 for i in range(batch_size * max_seqlen_k_): for j in range(nheads): for k in range(max_seqlen_q_): for l in range(max_seqlen_k_): - bias_ref[i][j][k][l] = cnt * 0.001 + bias_ref[i][j][k][l] = cnt * 0.1 cnt += 1 +dout = np.ones([batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim], dtype=dtypes) + + def softmax(logit): max_value_over_last_dim = np.max(logit, axis=-1, keepdims=True) logit_sub_max_value = logit - max_value_over_last_dim @@ -56,9 +65,9 @@ def fwd(q, k, v, max_seqlen_q, bias=None): k = k.transpose(0,2,1,3) v = v.transpose(0,2,1,3) - print ("data q block 0 = {}".format(q[0, 0, :, :])) + # print ("data q block 0 = {}".format(q[0, 0, :, :])) - s = np.matmul(q, k.transpose(0,1,3,2)) + s = np.matmul(q, k.transpose(0, 1, 3, 2)) if bias is not None: s = s + bias @@ -67,11 +76,91 @@ def fwd(q, k, v, max_seqlen_q, bias=None): o = np.matmul(p, v) - o = o.transpose(0,2,1,3).reshape(batch_size * max_seqlen_q, head_num, head_dim) + # o = o.transpose(0,2,1,3).reshape(batch_size * max_seqlen_q, head_num, head_dim) + return s, p, o, q, k, v + + +def bwd(dout, q, k, v, max_seqlen_q, bias=None): + s, p, o, _, _, _ = fwd(q, k, v, max_seqlen_q=max_seqlen_q, bias=bias) + + batch_size = int(q.shape[0] / max_seqlen_q) + head_num = q.shape[1] + head_dim = q.shape[2] + + dout = dout.reshape(batch_size, max_seqlen_q, head_num, head_dim) + dout = dout.transpose(0, 2, 1, 3) + # import pdb; pdb.set_trace() + + q = q.reshape(batch_size, max_seqlen_q, head_num, head_dim) + k = k.reshape(batch_size, max_seqlen_q, head_num, head_dim) + v = v.reshape(batch_size, max_seqlen_q, head_num, head_dim) + + q = q.transpose(0, 2, 1, 3) + k = k.transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + + # get dv + dv = np.matmul(p.transpose(0, 1, 3, 2), dout) + + # get dp + dp = np.matmul(dout, v.transpose(0, 1, 3, 2)) + + # ds_{i:} = P_{i:} \dot dP_{i:} - D_{i}P_{i:} + + ds = np.zeros([batch_size, head_num, max_seqlen_q, max_seqlen_q]) + for b in range(batch_size): + for h in range(head_num): + for i in range(max_seqlen_q): + # please refer equation 4 + Di = 0.0 + for l in range(max_seqlen_q): + Di += p[b][h][i][l] * dp[b][h][i][l] + + for j in range(max_seqlen_q): + ds[b][h][i][j] = p[b][h][i][j] * (dp[b][h][i][j] - Di) + + # get dq + dq = np.matmul(ds, k) + # dq = dq.transpose(0, 2, 1, 3) + + # get dk + dk = np.matmul(ds.transpose(0, 1, 3, 2), q) + # dk = dk.transpose(0, 2, 1, 3) + + if bias is None: + dbias = None + else: + dbias = ds.reshape(-1, *bias.shape).sum(axis=0) + + return dq, dk, dv, ds, dp, dbias + + +def fwd_pt(q_pt, k_pt, v_pt, bias=None): + s = torch.matmul(q_pt, k_pt.transpose(-1, -2)) + + if bias is not None: + s = s + bias + + p = torch.nn.functional.softmax(s, dim=-1) + o = torch.matmul(p, v_pt) return s, p, o +def bwd_pt(dout, q, k, v, max_seqlen_q, bias=None): + # q is [batch * seq * seq, head, head_dim] + q_pt, k_pt, v_pt, dout_pt, bias_pt = prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=bias) + + s, p, o = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt) + + if bias is None: + dq, dk, dv = torch.autograd.grad(o, (q_pt, k_pt, v_pt), dout_pt) + return dq, dk, dv, None + else: + dq, dk, dv, dbias = torch.autograd.grad(o, (q_pt, k_pt, v_pt, bias_pt), dout_pt) + return dq, dk, dv, dbias + + def compute_lse(s): max_value_over_last_dim = np.max(s, axis=-1, keepdims=True) logit_sub_max_value = s - max_value_over_last_dim @@ -82,34 +171,199 @@ def compute_lse(s): return softmax_lse - -if __name__ == '__main__': - - has_bias = True - +def check_fwd_kernel(has_bias=False): if has_bias: - s, p, o = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) + s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) else: - s, p, o = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_) + s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_) # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) # attn_output = np.loadtxt("attn_output.data", delimiter=" ") - attn_output = np.genfromtxt("attn_output.data", delimiter=" ", dtype=np.float16) + attn_output = np.genfromtxt("attn_output.data", delimiter=" ", dtype=np.float32) attn_output = attn_output.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) - # attn_output = attn_output.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + attn_output = attn_output.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + attn_output = attn_output.transpose(0, 2, 1, 3) print ("output max error: ", np.abs(o - attn_output).max()) attn_lse = np.genfromtxt("attn_lse.data", delimiter=" ", dtype=np.float32) max_seqlen_q_pad = ((max_seqlen_q_ + 16 - 1) // 16) * 16 attn_lse = attn_lse.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_pad) - print ("attn lse: ", attn_lse) + # print ("attn lse: ", attn_lse) attn_lse = attn_lse[:,:,:max_seqlen_q_] lse_ref = compute_lse(s) lse_ref = lse_ref.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_) - print ("ref lse: ", lse_ref) + # print ("ref lse: ", lse_ref) print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) +def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): + diff = np.abs(pred - gt) + + cnt = 0 + for index, x in np.ndenumerate(diff): + if x > abs_eps: + relative_diff = np.abs(x / gt[index]) + if relative_diff > relative_rps: + cnt += 1 + if verbose: + print (index, x, gt[index], relative_diff) + + if cnt > 0: + print ("not so match") + return False + else: + return True + + +def check_bwd_kernel(has_bias=False): + + if has_bias: + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) + else: + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_) + + attn_dq = np.genfromtxt("attn_dq.data", delimiter=" ", dtype=np.float32) + attn_dk = np.genfromtxt("attn_dk.data", delimiter=" ", dtype=np.float32) + attn_dv = np.genfromtxt("attn_dv.data", delimiter=" ", dtype=np.float32) + + attn_dq = attn_dq.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) + attn_dk = attn_dk.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) + attn_dv = attn_dv.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) + + attn_dq = attn_dq.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + attn_dk = attn_dk.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + attn_dv = attn_dv.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + + attn_dq = attn_dq.transpose(0, 2, 1, 3) + attn_dk = attn_dk.transpose(0, 2, 1, 3) + attn_dv = attn_dv.transpose(0, 2, 1, 3) + + assert (dq.shape == attn_dq.shape), "oh dq shape didn't match" + assert (dk.shape == attn_dk.shape), "oh dk shape didn't match" + assert (dv.shape == attn_dv.shape), "oh dv shape didn't match" + + print ("max error in dq: ", np.abs(attn_dq - dq).max(), ) + print ("max error in dk: ", np.abs(attn_dk - dk).max(), ) + print ("max error in dv: ", np.abs(attn_dv - dv).max(), ) + + print ("same matrix check q: ", is_same_matrix(attn_dq, dq)) + print ("same matrix check k: ", is_same_matrix(attn_dk, dk)) + print ("same matrix check v: ", is_same_matrix(attn_dv, dv)) + + +def check_bwd_np(has_bias=False): + print ("==== check bwd np ====") + if has_bias: + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) + dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) + else: + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None) + dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None) + + assert (dq.shape == dq_pt.detach().cpu().numpy().shape), "oh dq shape didn't match" + assert (dk.shape == dk_pt.detach().cpu().numpy().shape), "oh dk shape didn't match" + assert (dv.shape == dv_pt.detach().cpu().numpy().shape), "oh dv shape didn't match" + if has_bias: + assert (dbias.shape == dbias_pt.detach().cpu().numpy().shape), "oh dbias shape didn't match" + + print ("max error in dq: ", np.abs( dq - dq_pt.detach().cpu().numpy() ).max()) + print ("max error in dk: ", np.abs( dk - dk_pt.detach().cpu().numpy() ).max()) + print ("max error in dv: ", np.abs( dv - dv_pt.detach().cpu().numpy() ).max()) + if has_bias: + print ("max error in dbias: ", np.abs( dbias - dbias_pt.detach().cpu().numpy() ).max()) + + return + + +def prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=None): + q_pt = torch.from_numpy(q.copy()) + k_pt = torch.from_numpy(k.copy()) + v_pt = torch.from_numpy(v.copy()) + + batch_size = int(q.shape[0] / max_seqlen_q) + head_num = q.shape[1] + head_dim = q.shape[2] + + dout_pt = torch.from_numpy(dout.copy()) + dout_pt = dout_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) + dout_pt = dout_pt.permute(0, 2, 1, 3).cuda() + + q_pt = q_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) + k_pt = k_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) + v_pt = v_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) + + q_pt = q_pt.permute(0, 2, 1, 3).cuda() + k_pt = k_pt.permute(0, 2, 1, 3).cuda() + v_pt = v_pt.permute(0, 2, 1, 3).cuda() + + if bias is not None: + bias_pt = torch.from_numpy(bias.copy()).cuda() + bias_pt.requires_grad = True + else: + bias_pt = None + + q_pt.requires_grad = True + k_pt.requires_grad = True + v_pt.requires_grad = True + + return q_pt, k_pt, v_pt, dout_pt, bias_pt + + +def check_fwd_np(has_bias=False): + print ("==== check fwd np ====") + if has_bias: + s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) + + q_pt, k_pt, v_pt, dout_pt, bias_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) + s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias_pt) + else: + s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_) + + q_pt, k_pt, v_pt, dout_pt, _ = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_) + s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt) + + def check_input(a, b): + print ("max error in input: ", np.abs(a - b).max()) + + check_input(q, q_pt.detach().cpu().numpy()) + check_input(k, q_pt.detach().cpu().numpy()) + check_input(v, q_pt.detach().cpu().numpy()) + + assert (s.shape == s_pt.detach().cpu().numpy().shape), "oh s shape didn't match" + assert (p.shape == p_pt.detach().cpu().numpy().shape), "oh p shape didn't match" + assert (o.shape == o_pt.detach().cpu().numpy().shape), "oh o shape didn't match" + + print ("max error in s: ", np.abs( s - s_pt.detach().cpu().numpy() ).max()) + print ("max error in p: ", np.abs( p - p_pt.detach().cpu().numpy() ).max()) + print ("max error in o: ", np.abs( o - o_pt.detach().cpu().numpy() ).max()) + + return + + +if __name__ == '__main__': + # print ("====test without bias====") + # has_bias = False + # check_fwd_np(has_bias=has_bias) + # check_bwd_np(has_bias=has_bias) + # print ("====test without bias====") + + # print ("====test with bias====") + # has_bias = True + # check_fwd_np(has_bias=has_bias) + # check_bwd_np(has_bias=has_bias) + # print ("====test with bias====") + + # print ("====test kernel without bias====") + # has_bias = False + # check_fwd_kernel(has_bias=has_bias) + # check_bwd_kernel(has_bias=has_bias) + + print ("====test kernel with bias====") + has_bias = True + check_fwd_kernel(has_bias=has_bias) + check_bwd_kernel(has_bias=has_bias) + + \ No newline at end of file From 14f7e085ed86fde0ce124a9f8eb45c74c6bab5bb Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 1 Sep 2022 13:39:28 +0800 Subject: [PATCH 26/71] add ds save --- csrc/flash_attn/fmha_api.cpp | 31 ++- csrc/flash_attn/src/fmha.h | 3 + csrc/flash_attn/src/fmha/gmem_tile.h | 130 ++++++++++++ csrc/flash_attn/src/fmha/kernel_traits.h | 3 + .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 2 + .../src/fmha_dgrad_kernel_1xN_loop.h | 31 ++- .../src/fmha_fprop_fp16_kernel.sm80.cu | 4 +- tests/test_forward.cu | 34 +++- tests/tools/check_output.py | 189 ++++++++++++++++-- tests/tools/rebuild_dsoftmax.py | 64 ++++++ 10 files changed, 461 insertions(+), 30 deletions(-) create mode 100644 tests/tools/rebuild_dsoftmax.py diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 11ac0e071..3c97f699d 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -173,7 +173,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, float softmax_scale, bool is_causal, void *attn_mask, - void *attn_bias) { + void *attn_bias, + void *attn_ds) { set_params_fprop(params, b, seqlen_q, seqlen_k, h, d, @@ -204,6 +205,7 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; + params.attn_ds_ptr = attn_ds; } std::vector @@ -418,6 +420,18 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(cu_seqlens_q.is_contiguous()); TORCH_CHECK(cu_seqlens_k.is_contiguous()); + if (attn_bias.has_value()) { + TORCH_CHECK(attn_bias.value().is_cuda()); + TORCH_CHECK(attn_bias.value().dtype() == q_dtype); + TORCH_CHECK(attn_bias.value().is_contiguous()); + } + + if (attn_mask.has_value()) { + TORCH_CHECK(attn_mask.value().is_cuda()); + TORCH_CHECK(attn_mask.value().dtype() == q_dtype); + TORCH_CHECK(attn_mask.value().is_contiguous()); + } + const auto sizes = q.sizes(); const int batch_size = cu_seqlens_q.numel() - 1; @@ -442,6 +456,13 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + auto opts = q.options(); + at::Tensor ds; + if (attn_bias.has_value()) { + ds = torch::empty({batch_size, num_heads, max_seqlen_q_, max_seqlen_k_}, opts.dtype(q_dtype)); + ds.zero_(); + } + int blocksize_c = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256; int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; if( max_seqlen_k_ <= 128 ) { @@ -455,7 +476,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different. auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous(); - auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); at::Tensor dq_tmp; if (loop) { dq_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); } @@ -488,7 +509,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_scale, is_causal, attn_mask ? attn_mask->data_ptr() : nullptr, - attn_bias ? attn_bias->data_ptr() : nullptr); + attn_bias ? attn_bias->data_ptr() : nullptr, + attn_bias ? ds.data_ptr() : nullptr); + + // used for dbias auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -760,6 +784,7 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size softmax_scale, is_causal, nullptr, + nullptr, nullptr); params.blockmask = static_cast(blockmask.data_ptr()); diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 04fcba0f4..a130525de 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -79,6 +79,9 @@ struct FMHA_fprop_params : public Qkv_params { // The attn bias matrix void * __restrict__ attn_bias_ptr; + // The ds matrix + void * __restrict__ attn_ds_ptr; + // The O matrix (output). void * __restrict__ o_ptr; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index ad074e3e2..caecce6f4 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -772,6 +772,136 @@ struct Gmem_tile_mma_bias { const int tidx_; }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// attn bias struct like s, maybe later can reuse the above declaration +template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> +struct Gmem_tile_mma_ds { + + using Mma_tile = fmha::Hmma_tile; + // The type of the vectors stored by each STG. + using StoreType = uint32_t; + + // The number of MMAs in the M dimension. + static constexpr int M = Mma_tile::MMAS_M; + // The number of MMAs in the N dimension. + static constexpr int N = Mma_tile::MMAS_N; + + // The number of "rows" stored per iteration of the loop. The output of 1 MMA. + static constexpr int ROWS = Cta_tile::M; + static constexpr int COLS = Cta_tile::N; + + // The size of each LDG. + // load two elements of data + static constexpr int BYTES_PER_LDG = 2 * BYTES_PER_ELEMENT; + // The size of a row in bytes. + static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT; + + // The number of LDGS needed to store a chunk of the P matrix in total. + // Tell me if has more efficient way + static constexpr int LDGS_PER_THREAD_PER_WARP = 4; + static constexpr int THREADS_PER_QUAD = 4; + static constexpr int COL_PER_MMA_PER_CTA = Cta_tile::THREADS_PER_WARP / THREADS_PER_QUAD; + + // Ctor. + template< typename Params, typename Block_info > + inline __device__ Gmem_tile_mma_ds(const Params ¶ms, + // const uint32_t row_stride_in_elts, const uint32_t head_stride_in_elts, + const Block_info& binfo, const int tidx, const int loop_step_idx) + : ptr_(static_cast(params.attn_ds_ptr)) + // : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) + , actual_seqlen_q(binfo.actual_seqlen_q) + , actual_seqlen_k(binfo.actual_seqlen_k) + , tidx_(tidx) + , loop_step_idx(loop_step_idx) + { + row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + + const int warp = tidx_ / Cta_tile::THREADS_PER_WARP; + const int lane = tidx_ % Cta_tile::THREADS_PER_WARP; + + // find the warp in the Cta tile + const int warp_n = (warp / Cta_tile::WARPS_M); + const int warp_m = (warp % Cta_tile::WARPS_M); + + // decompose warp into 8x4 tile + const int quad = lane / 4; + const int tid = (lane % 4) * 2; + // this col is mean the 8x4 tile's cole + + row = warp_m * Mma_tile::M_PER_MMA + quad; + static_assert(Mma_tile::M_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + col = warp_n * Mma_tile::N_PER_MMA + tid; + static_assert(Mma_tile::N_PER_MMA == 16, + "only support sm80 m16n8k16 tensor core"); + + // The distance between two blocks (in bytes). + // TODO: mask is [bs, head, seq_q, seq_k] + // The block index. + uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + + // the index of bs and head dim + uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + // row_offset = (uint32_t)(row * row_stride_in_bytes); + row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + // do we need to move col first if seklen_k > cols + ptr_ += row_offset; + } + + // Store to global memory. + inline __device__ void store(const float (&softmax)[2 * M][4 * N]) { + uint32_t preds; + + #pragma unroll + for( int mi = 0; mi < M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < N; ni++ ) { + #pragma unroll + for ( int ii = 0; ii < 2; ++ii ) { + #pragma unroll + for (int jj = 0; jj < 2; ++jj ) { + float tmp00 = softmax[2 * mi + ii][4 * ni + jj * 2]; + float tmp01 = softmax[2 * mi + ii][4 * ni + jj * 2 + 1]; + uint32_t dst; + dst = fmha::float2_to_half2(tmp00, tmp01); + + const int current_row = mi * ROWS + ii * 8; + const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; + + char *ptrs = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + (uint32_t)current_col * BYTES_PER_ELEMENT; + + preds = (current_row <= min(ROWS, actual_seqlen_q)) + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); + + if (preds) { + fmha::stg(ptrs, dst); + } + } + } + } + } + } + + inline __device__ void move(const int steps = 1) { + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + this->actual_seqlen_q -= ROWS * steps; + } + + int row; + int col; + const int loop_step_idx; + uint32_t row_stride_in_bytes; + // The pointer. + char *ptr_; + int actual_seqlen_q; + int actual_seqlen_k; + const int tidx_; +}; + + //////////////////////////////////////////////////////////////////////////////////////////////////// template< diff --git a/csrc/flash_attn/src/fmha/kernel_traits.h b/csrc/flash_attn/src/fmha/kernel_traits.h index c5f573dea..4c898f92d 100644 --- a/csrc/flash_attn/src/fmha/kernel_traits.h +++ b/csrc/flash_attn/src/fmha/kernel_traits.h @@ -77,6 +77,9 @@ struct FMHA_kernel_traits { // Gmem_tile_mma_bias using Gmem_tile_bias = fmha::Gmem_tile_mma_bias; + // Gmem_tile_mma_ds + using Gmem_tile_ds = fmha::Gmem_tile_mma_ds; + // The shared memory tile to transpose S. using Smem_tile_st = fmha::Smem_tile_mma_transposed; diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 403ce9237..b28a5422e 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -47,6 +47,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ } dim3 grid(params.b, params.h); kernel<<>>(params); + printf("bwd grid size: %d %d\n", params.b, params.h); + printf("bwd block size: %d\n", Kernel_traits::THREADS); FMHA_CHECK_CUDA(cudaPeekAtLastError()); }); } diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 9f8130e37..f241def26 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -161,6 +161,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); // TODO: load fun as s + using Gmem_tile_ds = typename Kernel_traits::Gmem_tile_ds; + Gmem_tile_ds gmem_ds(params, binfo, tidx, loop_step_idx); + fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. @@ -220,6 +223,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng if (!(params.attn_bias_ptr == nullptr)) { // TODO: mask move gmem_bias.move(begin); + gmem_ds.move(begin); } if (!Is_first) { @@ -481,8 +485,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng for (int jj = 0; jj < 4; jj ++) { int st_row = 2 * mi + ii; int st_col = 4 * ki + jj; - printf("bwd softmax: threadIdx=%d, l=%d, mi=%d, ki=%d, ii=%d, jj=%d, elt=%d", - threadIdx.x, mi, ki, ii, jj, softmax.elt_[st_row][st_col]); + printf("bwd softmax: threadIdx=%d, l=%d, mi=%d, ki=%d, ii=%d, jj=%d, elt=%f\n", + threadIdx.x, l, mi, ki, ii, jj, softmax.elt_[st_row][st_col]); } } } @@ -503,7 +507,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.template pack(frag_p); // Store s * dmask to smem for transpose - // how to test smem_s.store(frag_p); // Trigger the load for the next Q values. @@ -595,6 +598,28 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.template pack(frag_p); + // if constexpr (has_bias) { + if (!(params.attn_bias_ptr == nullptr)) { +#ifdef DEBUG_PRINT + if ((blockIdx.x == 0) && (blockIdx.y == 0)) { + for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { + for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { + for (int ii = 0; ii < 2; ii ++) { + for (int jj = 0; jj < 4; jj ++) { + int st_row = 2 * mi + ii; + int st_col = 4 * ki + jj; + printf("bwd dsoftmax: threadIdx=%d, l=%d, mi=%d, ki=%d, ii=%d, jj=%d, elt=%f\n", + threadIdx.x, l, mi, ki, ii, jj, softmax.elt_[st_row][st_col]); + } + } + } + } + printf("\n"); + } +#endif + gmem_ds.store(softmax.elt_); + } + // Store dp to smem for transpose smem_dp.store(frag_p); diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index d0af0c850..d86118f05 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -71,8 +71,8 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, auto kernel = &fmha_fprop_fp16_sm80_loop_kernel; dim3 grid(launch_params.params.b, launch_params.params.h); - // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // printf("block size: %d\n", Kernel_traits::THREADS); + printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + printf("block size: %d\n", Kernel_traits::THREADS); kernel<<>>( launch_params.params); FMHA_CHECK_CUDA(cudaPeekAtLastError()); diff --git a/tests/test_forward.cu b/tests/test_forward.cu index a12393991..69db09b23 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -329,8 +329,8 @@ void test_fwd_with_bias_mini() { } -void dump_tensor(const std::string &tensor_name, at::Tensor &tensor) { - std::string file_name = tensor_name + ".data"; +void dump_tensor(const std::string &tensor_name, at::Tensor &tensor, const std::string &label) { + std::string file_name = label + "_" + tensor_name + ".data"; std::ofstream file(file_name.c_str()); // file << tensor_name << std::endl; // file << tensor << std::endl; @@ -448,6 +448,8 @@ void test_fwd_with_bias(bool has_bias) { attn_mask, attn_bias ); + dump_tensor("attn_output", ret[0], "has_bias"); + dump_tensor("attn_lse", ret[1], "has_bias"); }else{ ret = mha_fwd( q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i @@ -467,20 +469,30 @@ void test_fwd_with_bias(bool has_bias) { attn_mask // no bias ); + dump_tensor("attn_output", ret[0], ""); + dump_tensor("attn_lse", ret[1], ""); } // ret: std::vector result = {o, softmax_lse}; // [bs * seq * seq, head, head_dim] // [1 * 2 * 2, 1, 16] - std::cout << "fwd Ret vec size is " << ret.size(); + // std::cout << "fwd Ret vec size is " << ret.size(); // for (int i = 0; i < ret.size(); i ++) { // ret[i].cpu(); // std::cout << ret[i] << std::endl; // } - dump_tensor("attn_output", ret[0]); - dump_tensor("attn_lse", ret[1]); at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + dout_cpu[i][j][k] = cnt * 0.001; + cnt ++; + } + } + } at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); @@ -493,6 +505,8 @@ void test_fwd_with_bias(bool has_bias) { std::vector bwd_ret; if (has_bias) { + // modify ret[1] + bwd_ret = mha_bwd( dout, q, @@ -515,6 +529,9 @@ void test_fwd_with_bias(bool has_bias) { attn_mask, attn_bias ); + dump_tensor("attn_dq", dq, "has_bias"); + dump_tensor("attn_dk", dk, "has_bias"); + dump_tensor("attn_dv", dv, "has_bias"); }else{ bwd_ret = mha_bwd( dout, @@ -539,12 +556,11 @@ void test_fwd_with_bias(bool has_bias) { attn_mask // placeholder ); + dump_tensor("attn_dq", dq, ""); + dump_tensor("attn_dk", dk, ""); + dump_tensor("attn_dv", dv, ""); } - dump_tensor("attn_dq", dq); - dump_tensor("attn_dk", dk); - dump_tensor("attn_dv", dv); - // std::cout << "bwd Ret vec size is " << ret.size(); // for (int i = 0; i < bwd_ret.size(); i ++) { // bwd_ret[i].cpu(); diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py index de4675da5..af8d87f32 100644 --- a/tests/tools/check_output.py +++ b/tests/tools/check_output.py @@ -29,7 +29,6 @@ bias_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) cnt = 0 - for i in range(batch_size * max_seqlen_k_): for j in range(nheads): for k in range(max_seqlen_q_): @@ -38,8 +37,14 @@ cnt += 1 +# dout = np.random.rand(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim).astype(dtype=dtypes) +cnt = 0 dout = np.ones([batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim], dtype=dtypes) - +for i in range(batch_size * max_seqlen_k_ * max_seqlen_k_): + for j in range(nheads): + for k in range(headdim): + dout[i][j][k] = cnt * 0.001 + cnt += 1 def softmax(logit): max_value_over_last_dim = np.max(logit, axis=-1, keepdims=True) @@ -162,16 +167,23 @@ def bwd_pt(dout, q, k, v, max_seqlen_q, bias=None): def compute_lse(s): + # import pdb; pdb.set_trace() + # og_dtype = s.dtype + # s = s.astype(np.float32) + max_value_over_last_dim = np.max(s, axis=-1, keepdims=True) logit_sub_max_value = s - max_value_over_last_dim exp_x = np.exp(logit_sub_max_value) softmax_lse = np.max(s, axis=-1, keepdims=True) + np.log(np.sum(exp_x, axis=-1, keepdims=True)) + + # softmax_lse = softmax_lse.astype(og_dtype) return softmax_lse def check_fwd_kernel(has_bias=False): + print ("==== check fwd kernel with np ====") if has_bias: s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) else: @@ -179,13 +191,19 @@ def check_fwd_kernel(has_bias=False): # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) # attn_output = np.loadtxt("attn_output.data", delimiter=" ") - attn_output = np.genfromtxt("attn_output.data", delimiter=" ", dtype=np.float32) + if has_bias: + prefix = "has_bias" + print ("has bias on, prefix is ", prefix) + else: + prefix = "" + + attn_output = np.genfromtxt("{}_attn_output.data".format(prefix), delimiter=" ", dtype=np.float32) attn_output = attn_output.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) attn_output = attn_output.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) attn_output = attn_output.transpose(0, 2, 1, 3) print ("output max error: ", np.abs(o - attn_output).max()) - attn_lse = np.genfromtxt("attn_lse.data", delimiter=" ", dtype=np.float32) + attn_lse = np.genfromtxt("{}_attn_lse.data".format(prefix), delimiter=" ", dtype=np.float32) max_seqlen_q_pad = ((max_seqlen_q_ + 16 - 1) // 16) * 16 attn_lse = attn_lse.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_pad) # print ("attn lse: ", attn_lse) @@ -195,6 +213,8 @@ def check_fwd_kernel(has_bias=False): lse_ref = lse_ref.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_) # print ("ref lse: ", lse_ref) + import pdb; pdb.set_trace() + print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) @@ -219,15 +239,21 @@ def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): def check_bwd_kernel(has_bias=False): - + print ("==== check bwd kernel with np ====") if has_bias: dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) else: dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_) - attn_dq = np.genfromtxt("attn_dq.data", delimiter=" ", dtype=np.float32) - attn_dk = np.genfromtxt("attn_dk.data", delimiter=" ", dtype=np.float32) - attn_dv = np.genfromtxt("attn_dv.data", delimiter=" ", dtype=np.float32) + if has_bias: + prefix = "has_bias" + print ("has bias on, prefix is ", prefix) + else: + prefix = "" + + attn_dq = np.genfromtxt("{}_attn_dq.data".format(prefix), delimiter=" ", dtype=np.float32) + attn_dk = np.genfromtxt("{}_attn_dk.data".format(prefix), delimiter=" ", dtype=np.float32) + attn_dv = np.genfromtxt("{}_attn_dv.data".format(prefix), delimiter=" ", dtype=np.float32) attn_dq = attn_dq.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) attn_dk = attn_dk.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) @@ -343,6 +369,125 @@ def check_input(a, b): return +def parse_softmax_load(filename): + from parse import parse + format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' + softmax_p = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) + + with open(filename, "r") as f: + for line in f.readlines(): + # print (line) + if line.startswith("bwd softmax: "): + # print (line.strip()) + result = parse(format_string, line.strip()) + # print (result) + + tidx_ = int(result[0]) + mi = int(result[2]) + ni = int(result[3]) + ii = int(result[4]) + jj = int(result[5]) + value = float(result[6]) + + warp = tidx_ // 32 + lane = tidx_ % 32 + + warp_n = (warp // 1) + warp_m = (warp % 1) + + quad = lane // 4 + tid = (lane % 4) * 2 + + row = warp_m * 16 + quad + col = warp_n * 16 + tid + + current_row = mi * 16 + ii * 8 + row + # current_col = ni * 64 + jj * 8 + col + # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col + current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + if (current_row < 8 and current_col < 8): + print (line.strip()) + print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + warp, lane, quad, tid, current_row, current_col, value)) + softmax_p[0, 0, current_row, current_col] = value + + return softmax_p + + +def check_softmax_p(softmax_data, has_bias=False): + if has_bias: + s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) + else: + s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_) + # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) + import pdb; pdb.set_trace() + print ("max error in p: ", np.abs(p[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) + print ("same matrix check p: ", is_same_matrix(p[0, 0, :, :], softmax_data[0, 0, :, :])) + return + + +def parse_dsoftmax_load(filename): + from parse import parse + format_string = 'bwd dsoftmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' + dsoftmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) + + with open(filename, "r") as f: + for line in f.readlines(): + # print (line) + if line.startswith("bwd dsoftmax: "): + # print (line.strip()) + result = parse(format_string, line.strip()) + # print (result) + + tidx_ = int(result[0]) + mi = int(result[2]) + ni = int(result[3]) + ii = int(result[4]) + jj = int(result[5]) + value = float(result[6]) + + warp = tidx_ // 32 + lane = tidx_ % 32 + + warp_n = (warp // 1) + warp_m = (warp % 1) + + quad = lane // 4 + tid = (lane % 4) * 2 + + row = warp_m * 16 + quad + col = warp_n * 16 + tid + + current_row = mi * 16 + ii * 8 + row + # current_col = ni * 64 + jj * 8 + col + # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col + current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + if (current_row < 8 and current_col < 8): + print (line.strip()) + print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + warp, lane, quad, tid, current_row, current_col, value)) + dsoftmax[0, 0, current_row, current_col] = value + + return dsoftmax + + +def check_dsoftmax_p(softmax_data, has_bias=False): + if has_bias: + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) + else: + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_) + + # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) + import pdb; pdb.set_trace() + print ("max error in p: ", np.abs(ds[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) + print ("same matrix check p: ", is_same_matrix(ds[0, 0, :, :], softmax_data[0, 0, :, :])) + return + + if __name__ == '__main__': # print ("====test without bias====") # has_bias = False @@ -361,9 +506,27 @@ def check_input(a, b): # check_fwd_kernel(has_bias=has_bias) # check_bwd_kernel(has_bias=has_bias) - print ("====test kernel with bias====") - has_bias = True - check_fwd_kernel(has_bias=has_bias) - check_bwd_kernel(has_bias=has_bias) + # print ("====test kernel with bias====") + # has_bias = True + # check_fwd_kernel(has_bias=has_bias) + # check_bwd_kernel(has_bias=has_bias) - \ No newline at end of file + # print ("====test bwd kernel softmax without bias====") + # has_bias = False + # softmax_data = parse_softmax_load("output.log") + # check_softmax_p(softmax_data=softmax_data, has_bias=has_bias) + + # print ("====test bwd kernel softmax with bias====") + # has_bias = True + # softmax_data = parse_softmax_load("output.log") + # check_softmax_p(softmax_data=softmax_data, has_bias=has_bias) + + # print ("====test bwd kernel softmax without bias====") + # has_bias = False + # dsoftmax_data = parse_dsoftmax_load("output.log") + # check_dsoftmax_p(softmax_data=dsoftmax_data, has_bias=has_bias) + + print ("====test bwd kernel softmax with bias====") + has_bias = True + dsoftmax_data = parse_dsoftmax_load("output.log") + check_dsoftmax_p(softmax_data=dsoftmax_data, has_bias=has_bias) diff --git a/tests/tools/rebuild_dsoftmax.py b/tests/tools/rebuild_dsoftmax.py new file mode 100644 index 000000000..790f49642 --- /dev/null +++ b/tests/tools/rebuild_dsoftmax.py @@ -0,0 +1,64 @@ +from parse import parse +import sys +import numpy as np + +filename = "./output.log" +if len(sys.argv) > 1: + filename = sys.argv[1] + +# bwd softmax: threadIdx=195, l=0, mi=0, ki=1, ii=3, jj=0, elt=0.000000 +format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' +batch_size = 1 +nheads = 1 +headdim = 16 +seq = 8 +seq_q = 8 +max_seqlen_q_ = seq_q +max_seqlen_k_ = seq_q + + +d_softmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) + +def parse_dsoftmax_load(filename): + with open(filename, "r") as f: + for line in f.readlines(): + # print (line) + if line.startswith("bwd softmax: "): + # print (line.strip()) + result = parse(format_string, line.strip()) + # print (result) + + tidx_ = int(result[0]) + mi = int(result[2]) + ni = int(result[3]) + ii = int(result[4]) + jj = int(result[5]) + value = float(result[6]) + + warp = tidx_ // 32 + lane = tidx_ % 32 + + warp_n = (warp // 1) + warp_m = (warp % 1) + + quad = lane // 4 + tid = (lane % 4) * 2 + + row = warp_m * 16 + quad + col = warp_n * 16 + tid + + current_row = mi * 16 + ii * 8 + row + # current_col = ni * 64 + jj * 8 + col + # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col + current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + if (current_row < 8 and current_col < 8): + print (line.strip()) + print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + warp, lane, quad, tid, current_row, current_col, value)) + d_softmax[0, 0, current_row, current_col] = value + + +parse_dsoftmax_load(filename) +print ("output block 0 d_softmax: ", d_softmax[0, 0, :, :]) From debc0466c830bf266bbdb01feca060531deda80f Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 1 Sep 2022 13:47:14 +0800 Subject: [PATCH 27/71] fix interface --- flash_attn/flash_attn_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 6a9a59966..04488638d 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -31,7 +31,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal): softmax_d = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, attn_mask, None) + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, attn_mask, attn_bias) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dq, dk, dv, softmax_d From b8128890a6656ba4161fbce0c8d8791f1c4cef11 Mon Sep 17 00:00:00 2001 From: xh Date: Thu, 1 Sep 2022 21:10:48 +0800 Subject: [PATCH 28/71] add test --- csrc/flash_attn/fmha_api.cpp | 39 ++++++++++++++++++---------- csrc/flash_attn/src/fmha/gmem_tile.h | 2 +- tests/test_forward.cu | 4 +-- tests/tools/check_output.py | 36 ++++++++++++++++++++++++- 4 files changed, 63 insertions(+), 18 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 3c97f699d..36de75fb6 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -27,6 +27,7 @@ ******************************************************************************/ #include +#include #include #include "fmha.h" @@ -420,18 +421,6 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(cu_seqlens_q.is_contiguous()); TORCH_CHECK(cu_seqlens_k.is_contiguous()); - if (attn_bias.has_value()) { - TORCH_CHECK(attn_bias.value().is_cuda()); - TORCH_CHECK(attn_bias.value().dtype() == q_dtype); - TORCH_CHECK(attn_bias.value().is_contiguous()); - } - - if (attn_mask.has_value()) { - TORCH_CHECK(attn_mask.value().is_cuda()); - TORCH_CHECK(attn_mask.value().dtype() == q_dtype); - TORCH_CHECK(attn_mask.value().is_contiguous()); - } - const auto sizes = q.sizes(); const int batch_size = cu_seqlens_q.numel() - 1; @@ -456,6 +445,19 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (attn_bias.has_value()) { + TORCH_CHECK(attn_bias.value().is_cuda()); + TORCH_CHECK(attn_bias.value().dtype() == q_dtype); + TORCH_CHECK(attn_bias.value().is_contiguous()); + // check attn_bias shape + } + + if (attn_mask.has_value()) { + TORCH_CHECK(attn_mask.value().is_cuda()); + TORCH_CHECK(attn_mask.value().dtype() == q_dtype); + TORCH_CHECK(attn_mask.value().is_contiguous()); + } + auto opts = q.options(); at::Tensor ds; if (attn_bias.has_value()) { @@ -511,7 +513,6 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size attn_mask ? attn_mask->data_ptr() : nullptr, attn_bias ? attn_bias->data_ptr() : nullptr, attn_bias ? ds.data_ptr() : nullptr); - // used for dbias auto gen = at::get_generator_or_default( @@ -528,7 +529,17 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } launch(params, stream); - return { dq, dk, dv, softmax_d }; + + std::vector result = { dq, dk, dv, softmax_d }; + at::Tensor dbias; + if (attn_bias.has_value()) { + auto size = attn_bias->sizes(); + dbias = ds.reshape({ -1, size[0], size[1], size[2], size[3] }).sum({ 0 }); + result.push_back( dbias ); + result.push_back( ds ); + + } + return result; } std::vector diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index caecce6f4..9089c2fca 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -853,6 +853,7 @@ struct Gmem_tile_mma_ds { // Store to global memory. inline __device__ void store(const float (&softmax)[2 * M][4 * N]) { uint32_t preds; + uint32_t dst; #pragma unroll for( int mi = 0; mi < M; mi++ ) { @@ -864,7 +865,6 @@ struct Gmem_tile_mma_ds { for (int jj = 0; jj < 2; ++jj ) { float tmp00 = softmax[2 * mi + ii][4 * ni + jj * 2]; float tmp01 = softmax[2 * mi + ii][4 * ni + jj * 2 + 1]; - uint32_t dst; dst = fmha::float2_to_half2(tmp00, tmp01); const int current_row = mi * ROWS + ii * 8; diff --git a/tests/test_forward.cu b/tests/test_forward.cu index 69db09b23..1cf290326 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -505,8 +505,6 @@ void test_fwd_with_bias(bool has_bias) { std::vector bwd_ret; if (has_bias) { - // modify ret[1] - bwd_ret = mha_bwd( dout, q, @@ -532,6 +530,8 @@ void test_fwd_with_bias(bool has_bias) { dump_tensor("attn_dq", dq, "has_bias"); dump_tensor("attn_dk", dk, "has_bias"); dump_tensor("attn_dv", dv, "has_bias"); + dump_tensor("attn_dbias", bwd_ret[4], "has_bias"); + dump_tensor("attn_ds", bwd_ret[5], "has_bias"); }else{ bwd_ret = mha_bwd( dout, diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py index af8d87f32..7ccc60b18 100644 --- a/tests/tools/check_output.py +++ b/tests/tools/check_output.py @@ -36,6 +36,15 @@ bias_ref[i][j][k][l] = cnt * 0.1 cnt += 1 +# bias_ref = np.zeros([batch_size , nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) +# cnt = 0 +# for i in range(batch_size ): +# for j in range(nheads): +# for k in range(max_seqlen_q_): +# for l in range(max_seqlen_k_): +# bias_ref[i][j][k][l] = cnt * 0.1 +# cnt += 1 + # dout = np.random.rand(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim).astype(dtype=dtypes) cnt = 0 @@ -254,6 +263,8 @@ def check_bwd_kernel(has_bias=False): attn_dq = np.genfromtxt("{}_attn_dq.data".format(prefix), delimiter=" ", dtype=np.float32) attn_dk = np.genfromtxt("{}_attn_dk.data".format(prefix), delimiter=" ", dtype=np.float32) attn_dv = np.genfromtxt("{}_attn_dv.data".format(prefix), delimiter=" ", dtype=np.float32) + if has_bias: + attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) attn_dq = attn_dq.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) attn_dk = attn_dk.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) @@ -262,6 +273,9 @@ def check_bwd_kernel(has_bias=False): attn_dq = attn_dq.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) attn_dk = attn_dk.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) attn_dv = attn_dv.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + + if has_bias: + attn_dbias = attn_dbias.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) attn_dq = attn_dq.transpose(0, 2, 1, 3) attn_dk = attn_dk.transpose(0, 2, 1, 3) @@ -274,11 +288,20 @@ def check_bwd_kernel(has_bias=False): print ("max error in dq: ", np.abs(attn_dq - dq).max(), ) print ("max error in dk: ", np.abs(attn_dk - dk).max(), ) print ("max error in dv: ", np.abs(attn_dv - dv).max(), ) + if has_bias: + print ("max error in dq: ", np.abs(attn_dbias - dbias).max(), ) + # print (np.abs(attn_dbias - dbias) > 0.001) + attn_ds = np.genfromtxt("{}_attn_ds.data".format(prefix), delimiter=" ", dtype=np.float32) + attn_ds = attn_ds.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) + print ("max error in ds: ", np.abs(attn_ds - ds).max(), ) + print ("same matrix check q: ", is_same_matrix(attn_dq, dq)) print ("same matrix check k: ", is_same_matrix(attn_dk, dk)) print ("same matrix check v: ", is_same_matrix(attn_dv, dv)) - + if has_bias: + import pdb; pdb.set_trace() + print ("same matrix check dbias: ", is_same_matrix(attn_dbias, dbias)) def check_bwd_np(has_bias=False): print ("==== check bwd np ====") @@ -481,10 +504,21 @@ def check_dsoftmax_p(softmax_data, has_bias=False): else: dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_) + if has_bias: + prefix = "has_bias" + print ("has bias on, prefix is ", prefix) + else: + prefix = "" + # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) import pdb; pdb.set_trace() print ("max error in p: ", np.abs(ds[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) print ("same matrix check p: ", is_same_matrix(ds[0, 0, :, :], softmax_data[0, 0, :, :])) + + attn_ds = np.genfromtxt("{}_attn_ds.data".format(prefix), delimiter=" ", dtype=np.float32) + attn_ds = attn_ds.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) + + print ("max error in attn ds with softmax: ", np.abs(attn_ds[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) return From 62a2f88466da7cd2e7342ddc3c055daa4833d5bf Mon Sep 17 00:00:00 2001 From: robotcator Date: Sat, 3 Sep 2022 16:00:39 +0800 Subject: [PATCH 29/71] fix dbias --- csrc/flash_attn/fmha_api.cpp | 5 +- csrc/flash_attn/src/fmha/gmem_tile.h | 26 +++++++- .../src/fmha_dgrad_kernel_1xN_loop.h | 1 + tests/test_forward.cu | 3 +- tests/tools/check_output.py | 34 +++++++--- tests/tools/rebuild_bwd_softmax.py | 64 +++++++++++++++++++ 6 files changed, 119 insertions(+), 14 deletions(-) create mode 100644 tests/tools/rebuild_bwd_softmax.py diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 36de75fb6..654db659e 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -34,6 +34,7 @@ #ifdef DDEBUG_PRINT #include "fmha_api.h" +#include #endif #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -463,6 +464,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (attn_bias.has_value()) { ds = torch::empty({batch_size, num_heads, max_seqlen_q_, max_seqlen_k_}, opts.dtype(q_dtype)); ds.zero_(); + TORCH_CHECK(ds.is_contiguous()); } int blocksize_c = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256; @@ -533,11 +535,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size std::vector result = { dq, dk, dv, softmax_d }; at::Tensor dbias; if (attn_bias.has_value()) { + // compare block reduce auto size = attn_bias->sizes(); dbias = ds.reshape({ -1, size[0], size[1], size[2], size[3] }).sum({ 0 }); result.push_back( dbias ); - result.push_back( ds ); - } return result; } diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 9089c2fca..01fe1e570 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -574,7 +574,7 @@ struct Gmem_tile_mma_mask { ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + (uint32_t)current_col * BYTES_PER_ELEMENT; - preds[offset] = (current_row <= min(ROWS, actual_seqlen_q)) + preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { @@ -730,7 +730,7 @@ struct Gmem_tile_mma_bias { ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + (uint32_t)current_col * BYTES_PER_ELEMENT; - preds[offset] = (current_row <= min(ROWS, actual_seqlen_q)) + preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { @@ -845,6 +845,14 @@ struct Gmem_tile_mma_ds { // the index of bs and head dim uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; // row_offset = (uint32_t)(row * row_stride_in_bytes); +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x <= 4) && (blockIdx.y == 0)) { + printf("ds tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", + tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); + printf("ds bidb=%d, bidh=%d, param.h=%d, blockIdx.x=%d\n", binfo.bidb, binfo.bidh, params.h, blockIdx.x); + printf("\n"); + } +#endif row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); // do we need to move col first if seklen_k > cols ptr_ += row_offset; @@ -873,10 +881,22 @@ struct Gmem_tile_mma_ds { char *ptrs = ptr_ + (uint32_t)current_row * row_stride_in_bytes + (uint32_t)current_col * BYTES_PER_ELEMENT; - preds = (current_row <= min(ROWS, actual_seqlen_q)) + preds = (current_row < min(ROWS, actual_seqlen_q)) && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x <= 4) && (blockIdx.y == 0)) { + printf("ds store blockIdx.x=%d, mi=%d, ni=%d, ii=%d, jj=%d, current_row=%d, current_col=%d, float1=%f, float2=%f, begin=%p, ptrs=%p, preds=%d\n", + blockIdx.x, mi, ni, ii, jj, current_row, current_col, tmp00, tmp01, ptr_, ptrs, preds); + } +#endif if (preds) { +#ifdef DEBUG_PRINT + if ((threadIdx.x == 0) && (blockIdx.x <= 4) && (blockIdx.y == 0)) { + printf("ds store blockIdx.x=%d in, mi=%d, ni=%d, ii=%d, jj=%d, ptrs=%p, dst=%ud, data=%ud\n", + blockIdx.x, mi, ni, ii, jj, ptrs, dst, fmha::float2_to_half2(tmp00, tmp01)); + } +#endif fmha::stg(ptrs, dst); } } diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index f241def26..485e40d90 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -618,6 +618,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } #endif gmem_ds.store(softmax.elt_); + gmem_ds.move(); } // Store dp to smem for transpose diff --git a/tests/test_forward.cu b/tests/test_forward.cu index 1cf290326..06bf5248c 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -340,6 +340,7 @@ void dump_tensor(const std::string &tensor_name, at::Tensor &tensor, const std:: for (int i = 0; i < size; i ++) { file << flatten_tensor[i].item() << " "; + // file << flatten_tensor[i] << " "; } file << std::endl; } @@ -531,7 +532,7 @@ void test_fwd_with_bias(bool has_bias) { dump_tensor("attn_dk", dk, "has_bias"); dump_tensor("attn_dv", dv, "has_bias"); dump_tensor("attn_dbias", bwd_ret[4], "has_bias"); - dump_tensor("attn_ds", bwd_ret[5], "has_bias"); + // dump_tensor("attn_ds", bwd_ret[5], "has_bias"); }else{ bwd_ret = mha_bwd( dout, diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py index 7ccc60b18..4244304ee 100644 --- a/tests/tools/check_output.py +++ b/tests/tools/check_output.py @@ -518,7 +518,25 @@ def check_dsoftmax_p(softmax_data, has_bias=False): attn_ds = np.genfromtxt("{}_attn_ds.data".format(prefix), delimiter=" ", dtype=np.float32) attn_ds = attn_ds.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) + attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) + attn_dbias = attn_ds.reshape(*bias_ref.shape) + print ("max error in attn ds with softmax: ", np.abs(attn_ds[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) + print ("max error in attn ds with bwd: ", np.abs(attn_ds - ds).max(), ) + print ("max error in attn dbias with bwd: ", np.abs(attn_dbias - dbias).max(), ) + # for i in range(batch_size * max_seqlen_k_): + # for j in range(nheads): + # print ("max error in i = {}, j = {}, max_error = {} ".format(i, j, np.abs(attn_ds[i, j, :, :] - ds[i, j, :, :]).max(), )) + # print (np.abs(attn_ds[i, j, :, :] - ds[i, j, :, :]) <= 0.001) + # print ("attn_ds: ", attn_ds[i, j, :, :]) + # print ("ds: ", ds[i, j, :, :]) + + # for i in range(batch_size * max_seqlen_k_): + # for j in range(nheads): + # print ("max error in i = {}, j = {}, max_error = {} ".format(i, j, np.abs(attn_dbias[i, j, :, :] - dbias[i, j, :, :]).max(), )) + # print (np.abs(attn_dbias[i, j, :, :] - dbias[i, j, :, :]) <= 0.001) + # print ("attn_dbias: ", attn_dbias[i, j, :, :]) + # print ("dbias: ", dbias[i, j, :, :]) return @@ -540,10 +558,10 @@ def check_dsoftmax_p(softmax_data, has_bias=False): # check_fwd_kernel(has_bias=has_bias) # check_bwd_kernel(has_bias=has_bias) - # print ("====test kernel with bias====") - # has_bias = True - # check_fwd_kernel(has_bias=has_bias) - # check_bwd_kernel(has_bias=has_bias) + print ("====test kernel with bias====") + has_bias = True + check_fwd_kernel(has_bias=has_bias) + check_bwd_kernel(has_bias=has_bias) # print ("====test bwd kernel softmax without bias====") # has_bias = False @@ -560,7 +578,7 @@ def check_dsoftmax_p(softmax_data, has_bias=False): # dsoftmax_data = parse_dsoftmax_load("output.log") # check_dsoftmax_p(softmax_data=dsoftmax_data, has_bias=has_bias) - print ("====test bwd kernel softmax with bias====") - has_bias = True - dsoftmax_data = parse_dsoftmax_load("output.log") - check_dsoftmax_p(softmax_data=dsoftmax_data, has_bias=has_bias) + # print ("====test bwd kernel softmax with bias====") + # has_bias = True + # dsoftmax_data = parse_dsoftmax_load("output.log") + # check_dsoftmax_p(softmax_data=dsoftmax_data, has_bias=has_bias) diff --git a/tests/tools/rebuild_bwd_softmax.py b/tests/tools/rebuild_bwd_softmax.py new file mode 100644 index 000000000..790f49642 --- /dev/null +++ b/tests/tools/rebuild_bwd_softmax.py @@ -0,0 +1,64 @@ +from parse import parse +import sys +import numpy as np + +filename = "./output.log" +if len(sys.argv) > 1: + filename = sys.argv[1] + +# bwd softmax: threadIdx=195, l=0, mi=0, ki=1, ii=3, jj=0, elt=0.000000 +format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' +batch_size = 1 +nheads = 1 +headdim = 16 +seq = 8 +seq_q = 8 +max_seqlen_q_ = seq_q +max_seqlen_k_ = seq_q + + +d_softmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) + +def parse_dsoftmax_load(filename): + with open(filename, "r") as f: + for line in f.readlines(): + # print (line) + if line.startswith("bwd softmax: "): + # print (line.strip()) + result = parse(format_string, line.strip()) + # print (result) + + tidx_ = int(result[0]) + mi = int(result[2]) + ni = int(result[3]) + ii = int(result[4]) + jj = int(result[5]) + value = float(result[6]) + + warp = tidx_ // 32 + lane = tidx_ % 32 + + warp_n = (warp // 1) + warp_m = (warp % 1) + + quad = lane // 4 + tid = (lane % 4) * 2 + + row = warp_m * 16 + quad + col = warp_n * 16 + tid + + current_row = mi * 16 + ii * 8 + row + # current_col = ni * 64 + jj * 8 + col + # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col + current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + if (current_row < 8 and current_col < 8): + print (line.strip()) + print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + warp, lane, quad, tid, current_row, current_col, value)) + d_softmax[0, 0, current_row, current_col] = value + + +parse_dsoftmax_load(filename) +print ("output block 0 d_softmax: ", d_softmax[0, 0, :, :]) From 5eb754aa5c51ba9624317bdf4aae714360f1ce1f Mon Sep 17 00:00:00 2001 From: robotcator Date: Sun, 4 Sep 2022 16:34:13 +0800 Subject: [PATCH 30/71] add bias support --- csrc/flash_attn/fmha_api.cpp | 41 +++++++++++++++---- csrc/flash_attn/src/fmha.h | 2 + csrc/flash_attn/src/fmha/gmem_tile.h | 5 ++- .../src/fmha_fprop_fp16_kernel.sm80.cu | 2 +- 4 files changed, 40 insertions(+), 10 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 654db659e..0ac91a908 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -61,7 +61,8 @@ void set_params_fprop(FMHA_fprop_params ¶ms, float softmax_scale, bool is_causal, void *attn_mask, - void *attn_bias + void *attn_bias, + int bias_mod_size ) { Data_type acc_type = DATA_TYPE_FP32; @@ -107,6 +108,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, // attn mask & bias params.attn_mask_ptr = attn_mask; params.attn_bias_ptr = attn_bias; + params.bias_mod_size = bias_mod_size; #ifdef DEBUG_PRINT @@ -176,7 +178,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, bool is_causal, void *attn_mask, void *attn_bias, - void *attn_ds) { + void *attn_ds, + int bias_mod_size) { set_params_fprop(params, b, seqlen_q, seqlen_k, h, d, @@ -191,7 +194,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, softmax_scale, is_causal, attn_mask, - attn_bias); + attn_bias, + bias_mod_size); // Set the pointers and strides. params.dq_ptr = dq.data_ptr(); @@ -272,16 +276,29 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + int bias_mod_size = 0; if (attn_bias.has_value()) { TORCH_CHECK(attn_bias.value().is_cuda()); TORCH_CHECK(attn_bias.value().dtype() == q_dtype); TORCH_CHECK(attn_bias.value().is_contiguous()); + + const auto bias_sizes = attn_bias->sizes(); + // last two dimension + bias_mod_size = bias_sizes[0]; + TORCH_CHECK(bias_sizes[1] == num_heads); } if (attn_mask.has_value()) { TORCH_CHECK(attn_mask.value().is_cuda()); TORCH_CHECK(attn_mask.value().dtype() == q_dtype); TORCH_CHECK(attn_mask.value().is_contiguous()); + + const auto mask_sizes = attn_mask->sizes(); + // last two dimension + const int mask_mod_size = mask_sizes[1]; + TORCH_CHECK(mask_sizes[1] == 1 || mask_sizes[1] == num_heads); + std::cout << "mask head mode size: " << mask_mod_size << std::endl; + launch_params.params.mask_mod_size = mask_mod_size; } int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; @@ -335,7 +352,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q softmax_scale, is_causal, attn_mask ? attn_mask->data_ptr() : nullptr, - attn_bias ? attn_bias->data_ptr() : nullptr + attn_bias ? attn_bias->data_ptr() : nullptr, + bias_mod_size ); run_fmha_fp16_sm80(launch_params, /*configure=*/ true); @@ -446,11 +464,16 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + int bias_mod_size = 0; if (attn_bias.has_value()) { TORCH_CHECK(attn_bias.value().is_cuda()); TORCH_CHECK(attn_bias.value().dtype() == q_dtype); TORCH_CHECK(attn_bias.value().is_contiguous()); // check attn_bias shape + const auto bias_sizes = attn_bias->sizes(); + // last two dimension + bias_mod_size = bias_sizes[0]; + TORCH_CHECK(bias_sizes[1] == num_heads); } if (attn_mask.has_value()) { @@ -514,7 +537,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size is_causal, attn_mask ? attn_mask->data_ptr() : nullptr, attn_bias ? attn_bias->data_ptr() : nullptr, - attn_bias ? ds.data_ptr() : nullptr); + attn_bias ? ds.data_ptr() : nullptr, + bias_mod_size); // used for dbias auto gen = at::get_generator_or_default( @@ -647,7 +671,9 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t softmax_scale, is_causal, nullptr, - nullptr); + nullptr, + 0); + // TODO: add mask / bias launch_params.params.blockmask = static_cast(blockmask.data_ptr()); run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true); @@ -797,7 +823,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size is_causal, nullptr, nullptr, - nullptr); + nullptr, + 0); params.blockmask = static_cast(blockmask.data_ptr()); auto gen = at::get_generator_or_default( diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index a130525de..71420fc4c 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -75,9 +75,11 @@ struct FMHA_fprop_params : public Qkv_params { // The attn mask matrix void * __restrict__ attn_mask_ptr; + int mask_mod_size; // The attn bias matrix void * __restrict__ attn_bias_ptr; + int bias_mod_size; // The ds matrix void * __restrict__ attn_ds_ptr; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 01fe1e570..82def32d7 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -686,7 +686,8 @@ struct Gmem_tile_mma_bias { // The distance between two blocks (in bytes). // TODO: mask is [bs, head, seq_q, seq_k] // The block index. - uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + // uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + uint32_t bidx = ( binfo.bidb % params.bias_mod_size ) * params.h + binfo.bidh; // the index of bs and head dim uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; @@ -697,7 +698,7 @@ struct Gmem_tile_mma_bias { if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); - printf("bidb=%d, bidh=%d, param.h=%d\n", binfo.bidb, binfo.bidh, params.h); + printf("bidb=%d, bidh=%d, param.h=%d, bias_mod_size=%d\n", binfo.bidb, binfo.bidh, params.h, params.bias_mod_size); printf("\n"); } #endif diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index d86118f05..2e50bcc6c 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -63,7 +63,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, bool has_bias = !(launch_params.params.attn_bias_ptr == nullptr); #ifdef DEBUG_PRINT - printf ("has_attn=%d, has_bias=%d\n", has_attn, has_bias); + printf ("has_attn=%d, has_bias=%d, bias_mod_size=%d\n", has_attn, has_bias, launch_params.params.bias_mod_size); #endif // attn + bias on From baa6d1b5d88ea51a699a78652058ce85e0eaad27 Mon Sep 17 00:00:00 2001 From: robotcator Date: Mon, 5 Sep 2022 15:12:52 +0800 Subject: [PATCH 31/71] add mask shape --- csrc/flash_attn/fmha_api.cpp | 33 ++-- csrc/flash_attn/src/fmha.h | 2 +- csrc/flash_attn/src/fmha/gmem_tile.h | 3 +- tests/test_forward.cu | 222 +++++++++++++++++++++------ 4 files changed, 200 insertions(+), 60 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 0ac91a908..cb8ff45fb 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -62,7 +62,8 @@ void set_params_fprop(FMHA_fprop_params ¶ms, bool is_causal, void *attn_mask, void *attn_bias, - int bias_mod_size + int bias_mod_size, + int mask_head_mod_size ) { Data_type acc_type = DATA_TYPE_FP32; @@ -109,6 +110,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.attn_mask_ptr = attn_mask; params.attn_bias_ptr = attn_bias; params.bias_mod_size = bias_mod_size; + params.mask_head_mod_size = mask_head_mod_size; #ifdef DEBUG_PRINT @@ -179,7 +181,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, void *attn_mask, void *attn_bias, void *attn_ds, - int bias_mod_size) { + int bias_mod_size, + int mask_head_mod_size) { set_params_fprop(params, b, seqlen_q, seqlen_k, h, d, @@ -195,7 +198,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, is_causal, attn_mask, attn_bias, - bias_mod_size); + bias_mod_size, + mask_head_mod_size); // Set the pointers and strides. params.dq_ptr = dq.data_ptr(); @@ -288,6 +292,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(bias_sizes[1] == num_heads); } + int mask_head_mod_size = 0; if (attn_mask.has_value()) { TORCH_CHECK(attn_mask.value().is_cuda()); TORCH_CHECK(attn_mask.value().dtype() == q_dtype); @@ -295,10 +300,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const auto mask_sizes = attn_mask->sizes(); // last two dimension - const int mask_mod_size = mask_sizes[1]; + mask_head_mod_size = mask_sizes[1]; TORCH_CHECK(mask_sizes[1] == 1 || mask_sizes[1] == num_heads); - std::cout << "mask head mode size: " << mask_mod_size << std::endl; - launch_params.params.mask_mod_size = mask_mod_size; + std::cout << "mask head mode size: " << mask_head_mod_size << std::endl; } int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; @@ -353,7 +357,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q is_causal, attn_mask ? attn_mask->data_ptr() : nullptr, attn_bias ? attn_bias->data_ptr() : nullptr, - bias_mod_size + bias_mod_size, + mask_head_mod_size ); run_fmha_fp16_sm80(launch_params, /*configure=*/ true); @@ -476,10 +481,17 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(bias_sizes[1] == num_heads); } + int mask_head_mod_size = 0; if (attn_mask.has_value()) { TORCH_CHECK(attn_mask.value().is_cuda()); TORCH_CHECK(attn_mask.value().dtype() == q_dtype); TORCH_CHECK(attn_mask.value().is_contiguous()); + + const auto mask_sizes = attn_mask->sizes(); + // last two dimension + mask_head_mod_size = mask_sizes[1]; + TORCH_CHECK(mask_sizes[1] == 1 || mask_sizes[1] == num_heads); + std::cout << "mask head mode size: " << mask_head_mod_size << std::endl; } auto opts = q.options(); @@ -538,7 +550,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size attn_mask ? attn_mask->data_ptr() : nullptr, attn_bias ? attn_bias->data_ptr() : nullptr, attn_bias ? ds.data_ptr() : nullptr, - bias_mod_size); + bias_mod_size, + mask_head_mod_size); // used for dbias auto gen = at::get_generator_or_default( @@ -672,6 +685,7 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t is_causal, nullptr, nullptr, + 0, 0); // TODO: add mask / bias launch_params.params.blockmask = static_cast(blockmask.data_ptr()); @@ -824,7 +838,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size nullptr, nullptr, nullptr, - 0); + 0, 0); + // TODO: add support bias / mask params.blockmask = static_cast(blockmask.data_ptr()); auto gen = at::get_generator_or_default( diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 71420fc4c..4aabd8ffe 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -75,7 +75,7 @@ struct FMHA_fprop_params : public Qkv_params { // The attn mask matrix void * __restrict__ attn_mask_ptr; - int mask_mod_size; + int mask_head_mod_size; // The attn bias matrix void * __restrict__ attn_bias_ptr; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 82def32d7..46fa99032 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -530,7 +530,8 @@ struct Gmem_tile_mma_mask { // The distance between two blocks (in bytes). // TODO: mask is [bs * seq, head, seq_q, seq_k] // The block index. - uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + // uint32_t bidx = binfo.bidb * params.h + binfo.bidh; + uint32_t bidx = binfo.bidb * params.h + (binfo.bidh % params.mask_head_mod_size); // the index of bs and head dim uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; diff --git a/tests/test_forward.cu b/tests/test_forward.cu index 06bf5248c..b435c8902 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -4,16 +4,34 @@ #include #include #include +#include +#include -void test_fwd_with_mask() { +void dump_tensor(const std::string &tensor_name, at::Tensor &tensor, const std::string &label) { + std::string file_name = label + "_" + tensor_name + ".data"; + std::ofstream file(file_name.c_str()); + // file << tensor_name << std::endl; + // file << tensor << std::endl; + std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; + auto flatten_tensor = tensor.flatten(); + auto size = flatten_tensor.numel(); + + for (int i = 0; i < size; i ++) { + file << flatten_tensor[i].item() << " "; + // file << flatten_tensor[i] << " "; + } + file << std::endl; +} + +void test_fwd_with_mask(int has_mask) { int batch_size = 1; int nheads = 1; int headdim = 16; - int max_seqlen_q_ = 128; - int max_seqlen_k_ = 128; + int max_seqlen_q_ = 8; + int max_seqlen_k_ = 8; - float softmax_scale = 0.1; + float softmax_scale = 1; bool zero_tensors = false; bool is_causal = false; @@ -62,26 +80,49 @@ void test_fwd_with_mask() { auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); - at::Tensor attn_mask = at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).cuda(); - - cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < max_seqlen_q_; k ++) { - for (int l = 0; l < max_seqlen_k_; l ++) { - attn_mask[i][j][k][l] = cnt * 0.001; - cnt ++; - } - } - } - } + // at::Tensor attn_mask = at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).cuda(); + + // cnt = 0; + // for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { + // for (int j = 0; j < nheads; j ++) { + // for (int k = 0; k < max_seqlen_q_; k ++) { + // for (int l = 0; l < max_seqlen_k_; l ++) { + // attn_mask[i][j][k][l] = cnt * 0.001; + // cnt ++; + // } + // } + // } + // } + at::Tensor attn_mask = at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).tril().cuda(); + c10::optional gen_; c10::optional attn_bias; // std::cout << "attn bias" << attn_bias << std::endl; - - std::vector ret = mha_fwd( + std::vector ret; + if (has_mask) { + ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + attn_mask, + attn_bias + ); + dump_tensor("attn_output", ret[0], "has_mask"); + dump_tensor("attn_lse", ret[1], "has_mask"); + }else{ + ret = mha_fwd( q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -95,14 +136,95 @@ void test_fwd_with_mask() { is_causal, return_softmax, gen_, - attn_mask, + attn_bias, attn_bias - ); + ); + dump_tensor("attn_output", ret[0], ""); + dump_tensor("attn_lse", ret[1], ""); + } - std::cout << "Ret vec size is " << ret.size(); - for (int i = 0; i < ret.size(); i ++) { - ret[i].cpu(); - std::cout << ret[i] << std::endl; + // std::cout << "Ret vec size is " << ret.size(); + // for (int i = 0; i < ret.size(); i ++) { + // ret[i].cpu(); + // std::cout << ret[i] << std::endl; + // } + + at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + dout_cpu[i][j][k] = cnt * 0.001; + cnt ++; + } + } + } + + at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + auto dout = dout_cpu.cuda(); + auto dq = dq_cpu.cuda(); + auto dk = dk_cpu.cuda(); + auto dv = dv_cpu.cuda(); + std::vector bwd_ret; + + if (has_mask) { + bwd_ret = mha_bwd( + dout, + q, + k, + v, + ret[0], + ret[1], + dq, + dk, + dv, + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + gen_, + attn_mask, + attn_bias + ); + dump_tensor("attn_dq", dq, "has_mask"); + dump_tensor("attn_dk", dk, "has_mask"); + dump_tensor("attn_dv", dv, "has_mask"); + // dump_tensor("attn_ds", bwd_ret[5], "has_mask"); + }else{ + bwd_ret = mha_bwd( + dout, + q, + k, + v, + ret[0], + ret[1], + dq, + dk, + dv, + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + gen_, + attn_bias, + attn_bias + // placeholder + ); + dump_tensor("attn_dq", dq, ""); + dump_tensor("attn_dk", dk, ""); + dump_tensor("attn_dv", dv, ""); } } @@ -114,7 +236,7 @@ void test_fwd_with_mask_mini() { int max_seqlen_q_ = 2; int max_seqlen_k_ = 2; - float softmax_scale = 0.1; + float softmax_scale = 1.0; bool zero_tensors = false; bool is_causal = false; @@ -329,23 +451,6 @@ void test_fwd_with_bias_mini() { } -void dump_tensor(const std::string &tensor_name, at::Tensor &tensor, const std::string &label) { - std::string file_name = label + "_" + tensor_name + ".data"; - std::ofstream file(file_name.c_str()); - // file << tensor_name << std::endl; - // file << tensor << std::endl; - std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; - auto flatten_tensor = tensor.flatten(); - auto size = flatten_tensor.numel(); - - for (int i = 0; i < size; i ++) { - file << flatten_tensor[i].item() << " "; - // file << flatten_tensor[i] << " "; - } - file << std::endl; -} - - void test_fwd_with_bias(bool has_bias) { int batch_size = 1; int nheads = 1; @@ -537,8 +642,8 @@ void test_fwd_with_bias(bool has_bias) { bwd_ret = mha_bwd( dout, q, - k, - v, + k, + v, ret[0], ret[1], dq, @@ -572,11 +677,30 @@ void test_fwd_with_bias(bool has_bias) { int main(int argc, char** argv){ // test_fwd(); // test_fwd_with_bias_mini(); - bool has_bias = false; - if( argc == 2 ) { + int has_bias = 0; + int has_masked = 0; + + if ( argc >= 2 ) { std::cout << "argv: " << argv[1] << std::endl; - has_bias = true; + if (strcmp(argv[1], "has_bias") == 0) { + if (strcmp(argv[2], "true") == 0) { + has_bias = 1; + }else{ + has_bias = 0; + } + test_fwd_with_bias(has_bias); + }else if (strcmp(argv[1], "has_mask") == 0) { + if (strcmp(argv[2], "true") == 0) { + has_masked = 1; + }else{ + has_masked = 0; + } + test_fwd_with_mask(has_masked); + }else{ + has_bias = 0; + has_masked = 0; + std::cout << "no paramter found" << std::endl; + } } - test_fwd_with_bias(has_bias); return 0; } From 84e462fe9cf793e66d4d793fbf94af0b820cdfde Mon Sep 17 00:00:00 2001 From: robotcator Date: Mon, 5 Sep 2022 16:25:36 +0800 Subject: [PATCH 32/71] add test --- .../test/test_forward_without_bias_mask.py | 2 +- .../src/fmha_fprop_fp16_kernel.sm80.cu | 321 +++++++++--------- tests/test_forward.cu | 4 +- tests/tools/check_output.py | 120 ++++--- 4 files changed, 243 insertions(+), 204 deletions(-) diff --git a/benchmarks/test/test_forward_without_bias_mask.py b/benchmarks/test/test_forward_without_bias_mask.py index 3306aa885..930479cba 100644 --- a/benchmarks/test/test_forward_without_bias_mask.py +++ b/benchmarks/test/test_forward_without_bias_mask.py @@ -107,7 +107,7 @@ def _flash_attn(q, k, v, attn_mask=None): if attn_mask is not None: # import pdb; pdb.set_trace() - attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]) + attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]).contiguous() out = flash_attn_unpadded_func( q, diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 2e50bcc6c..98ad93ff4 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -68,118 +68,118 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, // attn + bias on // IsDropoutConst off - auto kernel = &fmha_fprop_fp16_sm80_loop_kernel; - dim3 grid(launch_params.params.b, launch_params.params.h); + // auto kernel = &fmha_fprop_fp16_sm80_loop_kernel; + // dim3 grid(launch_params.params.b, launch_params.params.h); - printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - printf("block size: %d\n", Kernel_traits::THREADS); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + // kernel<<>>( + // launch_params.params); + // FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // if (has_attn) - // { - // if (has_bias) { - // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // // https://github.com/kokkos/kokkos-kernels/issues/349 - // // https://github.com/HazyResearch/flash-attention/issues/21 - // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - // auto kernel = launch_params.params.is_causal - // ? (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel) - // : (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel); - // if( smem_size >= 48 * 1024 ) { - // FMHA_CHECK_CUDA(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // } - // dim3 grid(launch_params.params.b, launch_params.params.h); + if (has_attn) + { + if (has_bias) { + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); - // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // // printf("block size: %d\n", Kernel_traits::THREADS); - // kernel<<>>( - // launch_params.params); - // FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // }); - // }else{ - // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // // https://github.com/kokkos/kokkos-kernels/issues/349 - // // https://github.com/HazyResearch/flash-attention/issues/21 - // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - // auto kernel = launch_params.params.is_causal - // ? (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel) - // : (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel); - // if( smem_size >= 48 * 1024 ) { - // FMHA_CHECK_CUDA(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // } - // dim3 grid(launch_params.params.b, launch_params.params.h); + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); - // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // // printf("block size: %d\n", Kernel_traits::THREADS); - // kernel<<>>( - // launch_params.params); - // FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // }); - // } - // }else{ - // if (has_bias) { - // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // // https://github.com/kokkos/kokkos-kernels/issues/349 - // // https://github.com/HazyResearch/flash-attention/issues/21 - // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - // auto kernel = launch_params.params.is_causal - // ? (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel) - // : (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel); - // if( smem_size >= 48 * 1024 ) { - // FMHA_CHECK_CUDA(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // } - // dim3 grid(launch_params.params.b, launch_params.params.h); + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + } + }else{ + if (has_bias) { + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); - // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // // printf("block size: %d\n", Kernel_traits::THREADS); - // kernel<<>>( - // launch_params.params); - // FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // }); - // }else{ - // // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // // https://github.com/kokkos/kokkos-kernels/issues/349 - // // https://github.com/HazyResearch/flash-attention/issues/21 - // BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - // auto kernel = launch_params.params.is_causal - // ? (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel) - // : (launch_params.return_softmax - // ? &fmha_fprop_fp16_sm80_loop_kernel - // : &fmha_fprop_fp16_sm80_loop_kernel); - // if( smem_size >= 48 * 1024 ) { - // FMHA_CHECK_CUDA(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // } - // dim3 grid(launch_params.params.b, launch_params.params.h); + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fprop_fp16_sm80_loop_kernel + : &fmha_fprop_fp16_sm80_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + dim3 grid(launch_params.params.b, launch_params.params.h); - // // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // // printf("block size: %d\n", Kernel_traits::THREADS); - // kernel<<>>( - // launch_params.params); - // FMHA_CHECK_CUDA(cudaPeekAtLastError()); - // }); - // } - // } + // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); + // printf("block size: %d\n", Kernel_traits::THREADS); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + } + } } void run_fmha_fp16_sm80(Launch_params &launch_params, @@ -194,64 +194,61 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } - // else if( launch_params.params.seqlen_k == 256 ) { - // using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { - // // TD [2022-05-15] 512 gives wrong results rn - // // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; - // using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } + else if( launch_params.params.seqlen_k == 256 ) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + // TD [2022-05-15] 512 gives wrong results rn + // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } } - // debug on comments - // else if (launch_params.params.d == 32) { - // if( launch_params.params.seqlen_k == 128 ) { - // using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else if( launch_params.params.seqlen_k == 256 ) { - // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { - // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } - // } else if (launch_params.params.d == 64) { - // if( launch_params.params.seqlen_k == 128 ) { - // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else if( launch_params.params.seqlen_k >= 256 ) { - // if (dprops->major == 8 && dprops->minor >= 0) { - // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else if (dprops->major == 7 && dprops->minor == 5) { - // if (launch_params.is_dropout) { // Need to use the same block size as backward - // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { - // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } - // } - // } - // } else if (launch_params.params.d == 128) { - // if( launch_params.params.seqlen_k == 128 ) { - // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { - // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { - // // TD [2022-06-05] Keep K in registers to reduce register spilling - // // Gives about 6% speedup compared to using block size 128. - // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { // Need to use the same block size as backward - // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } - // } - // } - // debug on comments - + else if (launch_params.params.d == 32) { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k == 256 ) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } else if (launch_params.params.d == 64) { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k >= 256 ) { + if (dprops->major == 8 && dprops->minor >= 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if (dprops->major == 7 && dprops->minor == 5) { + if (launch_params.is_dropout) { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } + } + } else if (launch_params.params.d == 128) { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { + // TD [2022-06-05] Keep K in registers to reduce register spilling + // Gives about 6% speedup compared to using block size 128. + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } + } // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; diff --git a/tests/test_forward.cu b/tests/test_forward.cu index b435c8902..5dbb20cd2 100644 --- a/tests/test_forward.cu +++ b/tests/test_forward.cu @@ -94,7 +94,7 @@ void test_fwd_with_mask(int has_mask) { // } // } - at::Tensor attn_mask = at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).tril().cuda(); + at::Tensor attn_mask = 1 - at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).tril().cuda(); c10::optional gen_; c10::optional attn_bias; @@ -688,6 +688,7 @@ int main(int argc, char** argv){ }else{ has_bias = 0; } + std::cout << "has bias " << argv[2] << std::endl; test_fwd_with_bias(has_bias); }else if (strcmp(argv[1], "has_mask") == 0) { if (strcmp(argv[2], "true") == 0) { @@ -695,6 +696,7 @@ int main(int argc, char** argv){ }else{ has_masked = 0; } + std::cout << "has mask " << argv[2] << std::endl; test_fwd_with_mask(has_masked); }else{ has_bias = 0; diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py index 4244304ee..194a63c96 100644 --- a/tests/tools/check_output.py +++ b/tests/tools/check_output.py @@ -1,9 +1,18 @@ from audioop import bias from operator import truediv -from socket import NI_NAMEREQD import numpy as np import torch +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--test_np", required=False, help="test np implementation kernel with torch", type=bool, default=False) +parser.add_argument("--has_bias", required=False, help="add bias in attention", type=bool, default=False) +parser.add_argument("--has_mask", required=False, help="add mask in attention", type=bool, default=False) +args = parser.parse_args() +print(args) + + batch_size = 1 nheads = 1 headdim = 16 @@ -13,7 +22,6 @@ dtypes = np.float16 - q_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) k_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) v_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) @@ -36,6 +44,9 @@ bias_ref[i][j][k][l] = cnt * 0.1 cnt += 1 +mask_ref = np.ones([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) +mask_ref = (1 - np.tril(mask_ref)) * -1 + # bias_ref = np.zeros([batch_size , nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) # cnt = 0 # for i in range(batch_size ): @@ -65,7 +76,7 @@ def softmax(logit): return probs -def fwd(q, k, v, max_seqlen_q, bias=None): +def fwd(q, k, v, max_seqlen_q, bias=None, mask=None): batch_size = int(q.shape[0] / max_seqlen_q) head_num = q.shape[1] @@ -85,6 +96,13 @@ def fwd(q, k, v, max_seqlen_q, bias=None): if bias is not None: s = s + bias + + if mask is not None: + # s.masked_fill_(mask < 0, float('-inf')) + mask_np = np.ma.masked_where(mask < 0, s) + # np.ma.set_fill_value(mask_np, float('-inf')) + np.ma.set_fill_value(mask_np, float('-999')) + s = mask_np.filled() p = softmax(s) @@ -94,8 +112,8 @@ def fwd(q, k, v, max_seqlen_q, bias=None): return s, p, o, q, k, v -def bwd(dout, q, k, v, max_seqlen_q, bias=None): - s, p, o, _, _, _ = fwd(q, k, v, max_seqlen_q=max_seqlen_q, bias=bias) +def bwd(dout, q, k, v, max_seqlen_q, bias=None, mask=None): + s, p, o, _, _, _ = fwd(q, k, v, max_seqlen_q=max_seqlen_q, bias=bias, mask=mask) batch_size = int(q.shape[0] / max_seqlen_q) head_num = q.shape[1] @@ -149,23 +167,26 @@ def bwd(dout, q, k, v, max_seqlen_q, bias=None): return dq, dk, dv, ds, dp, dbias -def fwd_pt(q_pt, k_pt, v_pt, bias=None): +def fwd_pt(q_pt, k_pt, v_pt, bias=None, mask=None): s = torch.matmul(q_pt, k_pt.transpose(-1, -2)) if bias is not None: s = s + bias + if mask is not None: + s.masked_fill_(mask < 0, float('-999')) + p = torch.nn.functional.softmax(s, dim=-1) o = torch.matmul(p, v_pt) return s, p, o -def bwd_pt(dout, q, k, v, max_seqlen_q, bias=None): +def bwd_pt(dout, q, k, v, max_seqlen_q, bias=None, mask=None): # q is [batch * seq * seq, head, head_dim] - q_pt, k_pt, v_pt, dout_pt, bias_pt = prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=bias) + q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=bias, mask=mask) - s, p, o = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt) + s, p, o = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) if bias is None: dq, dk, dv = torch.autograd.grad(o, (q_pt, k_pt, v_pt), dout_pt) @@ -191,18 +212,22 @@ def compute_lse(s): return softmax_lse -def check_fwd_kernel(has_bias=False): +def check_fwd_kernel(has_bias=False, has_mask=False): print ("==== check fwd kernel with np ====") if has_bias: - s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) + s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref, mask=None) + elif has_mask: + s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=mask_ref) else: - s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_) + s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=None) # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) # attn_output = np.loadtxt("attn_output.data", delimiter=" ") if has_bias: prefix = "has_bias" print ("has bias on, prefix is ", prefix) + elif has_mask: + prefix = "has_mask" else: prefix = "" @@ -222,8 +247,6 @@ def check_fwd_kernel(has_bias=False): lse_ref = lse_ref.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_) # print ("ref lse: ", lse_ref) - import pdb; pdb.set_trace() - print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) @@ -247,16 +270,21 @@ def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): return True -def check_bwd_kernel(has_bias=False): +def check_bwd_kernel(has_bias=False, has_mask=False): print ("==== check bwd kernel with np ====") if has_bias: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref, mask=None) + elif has_mask: + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=mask_ref) else: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_) + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=None) if has_bias: prefix = "has_bias" print ("has bias on, prefix is ", prefix) + elif has_mask: + prefix = "has_mask" + print ("has mask on, prefix is ", prefix) else: prefix = "" @@ -291,9 +319,13 @@ def check_bwd_kernel(has_bias=False): if has_bias: print ("max error in dq: ", np.abs(attn_dbias - dbias).max(), ) # print (np.abs(attn_dbias - dbias) > 0.001) - attn_ds = np.genfromtxt("{}_attn_ds.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_ds = attn_ds.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) - print ("max error in ds: ", np.abs(attn_ds - ds).max(), ) + # attn_ds = np.genfromtxt("{}_attn_ds.data".format(prefix), delimiter=" ", dtype=np.float32) + # attn_ds = attn_ds.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) + # print ("max error in ds: ", np.abs(attn_ds - ds).max(), ) + + attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) + attn_dbias = attn_dbias.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) + print ("max error in dbias: ", np.abs(attn_dbias - dbias).max(), ) print ("same matrix check q: ", is_same_matrix(attn_dq, dq)) @@ -303,14 +335,15 @@ def check_bwd_kernel(has_bias=False): import pdb; pdb.set_trace() print ("same matrix check dbias: ", is_same_matrix(attn_dbias, dbias)) + def check_bwd_np(has_bias=False): print ("==== check bwd np ====") if has_bias: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) - dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) + dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) else: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None) - dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None) + dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) + dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) assert (dq.shape == dq_pt.detach().cpu().numpy().shape), "oh dq shape didn't match" assert (dk.shape == dk_pt.detach().cpu().numpy().shape), "oh dk shape didn't match" @@ -327,7 +360,7 @@ def check_bwd_np(has_bias=False): return -def prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=None): +def prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=None, mask=None): q_pt = torch.from_numpy(q.copy()) k_pt = torch.from_numpy(k.copy()) v_pt = torch.from_numpy(v.copy()) @@ -354,25 +387,30 @@ def prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=None): else: bias_pt = None + if mask is not None: + mask_pt = torch.from_numpy(mask.copy()).cuda() + else: + mask_pt = None + q_pt.requires_grad = True k_pt.requires_grad = True v_pt.requires_grad = True - return q_pt, k_pt, v_pt, dout_pt, bias_pt + return q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt -def check_fwd_np(has_bias=False): +def check_fwd_np(has_bias=False, has_atten=False): print ("==== check fwd np ====") if has_bias: - s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) + s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) - q_pt, k_pt, v_pt, dout_pt, bias_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref) - s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias_pt) + q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) + s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias_pt, mask_pt) else: - s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_) + s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) - q_pt, k_pt, v_pt, dout_pt, _ = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_) - s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt) + q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_) + s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=None, mask=None) def check_input(a, b): print ("max error in input: ", np.abs(a - b).max()) @@ -553,16 +591,18 @@ def check_dsoftmax_p(softmax_data, has_bias=False): # check_bwd_np(has_bias=has_bias) # print ("====test with bias====") - # print ("====test kernel without bias====") - # has_bias = False + print ("====test kernel without bias====") + has_bias = args.has_bias + has_mask = args.has_mask + + check_fwd_kernel(has_bias=has_bias, has_mask=has_mask) + check_bwd_kernel(has_bias=has_bias, has_mask=has_mask) + + # print ("====test kernel with bias====") + # has_bias = True # check_fwd_kernel(has_bias=has_bias) # check_bwd_kernel(has_bias=has_bias) - print ("====test kernel with bias====") - has_bias = True - check_fwd_kernel(has_bias=has_bias) - check_bwd_kernel(has_bias=has_bias) - # print ("====test bwd kernel softmax without bias====") # has_bias = False # softmax_data = parse_softmax_load("output.log") From 81c0743db3c49b819d5f93b55219885df024ea8f Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 6 Sep 2022 10:39:27 +0800 Subject: [PATCH 33/71] add support --- benchmarks/test/test_forward_with_bias_v2.py | 30 +++++++++++++----- csrc/flash_attn/fmha_api.cpp | 2 +- csrc/flash_attn/src/fmha/gmem_tile.h | 32 ++------------------ flash_attn/flash_attn_interface.py | 12 +++++--- setup.py | 2 +- 5 files changed, 34 insertions(+), 44 deletions(-) diff --git a/benchmarks/test/test_forward_with_bias_v2.py b/benchmarks/test/test_forward_with_bias_v2.py index 4cac9f524..34129f3ff 100644 --- a/benchmarks/test/test_forward_with_bias_v2.py +++ b/benchmarks/test/test_forward_with_bias_v2.py @@ -132,7 +132,7 @@ def _flash_attn(q, k, v, attn_mask=None, attn_bias=None): k_cu_seqlens, q_max_s, k_max_s, - attn_mask=None, + attn_mask=attn_mask, attn_bias=attn_bias, dropout_p = 0., softmax_scale = 1., # q has been scaled already @@ -181,12 +181,13 @@ def gen_attn_mask(mask, neg_inf): print ("origin shape: ", orig_tensor.shape) # [bs, seq, seq, head, c_dim] -bias = torch.ones( +bias = torch.randn( 1, 1, head, seq_q, seq_k, dtype=dtype, device=device -) * 1 +) print ("bias shape: ", bias.shape) bias_broadcast = bias.expand([bs, seq, head, seq_q, seq_k]) +bias_broadcast.requires_grad = True print ("bias_broadcast shape: ", bias_broadcast.shape) # print ("bias_broadcast: ", bias_broadcast) @@ -252,9 +253,15 @@ def gen_attn_mask(mask, neg_inf): # test backward g = torch.randn_like(output3) -dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) -dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) -dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) + +# dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, ), g) +# dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, ), g) +# dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash, ), g) + +dq_ref, dk_ref, dv_ref, dbias_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, bias_broadcast), g) +dq_pt, dk_pt, dv_pt, dbias_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, bias_broadcast), g) +dq, dk, dv, dbias = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash, bias_broadcast), g) + print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) @@ -268,4 +275,13 @@ def gen_attn_mask(mask, neg_inf): print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) -print ("less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file +print ("dq less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) +print ("dk less than twice error: ", ((dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()) ) +print ("dv less than twice error: ", ((dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()) ) + +if dbias is not None: + print ("dbias less than twice error: ", ((dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item()) ) + +assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() +if dbias is not None: + assert (dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item() diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index cb8ff45fb..b444d00b8 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -569,7 +569,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size launch(params, stream); - std::vector result = { dq, dk, dv, softmax_d }; + std::vector result = { softmax_d }; at::Tensor dbias; if (attn_bias.has_value()) { // compare block reduce diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 46fa99032..b736a266a 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -95,14 +95,6 @@ struct Gmem_tile_qkv { threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_LDG, LDGS); printf("\n"); } - if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("use_seqlen_q=%d\n", use_seqlen_q); - printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); - printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_LDG=%d, LDGS=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_LDG, LDGS); - printf("\n"); - } #endif // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); @@ -114,12 +106,6 @@ struct Gmem_tile_qkv { threadIdx.x, threadIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); printf("\n"); } - // if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("use_seqlen_q=%d\n", use_seqlen_q); - // printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", - // threadIdx.x, blockIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); - // printf("\n"); - // } #endif // Assemble the final pointer. ptr += row_offset + col * BYTES_PER_LDG; @@ -256,14 +242,6 @@ struct Gmem_tile_o { threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); printf("\n"); } - // if ((threadIdx.x == 2) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("print o parameter\n"); - // printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", - // threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); - // printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", - // threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); - // printf("\n"); - // } #endif // Is that thread active on the last STG? if( HAS_INCOMPLETE_STG ) { @@ -848,7 +826,7 @@ struct Gmem_tile_mma_ds { uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; // row_offset = (uint32_t)(row * row_stride_in_bytes); #ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x <= 4) && (blockIdx.y == 0)) { + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("ds tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); printf("ds bidb=%d, bidh=%d, param.h=%d, blockIdx.x=%d\n", binfo.bidb, binfo.bidh, params.h, blockIdx.x); @@ -887,18 +865,12 @@ struct Gmem_tile_mma_ds { && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); #ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x <= 4) && (blockIdx.y == 0)) { + if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("ds store blockIdx.x=%d, mi=%d, ni=%d, ii=%d, jj=%d, current_row=%d, current_col=%d, float1=%f, float2=%f, begin=%p, ptrs=%p, preds=%d\n", blockIdx.x, mi, ni, ii, jj, current_row, current_col, tmp00, tmp01, ptr_, ptrs, preds); } #endif if (preds) { -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x <= 4) && (blockIdx.y == 0)) { - printf("ds store blockIdx.x=%d in, mi=%d, ni=%d, ii=%d, jj=%d, ptrs=%p, dst=%ud, data=%ud\n", - blockIdx.x, mi, ni, ii, jj, ptrs, dst, fmha::float2_to_half2(tmp00, tmp01)); - } -#endif fmha::stg(ptrs, dst); } } diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 04488638d..a618742d4 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -29,12 +29,14 @@ def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_s def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, attn_mask, attn_bias, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal): - softmax_d = flash_attn_cuda.bwd( + softmax_d, *rest = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, attn_mask, attn_bias) + import pdb; pdb.set_trace() # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() - return dq, dk, dv, softmax_d + dbias = None if attn_bias is None else rest[0] + return dq, dk, dv, softmax_d, dbias class FlashAttnQKVPackedFunc(torch.autograd.Function): @@ -140,14 +142,14 @@ def backward(ctx, dout, *args): cur_rng_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(rng_state) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - import pdb; pdb.set_trace() - _flash_attn_backward( + # import pdb; pdb.set_trace() + dq, dk, dv, softmax_d, dbias = _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, attn_mask, attn_bias, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, dbias, None, None, None, None # TODO: the last two is attn_mask, attn_bias, bias need gradient diff --git a/setup.py b/setup.py index 2467be80e..bc7e36798 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,7 @@ def append_nvcc_threads(nvcc_extra_args): "nvcc": append_nvcc_threads( [ "-O3", - # "-g", + "-t4", "-DDEBUG_PRINT", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", From 89e74b9106dd7b785115b25b44f9705c8a784f90 Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 6 Sep 2022 08:50:24 -0400 Subject: [PATCH 34/71] fix bf16 and mask shape --- csrc/flash_attn/fmha_api.cpp | 43 ++++++++++++------- csrc/flash_attn/src/fmha.h | 1 + csrc/flash_attn/src/fmha/gmem_tile.h | 27 +++++++++--- .../src/fmha_dgrad_kernel_1xN_loop.h | 2 +- tests/tools/check_output.py | 15 ++++--- 5 files changed, 59 insertions(+), 29 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index b444d00b8..440b1370a 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -63,7 +63,8 @@ void set_params_fprop(FMHA_fprop_params ¶ms, void *attn_mask, void *attn_bias, int bias_mod_size, - int mask_head_mod_size + int mask_head_mod_size, + int mask_seq_mod_size ) { Data_type acc_type = DATA_TYPE_FP32; @@ -111,6 +112,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.attn_bias_ptr = attn_bias; params.bias_mod_size = bias_mod_size; params.mask_head_mod_size = mask_head_mod_size; + params.mask_seq_mod_size = mask_seq_mod_size; #ifdef DEBUG_PRINT @@ -182,7 +184,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, void *attn_bias, void *attn_ds, int bias_mod_size, - int mask_head_mod_size) { + int mask_head_mod_size, + int mask_seq_mod_size) { set_params_fprop(params, b, seqlen_q, seqlen_k, h, d, @@ -199,7 +202,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, attn_mask, attn_bias, bias_mod_size, - mask_head_mod_size); + mask_head_mod_size, + mask_seq_mod_size); // Set the pointers and strides. params.dq_ptr = dq.data_ptr(); @@ -293,6 +297,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } int mask_head_mod_size = 0; + int mask_seq_mod_size = 0; if (attn_mask.has_value()) { TORCH_CHECK(attn_mask.value().is_cuda()); TORCH_CHECK(attn_mask.value().dtype() == q_dtype); @@ -301,8 +306,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const auto mask_sizes = attn_mask->sizes(); // last two dimension mask_head_mod_size = mask_sizes[1]; + mask_seq_mod_size = mask_sizes[2]; TORCH_CHECK(mask_sizes[1] == 1 || mask_sizes[1] == num_heads); - std::cout << "mask head mode size: " << mask_head_mod_size << std::endl; + TORCH_CHECK(mask_sizes[2] == 1 || mask_sizes[2] == max_seqlen_q_); } int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; @@ -358,7 +364,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q attn_mask ? attn_mask->data_ptr() : nullptr, attn_bias ? attn_bias->data_ptr() : nullptr, bias_mod_size, - mask_head_mod_size + mask_head_mod_size, + mask_seq_mod_size ); run_fmha_fp16_sm80(launch_params, /*configure=*/ true); @@ -482,6 +489,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } int mask_head_mod_size = 0; + int mask_seq_mod_size = 0; if (attn_mask.has_value()) { TORCH_CHECK(attn_mask.value().is_cuda()); TORCH_CHECK(attn_mask.value().dtype() == q_dtype); @@ -490,8 +498,9 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const auto mask_sizes = attn_mask->sizes(); // last two dimension mask_head_mod_size = mask_sizes[1]; + mask_seq_mod_size = mask_sizes[2]; TORCH_CHECK(mask_sizes[1] == 1 || mask_sizes[1] == num_heads); - std::cout << "mask head mode size: " << mask_head_mod_size << std::endl; + TORCH_CHECK(mask_sizes[2] == 1 || mask_sizes[2] == max_seqlen_q_); } auto opts = q.options(); @@ -551,7 +560,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size attn_bias ? attn_bias->data_ptr() : nullptr, attn_bias ? ds.data_ptr() : nullptr, bias_mod_size, - mask_head_mod_size); + mask_head_mod_size, + mask_seq_mod_size); // used for dbias auto gen = at::get_generator_or_default( @@ -683,10 +693,11 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t p_dropout, softmax_scale, is_causal, - nullptr, - nullptr, - 0, - 0); + nullptr, // attn_mask + nullptr, // attn_bias + 0, // bias_mod_size + 0, // mask_head_mod_size + 0); // mask_seq_mod_size // TODO: add mask / bias launch_params.params.blockmask = static_cast(blockmask.data_ptr()); @@ -835,10 +846,12 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size p_dropout, softmax_scale, is_causal, - nullptr, - nullptr, - nullptr, - 0, 0); + nullptr, // attn_mask + nullptr, // attn_bias + nullptr, // attn_ds + 0, // bias_mod_size + 0, // mask_head_mod_size + 0); // mask_seq_mod_size // TODO: add support bias / mask params.blockmask = static_cast(blockmask.data_ptr()); diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 4aabd8ffe..5294c121a 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -76,6 +76,7 @@ struct FMHA_fprop_params : public Qkv_params { // The attn mask matrix void * __restrict__ attn_mask_ptr; int mask_head_mod_size; + int mask_seq_mod_size; // The attn bias matrix void * __restrict__ attn_bias_ptr; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index b736a266a..8e4b5ced6 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -482,6 +482,7 @@ struct Gmem_tile_mma_mask { , actual_seqlen_k(binfo.actual_seqlen_k) , tidx_(tidx) , loop_step_idx(loop_step_idx) + , mask_seq_mod_size(params.mask_seq_mod_size) { row_stride_in_bytes = binfo.actual_seqlen_k * BYTES_PER_ELEMENT; @@ -512,15 +513,19 @@ struct Gmem_tile_mma_mask { uint32_t bidx = binfo.bidb * params.h + (binfo.bidh % params.mask_head_mod_size); // the index of bs and head dim - uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; - // row_offset = (uint32_t)(row * row_stride_in_bytes); - row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + // uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + // row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); + + // to support the mask last two dimension + uint32_t row_offset = bidx * params.mask_seq_mod_size * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; + row_offset += (uint32_t)( (row % params.mask_seq_mod_size) * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); - printf("bidb=%d, bidh=%d, param.h=%d\n", binfo.bidb, binfo.bidh, params.h); + printf("bidb=%d, bidh=%d, param.h=%d, mask_head_mod_size=%d, mask_seq_mod_size=%d\n", + binfo.bidb, binfo.bidh, params.h, params.mask_head_mod_size, params.mask_seq_mod_size); printf("\n"); } #endif @@ -550,7 +555,11 @@ struct Gmem_tile_mma_mask { // const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; // 8 is actually col of half data now, for more general case ? // the row is already in the right position - ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + // ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + + // (uint32_t)current_col * BYTES_PER_ELEMENT; + + // to support the mask last two dimension + ptrs[offset] = ptr_ + (uint32_t)(current_row % mask_seq_mod_size) * row_stride_in_bytes + (uint32_t)current_col * BYTES_PER_ELEMENT; preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) @@ -580,7 +589,8 @@ struct Gmem_tile_mma_mask { } inline __device__ void move(const int steps = 1) { - ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; + // to support the mask last two dimension + ptr_ += (uint32_t)(ROWS % mask_seq_mod_size) * row_stride_in_bytes * steps; this->actual_seqlen_q -= ROWS * steps; } @@ -592,6 +602,7 @@ struct Gmem_tile_mma_mask { char *ptr_; int actual_seqlen_q; int actual_seqlen_k; + int mask_seq_mod_size; const int tidx_; }; @@ -839,6 +850,7 @@ struct Gmem_tile_mma_ds { } // Store to global memory. + template inline __device__ void store(const float (&softmax)[2 * M][4 * N]) { uint32_t preds; uint32_t dst; @@ -853,7 +865,8 @@ struct Gmem_tile_mma_ds { for (int jj = 0; jj < 2; ++jj ) { float tmp00 = softmax[2 * mi + ii][4 * ni + jj * 2]; float tmp01 = softmax[2 * mi + ii][4 * ni + jj * 2 + 1]; - dst = fmha::float2_to_half2(tmp00, tmp01); + // dst = fmha::float2_to_half2(tmp00, tmp01); + dst = fmha::float2_pack(tmp00, tmp01); const int current_row = mi * ROWS + ii * 8; const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 485e40d90..888eebf78 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -617,7 +617,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng printf("\n"); } #endif - gmem_ds.store(softmax.elt_); + gmem_ds.template store(softmax.elt_); gmem_ds.move(); } diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py index 194a63c96..ee7d9e4dd 100644 --- a/tests/tools/check_output.py +++ b/tests/tools/check_output.py @@ -99,7 +99,8 @@ def fwd(q, k, v, max_seqlen_q, bias=None, mask=None): if mask is not None: # s.masked_fill_(mask < 0, float('-inf')) - mask_np = np.ma.masked_where(mask < 0, s) + mask_broad = np.broadcast_to(mask, s.shape) + mask_np = np.ma.masked_where(mask_broad < 0, s) # np.ma.set_fill_value(mask_np, float('-inf')) np.ma.set_fill_value(mask_np, float('-999')) s = mask_np.filled() @@ -177,6 +178,8 @@ def fwd_pt(q_pt, k_pt, v_pt, bias=None, mask=None): s.masked_fill_(mask < 0, float('-999')) p = torch.nn.functional.softmax(s, dim=-1) + # from unicore.modules import softmax_dropout + # p = softmax_dropout(s, dropout_prob=0, is_training=True, mask=mask, bias=bias) o = torch.matmul(p, v_pt) return s, p, o @@ -585,11 +588,11 @@ def check_dsoftmax_p(softmax_data, has_bias=False): # check_bwd_np(has_bias=has_bias) # print ("====test without bias====") - # print ("====test with bias====") - # has_bias = True - # check_fwd_np(has_bias=has_bias) - # check_bwd_np(has_bias=has_bias) - # print ("====test with bias====") + print ("====test with bias====") + has_bias = True + check_fwd_np(has_bias=has_bias) + check_bwd_np(has_bias=has_bias) + print ("====test with bias====") print ("====test kernel without bias====") has_bias = args.has_bias From 05505f95f23b0ef60a4954dd828b89bb92a6c126 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 7 Sep 2022 14:49:27 +0800 Subject: [PATCH 35/71] fix mask head=1 shape --- benchmarks/test/test_forward_with_bias_v2.py | 2 +- csrc/flash_attn/src/fmha/gmem_tile.h | 2 +- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 2 + .../src/fmha_fprop_fp16_kernel.sm80.cu | 105 ++++++++---------- flash_attn/flash_attn_interface.py | 1 - 5 files changed, 51 insertions(+), 61 deletions(-) diff --git a/benchmarks/test/test_forward_with_bias_v2.py b/benchmarks/test/test_forward_with_bias_v2.py index 34129f3ff..a200df5c8 100644 --- a/benchmarks/test/test_forward_with_bias_v2.py +++ b/benchmarks/test/test_forward_with_bias_v2.py @@ -41,7 +41,7 @@ def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch value = value.float() if mask is not None: mask = mask.float() - if bias is not None: + if biases is not None: biases = biases.float() # [*, H, C_hidden, K] diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 8e4b5ced6..379dfc3bf 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -510,7 +510,7 @@ struct Gmem_tile_mma_mask { // TODO: mask is [bs * seq, head, seq_q, seq_k] // The block index. // uint32_t bidx = binfo.bidb * params.h + binfo.bidh; - uint32_t bidx = binfo.bidb * params.h + (binfo.bidh % params.mask_head_mod_size); + uint32_t bidx = binfo.bidb * params.mask_head_mod_size + (binfo.bidh % params.mask_head_mod_size); // the index of bs and head dim // uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index b28a5422e..7d1be82bc 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -47,8 +47,10 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ } dim3 grid(params.b, params.h); kernel<<>>(params); +#ifdef DEBUG_PRINT printf("bwd grid size: %d %d\n", params.b, params.h); printf("bwd block size: %d\n", Kernel_traits::THREADS); +#endif FMHA_CHECK_CUDA(cudaPeekAtLastError()); }); } diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 98ad93ff4..20d8c9c3e 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -63,21 +63,10 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, bool has_bias = !(launch_params.params.attn_bias_ptr == nullptr); #ifdef DEBUG_PRINT - printf ("has_attn=%d, has_bias=%d, bias_mod_size=%d\n", has_attn, has_bias, launch_params.params.bias_mod_size); + printf ("has_attn=%d, has_bias=%d, bias_mod_size=%d, mask_seq_mod_size=%d, mask_head_mod_size=%d\n", + has_attn, has_bias, launch_params.params.bias_mod_size, launch_params.params.mask_seq_mod_size, launch_params.params.mask_head_mod_size); #endif - // attn + bias on - // IsDropoutConst off - // auto kernel = &fmha_fprop_fp16_sm80_loop_kernel; - // dim3 grid(launch_params.params.b, launch_params.params.h); - - // printf("grid size: %d %d\n", launch_params.params.b, launch_params.params.h); - // printf("block size: %d\n", Kernel_traits::THREADS); - // kernel<<>>( - // launch_params.params); - // FMHA_CHECK_CUDA(cudaPeekAtLastError()); - - if (has_attn) { if (has_bias) { @@ -204,51 +193,51 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, run_fmha_fp16_sm80_loop_(launch_params, configure); } } - else if (launch_params.params.d == 32) { - if( launch_params.params.seqlen_k == 128 ) { - using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if( launch_params.params.seqlen_k == 256 ) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } else if (launch_params.params.d == 64) { - if( launch_params.params.seqlen_k == 128 ) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if( launch_params.params.seqlen_k >= 256 ) { - if (dprops->major == 8 && dprops->minor >= 0) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else if (dprops->major == 7 && dprops->minor == 5) { - if (launch_params.is_dropout) { // Need to use the same block size as backward - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } - } - } else if (launch_params.params.d == 128) { - if( launch_params.params.seqlen_k == 128 ) { - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { - if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { - // TD [2022-06-05] Keep K in registers to reduce register spilling - // Gives about 6% speedup compared to using block size 128. - using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } else { // Need to use the same block size as backward - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params, configure); - } - } - } + // else if (launch_params.params.d == 32) { + // if( launch_params.params.seqlen_k == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if( launch_params.params.seqlen_k == 256 ) { + // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } else if (launch_params.params.d == 64) { + // if( launch_params.params.seqlen_k == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if( launch_params.params.seqlen_k >= 256 ) { + // if (dprops->major == 8 && dprops->minor >= 0) { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else if (dprops->major == 7 && dprops->minor == 5) { + // if (launch_params.is_dropout) { // Need to use the same block size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } + // } + // } else if (launch_params.params.d == 128) { + // if( launch_params.params.seqlen_k == 128 ) { + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { + // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { + // // TD [2022-06-05] Keep K in registers to reduce register spilling + // // Gives about 6% speedup compared to using block size 128. + // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } else { // Need to use the same block size as backward + // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + // run_fmha_fp16_sm80_loop_(launch_params, configure); + // } + // } + // } // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a618742d4..37b1869db 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -32,7 +32,6 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens softmax_d, *rest = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, attn_mask, attn_bias) - import pdb; pdb.set_trace() # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() dbias = None if attn_bias is None else rest[0] From e435fd17dad87a7990b3baef1db267e71c14b0d7 Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 8 Sep 2022 17:20:57 +0800 Subject: [PATCH 36/71] add dump --- csrc/flash_attn/fmha_api.cpp | 35 ++++++++++++++++ csrc/flash_attn/src/fmha/gmem_tile.h | 44 ++++++++++----------- csrc/flash_attn/src/fmha/softmax.h | 11 +++++- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 2 +- 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 440b1370a..b3628254f 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -222,6 +222,25 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, params.attn_ds_ptr = attn_ds; } +void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const std::string &label) { + std::string file_name = label + "_" + tensor_name + ".data"; + std::ofstream file(file_name.c_str()); + // file << tensor_name << std::endl; + // file << tensor << std::endl; + std::cout << "tensor_name stride 0: " << tensor_name << " " << tensor.stride(0) << std::endl; + std::cout << "tensor_name stride 1: " << tensor_name << " " << tensor.stride(1) << std::endl; + + std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; + auto flatten_tensor = tensor.flatten(); + auto size = flatten_tensor.numel(); + + for (int i = 0; i < size; i ++) { + file << flatten_tensor[i].item() << " "; + // file << flatten_tensor[i] << " "; + } + file << std::endl; +} + std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -311,6 +330,18 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(mask_sizes[2] == 1 || mask_sizes[2] == max_seqlen_q_); } +#ifdef DEBUG_PRINT + dump_tensor("input_q", q, ""); + // dump_tensor("input_k", k, ""); + // dump_tensor("input_v", v, ""); + if (attn_mask.has_value()) { + dump_tensor("input_mask", *attn_mask, ""); + } + if (attn_bias.has_value()) { + dump_tensor("input_bias", *attn_bias, ""); + } +#endif + int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; // Need to round max_seqlen_k to multiples of blocksize_c int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; @@ -382,6 +413,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fp16_sm80(launch_params, /*configure=*/false); +#ifdef DEBUG_PRINT + dump_tensor("output_o", o, ""); +#endif + std::vector result = {o, softmax_lse}; if (return_softmax) {result.push_back(s);} return result; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 379dfc3bf..a24f30e4c 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -99,14 +99,14 @@ struct Gmem_tile_qkv { // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("use_seqlen_q=%d\n", use_seqlen_q); - printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", - threadIdx.x, threadIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); - printf("\n"); - } -#endif +// #ifdef DEBUG_PRINT +// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { +// printf("use_seqlen_q=%d\n", use_seqlen_q); +// printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", +// threadIdx.x, threadIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); +// printf("\n"); +// } +// #endif // Assemble the final pointer. ptr += row_offset + col * BYTES_PER_LDG; } @@ -233,16 +233,16 @@ struct Gmem_tile_o { row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); // Assemble the final pointer. ptr_ += row_offset + col * BYTES_PER_STG; -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("print o parameter\n"); - printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); - printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); - printf("\n"); - } -#endif +// #ifdef DEBUG_PRINT +// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { +// printf("print o parameter\n"); +// printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", +// threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); +// printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", +// threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); +// printf("\n"); +// } +// #endif // Is that thread active on the last STG? if( HAS_INCOMPLETE_STG ) { is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; @@ -524,8 +524,8 @@ struct Gmem_tile_mma_mask { if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); - printf("bidb=%d, bidh=%d, param.h=%d, mask_head_mod_size=%d, mask_seq_mod_size=%d\n", - binfo.bidb, binfo.bidh, params.h, params.mask_head_mod_size, params.mask_seq_mod_size); + printf("bidb=%d, bidh=%d, param.h=%d, mask_head_mod_size=%d, mask_seq_mod_size=%d, loop_step_idx=%d\n", + binfo.bidb, binfo.bidh, params.h, params.mask_head_mod_size, params.mask_seq_mod_size, loop_step_idx); printf("\n"); } #endif @@ -568,8 +568,8 @@ struct Gmem_tile_mma_mask { if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { printf("mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", mi, ni, ii, jj, offset, current_row, current_col, ptr_, ptrs[offset], preds[offset]); - printf("current_row=%d, current_col=%d, ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d\n", - current_row, current_col, ROWS, actual_seqlen_q, COLS, actual_seqlen_k); + printf("current_row=%d, current_col=%d, ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d, loop_step_idx=%d\n", + current_row, current_col, ROWS, actual_seqlen_q, COLS, actual_seqlen_k, loop_step_idx); printf("cond 1=%d\n", (current_row <= min(ROWS, actual_seqlen_q))); printf("cond 2=%d\n", ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); printf("\n"); diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index aaa38786b..86962c136 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -492,7 +492,7 @@ struct Softmax : public Softmax_base { } template - inline __device__ void apply_attn_mask(const Fragment (&mask)[MMAS_M][MMAS_N]) { + inline __device__ void apply_attn_mask(const Fragment (&mask)[MMAS_M][MMAS_N], int l = 0, int loop_step_idx = 0) { #pragma unroll for( int mi = 0; mi < MMAS_M; ++mi ) { #pragma unroll @@ -501,7 +501,14 @@ struct Softmax : public Softmax_base { for( int ni = 0; ni < MMAS_N; ++ni ) { #pragma unroll for( int jj = 0; jj < 4; ++jj ) { - if( abs(toFloat(mask[mi][ni].elt(ii * 4 + jj))) > 0 ) { + float value = toFloat(mask[mi][ni].elt(ii * 4 + jj)); +#ifdef DEBUG_PRINT + if ((blockIdx.x == 0) && (blockIdx.y == 0)) { + printf("Attnmask: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, l = %d, loop_step_idx=%d, blockIdx.x = %d\n", + threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), l, loop_step_idx, blockIdx.x); + } +#endif + if( abs(value) > 0 ) { this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; } // this->elt_[2 * mi + ii][4 * ni + jj] += float(mask[mi][ni].elt(ii * 4 + jj)); diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 934f575db..5aa6546a2 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -594,7 +594,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } #endif // Apply the attn mask. - softmax.apply_attn_mask(frag_mask); + softmax.apply_attn_mask(frag_mask, l, loop_step_idx); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { From 30c29a698aedbf12b066f4fc9a1412b26799b6d2 Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 9 Sep 2022 11:02:32 +0800 Subject: [PATCH 37/71] to fix len 512 --- csrc/flash_attn/fmha_api.cpp | 4 + csrc/flash_attn/src/fmha/softmax.h | 14 +- tests/tools/check_output.py | 227 +++++++++++++++++++++++------ 3 files changed, 195 insertions(+), 50 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index b3628254f..672a24027 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -239,6 +239,10 @@ void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const // file << flatten_tensor[i] << " "; } file << std::endl; + + std::string sfile_name = label + "_" + tensor_name + ".pt"; + std::ofstream sfile(sfile_name.c_str()); + torch::save(tensor, sfile); } std::vector diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index 86962c136..9992d0509 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -502,15 +502,15 @@ struct Softmax : public Softmax_base { #pragma unroll for( int jj = 0; jj < 4; ++jj ) { float value = toFloat(mask[mi][ni].elt(ii * 4 + jj)); + if( abs(value) > 0 ) { + this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; + } #ifdef DEBUG_PRINT if ((blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("Attnmask: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, l = %d, loop_step_idx=%d, blockIdx.x = %d\n", - threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), l, loop_step_idx, blockIdx.x); + printf("Attnmask: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, softmax = %f, l = %d, loop_step_idx=%d, blockIdx.x = %d\n", + threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), this->elt_[2 * mi + ii][4 * ni + jj], l, loop_step_idx, blockIdx.x); } #endif - if( abs(value) > 0 ) { - this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; - } // this->elt_[2 * mi + ii][4 * ni + jj] += float(mask[mi][ni].elt(ii * 4 + jj)); // this->elt_[2 * mi + ii][4 * ni + jj] += toFloat(mask[mi][ni].elt(ii * 4 + jj)); } @@ -534,8 +534,8 @@ struct Softmax : public Softmax_base { this->elt_[2 * mi + ii][4 * ni + jj] += value; #ifdef DEBUG_PRINT if ((blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("AttnBias: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, ldx = %d, blockIdx.x = %d\n", - threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), l, blockIdx.x); + printf("AttnBias: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, softmax = %f, ldx = %d, blockIdx.x = %d\n", + threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), this->elt_[2 * mi + ii][4 * ni + jj], l, blockIdx.x); } #endif } diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py index ee7d9e4dd..99770dac6 100644 --- a/tests/tools/check_output.py +++ b/tests/tools/check_output.py @@ -9,6 +9,8 @@ parser.add_argument("--test_np", required=False, help="test np implementation kernel with torch", type=bool, default=False) parser.add_argument("--has_bias", required=False, help="add bias in attention", type=bool, default=False) parser.add_argument("--has_mask", required=False, help="add mask in attention", type=bool, default=False) +parser.add_argument("--seqlen", required=False, help="seqlen", type=int, default=128) + args = parser.parse_args() print(args) @@ -16,51 +18,73 @@ batch_size = 1 nheads = 1 headdim = 16 -seq = 8 +if args.seqlen is not None: + seq = args.seqlen +else: + seq = 8 + +print ("processing seqlen: {0}".format(seq)) + +bs_seq = 1 max_seqlen_q_ = seq max_seqlen_k_ = seq dtypes = np.float16 -q_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) -k_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) -v_cpu = np.zeros((batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim), dtype=dtypes) +q_cpu = np.zeros((batch_size * bs_seq * max_seqlen_k_, nheads, headdim), dtype=dtypes) +k_cpu = np.zeros((batch_size * bs_seq * max_seqlen_k_, nheads, headdim), dtype=dtypes) +v_cpu = np.zeros((batch_size * bs_seq * max_seqlen_k_, nheads, headdim), dtype=dtypes) cnt = 0 -for i in range(batch_size * max_seqlen_k_ * max_seqlen_k_): +for i in range(batch_size * bs_seq * max_seqlen_k_): for j in range(nheads): for k in range(headdim): - q_cpu[i][j][k] = cnt * 0.001 - k_cpu[i][j][k] = cnt * 0.001 - v_cpu[i][j][k] = cnt * 0.001 + q_cpu[i][j][k] = cnt % 10000 * 0.001 + k_cpu[i][j][k] = cnt % 10000 * 0.001 + v_cpu[i][j][k] = cnt % 10000 * 0.001 cnt += 1 -bias_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) -cnt = 0 -for i in range(batch_size * max_seqlen_k_): - for j in range(nheads): - for k in range(max_seqlen_q_): - for l in range(max_seqlen_k_): - bias_ref[i][j][k][l] = cnt * 0.1 - cnt += 1 +# cost too much time when seq is large +# bias_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) +# cnt = 0 +# for i in range(batch_size * max_seqlen_k_): +# for j in range(nheads): +# for k in range(max_seqlen_q_): +# for l in range(max_seqlen_k_): +# bias_ref[i][j][k][l] = cnt * 0.1 +# cnt += 1 -mask_ref = np.ones([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) -mask_ref = (1 - np.tril(mask_ref)) * -1 +# mask_ref = np.ones([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) +# mask_ref = (1 - np.tril(mask_ref)) * -1 -# bias_ref = np.zeros([batch_size , nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) +mask_ref = np.ones([batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) * -1 # cnt = 0 -# for i in range(batch_size ): +# for i in range(batch_size * max_seqlen_k_): # for j in range(nheads): # for k in range(max_seqlen_q_): # for l in range(max_seqlen_k_): -# bias_ref[i][j][k][l] = cnt * 0.1 +# if l % 2 == 0: +# mask_ref[i][j][k][l] = 0 # cnt += 1 +for i in range(batch_size * bs_seq): + for j in range(1): + for k in range(1): + for l in range(max_seqlen_k_): + if l % 2 == 0: + mask_ref[i][j][k][l] = 0 + + +for i in range(batch_size * bs_seq): + for j in range(nheads): + for k in range(max_seqlen_q_): + mask_ref[i][j][k] = mask_ref[i][0][0] + # dout = np.random.rand(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim).astype(dtype=dtypes) cnt = 0 -dout = np.ones([batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim], dtype=dtypes) -for i in range(batch_size * max_seqlen_k_ * max_seqlen_k_): +dout = np.ones([batch_size * bs_seq * max_seqlen_k_, nheads, headdim], dtype=dtypes) +for i in range(batch_size * bs_seq * max_seqlen_k_): for j in range(nheads): for k in range(headdim): dout[i][j][k] = cnt * 0.001 @@ -102,7 +126,7 @@ def fwd(q, k, v, max_seqlen_q, bias=None, mask=None): mask_broad = np.broadcast_to(mask, s.shape) mask_np = np.ma.masked_where(mask_broad < 0, s) # np.ma.set_fill_value(mask_np, float('-inf')) - np.ma.set_fill_value(mask_np, float('-999')) + np.ma.set_fill_value(mask_np, float('-inf')) s = mask_np.filled() p = softmax(s) @@ -235,24 +259,134 @@ def check_fwd_kernel(has_bias=False, has_mask=False): prefix = "" attn_output = np.genfromtxt("{}_attn_output.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_output = attn_output.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) - attn_output = attn_output.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + attn_output = attn_output.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) + attn_output = attn_output.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) attn_output = attn_output.transpose(0, 2, 1, 3) + # batch_size * bs_seq, nheads, max_seqlen_k_, headdim + print ("attn output shape: ", attn_output.shape) print ("output max error: ", np.abs(o - attn_output).max()) attn_lse = np.genfromtxt("{}_attn_lse.data".format(prefix), delimiter=" ", dtype=np.float32) max_seqlen_q_pad = ((max_seqlen_q_ + 16 - 1) // 16) * 16 - attn_lse = attn_lse.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_pad) + attn_lse = attn_lse.reshape(batch_size * bs_seq , nheads, max_seqlen_q_pad) # print ("attn lse: ", attn_lse) attn_lse = attn_lse[:,:,:max_seqlen_q_] lse_ref = compute_lse(s) - lse_ref = lse_ref.reshape(batch_size * max_seqlen_k_ , nheads, max_seqlen_q_) + lse_ref = lse_ref.reshape(batch_size * bs_seq , nheads, max_seqlen_q_) # print ("ref lse: ", lse_ref) print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) + print ("is same matrix: ", is_same_matrix(lse_ref, attn_lse)) + print ("is same matrix: ", is_same_matrix(o, attn_output)) + + # with python interface input + python_inputs = np.genfromtxt("../inputs_flash_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) + python_inputs = python_inputs.reshape(batch_size, bs_seq, nheads, max_seqlen_q_, headdim) + python_inputs = python_inputs.transpose(0, 1, 3, 2, 4) + python_inputs = python_inputs.reshape(batch_size * bs_seq * max_seqlen_q_, nheads, headdim) + print ("is same matrix input: ", is_same_matrix(python_inputs, q_cpu)) + + python_attn_mask = np.genfromtxt("../attn_mask_flash_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) + python_attn_mask = python_attn_mask.reshape(batch_size, bs_seq, nheads, max_seqlen_q_, max_seqlen_k_) + python_attn_mask = python_attn_mask.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_) + print ("is same matrix mask: ", is_same_matrix(python_inputs, q_cpu)) + + # flash tmp output + # out = out.reshape(*batch_dims, n, no_heads, c) + + # python_output_tmp0 = np.genfromtxt("../tmp2.data", delimiter=" ", dtype=np.float32) + # python_output_tmp0 = python_output_tmp0.reshape(batch_size, bs_seq, max_seqlen_q_, nheads, headdim) + # python_output_tmp0 = python_output_tmp0.transpose(0, 1, 3, 2, 4) + # python_output_tmp0 = python_output_tmp0.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) + + # print (python_output_tmp0.shape) + # print ("is same matrix flash output tmp1: ", is_same_matrix(o, python_output_tmp0, verbose=True)) + # print ("is same matrix flash output tmp1: ", is_same_matrix(attn_output, python_output_tmp0)) + + python_output_tmp1 = np.genfromtxt("../flash_temp1.output".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) + python_output_tmp1 = python_output_tmp1.reshape(batch_size, bs_seq, max_seqlen_q_, nheads, headdim) + python_output_tmp1 = python_output_tmp1.transpose(0, 1, 3, 2, 4) + python_output_tmp1 = python_output_tmp1.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) + + print (python_output_tmp1.shape) + print ("is same matrix flash output tmp1: ", is_same_matrix(o, python_output_tmp1, verbose=True)) + print ("is same matrix flash output tmp1: ", is_same_matrix(attn_output, python_output_tmp1, verbose=True)) + + # flash output + # [batch_size, bs_seq, seq_k, head, c_dim] + # 1, 1, 512, 1, 16 + python_output = np.genfromtxt("../output_flash_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) + python_output = python_output.reshape(batch_size, bs_seq, max_seqlen_q_, nheads, headdim) + python_output = python_output.transpose(0, 1, 3, 2, 4) + python_output = python_output.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) + + print (python_output.shape) + print ("is same matrix flash output: ", is_same_matrix(o, python_output)) + print ("is same matrix flash output: ", is_same_matrix(attn_output, python_output)) + + # torch output + python_torch_output = np.genfromtxt("../output_torch_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) + python_torch_output = python_torch_output.reshape(batch_size, bs_seq, nheads, max_seqlen_q_, headdim) + python_torch_output = python_torch_output.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) + + print (python_torch_output.shape) + print ("is same matrix torch output: ", is_same_matrix(o, python_torch_output)) + print ("is same matrix torch output: ", is_same_matrix(attn_output, python_torch_output)) + + + +def check_fwd_kernel_pt(has_bias=False, has_mask=False): + print ("==== check fwd kernel with np ====") + if has_bias: + q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=None) + s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) + elif has_mask: + q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=mask_ref) + s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) + else: + q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) + s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) + + o = o_pt.detach().cpu().numpy() + s = s_pt.detach().cpu().numpy() + + # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) + + # attn_output = np.loadtxt("attn_output.data", delimiter=" ") + if has_bias: + prefix = "has_bias" + print ("has bias on, prefix is ", prefix) + elif has_mask: + prefix = "has_mask" + else: + prefix = "" + + attn_output = np.genfromtxt("{}_attn_output.data".format(prefix), delimiter=" ", dtype=np.float32) + attn_output = attn_output.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) + attn_output = attn_output.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) + attn_output = attn_output.transpose(0, 2, 1, 3) + print ("output max error: ", np.abs(o - attn_output).max()) + + attn_lse = np.genfromtxt("{}_attn_lse.data".format(prefix), delimiter=" ", dtype=np.float32) + max_seqlen_q_pad = ((max_seqlen_q_ + 16 - 1) // 16) * 16 + attn_lse = attn_lse.reshape(batch_size * bs_seq , nheads, max_seqlen_q_pad) + # print ("attn lse: ", attn_lse) + attn_lse = attn_lse[:,:,:max_seqlen_q_] + + lse_ref = compute_lse(s) + lse_ref = lse_ref.reshape(batch_size * bs_seq , nheads, max_seqlen_q_) + # print ("ref lse: ", lse_ref) + + print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) + print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) + + print ("is same matrix (lse): ", is_same_matrix(lse_ref, attn_lse)) + print ("is same matrix (attn_output): ", is_same_matrix(o, attn_output)) + + def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): diff = np.abs(pred - gt) @@ -297,16 +431,16 @@ def check_bwd_kernel(has_bias=False, has_mask=False): if has_bias: attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_dq = attn_dq.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) - attn_dk = attn_dk.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) - attn_dv = attn_dv.reshape(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim) + attn_dq = attn_dq.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) + attn_dk = attn_dk.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) + attn_dv = attn_dv.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) - attn_dq = attn_dq.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) - attn_dk = attn_dk.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) - attn_dv = attn_dv.reshape(batch_size * max_seqlen_k_, max_seqlen_k_, nheads, headdim) + attn_dq = attn_dq.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) + attn_dk = attn_dk.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) + attn_dv = attn_dv.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) if has_bias: - attn_dbias = attn_dbias.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) + attn_dbias = attn_dbias.reshape(batch_size * bs_seq, nheads, max_seqlen_k_, max_seqlen_k_) attn_dq = attn_dq.transpose(0, 2, 1, 3) attn_dk = attn_dk.transpose(0, 2, 1, 3) @@ -327,7 +461,7 @@ def check_bwd_kernel(has_bias=False, has_mask=False): # print ("max error in ds: ", np.abs(attn_ds - ds).max(), ) attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_dbias = attn_dbias.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) + attn_dbias = attn_dbias.reshape(batch_size * bs_seq, nheads, max_seqlen_k_, max_seqlen_k_) print ("max error in dbias: ", np.abs(attn_dbias - dbias).max(), ) @@ -371,6 +505,7 @@ def prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=None, mask=None): batch_size = int(q.shape[0] / max_seqlen_q) head_num = q.shape[1] head_dim = q.shape[2] + import pdb; pdb.set_trace() dout_pt = torch.from_numpy(dout.copy()) dout_pt = dout_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) @@ -588,18 +723,24 @@ def check_dsoftmax_p(softmax_data, has_bias=False): # check_bwd_np(has_bias=has_bias) # print ("====test without bias====") - print ("====test with bias====") - has_bias = True - check_fwd_np(has_bias=has_bias) - check_bwd_np(has_bias=has_bias) - print ("====test with bias====") + # print ("====test with bias====") + # has_bias = True + # check_fwd_np(has_bias=has_bias) + # check_bwd_np(has_bias=has_bias) + # print ("====test with bias====") + + # print ("====test kernel using torch====") + # has_bias = args.has_bias + # has_mask = args.has_mask + + # check_fwd_kernel_pt(has_bias=has_bias, has_mask=has_mask) - print ("====test kernel without bias====") + print ("====test kernel using numpy====") has_bias = args.has_bias has_mask = args.has_mask check_fwd_kernel(has_bias=has_bias, has_mask=has_mask) - check_bwd_kernel(has_bias=has_bias, has_mask=has_mask) + # check_bwd_kernel(has_bias=has_bias, has_mask=has_mask) # print ("====test kernel with bias====") # has_bias = True From f06016e869630ea3b221dd4973b2fb2d1b432440 Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 9 Sep 2022 11:04:27 +0800 Subject: [PATCH 38/71] add test --- tests/test_forward_shape.cu | 249 ++++++++++++++++++++++++++++++++++++ tests/test_torch_capi.cpp | 60 +++++++++ 2 files changed, 309 insertions(+) create mode 100644 tests/test_forward_shape.cu create mode 100644 tests/test_torch_capi.cpp diff --git a/tests/test_forward_shape.cu b/tests/test_forward_shape.cu new file mode 100644 index 000000000..a26d7955d --- /dev/null +++ b/tests/test_forward_shape.cu @@ -0,0 +1,249 @@ +#include +#include +//#include +#include +#include +#include +#include +#include +#include + + +void dump_tensor(const std::string &tensor_name, at::Tensor &tensor, const std::string &label) { + std::string file_name = label + "_" + tensor_name + ".data"; + std::ofstream file(file_name.c_str()); + // file << tensor_name << std::endl; + // file << tensor << std::endl; + std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; + auto flatten_tensor = tensor.flatten(); + auto size = flatten_tensor.numel(); + + for (int i = 0; i < size; i ++) { + file << flatten_tensor[i].item() << " "; + // file << flatten_tensor[i] << " "; + } + file << std::endl; +} + +void test_fwd_with_mask(int seq, int has_mask=1) { + int batch_size = 1; + int nheads = 1; + int headdim = 16; + // int seq = 400; + + int bs_seq = 1; + int max_seqlen_q_ = seq; + int max_seqlen_k_ = seq; + + float softmax_scale = 1; + + bool zero_tensors = false; + bool is_causal = false; + bool return_softmax = false; + + // q -> [bs * seq, head, head_dim] + // q -> [1 * 128, 1, 16] + // block q -> [128, 16] + + // k -> [bs * seq, head, head_dim] + // k -> [1 * 128, 1, 16] + // block k -> [128, 16] + + // v -> [bs * seq, head, head_dim] + // v -> [1 * 128, 1, 16] + // block k -> [128, 16] + + at::Tensor q_cpu = at::zeros({batch_size * bs_seq * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor k_cpu = at::zeros({batch_size * bs_seq * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor v_cpu = at::zeros({batch_size * bs_seq * max_seqlen_k_, nheads, headdim}, at::kHalf); + + int cnt = 0; + for (int i = 0; i < batch_size * bs_seq * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + q_cpu[i][j][k] = (cnt % 10000) * 0.001; + k_cpu[i][j][k] = (cnt % 10000) * 0.001; + v_cpu[i][j][k] = (cnt % 10000) * 0.001; + cnt ++; + } + } + } + + auto q = q_cpu.cuda(); + auto k = k_cpu.cuda(); + auto v = v_cpu.cuda(); + + at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * bs_seq + 1}, at::kInt); + at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * bs_seq + 1}, at::kInt); + + for (int i = 0; i < batch_size * bs_seq + 1; ++i) { + cu_seqlens_q_cpu[i] = i * max_seqlen_q_; + cu_seqlens_k_cpu[i] = i * max_seqlen_k_; + } + + auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); + auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); + + at::Tensor attn_mask = at::ones({batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf) * -1; + + // cnt = 0; + // for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { + // for (int j = 0; j < nheads; j ++) { + // for (int k = 0; k < max_seqlen_q_; k ++) { + // for (int l = 0; l < max_seqlen_k_; l ++) { + // // attn_mask[i][j][k][l] = cnt * 0.001; + // // cnt ++; + // if (l % 2 == 0) { + // attn_mask[i][j][k][l] = 0; + // } + // cnt ++; + // } + // } + // } + // } + + for (int i = 0; i < batch_size * bs_seq; i ++) { + for (int j = 0; j < 1; j ++) { + for (int k = 0; k < 1; k ++) { + for (int l = 0; l < max_seqlen_k_; l ++) { + if (l % 2 == 0) { + attn_mask[i][0][0][l] = 0; + } + } + } + } + } + + for (int i = 0; i < batch_size * bs_seq; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < max_seqlen_q_; k ++) { + attn_mask[i][j][k] = attn_mask[i][0][0]; + } + } + } + + attn_mask = attn_mask.cuda(); + + c10::optional gen_; + c10::optional attn_bias; + + std::vector ret; + + ret = mha_fwd( + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, // False + is_causal, // False + return_softmax, // False + gen_, + attn_mask, + attn_bias + ); + dump_tensor("attn_output", ret[0], "has_mask"); + dump_tensor("attn_lse", ret[1], "has_mask"); + + + return ; + // std::cout << "Ret vec size is " << ret.size(); + // for (int i = 0; i < ret.size(); i ++) { + // ret[i].cpu(); + // std::cout << ret[i] << std::endl; + // } + + at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + cnt = 0; + for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { + for (int j = 0; j < nheads; j ++) { + for (int k = 0; k < headdim; k ++) { + dout_cpu[i][j][k] = cnt * 0.001; + cnt ++; + } + } + } + + at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); + + auto dout = dout_cpu.cuda(); + auto dq = dq_cpu.cuda(); + auto dk = dk_cpu.cuda(); + auto dv = dv_cpu.cuda(); + std::vector bwd_ret; + + if (has_mask) { + bwd_ret = mha_bwd( + dout, + q, + k, + v, + ret[0], + ret[1], + dq, + dk, + dv, + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + gen_, + attn_mask, + attn_bias + ); + dump_tensor("attn_dq", dq, "has_mask"); + dump_tensor("attn_dk", dk, "has_mask"); + dump_tensor("attn_dv", dv, "has_mask"); + // dump_tensor("attn_ds", bwd_ret[5], "has_mask"); + }else{ + bwd_ret = mha_bwd( + dout, + q, + k, + v, + ret[0], + ret[1], + dq, + dk, + dv, + cu_seqlens_q, // b + 1 + cu_seqlens_k, // b + 1 + max_seqlen_q_, + max_seqlen_k_, + 0.0, + softmax_scale, + zero_tensors, + is_causal, + gen_, + attn_bias, + attn_bias + // placeholder + ); + dump_tensor("attn_dq", dq, ""); + dump_tensor("attn_dk", dk, ""); + dump_tensor("attn_dv", dv, ""); + } +} + +int main(int argc, char** argv){ + + if ( argc >= 2 ) { + std::cout << "argv: " << argv[1] << std::endl; + int seq = atoi(argv[1]); + + test_fwd_with_mask(seq); + + } + return 0; +} diff --git a/tests/test_torch_capi.cpp b/tests/test_torch_capi.cpp new file mode 100644 index 000000000..adfbd722b --- /dev/null +++ b/tests/test_torch_capi.cpp @@ -0,0 +1,60 @@ +#include +#include +#include +#include +#include +#include + + +torch::Tensor load_tensor(std::string filename) { + std::cout << filename << std::endl; + std::ifstream sfile(filename.c_str()); + + torch::Tensor tensor2; + torch::load(tensor2, sfile); + + // std::cout << tensor2 << std::endl; + return tensor2; +} + +int main(){ + + std::string label = ""; + std::string tensor_name = "input_mask"; + std::string sfile_name = label + "_" + tensor_name + ".pt"; + + // std::ifstream sfile(sfile_name.c_str()); + // torch::Tensor tensor2; + // torch::load(tensor2, sfile); + + torch::Tensor tensor_c = load_tensor(sfile_name); + std::cout << tensor_c << std::endl; + + + std::string python_file_name = "../" + label + "_" + tensor_name + ".pt"; + torch::Tensor tensor_python = load_tensor(python_file_name); + std::cout << tensor_python << std::endl; + + + // int batch_size = 2; + // int num_heads = 4; + // int max_seqlen_q = 8; + // int max_seqlen_k = 8; + + // auto bias = torch::ones({1, num_heads, max_seqlen_q, max_seqlen_k}); + // auto ds = torch::ones({batch_size, num_heads, max_seqlen_q, max_seqlen_k}); + // // batch_size, 1, num_heads, max_seqlen_q, max_seqlen_k + + + // auto shape = bias.sizes(); + // // auto newshape = std::vector(shape); + // // newshape.insert(newshape.begin(), -1); + // // std::cout << newshape << std::endl; + + // auto dbias = ds.reshape({-1, shape[0], shape[1], shape[2], shape[3] }).sum({0}); + + // std::cout << dbias.sizes() << std::endl; + return 0; +} + + From 5ad59e9177c7c508a5defc3315b3548da3f1f9ff Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 9 Sep 2022 16:42:37 +0800 Subject: [PATCH 39/71] fix seqlen greater than 256 --- csrc/flash_attn/fmha_api.cpp | 2 ++ csrc/flash_attn/src/fmha/gmem_tile.h | 30 +++++++++++---------- csrc/flash_attn/src/fmha/softmax.h | 2 +- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 1 + 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 672a24027..541f414a9 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -229,6 +229,8 @@ void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const // file << tensor << std::endl; std::cout << "tensor_name stride 0: " << tensor_name << " " << tensor.stride(0) << std::endl; std::cout << "tensor_name stride 1: " << tensor_name << " " << tensor.stride(1) << std::endl; + std::cout << "tensor_name stride 2: " << tensor_name << " " << tensor.stride(2) << std::endl; + std::cout << "tensor_name stride 3: " << tensor_name << " " << tensor.stride(-1) << std::endl; std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; auto flatten_tensor = tensor.flatten(); diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index a24f30e4c..2f8d6a92e 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -522,9 +522,9 @@ struct Gmem_tile_mma_mask { #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", + printf("init mask tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); - printf("bidb=%d, bidh=%d, param.h=%d, mask_head_mod_size=%d, mask_seq_mod_size=%d, loop_step_idx=%d\n", + printf("init mask bidb=%d, bidh=%d, param.h=%d, mask_head_mod_size=%d, mask_seq_mod_size=%d, loop_step_idx=%d\n", binfo.bidb, binfo.bidh, params.h, params.mask_head_mod_size, params.mask_seq_mod_size, loop_step_idx); printf("\n"); } @@ -562,16 +562,18 @@ struct Gmem_tile_mma_mask { ptrs[offset] = ptr_ + (uint32_t)(current_row % mask_seq_mod_size) * row_stride_in_bytes + (uint32_t)current_col * BYTES_PER_ELEMENT; + // preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) + // && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) - && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", + printf("load mask mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", mi, ni, ii, jj, offset, current_row, current_col, ptr_, ptrs[offset], preds[offset]); - printf("current_row=%d, current_col=%d, ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d, loop_step_idx=%d\n", - current_row, current_col, ROWS, actual_seqlen_q, COLS, actual_seqlen_k, loop_step_idx); - printf("cond 1=%d\n", (current_row <= min(ROWS, actual_seqlen_q))); - printf("cond 2=%d\n", ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); + printf("load mask ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d, loop_step_idx=%d, cond1=%d, cond2=%d\n", + ROWS, actual_seqlen_q, COLS, actual_seqlen_k, loop_step_idx, + (current_row <= min(ROWS, actual_seqlen_q)), + ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); printf("\n"); } #endif @@ -722,15 +724,15 @@ struct Gmem_tile_mma_bias { (uint32_t)current_col * BYTES_PER_ELEMENT; preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) - && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", + printf("load bias mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", mi, ni, ii, jj, offset, current_row, current_col, ptr_, ptrs[offset], preds[offset]); - printf("current_row=%d, current_col=%d, ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d\n", - current_row, current_col, ROWS, actual_seqlen_q, COLS, actual_seqlen_k); - printf("cond 1=%d\n", (current_row <= min(ROWS, actual_seqlen_q))); - printf("cond 2=%d\n", ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); + printf("load bias ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d, loop_step_idx=%d, cond1=%d, cond2=%d\n", + ROWS, actual_seqlen_q, COLS, actual_seqlen_k, loop_step_idx, + (current_row <= min(ROWS, actual_seqlen_q)), + ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); printf("\n"); } #endif diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index 9992d0509..a4899eda8 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -506,7 +506,7 @@ struct Softmax : public Softmax_base { this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; } #ifdef DEBUG_PRINT - if ((blockIdx.x == 0) && (blockIdx.y == 0)) { + if ((blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { printf("Attnmask: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, softmax = %f, l = %d, loop_step_idx=%d, blockIdx.x = %d\n", threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), this->elt_[2 * mi + ii][4 * ni + jj], l, loop_step_idx, blockIdx.x); } diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 5aa6546a2..291ffb86a 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -567,6 +567,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if (!(params.attn_mask_ptr == nullptr)) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; + fmha::clear(frag_mask); gmem_mask.template load(frag_mask); gmem_mask.move(); From 94597de5ede7db6584833e7ea32cd9a1d695e114 Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 9 Sep 2022 17:48:16 +0800 Subject: [PATCH 40/71] fix bias seqlen --- csrc/flash_attn/fmha_api.cpp | 17 +++++++++-------- csrc/flash_attn/src/fmha/gmem_tile.h | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 541f414a9..b02948ef8 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -233,14 +233,15 @@ void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const std::cout << "tensor_name stride 3: " << tensor_name << " " << tensor.stride(-1) << std::endl; std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; - auto flatten_tensor = tensor.flatten(); - auto size = flatten_tensor.numel(); - - for (int i = 0; i < size; i ++) { - file << flatten_tensor[i].item() << " "; - // file << flatten_tensor[i] << " "; - } - file << std::endl; + // cost too much time + // auto flatten_tensor = tensor.flatten(); + // auto size = flatten_tensor.numel(); + + // for (int i = 0; i < size; i ++) { + // file << flatten_tensor[i].item() << " "; + // // file << flatten_tensor[i] << " "; + // } + // file << std::endl; std::string sfile_name = label + "_" + tensor_name + ".pt"; std::ofstream sfile(sfile_name.c_str()); diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 2f8d6a92e..3e8aeb558 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -877,7 +877,7 @@ struct Gmem_tile_mma_ds { (uint32_t)current_col * BYTES_PER_ELEMENT; preds = (current_row < min(ROWS, actual_seqlen_q)) - && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); + && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); #ifdef DEBUG_PRINT if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { From fb7ef925d6a49e687050228bef86850b7cb1299d Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 15 Sep 2022 15:18:13 +0800 Subject: [PATCH 41/71] add constexpr --- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 8 ++ .../src/fmha_fprop_fp16_kernel.sm80.cu | 91 ++++++++++--------- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 65 ++----------- 3 files changed, 63 insertions(+), 101 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 7d1be82bc..868193959 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -27,6 +27,14 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" + + bool has_attn = !(params.attn_mask_ptr == nullptr); + bool has_bias = !(params.attn_bias_ptr == nullptr); + +#ifdef DEBUG_PRINT + printf ("has_attn=%d, has_bias=%d, bias_mod_size=%d\n", has_attn, has_bias, params.bias_mod_size); +#endif + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { auto kernel = params.is_causal diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 20d8c9c3e..a679ed01d 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -192,52 +192,53 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } + } + else if (launch_params.params.d == 32) { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k == 256 ) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } } - // else if (launch_params.params.d == 32) { - // if( launch_params.params.seqlen_k == 128 ) { - // using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else if( launch_params.params.seqlen_k == 256 ) { - // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { - // using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } - // } else if (launch_params.params.d == 64) { - // if( launch_params.params.seqlen_k == 128 ) { - // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else if( launch_params.params.seqlen_k >= 256 ) { - // if (dprops->major == 8 && dprops->minor >= 0) { - // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else if (dprops->major == 7 && dprops->minor == 5) { - // if (launch_params.is_dropout) { // Need to use the same block size as backward - // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { - // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } - // } - // } - // } else if (launch_params.params.d == 128) { - // if( launch_params.params.seqlen_k == 128 ) { - // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { - // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { - // // TD [2022-06-05] Keep K in registers to reduce register spilling - // // Gives about 6% speedup compared to using block size 128. - // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } else { // Need to use the same block size as backward - // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params, configure); - // } - // } - // } + else if (launch_params.params.d == 64) { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if( launch_params.params.seqlen_k >= 256 ) { + if (dprops->major == 8 && dprops->minor >= 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else if (dprops->major == 7 && dprops->minor == 5) { + if (launch_params.is_dropout) { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } + } + } else if (launch_params.params.d == 128) { + if( launch_params.params.seqlen_k == 128 ) { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { + if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { + // TD [2022-06-05] Keep K in registers to reduce register spilling + // Gives about 6% speedup compared to using block size 128. + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } else { // Need to use the same block size as backward + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fp16_sm80_loop_(launch_params, configure); + } + } + } // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 291ffb86a..488f59b95 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -374,20 +374,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_o.move(begin); gmem_o_tmp.move(begin); if (Return_softmax) { gmem_s.move(begin); } + gmem_softmax_lse.move(begin); - // if constexpr (has_attn) { - if (!(params.attn_mask_ptr == nullptr)) { + if constexpr (has_attn) { + // if (!(params.attn_mask_ptr == nullptr)) { // TODO: mask move gmem_mask.move(begin); } - // if constexpr (has_bias) { - if (!(params.attn_bias_ptr == nullptr)) { + if constexpr (has_bias) { + // if (!(params.attn_bias_ptr == nullptr)) { // TODO: bias move gmem_bias.move(begin); } - - gmem_softmax_lse.move(begin); fmha::Mask mask(binfo, tidx, loop_step_idx); @@ -409,15 +408,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); } - // if constexpr (has_attn) { - // if (!(params.attn_mask_ptr == nullptr)) { - // // TODO: mask move as s, with col move - // gmem_mask.move(loop_step_idx * steps_og); - // } - // // if constexpr (has_bias) { - // if (!(params.attn_bias_ptr == nullptr)) { - // gmem_bias.move(loop_step_idx * steps_og); - // } } // Trigger the loads for K. @@ -509,44 +499,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } #endif - // if constexpr (has_attn) { - // if (!(params.attn_mask_ptr == nullptr)) { - // using Frag_mask = fmha::Fragment_c; - // Frag_mask frag_mask[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - - // gmem_mask.load(frag_mask); - // // do we need sync ? - // __syncthreads(); - - // #pragma unroll - // for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - // #pragma unroll - // for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - // acc_p[mi][ni].addf(frag_mask[ni][mi]); - // } - // } - // gmem_mask.move(); - // } - - // if constexpr (has_bias) { - // if (!(params.attn_bias_ptr == nullptr)) { - // using Frag_bias = fmha::Fragment_c; - - // Frag_bias frag_bias[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - // gmem_bias.load(frag_bias); - - // __syncthreads(); - - // #pragma unroll - // for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - // #pragma unroll - // for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - // acc_p[mi][ni].addf(frag_bias[ni][mi]); - // } - // } - // gmem_bias.move(); - // } - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; if (!Is_first) { gmem_o_tmp.load(out, 0); } @@ -563,8 +515,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - // if constexpr (has_attn) { - if (!(params.attn_mask_ptr == nullptr)) { + if constexpr (has_attn) { + // if (!(params.attn_mask_ptr == nullptr)) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::clear(frag_mask); @@ -622,7 +574,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i #endif } - if (!(params.attn_bias_ptr == nullptr)) { + if constexpr (has_bias) { + // if (!(params.attn_bias_ptr == nullptr)) { using Frag_Bias = fmha::Fragment_c; Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::clear(frag_bias); From 4efdf9e9aecee7d7da59a10e97ecfc9104145448 Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 15 Sep 2022 20:03:11 +0800 Subject: [PATCH 42/71] add const expr for bwd --- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 145 ++++++++++++++---- .../src/fmha_dgrad_kernel_1xN_loop.h | 33 ++-- 2 files changed, 136 insertions(+), 42 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 868193959..5678377a4 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -5,9 +5,9 @@ #include "fmha.h" #include "fmha_dgrad_kernel_1xN_loop.h" -template +template __global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { - fmha::compute_dq_dk_dv_1xN(params); + fmha::compute_dq_dk_dv_1xN(params); } template @@ -35,32 +35,125 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ printf ("has_attn=%d, has_bias=%d, bias_mod_size=%d\n", has_attn, has_bias, params.bias_mod_size); #endif - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { - auto kernel = params.is_causal - ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel - : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; - if (params.seqlen_k == blocksize_c) { - kernel = params.is_causal - ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel - : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; - } else if (params.seqlen_k == blocksize_c * 2) { - kernel = params.is_causal - ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel - : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (has_attn) { + if (has_bias) { + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); } - if( smem_size_dq_dk_dv >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + }else{ + if (has_bias) { + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); + }else{ + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + auto kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + if (params.seqlen_k == blocksize_c) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = params.is_causal + ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel + : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + } + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); } - dim3 grid(params.b, params.h); - kernel<<>>(params); -#ifdef DEBUG_PRINT - printf("bwd grid size: %d %d\n", params.b, params.h); - printf("bwd block size: %d\n", Kernel_traits::THREADS); -#endif - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); + } + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. +// BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { +// auto kernel = params.is_causal +// ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel +// : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; +// if (params.seqlen_k == blocksize_c) { +// kernel = params.is_causal +// ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel +// : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; +// } else if (params.seqlen_k == blocksize_c * 2) { +// kernel = params.is_causal +// ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel +// : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; +// } +// if( smem_size_dq_dk_dv >= 48 * 1024 ) { +// FMHA_CHECK_CUDA(cudaFuncSetAttribute( +// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); +// } +// dim3 grid(params.b, params.h); +// kernel<<>>(params); +// #ifdef DEBUG_PRINT +// printf("bwd grid size: %d %d\n", params.b, params.h); +// printf("bwd block size: %d\n", Kernel_traits::THREADS); +// #endif +// FMHA_CHECK_CUDA(cudaPeekAtLastError()); +// }); } void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 888eebf78..f01e37cd8 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -31,7 +31,7 @@ inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph, const int loop_step_idx) { @@ -213,14 +213,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_softmax_lse.move(begin); gmem_softmax_d.move(begin); - // if constexpr (has_attn) { - if (!(params.attn_mask_ptr == nullptr)) { + if constexpr (has_attn) { + // if (!(params.attn_mask_ptr == nullptr)) { // TODO: mask move gmem_mask.move(begin); } - // if constexpr (has_attn) { - if (!(params.attn_bias_ptr == nullptr)) { + if constexpr (has_attn) { + // if (!(params.attn_bias_ptr == nullptr)) { // TODO: mask move gmem_bias.move(begin); gmem_ds.move(begin); @@ -356,8 +356,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - // if constexpr (has_attn) { - if (!(params.attn_mask_ptr == nullptr)) { + if constexpr (has_attn) { + // if (!(params.attn_mask_ptr == nullptr)) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; gmem_mask.template load(frag_mask); @@ -414,7 +414,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng #endif } - if (!(params.attn_bias_ptr == nullptr)) { + if constexpr (has_bias) { + // if (!(params.attn_bias_ptr == nullptr)) { using Frag_Bias = fmha::Fragment_c; Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; gmem_bias.template load(frag_bias); @@ -853,7 +854,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N. // This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2. -template +template inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; @@ -869,20 +870,20 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); if (loop_steps == 1) { - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); + compute_dq_dk_dv_1xN_one_iter(params, ph, 0); } else if (loop_steps == 2) { - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); - compute_dq_dk_dv_1xN_one_iter(params, ph, 1); + compute_dq_dk_dv_1xN_one_iter(params, ph, 0); + compute_dq_dk_dv_1xN_one_iter(params, ph, 1); } else { if (params.seqlen_k == blocksize_c) { - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); + compute_dq_dk_dv_1xN_one_iter(params, ph, 0); } else { const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; - compute_dq_dk_dv_1xN_one_iter(params, ph, 0); + compute_dq_dk_dv_1xN_one_iter(params, ph, 0); for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { - compute_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); + compute_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); } - compute_dq_dk_dv_1xN_one_iter(params, ph, max_loop_steps - 1); + compute_dq_dk_dv_1xN_one_iter(params, ph, max_loop_steps - 1); } } } From 00d3e03399a2cee4fe329ae4ce556391ee7899eb Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 16 Sep 2022 11:08:19 +0800 Subject: [PATCH 43/71] add benchmark --- benchmarks/correctness/attention.py | 44 +++ benchmarks/correctness/benchmark_memory.py | 128 +++++++ benchmarks/correctness/check_correct.py | 336 ++++++++++++++++++ .../correctness/check_speed_backward.py | 131 +++++++ benchmarks/correctness/check_speed_forward.py | 128 +++++++ benchmarks/correctness/flash_attention.py | 63 ++++ benchmarks/correctness/torch_attention.py | 51 +++ 7 files changed, 881 insertions(+) create mode 100644 benchmarks/correctness/attention.py create mode 100644 benchmarks/correctness/benchmark_memory.py create mode 100644 benchmarks/correctness/check_correct.py create mode 100644 benchmarks/correctness/check_speed_backward.py create mode 100644 benchmarks/correctness/check_speed_forward.py create mode 100644 benchmarks/correctness/flash_attention.py create mode 100644 benchmarks/correctness/torch_attention.py diff --git a/benchmarks/correctness/attention.py b/benchmarks/correctness/attention.py new file mode 100644 index 000000000..1f89e52f2 --- /dev/null +++ b/benchmarks/correctness/attention.py @@ -0,0 +1,44 @@ +import torch +from typing import Optional, Callable, List, Tuple, Sequence + +from unicore.modules import softmax_dropout + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +def _attention(query, key, value, mask=None, bias=None, upcast=False) -> torch.Tensor: + # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + # output back to fp16/bf16. + dtype_og = query.dtype + + if upcast: + query = query.float() + key = key.float() + value = value.float() + if mask is not None: + mask = mask.float() + if bias is not None: + bias = bias.float() + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + # if biases is not None: + # a += biases + + # if mask is not None: + # a.masked_fill_(mask < 0, float('-inf')) + + # a = softmax_no_cast(a, -1) + a = softmax_dropout(a, dropout_prob=0, is_training=True, mask=mask, bias=bias) + + # [*, H, Q, C_hidden] + b = torch.matmul(a, value) + + return b.to(dtype_og) diff --git a/benchmarks/correctness/benchmark_memory.py b/benchmarks/correctness/benchmark_memory.py new file mode 100644 index 000000000..4107b8012 --- /dev/null +++ b/benchmarks/correctness/benchmark_memory.py @@ -0,0 +1,128 @@ +import torch +import torch.utils.benchmark as benchmark + +from flash_attention import _flash_attn +from attention import _attention +from torch_attention import _torch_attention + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--has_mask_bias", required=False, help="add bias in attention", type=bool, default=False) +parser.add_argument("--eval", required=False, help="test whether has backward", type=bool, default=False) + +args = parser.parse_args() +print(args) + + +def benchmark_memory(fn, inputs, mask=None, bias=None, grad=None, eval=True, desc='', verbose=False, **kwinputs): + def fwd(grad, inputs, mask=mask, bias=bias, **kwinputs): + with torch.no_grad(): + y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) + + + def fwd_bwd(grad, inputs, mask=mask, bias=bias, **kwinputs): + y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError('Grad shape does not match output shape') + y.backward(grad, retain_graph=False) + + if eval: + f = fwd + if verbose: + print ("using fwd func...") + else: + f = fwd_bwd + if verbose: + print ("using fwd and bwd func...") + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + f(None, inputs, mask, bias) + + torch.cuda.synchronize() + mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000) + if verbose: + print(f"{desc} max memory: ", mem) + torch.cuda.empty_cache() + return mem + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True, eval=True): + bs = 1 + head = 4 + c_dim = 32 + seq_q = seq_k = seq_v = seqlen + dtype = torch.bfloat16 + device = "cuda" + + inputs = torch.empty((bs, seq_q, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) + inputs.requires_grad = True + if verbose: + print ("inputs shape: ", inputs.shape) + # [bs, seq, seq, head, c_dim] + + if has_bias: + bias = torch.randn( + 1, 1, head, seq_q, seq_k, dtype=dtype, device=device + ) + bias.requires_grad = True + if verbose: + print ("bias shape: ", bias.shape) + # [1, 1, seq, head, seq_k] + else: + bias = None + + if has_mask: + mask = gen_attn_mask( + ( + torch.rand(bs, seq_q, 1, 1, seq_k, dtype=dtype, device=device,) > 0.2 + ).type(dtype), + -3e4, + ) + if verbose: + print ("mask shape: ", mask.shape) + else: + mask = None + + print ("processing seq length: {} in eval model {} ......".format(seqlen, eval)) + + try: + m1 = benchmark_memory(_attention, inputs, mask=mask, bias=bias, eval=eval, desc='Normal Attention forward') + print (m1) + except: + print ("Normal Attention OOM") + + try: + m2 = benchmark_memory(_flash_attn, inputs, mask=mask, bias=bias, eval=eval, desc='Flash Attention forward') + print (m2) + except: + print ("Flash Attention OOM") + + +for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]: + if args.has_mask_bias: + if not args.eval: + fun(seqlen=seqlen, eval=False) + else: + fun(seqlen=seqlen, eval=True) + else: + if not args.eval: + fun(seqlen=seqlen, has_bias=None, has_mask=None, eval=False) + else: + fun(seqlen=seqlen, has_bias=None, has_mask=None, eval=True) + diff --git a/benchmarks/correctness/check_correct.py b/benchmarks/correctness/check_correct.py new file mode 100644 index 000000000..032909f74 --- /dev/null +++ b/benchmarks/correctness/check_correct.py @@ -0,0 +1,336 @@ +import torch + +# from attention import _attention +from torch_attention import _torch_attention as _attention +from flash_attention import _flash_attn + +import numpy as np +import pytest + + +def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): + diff = np.abs(pred - gt) + + cnt = 0 + for index, x in np.ndenumerate(diff): + if x > abs_eps: + relative_diff = np.abs(x / gt[index]) + if relative_diff > relative_rps: + cnt += 1 + if verbose: + print (index, x, gt[index], relative_diff) + + if cnt > 0: + print ("not so match") + return False + else: + return True + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +# # @pytest.mark.parametrize('c_dim', [64, 32, 16]) +# @pytest.mark.parametrize('c_dim', [16]) +# @pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 1536, 2048]) +def test_flash_attn_unpadded_shape1(seqlen, c_dim, dtype, device = "cuda"): + # mini + # bs = 2 + # head = 8 + # c_dim = 16 + # seq_q = seq_k = seq_v = 128 + # dtype = torch.half + # device = "cuda" + + bs = 1 + head = 1 + c_dim = c_dim + bs_seq = 1 + seq_q = seq_k = seq_v = seqlen + dtype = dtype + device = device + + inputs = torch.empty((bs, bs_seq, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) + # debug data + # inputs = torch.zeros((bs, bs_seq, seq_q, head, c_dim), dtype=dtype, device=device) + # cnt = 0 + # for i in range(bs): + # for j in range(bs_seq): + # for k in range(seq_q): + # for l in range(head): + # for m in range(c_dim): + # inputs[i][j][k][l][m] = (cnt % 10000) * 0.001 + # cnt += 1 + + # inputs = inputs.permute(0, 1, 3, 2, 4) + inputs.requires_grad = True + + print ("inputs shape: ", inputs.shape) + # [bs, seq, seq, head, c_dim] + + bias = torch.randn( + 1, 1, head, seq_q, seq_k, dtype=dtype, device=device + ) + bias.requires_grad = True + + print ("bias shape: ", bias.shape) + # [1, 1, seq, head, seq_k] + + mask = gen_attn_mask( + ( + torch.randn((bs, bs_seq, 1, 1, seq_k), dtype=dtype, device=device,) > 0.2 + ).type(dtype), + -3e4, + ) + # [bs, bs_seq, head, 1, seq_k] + + # debug data + # mask = torch.ones(bs, bs_seq, head, 1, seq_k, dtype=dtype, device=device,) * -1 + # for i in range(bs): + # for j in range(bs_seq): + # for k in range(head): + # for l in range(1): + # for m in range(seq_q): + # if m % 2 == 0: + # mask[i][j][k][l][m] = 0 + + # mask = mask.expand(bs, bs_seq, head, seq_q, seq_k) + print ("mask shape: ", mask.shape) + + # bias = None + # mask = None + # [bs, seq_q, 1, 1, seq_k] + + normal_attn_v1 = inputs.clone() + output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, bias=bias, mask=mask, upcast=True) + output_ref = output_ref.transpose(-2, -3) + print ("attention ref output shape: ", output_ref.shape) + + normal_attn_v2 = inputs.clone() + output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, bias=bias, mask=mask) + # be careful here + output_pt = output_pt.transpose(-2, -3) + print ("attention output shape: ", output_pt.shape) + + normal_attn_flash = inputs.clone() + output_flash = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, bias=bias, mask=mask) + print ("flash attn output shape: ", output_flash.shape) + + print (10 * "*" + "comparing forward" + 10 * "*" ) + # fp32 result + print("Output max diff: {0}".format((output_flash - output_ref).abs().max().item())) + print("Output mean diff: {0}".format((output_flash - output_ref).abs().mean().item())) + + print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) + print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) + + print("Output max diff with Pytorch: {0}".format((output_flash - output_pt).abs().max().item())) + print("Output mean diff with Pytorch: {0}".format((output_flash - output_pt).abs().mean().item())) + + # Check that FlashAttention's numerical error is at most twice the numerical error of a Pytorch implementation. + print ("less than twice error: ", (output_flash - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) + print () + + g = torch.randn_like(output_flash) + # dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, ), g) + # dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, ), g) + # dq, dk, dv, = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, ), g) + + dq_ref, dk_ref, dv_ref, dbias_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, bias), g) + dq_pt, dk_pt, dv_pt, dbias_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, bias), g) + dq, dk, dv, dbias = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, bias), g) + + print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) + print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) + print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) + + print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) + print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) + print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) + + print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) + print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) + print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) + + print ("dq less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) + print ("dk less than twice error: ", ((dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()) ) + print ("dv less than twice error: ", ((dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()) ) + + if bias is not None: + print ("dbias less than twice error: ", ((dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item()) ) + + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item(), "dq larger than twice error" + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item(), "dq larger than twice error" + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item(), "dq larger than twice error" + + if bias is not None: + print("Output dbias max diff: {0}".format( (dbias - dbias_ref).abs().max().item() )) + print("Pytorch dbias max diff: {0}".format( (dbias - dbias_pt).abs().max().item() )) + assert (dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item(), "dbias larger than twice error" + + + +# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +# # @pytest.mark.parametrize('c_dim', [64, 32, 16]) +# @pytest.mark.parametrize('c_dim', [16]) +# @pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 1536, 2048]) +def test_flash_attn_unpadded_shape2(seqlen, c_dim, dtype, device = "cuda"): + # mini + # bs = 2 + # head = 8 + # c_dim = 16 + # seq_q = seq_k = seq_v = 128 + # dtype = torch.half + # device = "cuda" + + bs = 1 + head = 1 + c_dim = c_dim + bs_seq = 1 + seq_q = seq_k = seq_v = seqlen + dtype = dtype + device = device + + inputs = torch.empty((bs, bs_seq, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) + # debug data + # inputs = torch.zeros((bs, bs_seq, seq_q, head, c_dim), dtype=dtype, device=device) + # cnt = 0 + # for i in range(bs): + # for j in range(bs_seq): + # for k in range(seq_q): + # for l in range(head): + # for m in range(c_dim): + # inputs[i][j][k][l][m] = (cnt % 10000) * 0.001 + # cnt += 1 + + # inputs = inputs.permute(0, 1, 3, 2, 4) + inputs.requires_grad = True + + print ("inputs shape: ", inputs.shape) + # [bs, seq, seq, head, c_dim] + + bias = torch.randn( + 1, bs_seq, head, seq_q, seq_k, dtype=dtype, device=device + ) + bias.requires_grad = True + + print ("bias shape: ", bias.shape) + # [1, 1, seq, head, seq_k] + + mask = gen_attn_mask( + ( + torch.randn((bs, bs_seq, head, 1, seq_k), dtype=dtype, device=device,) > 0.2 + ).type(dtype), + -3e4, + ) + # [bs, bs_seq, head, 1, seq_k] + + # debug data + # mask = torch.ones(bs, bs_seq, head, 1, seq_k, dtype=dtype, device=device,) * -1 + # for i in range(bs): + # for j in range(bs_seq): + # for k in range(head): + # for l in range(1): + # for m in range(seq_q): + # if m % 2 == 0: + # mask[i][j][k][l][m] = 0 + + # mask = mask.expand(bs, bs_seq, head, seq_q, seq_k) + print ("mask shape: ", mask.shape) + + # bias = None + # mask = None + # [bs, seq_q, 1, 1, seq_k] + + # np.savetxt("inputs_flash_seq{0}.data".format(seqlen), inputs.detach().cpu().numpy().flatten(), delimiter=" ") + # if mask is not None: + # np.savetxt("attn_mask_flash_seq{0}.data".format(seqlen), mask.detach().cpu().numpy().flatten(), delimiter=" ") + + normal_attn_v1 = inputs.clone() + output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, bias=bias, mask=mask, upcast=True) + output_ref = output_ref.transpose(-2, -3) + print ("attention ref output shape: ", output_ref.shape) + + normal_attn_v2 = inputs.clone() + output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, bias=bias, mask=mask) + # be careful here + output_pt = output_pt.transpose(-2, -3) + print ("attention output shape: ", output_pt.shape) + + normal_attn_flash = inputs.clone() + output_flash = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, bias=bias, mask=mask) + print ("flash attn output shape: ", output_flash.shape) + # [bs, bs_seq, head, seq_k c_dim] + + # np.savetxt("output_torch_seq{0}.data".format(seqlen), output_pt.detach().cpu().numpy().flatten(), delimiter=" ") + # np.savetxt("output_flash_seq{0}.data".format(seqlen), output_flash.detach().cpu().numpy().flatten(), delimiter=" ") + + print (10 * "*" + "comparing forward" + 10 * "*" ) + # fp32 result + print("Output max diff: {0}".format((output_flash - output_ref).abs().max().item())) + print("Output mean diff: {0}".format((output_flash - output_ref).abs().mean().item())) + + print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) + print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) + + print("Output max diff with Pytorch: {0}".format((output_flash - output_pt).abs().max().item())) + print("Output mean diff with Pytorch: {0}".format((output_flash - output_pt).abs().mean().item())) + + # Check that FlashAttention's numerical error is at most twice the numerical error of a Pytorch implementation. + print ("less than twice error: ", (output_flash - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) + print () + assert (output_flash - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() + + g = torch.randn_like(output_flash) + # dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, ), g) + # dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, ), g) + # dq, dk, dv, = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, ), g) + + dq_ref, dk_ref, dv_ref, dbias_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, bias), g) + dq_pt, dk_pt, dv_pt, dbias_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, bias), g) + dq, dk, dv, dbias = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, bias), g) + + print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) + print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) + print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) + + print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) + print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) + print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) + + print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) + print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) + print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) + + print ("dq less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) + print ("dk less than twice error: ", ((dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()) ) + print ("dv less than twice error: ", ((dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()) ) + + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item(), "dq larger than twice error" + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item(), "dq larger than twice error" + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item(), "dq larger than twice error" + + if bias is not None: + print("Output dbias max diff: {0}".format( (dbias - dbias_ref).abs().max().item() )) + print("Pytorch dbias max diff: {0}".format( (dbias - dbias_pt).abs().max().item() )) + assert (dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item(), "dbias larger than twice error" + + +# for dtype in [torch.float16]: +# # for dtype in [torch.float16, torch.bfloat16]: +# for c_dim in [16]: +# for seqlen in [64, 128, 256, 512]: +# print ("dtype={}, c_dim={}, seqlen={}".format(dtype, c_dim, seqlen)) +# test_flash_attn_unpadded_shape1(seqlen, c_dim, dtype) + + +# for dtype in [torch.float16]: +for dtype in [torch.float16, torch.bfloat16]: + for c_dim in [16, 32, 64]: + for seqlen in [64, 128, 256, 512, 1024, 2048]: + print ("dtype={}, c_dim={}, seqlen={}".format(dtype, c_dim, seqlen)) + test_flash_attn_unpadded_shape2(seqlen, c_dim, dtype) \ No newline at end of file diff --git a/benchmarks/correctness/check_speed_backward.py b/benchmarks/correctness/check_speed_backward.py new file mode 100644 index 000000000..4ef74bd6d --- /dev/null +++ b/benchmarks/correctness/check_speed_backward.py @@ -0,0 +1,131 @@ +import torch +import torch.utils.benchmark as benchmark + +from flash_attention import _flash_attn +from attention import _attention +from torch_attention import _torch_attention + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--has_mask_bias", required=False, help="add bias in attention", type=bool, default=False) + +args = parser.parse_args() +print(args) + + +def benchmark_combined(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): + """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ + if verbose: + print(desc, '- Forward + Backward pass') + + def f(grad, inputs, mask=mask, bias=bias, **kwinputs): + y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError('Grad shape does not match output shape') + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', + globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_forward(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): + """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ + if verbose: + print(desc, '- Forward pass with no grad') + + def f(grad, inputs, mask=mask, bias=bias, **kwinputs): + with torch.no_grad(): + y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) + + t = benchmark.Timer( + stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', + globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True): + bs = 1 + head = 4 + c_dim = 32 + seq_q = seq_k = seq_v = seqlen + dtype = torch.bfloat16 + device = "cuda" + + inputs = torch.empty((bs, seq_q, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) + inputs.requires_grad = True + if verbose: + print ("inputs shape: ", inputs.shape) + # [bs, seq, seq, head, c_dim] + + if has_bias: + bias = torch.randn( + 1, 1, head, seq_q, seq_k, dtype=dtype, device=device + ) + bias.requires_grad = True + if verbose: + print ("bias shape: ", bias.shape) + # [1, 1, seq, head, seq_k] + else: + bias = None + + if has_mask: + mask = gen_attn_mask( + ( + torch.rand(bs, seq_q, 1, 1, seq_k, dtype=dtype, device=device,) > 0.2 + ).type(dtype), + -3e4, + ) + if verbose: + print ("mask shape: ", mask.shape) + else: + mask = None + + print ("processing seq length: {} ......".format(seqlen)) + try: + t1, m1 = benchmark_combined(_attention, inputs, mask=mask, bias=bias, repeats=100, desc='Normal Attention forward') + # import pdb; pdb.set_trace() + # print (m1) + # raw_times / number_per_run * 1000 ms + print (m1.raw_times[0]) + except: + print ("normal attention OOM") + + try: + t2, m2 = benchmark_combined(_flash_attn, inputs, mask=mask, bias=bias, repeats=100, desc='Flash Attention forward') + # print (m2) + print (m2.raw_times[0]) + except: + print ("flash attention OOM") + + +for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]: + if has_mask_bias: + fun(seqlen=seqlen) + else: + fun(seqlen=seqlen, has_bias=None, has_mask=None) + diff --git a/benchmarks/correctness/check_speed_forward.py b/benchmarks/correctness/check_speed_forward.py new file mode 100644 index 000000000..653bb00ba --- /dev/null +++ b/benchmarks/correctness/check_speed_forward.py @@ -0,0 +1,128 @@ +import torch +import torch.utils.benchmark as benchmark + +from flash_attention import _flash_attn +from attention import _attention +from torch_attention import _torch_attention + +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--has_mask_bias", required=False, help="add bias in attention", type=bool, default=False) + +args = parser.parse_args() +print(args) + + +def benchmark_combined(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): + """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ + if verbose: + print(desc, '- Forward + Backward pass') + + def f(grad, inputs, mask=mask, bias=bias, **kwinputs): + y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError('Grad shape does not match output shape') + y.backward(grad, retain_graph=True) + + t = benchmark.Timer( + stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', + globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_forward(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): + """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ + if verbose: + print(desc, '- Forward pass with no grad') + + def f(grad, inputs, mask=mask, bias=bias, **kwinputs): + with torch.no_grad(): + y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) + + t = benchmark.Timer( + stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', + globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True): + bs = 1 + head = 4 + c_dim = 32 + seq_q = seq_k = seq_v = seqlen + dtype = torch.bfloat16 + device = "cuda" + + inputs = torch.empty((bs, seq_q, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) + inputs.requires_grad = True + if verbose: + print ("inputs shape: ", inputs.shape) + # [bs, seq, seq, head, c_dim] + + if has_bias: + bias = torch.randn( + 1, 1, head, seq_q, seq_k, dtype=dtype, device=device + ) + bias.requires_grad = True + if verbose: + print ("bias shape: ", bias.shape) + # [1, 1, seq, head, seq_k] + else: + bias = None + + if has_mask: + mask = gen_attn_mask( + ( + torch.rand(bs, seq_q, 1, 1, seq_k, dtype=dtype, device=device,) > 0.2 + ).type(dtype), + -3e4, + ) + if verbose: + print ("mask shape: ", mask.shape) + else: + mask = None + + print ("processing seq length: {} ......".format(seqlen)) + try: + t1, m1 = benchmark_forward(_attention, inputs, mask=mask, bias=bias, repeats=100, desc='Normal Attention forward') + # print (m1) + print (m1.raw_times[0]) + except: + print ("normal attention OOM") + + try: + t2, m2 = benchmark_forward(_flash_attn, inputs, mask=mask, bias=bias, repeats=100, desc='Flash Attention forward') + # print (m2) + print (m2.raw_times[0]) + except: + print ("flash attention OOM") + + +for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]: + if args.has_mask_bias: + fun(seqlen=seqlen) + else: + fun(seqlen=seqlen, has_bias=None, has_mask=None) diff --git a/benchmarks/correctness/flash_attention.py b/benchmarks/correctness/flash_attention.py new file mode 100644 index 000000000..cdf1364c5 --- /dev/null +++ b/benchmarks/correctness/flash_attention.py @@ -0,0 +1,63 @@ + +import torch +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + + +def _flash_attn(q, k, v, mask=None, bias=None): + batch_dims = q.shape[:-3] + no_heads, n, c = q.shape[-3:] + dtype = q.dtype + + # [*, B, N, H, C] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + # [B_flat, N, H, C] + q = q.reshape(-1, *q.shape[-3:]) + k = k.reshape(-1, *k.shape[-3:]) + v = v.reshape(-1, *v.shape[-3:]) + + # Flattened batch size + batch_size = q.shape[0] + + # [B_flat * N, H, C] + q = q.reshape(-1, *q.shape[-2:]) + k = k.reshape(-1, *k.shape[-2:]) + v = v.reshape(-1, *v.shape[-2:]) + + q_max_s = n + q_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device + ) + + k_max_s = n + k_cu_seqlens = torch.arange( + 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device + ) + + if mask is not None: + mask_heads, tgt_len, src_len = mask.shape[-3:] + mask = mask.reshape(-1 , mask_heads, tgt_len, src_len).contiguous() + + if bias is not None: + bias_heads, tgt_len, src_len = bias.shape[-3:] + bias = bias.reshape(-1 , bias_heads, tgt_len, src_len).contiguous() + + out = flash_attn_unpadded_func( + q, + k, + v, + q_cu_seqlens, + k_cu_seqlens, + q_max_s, + k_max_s, + attn_mask=mask, + attn_bias=bias, + dropout_p = 0., + softmax_scale = 1., # q has been scaled already + ) + + # [*, B, N, H, C] + out = out.reshape(*batch_dims, n, no_heads, c) + return out diff --git a/benchmarks/correctness/torch_attention.py b/benchmarks/correctness/torch_attention.py new file mode 100644 index 000000000..277f45b88 --- /dev/null +++ b/benchmarks/correctness/torch_attention.py @@ -0,0 +1,51 @@ +import torch +from typing import Optional, Callable, List, Tuple, Sequence + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of + type bfloat16 + """ + d = t.dtype + s = torch.nn.functional.softmax(t, dim=dim) + return s + +def _torch_attention(query, key, value, mask=None, bias=None, upcast=False) -> torch.Tensor: + # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + # output back to fp16/bf16. + dtype_og = query.dtype + if upcast: + query = query.float() + key = key.float() + value = value.float() + if mask is not None: + mask = mask.float() + if bias is not None: + bias = bias.float() + + # [*, H, C_hidden, K] + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + a = torch.matmul(query, key) + + if bias is not None: + a += bias + + if mask is not None: + a.masked_fill_(mask < 0, float('-inf')) + # a += mask + + a = softmax_no_cast(a, -1) + + # [*, H, Q, C_hidden] + b = torch.matmul(a, value) + + return b.to(dtype_og) From 24b55bd636cd1bffd2a077f214666ba88af28fdf Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 16 Sep 2022 11:16:01 +0800 Subject: [PATCH 44/71] add test tools --- tests/tools/rebuild_bwd_attn_mask.py | 64 +++++++++++++++ tests/tools/rebuild_fwd_softmax.py | 113 +++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 tests/tools/rebuild_bwd_attn_mask.py create mode 100644 tests/tools/rebuild_fwd_softmax.py diff --git a/tests/tools/rebuild_bwd_attn_mask.py b/tests/tools/rebuild_bwd_attn_mask.py new file mode 100644 index 000000000..790f49642 --- /dev/null +++ b/tests/tools/rebuild_bwd_attn_mask.py @@ -0,0 +1,64 @@ +from parse import parse +import sys +import numpy as np + +filename = "./output.log" +if len(sys.argv) > 1: + filename = sys.argv[1] + +# bwd softmax: threadIdx=195, l=0, mi=0, ki=1, ii=3, jj=0, elt=0.000000 +format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' +batch_size = 1 +nheads = 1 +headdim = 16 +seq = 8 +seq_q = 8 +max_seqlen_q_ = seq_q +max_seqlen_k_ = seq_q + + +d_softmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) + +def parse_dsoftmax_load(filename): + with open(filename, "r") as f: + for line in f.readlines(): + # print (line) + if line.startswith("bwd softmax: "): + # print (line.strip()) + result = parse(format_string, line.strip()) + # print (result) + + tidx_ = int(result[0]) + mi = int(result[2]) + ni = int(result[3]) + ii = int(result[4]) + jj = int(result[5]) + value = float(result[6]) + + warp = tidx_ // 32 + lane = tidx_ % 32 + + warp_n = (warp // 1) + warp_m = (warp % 1) + + quad = lane // 4 + tid = (lane % 4) * 2 + + row = warp_m * 16 + quad + col = warp_n * 16 + tid + + current_row = mi * 16 + ii * 8 + row + # current_col = ni * 64 + jj * 8 + col + # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col + current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + if (current_row < 8 and current_col < 8): + print (line.strip()) + print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + warp, lane, quad, tid, current_row, current_col, value)) + d_softmax[0, 0, current_row, current_col] = value + + +parse_dsoftmax_load(filename) +print ("output block 0 d_softmax: ", d_softmax[0, 0, :, :]) diff --git a/tests/tools/rebuild_fwd_softmax.py b/tests/tools/rebuild_fwd_softmax.py new file mode 100644 index 000000000..3158ac2a9 --- /dev/null +++ b/tests/tools/rebuild_fwd_softmax.py @@ -0,0 +1,113 @@ +from parse import parse +import sys +import numpy as np + +def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): + diff = np.abs(pred - gt) + + cnt = 0 + for index, x in np.ndenumerate(diff): + if x > abs_eps: + if abs(gt[index]) < 1e-9: + relative_diff = 100 + else: + relative_diff = np.abs(x / gt[index]) + if relative_diff > relative_rps: + cnt += 1 + if verbose: + print ("index={0}, diff={1}, pred={2}, true={3}, relative_diff={4}".format( + index, x, pred[index], gt[index], relative_diff)) + + if cnt > 0: + print ("not so match") + return False + else: + return True + +filename = "./output512.log" +if len(sys.argv) > 1: + filename = sys.argv[1] + +# Attnmask: threadIdx.x = 98, threadIdx.y = 0, mi = 0, ni = 0, ii = 0, jj = 2, value = 0.000000, softmax = 0.608030, l = 0, loop_step_idx=1, blockIdx.x = 0 +format_string = 'Attnmask: threadIdx.x = {0}, threadIdx.y = {1}, mi = {2}, ni = {3}, ii = {4}, jj = {5}, value = {6}, softmax = {7}, l = {8}, loop_step_idx={9}, blockIdx.x = {10}' +batch_size = 1 +nheads = 1 +headdim = 16 +bs_seq = 1 +seq_q = 512 +max_seqlen_q_ = seq_q +max_seqlen_k_ = seq_q + +Cta_tile_p_N = 256 +Cta_tile_p_M = 16 + + +def parse_fwd_softmax_load(filename): + print ("processing... reconstruct from ", filename) + softmax_data = np.zeros([batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) + with open(filename, "r") as f: + for line in f.readlines(): + # print (line) + if line.startswith("Attnmask: "): + # print (line.strip()) + result = parse(format_string, line.strip()) + # print (result) + + tidx_ = int(result[0]) + mi = int(result[2]) + ni = int(result[3]) + ii = int(result[4]) + jj = int(result[5]) + value = float(result[6]) + softmax_elt = float(result[7]) + q_loop = int(result[8]) + k_loop = int(result[9]) + + warp = tidx_ // 32 + lane = tidx_ % 32 + # thread per warp = 32 + + warp_n = (warp // 1) + warp_m = (warp % 1) + # WARPS_M = 1 + + quad = lane // 4 + tid = (lane % 4) * 2 + + row = warp_m * 16 + quad + col = warp_n * 16 + tid + + current_row = Cta_tile_p_M * q_loop + mi * 16 + ii * 8 + row + + current_col = k_loop * Cta_tile_p_N + ni * 64 + (jj & 2) * 4 + (jj & 1) + col + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + if current_col > 510: + print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + warp, lane, quad, tid, current_row, current_col, value)) + print (line.strip()) + if (current_row < 16 and current_col < 512): + # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( + # warp, lane, quad, tid, current_row, current_col, value)) + # print ("") + softmax_data[0, 0, current_row, current_col] = value + + return softmax_data + + +softmax_cpp = parse_fwd_softmax_load(filename) +softmax_python = parse_fwd_softmax_load("../" + filename) + + +print (is_same_matrix(softmax_cpp, softmax_python, verbose=True)) + +for i in range(16): + print ("first part idx = {} softmax cpp = {}: ".format(i, softmax_cpp[0, 0, i, :256])) + print ("first part idx = {} softmax python = {}: ".format(i, softmax_python[0, 0, i, :256])) + + print (np.allclose(softmax_cpp[0, 0, i, :256],softmax_python[0, 0, i, :256])) + + print ("second part idx = {} softmax cpp = {}: ".format(i, softmax_cpp[0, 0, i, 256:])) + print ("second part idx = {} softmax python = {}: ".format(i, softmax_python[0, 0, i, 256:])) + + print (np.allclose(softmax_cpp[0, 0, i, 256:],softmax_python[0, 0, i, 256:])) From a0b4891c594121b9411a1ca2cb955c387366d798 Mon Sep 17 00:00:00 2001 From: robotcator Date: Fri, 16 Sep 2022 11:49:25 +0800 Subject: [PATCH 45/71] add script --- benchmarks/correctness/test_mem.sh | 5 +++++ benchmarks/correctness/test_time.sh | 5 +++++ 2 files changed, 10 insertions(+) create mode 100644 benchmarks/correctness/test_mem.sh create mode 100644 benchmarks/correctness/test_time.sh diff --git a/benchmarks/correctness/test_mem.sh b/benchmarks/correctness/test_mem.sh new file mode 100644 index 000000000..500646c8a --- /dev/null +++ b/benchmarks/correctness/test_mem.sh @@ -0,0 +1,5 @@ +python benchmarks/correctness/benchmark_memory.py --has_mask_bias=true --eval=false 2>&1 |tee has_mask_bias_train.txt +python benchmarks/correctness/benchmark_memory.py --has_mask_bias=false --eval=false 2>&1 |tee no_mask_bias_train.txt + +python benchmarks/correctness/benchmark_memory.py --has_mask_bias=true --eval=true 2>&1 |tee has_mask_bias_test.txt +python benchmarks/correctness/benchmark_memory.py --has_mask_bias=false --eval=true 2>&1 |tee no_mask_bias_test.txt \ No newline at end of file diff --git a/benchmarks/correctness/test_time.sh b/benchmarks/correctness/test_time.sh new file mode 100644 index 000000000..92bd47d69 --- /dev/null +++ b/benchmarks/correctness/test_time.sh @@ -0,0 +1,5 @@ +python benchmarks/correctness/check_speed_forward.py --has_mask_bias=false 2>&1 |tee no_mask_bias_test.txt +python benchmarks/correctness/check_speed_forward.py --has_mask_bias=true 2>&1 |tee has_mask_bias_test.txt + +python benchmarks/correctness/check_speed_backward.py --has_mask_bias=false 2>&1 |tee no_mask_bias_train.txt +python benchmarks/correctness/check_speed_backward.py --has_mask_bias=true 2>&1 |tee has_mask_bias_train.txt \ No newline at end of file From 95d0308a17bcf7eb046a5edcc0759c0911ae9921 Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 22 Sep 2022 12:51:01 +0800 Subject: [PATCH 46/71] add cross attention --- benchmarks/correctness/check_speed_backward.py | 2 +- benchmarks/correctness/flash_attention.py | 6 ++++-- csrc/flash_attn/fmha_api.cpp | 16 ++++++++-------- csrc/flash_attn/src/fmha/softmax.h | 10 +++++++--- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/benchmarks/correctness/check_speed_backward.py b/benchmarks/correctness/check_speed_backward.py index 4ef74bd6d..05bc7e58d 100644 --- a/benchmarks/correctness/check_speed_backward.py +++ b/benchmarks/correctness/check_speed_backward.py @@ -124,7 +124,7 @@ def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True): for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]: - if has_mask_bias: + if args.has_mask_bias: fun(seqlen=seqlen) else: fun(seqlen=seqlen, has_bias=None, has_mask=None) diff --git a/benchmarks/correctness/flash_attention.py b/benchmarks/correctness/flash_attention.py index cdf1364c5..0d7d5b6d4 100644 --- a/benchmarks/correctness/flash_attention.py +++ b/benchmarks/correctness/flash_attention.py @@ -6,6 +6,8 @@ def _flash_attn(q, k, v, mask=None, bias=None): batch_dims = q.shape[:-3] no_heads, n, c = q.shape[-3:] + k_no_heads, k_n, k_c = k.shape[-3:] + dtype = q.dtype # [*, B, N, H, C] @@ -31,9 +33,9 @@ def _flash_attn(q, k, v, mask=None, bias=None): 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device ) - k_max_s = n + k_max_s = k_n k_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device + 0, (batch_size + 1) * k_n, step=k_n, dtype=torch.int32, device=k.device ) if mask is not None: diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index b02948ef8..5a3032ca5 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -234,14 +234,14 @@ void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; // cost too much time - // auto flatten_tensor = tensor.flatten(); - // auto size = flatten_tensor.numel(); - - // for (int i = 0; i < size; i ++) { - // file << flatten_tensor[i].item() << " "; - // // file << flatten_tensor[i] << " "; - // } - // file << std::endl; + auto flatten_tensor = tensor.flatten(); + auto size = flatten_tensor.numel(); + + for (int i = 0; i < size; i ++) { + file << flatten_tensor[i].item() << " "; + // file << flatten_tensor[i] << " "; + } + file << std::endl; std::string sfile_name = label + "_" + tensor_name + ".pt"; std::ofstream sfile(sfile_name.c_str()); diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index a4899eda8..3f2018ff0 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -502,9 +502,13 @@ struct Softmax : public Softmax_base { #pragma unroll for( int jj = 0; jj < 4; ++jj ) { float value = toFloat(mask[mi][ni].elt(ii * 4 + jj)); - if( abs(value) > 0 ) { - this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; - } + // if( abs(value) > 0 ) { + // this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; + // } + // if( value < 0 ) { + // this->elt_[2 * mi + ii][4 * ni + jj] = -INFINITY; + // } + this->elt_[2 * mi + ii][4 * ni + jj] += value; #ifdef DEBUG_PRINT if ((blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { printf("Attnmask: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, softmax = %f, l = %d, loop_step_idx=%d, blockIdx.x = %d\n", From df852f5802c5197e7c5ec4f7de6648a5e12a7764 Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 22 Sep 2022 14:04:13 +0800 Subject: [PATCH 47/71] add cross attn --- .../test_template_pointwise_att.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 benchmarks/correctness/test_template_pointwise_att.py diff --git a/benchmarks/correctness/test_template_pointwise_att.py b/benchmarks/correctness/test_template_pointwise_att.py new file mode 100644 index 000000000..dcfc0e1be --- /dev/null +++ b/benchmarks/correctness/test_template_pointwise_att.py @@ -0,0 +1,58 @@ +import torch + +# from attention import _attention +from torch_attention import _torch_attention as _attention +from flash_attention import _flash_attn + +import numpy as np +import pytest + + +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + + +def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): + diff = np.abs(pred - gt) + + cnt = 0 + for index, x in np.ndenumerate(diff): + if x > abs_eps: + relative_diff = np.abs(x / gt[index]) + if relative_diff > relative_rps: + cnt += 1 + if verbose: + print (index, x, gt[index], relative_diff) + + if cnt > 0: + print ("not so match") + return False + else: + return True + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +def test_attn(): + dtype = torch.half + device = "cuda" + + q = torch.randn(1, 256, 256, 4, 1, 16, dtype=dtype, device=device) + k = torch.randn(1, 256, 256, 4, 4, 16, dtype=dtype, device=device) + v = torch.randn(1, 256, 256, 4, 4, 16, dtype=dtype, device=device) + + # print ("q shape = {}, k shape = {}, v shape = {}".format(q.shape, k.shape, v.shape)) + + o = _attention(q, k, v, mask=None, bias=None) + o = o.transpose(-2, -3).contiguous() + + output_flash = _flash_attn(q, k, v, mask=None, bias=None) + + print("Output max diff: {0}".format((output_flash - o).abs().max().item())) + print (is_same_matrix(o.detach().cpu().numpy(), output_flash.detach().cpu().numpy())) + +test_attn() \ No newline at end of file From d59fa76c54e86ba0af39b838660bc3a72ecc97c9 Mon Sep 17 00:00:00 2001 From: robotcator Date: Mon, 10 Oct 2022 16:42:48 +0800 Subject: [PATCH 48/71] fix bugs --- csrc/flash_attn/fmha_api.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 5a3032ca5..a60ed28f9 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -222,6 +222,7 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, params.attn_ds_ptr = attn_ds; } +#ifdef DDEBUG_PRINT void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const std::string &label) { std::string file_name = label + "_" + tensor_name + ".data"; std::ofstream file(file_name.c_str()); @@ -247,6 +248,7 @@ void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const std::ofstream sfile(sfile_name.c_str()); torch::save(tensor, sfile); } +#endif std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i @@ -291,7 +293,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(q.stride(-1) == 1); TORCH_CHECK(k.stride(-1) == 1); TORCH_CHECK(v.stride(-1) == 1); - TORCH_CHECK(cu_seqlens_k.is_contiguous()); + TORCH_CHECK(cu_seqlens_q.is_contiguous()); TORCH_CHECK(cu_seqlens_k.is_contiguous()); const auto sizes = q.sizes(); From bdc1fb3e829877844db6b5e909df3280b7e01d0a Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 11 Oct 2022 23:21:15 +0800 Subject: [PATCH 49/71] remove test tools --- benchmarks/correctness/attention.py | 9 - .../test_template_pointwise_att.py | 58 -- benchmarks/test/test_forward_with_bias_v2.py | 287 ------- benchmarks/test/test_forward_with_mask_v2.py | 271 ------ .../test/test_forward_without_bias_mask.py | 247 ------ benchmarks/test_example.py | 213 ----- csrc/flash_attn/fmha_api.cpp | 3 +- flash_attn/{Attention.py => attention.py} | 28 +- setup.py | 4 +- tests/build.sh | 49 -- tests/fmha_api.h | 48 -- tests/test_flash_attn.py | 667 --------------- tests/test_forward.cu | 708 ---------------- tests/test_forward_shape.cu | 249 ------ tests/test_torch_capi.cpp | 60 -- tests/tools/check_output.py | 768 ------------------ tests/tools/rebuild_bwd_attn_mask.py | 64 -- tests/tools/rebuild_bwd_softmax.py | 64 -- tests/tools/rebuild_dsoftmax.py | 64 -- tests/tools/rebuild_fwd_softmax.py | 113 --- tests/tools/rebuild_mat.py | 94 --- 21 files changed, 21 insertions(+), 4047 deletions(-) delete mode 100644 benchmarks/correctness/test_template_pointwise_att.py delete mode 100644 benchmarks/test/test_forward_with_bias_v2.py delete mode 100644 benchmarks/test/test_forward_with_mask_v2.py delete mode 100644 benchmarks/test/test_forward_without_bias_mask.py delete mode 100644 benchmarks/test_example.py rename flash_attn/{Attention.py => attention.py} (63%) delete mode 100644 tests/build.sh delete mode 100644 tests/fmha_api.h delete mode 100644 tests/test_flash_attn.py delete mode 100644 tests/test_forward.cu delete mode 100644 tests/test_forward_shape.cu delete mode 100644 tests/test_torch_capi.cpp delete mode 100644 tests/tools/check_output.py delete mode 100644 tests/tools/rebuild_bwd_attn_mask.py delete mode 100644 tests/tools/rebuild_bwd_softmax.py delete mode 100644 tests/tools/rebuild_dsoftmax.py delete mode 100644 tests/tools/rebuild_fwd_softmax.py delete mode 100644 tests/tools/rebuild_mat.py diff --git a/benchmarks/correctness/attention.py b/benchmarks/correctness/attention.py index 1f89e52f2..d8a686548 100644 --- a/benchmarks/correctness/attention.py +++ b/benchmarks/correctness/attention.py @@ -10,8 +10,6 @@ def permute_final_dims(tensor: torch.Tensor, inds: List[int]): return tensor.permute(first_inds + [zero_index + i for i in inds]) def _attention(query, key, value, mask=None, bias=None, upcast=False) -> torch.Tensor: - # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - # output back to fp16/bf16. dtype_og = query.dtype if upcast: @@ -29,13 +27,6 @@ def _attention(query, key, value, mask=None, bias=None, upcast=False) -> torch.T # [*, H, Q, K] a = torch.matmul(query, key) - # if biases is not None: - # a += biases - - # if mask is not None: - # a.masked_fill_(mask < 0, float('-inf')) - - # a = softmax_no_cast(a, -1) a = softmax_dropout(a, dropout_prob=0, is_training=True, mask=mask, bias=bias) # [*, H, Q, C_hidden] diff --git a/benchmarks/correctness/test_template_pointwise_att.py b/benchmarks/correctness/test_template_pointwise_att.py deleted file mode 100644 index dcfc0e1be..000000000 --- a/benchmarks/correctness/test_template_pointwise_att.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch - -# from attention import _attention -from torch_attention import _torch_attention as _attention -from flash_attention import _flash_attn - -import numpy as np -import pytest - - -from flash_attn.flash_attn_interface import flash_attn_unpadded_func - - -def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): - diff = np.abs(pred - gt) - - cnt = 0 - for index, x in np.ndenumerate(diff): - if x > abs_eps: - relative_diff = np.abs(x / gt[index]) - if relative_diff > relative_rps: - cnt += 1 - if verbose: - print (index, x, gt[index], relative_diff) - - if cnt > 0: - print ("not so match") - return False - else: - return True - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - - -def test_attn(): - dtype = torch.half - device = "cuda" - - q = torch.randn(1, 256, 256, 4, 1, 16, dtype=dtype, device=device) - k = torch.randn(1, 256, 256, 4, 4, 16, dtype=dtype, device=device) - v = torch.randn(1, 256, 256, 4, 4, 16, dtype=dtype, device=device) - - # print ("q shape = {}, k shape = {}, v shape = {}".format(q.shape, k.shape, v.shape)) - - o = _attention(q, k, v, mask=None, bias=None) - o = o.transpose(-2, -3).contiguous() - - output_flash = _flash_attn(q, k, v, mask=None, bias=None) - - print("Output max diff: {0}".format((output_flash - o).abs().max().item())) - print (is_same_matrix(o.detach().cpu().numpy(), output_flash.detach().cpu().numpy())) - -test_attn() \ No newline at end of file diff --git a/benchmarks/test/test_forward_with_bias_v2.py b/benchmarks/test/test_forward_with_bias_v2.py deleted file mode 100644 index a200df5c8..000000000 --- a/benchmarks/test/test_forward_with_bias_v2.py +++ /dev/null @@ -1,287 +0,0 @@ -import torch -import torch.nn as nn - -from functools import partial -import math -from typing import Optional, Callable, List, Tuple, Sequence -import numpy as np -import deepspeed - -from flash_attn.flash_attn_interface import flash_attn_unpadded_func - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - -@torch.jit.ignore -def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - d = t.dtype - # if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): - # with torch.cuda.amp.autocast(enabled=False): - # s = torch.nn.functional.softmax(t, dim=dim) - # else: - # s = torch.nn.functional.softmax(t, dim=dim) - s = torch.nn.functional.softmax(t, dim=dim) - return s - - -def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch.Tensor: - # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - # output back to fp16/bf16. - dtype_og = query.dtype - if upcast: - query = query.float() - key = key.float() - value = value.float() - if mask is not None: - mask = mask.float() - if biases is not None: - biases = biases.float() - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 0)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - # print ("q * k: ", a) - # import pdb; pdb.set_trace() - - if biases is not None: - print ("attn_shape = {}, bias_shape = {}".format(a.shape, biases.shape)) - a += biases - # print ("after bias:", a) - - if mask is not None: - # a += mask - # import pdb; pdb.set_trace() - # please do not use add now - a.masked_fill_(mask < 0, float('-inf')) - - # print ("after mask:", a) - - a = softmax_no_cast(a, -1) - # print ("softmax :", a) - - # [*, H, Q, C_hidden] - b = torch.matmul(a, value) - # print ("p * v: ", a) - return b.to(dtype_og), a.to(dtype_og) - - -def _flash_attn(q, k, v, attn_mask=None, attn_bias=None): - batch_dims = q.shape[:-3] - no_heads, n, c = q.shape[-3:] - dtype = q.dtype - - # [*, B, N, H, C] - q = q.transpose(-2, -3) - k = k.transpose(-2, -3) - v = v.transpose(-2, -3) - - # [B_flat, N, H, C] - q = q.reshape(-1, *q.shape[-3:]) - k = k.reshape(-1, *k.shape[-3:]) - v = v.reshape(-1, *v.shape[-3:]) - - # Flattened batch size - batch_size = q.shape[0] - - # [B_flat * N, H, C] - q = q.reshape(-1, *q.shape[-2:]) - k = k.reshape(-1, *k.shape[-2:]) - v = v.reshape(-1, *v.shape[-2:]) - - q_max_s = n - q_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device - ) - - k_max_s = n - k_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device - ) - - if attn_mask is not None: - # import pdb; pdb.set_trace() - attn_mask = attn_mask.reshape([batch_size , no_heads, n, n]).contiguous() - - if attn_bias is not None: - # import pdb; pdb.set_trace() - if attn_bias.is_contiguous: - print ("attn_bias it not contiguous, stride is", attn_bias.stride()) - attn_bias = attn_bias.reshape([batch_size , no_heads, n, n]).contiguous() - # attn_bias = attn_bias.reshape([batch_size , no_heads, n, n]) - print ("attn_bias stride is", attn_bias.stride()) - - print ("check shapes q_shape = {} k_shape = {} v_shape = {}".format(q.shape, k.shape, v.shape)) - print ("check shapes q_cu_shape = {} k_cu_shape = {}".format(q_cu_seqlens.shape, k_cu_seqlens.shape)) - if attn_bias is not None: - print ("attn_bias shape = {}".format(attn_bias.shape)) - - out = flash_attn_unpadded_func( - q, - k, - v, - q_cu_seqlens, - k_cu_seqlens, - q_max_s, - k_max_s, - attn_mask=attn_mask, - attn_bias=attn_bias, - dropout_p = 0., - softmax_scale = 1., # q has been scaled already - ) - - # [*, B, N, H, C] - out = out.reshape(*batch_dims, n, no_heads, c) - return out - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - - -torch.manual_seed(0) -# v2 -# bs = 1 -# seq = 128 -# head = 1 -# c_dim = 16 - -# mini -bs = 1 -seq = 128 -head = 1 -c_dim = 16 - -seq_q = seq_k = seq_v = 128 - -print (10 * "*" + "prepare data" + 10 * "*" ) -# dtype = torch.bfloat16 -dtype = torch.half -device = "cuda" - -# orig_tensor = torch.stack( -# [ (i+1) * 0.1 * torch.randn((bs, seq, head, c_dim)) for i in range(seq) ] -# ,dim = 1 -# ).to(device).to(dtype) - -orig_tensor = torch.empty((bs, seq, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) -orig_tensor.requires_grad = True -# print ("tensor: ", orig_tensor) -print ("origin shape: ", orig_tensor.shape) -# [bs, seq, seq, head, c_dim] - -bias = torch.randn( - 1, 1, head, seq_q, seq_k, dtype=dtype, device=device -) - -print ("bias shape: ", bias.shape) -bias_broadcast = bias.expand([bs, seq, head, seq_q, seq_k]) -bias_broadcast.requires_grad = True -print ("bias_broadcast shape: ", bias_broadcast.shape) - -# print ("bias_broadcast: ", bias_broadcast) - -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -normal_attn_v1 = orig_tensor.clone() -output_ref, softmax_output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, biases=bias_broadcast, upcast=True) -# be careful here -output_ref = output_ref.transpose(-2, -3) -print ("attention output shape: ", output_ref.shape) -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -print () - - -print (10 * "*" + "normal attn fp16" + 10 * "*" ) -normal_attn_v2 = orig_tensor.clone() -output_pt, softmax_output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, biases=bias_broadcast) -# be careful here -output_pt = output_pt.transpose(-2, -3) -print ("attention output shape: ", output_pt.shape) -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -print () - - -print (10 * "*" + "flash attn" + 10 * "*" ) -normal_attn_flash = orig_tensor.clone() -output3 = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, attn_bias=bias_broadcast) -# import pdb; pdb.set_trace() -print ("flash attn output shape: ", output3.shape) -print (10 * "*" + "flash attn" + 10 * "*" ) -print () - -# print ("max abs error: ", (output3 - output_ref).abs().max()) -# print ("all close at pre@.2: ", torch.allclose(output3, output_ref, atol=1e-2)) - -print (10 * "*" + "comparing forward" + 10 * "*" ) -print("Output max diff: {0}".format((output3 - output_ref).abs().max().item())) -print("Output mean diff: {0}".format((output3 - output_ref).abs().mean().item())) - -# print("Output max diff: {0}".format((output3[:,0,:,:,:] - output_ref[:,0,:,:,:]).abs().max().item())) -# print("Output max diff: {0}".format((output3[:,3,:,:,:] - output_ref[:,3,:,:,:]).abs().max().item())) - -print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) -print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) - -print("Output max diff with Pytorch: {0}".format((output3 - output_pt).abs().max().item())) -print("Output mean diff with Pytorch: {0}".format((output3 - output_pt).abs().mean().item())) - -print ("less than twice error: ", (output3 - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) -print (10 * "*" + "comparing forward" + 10 * "*" ) -print () - -# max_diff = (output3 - output_ref).abs().max().item() -# relative_diff = (output_pt - output_ref).abs().max().item() - -# for i in range(bs): -# for j in range(seq_q): -# for k in range(seq_k): -# if (output3[i, j, k, :, :] - output_ref[i, j, k, :, :]).abs().max().item() >= 2 * (relative_diff): -# print ("i={}, j={}, k={} output3={}".format(i, j, k, output3[i, j, k, :, :].data)) -# print ("i={}, j={}, k={} output_pt={}".format(i, j, k, output_ref[i, j, k, :, :].data)) - -# test backward - -g = torch.randn_like(output3) - -# dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, ), g) -# dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, ), g) -# dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash, ), g) - -dq_ref, dk_ref, dv_ref, dbias_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, bias_broadcast), g) -dq_pt, dk_pt, dv_pt, dbias_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, bias_broadcast), g) -dq, dk, dv, dbias = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash, bias_broadcast), g) - - -print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) -print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) -print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) - -print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) -print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) -print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) - -print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) -print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) -print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) - -print ("dq less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) -print ("dk less than twice error: ", ((dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()) ) -print ("dv less than twice error: ", ((dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()) ) - -if dbias is not None: - print ("dbias less than twice error: ", ((dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item()) ) - -assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() -if dbias is not None: - assert (dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item() diff --git a/benchmarks/test/test_forward_with_mask_v2.py b/benchmarks/test/test_forward_with_mask_v2.py deleted file mode 100644 index 3906165f8..000000000 --- a/benchmarks/test/test_forward_with_mask_v2.py +++ /dev/null @@ -1,271 +0,0 @@ -import torch -import torch.nn as nn - -from functools import partial -import math -from typing import Optional, Callable, List, Tuple, Sequence -import numpy as np -import deepspeed - -from flash_attn.flash_attn_interface import flash_attn_unpadded_func - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - -@torch.jit.ignore -def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - d = t.dtype - # if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): - # with torch.cuda.amp.autocast(enabled=False): - # s = torch.nn.functional.softmax(t, dim=dim) - # else: - # s = torch.nn.functional.softmax(t, dim=dim) - s = torch.nn.functional.softmax(t, dim=dim) - return s - - -def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch.Tensor: - # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - # output back to fp16/bf16. - dtype_og = query.dtype - if upcast: - query = query.float() - key = key.float() - value = value.float() - if mask is not None: - mask = mask.float() - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 0)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - # print ("q * k: ", a) - - if biases is None: - biases = [] - for b in biases: - a += b - # print ("after bias:", a) - - if mask is not None: - # a += mask - # import pdb; pdb.set_trace() - # please do not use add now - a.masked_fill_(mask < 0, float('-inf')) - - # print ("after mask:", a) - - a = softmax_no_cast(a, -1) - # print ("softmax :", a) - - # [*, H, Q, C_hidden] - b = torch.matmul(a, value) - # print ("p * v: ", a) - return b.to(dtype_og), a.to(dtype_og) - - -def _flash_attn(q, k, v, attn_mask=None): - batch_dims = q.shape[:-3] - no_heads, n, c = q.shape[-3:] - dtype = q.dtype - - # [*, B, N, H, C] - q = q.transpose(-2, -3) - k = k.transpose(-2, -3) - v = v.transpose(-2, -3) - - # [B_flat, N, H, C] - q = q.reshape(-1, *q.shape[-3:]) - k = k.reshape(-1, *k.shape[-3:]) - v = v.reshape(-1, *v.shape[-3:]) - - # Flattened batch size - batch_size = q.shape[0] - - # [B_flat * N, H, C] - q = q.reshape(-1, *q.shape[-2:]) - k = k.reshape(-1, *k.shape[-2:]) - v = v.reshape(-1, *v.shape[-2:]) - - q_max_s = n - q_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device - ) - - k_max_s = n - k_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device - ) - - if attn_mask is not None: - # import pdb; pdb.set_trace() - attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]) - attn_mask = attn_mask.contiguous() - - out = flash_attn_unpadded_func( - q, - k, - v, - q_cu_seqlens, - k_cu_seqlens, - q_max_s, - k_max_s, - attn_mask=attn_mask, - dropout_p = 0., - softmax_scale = 1., # q has been scaled already - ) - - # [*, B, N, H, C] - out = out.reshape(*batch_dims, n, no_heads, c) - return out - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - - -torch.manual_seed(0) -# v2 -bs = 1 -seq = 128 -head = 1 -c_dim = 16 - -# mini -# bs = 1 -# seq = 2 -# head = 1 -# c_dim = 16 - -seq_q = seq_k = seq_v = seq - -print (10 * "*" + "prepare data" + 10 * "*" ) -dtype = torch.bfloat16 -# dtype = torch.half -device = "cuda" - -# orig_tensor = torch.stack( -# [ (i+1) * 0.1 * torch.randn((bs, seq, head, c_dim)) for i in range(seq) ] -# ,dim = 1 -# ).to(device).to(dtype) - -orig_tensor = torch.empty((bs, seq, head, seq, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) -orig_tensor.requires_grad = True -# print ("tensor: ", orig_tensor) -print ("origin shape: ", orig_tensor.shape) -# [bs, seq, seq, head, c_dim] - -mask_data = torch.rand( - bs, - seq_q, - 1, - 1, - seq_k, - dtype=dtype, - device=device, - ) - -# fake data -# mask_data[:, :, :, :, :] = 0.02 -# mask_data[:, :, :, :, 0] = 0.001 - -mask = gen_attn_mask( - ( mask_data > 0.01 ).type(dtype), - -3e4, -) -print ("mask shape: ", mask.shape) -mask_broadcast = mask.expand([bs, seq_k, head, seq_q, seq_k]) -print ("mask_broadcast shape: ", mask_broadcast.shape) - -print ("mask broadcast: ", mask_broadcast) - -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -normal_attn_v1 = orig_tensor.clone() -output_ref, softmax_output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, mask=mask_broadcast, upcast=True) -# be careful here -output_ref = output_ref.transpose(-2, -3) -print ("attention output shape: ", output_ref.shape) -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -print () - - -print (10 * "*" + "normal attn fp16" + 10 * "*" ) -normal_attn_v2 = orig_tensor.clone() -output_pt, softmax_output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, mask=mask_broadcast) -# be careful here -output_pt = output_pt.transpose(-2, -3) -print ("attention output shape: ", output_pt.shape) -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -print () - - -print (10 * "*" + "flash attn" + 10 * "*" ) -normal_attn_flash = orig_tensor.clone() -output3 = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, attn_mask=mask_broadcast) -# import pdb; pdb.set_trace() -print ("flash attn output shape: ", output3.shape) -print (10 * "*" + "flash attn" + 10 * "*" ) -print () - -# print ("max abs error: ", (output3 - output_ref).abs().max()) -# print ("all close at pre@.2: ", torch.allclose(output3, output_ref, atol=1e-2)) - -print (10 * "*" + "comparing forward" + 10 * "*" ) -print("Output max diff: {0}".format((output3 - output_ref).abs().max().item())) -print("Output mean diff: {0}".format((output3 - output_ref).abs().mean().item())) - -print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) -print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) - -print("Output max diff with Pytorch: {0}".format((output3 - output_pt).abs().max().item())) -print("Output mean diff with Pytorch: {0}".format((output3 - output_pt).abs().mean().item())) - -print ("less than twice error: ", (output3 - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) -print (10 * "*" + "comparing forward" + 10 * "*" ) -print () - - -# test backward - -g = torch.randn_like(output3) -dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) -dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) -dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) - -print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) -print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) -print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) - -print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) -print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) -print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) - -print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) -print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) -print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) - - -print("Output dQ mean diff: {0}".format( (dq - dq_ref).abs().mean().item() )) -print("Output dK mean diff: {0}".format( (dk - dk_ref).abs().mean().item() )) -print("Output dV mean diff: {0}".format( (dv - dv_ref).abs().mean().item() )) - -print("Pytorch dQ mean diff: {0}".format( (dq_pt - dq_ref).abs().mean().item() )) -print("Pytorch dK mean diff: {0}".format( (dk_pt - dk_ref).abs().mean().item() )) -print("Pytorch dV mean diff: {0}".format( (dv_pt - dv_ref).abs().mean().item() )) - -print("Output dQ mean diff with Pytorch: {0}".format( (dq - dq_pt).abs().mean().item() )) -print("Output dK mean diff with Pytorch: {0}".format( (dk - dk_pt).abs().mean().item() )) -print("Output dV mean diff with Pytorch: {0}".format( (dv - dv_pt).abs().mean().item() )) - -print ("less than twice in max error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file diff --git a/benchmarks/test/test_forward_without_bias_mask.py b/benchmarks/test/test_forward_without_bias_mask.py deleted file mode 100644 index 930479cba..000000000 --- a/benchmarks/test/test_forward_without_bias_mask.py +++ /dev/null @@ -1,247 +0,0 @@ -import torch -import torch.nn as nn - -from functools import partial -import math -from typing import Optional, Callable, List, Tuple, Sequence -import numpy as np -import deepspeed - -from flash_attn.flash_attn_interface import flash_attn_unpadded_func - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - -@torch.jit.ignore -def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - d = t.dtype - # if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): - # with torch.cuda.amp.autocast(enabled=False): - # s = torch.nn.functional.softmax(t, dim=dim) - # else: - # s = torch.nn.functional.softmax(t, dim=dim) - s = torch.nn.functional.softmax(t, dim=dim) - return s - - -def _attention(query, key, value, mask=None, biases=None, upcast=False) -> torch.Tensor: - # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - # output back to fp16/bf16. - dtype_og = query.dtype - if upcast: - query = query.float() - key = key.float() - value = value.float() - if mask is not None: - mask = mask.float() - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 0)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - # print ("q * k: ", a) - - if biases is None: - biases = [] - for b in biases: - a += b - # print ("after bias:", a) - - if mask is not None: - # a += mask - # import pdb; pdb.set_trace() - # please do not use add now - a.masked_fill_(mask < 0, float('-inf')) - - # print ("after mask:", a) - - a = softmax_no_cast(a, -1) - # print ("softmax :", a) - - # [*, H, Q, C_hidden] - b = torch.matmul(a, value) - # print ("p * v: ", a) - return b.to(dtype_og), a.to(dtype_og) - - -def _flash_attn(q, k, v, attn_mask=None): - batch_dims = q.shape[:-3] - no_heads, n, c = q.shape[-3:] - dtype = q.dtype - - # [*, B, N, H, C] - q = q.transpose(-2, -3) - k = k.transpose(-2, -3) - v = v.transpose(-2, -3) - - # [B_flat, N, H, C] - q = q.reshape(-1, *q.shape[-3:]) - k = k.reshape(-1, *k.shape[-3:]) - v = v.reshape(-1, *v.shape[-3:]) - - # Flattened batch size - batch_size = q.shape[0] - - # [B_flat * N, H, C] - q = q.reshape(-1, *q.shape[-2:]) - k = k.reshape(-1, *k.shape[-2:]) - v = v.reshape(-1, *v.shape[-2:]) - - q_max_s = n - q_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device - ) - - k_max_s = n - k_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device - ) - - if attn_mask is not None: - # import pdb; pdb.set_trace() - attn_mask = attn_mask.reshape([bs * n, no_heads, n, n]).contiguous() - - out = flash_attn_unpadded_func( - q, - k, - v, - q_cu_seqlens, - k_cu_seqlens, - q_max_s, - k_max_s, - attn_mask=attn_mask, - dropout_p = 0., - softmax_scale = 1., # q has been scaled already - ) - - # [*, B, N, H, C] - out = out.reshape(*batch_dims, n, no_heads, c) - return out - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - - -torch.manual_seed(0) -# v2 -bs = 1 -seq = 128 -head = 1 -c_dim = 16 - -# mini -# bs = 1 -# seq = 2 -# head = 1 -# c_dim = 16 - -seq_q = seq_k = seq_v = seq - -print (10 * "*" + "prepare data" + 10 * "*" ) -dtype = torch.bfloat16 -# dtype = torch.half -device = "cuda" - -# orig_tensor = torch.stack( -# [ (i+1) * 0.1 * torch.randn((bs, seq, head, c_dim)) for i in range(seq) ] -# ,dim = 1 -# ).to(device).to(dtype) - -orig_tensor = torch.empty((bs, seq, head, seq, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) -orig_tensor.requires_grad = True -# print ("tensor: ", orig_tensor) -print ("origin shape: ", orig_tensor.shape) -# [bs, seq, seq, head, c_dim] - - -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -normal_attn_v1 = orig_tensor.clone() -output_ref, softmax_output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, mask=None, upcast=True) -# be careful here -output_ref = output_ref.transpose(-2, -3) -print ("attention output shape: ", output_ref.shape) -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -print () - - -print (10 * "*" + "normal attn fp16" + 10 * "*" ) -normal_attn_v2 = orig_tensor.clone() -output_pt, softmax_output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, mask=None) -# be careful here -output_pt = output_pt.transpose(-2, -3) -print ("attention output shape: ", output_pt.shape) -print (10 * "*" + "normal attn fp32" + 10 * "*" ) -print () - - -print (10 * "*" + "flash attn" + 10 * "*" ) -normal_attn_flash = orig_tensor.clone() -output3 = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, attn_mask=None) -# import pdb; pdb.set_trace() -print ("flash attn output shape: ", output3.shape) -print (10 * "*" + "flash attn" + 10 * "*" ) -print () - -# print ("max abs error: ", (output3 - output_ref).abs().max()) -# print ("all close at pre@.2: ", torch.allclose(output3, output_ref, atol=1e-2)) - -print (10 * "*" + "comparing forward" + 10 * "*" ) -print("Output max diff: {0}".format((output3 - output_ref).abs().max().item())) -print("Output mean diff: {0}".format((output3 - output_ref).abs().mean().item())) - -print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) -print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) - -print("Output max diff with Pytorch: {0}".format((output3 - output_pt).abs().max().item())) -print("Output mean diff with Pytorch: {0}".format((output3 - output_pt).abs().mean().item())) - -print ("less than twice error: ", (output3 - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) -print (10 * "*" + "comparing forward" + 10 * "*" ) -print () - - -# test backward - -g = torch.randn_like(output3) -dq_ref, dk_ref, dv_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1), g) -dq_pt, dk_pt, dv_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2), g) -dq, dk, dv, = torch.autograd.grad(output3, (normal_attn_flash, normal_attn_flash, normal_attn_flash), g) - -print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) -print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) -print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) - -print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) -print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) -print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) - -print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) -print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) -print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) - - -print("Output dQ mean diff: {0}".format( (dq - dq_ref).abs().mean().item() )) -print("Output dK mean diff: {0}".format( (dk - dk_ref).abs().mean().item() )) -print("Output dV mean diff: {0}".format( (dv - dv_ref).abs().mean().item() )) - -print("Pytorch dQ mean diff: {0}".format( (dq_pt - dq_ref).abs().mean().item() )) -print("Pytorch dK mean diff: {0}".format( (dk_pt - dk_ref).abs().mean().item() )) -print("Pytorch dV mean diff: {0}".format( (dv_pt - dv_ref).abs().mean().item() )) - -print("Output dQ mean diff with Pytorch: {0}".format( (dq - dq_pt).abs().mean().item() )) -print("Output dK mean diff with Pytorch: {0}".format( (dk - dk_pt).abs().mean().item() )) -print("Output dV mean diff with Pytorch: {0}".format( (dv - dv_pt).abs().mean().item() )) - -print ("less than twice in max error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) \ No newline at end of file diff --git a/benchmarks/test_example.py b/benchmarks/test_example.py deleted file mode 100644 index 086d13463..000000000 --- a/benchmarks/test_example.py +++ /dev/null @@ -1,213 +0,0 @@ -import torch -import torch.nn as nn - -from functools import partial -import math -from typing import Optional, Callable, List, Tuple, Sequence -import numpy as np -import deepspeed - -from time import perf_counter_ns - -from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - -@torch.jit.ignore -def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - d = t.dtype - if(d is torch.bfloat16 and not deepspeed.utils.is_initialized()): - with torch.cuda.amp.autocast(enabled=False): - s = torch.nn.functional.softmax(t, dim=dim) - else: - s = torch.nn.functional.softmax(t, dim=dim) - - return s - -def _attention(query, key, value, mask=None, biases=None) -> torch.Tensor: - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 0)) - - # import pdb; pdb.set_trace() - # [*, H, Q, K] - a = torch.matmul(query, key) - - print ("q * k: ", a) - - if biases is None: - biases = [] - for b in biases: - a += b - - print ("after bias:", a) - - if mask is not None: - a += mask - - print ("after mask:", a) - - a = softmax_no_cast(a, -1) - print ("softmax :", a) - - # [*, H, Q, C_hidden] - a = torch.matmul(a, value) - print ("p * v: ", a) - - return a - - -torch.manual_seed(0) -# v2 -bs = 1 -seq = 2 -head = 1 -c_dim = 16 - -# import pdb; pdb.set_trace() - -print (10 * "*" + "prepare data" + 10 * "*" ) -# dtype = torch.bfloat16 -dtype = torch.half -device = "cuda" - -orig_tensor = torch.stack( - [ (i+1) * 0.1 * torch.ones((bs, seq, head, c_dim)) for i in range(seq) ] - ,dim = 1 -).cuda().to(dtype) - -print ("tensor: ", orig_tensor) -print ("origin shape: ", orig_tensor.shape) -# [bs, seq, seq, head, c_dim] - -batch_size = bs * seq -seqlen = seq -max_s = seqlen -cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, - device=orig_tensor.device) - -print ("cu_seqlens: ", cu_seqlens) - -# [bs, seq, seq, head, c_dim] -orig_tensor = orig_tensor.permute([0, 1, 3, 2, 4]) -# [bs, seq, head, seq, c_dim] -print ("after permute: ", orig_tensor.shape) - -print (10 * "*" + "end prepare data" + 10 * "*" ) - -print (10 * "*" + "normal attn" + 10 * "*" ) -print ("normal attn: ", _attention(orig_tensor, orig_tensor, orig_tensor)) -print (10 * "*" + "end normal attn" + 10 * "*" ) - -tensor_2d_pad = orig_tensor.reshape(-1, head, c_dim) - -print (10 * "*" + "flash attn without mask" + 10 * "*" ) -output3 = flash_attn_unpadded_func( - tensor_2d_pad, - tensor_2d_pad, - tensor_2d_pad, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - dropout_p = 0., - softmax_scale = 1., # q has been scaled already -) - -print ("output3 shape: ", output3.shape) -output3 = output3.reshape((bs, seq, seq, head, c_dim)) -print ("output3: ", output3.shape) -print (10 * "*" + "end flash attn without mask" + 10 * "*" ) - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - -# mask = gen_attn_mask( -# ( -# # [bs, seq, h, seq, seq_k] -# # [bs, seq, 1, 1, seq_k] -# torch.ones( -# bs, -# seq, -# 1, -# 1, -# seq, -# dtype=dtype, -# device="cuda", -# ) -# > 0.2 -# ).type(dtype), -# -1e5, -# ) -# unicore mask -# torch.rand( -# n_batch, -# n_groups, -# 1, -# 1, -# last_dim, -# dtype=dtype, -# device=test_device, -# ) - -print (10 * "*" + "flash attn with mask" + 10 * "*" ) -mask = torch.randn( - bs, - seq, - 1, - 1, - seq, - dtype=dtype, - device="cuda", - ) - -# [bs, group, 1, 1, seq_k] -seq_q = seq -seq_k = seq -print ("mask: ", mask.shape) -mask_exp = mask.expand([bs, seq_q, head, seq_q, seq_k]) -print ("mask_exp: ", mask_exp.shape) -mask_batch = mask_exp.reshape((bs*seq_q, head, seq_q, seq_k)) -print ("mask_exp: ", mask_batch.shape) - -print ("mask: ", mask_batch) -print ("tensor: ", tensor_2d_pad) -print ("mask maximum number :", mask_batch.abs().max()) - -# bs * seq -# batch_size, num_heads, max_seqlen_q, max_seqlen_k -output4 = flash_attn_unpadded_func(tensor_2d_pad, - tensor_2d_pad, - tensor_2d_pad, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - # None, - attn_mask=mask_batch, - attn_bias=mask_batch, - dropout_p=0.0, - softmax_scale=1.0) - -output4 = output4.reshape((bs, seq, seq, head, c_dim)) - -print ("output4: ", output4.shape) - -print (10 * "*" + "end flash attn with mask" + 10 * "*" ) - -print (10 * "*" + "normal attn with mask" + 10 * "*" ) -print ("normal attn: ", _attention(orig_tensor, orig_tensor, orig_tensor, mask=mask)) -print (10 * "*" + "end normal attn with mask" + 10 * "*" ) - -print ("all close on output3 and output4 max error", (output3 - output4).abs().max()) -print ("all close on output3 and output4 min error", (output3 - output4).abs().min()) -print ("all close on output3 and output4 num less min error", torch.sum( (output3 - output4).abs() <=(output3 - output4).abs().min() )) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index a60ed28f9..d4da2eba5 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -222,7 +222,7 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, params.attn_ds_ptr = attn_ds; } -#ifdef DDEBUG_PRINT +#ifdef DEBUG_PRINT void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const std::string &label) { std::string file_name = label + "_" + tensor_name + ".data"; std::ofstream file(file_name.c_str()); @@ -361,7 +361,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; - // loop over blocks more than once ? auto opts = q.options(); diff --git a/flash_attn/Attention.py b/flash_attn/attention.py similarity index 63% rename from flash_attn/Attention.py rename to flash_attn/attention.py index 0ff736f13..2a44b9d9f 100644 --- a/flash_attn/Attention.py +++ b/flash_attn/attention.py @@ -1,14 +1,15 @@ import torch - +import torch.nn.functional as F from flash_attn.flash_attn_interface import flash_attn_unpadded_func -def flash_attn(q, k, v): - # import pdb; pdb.set_trace() +def _flash_attn(q, k, v, mask=None, bias=None): batch_dims = q.shape[:-3] no_heads, n, c = q.shape[-3:] dtype = q.dtype + k_no_heads, k_n, k_c = k.shape[-3:] + # [*, B, N, H, C] q = q.transpose(-2, -3) k = k.transpose(-2, -3) @@ -21,6 +22,7 @@ def flash_attn(q, k, v): # Flattened batch size batch_size = q.shape[0] + k_batch_size = k.shape[0] # [B_flat * N, H, C] q = q.reshape(-1, *q.shape[-2:]) @@ -32,11 +34,19 @@ def flash_attn(q, k, v): 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device ) - k_max_s = n + k_max_s = k_n k_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=k.device + 0, (k_batch_size + 1) * k_n, step=k_n, dtype=torch.int32, device=k.device ) + if mask is not None: + mask_heads, tgt_len, src_len = mask.shape[-3:] + mask = mask.reshape(-1 , mask_heads, tgt_len, src_len).contiguous() + + if bias is not None: + bias_heads, tgt_len, src_len = bias.shape[-3:] + bias = bias.reshape(-1 , bias_heads, tgt_len, src_len).contiguous() + out = flash_attn_unpadded_func( q, k, @@ -45,12 +55,12 @@ def flash_attn(q, k, v): k_cu_seqlens, q_max_s, k_max_s, + attn_mask=mask, + attn_bias=bias, dropout_p = 0., softmax_scale = 1., # q has been scaled already ) + # [*, B, N, H, C] out = out.reshape(*batch_dims, n, no_heads, c) - - out = out.to(dtype=dtype) - - return out \ No newline at end of file + return out diff --git a/setup.py b/setup.py index bc7e36798..f627b7d5f 100644 --- a/setup.py +++ b/setup.py @@ -125,13 +125,11 @@ def append_nvcc_threads(nvcc_extra_args): "csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu", ], extra_compile_args={ - # "cxx": ["-O3"] + generator_flag, - "cxx": ["-O3", "-DDEBUG_PRINT"] + generator_flag, + "cxx": ["-O3"] + generator_flag, "nvcc": append_nvcc_threads( [ "-O3", "-t4", - "-DDEBUG_PRINT", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "--expt-relaxed-constexpr", diff --git a/tests/build.sh b/tests/build.sh deleted file mode 100644 index 2f11ad907..000000000 --- a/tests/build.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash -# csrc_path=../csrc/flash_attn -# csrc_path=/workspace/openfold/single_test/flash_attn/flash-attention_v2/csrc/flash_attn -csrc_path=../csrc/flash_attn -src_file= -src_file+=test_forward.cu -src_file+=" ${csrc_path}/fmha_api.cpp" -src_file+=" ${csrc_path}/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu" -src_file+=" ${csrc_path}/src/fmha_block_fprop_fp16_kernel.sm80.cu" -src_file+=" ${csrc_path}/src/fmha_dgrad_fp16_kernel_loop.sm80.cu" -src_file+=" ${csrc_path}/src/fmha_fprop_fp16_kernel.sm80.cu" - -echo ${src_file} - -echo ${csrc_path}/ -echo ${csrc_path}/src -echo ${csrc_path}/cutlass/include - -# nvcc -o test ${src_file} \ -/usr/local/cuda-11.3/bin/nvcc -v -o test ${src_file} \ - --compiler-options='-Wl\,--no-as-needed' \ - -lc10 -ltorch -ltorch_cpu -lcudart -lc10_cuda -ltorch_cuda -ltorch_cuda_cu -ltorch_cuda_cpp \ - -I ./ \ - -I ${csrc_path} \ - -I ${csrc_path}/src \ - -I ${csrc_path}/cutlass/include \ - -I /opt/conda/lib/python3.7/site-packages/torch/include \ - -I /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include \ - -I /opt/conda/lib/python3.7/site-packages/torch/include/TH \ - -I /opt/conda/lib/python3.7/site-packages/torch/include/THC \ - -I /opt/conda/include \ - -I /opt/conda/include/python3.7m \ - -L /opt/conda/lib/python3.7/site-packages/torch/lib/ \ - -L /usr/local/cuda-11.3/lib64/ \ - -L /opt/conda/lib64/ \ - -L /opt/conda/lib/ \ - -g -G \ - -t 4 \ - -D_GLIBCXX_USE_CXX11_ABI=0 \ - -DDEBUG_PRINT \ - -DDEBUG_USING_NVCC \ - -gencode arch=compute_80,code=sm_80 \ - -U__CUDA_NO_HALF_OPERATORS__ \ - -U__CUDA_NO_HALF_CONVERSIONS__ \ - --expt-relaxed-constexpr \ - --expt-extended-lambda \ - --use_fast_math - - diff --git a/tests/fmha_api.h b/tests/fmha_api.h deleted file mode 100644 index a1ff42204..000000000 --- a/tests/fmha_api.h +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -#include -#include - -#include "fmha.h" - -std::vector -mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const int max_seqlen_q_, - const int max_seqlen_k_, - const float p_dropout, - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - const bool return_softmax, - c10::optional gen_, - const c10::optional &attn_mask, // attn_mask - const c10::optional &attn_bias // attn bias - ); - - -std::vector -mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp - at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - const int max_seqlen_q_, - const int max_seqlen_k_, // max sequence length to choose the kernel - const float p_dropout, // probability to drop - const float softmax_scale, - const bool zero_tensors, - const bool is_causal, - c10::optional gen_, - const c10::optional &attn_mask, // attn_mask - const c10::optional &attn_bias // attn bias -); \ No newline at end of file diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py deleted file mode 100644 index 15afa3fc9..000000000 --- a/tests/test_flash_attn.py +++ /dev/null @@ -1,667 +0,0 @@ -import math - -import torch -import torch.nn.functional as F - -import pytest - -from einops import rearrange, repeat - -from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_qkvpacked_func, _get_block_size, flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func -from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis - - -is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5) - - -def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'): - assert mode in ['full', 'random', 'third'] - if mode == 'full': - lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) - elif mode == 'random': - lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) - elif mode == 'third': - lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) - padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths - return padding_mask - - -def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None, - kvpacked=False, qkvpacked=False): - """ - Arguments: - x: (batch_size, seqlen, nheads * d) - Wqkv: nn.Linear(nheads * d, 3 * nheads * d) - query_padding_mask: (batch_size, seqlen), bool - key_padding_mask: (batch_size, seqlen), bool - """ - assert not (kvpacked and qkvpacked) - batch_size, seqlen, dim = x.shape - q, k, v = Wqkv(x).chunk(3, dim=-1) - - if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) - q_unpad = rearrange(q_unpad, 'nnz (h d) -> nnz h d', h=nheads) - output_pad_fn = lambda output_unpad: rearrange( - pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen), - 'b s (h d) -> b s h d', h=nheads - ) - else: - q_unpad = rearrange(q, 'b s (h d) -> (b s) h d', h=nheads) - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, - device=q_unpad.device) - max_seqlen_q = seqlen - output_pad_fn = lambda output_unpad: rearrange(output_unpad, '(b s) h d -> b s h d', b=batch_size) - - if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - k_unpad = rearrange(k_unpad, 'nnz (h d) -> nnz h d', h=nheads) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) - v_unpad = rearrange(v_unpad, 'nnz (h d) -> nnz h d', h=nheads) - else: - k_unpad = rearrange(k, 'b s (h d) -> (b s) h d', h=nheads) - v_unpad = rearrange(v, 'b s (h d) -> (b s) h d', h=nheads) - cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, - device=q_unpad.device) - max_seqlen_k = seqlen - - if qkvpacked: - assert (query_padding_mask == key_padding_mask).all() - qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) - qkv = rearrange(torch.stack([q, k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads) - if query_padding_mask is not None: - dqkv_pad_fn = lambda dqkv_unpad: rearrange( - pad_input(rearrange(dqkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_q, batch_size, seqlen), - 'b s (t h d) -> b s t h d', t=3, h=nheads - ) - else: - dqkv_pad_fn = lambda dqkv_unpad: rearrange(dqkv_unpad, '(b s) t h d -> b s t h d', b=batch_size) - return (qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, - qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn) - elif kvpacked: - kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) - q = rearrange(q, 'b s (h d) -> b s h d', h=nheads) - kv = rearrange(torch.stack([k, v], dim=2), 'b s t (h d) -> b s t h d', h=nheads) - dq_pad_fn = output_pad_fn - if key_padding_mask is not None: - dkv_pad_fn = lambda dkv_unpad: rearrange( - pad_input(rearrange(dkv_unpad, 'nnz t h d -> nnz (t h d)'), indices_k, batch_size, seqlen), - 'b s (t h d) -> b s t h d', t=2, h=nheads - ) - else: - dkv_pad_fn = lambda dkv_unpad: rearrange(dkv_unpad, '(b s) t h d -> b s t h d', b=batch_size) - return (q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - q.detach().requires_grad_(), kv.detach().requires_grad_(), - output_pad_fn, dq_pad_fn, dkv_pad_fn) - else: - q, k, v = [rearrange(z, 'b s (h d) -> b s h d', h=nheads).detach().requires_grad_() - for z in [q, k, v]] - dq_pad_fn = output_pad_fn - if key_padding_mask is not None: - dk_pad_fn = lambda dk_unpad: rearrange( - pad_input(rearrange(dk_unpad, 'nnz h d -> nnz (h d)'), indices_k, batch_size, seqlen), - 'b s (h d) -> b s h d', h=nheads - ) - else: - dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, '(b s) h d -> b s h d', b=batch_size) - return (q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), - v_unpad.detach().requires_grad_(), - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - q, k, v, - output_pad_fn, dq_pad_fn, dk_pad_fn) - - -def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0, - dropout_mask=None, causal=False, upcast=True, reorder_ops=False): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - d = q.shape[-1] - if not reorder_ops: - scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k) - else: - scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d)) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf')) - if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) - scores.masked_fill_(causal_mask, float('-inf')) - attention = torch.softmax(scores, dim=-1) - dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - if dropout_mask is not None: - attention_drop = attention.masked_fill(~dropout_mask, 0.0) - output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0) - attention = attention.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - -def attention_kvpacked_ref(q, kv, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0, - dropout_mask=None, causal=False, upcast=True, reorder_ops=False): - return attention_ref(q, kv[:, :, 0], kv[:, :, 1], query_padding_mask, - key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal, - reorder_ops=reorder_ops) - - -def attention_qkvpacked_ref(qkv, key_padding_mask=None, dropout_p=0.0, - dropout_mask=None, causal=False, upcast=True, reorder_ops=False): - return attention_ref(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask, - key_padding_mask, dropout_p, dropout_mask, upcast=upcast, causal=causal, - reorder_ops=reorder_ops) - - -def generate_sparsity_mask(seqlen, sparsity=0.3): - repeats = seqlen // 16 // 2 - # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'), - # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) - # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'), - # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) - # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) - # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1) - nrow, ncol = seqlen // 16, seqlen // 256 - mask = torch.rand(nrow, ncol, device='cuda') < sparsity - return mask - - -def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): - """ - Arguments: - qkv: (batch_size, seqlen, 3, nheads, head_dim) - blockmask: (seqlen / 16, seqlen / 256) - attn_mask: (batch_size, seqlen) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen, seqlen) - Output: - output: (batch_size, seqlen, nheads, head_dim) - attention: softmax after dropout - """ - q, k, v = qkv.float().unbind(dim=2) - d = qkv.shape[-1] - seqlen = qkv.shape[1] - scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k) - scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf')) - blockmask = repeat(blockmask, 's_16 s_256 -> (s_16 16) (s_256 256)') - blockmask = blockmask[:seqlen, :seqlen] - scores.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), float('-inf')) - attention = torch.softmax(scores, dim=-1) - attention = attention.masked_fill(rearrange(~attn_mask, 'b s -> b 1 s 1'), 0.0) - attention = attention.masked_fill_(rearrange(~blockmask, 't s -> 1 1 t s'), 0.0) - attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) - output = torch.einsum('bhts,bshd->bthd', attention_drop , v) - output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1 1'), 0) - return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) - - -def convert_flash_attn_S_to_softmax(S, query_padding_mask, key_padding_mask, head_dim, is_dropout, - causal=False): - """FlashAttention stores the S matrix in a different way. - Arguments: - S: (batch_size, nheads, seqlen_q, seqlen_k) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - """ - S_flat = rearrange(S, 'b h t s -> b h (t s)') - seqlen_q, seqlen_k = S.shape[-2:] - block_size = _get_block_size(S.device, head_dim, is_dropout) - loop_steps = (seqlen_k + block_size - 1) // block_size - warps_n = 4 - mmas_n = (seqlen_k // warps_n // 16) if seqlen_k <= block_size else (block_size // warps_n // 16) - S_converted = rearrange(S_flat, 'b h (loop nsteps mmas_n warps_n eight t r c0 c1) -> b h (nsteps r eight) (loop mmas_n warps_n c0 t c1)', - loop=loop_steps, nsteps=seqlen_q // 16, mmas_n=mmas_n, warps_n=warps_n, eight=8, t=4, - r=2, c0=2, c1=2) - - # Need to zero out things not in attention_mask in case S was initialized with random values - # and some of those values aren't overwritten. - seqlen_q_og = query_padding_mask.shape[-1] - if seqlen_q_og < seqlen_q: - query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) - else: - query_padding_mask = query_padding_mask[:, :seqlen_q] - S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) - seqlen_k_og = key_padding_mask.shape[-1] - if seqlen_k_og < seqlen_k: - key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) - else: - key_padding_mask = key_padding_mask[:, :seqlen_k] - S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), 0.0) - if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1) - S_converted.masked_fill_(causal_mask, 0.0) - if seqlen_q_og < seqlen_q: - S_converted = S_converted[:, :, :seqlen_q_og, :] - else: - S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q)) - if seqlen_k_og < seqlen_k: - S_converted = S_converted[:, :, :, :seqlen_k_og] - else: - S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k)) - return S_converted - - -def normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask=None, key_padding_mask=None, - is_dropout=False, causal=False): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k, v: (batch_size, seqlen_k, nheads, head_dim) - key_padding_mask: (batch_size, seqlen_q) - Output: - softmax_lse: (batch_size, nheads, seqlen_q) - softmax_max: (batch_size, nheads, seqlen_q) - """ - q, k, v = q.float(), k.float(), v.float() - _, seqlen_q, _, head_dim = q.shape - seqlen_k = k.shape[1] - scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(head_dim), k) - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf')) - if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) - scores.masked_fill_(causal_mask, float('-inf')) - block_size = _get_block_size(scores.device, head_dim, is_dropout) - scores_block = scores.split(block_size, dim=-1) - lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) - lcse_block = torch.logcumsumexp(lse_block, dim=-1).unbind(dim=-1) - scores_max_block = ([torch.amax(scores_block[0], dim=-1)] - + [torch.maximum(torch.amax(s, dim=-1), lcse) - for s, lcse in zip(scores_block[1:], lcse_block[:-1])]) - attn_unnorm_block = attn_unnorm.split(block_size, dim=-1) - attn_norm = torch.cat([a / rearrange(torch.exp(lcse_block[-1] - m), 'b h s -> b h s 1') - for a, m in zip(attn_unnorm_block, scores_max_block)], dim=-1) - if query_padding_mask is not None: - attn_norm.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), 0.0) - return attn_norm.to(dtype=attn_unnorm.dtype) - - -def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False): - """ - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - """ - batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape - dropped = ~dropout_mask - if query_padding_mask is not None: - dropped.masked_fill_(rearrange(~query_padding_mask, 'b s -> b 1 s 1'), False) - if key_padding_mask is not None: - dropped.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), False) - if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, - device=dropout_mask.device), 1) - dropped.masked_fill_(causal_mask, False) - dropped_total = dropped.sum() - query_lengths = (query_padding_mask.sum(dim=-1) if query_padding_mask is not None - else torch.full((batch_size,), seqlen_q, device=dropout_mask.device)) - key_lengths = (key_padding_mask.sum(dim=-1) if key_padding_mask is not None - else torch.full((batch_size,), seqlen_k, device=dropout_mask.device)) - if not causal: - numel_per_batch = query_lengths * key_lengths - else: - numel_per_batch = torch.where( - query_lengths <= key_lengths, - query_lengths * (query_lengths + 1) / 2, - query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2) - ) - return dropped_total / (numel_per_batch.sum() * nheads) - - -@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('d', [128, 64, 32, 16]) -# @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: - pytest.skip() # Reference implementation OOM - device = 'cuda' - # if dtype == torch.float16: - # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) - # else: # torch.bfloat16 - # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) - # set seed - torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') - - qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( - x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True - ) - - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_qkvpacked_func( - qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal - ) - output = output_pad_fn(output_unpad) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], - key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, - causal=causal).item() - - output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal) - output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, - causal=causal, upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - if not (is_sm75 and d == 128): - g = torch.randn_like(output) - dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) - dqkv = dqkv_pad_fn(dqkv_unpad) - dqkv_ref, = torch.autograd.grad(output_ref, qkv, g) - dqkv_pt, = torch.autograd.grad(output_pt, qkv, g) - print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') - print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') - print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') - print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') - print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') - print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') - print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') - print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() - else: - assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - - if not (is_sm75 and d == 128): - assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() - # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('d', [128, 64, 32, 16]) -# @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: - pytest.skip() # Reference implementation OOM - device = 'cuda' - # if dtype == torch.float16: - # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) - # else: # torch.bfloat16 - # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) - # set seed - torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - - (q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv, - output_pad_fn, dq_pad_fn, dkv_pad_fn) = generate_qkv( - x, Wqkv, nheads, query_padding_mask, key_padding_mask, kvpacked=True - ) - - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_kvpacked_func( - q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, return_attn_probs=True, causal=causal - ) - output = output_pad_fn(output_unpad) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, q, kv[:, :, 0], kv[:, :, 1], - query_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask, - causal=causal) - - output_ref, attn_ref = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal) - output_pt, attn_pt = attention_kvpacked_ref(q, kv, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal, - upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - if not (is_sm75 and d == 128): - g = torch.randn_like(output) - dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g) - dq = dq_pad_fn(dq_unpad) - dkv = dkv_pad_fn(dkv_unpad) - dq_ref, dkv_ref, = torch.autograd.grad(output_ref, (q, kv), g) - dq_pt, dkv_pt = torch.autograd.grad(output_pt, (q, kv), g) - print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') - print(f'dK max diff: {(dkv[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}') - print(f'dV max diff: {(dkv[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}') - print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') - print(f'dK Pytorch max diff: {(dkv_pt[:, :, 0] - dkv_ref[:, :, 0]).abs().max().item()}') - print(f'dV Pytorch max diff: {(dkv_pt[:, :, 1] - dkv_ref[:, :, 1]).abs().max().item()}') - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() - else: - assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - - if not (is_sm75 and d == 128): - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() - assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item() - # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol) - # assert torch.allclose(dkv, dkv_ref, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('d', [128, 64, 32, 16]) -# @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: - pytest.skip() # Reference implementation OOM - device = 'cuda' - # if dtype == torch.float16: - # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) - # else: # torch.bfloat16 - # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) - # set seed - torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - - (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, - output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv( - x, Wqkv, nheads, query_padding_mask, key_padding_mask - ) - - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func( - q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, return_attn_probs=True, causal=causal - ) - output = output_pad_fn(output_unpad) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - dropout_mask = S_dmask_converted >= 0 - attn_unnorm = S_dmask_converted.abs() - attn = normalize_flash_attn_S(attn_unnorm, q, k, v, query_padding_mask, key_padding_mask, - dropout_p > 0.0, causal=causal) - dropout_fraction = get_dropout_fraction(dropout_mask, query_padding_mask, key_padding_mask, - causal=causal) - - output_ref, attn_ref = attention_ref(q, k, v, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal) - output_pt, attn_pt = attention_ref(q, k, v, query_padding_mask, key_padding_mask, - dropout_p, dropout_mask, causal=causal, - upcast=False, reorder_ops=True) - print(f'Actual dropout fraction: {dropout_fraction}') - print(f'Output max diff: {(output - output_ref).abs().max().item()}') - print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') - print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') - print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') - print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - - if not (is_sm75 and d == 128): - g = torch.randn_like(output) - dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g) - dq = dq_pad_fn(dq_unpad) - dk = dk_pad_fn(dk_unpad) - dv = dk_pad_fn(dv_unpad) - dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) - dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) - print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') - print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') - print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') - print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') - print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') - print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() - # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) - if dropout_p == 0.0: - assert dropout_mask.all() - else: - assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - - if not (is_sm75 and d == 128): - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() - assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() - # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol) - # assert torch.allclose(dk, dk_ref, rtol=rtol, atol=atol) - # assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('d', [128, 64, 32, 16]) -# @pytest.mark.parametrize('d', [64]) -@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) -def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: - pytest.skip() # Reference implementation OOM - device = 'cuda' - # set seed - torch.random.manual_seed(0) - batch_size = 32 - nheads = 4 - x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) - Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) - - query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') - - (q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, - output_pad_fn, dq_pad_fn, dk_pad_fn) = generate_qkv( - x, Wqkv, nheads, query_padding_mask, key_padding_mask - ) - - torch.random.manual_seed(0) - output_unpad_0, sm_lse_0, S_dmask_0 = flash_attn_unpadded_func( - q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, return_attn_probs=True, causal=causal - ) - S_dmask_converted_0 = convert_flash_attn_S_to_softmax( - S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - - if not (is_sm75 and d == 128): - g = torch.randn_like(output_unpad_0) - dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0, - (q_unpad, k_unpad, v_unpad), g) - - for _ in range(10): - torch.random.manual_seed(0) - output_unpad, sm_lse, S_dmask = flash_attn_unpadded_func( - q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, return_attn_probs=True, causal=causal - ) - S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - ) - assert torch.equal(output_unpad, output_unpad_0) - # sm_lse has some parts that are uninitialized from torch.empty - # assert torch.equal(sm_lse, sm_lse_0) - assert torch.equal(S_dmask_converted, S_dmask_converted_0) - - if not (is_sm75 and d == 128): - dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad, - (q_unpad, k_unpad, v_unpad), g) - assert torch.equal(dq_unpad, dq_unpad_0) - assert torch.equal(dk_unpad, dk_unpad_0) - assert torch.equal(dv_unpad, dv_unpad_0) diff --git a/tests/test_forward.cu b/tests/test_forward.cu deleted file mode 100644 index 5dbb20cd2..000000000 --- a/tests/test_forward.cu +++ /dev/null @@ -1,708 +0,0 @@ -#include -#include -//#include -#include -#include -#include -#include -#include - - -void dump_tensor(const std::string &tensor_name, at::Tensor &tensor, const std::string &label) { - std::string file_name = label + "_" + tensor_name + ".data"; - std::ofstream file(file_name.c_str()); - // file << tensor_name << std::endl; - // file << tensor << std::endl; - std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; - auto flatten_tensor = tensor.flatten(); - auto size = flatten_tensor.numel(); - - for (int i = 0; i < size; i ++) { - file << flatten_tensor[i].item() << " "; - // file << flatten_tensor[i] << " "; - } - file << std::endl; -} - -void test_fwd_with_mask(int has_mask) { - int batch_size = 1; - int nheads = 1; - int headdim = 16; - int max_seqlen_q_ = 8; - int max_seqlen_k_ = 8; - - float softmax_scale = 1; - - bool zero_tensors = false; - bool is_causal = false; - bool return_softmax = false; - - // q -> [bs * seq, head, head_dim] - // q -> [1 * 128, 1, 16] - // block q -> [128, 16] - - // k -> [bs * seq, head, head_dim] - // k -> [1 * 128, 1, 16] - // block k -> [128, 16] - - // v -> [bs * seq, head, head_dim] - // v -> [1 * 128, 1, 16] - // block k -> [128, 16] - - at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - int cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - q_cpu[i][j][k] = cnt * 0.001; - k_cpu[i][j][k] = cnt * 0.001; - v_cpu[i][j][k] = cnt * 0.001; - cnt ++; - } - } - } - - auto q = q_cpu.cuda(); - auto k = k_cpu.cuda(); - auto v = v_cpu.cuda(); - - at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - - for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { - cu_seqlens_q_cpu[i] = i * max_seqlen_q_; - cu_seqlens_k_cpu[i] = i * max_seqlen_k_; - } - - auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); - auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); - - // at::Tensor attn_mask = at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).cuda(); - - // cnt = 0; - // for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { - // for (int j = 0; j < nheads; j ++) { - // for (int k = 0; k < max_seqlen_q_; k ++) { - // for (int l = 0; l < max_seqlen_k_; l ++) { - // attn_mask[i][j][k][l] = cnt * 0.001; - // cnt ++; - // } - // } - // } - // } - - at::Tensor attn_mask = 1 - at::ones({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf).tril().cuda(); - - c10::optional gen_; - c10::optional attn_bias; - - // std::cout << "attn bias" << attn_bias << std::endl; - std::vector ret; - if (has_mask) { - ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - attn_mask, - attn_bias - ); - dump_tensor("attn_output", ret[0], "has_mask"); - dump_tensor("attn_lse", ret[1], "has_mask"); - }else{ - ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - attn_bias, - attn_bias - ); - dump_tensor("attn_output", ret[0], ""); - dump_tensor("attn_lse", ret[1], ""); - } - - // std::cout << "Ret vec size is " << ret.size(); - // for (int i = 0; i < ret.size(); i ++) { - // ret[i].cpu(); - // std::cout << ret[i] << std::endl; - // } - - at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - dout_cpu[i][j][k] = cnt * 0.001; - cnt ++; - } - } - } - - at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - auto dout = dout_cpu.cuda(); - auto dq = dq_cpu.cuda(); - auto dk = dk_cpu.cuda(); - auto dv = dv_cpu.cuda(); - std::vector bwd_ret; - - if (has_mask) { - bwd_ret = mha_bwd( - dout, - q, - k, - v, - ret[0], - ret[1], - dq, - dk, - dv, - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - gen_, - attn_mask, - attn_bias - ); - dump_tensor("attn_dq", dq, "has_mask"); - dump_tensor("attn_dk", dk, "has_mask"); - dump_tensor("attn_dv", dv, "has_mask"); - // dump_tensor("attn_ds", bwd_ret[5], "has_mask"); - }else{ - bwd_ret = mha_bwd( - dout, - q, - k, - v, - ret[0], - ret[1], - dq, - dk, - dv, - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - gen_, - attn_bias, - attn_bias - // placeholder - ); - dump_tensor("attn_dq", dq, ""); - dump_tensor("attn_dk", dk, ""); - dump_tensor("attn_dv", dv, ""); - } -} - - -void test_fwd_with_mask_mini() { - int batch_size = 1; - int nheads = 1; - int headdim = 16; - int max_seqlen_q_ = 2; - int max_seqlen_k_ = 2; - - float softmax_scale = 1.0; - - bool zero_tensors = false; - bool is_causal = false; - bool return_softmax = false; - - // q -> [bs * seq, head, head_dim] - // q -> [1 * 128, 1, 16] - // block q -> [128, 16] - - // k -> [bs * seq, head, head_dim] - // k -> [1 * 128, 1, 16] - // block k -> [128, 16] - - // v -> [bs * seq, head, head_dim] - // v -> [1 * 128, 1, 16] - // block k -> [128, 16] - - at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - int cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - q_cpu[i][j][k] = cnt * 0.001; - k_cpu[i][j][k] = cnt * 0.001; - v_cpu[i][j][k] = cnt * 0.001; - cnt ++; - } - } - } - - auto q = q_cpu.cuda(); - auto k = k_cpu.cuda(); - auto v = v_cpu.cuda(); - - at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - - for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { - cu_seqlens_q_cpu[i] = i * max_seqlen_q_; - cu_seqlens_k_cpu[i] = i * max_seqlen_k_; - } - - auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); - auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); - - at::Tensor attn_mask_cpu = at::zeros({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf); - - cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < max_seqlen_q_; k ++) { - for (int l = 0; l < max_seqlen_k_; l ++) { - // if (l == 0) attn_mask[i][j][k][l] = -INFINITY; - if (l == 0) attn_mask_cpu[i][j][k][l] = -3e4; - else attn_mask_cpu[i][j][k][l] = 0; - - attn_mask_cpu[i][j][k][l] = -3e4; - printf("i=%d, j=%d, k=%d, l=%d attn_mask=%f\n", i, j, k, l, attn_mask_cpu[i][j][k][l]); - } - } - } - } - - auto attn_mask = attn_mask_cpu.cuda(); - - c10::optional gen_; - c10::optional attn_bias; - - // std::cout << "attn bias: " << attn_bias << std::endl; - - std::vector ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - attn_mask, - attn_bias - ); - - // ret: std::vector result = {o, softmax_lse}; - // [bs * seq * seq, head, head_dim] - // [1 * 2 * 2, 1, 16] - std::cout << "Ret vec size is " << ret.size(); - for (int i = 0; i < ret.size(); i ++) { - ret[i].cpu(); - std::cout << ret[i] << std::endl; - } -} - - -void test_fwd_with_bias_mini() { - int batch_size = 1; - int nheads = 1; - int headdim = 16; - int max_seqlen_q_ = 2; - int max_seqlen_k_ = 2; - - float softmax_scale = 0.1; - - bool zero_tensors = false; - bool is_causal = false; - bool return_softmax = false; - - // q -> [bs * seq, head, head_dim] - // q -> [1 * 128, 1, 16] - // block q -> [128, 16] - - // k -> [bs * seq, head, head_dim] - // k -> [1 * 128, 1, 16] - // block k -> [128, 16] - - // v -> [bs * seq, head, head_dim] - // v -> [1 * 128, 1, 16] - // block k -> [128, 16] - - at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - int cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - q_cpu[i][j][k] = cnt * 0.001; - k_cpu[i][j][k] = cnt * 0.001; - v_cpu[i][j][k] = cnt * 0.001; - cnt ++; - } - } - } - - auto q = q_cpu.cuda(); - auto k = k_cpu.cuda(); - auto v = v_cpu.cuda(); - - at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - - for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { - cu_seqlens_q_cpu[i] = i * max_seqlen_q_; - cu_seqlens_k_cpu[i] = i * max_seqlen_k_; - } - - auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); - auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); - - at::Tensor attn_bias_cpu = at::zeros({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf); - - cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < max_seqlen_q_; k ++) { - for (int l = 0; l < max_seqlen_k_; l ++) { - // if (l == 0) attn_mask[i][j][k][l] = -INFINITY; - if (l == 0) attn_bias_cpu[i][j][k][l] = -3e4; - else attn_bias_cpu[i][j][k][l] = 0; - - attn_bias_cpu[i][j][k][l] = 100; - printf("i=%d, j=%d, k=%d, l=%d attn_bias=%f\n", i, j, k, l, attn_bias_cpu[i][j][k][l]); - // std::cout << "i=" << i << ", j=" << j << ", k=" << k << ", l" - // << l << << ", attn_bias=" << attn_bias_cpu[i][j][k][l] << std::endl; - } - } - } - } - - auto attn_bias = attn_bias_cpu.cuda(); - - c10::optional gen_; - c10::optional attn_mask; - - // std::cout << attn_mask << std::endl; - - std::vector ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - attn_mask, - attn_bias - ); - - // ret: std::vector result = {o, softmax_lse}; - // [bs * seq * seq, head, head_dim] - // [1 * 2 * 2, 1, 16] - std::cout << "Ret vec size is " << ret.size(); - for (int i = 0; i < ret.size(); i ++) { - ret[i].cpu(); - std::cout << ret[i] << std::endl; - } -} - - -void test_fwd_with_bias(bool has_bias) { - int batch_size = 1; - int nheads = 1; - int headdim = 16; - int seq = 8; - int max_seqlen_q_ = seq; - int max_seqlen_k_ = seq; - - float softmax_scale = 1; - - bool zero_tensors = false; - bool is_causal = false; - bool return_softmax = false; - - // q -> [bs * seq, head, head_dim] - // q -> [1 * 128, 1, 16] - // block q -> [128, 16] - - // k -> [bs * seq, head, head_dim] - // k -> [1 * 128, 1, 16] - // block k -> [128, 16] - - // v -> [bs * seq, head, head_dim] - // v -> [1 * 128, 1, 16] - // block k -> [128, 16] - - at::Tensor q_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor k_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor v_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - int cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - q_cpu[i][j][k] = cnt * 0.001; - k_cpu[i][j][k] = cnt * 0.001; - v_cpu[i][j][k] = cnt * 0.001; - cnt ++; - } - } - } - - auto q = q_cpu.cuda(); - auto k = k_cpu.cuda(); - auto v = v_cpu.cuda(); - - at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * max_seqlen_k_ + 1}, at::kInt); - - for (int i = 0; i < batch_size * max_seqlen_k_ + 1; ++i) { - cu_seqlens_q_cpu[i] = i * max_seqlen_q_; - cu_seqlens_k_cpu[i] = i * max_seqlen_k_; - } - - auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); - auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); - - at::Tensor attn_bias_cpu = at::zeros({batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf); - - cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < max_seqlen_q_; k ++) { - for (int l = 0; l < max_seqlen_k_; l ++) { - // if (l == 0) attn_mask[i][j][k][l] = -INFINITY; - // if (l == 0) attn_bias_cpu[i][j][k][l] = -3e4; - // else attn_bias_cpu[i][j][k][l] = 0; - - // attn_bias_cpu[i][j][k][l] = 0; - attn_bias_cpu[i][j][k][l] = cnt * 0.1; - cnt ++; - // printf("i=%d, j=%d, k=%d, l=%d attn_bias=%f\n", i, j, k, l, attn_bias_cpu[i][j][k][l]); - // std::cout << "i=" << i << ", j=" << j << ", k=" << k << ", l" - // << l << << ", attn_bias=" << attn_bias_cpu[i][j][k][l] << std::endl; - } - } - } - } - - auto attn_bias = attn_bias_cpu.cuda(); - - c10::optional gen_; - c10::optional attn_mask; - std::vector ret ; - - if (has_bias) { - ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - attn_mask, - attn_bias - ); - dump_tensor("attn_output", ret[0], "has_bias"); - dump_tensor("attn_lse", ret[1], "has_bias"); - }else{ - ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - return_softmax, - gen_, - attn_mask, - attn_mask - // no bias - ); - dump_tensor("attn_output", ret[0], ""); - dump_tensor("attn_lse", ret[1], ""); - } - - // ret: std::vector result = {o, softmax_lse}; - // [bs * seq * seq, head, head_dim] - // [1 * 2 * 2, 1, 16] - // std::cout << "fwd Ret vec size is " << ret.size(); - // for (int i = 0; i < ret.size(); i ++) { - // ret[i].cpu(); - // std::cout << ret[i] << std::endl; - // } - - at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - dout_cpu[i][j][k] = cnt * 0.001; - cnt ++; - } - } - } - - at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - auto dout = dout_cpu.cuda(); - auto dq = dq_cpu.cuda(); - auto dk = dk_cpu.cuda(); - auto dv = dv_cpu.cuda(); - std::vector bwd_ret; - - if (has_bias) { - bwd_ret = mha_bwd( - dout, - q, - k, - v, - ret[0], - ret[1], - dq, - dk, - dv, - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - gen_, - attn_mask, - attn_bias - ); - dump_tensor("attn_dq", dq, "has_bias"); - dump_tensor("attn_dk", dk, "has_bias"); - dump_tensor("attn_dv", dv, "has_bias"); - dump_tensor("attn_dbias", bwd_ret[4], "has_bias"); - // dump_tensor("attn_ds", bwd_ret[5], "has_bias"); - }else{ - bwd_ret = mha_bwd( - dout, - q, - k, - v, - ret[0], - ret[1], - dq, - dk, - dv, - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - gen_, - attn_mask, - attn_mask - // placeholder - ); - dump_tensor("attn_dq", dq, ""); - dump_tensor("attn_dk", dk, ""); - dump_tensor("attn_dv", dv, ""); - } - - // std::cout << "bwd Ret vec size is " << ret.size(); - // for (int i = 0; i < bwd_ret.size(); i ++) { - // bwd_ret[i].cpu(); - // std::cout << bwd_ret[i] << std::endl; - // } -} - -int main(int argc, char** argv){ - // test_fwd(); - // test_fwd_with_bias_mini(); - int has_bias = 0; - int has_masked = 0; - - if ( argc >= 2 ) { - std::cout << "argv: " << argv[1] << std::endl; - if (strcmp(argv[1], "has_bias") == 0) { - if (strcmp(argv[2], "true") == 0) { - has_bias = 1; - }else{ - has_bias = 0; - } - std::cout << "has bias " << argv[2] << std::endl; - test_fwd_with_bias(has_bias); - }else if (strcmp(argv[1], "has_mask") == 0) { - if (strcmp(argv[2], "true") == 0) { - has_masked = 1; - }else{ - has_masked = 0; - } - std::cout << "has mask " << argv[2] << std::endl; - test_fwd_with_mask(has_masked); - }else{ - has_bias = 0; - has_masked = 0; - std::cout << "no paramter found" << std::endl; - } - } - return 0; -} diff --git a/tests/test_forward_shape.cu b/tests/test_forward_shape.cu deleted file mode 100644 index a26d7955d..000000000 --- a/tests/test_forward_shape.cu +++ /dev/null @@ -1,249 +0,0 @@ -#include -#include -//#include -#include -#include -#include -#include -#include -#include - - -void dump_tensor(const std::string &tensor_name, at::Tensor &tensor, const std::string &label) { - std::string file_name = label + "_" + tensor_name + ".data"; - std::ofstream file(file_name.c_str()); - // file << tensor_name << std::endl; - // file << tensor << std::endl; - std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; - auto flatten_tensor = tensor.flatten(); - auto size = flatten_tensor.numel(); - - for (int i = 0; i < size; i ++) { - file << flatten_tensor[i].item() << " "; - // file << flatten_tensor[i] << " "; - } - file << std::endl; -} - -void test_fwd_with_mask(int seq, int has_mask=1) { - int batch_size = 1; - int nheads = 1; - int headdim = 16; - // int seq = 400; - - int bs_seq = 1; - int max_seqlen_q_ = seq; - int max_seqlen_k_ = seq; - - float softmax_scale = 1; - - bool zero_tensors = false; - bool is_causal = false; - bool return_softmax = false; - - // q -> [bs * seq, head, head_dim] - // q -> [1 * 128, 1, 16] - // block q -> [128, 16] - - // k -> [bs * seq, head, head_dim] - // k -> [1 * 128, 1, 16] - // block k -> [128, 16] - - // v -> [bs * seq, head, head_dim] - // v -> [1 * 128, 1, 16] - // block k -> [128, 16] - - at::Tensor q_cpu = at::zeros({batch_size * bs_seq * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor k_cpu = at::zeros({batch_size * bs_seq * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor v_cpu = at::zeros({batch_size * bs_seq * max_seqlen_k_, nheads, headdim}, at::kHalf); - - int cnt = 0; - for (int i = 0; i < batch_size * bs_seq * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - q_cpu[i][j][k] = (cnt % 10000) * 0.001; - k_cpu[i][j][k] = (cnt % 10000) * 0.001; - v_cpu[i][j][k] = (cnt % 10000) * 0.001; - cnt ++; - } - } - } - - auto q = q_cpu.cuda(); - auto k = k_cpu.cuda(); - auto v = v_cpu.cuda(); - - at::Tensor cu_seqlens_q_cpu = at::zeros({batch_size * bs_seq + 1}, at::kInt); - at::Tensor cu_seqlens_k_cpu = at::zeros({batch_size * bs_seq + 1}, at::kInt); - - for (int i = 0; i < batch_size * bs_seq + 1; ++i) { - cu_seqlens_q_cpu[i] = i * max_seqlen_q_; - cu_seqlens_k_cpu[i] = i * max_seqlen_k_; - } - - auto cu_seqlens_q = cu_seqlens_q_cpu.cuda(); - auto cu_seqlens_k = cu_seqlens_k_cpu.cuda(); - - at::Tensor attn_mask = at::ones({batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_}, at::kHalf) * -1; - - // cnt = 0; - // for (int i = 0; i < batch_size * max_seqlen_k_; i ++) { - // for (int j = 0; j < nheads; j ++) { - // for (int k = 0; k < max_seqlen_q_; k ++) { - // for (int l = 0; l < max_seqlen_k_; l ++) { - // // attn_mask[i][j][k][l] = cnt * 0.001; - // // cnt ++; - // if (l % 2 == 0) { - // attn_mask[i][j][k][l] = 0; - // } - // cnt ++; - // } - // } - // } - // } - - for (int i = 0; i < batch_size * bs_seq; i ++) { - for (int j = 0; j < 1; j ++) { - for (int k = 0; k < 1; k ++) { - for (int l = 0; l < max_seqlen_k_; l ++) { - if (l % 2 == 0) { - attn_mask[i][0][0][l] = 0; - } - } - } - } - } - - for (int i = 0; i < batch_size * bs_seq; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < max_seqlen_q_; k ++) { - attn_mask[i][j][k] = attn_mask[i][0][0]; - } - } - } - - attn_mask = attn_mask.cuda(); - - c10::optional gen_; - c10::optional attn_bias; - - std::vector ret; - - ret = mha_fwd( - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, // False - is_causal, // False - return_softmax, // False - gen_, - attn_mask, - attn_bias - ); - dump_tensor("attn_output", ret[0], "has_mask"); - dump_tensor("attn_lse", ret[1], "has_mask"); - - - return ; - // std::cout << "Ret vec size is " << ret.size(); - // for (int i = 0; i < ret.size(); i ++) { - // ret[i].cpu(); - // std::cout << ret[i] << std::endl; - // } - - at::Tensor dout_cpu = at::ones({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - cnt = 0; - for (int i = 0; i < batch_size * max_seqlen_k_ * max_seqlen_k_; i ++) { - for (int j = 0; j < nheads; j ++) { - for (int k = 0; k < headdim; k ++) { - dout_cpu[i][j][k] = cnt * 0.001; - cnt ++; - } - } - } - - at::Tensor dq_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor dk_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - at::Tensor dv_cpu = at::zeros({batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim}, at::kHalf); - - auto dout = dout_cpu.cuda(); - auto dq = dq_cpu.cuda(); - auto dk = dk_cpu.cuda(); - auto dv = dv_cpu.cuda(); - std::vector bwd_ret; - - if (has_mask) { - bwd_ret = mha_bwd( - dout, - q, - k, - v, - ret[0], - ret[1], - dq, - dk, - dv, - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - gen_, - attn_mask, - attn_bias - ); - dump_tensor("attn_dq", dq, "has_mask"); - dump_tensor("attn_dk", dk, "has_mask"); - dump_tensor("attn_dv", dv, "has_mask"); - // dump_tensor("attn_ds", bwd_ret[5], "has_mask"); - }else{ - bwd_ret = mha_bwd( - dout, - q, - k, - v, - ret[0], - ret[1], - dq, - dk, - dv, - cu_seqlens_q, // b + 1 - cu_seqlens_k, // b + 1 - max_seqlen_q_, - max_seqlen_k_, - 0.0, - softmax_scale, - zero_tensors, - is_causal, - gen_, - attn_bias, - attn_bias - // placeholder - ); - dump_tensor("attn_dq", dq, ""); - dump_tensor("attn_dk", dk, ""); - dump_tensor("attn_dv", dv, ""); - } -} - -int main(int argc, char** argv){ - - if ( argc >= 2 ) { - std::cout << "argv: " << argv[1] << std::endl; - int seq = atoi(argv[1]); - - test_fwd_with_mask(seq); - - } - return 0; -} diff --git a/tests/test_torch_capi.cpp b/tests/test_torch_capi.cpp deleted file mode 100644 index adfbd722b..000000000 --- a/tests/test_torch_capi.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include -#include -#include -#include -#include -#include - - -torch::Tensor load_tensor(std::string filename) { - std::cout << filename << std::endl; - std::ifstream sfile(filename.c_str()); - - torch::Tensor tensor2; - torch::load(tensor2, sfile); - - // std::cout << tensor2 << std::endl; - return tensor2; -} - -int main(){ - - std::string label = ""; - std::string tensor_name = "input_mask"; - std::string sfile_name = label + "_" + tensor_name + ".pt"; - - // std::ifstream sfile(sfile_name.c_str()); - // torch::Tensor tensor2; - // torch::load(tensor2, sfile); - - torch::Tensor tensor_c = load_tensor(sfile_name); - std::cout << tensor_c << std::endl; - - - std::string python_file_name = "../" + label + "_" + tensor_name + ".pt"; - torch::Tensor tensor_python = load_tensor(python_file_name); - std::cout << tensor_python << std::endl; - - - // int batch_size = 2; - // int num_heads = 4; - // int max_seqlen_q = 8; - // int max_seqlen_k = 8; - - // auto bias = torch::ones({1, num_heads, max_seqlen_q, max_seqlen_k}); - // auto ds = torch::ones({batch_size, num_heads, max_seqlen_q, max_seqlen_k}); - // // batch_size, 1, num_heads, max_seqlen_q, max_seqlen_k - - - // auto shape = bias.sizes(); - // // auto newshape = std::vector(shape); - // // newshape.insert(newshape.begin(), -1); - // // std::cout << newshape << std::endl; - - // auto dbias = ds.reshape({-1, shape[0], shape[1], shape[2], shape[3] }).sum({0}); - - // std::cout << dbias.sizes() << std::endl; - return 0; -} - - diff --git a/tests/tools/check_output.py b/tests/tools/check_output.py deleted file mode 100644 index 99770dac6..000000000 --- a/tests/tools/check_output.py +++ /dev/null @@ -1,768 +0,0 @@ -from audioop import bias -from operator import truediv -import numpy as np -import torch - -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument("--test_np", required=False, help="test np implementation kernel with torch", type=bool, default=False) -parser.add_argument("--has_bias", required=False, help="add bias in attention", type=bool, default=False) -parser.add_argument("--has_mask", required=False, help="add mask in attention", type=bool, default=False) -parser.add_argument("--seqlen", required=False, help="seqlen", type=int, default=128) - -args = parser.parse_args() -print(args) - - -batch_size = 1 -nheads = 1 -headdim = 16 -if args.seqlen is not None: - seq = args.seqlen -else: - seq = 8 - -print ("processing seqlen: {0}".format(seq)) - -bs_seq = 1 -max_seqlen_q_ = seq -max_seqlen_k_ = seq - -dtypes = np.float16 - -q_cpu = np.zeros((batch_size * bs_seq * max_seqlen_k_, nheads, headdim), dtype=dtypes) -k_cpu = np.zeros((batch_size * bs_seq * max_seqlen_k_, nheads, headdim), dtype=dtypes) -v_cpu = np.zeros((batch_size * bs_seq * max_seqlen_k_, nheads, headdim), dtype=dtypes) - -cnt = 0 -for i in range(batch_size * bs_seq * max_seqlen_k_): - for j in range(nheads): - for k in range(headdim): - q_cpu[i][j][k] = cnt % 10000 * 0.001 - k_cpu[i][j][k] = cnt % 10000 * 0.001 - v_cpu[i][j][k] = cnt % 10000 * 0.001 - cnt += 1 - -# cost too much time when seq is large -# bias_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) -# cnt = 0 -# for i in range(batch_size * max_seqlen_k_): -# for j in range(nheads): -# for k in range(max_seqlen_q_): -# for l in range(max_seqlen_k_): -# bias_ref[i][j][k][l] = cnt * 0.1 -# cnt += 1 - -# mask_ref = np.ones([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) -# mask_ref = (1 - np.tril(mask_ref)) * -1 - -mask_ref = np.ones([batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_], dtype=dtypes) * -1 -# cnt = 0 -# for i in range(batch_size * max_seqlen_k_): -# for j in range(nheads): -# for k in range(max_seqlen_q_): -# for l in range(max_seqlen_k_): -# if l % 2 == 0: -# mask_ref[i][j][k][l] = 0 -# cnt += 1 - -for i in range(batch_size * bs_seq): - for j in range(1): - for k in range(1): - for l in range(max_seqlen_k_): - if l % 2 == 0: - mask_ref[i][j][k][l] = 0 - - -for i in range(batch_size * bs_seq): - for j in range(nheads): - for k in range(max_seqlen_q_): - mask_ref[i][j][k] = mask_ref[i][0][0] - - -# dout = np.random.rand(batch_size * max_seqlen_k_ * max_seqlen_k_, nheads, headdim).astype(dtype=dtypes) -cnt = 0 -dout = np.ones([batch_size * bs_seq * max_seqlen_k_, nheads, headdim], dtype=dtypes) -for i in range(batch_size * bs_seq * max_seqlen_k_): - for j in range(nheads): - for k in range(headdim): - dout[i][j][k] = cnt * 0.001 - cnt += 1 - -def softmax(logit): - max_value_over_last_dim = np.max(logit, axis=-1, keepdims=True) - logit_sub_max_value = logit - max_value_over_last_dim - - exp_x = np.exp(logit_sub_max_value) - - probs = exp_x / np.sum(exp_x, axis=-1, keepdims=True) - return probs - - -def fwd(q, k, v, max_seqlen_q, bias=None, mask=None): - - batch_size = int(q.shape[0] / max_seqlen_q) - head_num = q.shape[1] - head_dim = q.shape[2] - - q = q.reshape(batch_size, max_seqlen_q, head_num, head_dim) - k = k.reshape(batch_size, max_seqlen_q, head_num, head_dim) - v = v.reshape(batch_size, max_seqlen_q, head_num, head_dim) - - q = q.transpose(0,2,1,3) - k = k.transpose(0,2,1,3) - v = v.transpose(0,2,1,3) - - # print ("data q block 0 = {}".format(q[0, 0, :, :])) - - s = np.matmul(q, k.transpose(0, 1, 3, 2)) - - if bias is not None: - s = s + bias - - if mask is not None: - # s.masked_fill_(mask < 0, float('-inf')) - mask_broad = np.broadcast_to(mask, s.shape) - mask_np = np.ma.masked_where(mask_broad < 0, s) - # np.ma.set_fill_value(mask_np, float('-inf')) - np.ma.set_fill_value(mask_np, float('-inf')) - s = mask_np.filled() - - p = softmax(s) - - o = np.matmul(p, v) - - # o = o.transpose(0,2,1,3).reshape(batch_size * max_seqlen_q, head_num, head_dim) - return s, p, o, q, k, v - - -def bwd(dout, q, k, v, max_seqlen_q, bias=None, mask=None): - s, p, o, _, _, _ = fwd(q, k, v, max_seqlen_q=max_seqlen_q, bias=bias, mask=mask) - - batch_size = int(q.shape[0] / max_seqlen_q) - head_num = q.shape[1] - head_dim = q.shape[2] - - dout = dout.reshape(batch_size, max_seqlen_q, head_num, head_dim) - dout = dout.transpose(0, 2, 1, 3) - # import pdb; pdb.set_trace() - - q = q.reshape(batch_size, max_seqlen_q, head_num, head_dim) - k = k.reshape(batch_size, max_seqlen_q, head_num, head_dim) - v = v.reshape(batch_size, max_seqlen_q, head_num, head_dim) - - q = q.transpose(0, 2, 1, 3) - k = k.transpose(0, 2, 1, 3) - v = v.transpose(0, 2, 1, 3) - - # get dv - dv = np.matmul(p.transpose(0, 1, 3, 2), dout) - - # get dp - dp = np.matmul(dout, v.transpose(0, 1, 3, 2)) - - # ds_{i:} = P_{i:} \dot dP_{i:} - D_{i}P_{i:} - - ds = np.zeros([batch_size, head_num, max_seqlen_q, max_seqlen_q]) - for b in range(batch_size): - for h in range(head_num): - for i in range(max_seqlen_q): - # please refer equation 4 - Di = 0.0 - for l in range(max_seqlen_q): - Di += p[b][h][i][l] * dp[b][h][i][l] - - for j in range(max_seqlen_q): - ds[b][h][i][j] = p[b][h][i][j] * (dp[b][h][i][j] - Di) - - # get dq - dq = np.matmul(ds, k) - # dq = dq.transpose(0, 2, 1, 3) - - # get dk - dk = np.matmul(ds.transpose(0, 1, 3, 2), q) - # dk = dk.transpose(0, 2, 1, 3) - - if bias is None: - dbias = None - else: - dbias = ds.reshape(-1, *bias.shape).sum(axis=0) - - return dq, dk, dv, ds, dp, dbias - - -def fwd_pt(q_pt, k_pt, v_pt, bias=None, mask=None): - s = torch.matmul(q_pt, k_pt.transpose(-1, -2)) - - if bias is not None: - s = s + bias - - if mask is not None: - s.masked_fill_(mask < 0, float('-999')) - - p = torch.nn.functional.softmax(s, dim=-1) - # from unicore.modules import softmax_dropout - # p = softmax_dropout(s, dropout_prob=0, is_training=True, mask=mask, bias=bias) - - o = torch.matmul(p, v_pt) - return s, p, o - - -def bwd_pt(dout, q, k, v, max_seqlen_q, bias=None, mask=None): - # q is [batch * seq * seq, head, head_dim] - q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=bias, mask=mask) - - s, p, o = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) - - if bias is None: - dq, dk, dv = torch.autograd.grad(o, (q_pt, k_pt, v_pt), dout_pt) - return dq, dk, dv, None - else: - dq, dk, dv, dbias = torch.autograd.grad(o, (q_pt, k_pt, v_pt, bias_pt), dout_pt) - return dq, dk, dv, dbias - - -def compute_lse(s): - # import pdb; pdb.set_trace() - # og_dtype = s.dtype - # s = s.astype(np.float32) - - max_value_over_last_dim = np.max(s, axis=-1, keepdims=True) - logit_sub_max_value = s - max_value_over_last_dim - - exp_x = np.exp(logit_sub_max_value) - - softmax_lse = np.max(s, axis=-1, keepdims=True) + np.log(np.sum(exp_x, axis=-1, keepdims=True)) - - # softmax_lse = softmax_lse.astype(og_dtype) - return softmax_lse - - -def check_fwd_kernel(has_bias=False, has_mask=False): - print ("==== check fwd kernel with np ====") - if has_bias: - s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref, mask=None) - elif has_mask: - s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=mask_ref) - else: - s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=None) - # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) - - # attn_output = np.loadtxt("attn_output.data", delimiter=" ") - if has_bias: - prefix = "has_bias" - print ("has bias on, prefix is ", prefix) - elif has_mask: - prefix = "has_mask" - else: - prefix = "" - - attn_output = np.genfromtxt("{}_attn_output.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_output = attn_output.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) - attn_output = attn_output.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) - attn_output = attn_output.transpose(0, 2, 1, 3) - # batch_size * bs_seq, nheads, max_seqlen_k_, headdim - print ("attn output shape: ", attn_output.shape) - print ("output max error: ", np.abs(o - attn_output).max()) - - attn_lse = np.genfromtxt("{}_attn_lse.data".format(prefix), delimiter=" ", dtype=np.float32) - max_seqlen_q_pad = ((max_seqlen_q_ + 16 - 1) // 16) * 16 - attn_lse = attn_lse.reshape(batch_size * bs_seq , nheads, max_seqlen_q_pad) - # print ("attn lse: ", attn_lse) - attn_lse = attn_lse[:,:,:max_seqlen_q_] - - lse_ref = compute_lse(s) - lse_ref = lse_ref.reshape(batch_size * bs_seq , nheads, max_seqlen_q_) - # print ("ref lse: ", lse_ref) - - print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) - print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) - - print ("is same matrix: ", is_same_matrix(lse_ref, attn_lse)) - print ("is same matrix: ", is_same_matrix(o, attn_output)) - - # with python interface input - python_inputs = np.genfromtxt("../inputs_flash_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) - python_inputs = python_inputs.reshape(batch_size, bs_seq, nheads, max_seqlen_q_, headdim) - python_inputs = python_inputs.transpose(0, 1, 3, 2, 4) - python_inputs = python_inputs.reshape(batch_size * bs_seq * max_seqlen_q_, nheads, headdim) - print ("is same matrix input: ", is_same_matrix(python_inputs, q_cpu)) - - python_attn_mask = np.genfromtxt("../attn_mask_flash_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) - python_attn_mask = python_attn_mask.reshape(batch_size, bs_seq, nheads, max_seqlen_q_, max_seqlen_k_) - python_attn_mask = python_attn_mask.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_) - print ("is same matrix mask: ", is_same_matrix(python_inputs, q_cpu)) - - # flash tmp output - # out = out.reshape(*batch_dims, n, no_heads, c) - - # python_output_tmp0 = np.genfromtxt("../tmp2.data", delimiter=" ", dtype=np.float32) - # python_output_tmp0 = python_output_tmp0.reshape(batch_size, bs_seq, max_seqlen_q_, nheads, headdim) - # python_output_tmp0 = python_output_tmp0.transpose(0, 1, 3, 2, 4) - # python_output_tmp0 = python_output_tmp0.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) - - # print (python_output_tmp0.shape) - # print ("is same matrix flash output tmp1: ", is_same_matrix(o, python_output_tmp0, verbose=True)) - # print ("is same matrix flash output tmp1: ", is_same_matrix(attn_output, python_output_tmp0)) - - python_output_tmp1 = np.genfromtxt("../flash_temp1.output".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) - python_output_tmp1 = python_output_tmp1.reshape(batch_size, bs_seq, max_seqlen_q_, nheads, headdim) - python_output_tmp1 = python_output_tmp1.transpose(0, 1, 3, 2, 4) - python_output_tmp1 = python_output_tmp1.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) - - print (python_output_tmp1.shape) - print ("is same matrix flash output tmp1: ", is_same_matrix(o, python_output_tmp1, verbose=True)) - print ("is same matrix flash output tmp1: ", is_same_matrix(attn_output, python_output_tmp1, verbose=True)) - - # flash output - # [batch_size, bs_seq, seq_k, head, c_dim] - # 1, 1, 512, 1, 16 - python_output = np.genfromtxt("../output_flash_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) - python_output = python_output.reshape(batch_size, bs_seq, max_seqlen_q_, nheads, headdim) - python_output = python_output.transpose(0, 1, 3, 2, 4) - python_output = python_output.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) - - print (python_output.shape) - print ("is same matrix flash output: ", is_same_matrix(o, python_output)) - print ("is same matrix flash output: ", is_same_matrix(attn_output, python_output)) - - # torch output - python_torch_output = np.genfromtxt("../output_torch_seq{}.data".format(max_seqlen_q_), delimiter=" ", dtype=np.float32) - python_torch_output = python_torch_output.reshape(batch_size, bs_seq, nheads, max_seqlen_q_, headdim) - python_torch_output = python_torch_output.reshape(batch_size * bs_seq, nheads, max_seqlen_q_, headdim) - - print (python_torch_output.shape) - print ("is same matrix torch output: ", is_same_matrix(o, python_torch_output)) - print ("is same matrix torch output: ", is_same_matrix(attn_output, python_torch_output)) - - - -def check_fwd_kernel_pt(has_bias=False, has_mask=False): - print ("==== check fwd kernel with np ====") - if has_bias: - q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=None) - s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) - elif has_mask: - q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=mask_ref) - s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) - else: - q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) - s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=bias_pt, mask=mask_pt) - - o = o_pt.detach().cpu().numpy() - s = s_pt.detach().cpu().numpy() - - # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) - - # attn_output = np.loadtxt("attn_output.data", delimiter=" ") - if has_bias: - prefix = "has_bias" - print ("has bias on, prefix is ", prefix) - elif has_mask: - prefix = "has_mask" - else: - prefix = "" - - attn_output = np.genfromtxt("{}_attn_output.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_output = attn_output.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) - attn_output = attn_output.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) - attn_output = attn_output.transpose(0, 2, 1, 3) - print ("output max error: ", np.abs(o - attn_output).max()) - - attn_lse = np.genfromtxt("{}_attn_lse.data".format(prefix), delimiter=" ", dtype=np.float32) - max_seqlen_q_pad = ((max_seqlen_q_ + 16 - 1) // 16) * 16 - attn_lse = attn_lse.reshape(batch_size * bs_seq , nheads, max_seqlen_q_pad) - # print ("attn lse: ", attn_lse) - attn_lse = attn_lse[:,:,:max_seqlen_q_] - - lse_ref = compute_lse(s) - lse_ref = lse_ref.reshape(batch_size * bs_seq , nheads, max_seqlen_q_) - # print ("ref lse: ", lse_ref) - - print ("lse_ref shape = {}, attn_lse shape = {}".format(lse_ref.shape, attn_lse.shape)) - print ("lse max error: ", np.abs(lse_ref - attn_lse).max()) - - print ("is same matrix (lse): ", is_same_matrix(lse_ref, attn_lse)) - print ("is same matrix (attn_output): ", is_same_matrix(o, attn_output)) - - - -def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): - diff = np.abs(pred - gt) - - cnt = 0 - for index, x in np.ndenumerate(diff): - if x > abs_eps: - relative_diff = np.abs(x / gt[index]) - if relative_diff > relative_rps: - cnt += 1 - if verbose: - print (index, x, gt[index], relative_diff) - - if cnt > 0: - print ("not so match") - return False - else: - return True - - -def check_bwd_kernel(has_bias=False, has_mask=False): - print ("==== check bwd kernel with np ====") - if has_bias: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref, mask=None) - elif has_mask: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=mask_ref) - else: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=None, mask=None) - - if has_bias: - prefix = "has_bias" - print ("has bias on, prefix is ", prefix) - elif has_mask: - prefix = "has_mask" - print ("has mask on, prefix is ", prefix) - else: - prefix = "" - - attn_dq = np.genfromtxt("{}_attn_dq.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_dk = np.genfromtxt("{}_attn_dk.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_dv = np.genfromtxt("{}_attn_dv.data".format(prefix), delimiter=" ", dtype=np.float32) - if has_bias: - attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) - - attn_dq = attn_dq.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) - attn_dk = attn_dk.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) - attn_dv = attn_dv.reshape(batch_size * bs_seq * max_seqlen_k_, nheads, headdim) - - attn_dq = attn_dq.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) - attn_dk = attn_dk.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) - attn_dv = attn_dv.reshape(batch_size * bs_seq, max_seqlen_k_, nheads, headdim) - - if has_bias: - attn_dbias = attn_dbias.reshape(batch_size * bs_seq, nheads, max_seqlen_k_, max_seqlen_k_) - - attn_dq = attn_dq.transpose(0, 2, 1, 3) - attn_dk = attn_dk.transpose(0, 2, 1, 3) - attn_dv = attn_dv.transpose(0, 2, 1, 3) - - assert (dq.shape == attn_dq.shape), "oh dq shape didn't match" - assert (dk.shape == attn_dk.shape), "oh dk shape didn't match" - assert (dv.shape == attn_dv.shape), "oh dv shape didn't match" - - print ("max error in dq: ", np.abs(attn_dq - dq).max(), ) - print ("max error in dk: ", np.abs(attn_dk - dk).max(), ) - print ("max error in dv: ", np.abs(attn_dv - dv).max(), ) - if has_bias: - print ("max error in dq: ", np.abs(attn_dbias - dbias).max(), ) - # print (np.abs(attn_dbias - dbias) > 0.001) - # attn_ds = np.genfromtxt("{}_attn_ds.data".format(prefix), delimiter=" ", dtype=np.float32) - # attn_ds = attn_ds.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) - # print ("max error in ds: ", np.abs(attn_ds - ds).max(), ) - - attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_dbias = attn_dbias.reshape(batch_size * bs_seq, nheads, max_seqlen_k_, max_seqlen_k_) - print ("max error in dbias: ", np.abs(attn_dbias - dbias).max(), ) - - - print ("same matrix check q: ", is_same_matrix(attn_dq, dq)) - print ("same matrix check k: ", is_same_matrix(attn_dk, dk)) - print ("same matrix check v: ", is_same_matrix(attn_dv, dv)) - if has_bias: - import pdb; pdb.set_trace() - print ("same matrix check dbias: ", is_same_matrix(attn_dbias, dbias)) - - -def check_bwd_np(has_bias=False): - print ("==== check bwd np ====") - if has_bias: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) - dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) - else: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) - dq_pt, dk_pt, dv_pt, dbias_pt = bwd_pt(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) - - assert (dq.shape == dq_pt.detach().cpu().numpy().shape), "oh dq shape didn't match" - assert (dk.shape == dk_pt.detach().cpu().numpy().shape), "oh dk shape didn't match" - assert (dv.shape == dv_pt.detach().cpu().numpy().shape), "oh dv shape didn't match" - if has_bias: - assert (dbias.shape == dbias_pt.detach().cpu().numpy().shape), "oh dbias shape didn't match" - - print ("max error in dq: ", np.abs( dq - dq_pt.detach().cpu().numpy() ).max()) - print ("max error in dk: ", np.abs( dk - dk_pt.detach().cpu().numpy() ).max()) - print ("max error in dv: ", np.abs( dv - dv_pt.detach().cpu().numpy() ).max()) - if has_bias: - print ("max error in dbias: ", np.abs( dbias - dbias_pt.detach().cpu().numpy() ).max()) - - return - - -def prepare_pt_data(dout, q, k, v, max_seqlen_q, bias=None, mask=None): - q_pt = torch.from_numpy(q.copy()) - k_pt = torch.from_numpy(k.copy()) - v_pt = torch.from_numpy(v.copy()) - - batch_size = int(q.shape[0] / max_seqlen_q) - head_num = q.shape[1] - head_dim = q.shape[2] - import pdb; pdb.set_trace() - - dout_pt = torch.from_numpy(dout.copy()) - dout_pt = dout_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) - dout_pt = dout_pt.permute(0, 2, 1, 3).cuda() - - q_pt = q_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) - k_pt = k_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) - v_pt = v_pt.reshape(batch_size, max_seqlen_q, head_num, head_dim) - - q_pt = q_pt.permute(0, 2, 1, 3).cuda() - k_pt = k_pt.permute(0, 2, 1, 3).cuda() - v_pt = v_pt.permute(0, 2, 1, 3).cuda() - - if bias is not None: - bias_pt = torch.from_numpy(bias.copy()).cuda() - bias_pt.requires_grad = True - else: - bias_pt = None - - if mask is not None: - mask_pt = torch.from_numpy(mask.copy()).cuda() - else: - mask_pt = None - - q_pt.requires_grad = True - k_pt.requires_grad = True - v_pt.requires_grad = True - - return q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt - - -def check_fwd_np(has_bias=False, has_atten=False): - print ("==== check fwd np ====") - if has_bias: - s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) - - q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=bias_ref, mask=mask_ref) - s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias_pt, mask_pt) - else: - s, p, o, q, k, v = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_, bias=None, mask=None) - - q_pt, k_pt, v_pt, dout_pt, bias_pt, mask_pt = prepare_pt_data(dout, q_cpu, k_cpu, v_cpu, max_seqlen_q=max_seqlen_q_) - s_pt, p_pt, o_pt = fwd_pt(q_pt, k_pt, v_pt, bias=None, mask=None) - - def check_input(a, b): - print ("max error in input: ", np.abs(a - b).max()) - - check_input(q, q_pt.detach().cpu().numpy()) - check_input(k, q_pt.detach().cpu().numpy()) - check_input(v, q_pt.detach().cpu().numpy()) - - assert (s.shape == s_pt.detach().cpu().numpy().shape), "oh s shape didn't match" - assert (p.shape == p_pt.detach().cpu().numpy().shape), "oh p shape didn't match" - assert (o.shape == o_pt.detach().cpu().numpy().shape), "oh o shape didn't match" - - print ("max error in s: ", np.abs( s - s_pt.detach().cpu().numpy() ).max()) - print ("max error in p: ", np.abs( p - p_pt.detach().cpu().numpy() ).max()) - print ("max error in o: ", np.abs( o - o_pt.detach().cpu().numpy() ).max()) - - return - - -def parse_softmax_load(filename): - from parse import parse - format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' - softmax_p = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) - - with open(filename, "r") as f: - for line in f.readlines(): - # print (line) - if line.startswith("bwd softmax: "): - # print (line.strip()) - result = parse(format_string, line.strip()) - # print (result) - - tidx_ = int(result[0]) - mi = int(result[2]) - ni = int(result[3]) - ii = int(result[4]) - jj = int(result[5]) - value = float(result[6]) - - warp = tidx_ // 32 - lane = tidx_ % 32 - - warp_n = (warp // 1) - warp_m = (warp % 1) - - quad = lane // 4 - tid = (lane % 4) * 2 - - row = warp_m * 16 + quad - col = warp_n * 16 + tid - - current_row = mi * 16 + ii * 8 + row - # current_col = ni * 64 + jj * 8 + col - # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col - current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - if (current_row < 8 and current_col < 8): - print (line.strip()) - print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - warp, lane, quad, tid, current_row, current_col, value)) - softmax_p[0, 0, current_row, current_col] = value - - return softmax_p - - -def check_softmax_p(softmax_data, has_bias=False): - if has_bias: - s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) - else: - s, p, o, _, _, _ = fwd(q_cpu, k_cpu, v_cpu, max_seqlen_k_) - # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) - import pdb; pdb.set_trace() - print ("max error in p: ", np.abs(p[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) - print ("same matrix check p: ", is_same_matrix(p[0, 0, :, :], softmax_data[0, 0, :, :])) - return - - -def parse_dsoftmax_load(filename): - from parse import parse - format_string = 'bwd dsoftmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' - dsoftmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) - - with open(filename, "r") as f: - for line in f.readlines(): - # print (line) - if line.startswith("bwd dsoftmax: "): - # print (line.strip()) - result = parse(format_string, line.strip()) - # print (result) - - tidx_ = int(result[0]) - mi = int(result[2]) - ni = int(result[3]) - ii = int(result[4]) - jj = int(result[5]) - value = float(result[6]) - - warp = tidx_ // 32 - lane = tidx_ % 32 - - warp_n = (warp // 1) - warp_m = (warp % 1) - - quad = lane // 4 - tid = (lane % 4) * 2 - - row = warp_m * 16 + quad - col = warp_n * 16 + tid - - current_row = mi * 16 + ii * 8 + row - # current_col = ni * 64 + jj * 8 + col - # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col - current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - if (current_row < 8 and current_col < 8): - print (line.strip()) - print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - warp, lane, quad, tid, current_row, current_col, value)) - dsoftmax[0, 0, current_row, current_col] = value - - return dsoftmax - - -def check_dsoftmax_p(softmax_data, has_bias=False): - if has_bias: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_, bias=bias_ref) - else: - dq, dk, dv, ds, dp, dbias = bwd(dout, q_cpu, k_cpu, v_cpu, max_seqlen_k_) - - if has_bias: - prefix = "has_bias" - print ("has bias on, prefix is ", prefix) - else: - prefix = "" - - # print ("q * k = p'shape = {} p = {}".format(p.shape, p)) - import pdb; pdb.set_trace() - print ("max error in p: ", np.abs(ds[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) - print ("same matrix check p: ", is_same_matrix(ds[0, 0, :, :], softmax_data[0, 0, :, :])) - - attn_ds = np.genfromtxt("{}_attn_ds.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_ds = attn_ds.reshape(batch_size * max_seqlen_k_, nheads, max_seqlen_k_, max_seqlen_k_) - - attn_dbias = np.genfromtxt("{}_attn_dbias.data".format(prefix), delimiter=" ", dtype=np.float32) - attn_dbias = attn_ds.reshape(*bias_ref.shape) - - print ("max error in attn ds with softmax: ", np.abs(attn_ds[0, 0, :, :] - softmax_data[0, 0, :, :]).max(), ) - print ("max error in attn ds with bwd: ", np.abs(attn_ds - ds).max(), ) - print ("max error in attn dbias with bwd: ", np.abs(attn_dbias - dbias).max(), ) - # for i in range(batch_size * max_seqlen_k_): - # for j in range(nheads): - # print ("max error in i = {}, j = {}, max_error = {} ".format(i, j, np.abs(attn_ds[i, j, :, :] - ds[i, j, :, :]).max(), )) - # print (np.abs(attn_ds[i, j, :, :] - ds[i, j, :, :]) <= 0.001) - # print ("attn_ds: ", attn_ds[i, j, :, :]) - # print ("ds: ", ds[i, j, :, :]) - - # for i in range(batch_size * max_seqlen_k_): - # for j in range(nheads): - # print ("max error in i = {}, j = {}, max_error = {} ".format(i, j, np.abs(attn_dbias[i, j, :, :] - dbias[i, j, :, :]).max(), )) - # print (np.abs(attn_dbias[i, j, :, :] - dbias[i, j, :, :]) <= 0.001) - # print ("attn_dbias: ", attn_dbias[i, j, :, :]) - # print ("dbias: ", dbias[i, j, :, :]) - return - - -if __name__ == '__main__': - # print ("====test without bias====") - # has_bias = False - # check_fwd_np(has_bias=has_bias) - # check_bwd_np(has_bias=has_bias) - # print ("====test without bias====") - - # print ("====test with bias====") - # has_bias = True - # check_fwd_np(has_bias=has_bias) - # check_bwd_np(has_bias=has_bias) - # print ("====test with bias====") - - # print ("====test kernel using torch====") - # has_bias = args.has_bias - # has_mask = args.has_mask - - # check_fwd_kernel_pt(has_bias=has_bias, has_mask=has_mask) - - print ("====test kernel using numpy====") - has_bias = args.has_bias - has_mask = args.has_mask - - check_fwd_kernel(has_bias=has_bias, has_mask=has_mask) - # check_bwd_kernel(has_bias=has_bias, has_mask=has_mask) - - # print ("====test kernel with bias====") - # has_bias = True - # check_fwd_kernel(has_bias=has_bias) - # check_bwd_kernel(has_bias=has_bias) - - # print ("====test bwd kernel softmax without bias====") - # has_bias = False - # softmax_data = parse_softmax_load("output.log") - # check_softmax_p(softmax_data=softmax_data, has_bias=has_bias) - - # print ("====test bwd kernel softmax with bias====") - # has_bias = True - # softmax_data = parse_softmax_load("output.log") - # check_softmax_p(softmax_data=softmax_data, has_bias=has_bias) - - # print ("====test bwd kernel softmax without bias====") - # has_bias = False - # dsoftmax_data = parse_dsoftmax_load("output.log") - # check_dsoftmax_p(softmax_data=dsoftmax_data, has_bias=has_bias) - - # print ("====test bwd kernel softmax with bias====") - # has_bias = True - # dsoftmax_data = parse_dsoftmax_load("output.log") - # check_dsoftmax_p(softmax_data=dsoftmax_data, has_bias=has_bias) diff --git a/tests/tools/rebuild_bwd_attn_mask.py b/tests/tools/rebuild_bwd_attn_mask.py deleted file mode 100644 index 790f49642..000000000 --- a/tests/tools/rebuild_bwd_attn_mask.py +++ /dev/null @@ -1,64 +0,0 @@ -from parse import parse -import sys -import numpy as np - -filename = "./output.log" -if len(sys.argv) > 1: - filename = sys.argv[1] - -# bwd softmax: threadIdx=195, l=0, mi=0, ki=1, ii=3, jj=0, elt=0.000000 -format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' -batch_size = 1 -nheads = 1 -headdim = 16 -seq = 8 -seq_q = 8 -max_seqlen_q_ = seq_q -max_seqlen_k_ = seq_q - - -d_softmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) - -def parse_dsoftmax_load(filename): - with open(filename, "r") as f: - for line in f.readlines(): - # print (line) - if line.startswith("bwd softmax: "): - # print (line.strip()) - result = parse(format_string, line.strip()) - # print (result) - - tidx_ = int(result[0]) - mi = int(result[2]) - ni = int(result[3]) - ii = int(result[4]) - jj = int(result[5]) - value = float(result[6]) - - warp = tidx_ // 32 - lane = tidx_ % 32 - - warp_n = (warp // 1) - warp_m = (warp % 1) - - quad = lane // 4 - tid = (lane % 4) * 2 - - row = warp_m * 16 + quad - col = warp_n * 16 + tid - - current_row = mi * 16 + ii * 8 + row - # current_col = ni * 64 + jj * 8 + col - # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col - current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - if (current_row < 8 and current_col < 8): - print (line.strip()) - print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - warp, lane, quad, tid, current_row, current_col, value)) - d_softmax[0, 0, current_row, current_col] = value - - -parse_dsoftmax_load(filename) -print ("output block 0 d_softmax: ", d_softmax[0, 0, :, :]) diff --git a/tests/tools/rebuild_bwd_softmax.py b/tests/tools/rebuild_bwd_softmax.py deleted file mode 100644 index 790f49642..000000000 --- a/tests/tools/rebuild_bwd_softmax.py +++ /dev/null @@ -1,64 +0,0 @@ -from parse import parse -import sys -import numpy as np - -filename = "./output.log" -if len(sys.argv) > 1: - filename = sys.argv[1] - -# bwd softmax: threadIdx=195, l=0, mi=0, ki=1, ii=3, jj=0, elt=0.000000 -format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' -batch_size = 1 -nheads = 1 -headdim = 16 -seq = 8 -seq_q = 8 -max_seqlen_q_ = seq_q -max_seqlen_k_ = seq_q - - -d_softmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) - -def parse_dsoftmax_load(filename): - with open(filename, "r") as f: - for line in f.readlines(): - # print (line) - if line.startswith("bwd softmax: "): - # print (line.strip()) - result = parse(format_string, line.strip()) - # print (result) - - tidx_ = int(result[0]) - mi = int(result[2]) - ni = int(result[3]) - ii = int(result[4]) - jj = int(result[5]) - value = float(result[6]) - - warp = tidx_ // 32 - lane = tidx_ % 32 - - warp_n = (warp // 1) - warp_m = (warp % 1) - - quad = lane // 4 - tid = (lane % 4) * 2 - - row = warp_m * 16 + quad - col = warp_n * 16 + tid - - current_row = mi * 16 + ii * 8 + row - # current_col = ni * 64 + jj * 8 + col - # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col - current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - if (current_row < 8 and current_col < 8): - print (line.strip()) - print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - warp, lane, quad, tid, current_row, current_col, value)) - d_softmax[0, 0, current_row, current_col] = value - - -parse_dsoftmax_load(filename) -print ("output block 0 d_softmax: ", d_softmax[0, 0, :, :]) diff --git a/tests/tools/rebuild_dsoftmax.py b/tests/tools/rebuild_dsoftmax.py deleted file mode 100644 index 790f49642..000000000 --- a/tests/tools/rebuild_dsoftmax.py +++ /dev/null @@ -1,64 +0,0 @@ -from parse import parse -import sys -import numpy as np - -filename = "./output.log" -if len(sys.argv) > 1: - filename = sys.argv[1] - -# bwd softmax: threadIdx=195, l=0, mi=0, ki=1, ii=3, jj=0, elt=0.000000 -format_string = 'bwd softmax: threadIdx={}, l={}, mi={}, ki={}, ii={}, jj={}, elt={}' -batch_size = 1 -nheads = 1 -headdim = 16 -seq = 8 -seq_q = 8 -max_seqlen_q_ = seq_q -max_seqlen_k_ = seq_q - - -d_softmax = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) - -def parse_dsoftmax_load(filename): - with open(filename, "r") as f: - for line in f.readlines(): - # print (line) - if line.startswith("bwd softmax: "): - # print (line.strip()) - result = parse(format_string, line.strip()) - # print (result) - - tidx_ = int(result[0]) - mi = int(result[2]) - ni = int(result[3]) - ii = int(result[4]) - jj = int(result[5]) - value = float(result[6]) - - warp = tidx_ // 32 - lane = tidx_ % 32 - - warp_n = (warp // 1) - warp_m = (warp % 1) - - quad = lane // 4 - tid = (lane % 4) * 2 - - row = warp_m * 16 + quad - col = warp_n * 16 + tid - - current_row = mi * 16 + ii * 8 + row - # current_col = ni * 64 + jj * 8 + col - # current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col - current_col = ni * 64 + (jj & 2) * 8 + (jj & 1) + col - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - if (current_row < 8 and current_col < 8): - print (line.strip()) - print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - warp, lane, quad, tid, current_row, current_col, value)) - d_softmax[0, 0, current_row, current_col] = value - - -parse_dsoftmax_load(filename) -print ("output block 0 d_softmax: ", d_softmax[0, 0, :, :]) diff --git a/tests/tools/rebuild_fwd_softmax.py b/tests/tools/rebuild_fwd_softmax.py deleted file mode 100644 index 3158ac2a9..000000000 --- a/tests/tools/rebuild_fwd_softmax.py +++ /dev/null @@ -1,113 +0,0 @@ -from parse import parse -import sys -import numpy as np - -def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): - diff = np.abs(pred - gt) - - cnt = 0 - for index, x in np.ndenumerate(diff): - if x > abs_eps: - if abs(gt[index]) < 1e-9: - relative_diff = 100 - else: - relative_diff = np.abs(x / gt[index]) - if relative_diff > relative_rps: - cnt += 1 - if verbose: - print ("index={0}, diff={1}, pred={2}, true={3}, relative_diff={4}".format( - index, x, pred[index], gt[index], relative_diff)) - - if cnt > 0: - print ("not so match") - return False - else: - return True - -filename = "./output512.log" -if len(sys.argv) > 1: - filename = sys.argv[1] - -# Attnmask: threadIdx.x = 98, threadIdx.y = 0, mi = 0, ni = 0, ii = 0, jj = 2, value = 0.000000, softmax = 0.608030, l = 0, loop_step_idx=1, blockIdx.x = 0 -format_string = 'Attnmask: threadIdx.x = {0}, threadIdx.y = {1}, mi = {2}, ni = {3}, ii = {4}, jj = {5}, value = {6}, softmax = {7}, l = {8}, loop_step_idx={9}, blockIdx.x = {10}' -batch_size = 1 -nheads = 1 -headdim = 16 -bs_seq = 1 -seq_q = 512 -max_seqlen_q_ = seq_q -max_seqlen_k_ = seq_q - -Cta_tile_p_N = 256 -Cta_tile_p_M = 16 - - -def parse_fwd_softmax_load(filename): - print ("processing... reconstruct from ", filename) - softmax_data = np.zeros([batch_size * bs_seq, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) - with open(filename, "r") as f: - for line in f.readlines(): - # print (line) - if line.startswith("Attnmask: "): - # print (line.strip()) - result = parse(format_string, line.strip()) - # print (result) - - tidx_ = int(result[0]) - mi = int(result[2]) - ni = int(result[3]) - ii = int(result[4]) - jj = int(result[5]) - value = float(result[6]) - softmax_elt = float(result[7]) - q_loop = int(result[8]) - k_loop = int(result[9]) - - warp = tidx_ // 32 - lane = tidx_ % 32 - # thread per warp = 32 - - warp_n = (warp // 1) - warp_m = (warp % 1) - # WARPS_M = 1 - - quad = lane // 4 - tid = (lane % 4) * 2 - - row = warp_m * 16 + quad - col = warp_n * 16 + tid - - current_row = Cta_tile_p_M * q_loop + mi * 16 + ii * 8 + row - - current_col = k_loop * Cta_tile_p_N + ni * 64 + (jj & 2) * 4 + (jj & 1) + col - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - if current_col > 510: - print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - warp, lane, quad, tid, current_row, current_col, value)) - print (line.strip()) - if (current_row < 16 and current_col < 512): - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - # print ("") - softmax_data[0, 0, current_row, current_col] = value - - return softmax_data - - -softmax_cpp = parse_fwd_softmax_load(filename) -softmax_python = parse_fwd_softmax_load("../" + filename) - - -print (is_same_matrix(softmax_cpp, softmax_python, verbose=True)) - -for i in range(16): - print ("first part idx = {} softmax cpp = {}: ".format(i, softmax_cpp[0, 0, i, :256])) - print ("first part idx = {} softmax python = {}: ".format(i, softmax_python[0, 0, i, :256])) - - print (np.allclose(softmax_cpp[0, 0, i, :256],softmax_python[0, 0, i, :256])) - - print ("second part idx = {} softmax cpp = {}: ".format(i, softmax_cpp[0, 0, i, 256:])) - print ("second part idx = {} softmax python = {}: ".format(i, softmax_python[0, 0, i, 256:])) - - print (np.allclose(softmax_cpp[0, 0, i, 256:],softmax_python[0, 0, i, 256:])) diff --git a/tests/tools/rebuild_mat.py b/tests/tools/rebuild_mat.py deleted file mode 100644 index 1fdf98e0a..000000000 --- a/tests/tools/rebuild_mat.py +++ /dev/null @@ -1,94 +0,0 @@ -from parse import parse -import sys -import numpy as np - -filename = "./output.log" -if len(sys.argv) > 1: - filename = sys.argv[1] - -# AttnBias: threadIdx.x = 0, threadIdx.y = 0, mi = 0, ni = 0, ii = 0, jj = 0, value = 0.000000 -format_string = 'AttnBias: threadIdx.x = {}, threadIdx.y = {}, mi = {}, ni = {}, ii = {}, jj = {}, value = {}, ldx = {}, blockIdx.x = {}' -batch_size = 1 -nheads = 1 -headdim = 16 -seq = 8 -seq_q = 8 -max_seqlen_q_ = seq_q -max_seqlen_k_ = seq_q - - -mask_ref = np.zeros([batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_], dtype=np.float16) -cnt = 0 - -for i in range(batch_size * max_seqlen_k_): - for j in range(nheads): - for k in range(max_seqlen_q_): - for l in range(max_seqlen_k_): - mask_ref[i][j][k][l] = cnt * 0.001 - cnt += 1 - -# mask = np.zeros([1, 1, max_seqlen_q_, max_seqlen_k_], dtype=np.float32) -# batch_size * max_seqlen_k_, nheads, max_seqlen_q_, max_seqlen_k_ -mask = np.zeros([batch_size * max_seqlen_k_, nheads, 16, 128], dtype=np.float16) - - -def parse_bias_load(filename): - with open(filename, "r") as f: - for line in f.readlines(): - # print (line) - if line.startswith("AttnBias:"): - # print (line.strip()) - - result = parse(format_string, line.strip()) - print (result) - # import pdb; pdb.set_trace() - # if result[0] == 0: - # print (result[0], result[1], result[2], result[3], result[4], result[5], result[6]) - tidx_ = int(result[0]) - mi = int(result[2]) - ni = int(result[3]) - ii = int(result[4]) - jj = int(result[5]) - value = float(result[6]) - block_idx = int(result[8]) - - warp = tidx_ // 32 - lane = tidx_ % 32 - - warp_n = (warp // 1) - warp_m = (warp % 1) - - quad = lane // 4 - tid = (lane % 4) * 2 - - row = warp_m * 16 + quad - col = warp_n * 16 + tid - - current_row = mi * 16 + ii * 8 + row - # current_col = ni * 64 + jj * 8 + col - current_col = ni * 64 + (jj & 2) * 4 + (jj & 1) + col - - # if (current_row < 8 and current_col < 8): - # print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_row={}, value={}".format( - # warp, lane, quad, tid, current_row, current_col, value)) - # mask[0, 0, current_row, current_col] = value - print ("warp={}, lane={}, quad={}, tid={}, current_row={}, current_col={}, value={}".format( - warp, lane, quad, tid, current_row, current_col, value)) - mask[block_idx, 0, current_row, current_col] = value - - -def check(mask, mask_ref, block_idx=0): - flag = True - bs, nheads, max_seqlen_q_, max_seqlen_k_ = mask_ref.shape - for i in range(max_seqlen_q_): - for j in range(max_seqlen_k_): - if (abs(mask[0, 0, i, j] - mask_ref[block_idx, 0, i, j]) > 1e-3): - print ("False in block_idx = {}, i = {}, j = {}, mask = {}, mask_ref = {}".format(block_idx, - i, j, mask[0, 0, i, j] - mask_ref[block_idx, 0, i, j])) - flag = False - return flag - -parse_bias_load(filename) - -# block_idx = 1 -# print (check(mask, mask_ref, block_idx)) \ No newline at end of file From 2543703ed7035aaf0cd61ba5ffc2b6910b7485f2 Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 11 Oct 2022 23:31:26 +0800 Subject: [PATCH 50/71] clean fmha_api.cpp --- csrc/flash_attn/fmha_api.cpp | 73 +----------------------------------- 1 file changed, 1 insertion(+), 72 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index d4da2eba5..3f8a65b38 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -32,10 +32,6 @@ #include "fmha.h" -#ifdef DDEBUG_PRINT -#include "fmha_api.h" -#include -#endif #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") @@ -114,27 +110,6 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.mask_head_mod_size = mask_head_mod_size; params.mask_seq_mod_size = mask_seq_mod_size; - -#ifdef DEBUG_PRINT - printf("========================================\n"); - printf("params.q_row_stride_in_elts = %d \n", params.q_row_stride_in_elts); - printf("params.k_row_stride_in_elts = %d \n", params.k_row_stride_in_elts); - printf("params.v_row_stride_in_elts = %d \n", params.v_row_stride_in_elts); - printf("params.q_head_stride_in_elts = %d \n", params.q_head_stride_in_elts); - printf("params.k_head_stride_in_elts = %d \n", params.k_head_stride_in_elts); - printf("params.v_head_stride_in_elts = %d \n", params.v_head_stride_in_elts); - printf("params.h = %d \n", params.h); - printf("params.b = %d \n", params.b); - printf("params.seqlen_q (max seq) = %d \n", params.seqlen_q); - printf("params.seqlen_k (max seq) = %d \n", params.seqlen_k); - printf("params.d = %d \n", params.d); - printf("params.o_row_stride_in_elts = %d \n", params.o_row_stride_in_elts); - printf("params.o_head_stride_in_elts = %d \n", params.o_head_stride_in_elts); - printf("params.s_stride_in_bytes = %d \n", params.s_stride_in_bytes); - printf("========================================\n"); -#endif - - // Set the different scale values. // const float scale_bmm1 = 1.f / sqrtf(d); const float scale_bmm1 = softmax_scale; @@ -222,34 +197,6 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, params.attn_ds_ptr = attn_ds; } -#ifdef DEBUG_PRINT -void dump_tensor(const std::string &tensor_name, const at::Tensor &tensor, const std::string &label) { - std::string file_name = label + "_" + tensor_name + ".data"; - std::ofstream file(file_name.c_str()); - // file << tensor_name << std::endl; - // file << tensor << std::endl; - std::cout << "tensor_name stride 0: " << tensor_name << " " << tensor.stride(0) << std::endl; - std::cout << "tensor_name stride 1: " << tensor_name << " " << tensor.stride(1) << std::endl; - std::cout << "tensor_name stride 2: " << tensor_name << " " << tensor.stride(2) << std::endl; - std::cout << "tensor_name stride 3: " << tensor_name << " " << tensor.stride(-1) << std::endl; - - std::cout << "tensor_name size: " << tensor_name << " " << tensor.sizes() << std::endl; - // cost too much time - auto flatten_tensor = tensor.flatten(); - auto size = flatten_tensor.numel(); - - for (int i = 0; i < size; i ++) { - file << flatten_tensor[i].item() << " "; - // file << flatten_tensor[i] << " "; - } - file << std::endl; - - std::string sfile_name = label + "_" + tensor_name + ".pt"; - std::ofstream sfile(sfile_name.c_str()); - torch::save(tensor, sfile); -} -#endif - std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -339,18 +286,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(mask_sizes[2] == 1 || mask_sizes[2] == max_seqlen_q_); } -#ifdef DEBUG_PRINT - dump_tensor("input_q", q, ""); - // dump_tensor("input_k", k, ""); - // dump_tensor("input_v", v, ""); - if (attn_mask.has_value()) { - dump_tensor("input_mask", *attn_mask, ""); - } - if (attn_bias.has_value()) { - dump_tensor("input_bias", *attn_bias, ""); - } -#endif - int blocksize_c = ((head_size == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size == 64 && is_dropout)) ? 128 : 256; // Need to round max_seqlen_k to multiples of blocksize_c int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; @@ -421,10 +356,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fp16_sm80(launch_params, /*configure=*/false); -#ifdef DEBUG_PRINT - dump_tensor("output_o", o, ""); -#endif - std::vector result = {o, softmax_lse}; if (return_softmax) {result.push_back(s);} return result; @@ -915,12 +846,10 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size return { dq, dk, dv, softmax_d }; } -#if !defined(DEBUG_USING_NVCC) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "Fused Multi-head Self-attention"; m.def("fwd", &mha_fwd, "Forward pass"); m.def("bwd", &mha_bwd, "Backward pass"); m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)"); m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)"); -} -#endif \ No newline at end of file +} \ No newline at end of file From bf68f9050680881e678415258e9b5fd99124221f Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 11 Oct 2022 23:34:05 +0800 Subject: [PATCH 51/71] clean fmha_dgrad_fp16_kernel_loop.sm80.cu --- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 5678377a4..3aa0cc472 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -31,10 +31,6 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ bool has_attn = !(params.attn_mask_ptr == nullptr); bool has_bias = !(params.attn_bias_ptr == nullptr); -#ifdef DEBUG_PRINT - printf ("has_attn=%d, has_bias=%d, bias_mod_size=%d\n", has_attn, has_bias, params.bias_mod_size); -#endif - if (has_attn) { if (has_bias) { BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { @@ -128,32 +124,6 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ }); } } - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. -// BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { -// auto kernel = params.is_causal -// ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel -// : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; -// if (params.seqlen_k == blocksize_c) { -// kernel = params.is_causal -// ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel -// : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; -// } else if (params.seqlen_k == blocksize_c * 2) { -// kernel = params.is_causal -// ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel -// : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; -// } -// if( smem_size_dq_dk_dv >= 48 * 1024 ) { -// FMHA_CHECK_CUDA(cudaFuncSetAttribute( -// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); -// } -// dim3 grid(params.b, params.h); -// kernel<<>>(params); -// #ifdef DEBUG_PRINT -// printf("bwd grid size: %d %d\n", params.b, params.h); -// printf("bwd block size: %d\n", Kernel_traits::THREADS); -// #endif -// FMHA_CHECK_CUDA(cudaPeekAtLastError()); -// }); } void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { From 4f37437cbcce05cc3e35ba048d3132e148eab7eb Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 11 Oct 2022 23:38:03 +0800 Subject: [PATCH 52/71] clean fmha_dgrad_kernel_1xN_loop.h --- .../src/fmha_dgrad_kernel_1xN_loop.h | 142 +----------------- 1 file changed, 4 insertions(+), 138 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index f01e37cd8..2a501dd49 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -153,13 +153,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; // conctructor Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); - // TODO: load fun as s // Allocate the global memory tile loader for bias. using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; // conctructor Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); - // TODO: load fun as s using Gmem_tile_ds = typename Kernel_traits::Gmem_tile_ds; Gmem_tile_ds gmem_ds(params, binfo, tidx, loop_step_idx); @@ -214,14 +212,10 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_softmax_d.move(begin); if constexpr (has_attn) { - // if (!(params.attn_mask_ptr == nullptr)) { - // TODO: mask move gmem_mask.move(begin); } if constexpr (has_attn) { - // if (!(params.attn_bias_ptr == nullptr)) { - // TODO: mask move gmem_bias.move(begin); gmem_ds.move(begin); } @@ -250,7 +244,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Commit the data for Q, dO, and V to shared memory. gmem_q.commit(gemm_q_k.smem_q); gmem_do.commit(smem_do); - // D_sum + if (Is_first) { dot_do_o( gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx @@ -363,55 +357,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_mask.template load(frag_mask); gmem_mask.move(); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif // Apply the attn mask. softmax.apply_attn_mask(frag_mask); - -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif } if constexpr (has_bias) { @@ -421,55 +368,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_bias.template load(frag_bias); gmem_bias.move(); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif // Apply the attn mask. softmax.apply_attn_bias(frag_bias, l); - -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif } // Apply the mask. @@ -478,23 +378,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // softmax.apply_exp(p_lse); // exp (x - (max+log(sum))) = exp(x - max) / sum softmax.template scale_apply_exp(p_lse, params.scale_bmm1f); -#ifdef DEBUG_PRINT - if ((blockIdx.x == 0) && (blockIdx.y == 0)) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - for (int ii = 0; ii < 2; ii ++) { - for (int jj = 0; jj < 4; jj ++) { - int st_row = 2 * mi + ii; - int st_col = 4 * ki + jj; - printf("bwd softmax: threadIdx=%d, l=%d, mi=%d, ki=%d, ii=%d, jj=%d, elt=%f\n", - threadIdx.x, l, mi, ki, ii, jj, softmax.elt_[st_row][st_col]); - } - } - } - } - printf("\n"); - } -#endif + if (Is_dropout) { // softmax.apply_dropout(ph, params.p_dropout_in_uint); // softmax.template apply_dropout(ph, params.p_dropout_in_uint); @@ -522,7 +406,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // __syncthreads(); // } - // what's meaning? fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; #pragma unroll for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { @@ -599,25 +482,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.template pack(frag_p); - // if constexpr (has_bias) { - if (!(params.attn_bias_ptr == nullptr)) { -#ifdef DEBUG_PRINT - if ((blockIdx.x == 0) && (blockIdx.y == 0)) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - for (int ii = 0; ii < 2; ii ++) { - for (int jj = 0; jj < 4; jj ++) { - int st_row = 2 * mi + ii; - int st_col = 4 * ki + jj; - printf("bwd dsoftmax: threadIdx=%d, l=%d, mi=%d, ki=%d, ii=%d, jj=%d, elt=%f\n", - threadIdx.x, l, mi, ki, ii, jj, softmax.elt_[st_row][st_col]); - } - } - } - } - printf("\n"); - } -#endif + if constexpr (has_bias) { + // if (!(params.attn_bias_ptr == nullptr)) { gmem_ds.template store(softmax.elt_); gmem_ds.move(); } From 60e27d85f605d9b1bbece494f9db5648c409940a Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 11 Oct 2022 23:41:32 +0800 Subject: [PATCH 53/71] clean fmha_fprop_fp16_kernel.sm80.cu --- csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu | 7 ------- 1 file changed, 7 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index 84f0cc082..66cd1583d 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -62,11 +62,6 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, bool has_attn = !(launch_params.params.attn_mask_ptr == nullptr); bool has_bias = !(launch_params.params.attn_bias_ptr == nullptr); -#ifdef DEBUG_PRINT - printf ("has_attn=%d, has_bias=%d, bias_mod_size=%d, mask_seq_mod_size=%d, mask_head_mod_size=%d\n", - has_attn, has_bias, launch_params.params.bias_mod_size, launch_params.params.mask_seq_mod_size, launch_params.params.mask_head_mod_size); -#endif - if (has_attn) { if (has_bias) { @@ -178,8 +173,6 @@ void run_fmha_fp16_sm80(Launch_params &launch_params, auto dprops = at::cuda::getCurrentDeviceProperties(); if (launch_params.params.d == 16) { if( launch_params.params.seqlen_k == 128 ) { - // int S, int D, int STEP, int WARPS_M, int WARPS_N, - // D is [hidden_dim] using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } From 96b83bfbd9c955a13278a62f41977b81dcdff127 Mon Sep 17 00:00:00 2001 From: robotcator Date: Tue, 11 Oct 2022 23:48:33 +0800 Subject: [PATCH 54/71] clean fmha_fprop_kernel_1xN.h --- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 212 +------------------- 1 file changed, 2 insertions(+), 210 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 488f59b95..14da9c5b7 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -174,7 +174,6 @@ struct Gemm_Q_K : public Gemm_Q_K_base; -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // Cta_tile_p - printf("Cta_tile_p::M = %d, Cta_tile_p::N = %d, Cta_tile_p::K = %d\n", - Cta_tile_p::M, Cta_tile_p::N, Cta_tile_p::K); - printf("Cta_tile_p::WARPS_M = %d, Cta_tile_p::WARPS_N = %d, Cta_tile_p::WARPS_K = %d\n", - Cta_tile_p::WARPS_M, Cta_tile_p::WARPS_N, Cta_tile_p::WARPS_K); - printf("Cta_tile_p::WARPS_PER_CTA = %d, Cta_tile_p::THREADS_PER_WARP = %d, Cta_tile_p::THREADS_PER_CTA = %d\n", - Cta_tile_p::WARPS_PER_CTA, Cta_tile_p::THREADS_PER_WARP, Cta_tile_p::THREADS_PER_CTA); - printf("\n"); - - // Cta_tile_o - printf("Cta_tile_o::M = %d, Cta_tile_o::N = %d, Cta_tile_o::K = %d\n", - Cta_tile_o::M, Cta_tile_o::N, Cta_tile_o::K); - printf("Cta_tile_o::WARPS_M = %d, Cta_tile_o::WARPS_N = %d, Cta_tile_o::WARPS_K = %d\n", - Cta_tile_o::WARPS_M, Cta_tile_o::WARPS_N, Cta_tile_o::WARPS_K); - printf("Cta_tile_o::WARPS_PER_CTA = %d, Cta_tile_o::THREADS_PER_WARP = %d, Cta_tile_o::THREADS_PER_CTA = %d\n", - Cta_tile_o::WARPS_PER_CTA, Cta_tile_o::THREADS_PER_WARP, Cta_tile_o::THREADS_PER_CTA); - printf("\n"); - - // Mma_tile_p - printf("Mma_tile_p::MMAS_M = %d, Mma_tile_p::MMAS_N = %d, Mma_tile_p::MMAS_K = %d\n", - Mma_tile_p::MMAS_M, Mma_tile_p::MMAS_N, Mma_tile_p::MMAS_K); - // The number of elements computed with a single CTA-MMA. - printf("Mma_tile_p::M_PER_MMA_PER_CTA = %d, Mma_tile_p::N_PER_MMA_PER_CTA = %d, Mma_tile_p::K_PER_MMA_PER_CTA = %d\n", - Mma_tile_p::M_PER_MMA_PER_CTA, Mma_tile_p::N_PER_MMA_PER_CTA, Mma_tile_p::K_PER_MMA_PER_CTA); - printf("\n"); - - // Mma_tile_o - printf("Mma_tile_o::MMAS_M = %d, Mma_tile_o::MMAS_N = %d, Mma_tile_o::MMAS_K = %d\n", - Mma_tile_o::MMAS_M, Mma_tile_o::MMAS_N, Mma_tile_o::MMAS_K); - printf("Mma_tile_o::M_PER_MMA_PER_CTA = %d, Mma_tile_o::N_PER_MMA_PER_CTA = %d, Mma_tile_o::K_PER_MMA_PER_CTA = %d\n", - Mma_tile_o::M_PER_MMA_PER_CTA, Mma_tile_o::N_PER_MMA_PER_CTA, Mma_tile_o::K_PER_MMA_PER_CTA); - printf("\n"); - - // Gmem_tile_q - printf("Gmem_tile_q::BYTES_PER_ELEMENT = %d, Gmem_tile_q::ROWS = %d, Gmem_tile_q::COLS = %d, Gmem_tile_q::LDGS = %d, Gmem_tile_q::THREADS_PER_ROW = %d, Gmem_tile_q::ROWS_PER_LDG=%d\n", - Gmem_tile_q::BYTES_PER_ELEMENT, Gmem_tile_q::ROWS, Gmem_tile_q::COLS, Gmem_tile_q::LDGS, - Gmem_tile_q::THREADS_PER_ROW, Gmem_tile_q::ROWS_PER_LDG); - printf("\n"); - - // Gmem_tile_k - printf("Gmem_tile_k::BYTES_PER_ELEMENT = %d, Gmem_tile_k::ROWS = %d, Gmem_tile_k::COLS = %d, Gmem_tile_k::LDGS = %d, Gmem_tile_q::THREADS_PER_ROW = %d, Gmem_tile_q::ROWS_PER_LDG=%d\n", - Gmem_tile_k::BYTES_PER_ELEMENT, Gmem_tile_k::ROWS, Gmem_tile_k::COLS, Gmem_tile_k::LDGS, - Gmem_tile_q::THREADS_PER_ROW, Gmem_tile_q::ROWS_PER_LDG); - printf("\n"); - - // Gmem_tile_v - printf("Gmem_tile_v::BYTES_PER_ELEMENT = %d, Gmem_tile_v::ROWS = %d, Gmem_tile_v::COLS = %d, Gmem_tile_v::LDGS = %d, Gmem_tile_q::THREADS_PER_ROW = %d, Gmem_tile_q::ROWS_PER_LDG=%d\n", - Gmem_tile_v::BYTES_PER_ELEMENT, Gmem_tile_v::ROWS, Gmem_tile_v::COLS, Gmem_tile_v::LDGS, - Gmem_tile_q::THREADS_PER_ROW, Gmem_tile_q::ROWS_PER_LDG); - printf("\n"); - - // Gmem_tile_o - printf("Gmem_tile_o::ROWS = %d, Gmem_tile_o::COLS = %d, Gmem_tile_o::STGS = %d, Gmem_tile_o::STGS_PER_LOOP = %d\n", - Gmem_tile_o::ROWS, Gmem_tile_o::COLS, Gmem_tile_o::STGS, Gmem_tile_o::STGS_PER_LOOP); - printf("\n"); - - // Gmem_tile_s - printf("Gmem_tile_s::M = %d, Gmem_tile_s::N = %d\n", - Gmem_tile_s::M, Gmem_tile_s::N); - printf("\n"); - - // Gmem_softmax_sum - printf("Gmem_softmax_sum::MMAS_M = %d, Gmem_softmax_sum::ROWS = %d\n", - Gmem_softmax_sum::MMAS_M, Gmem_softmax_sum::ROWS); - printf("\n"); - - // Gemm1 - printf("Gemm1::SHARE_SMEM_FOR_K_AND_V = %d, Gemm1::SMEM_OFFSET_O = %d, Gemm1::SMEM_OFFSET_SOFTMAX = %d, Gemm1::SMEM_OFFSET_V = %d, Gemm1::SMEM_OFFSET_V = %d\n", - Gemm1::SHARE_SMEM_FOR_K_AND_V, Gemm1::SMEM_OFFSET_O, Gemm1::SMEM_OFFSET_SOFTMAX, Gemm1::SMEM_OFFSET_V, Gemm1::SMEM_OFFSET_V); - printf("\n"); - - // Softmax - printf("Softmax::WARPS_M = %d, Softmax::WARPS_N = %d, Softmax::MMAS_M = %d, Softmax::MMAS_N = %d\n", - Softmax::WARPS_M, Softmax::WARPS_N, Softmax::MMAS_M, Softmax::MMAS_N); - printf("\n"); - } -#endif - // Shared memory. extern __shared__ char smem_[]; @@ -353,14 +272,12 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; // conctructor Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); - // TODO: load fun as s // bool has_bias = !(params.attn_bias_ptr == nullptr); // Allocate the global memory tile loader for bias. using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; // conctructor Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); - // TODO: load fun as s Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -378,13 +295,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if constexpr (has_attn) { // if (!(params.attn_mask_ptr == nullptr)) { - // TODO: mask move gmem_mask.move(begin); } if constexpr (has_bias) { // if (!(params.attn_bias_ptr == nullptr)) { - // TODO: bias move gmem_bias.move(begin); } @@ -481,23 +396,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Do this part of P = Q * K^T. gemm_q_k(acc_p); - // TODO acc_p += mask, index like gmem_s.store(frag_p, mask); // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1)); // } -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { - for (int ii = 0; ii < Mma_tile_p::MMAS_M; ii ++) { - for (int jj = 0; jj < Mma_tile_p::MMAS_N; jj ++) { - for (int kk = 0; kk < acc_p[ii][jj].NUM_ELTS; kk ++) { - printf("ii=%d, jj=%d, kk=%d, acc_p=%.6f\n", ii, jj, kk, acc_p[ii][jj].elt(kk)); - } - } - } - } -#endif uint4 out[Gmem_tile_o::STGS_PER_LOOP]; if (!Is_first) { gmem_o_tmp.load(out, 0); } @@ -523,55 +426,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_mask.template load(frag_mask); gmem_mask.move(); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("before attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif // Apply the attn mask. softmax.apply_attn_mask(frag_mask, l, loop_step_idx); - -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("after attn mask softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif } if constexpr (has_bias) { @@ -582,55 +438,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_bias.template load(frag_bias); gmem_bias.move(); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("before attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif // Apply the attn mask. softmax.apply_attn_bias(frag_bias, l); - -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - for( int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi ) { - for( int ki = 0; ki < Mma_tile_p::MMAS_N; ++ki ) { - // 1st row - 4 elements per row. - float tmp_00 = softmax.elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = softmax.elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = softmax.elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = softmax.elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = softmax.elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = softmax.elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = softmax.elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = softmax.elt_[2 * mi + 1][4 * ki + 3]; - - printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_00, tmp_01, tmp_02, tmp_03); - printf("after attn bias softmax: mi=%d, ki=%d, %f %f %f %f\n", mi, ki, tmp_10, tmp_11, tmp_12, tmp_13); - } - } - printf("\n"); - } -#endif } // Apply the mask. @@ -664,15 +473,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } softmax.template reduce_max(p_max); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - // can we print the tile row? - for (int i = 0; i < Mma_tile_p::MMAS_M * 2; i ++) { - printf("i=%d, p_max=%f\n", i, p_max[i]); - } - printf("\n"); - } -#endif + // if ((threadIdx.x == 0) && (l == 38)) { // printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]); // } @@ -685,7 +486,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Compute the exponential value. // softmax.apply_exp(p_max); - // Compute: exp(p - p_max) softmax.scale_apply_exp(p_max, params.scale_bmm1f); // if (!Is_first) { @@ -706,15 +506,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // softmax.reduce_sum(p_sum); softmax.reduce_sum_before_sync_(p_sum); // softmax.template reduce_sum_before_sync_(p_sum); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - // can we print the tile row? - for (int i = 0; i < Mma_tile_p::MMAS_M * 2; i ++) { - printf("i=%d, p_max=%f\n", i, p_sum[i]); - } - printf("\n"); - } -#endif + // float p_sum_log[Mma_tile_p::MMAS_M * 2]; // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) { // float sum = p_sum[mi]; From 9c1cb91fcd4b8a7a1920498b7d1c30a989974373 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 00:02:18 +0800 Subject: [PATCH 55/71] cleangmem_tile.h --- csrc/flash_attn/src/fmha/gmem_tile.h | 85 +--------------------------- 1 file changed, 3 insertions(+), 82 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 3e8aeb558..9238a6cb8 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -86,27 +86,10 @@ struct Gmem_tile_qkv { // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes); // Add the block index. -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("use_seqlen_q=%d\n", use_seqlen_q); - printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); - printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_LDG=%d, LDGS=%d\n", - threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_LDG, LDGS); - printf("\n"); - } -#endif + // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); -// #ifdef DEBUG_PRINT -// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { -// printf("use_seqlen_q=%d\n", use_seqlen_q); -// printf("threadIdx.x=%d, threadIdx.x=%d, blockIdx.y=%d, row_offset=%d BYTES_PER_LDG=%d\n", -// threadIdx.x, threadIdx.x, blockIdx.y, row_offset, BYTES_PER_LDG); -// printf("\n"); -// } -// #endif // Assemble the final pointer. ptr += row_offset + col * BYTES_PER_LDG; } @@ -233,16 +216,7 @@ struct Gmem_tile_o { row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); // Assemble the final pointer. ptr_ += row_offset + col * BYTES_PER_STG; -// #ifdef DEBUG_PRINT -// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { -// printf("print o parameter\n"); -// printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, tidx=%d, row=%d, col=%d, THREADS_PER_ROW=%d\n", -// threadIdx.x, blockIdx.x, blockIdx.y, tidx, row, col, THREADS_PER_ROW); -// printf("threadIdx.x=%d, blockIdx.x=%d, blockIdx.y=%d, row_offset=%d ROWs=%d, COLS=%d, BYTES_PER_ROW=%d, THREADS_PER_ROW=%d, ROWS_PER_STG=%d, LDGS=%d\n", -// threadIdx.x, blockIdx.x, blockIdx.y, row_offset, ROWS, COLS, BYTES_PER_ROW, THREADS_PER_ROW, ROWS_PER_STG, STGS); -// printf("\n"); -// } -// #endif + // Is that thread active on the last STG? if( HAS_INCOMPLETE_STG ) { is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; @@ -520,16 +494,6 @@ struct Gmem_tile_mma_mask { uint32_t row_offset = bidx * params.mask_seq_mod_size * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; row_offset += (uint32_t)( (row % params.mask_seq_mod_size) * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("init mask tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", - tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); - printf("init mask bidb=%d, bidh=%d, param.h=%d, mask_head_mod_size=%d, mask_seq_mod_size=%d, loop_step_idx=%d\n", - binfo.bidb, binfo.bidh, params.h, params.mask_head_mod_size, params.mask_seq_mod_size, loop_step_idx); - printf("\n"); - } -#endif - // do we need to move col first if seklen_k > cols ptr_ += row_offset; } @@ -566,17 +530,6 @@ struct Gmem_tile_mma_mask { // && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("load mask mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", - mi, ni, ii, jj, offset, current_row, current_col, ptr_, ptrs[offset], preds[offset]); - printf("load mask ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d, loop_step_idx=%d, cond1=%d, cond2=%d\n", - ROWS, actual_seqlen_q, COLS, actual_seqlen_k, loop_step_idx, - (current_row <= min(ROWS, actual_seqlen_q)), - ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); - printf("\n"); - } -#endif } } @@ -686,14 +639,6 @@ struct Gmem_tile_mma_bias { // row_offset = (uint32_t)(row * row_stride_in_bytes); row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", - tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); - printf("bidb=%d, bidh=%d, param.h=%d, bias_mod_size=%d\n", binfo.bidb, binfo.bidh, params.h, params.bias_mod_size); - printf("\n"); - } -#endif // do we need to move col first if seklen_k > cols ptr_ += row_offset; } @@ -725,17 +670,6 @@ struct Gmem_tile_mma_bias { preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("load bias mi=%d, ni=%d, ii=%d, jj=%d, offset=%d, current_row=%d, current_col=%d, start_ptr=%p, ptrs[offset]=%p, preds[offset]=%d\n", - mi, ni, ii, jj, offset, current_row, current_col, ptr_, ptrs[offset], preds[offset]); - printf("load bias ROWS=%d, actual_seqlen_q=%d, COLS=%d, actual_seqlen_k=%d, loop_step_idx=%d, cond1=%d, cond2=%d\n", - ROWS, actual_seqlen_q, COLS, actual_seqlen_k, loop_step_idx, - (current_row <= min(ROWS, actual_seqlen_q)), - ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k))); - printf("\n"); - } -#endif } } @@ -838,14 +772,7 @@ struct Gmem_tile_mma_ds { // the index of bs and head dim uint32_t row_offset = bidx * binfo.actual_seqlen_q * binfo.actual_seqlen_k * BYTES_PER_ELEMENT; // row_offset = (uint32_t)(row * row_stride_in_bytes); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("ds tid_=%d, warp=%d, lane=%d, warp_n=%d, warp_m=%d, quad=%d, tid=%d, row=%d, col=%d\n", - tidx_, warp, lane, warp_n, warp_m, quad, tid, row, col); - printf("ds bidb=%d, bidh=%d, param.h=%d, blockIdx.x=%d\n", binfo.bidb, binfo.bidh, params.h, blockIdx.x); - printf("\n"); - } -#endif + row_offset += (uint32_t)(row * binfo.actual_seqlen_k * BYTES_PER_ELEMENT); // do we need to move col first if seklen_k > cols ptr_ += row_offset; @@ -879,12 +806,6 @@ struct Gmem_tile_mma_ds { preds = (current_row < min(ROWS, actual_seqlen_q)) && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); -#ifdef DEBUG_PRINT - if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("ds store blockIdx.x=%d, mi=%d, ni=%d, ii=%d, jj=%d, current_row=%d, current_col=%d, float1=%f, float2=%f, begin=%p, ptrs=%p, preds=%d\n", - blockIdx.x, mi, ni, ii, jj, current_row, current_col, tmp00, tmp01, ptr_, ptrs, preds); - } -#endif if (preds) { fmha::stg(ptrs, dst); } From 4d20f5928d9aa72843538884ab8ab6ccfd6b3a71 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 00:03:58 +0800 Subject: [PATCH 56/71] clean softmax.h --- csrc/flash_attn/src/fmha/softmax.h | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index 3f2018ff0..277e6be32 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -502,21 +502,7 @@ struct Softmax : public Softmax_base { #pragma unroll for( int jj = 0; jj < 4; ++jj ) { float value = toFloat(mask[mi][ni].elt(ii * 4 + jj)); - // if( abs(value) > 0 ) { - // this->elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY; - // } - // if( value < 0 ) { - // this->elt_[2 * mi + ii][4 * ni + jj] = -INFINITY; - // } this->elt_[2 * mi + ii][4 * ni + jj] += value; -#ifdef DEBUG_PRINT - if ((blockIdx.x == 0) && (blockIdx.y == 0) && l == 0) { - printf("Attnmask: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, softmax = %f, l = %d, loop_step_idx=%d, blockIdx.x = %d\n", - threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), this->elt_[2 * mi + ii][4 * ni + jj], l, loop_step_idx, blockIdx.x); - } -#endif - // this->elt_[2 * mi + ii][4 * ni + jj] += float(mask[mi][ni].elt(ii * 4 + jj)); - // this->elt_[2 * mi + ii][4 * ni + jj] += toFloat(mask[mi][ni].elt(ii * 4 + jj)); } } } @@ -536,12 +522,6 @@ struct Softmax : public Softmax_base { for( int jj = 0; jj < 4; ++jj ) { float value = toFloat(bias[mi][ni].elt(ii * 4 + jj)); this->elt_[2 * mi + ii][4 * ni + jj] += value; -#ifdef DEBUG_PRINT - if ((blockIdx.x == 0) && (blockIdx.y == 0)) { - printf("AttnBias: threadIdx.x = %d, threadIdx.y = %d, mi = %d, ni = %d, ii = %d, jj = %d, value = %f, softmax = %f, ldx = %d, blockIdx.x = %d\n", - threadIdx.x, threadIdx.y, mi, ni, ii, jj, float(value), this->elt_[2 * mi + ii][4 * ni + jj], l, blockIdx.x); - } -#endif } } } From ede0a965a47c86cf75da20f031729d7c9788b376 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 00:08:40 +0800 Subject: [PATCH 57/71] restore test_flash_attn.py --- tests/test_flash_attn.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 15afa3fc9..192f78b2d 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -12,6 +12,7 @@ is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5) +is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0) def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'): @@ -331,6 +332,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('d', [128, 64, 32, 16]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @@ -385,7 +387,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output) dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) @@ -411,7 +413,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): else: assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) @@ -476,7 +478,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output) dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g) dq = dq_pad_fn(dq_unpad) @@ -501,7 +503,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): else: assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item() # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol) @@ -568,7 +570,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output) dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g) dq = dq_pad_fn(dq_unpad) @@ -594,7 +596,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): else: assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() @@ -640,7 +642,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal ) - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output_unpad_0) dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0, (q_unpad, k_unpad, v_unpad), g) @@ -659,9 +661,9 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): # assert torch.equal(sm_lse, sm_lse_0) assert torch.equal(S_dmask_converted, S_dmask_converted_0) - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad, (q_unpad, k_unpad, v_unpad), g) assert torch.equal(dq_unpad, dq_unpad_0) assert torch.equal(dk_unpad, dk_unpad_0) - assert torch.equal(dv_unpad, dv_unpad_0) + assert torch.equal(dv_unpad, dv_unpad_0) \ No newline at end of file From 2993caef49d45a270ed59e4cf20367721ba8a9dd Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 00:28:23 +0800 Subject: [PATCH 58/71] clean gmem_tile.h --- csrc/flash_attn/src/fmha/gmem_tile.h | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 9238a6cb8..268af915b 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -410,9 +410,6 @@ struct Gmem_tile_mma_s : public Base { //////////////////////////////////////////////////////////////////////////////////////////////////// -// attn mask struct like s, maybe later can reuse the above declaration -// template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > -// struct Gmem_tile_mma_mask : public Base { template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> struct Gmem_tile_mma_mask { @@ -516,18 +513,10 @@ struct Gmem_tile_mma_mask { int offset = ii * 2 + jj; const int current_row = mi * ROWS + ii * 8; const int current_col = loop_step_idx * Cta_tile::N + ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; - // const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + jj * 8 + col; - // 8 is actually col of half data now, for more general case ? - // the row is already in the right position - // ptrs[offset] = ptr_ + (uint32_t)current_row * row_stride_in_bytes + - // (uint32_t)current_col * BYTES_PER_ELEMENT; - // to support the mask last two dimension ptrs[offset] = ptr_ + (uint32_t)(current_row % mask_seq_mod_size) * row_stride_in_bytes + (uint32_t)current_col * BYTES_PER_ELEMENT; - // preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) - // && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= min(COLS, actual_seqlen_k)); preds[offset] = (current_row < min(ROWS, actual_seqlen_q)) && ((current_col + BYTES_PER_LDG / BYTES_PER_ELEMENT) <= actual_seqlen_k); } @@ -562,7 +551,6 @@ struct Gmem_tile_mma_mask { }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// attn bias struct like s, maybe later can reuse the above declaration template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> struct Gmem_tile_mma_bias { @@ -701,7 +689,6 @@ struct Gmem_tile_mma_bias { //////////////////////////////////////////////////////////////////////////////////////////////////// -// attn bias struct like s, maybe later can reuse the above declaration template< typename Cta_tile, int BYTES_PER_ELEMENT = 2> struct Gmem_tile_mma_ds { From 1be0e940413218af635ce70223e7bb83f0a85bf7 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 10:41:52 +0800 Subject: [PATCH 59/71] fix fmha_fprop_kernel_1xN.h --- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 28 ++++++++++----------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 14da9c5b7..845062cc2 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -267,17 +267,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - // bool has_attn = !(params.attn_mask_ptr == nullptr); - // Allocate the global memory tile loader for mask. - using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; - // conctructor - Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); - - // bool has_bias = !(params.attn_bias_ptr == nullptr); - // Allocate the global memory tile loader for bias. - using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; - // conctructor - Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); + if constexpr (has_attn) { + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + } + + if constexpr (has_bias) { + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); + } Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -294,12 +296,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_softmax_lse.move(begin); if constexpr (has_attn) { - // if (!(params.attn_mask_ptr == nullptr)) { gmem_mask.move(begin); } if constexpr (has_bias) { - // if (!(params.attn_bias_ptr == nullptr)) { gmem_bias.move(begin); } @@ -419,7 +419,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i softmax.unpack_noscale(acc_p); if constexpr (has_attn) { - // if (!(params.attn_mask_ptr == nullptr)) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::clear(frag_mask); @@ -431,7 +430,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } if constexpr (has_bias) { - // if (!(params.attn_bias_ptr == nullptr)) { using Frag_Bias = fmha::Fragment_c; Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::clear(frag_bias); From 32f6cd12d02bbdc93b4c46e5a689786d09c3c14e Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 10:45:51 +0800 Subject: [PATCH 60/71] fix fmha_dgrad_kernel_1xN_loop.h --- .../src/fmha_dgrad_kernel_1xN_loop.h | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 2a501dd49..d1f2dd6aa 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -148,19 +148,22 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - // bool has_attn = !(params.attn_mask_ptr == nullptr); - // Allocate the global memory tile loader for mask. - using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; - // conctructor - Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + if constexpr (has_attn) { + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + } - // Allocate the global memory tile loader for bias. - using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; - // conctructor - Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); + if constexpr (has_bias) { + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); - using Gmem_tile_ds = typename Kernel_traits::Gmem_tile_ds; - Gmem_tile_ds gmem_ds(params, binfo, tidx, loop_step_idx); + using Gmem_tile_ds = typename Kernel_traits::Gmem_tile_ds; + Gmem_tile_ds gmem_ds(params, binfo, tidx, loop_step_idx); + } fmha::Mask mask(binfo, tidx, loop_step_idx); @@ -215,7 +218,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_mask.move(begin); } - if constexpr (has_attn) { + if constexpr (has_bias) { gmem_bias.move(begin); gmem_ds.move(begin); } @@ -351,7 +354,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); if constexpr (has_attn) { - // if (!(params.attn_mask_ptr == nullptr)) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; gmem_mask.template load(frag_mask); @@ -362,7 +364,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } if constexpr (has_bias) { - // if (!(params.attn_bias_ptr == nullptr)) { using Frag_Bias = fmha::Fragment_c; Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; gmem_bias.template load(frag_bias); From 30e42536306206c5e93b378db10e4d3842bdee40 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 10:51:03 +0800 Subject: [PATCH 61/71] rename has_attn to has_attn_mask, has_bias to has_attn_bias --- csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu index bbe36647d..66b62b53f 100644 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu @@ -60,12 +60,12 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, const int smem_size = fmha::get_dynamic_smem_size() + (loop_steps > 1 ? smem_size_softmax_lse : 0); - bool has_attn = !(launch_params.params.attn_mask_ptr == nullptr); - bool has_bias = !(launch_params.params.attn_bias_ptr == nullptr); + bool has_attn_mask = !(launch_params.params.attn_mask_ptr == nullptr); + bool has_attn_bias = !(launch_params.params.attn_bias_ptr == nullptr); - if (has_attn) + if (has_attn_mask) { - if (has_bias) { + if (has_attn_bias) { // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // https://github.com/kokkos/kokkos-kernels/issues/349 // https://github.com/HazyResearch/flash-attention/issues/21 @@ -115,7 +115,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, }); } }else{ - if (has_bias) { + if (has_attn_bias) { // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // https://github.com/kokkos/kokkos-kernels/issues/349 // https://github.com/HazyResearch/flash-attention/issues/21 From e8a376ea34ceae32e19940cc54c15da6f9d97a60 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 10:53:26 +0800 Subject: [PATCH 62/71] fix fmha_fprop_kernel_1xN.h --- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 845062cc2..91d4e0c1f 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -198,7 +198,7 @@ constexpr size_t get_dynamic_smem_size(){ return Gemm_Q_K::SMEM_BYTES; } -template +template inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -267,14 +267,14 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - if constexpr (has_attn) { + if constexpr (has_attn_mask) { // Allocate the global memory tile loader for mask. using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; // conctructor Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); } - if constexpr (has_bias) { + if constexpr (has_attn_bias) { // Allocate the global memory tile loader for bias. using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; // conctructor @@ -295,11 +295,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i if (Return_softmax) { gmem_s.move(begin); } gmem_softmax_lse.move(begin); - if constexpr (has_attn) { + if constexpr (has_attn_mask) { gmem_mask.move(begin); } - if constexpr (has_bias) { + if constexpr (has_attn_bias) { gmem_bias.move(begin); } @@ -418,7 +418,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - if constexpr (has_attn) { + if constexpr (has_attn_mask) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::clear(frag_mask); @@ -429,7 +429,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i softmax.apply_attn_mask(frag_mask, l, loop_step_idx); } - if constexpr (has_bias) { + if constexpr (has_attn_bias) { using Frag_Bias = fmha::Fragment_c; Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::clear(frag_bias); From 806e15648e37a0f08474676c324db6e3d35b1fa8 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 10:56:15 +0800 Subject: [PATCH 63/71] rename has_attn to has_attn_mask, has_bias to has_attn_bias --- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 10 +++++----- .../flash_attn/src/fmha_dgrad_kernel_1xN_loop.h | 17 ++++++++--------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index b0f5f337c..2028f7211 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -29,11 +29,11 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" - bool has_attn = !(params.attn_mask_ptr == nullptr); - bool has_bias = !(params.attn_bias_ptr == nullptr); + bool has_attn_mask = !(params.attn_mask_ptr == nullptr); + bool has_attn_bias = !(params.attn_bias_ptr == nullptr); - if (has_attn) { - if (has_bias) { + if (has_attn_mask) { + if (has_attn_bias) { BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { auto kernel = params.is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel @@ -79,7 +79,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_ }); } }else{ - if (has_bias) { + if (has_attn_bias) { BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { auto kernel = params.is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index d1f2dd6aa..28149d32d 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -31,7 +31,7 @@ inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph, const int loop_step_idx) { @@ -148,14 +148,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - if constexpr (has_attn) { + if constexpr (has_attn_mask) { // Allocate the global memory tile loader for mask. using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; // conctructor Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); } - if constexpr (has_bias) { + if constexpr (has_attn_bias) { // Allocate the global memory tile loader for bias. using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; // conctructor @@ -214,11 +214,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_softmax_lse.move(begin); gmem_softmax_d.move(begin); - if constexpr (has_attn) { + if constexpr (has_attn_mask) { gmem_mask.move(begin); } - if constexpr (has_bias) { + if constexpr (has_attn_bias) { gmem_bias.move(begin); gmem_ds.move(begin); } @@ -353,7 +353,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); - if constexpr (has_attn) { + if constexpr (has_attn_mask) { using Frag_mask = fmha::Fragment_c; Frag_mask frag_mask[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; gmem_mask.template load(frag_mask); @@ -363,7 +363,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.apply_attn_mask(frag_mask); } - if constexpr (has_bias) { + if constexpr (has_attn_bias) { using Frag_Bias = fmha::Fragment_c; Frag_Bias frag_bias[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; gmem_bias.template load(frag_bias); @@ -483,8 +483,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng softmax.template pack(frag_p); - if constexpr (has_bias) { - // if (!(params.attn_bias_ptr == nullptr)) { + if constexpr (has_attn_bias) { gmem_ds.template store(softmax.elt_); gmem_ds.move(); } From 15ade00d6b14f7abf33b90ddf69cb70976a19a55 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 11:07:28 +0800 Subject: [PATCH 64/71] remove useless benchmark code --- benchmarks/correctness/attention.py | 35 -- benchmarks/correctness/benchmark_memory.py | 128 ------- benchmarks/correctness/check_correct.py | 336 ------------------ .../correctness/check_speed_backward.py | 131 ------- benchmarks/correctness/check_speed_forward.py | 128 ------- benchmarks/correctness/flash_attention.py | 65 ---- benchmarks/correctness/test_mem.sh | 5 - benchmarks/correctness/test_time.sh | 5 - benchmarks/correctness/torch_attention.py | 51 --- build.sh | 10 - 10 files changed, 894 deletions(-) delete mode 100644 benchmarks/correctness/attention.py delete mode 100644 benchmarks/correctness/benchmark_memory.py delete mode 100644 benchmarks/correctness/check_correct.py delete mode 100644 benchmarks/correctness/check_speed_backward.py delete mode 100644 benchmarks/correctness/check_speed_forward.py delete mode 100644 benchmarks/correctness/flash_attention.py delete mode 100644 benchmarks/correctness/test_mem.sh delete mode 100644 benchmarks/correctness/test_time.sh delete mode 100644 benchmarks/correctness/torch_attention.py delete mode 100644 build.sh diff --git a/benchmarks/correctness/attention.py b/benchmarks/correctness/attention.py deleted file mode 100644 index d8a686548..000000000 --- a/benchmarks/correctness/attention.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from typing import Optional, Callable, List, Tuple, Sequence - -from unicore.modules import softmax_dropout - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - -def _attention(query, key, value, mask=None, bias=None, upcast=False) -> torch.Tensor: - dtype_og = query.dtype - - if upcast: - query = query.float() - key = key.float() - value = value.float() - if mask is not None: - mask = mask.float() - if bias is not None: - bias = bias.float() - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 0)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - - a = softmax_dropout(a, dropout_prob=0, is_training=True, mask=mask, bias=bias) - - # [*, H, Q, C_hidden] - b = torch.matmul(a, value) - - return b.to(dtype_og) diff --git a/benchmarks/correctness/benchmark_memory.py b/benchmarks/correctness/benchmark_memory.py deleted file mode 100644 index 4107b8012..000000000 --- a/benchmarks/correctness/benchmark_memory.py +++ /dev/null @@ -1,128 +0,0 @@ -import torch -import torch.utils.benchmark as benchmark - -from flash_attention import _flash_attn -from attention import _attention -from torch_attention import _torch_attention - -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument("--has_mask_bias", required=False, help="add bias in attention", type=bool, default=False) -parser.add_argument("--eval", required=False, help="test whether has backward", type=bool, default=False) - -args = parser.parse_args() -print(args) - - -def benchmark_memory(fn, inputs, mask=None, bias=None, grad=None, eval=True, desc='', verbose=False, **kwinputs): - def fwd(grad, inputs, mask=mask, bias=bias, **kwinputs): - with torch.no_grad(): - y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) - - - def fwd_bwd(grad, inputs, mask=mask, bias=bias, **kwinputs): - y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) - if type(y) is tuple: - y = y[0] - if grad is None: - grad = torch.randn_like(y) - else: - if grad.shape != y.shape: - raise RuntimeError('Grad shape does not match output shape') - y.backward(grad, retain_graph=False) - - if eval: - f = fwd - if verbose: - print ("using fwd func...") - else: - f = fwd_bwd - if verbose: - print ("using fwd and bwd func...") - - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - - f(None, inputs, mask, bias) - - torch.cuda.synchronize() - mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000) - if verbose: - print(f"{desc} max memory: ", mem) - torch.cuda.empty_cache() - return mem - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - - -def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True, eval=True): - bs = 1 - head = 4 - c_dim = 32 - seq_q = seq_k = seq_v = seqlen - dtype = torch.bfloat16 - device = "cuda" - - inputs = torch.empty((bs, seq_q, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) - inputs.requires_grad = True - if verbose: - print ("inputs shape: ", inputs.shape) - # [bs, seq, seq, head, c_dim] - - if has_bias: - bias = torch.randn( - 1, 1, head, seq_q, seq_k, dtype=dtype, device=device - ) - bias.requires_grad = True - if verbose: - print ("bias shape: ", bias.shape) - # [1, 1, seq, head, seq_k] - else: - bias = None - - if has_mask: - mask = gen_attn_mask( - ( - torch.rand(bs, seq_q, 1, 1, seq_k, dtype=dtype, device=device,) > 0.2 - ).type(dtype), - -3e4, - ) - if verbose: - print ("mask shape: ", mask.shape) - else: - mask = None - - print ("processing seq length: {} in eval model {} ......".format(seqlen, eval)) - - try: - m1 = benchmark_memory(_attention, inputs, mask=mask, bias=bias, eval=eval, desc='Normal Attention forward') - print (m1) - except: - print ("Normal Attention OOM") - - try: - m2 = benchmark_memory(_flash_attn, inputs, mask=mask, bias=bias, eval=eval, desc='Flash Attention forward') - print (m2) - except: - print ("Flash Attention OOM") - - -for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]: - if args.has_mask_bias: - if not args.eval: - fun(seqlen=seqlen, eval=False) - else: - fun(seqlen=seqlen, eval=True) - else: - if not args.eval: - fun(seqlen=seqlen, has_bias=None, has_mask=None, eval=False) - else: - fun(seqlen=seqlen, has_bias=None, has_mask=None, eval=True) - diff --git a/benchmarks/correctness/check_correct.py b/benchmarks/correctness/check_correct.py deleted file mode 100644 index 032909f74..000000000 --- a/benchmarks/correctness/check_correct.py +++ /dev/null @@ -1,336 +0,0 @@ -import torch - -# from attention import _attention -from torch_attention import _torch_attention as _attention -from flash_attention import _flash_attn - -import numpy as np -import pytest - - -def is_same_matrix(pred, gt, abs_eps=0.01, relative_rps=0.03, verbose=False): - diff = np.abs(pred - gt) - - cnt = 0 - for index, x in np.ndenumerate(diff): - if x > abs_eps: - relative_diff = np.abs(x / gt[index]) - if relative_diff > relative_rps: - cnt += 1 - if verbose: - print (index, x, gt[index], relative_diff) - - if cnt > 0: - print ("not so match") - return False - else: - return True - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - -# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -# # @pytest.mark.parametrize('c_dim', [64, 32, 16]) -# @pytest.mark.parametrize('c_dim', [16]) -# @pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 1536, 2048]) -def test_flash_attn_unpadded_shape1(seqlen, c_dim, dtype, device = "cuda"): - # mini - # bs = 2 - # head = 8 - # c_dim = 16 - # seq_q = seq_k = seq_v = 128 - # dtype = torch.half - # device = "cuda" - - bs = 1 - head = 1 - c_dim = c_dim - bs_seq = 1 - seq_q = seq_k = seq_v = seqlen - dtype = dtype - device = device - - inputs = torch.empty((bs, bs_seq, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) - # debug data - # inputs = torch.zeros((bs, bs_seq, seq_q, head, c_dim), dtype=dtype, device=device) - # cnt = 0 - # for i in range(bs): - # for j in range(bs_seq): - # for k in range(seq_q): - # for l in range(head): - # for m in range(c_dim): - # inputs[i][j][k][l][m] = (cnt % 10000) * 0.001 - # cnt += 1 - - # inputs = inputs.permute(0, 1, 3, 2, 4) - inputs.requires_grad = True - - print ("inputs shape: ", inputs.shape) - # [bs, seq, seq, head, c_dim] - - bias = torch.randn( - 1, 1, head, seq_q, seq_k, dtype=dtype, device=device - ) - bias.requires_grad = True - - print ("bias shape: ", bias.shape) - # [1, 1, seq, head, seq_k] - - mask = gen_attn_mask( - ( - torch.randn((bs, bs_seq, 1, 1, seq_k), dtype=dtype, device=device,) > 0.2 - ).type(dtype), - -3e4, - ) - # [bs, bs_seq, head, 1, seq_k] - - # debug data - # mask = torch.ones(bs, bs_seq, head, 1, seq_k, dtype=dtype, device=device,) * -1 - # for i in range(bs): - # for j in range(bs_seq): - # for k in range(head): - # for l in range(1): - # for m in range(seq_q): - # if m % 2 == 0: - # mask[i][j][k][l][m] = 0 - - # mask = mask.expand(bs, bs_seq, head, seq_q, seq_k) - print ("mask shape: ", mask.shape) - - # bias = None - # mask = None - # [bs, seq_q, 1, 1, seq_k] - - normal_attn_v1 = inputs.clone() - output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, bias=bias, mask=mask, upcast=True) - output_ref = output_ref.transpose(-2, -3) - print ("attention ref output shape: ", output_ref.shape) - - normal_attn_v2 = inputs.clone() - output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, bias=bias, mask=mask) - # be careful here - output_pt = output_pt.transpose(-2, -3) - print ("attention output shape: ", output_pt.shape) - - normal_attn_flash = inputs.clone() - output_flash = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, bias=bias, mask=mask) - print ("flash attn output shape: ", output_flash.shape) - - print (10 * "*" + "comparing forward" + 10 * "*" ) - # fp32 result - print("Output max diff: {0}".format((output_flash - output_ref).abs().max().item())) - print("Output mean diff: {0}".format((output_flash - output_ref).abs().mean().item())) - - print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) - print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) - - print("Output max diff with Pytorch: {0}".format((output_flash - output_pt).abs().max().item())) - print("Output mean diff with Pytorch: {0}".format((output_flash - output_pt).abs().mean().item())) - - # Check that FlashAttention's numerical error is at most twice the numerical error of a Pytorch implementation. - print ("less than twice error: ", (output_flash - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) - print () - - g = torch.randn_like(output_flash) - # dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, ), g) - # dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, ), g) - # dq, dk, dv, = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, ), g) - - dq_ref, dk_ref, dv_ref, dbias_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, bias), g) - dq_pt, dk_pt, dv_pt, dbias_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, bias), g) - dq, dk, dv, dbias = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, bias), g) - - print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) - print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) - print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) - - print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) - print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) - print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) - - print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) - print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) - print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) - - print ("dq less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) - print ("dk less than twice error: ", ((dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()) ) - print ("dv less than twice error: ", ((dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()) ) - - if bias is not None: - print ("dbias less than twice error: ", ((dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item()) ) - - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item(), "dq larger than twice error" - assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item(), "dq larger than twice error" - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item(), "dq larger than twice error" - - if bias is not None: - print("Output dbias max diff: {0}".format( (dbias - dbias_ref).abs().max().item() )) - print("Pytorch dbias max diff: {0}".format( (dbias - dbias_pt).abs().max().item() )) - assert (dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item(), "dbias larger than twice error" - - - -# @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -# # @pytest.mark.parametrize('c_dim', [64, 32, 16]) -# @pytest.mark.parametrize('c_dim', [16]) -# @pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 1536, 2048]) -def test_flash_attn_unpadded_shape2(seqlen, c_dim, dtype, device = "cuda"): - # mini - # bs = 2 - # head = 8 - # c_dim = 16 - # seq_q = seq_k = seq_v = 128 - # dtype = torch.half - # device = "cuda" - - bs = 1 - head = 1 - c_dim = c_dim - bs_seq = 1 - seq_q = seq_k = seq_v = seqlen - dtype = dtype - device = device - - inputs = torch.empty((bs, bs_seq, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) - # debug data - # inputs = torch.zeros((bs, bs_seq, seq_q, head, c_dim), dtype=dtype, device=device) - # cnt = 0 - # for i in range(bs): - # for j in range(bs_seq): - # for k in range(seq_q): - # for l in range(head): - # for m in range(c_dim): - # inputs[i][j][k][l][m] = (cnt % 10000) * 0.001 - # cnt += 1 - - # inputs = inputs.permute(0, 1, 3, 2, 4) - inputs.requires_grad = True - - print ("inputs shape: ", inputs.shape) - # [bs, seq, seq, head, c_dim] - - bias = torch.randn( - 1, bs_seq, head, seq_q, seq_k, dtype=dtype, device=device - ) - bias.requires_grad = True - - print ("bias shape: ", bias.shape) - # [1, 1, seq, head, seq_k] - - mask = gen_attn_mask( - ( - torch.randn((bs, bs_seq, head, 1, seq_k), dtype=dtype, device=device,) > 0.2 - ).type(dtype), - -3e4, - ) - # [bs, bs_seq, head, 1, seq_k] - - # debug data - # mask = torch.ones(bs, bs_seq, head, 1, seq_k, dtype=dtype, device=device,) * -1 - # for i in range(bs): - # for j in range(bs_seq): - # for k in range(head): - # for l in range(1): - # for m in range(seq_q): - # if m % 2 == 0: - # mask[i][j][k][l][m] = 0 - - # mask = mask.expand(bs, bs_seq, head, seq_q, seq_k) - print ("mask shape: ", mask.shape) - - # bias = None - # mask = None - # [bs, seq_q, 1, 1, seq_k] - - # np.savetxt("inputs_flash_seq{0}.data".format(seqlen), inputs.detach().cpu().numpy().flatten(), delimiter=" ") - # if mask is not None: - # np.savetxt("attn_mask_flash_seq{0}.data".format(seqlen), mask.detach().cpu().numpy().flatten(), delimiter=" ") - - normal_attn_v1 = inputs.clone() - output_ref = _attention(normal_attn_v1, normal_attn_v1, normal_attn_v1, bias=bias, mask=mask, upcast=True) - output_ref = output_ref.transpose(-2, -3) - print ("attention ref output shape: ", output_ref.shape) - - normal_attn_v2 = inputs.clone() - output_pt = _attention(normal_attn_v2, normal_attn_v2, normal_attn_v2, bias=bias, mask=mask) - # be careful here - output_pt = output_pt.transpose(-2, -3) - print ("attention output shape: ", output_pt.shape) - - normal_attn_flash = inputs.clone() - output_flash = _flash_attn(normal_attn_flash, normal_attn_flash, normal_attn_flash, bias=bias, mask=mask) - print ("flash attn output shape: ", output_flash.shape) - # [bs, bs_seq, head, seq_k c_dim] - - # np.savetxt("output_torch_seq{0}.data".format(seqlen), output_pt.detach().cpu().numpy().flatten(), delimiter=" ") - # np.savetxt("output_flash_seq{0}.data".format(seqlen), output_flash.detach().cpu().numpy().flatten(), delimiter=" ") - - print (10 * "*" + "comparing forward" + 10 * "*" ) - # fp32 result - print("Output max diff: {0}".format((output_flash - output_ref).abs().max().item())) - print("Output mean diff: {0}".format((output_flash - output_ref).abs().mean().item())) - - print("Pytorch max diff: {0}".format((output_pt - output_ref).abs().max().item())) - print("Pytorch mean diff: {0}".format((output_pt - output_ref).abs().mean().item())) - - print("Output max diff with Pytorch: {0}".format((output_flash - output_pt).abs().max().item())) - print("Output mean diff with Pytorch: {0}".format((output_flash - output_pt).abs().mean().item())) - - # Check that FlashAttention's numerical error is at most twice the numerical error of a Pytorch implementation. - print ("less than twice error: ", (output_flash - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()) - print () - assert (output_flash - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() - - g = torch.randn_like(output_flash) - # dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, ), g) - # dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, ), g) - # dq, dk, dv, = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, ), g) - - dq_ref, dk_ref, dv_ref, dbias_ref = torch.autograd.grad(output_ref, (normal_attn_v1, normal_attn_v1, normal_attn_v1, bias), g) - dq_pt, dk_pt, dv_pt, dbias_pt = torch.autograd.grad(output_pt, (normal_attn_v2, normal_attn_v2, normal_attn_v2, bias), g) - dq, dk, dv, dbias = torch.autograd.grad(output_flash, (normal_attn_flash, normal_attn_flash, normal_attn_flash, bias), g) - - print("Output dQ max diff: {0}".format( (dq - dq_ref).abs().max().item() )) - print("Output dK max diff: {0}".format( (dk - dk_ref).abs().max().item() )) - print("Output dV max diff: {0}".format( (dv - dv_ref).abs().max().item() )) - - print("Pytorch dQ max diff: {0}".format( (dq_pt - dq_ref).abs().max().item() )) - print("Pytorch dK max diff: {0}".format( (dk_pt - dk_ref).abs().max().item() )) - print("Pytorch dV max diff: {0}".format( (dv_pt - dv_ref).abs().max().item() )) - - print("Output dQ max diff with Pytorch: {0}".format( (dq - dq_pt).abs().max().item() )) - print("Output dK max diff with Pytorch: {0}".format( (dk - dk_pt).abs().max().item() )) - print("Output dV max diff with Pytorch: {0}".format( (dv - dv_pt).abs().max().item() )) - - print ("dq less than twice error: ", ((dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()) ) - print ("dk less than twice error: ", ((dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()) ) - print ("dv less than twice error: ", ((dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()) ) - - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item(), "dq larger than twice error" - assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item(), "dq larger than twice error" - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item(), "dq larger than twice error" - - if bias is not None: - print("Output dbias max diff: {0}".format( (dbias - dbias_ref).abs().max().item() )) - print("Pytorch dbias max diff: {0}".format( (dbias - dbias_pt).abs().max().item() )) - assert (dbias - dbias_ref).abs().max().item() <= 2 * (dbias_pt - dbias_ref).abs().max().item(), "dbias larger than twice error" - - -# for dtype in [torch.float16]: -# # for dtype in [torch.float16, torch.bfloat16]: -# for c_dim in [16]: -# for seqlen in [64, 128, 256, 512]: -# print ("dtype={}, c_dim={}, seqlen={}".format(dtype, c_dim, seqlen)) -# test_flash_attn_unpadded_shape1(seqlen, c_dim, dtype) - - -# for dtype in [torch.float16]: -for dtype in [torch.float16, torch.bfloat16]: - for c_dim in [16, 32, 64]: - for seqlen in [64, 128, 256, 512, 1024, 2048]: - print ("dtype={}, c_dim={}, seqlen={}".format(dtype, c_dim, seqlen)) - test_flash_attn_unpadded_shape2(seqlen, c_dim, dtype) \ No newline at end of file diff --git a/benchmarks/correctness/check_speed_backward.py b/benchmarks/correctness/check_speed_backward.py deleted file mode 100644 index 05bc7e58d..000000000 --- a/benchmarks/correctness/check_speed_backward.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import torch.utils.benchmark as benchmark - -from flash_attention import _flash_attn -from attention import _attention -from torch_attention import _torch_attention - -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument("--has_mask_bias", required=False, help="add bias in attention", type=bool, default=False) - -args = parser.parse_args() -print(args) - - -def benchmark_combined(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): - """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ - if verbose: - print(desc, '- Forward + Backward pass') - - def f(grad, inputs, mask=mask, bias=bias, **kwinputs): - y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) - if type(y) is tuple: - y = y[0] - if grad is None: - grad = torch.randn_like(y) - else: - if grad.shape != y.shape: - raise RuntimeError('Grad shape does not match output shape') - y.backward(grad, retain_graph=True) - - t = benchmark.Timer( - stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', - globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def benchmark_forward(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): - """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ - if verbose: - print(desc, '- Forward pass with no grad') - - def f(grad, inputs, mask=mask, bias=bias, **kwinputs): - with torch.no_grad(): - y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) - - t = benchmark.Timer( - stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', - globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - - -def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True): - bs = 1 - head = 4 - c_dim = 32 - seq_q = seq_k = seq_v = seqlen - dtype = torch.bfloat16 - device = "cuda" - - inputs = torch.empty((bs, seq_q, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) - inputs.requires_grad = True - if verbose: - print ("inputs shape: ", inputs.shape) - # [bs, seq, seq, head, c_dim] - - if has_bias: - bias = torch.randn( - 1, 1, head, seq_q, seq_k, dtype=dtype, device=device - ) - bias.requires_grad = True - if verbose: - print ("bias shape: ", bias.shape) - # [1, 1, seq, head, seq_k] - else: - bias = None - - if has_mask: - mask = gen_attn_mask( - ( - torch.rand(bs, seq_q, 1, 1, seq_k, dtype=dtype, device=device,) > 0.2 - ).type(dtype), - -3e4, - ) - if verbose: - print ("mask shape: ", mask.shape) - else: - mask = None - - print ("processing seq length: {} ......".format(seqlen)) - try: - t1, m1 = benchmark_combined(_attention, inputs, mask=mask, bias=bias, repeats=100, desc='Normal Attention forward') - # import pdb; pdb.set_trace() - # print (m1) - # raw_times / number_per_run * 1000 ms - print (m1.raw_times[0]) - except: - print ("normal attention OOM") - - try: - t2, m2 = benchmark_combined(_flash_attn, inputs, mask=mask, bias=bias, repeats=100, desc='Flash Attention forward') - # print (m2) - print (m2.raw_times[0]) - except: - print ("flash attention OOM") - - -for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]: - if args.has_mask_bias: - fun(seqlen=seqlen) - else: - fun(seqlen=seqlen, has_bias=None, has_mask=None) - diff --git a/benchmarks/correctness/check_speed_forward.py b/benchmarks/correctness/check_speed_forward.py deleted file mode 100644 index 653bb00ba..000000000 --- a/benchmarks/correctness/check_speed_forward.py +++ /dev/null @@ -1,128 +0,0 @@ -import torch -import torch.utils.benchmark as benchmark - -from flash_attention import _flash_attn -from attention import _attention -from torch_attention import _torch_attention - -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument("--has_mask_bias", required=False, help="add bias in attention", type=bool, default=False) - -args = parser.parse_args() -print(args) - - -def benchmark_combined(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): - """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ - if verbose: - print(desc, '- Forward + Backward pass') - - def f(grad, inputs, mask=mask, bias=bias, **kwinputs): - y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) - if type(y) is tuple: - y = y[0] - if grad is None: - grad = torch.randn_like(y) - else: - if grad.shape != y.shape: - raise RuntimeError('Grad shape does not match output shape') - y.backward(grad, retain_graph=True) - - t = benchmark.Timer( - stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', - globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def benchmark_forward(fn, inputs, mask=None, bias=None, grad=None, repeats=10, desc='', verbose=False, **kwinputs): - """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ - if verbose: - print(desc, '- Forward pass with no grad') - - def f(grad, inputs, mask=mask, bias=bias, **kwinputs): - with torch.no_grad(): - y = fn(inputs, inputs, inputs, mask=mask, bias=bias, **kwinputs) - - t = benchmark.Timer( - stmt='f(grad, inputs, mask=mask, bias=bias, **kwinputs)', - globals={'f': f, 'fn': fn, 'inputs': inputs, 'mask': mask, 'bias': bias, 'grad': grad, 'kwinputs': kwinputs}, - num_threads=torch.get_num_threads(), - ) - m = t.timeit(repeats) - if verbose: - print(m) - return t, m - - -def gen_attn_mask(mask, neg_inf): - assert neg_inf < -1e4 - attn_mask = torch.zeros_like(mask) - attn_mask[mask == 0] = neg_inf - return attn_mask - - -def fun(seqlen=128, verbose=False, has_bias=True, has_mask=True): - bs = 1 - head = 4 - c_dim = 32 - seq_q = seq_k = seq_v = seqlen - dtype = torch.bfloat16 - device = "cuda" - - inputs = torch.empty((bs, seq_q, head, seq_q, c_dim), dtype=dtype, device=device).normal_(mean=0, std=.5) - inputs.requires_grad = True - if verbose: - print ("inputs shape: ", inputs.shape) - # [bs, seq, seq, head, c_dim] - - if has_bias: - bias = torch.randn( - 1, 1, head, seq_q, seq_k, dtype=dtype, device=device - ) - bias.requires_grad = True - if verbose: - print ("bias shape: ", bias.shape) - # [1, 1, seq, head, seq_k] - else: - bias = None - - if has_mask: - mask = gen_attn_mask( - ( - torch.rand(bs, seq_q, 1, 1, seq_k, dtype=dtype, device=device,) > 0.2 - ).type(dtype), - -3e4, - ) - if verbose: - print ("mask shape: ", mask.shape) - else: - mask = None - - print ("processing seq length: {} ......".format(seqlen)) - try: - t1, m1 = benchmark_forward(_attention, inputs, mask=mask, bias=bias, repeats=100, desc='Normal Attention forward') - # print (m1) - print (m1.raw_times[0]) - except: - print ("normal attention OOM") - - try: - t2, m2 = benchmark_forward(_flash_attn, inputs, mask=mask, bias=bias, repeats=100, desc='Flash Attention forward') - # print (m2) - print (m2.raw_times[0]) - except: - print ("flash attention OOM") - - -for seqlen in [2**8, 2**9, 600, 700, 800, 2**10, 1200, 1400, 2**11, 2500, 3000, 3500, 2**12]: - if args.has_mask_bias: - fun(seqlen=seqlen) - else: - fun(seqlen=seqlen, has_bias=None, has_mask=None) diff --git a/benchmarks/correctness/flash_attention.py b/benchmarks/correctness/flash_attention.py deleted file mode 100644 index 0d7d5b6d4..000000000 --- a/benchmarks/correctness/flash_attention.py +++ /dev/null @@ -1,65 +0,0 @@ - -import torch -from flash_attn.flash_attn_interface import flash_attn_unpadded_func - - -def _flash_attn(q, k, v, mask=None, bias=None): - batch_dims = q.shape[:-3] - no_heads, n, c = q.shape[-3:] - k_no_heads, k_n, k_c = k.shape[-3:] - - dtype = q.dtype - - # [*, B, N, H, C] - q = q.transpose(-2, -3) - k = k.transpose(-2, -3) - v = v.transpose(-2, -3) - - # [B_flat, N, H, C] - q = q.reshape(-1, *q.shape[-3:]) - k = k.reshape(-1, *k.shape[-3:]) - v = v.reshape(-1, *v.shape[-3:]) - - # Flattened batch size - batch_size = q.shape[0] - - # [B_flat * N, H, C] - q = q.reshape(-1, *q.shape[-2:]) - k = k.reshape(-1, *k.shape[-2:]) - v = v.reshape(-1, *v.shape[-2:]) - - q_max_s = n - q_cu_seqlens = torch.arange( - 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device - ) - - k_max_s = k_n - k_cu_seqlens = torch.arange( - 0, (batch_size + 1) * k_n, step=k_n, dtype=torch.int32, device=k.device - ) - - if mask is not None: - mask_heads, tgt_len, src_len = mask.shape[-3:] - mask = mask.reshape(-1 , mask_heads, tgt_len, src_len).contiguous() - - if bias is not None: - bias_heads, tgt_len, src_len = bias.shape[-3:] - bias = bias.reshape(-1 , bias_heads, tgt_len, src_len).contiguous() - - out = flash_attn_unpadded_func( - q, - k, - v, - q_cu_seqlens, - k_cu_seqlens, - q_max_s, - k_max_s, - attn_mask=mask, - attn_bias=bias, - dropout_p = 0., - softmax_scale = 1., # q has been scaled already - ) - - # [*, B, N, H, C] - out = out.reshape(*batch_dims, n, no_heads, c) - return out diff --git a/benchmarks/correctness/test_mem.sh b/benchmarks/correctness/test_mem.sh deleted file mode 100644 index 500646c8a..000000000 --- a/benchmarks/correctness/test_mem.sh +++ /dev/null @@ -1,5 +0,0 @@ -python benchmarks/correctness/benchmark_memory.py --has_mask_bias=true --eval=false 2>&1 |tee has_mask_bias_train.txt -python benchmarks/correctness/benchmark_memory.py --has_mask_bias=false --eval=false 2>&1 |tee no_mask_bias_train.txt - -python benchmarks/correctness/benchmark_memory.py --has_mask_bias=true --eval=true 2>&1 |tee has_mask_bias_test.txt -python benchmarks/correctness/benchmark_memory.py --has_mask_bias=false --eval=true 2>&1 |tee no_mask_bias_test.txt \ No newline at end of file diff --git a/benchmarks/correctness/test_time.sh b/benchmarks/correctness/test_time.sh deleted file mode 100644 index 92bd47d69..000000000 --- a/benchmarks/correctness/test_time.sh +++ /dev/null @@ -1,5 +0,0 @@ -python benchmarks/correctness/check_speed_forward.py --has_mask_bias=false 2>&1 |tee no_mask_bias_test.txt -python benchmarks/correctness/check_speed_forward.py --has_mask_bias=true 2>&1 |tee has_mask_bias_test.txt - -python benchmarks/correctness/check_speed_backward.py --has_mask_bias=false 2>&1 |tee no_mask_bias_train.txt -python benchmarks/correctness/check_speed_backward.py --has_mask_bias=true 2>&1 |tee has_mask_bias_train.txt \ No newline at end of file diff --git a/benchmarks/correctness/torch_attention.py b/benchmarks/correctness/torch_attention.py deleted file mode 100644 index 277f45b88..000000000 --- a/benchmarks/correctness/torch_attention.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from typing import Optional, Callable, List, Tuple, Sequence - - -def permute_final_dims(tensor: torch.Tensor, inds: List[int]): - zero_index = -1 * len(inds) - first_inds = list(range(len(tensor.shape[:zero_index]))) - return tensor.permute(first_inds + [zero_index + i for i in inds]) - -@torch.jit.ignore -def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: - """ - Softmax, but without automatic casting to fp32 when the input is of - type bfloat16 - """ - d = t.dtype - s = torch.nn.functional.softmax(t, dim=dim) - return s - -def _torch_attention(query, key, value, mask=None, bias=None, upcast=False) -> torch.Tensor: - # upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - # output back to fp16/bf16. - dtype_og = query.dtype - if upcast: - query = query.float() - key = key.float() - value = value.float() - if mask is not None: - mask = mask.float() - if bias is not None: - bias = bias.float() - - # [*, H, C_hidden, K] - key = permute_final_dims(key, (1, 0)) - - # [*, H, Q, K] - a = torch.matmul(query, key) - - if bias is not None: - a += bias - - if mask is not None: - a.masked_fill_(mask < 0, float('-inf')) - # a += mask - - a = softmax_no_cast(a, -1) - - # [*, H, Q, C_hidden] - b = torch.matmul(a, value) - - return b.to(dtype_og) diff --git a/build.sh b/build.sh deleted file mode 100644 index f8309df33..000000000 --- a/build.sh +++ /dev/null @@ -1,10 +0,0 @@ - -#rm -rf build flash_attn_cuda.cpython-37m-x86_64-linux-gnu.so - -start=`date +%s` -#CXX="/usr/lib/ccache/c++" -python setup.py build -j 8 develop 2>&1 | tee build.log -end=`date +%s` - -runtime=$((end-start)) -echo ${runtime} From d663cf5e40d3013acc913ef34c2597825a47ed54 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 11:22:41 +0800 Subject: [PATCH 65/71] add declaration --- .../src/fmha_dgrad_kernel_1xN_loop.h | 28 ++++++++----------- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 22 ++++++--------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 28149d32d..d8960259f 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -148,22 +148,18 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - if constexpr (has_attn_mask) { - // Allocate the global memory tile loader for mask. - using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; - // conctructor - Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); - } - - if constexpr (has_attn_bias) { - // Allocate the global memory tile loader for bias. - using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; - // conctructor - Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); - - using Gmem_tile_ds = typename Kernel_traits::Gmem_tile_ds; - Gmem_tile_ds gmem_ds(params, binfo, tidx, loop_step_idx); - } + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + using Gmem_tile_ds = typename Kernel_traits::Gmem_tile_ds; + + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); + Gmem_tile_ds gmem_ds(params, binfo, tidx, loop_step_idx); fmha::Mask mask(binfo, tidx, loop_step_idx); diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 91d4e0c1f..4bef99f09 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -267,19 +267,15 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); - if constexpr (has_attn_mask) { - // Allocate the global memory tile loader for mask. - using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; - // conctructor - Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); - } - - if constexpr (has_attn_bias) { - // Allocate the global memory tile loader for bias. - using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; - // conctructor - Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); - } + // Allocate the global memory tile loader for mask. + using Gmem_tile_mask = typename Kernel_traits::Gmem_tile_mask; + // conctructor + Gmem_tile_mask gmem_mask(params, binfo, tidx, loop_step_idx); + + // Allocate the global memory tile loader for bias. + using Gmem_tile_bias = typename Kernel_traits::Gmem_tile_bias; + // conctructor + Gmem_tile_bias gmem_bias(params, binfo, tidx, loop_step_idx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); From de4f2cc8505a262c0df87e0aceb976aa3b69737c Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 16:15:47 +0800 Subject: [PATCH 66/71] remove useless comments --- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 4bef99f09..9930cdaa6 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -52,7 +52,6 @@ struct Gemm_Q_K_base { using Mma_tile_p = fmha::Hmma_tile; static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; - // ? __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) : smem_q(smem_ptr_q, tidx) @@ -260,7 +259,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Allocate the global memory tile loader for Q. Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); - // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); @@ -422,7 +420,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gmem_mask.move(); // Apply the attn mask. - softmax.apply_attn_mask(frag_mask, l, loop_step_idx); + softmax.apply_attn_mask(frag_mask); } if constexpr (has_attn_bias) { From 29578380bc1db56665048d55584e0b690885921d Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 16:40:44 +0800 Subject: [PATCH 67/71] remove useless comments --- csrc/flash_attn/src/fmha/gmem_tile.h | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 268af915b..953f335b1 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -359,7 +359,6 @@ struct Gmem_tile_mma_sd { template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > struct Gmem_tile_mma_s : public Base { - // mma matrix multiply // The number of mmas in the vertical dimension. static constexpr int M = Base::MMAS_M; // The number of mmas in the horizontal dimension. @@ -470,12 +469,10 @@ struct Gmem_tile_mma_mask { // this col is mean the 8x4 tile's cole row = warp_m * Mma_tile::M_PER_MMA + quad; - static_assert(Mma_tile::M_PER_MMA == 16, - "only support sm80 m16n8k16 tensor core"); + static_assert(Mma_tile::M_PER_MMA == 16); col = warp_n * Mma_tile::N_PER_MMA + tid; - static_assert(Mma_tile::N_PER_MMA == 16, - "only support sm80 m16n8k16 tensor core"); + static_assert(Mma_tile::N_PER_MMA == 16); // The distance between two blocks (in bytes). // TODO: mask is [bs * seq, head, seq_q, seq_k] @@ -609,12 +606,10 @@ struct Gmem_tile_mma_bias { // this col is mean the 8x4 tile's cole row = warp_m * Mma_tile::M_PER_MMA + quad; - static_assert(Mma_tile::M_PER_MMA == 16, - "only support sm80 m16n8k16 tensor core"); + static_assert(Mma_tile::M_PER_MMA == 16); col = warp_n * Mma_tile::N_PER_MMA + tid; - static_assert(Mma_tile::N_PER_MMA == 16, - "only support sm80 m16n8k16 tensor core"); + static_assert(Mma_tile::N_PER_MMA == 16); // The distance between two blocks (in bytes). // TODO: mask is [bs, head, seq_q, seq_k] From 39fa9d440fb3f39f8cfc03c8d4991be301ba2ec0 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 17:37:14 +0800 Subject: [PATCH 68/71] add timeout --- .github/workflows/publish.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ca412cef2..6e9f25f76 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,6 +13,7 @@ on: publish jobs: + timeout-minutes: 360 release: name: Create Release runs-on: ubuntu-latest From 0bb403ef23b0fae709a06eae0c7c2ca14d5f04ae Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 17:44:05 +0800 Subject: [PATCH 69/71] add default timeout for build wheel --- .github/workflows/publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6e9f25f76..28db34189 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,7 +13,6 @@ on: publish jobs: - timeout-minutes: 360 release: name: Create Release runs-on: ubuntu-latest @@ -35,6 +34,7 @@ jobs: wheel: name: Build Wheel runs-on: ${{ matrix.os }} + timeout-minutes: 360 needs: release strategy: From 184991b0a7b8518b0bff448d976bd2960e996970 Mon Sep 17 00:00:00 2001 From: robotcator Date: Wed, 12 Oct 2022 19:03:19 +0800 Subject: [PATCH 70/71] remove timeout --- .github/workflows/publish.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 28db34189..3a0a32cc3 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -34,7 +34,6 @@ jobs: wheel: name: Build Wheel runs-on: ${{ matrix.os }} - timeout-minutes: 360 needs: release strategy: @@ -128,4 +127,4 @@ jobs: upload_url: ${{ steps.get_current_release.outputs.upload_url }} asset_path: ./${{env.wheel_name}} asset_name: ${{env.wheel_name}} - asset_content_type: application/* \ No newline at end of file + asset_content_type: application/* From 3384115b51d1a110aea1a466480036c8e56fd932 Mon Sep 17 00:00:00 2001 From: robotcator Date: Thu, 13 Oct 2022 09:07:56 +0800 Subject: [PATCH 71/71] reduce build worker for workflow oom --- .github/workflows/publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3a0a32cc3..ddfdaba16 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -112,7 +112,7 @@ jobs: export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR pip install wheel - python setup.py bdist_wheel --dist-dir=dist + MAX_JOBS=1 python setup.py bdist_wheel --dist-dir=dist tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }} wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} ${wheel_name}