Skip to content

Commit

Permalink
[TE/JAX] Enabling CudaGraph for custom calls with FFI (#1228)
Browse files Browse the repository at this point in the history
* register CmdBufferCompatible traits via C++ API

* renamed FFI_Traits

* use register_ffi_target()

---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Oct 17, 2024
1 parent 8e97c8d commit 12f30ea
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
7 changes: 3 additions & 4 deletions transformer_engine/jax/cpp_extensions/custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from dataclasses import dataclass
from enum import IntEnum

from jax.lib import xla_client
from jax.interpreters import mlir
import jax.extend as jex

from transformer_engine import transformer_engine_jax

Expand All @@ -30,12 +30,11 @@ class CustomCallAPIVersion(IntEnum):
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
# COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default
xla_client.register_custom_call_target(
jex.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
xla_client.register_custom_call_target(
jex.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)

Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/jax/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);

void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
Expand Down Expand Up @@ -276,7 +277,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI,
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);

pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/jax/csrc/extensions/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};

DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type);
Error_Type ffi_with_cuda_error_check();
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/jax/csrc/extensions/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
.Ret<Buffer_Type>() // input_cast
.Ret<Buffer_Type>() // input_cast_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("transpose_axis"));
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);

pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
Expand Down

0 comments on commit 12f30ea

Please sign in to comment.