From 12f30ead326e88d87c7de496137272d8fb462d34 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Thu, 17 Oct 2024 08:30:02 -0700 Subject: [PATCH] [TE/JAX] Enabling CudaGraph for custom calls with FFI (#1228) * register CmdBufferCompatible traits via C++ API * renamed FFI_Traits * use register_ffi_target() --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/custom_call.py | 7 +++---- transformer_engine/jax/csrc/extensions/activation.cpp | 6 ++++-- transformer_engine/jax/csrc/extensions/ffi.h | 1 + transformer_engine/jax/csrc/extensions/transpose.cpp | 3 ++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 8e58ed3bed..1075030a0d 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from enum import IntEnum -from jax.lib import xla_client from jax.interpreters import mlir +import jax.extend as jex from transformer_engine import transformer_engine_jax @@ -30,12 +30,11 @@ class CustomCallAPIVersion(IntEnum): for _name, _value in transformer_engine_jax.registrations().items(): if _name.endswith("_ffi"): if is_ffi_enabled(): - # COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default - xla_client.register_custom_call_target( + jex.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value ) else: - xla_client.register_custom_call_target( + jex.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value ) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 1e8998b365..2baba48acf 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -126,7 +126,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ctx() // stream .Arg() // input .Ret() // output - .Attr("act_enum")); + .Attr("act_enum"), + FFI_CudaGraph_Traits); void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; @@ -276,7 +277,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI, .Arg() // input .Arg() // act_input .Ret() // output - .Attr("act_enum")); + .Attr("act_enum"), + FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 77132c3fca..729e8e60e3 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -17,6 +17,7 @@ using Result_Type = xla::ffi::Result; using Error_Type = xla::ffi::Error; using FFI = xla::ffi::Ffi; using FFI_Stream_Type = xla::ffi::PlatformStream; +constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type); Error_Type ffi_with_cuda_error_check(); diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 7a2e31312a..1d1957e0bf 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -120,7 +120,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI, .Ret() // input_cast .Ret() // input_cast_trans .Ret() // amax_out - .Attr("transpose_axis")); + .Attr("transpose_axis"), + FFI_CudaGraph_Traits); pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) {