Skip to content

Commit

Permalink
[JAX] XLA Custom Calls with FFI for FusedAttnFwd, Quantize, Transpose…
Browse files Browse the repository at this point in the history
…, ActLuFP8, LayerNormForwardFP8FFI, and LayerNormBackwardFFI (#1263)

* Add TransposeFFI, test passed

Signed-off-by: Hua Huang <[email protected]>

* Add ActLuFP8FFI; fix TransposeFFI

Signed-off-by: Hua Huang <[email protected]>

* Add QuantizeFFI

Signed-off-by: Hua Huang <[email protected]>

* Add FusedAttnForwardFFI and some unit tests

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor fix

Signed-off-by: Hua Huang <[email protected]>

* Add LayerNormForwardFP8FFI & LayerNormBackwardFFI

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revise FusedAttnForwardFFI()

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add FFI_CudaGraph_Traits

All tests passed, ready for merge

Signed-off-by: Hua Huang <[email protected]>

* Bug fix for FFI data type mismatch

Also add a safeguard on the entrance to FFI function

Signed-off-by: Hua Huang <[email protected]>

---------

Signed-off-by: Hua Huang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
huanghua1994 and pre-commit-ci[bot] authored Oct 24, 2024
1 parent 20c7529 commit 18c2234
Show file tree
Hide file tree
Showing 14 changed files with 790 additions and 136 deletions.
109 changes: 106 additions & 3 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
# See LICENSE for license information.

from contextlib import nullcontext
import functools
import operator
from typing import Callable, List, Sequence, Union
import os

import jax
import jax.numpy as jnp
Expand All @@ -14,12 +13,17 @@
from jax import jit, value_and_grad
from flax import linen as nn

from utils import assert_allclose
from utils import assert_allclose, assert_tree_like_allclose
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax.cpp_extensions.transpose import (
_jax_transpose,
_jax_cast_transpose,
)
from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8
from transformer_engine.jax import cpp_extensions as tex


Expand Down Expand Up @@ -746,3 +750,102 @@ def ref_func(x, y, gamma, beta, zero_centered_gamma):
assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
if beta is not None:
assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)


@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"input_shape, transpose_axis",
[
pytest.param((16, 16), 1, id="(16, 16)-1"),
pytest.param((256, 128), 1, id="(256, 128)-1"),
pytest.param((128, 512), 1, id="(128, 512)-1"),
pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"),
pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"),
pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"),
],
)
class TestTranspose:
def test_transpose(self, in_dtype, input_shape, transpose_axis):
key = jax.random.PRNGKey(0)
input_tensor = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis)
assert_allclose(jax_output, noffi_output)
assert_allclose(noffi_output, ffi_output)

@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
static_axis_boundary = -1
jax_output = _jax_cast_transpose(
input, scale, amax, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_transpose(
input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis
)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize(
"input_shape",
[
pytest.param((256, 128), id="(256, 128)"),
pytest.param((128, 512, 8), id="(128, 512, 8)"),
],
)
@pytest.mark.parametrize(
"in_dtype",
[
pytest.param(jnp.float32, id="input_float32"),
pytest.param(jnp.float16, id="input_float16"),
pytest.param(jnp.bfloat16, id="input_bfloat16"),
],
)
@pytest.mark.parametrize(
"out_dtype",
[
pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"),
pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"),
],
)
def test_quantize(input_shape, in_dtype, out_dtype):
amax = jnp.zeros(1, jnp.float32)
scale = jnp.ones(1, jnp.float32)
scale_inv = jnp.ones(1, jnp.float32)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
jax_output = _jax_cast_fp8(input, scale, amax, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "0"
noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
os.environ["NVTE_JAX_WITH_FFI"] = "1"
ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype)
assert_tree_like_allclose(jax_output, ffi_output)
assert_tree_like_allclose(noffi_output, ffi_output)
2 changes: 1 addition & 1 deletion tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def enable_fused_attn():
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_HIDDEN_DROPOUT: 0,
_KEY_OF_ATTENTION_DROPOUT: 0,
_KEY_OF_ATTENTION_DROPOUT: 0.0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm",
Expand Down
66 changes: 36 additions & 30 deletions transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,37 +383,43 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape

hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)
if is_ffi_enabled():
name = "te_act_lu_fp8_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})(
ctx, x, amax, scale, scale_inv, act_enum=act_enum
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape

out = custom_caller(
ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)
hidden_size = ir_x_shape[-1]
batch_shape = ir_x_shape[:-2]
batch_size = reduce(operator.mul, batch_shape)
out_shape = batch_shape + [hidden_size]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
]
operands = [x, amax, scale, scale_inv]
operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_common_descriptor(
(batch_size, hidden_size),
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
act_enum,
)

out = custom_caller(
ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1}
)

return out

Expand Down
112 changes: 80 additions & 32 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import (
Expand All @@ -33,6 +34,7 @@
te_dtype_to_jax_dtype,
get_padded_spec,
get_cudnn_version,
is_ffi_enabled,
)
from ..sharding import (
global_mesh_resource,
Expand Down Expand Up @@ -352,14 +354,6 @@ def lowering(
"""
Fused attention fwd lowering rules
"""
operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
Expand All @@ -376,31 +370,85 @@ def lowering(

wkspace_aval = ctx.avals_out[-1]

opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
head_dim=head_dim,
max_segments_per_seq=config.max_segments_per_seq,
wkspace_size=wkspace_aval.size,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type),
mask_type=int(config.attn_mask_type),
qkv_layout=int(config.qkv_layout),
dtype=int(jax_dtype_to_te_dtype(q_aval.dtype)),
wkspace_dtype=int(jax_dtype_to_te_dtype(wkspace_aval.dtype)),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=config.window_size[0],
window_size_right=config.window_size[1],
)
else:
operands = [
q,
k,
v,
bias,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
seed,
]
operand_shapes = map(lambda x: x.type.shape, operands)
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

opaque = transformer_engine_jax.pack_fused_attn_descriptor(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
head_dim,
config.max_segments_per_seq,
wkspace_aval.size,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type,
config.attn_mask_type,
config.qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
config.is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
config.window_size[0],
config.window_size[1],
)

out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)

return out

Expand Down
Loading

0 comments on commit 18c2234

Please sign in to comment.