Skip to content

Commit

Permalink
sync to flash attention kernel 2.5.9 and add document of how to write…
Browse files Browse the repository at this point in the history
… custom op (#757)

* sync to flash attention kernel 2.5.9

* support users to overload GetMayInplace and ReleaseMayInplace

* Undo the change for pybind11 dependency
  • Loading branch information
jslhcl authored Jul 10, 2024
1 parent b436d09 commit 95d65e4
Show file tree
Hide file tree
Showing 12 changed files with 1,315 additions and 1,006 deletions.
2 changes: 0 additions & 2 deletions cmake/ext_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no

add_compile_definitions(USE_CUDA)

set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use
set(OCOS_USE_FLASH_ATTENTION OFF)
if (OCOS_USE_FLASH_ATTENTION)
message(STATUS "Enable flash attention")
add_compile_definitions(OCOS_USE_FLASH_ATTENTION)
Expand Down
60 changes: 60 additions & 0 deletions docs/How_to_write_custom_op.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# How to write custom ops

Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order.

## Basic scenario

You have 2 ways to write a custom op: by writing a function, or by writing a structure.

### Custom op in the form of function

If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like:

```C++
const Ort::Custom::Tensor<T>&
// or
const Ort::Custom::Tensor<T>*
```

For the inputs that are optional, their type would be like:

```C++
std::optional<const Ort::Custom::Tensor<T>*>
```

The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU.

The function will return the type **OrtStatusPtr**

Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types.

### Custom op in the form of structure

If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions:

```C++
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op

OrtStatusPtr Compute(...) const // This function computes the customized kernel.
```
The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function)
## Advanced scenario
In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as:
```C++
// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs,
// you can provide your own implementation to specify the ith input is in CPU or GPU.
static OrtMemType GetInputMemoryType(size_t input_index)
// You can specify input i shares the same memory with output j if possible, by allocating
// two array with same length for the pointer input_index and output_index seperately, and
// then let (*input_index)[k] = i and (*output_index)[k] = j.
// The return value is the length of the allocated array.
static size_t GetMayInplace(int** input_index, int** output_index)
// Release the allocated array from the GetMayInplace() function.
static void ReleaseMayInplace(int* input_index, int* output_index)
```
7 changes: 7 additions & 0 deletions include/custom_op/custom_op_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return INPUT_OUTPUT_OPTIONAL;
};
#endif

#if ORT_API_VERSION >= 18
OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
return 0;
};
OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
#endif
}

const std::string op_name_;
Expand Down
25 changes: 25 additions & 0 deletions include/op_def_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {};
template <typename T>
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};

template <typename T, typename = void>
struct CustomOp_defined_getMayInplace : std::false_type {};

template <typename T>
struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};

template <typename T, typename = void>
struct CustomOp_defined_releaseMayInplace : std::false_type {};

template <typename T>
struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {};

template <typename CustomOpKernel>
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
using ComputeFunction = decltype(&CustomOpKernel::Compute);
Expand Down Expand Up @@ -192,6 +204,19 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
};
}

#if ORT_API_VERSION >= 18
if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
return CustomOpKernel::GetMayInplace(input_index, output_index);
};
}
if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) {
OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void {
CustomOpKernel::ReleaseMayInplace(input_index, output_index);
};
}
#endif

OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
if (api == nullptr) {
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

__author__ = "Microsoft"


from ._version import __version__
from ._ocos import get_library_path
from ._ocos import Opdef, PyCustomOpDef
Expand Down Expand Up @@ -66,6 +65,10 @@ def _unimplemented(*args, **kwargs):
gen_processing_models = _unimplemented
OrtPyFunction = _unimplemented
ort_inference = _unimplemented
PyOrtFunction = _unimplemented
optimize_model = _unimplemented
make_onnx_model = _unimplemented
ONNXRuntimeError = _unimplemented

else:
__all__ += _offline_api
Expand Down
10 changes: 10 additions & 0 deletions operators/cuda/attention_lib/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ struct Flash_fwd_params : public Qkv_params {
// The indices to index into the KV cache.
int* __restrict__ cache_batch_idx = nullptr;

// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;

float rp_dropout;

// Local window size
int window_size_left = -1;
int window_size_right = -1;
Expand All @@ -102,6 +109,9 @@ struct Flash_fwd_params : public Qkv_params {

int num_splits = 0; // For split-KV version

void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;

const cudaDeviceProp* dprops = nullptr;
};

Expand Down
29 changes: 22 additions & 7 deletions operators/cuda/attention_lib/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ void set_params_fprop(Flash_fwd_params& params,
bool is_bf16,
bool kv_bsnh = true,
int window_size_left = -1,
int window_size_right = -1) {
int window_size_right = -1,
bool paged_KV = false,
int page_block_size = -1) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
Expand Down Expand Up @@ -64,8 +66,8 @@ void set_params_fprop(Flash_fwd_params& params,

if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.k_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
params.v_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
} else {
params.q_batch_stride = 0;
Expand Down Expand Up @@ -99,6 +101,10 @@ void set_params_fprop(Flash_fwd_params& params,
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;

params.rp_dropout = 1.f;
params.alibi_slopes_ptr = nullptr;
params.alibi_slopes_batch_stride = 0;

// In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates
// local and causal, meaning when we have local window size
params.is_causal = is_causal;
Expand Down Expand Up @@ -349,8 +355,8 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
Expand All @@ -374,7 +380,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size,
bool is_rotary_interleaved,
bool is_packed_qkv) {
bool is_packed_qkv,
int32_t* block_table, // batch_size x max_num_blocks_per_seq
int32_t max_num_blocks_per_seq,
int32_t page_block_size) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
Expand All @@ -398,7 +407,9 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_bf16,
past_bsnh,
local_window_size,
is_causal ? 0 : -1);
is_causal ? 0 : -1,
block_table != nullptr,
page_block_size);
params.dprops = &dprops;

if (k_new != nullptr && v_new != nullptr) {
Expand Down Expand Up @@ -454,6 +465,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.oaccum_ptr = nullptr;
}

params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
params.page_block_size = page_block_size;

// Only split kernel supports appending to KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);

Expand Down
9 changes: 6 additions & 3 deletions operators/cuda/attention_lib/flash_attention/flash_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops,
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
Expand All @@ -78,7 +78,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1,
bool is_rotary_interleaved = false,
bool is_packed_qkv = false);
bool is_packed_qkv = false,
int32_t* block_table = nullptr, // batch_size x max_num_blocks_per_seq
int32_t max_num_blocks_per_seq = -1,
int32_t page_block_size = 1);

size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

Expand Down
Loading

0 comments on commit 95d65e4

Please sign in to comment.