From b831b175c313ed8f105cb5a91446b4a414e37ade Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Tue, 24 Sep 2024 09:34:32 -0700 Subject: [PATCH] [XLA:GPU] Add support for the explicit algorithm=BF16_BF16_F32 in Triton when the input is F32. It is the case that was not covered when BF16_BF16_F32_X3 was introduced. We enable F32 input in algorithm_util.cc. But the default behavior led to F32_F32_F32 triton that was slower than the cuBLAS with ~21ms. I.e. it was not faster despite lower precision and at the same time the fusion was forbidden due to "Pure matmul". With the explicit truncation the F32 input to BF16 in the triton emitter we could reach the latency ~4ms which is way better than F32_F32_F32 (~21ms), and BF16_BF16_F32_X3 (~13ms), and BF16_BF16_F32_X6 (~18ms), but it is still slower that the clear dot for BF16 arguments (1.53ms). PiperOrigin-RevId: 678283878 --- xla/service/algorithm_util.cc | 12 ++++-- xla/service/gpu/dot_algorithm_support_test.cc | 19 ++++++--- xla/service/gpu/fusions/triton/BUILD | 1 + .../fusions/triton/triton_fusion_emitter.cc | 12 ++++++ ...riton_fusion_emitter_device_legacy_test.cc | 40 +++++++++++++++++++ .../fusions/triton/triton_support_legacy.cc | 6 ++- xla/service/gpu/transforms/gemm_fusion.cc | 7 ++-- 7 files changed, 83 insertions(+), 14 deletions(-) diff --git a/xla/service/algorithm_util.cc b/xla/service/algorithm_util.cc index 4e03705673090..8eec061e84e13 100644 --- a/xla/service/algorithm_util.cc +++ b/xla/service/algorithm_util.cc @@ -174,9 +174,15 @@ bool IsSupportedDotAlgorithmOnGpu( return input_storage_type == F16 && (output_storage_type == F16 || output_storage_type == F32); case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && - input_storage_type == BF16 && - (output_storage_type == BF16 || output_storage_type == F32); + if (!is_cuda_ge_ampere && !is_rocm_mi100_and_above) return false; + switch (input_storage_type) { + case BF16: + return output_storage_type == BF16 || output_storage_type == F32; + case F32: + return output_storage_type == F32; + default: + return false; + } case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && diff --git a/xla/service/gpu/dot_algorithm_support_test.cc b/xla/service/gpu/dot_algorithm_support_test.cc index bb17f1a6ed58f..f731049a8f6f6 100644 --- a/xla/service/gpu/dot_algorithm_support_test.cc +++ b/xla/service/gpu/dot_algorithm_support_test.cc @@ -173,10 +173,10 @@ TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) { if (params.backend_restriction == BackendRestriction::kTritonOnly) { MatchOptimizedHlo(hlo_text, R"( - ;CHECK: ENTRY - ;CHECK: ROOT - ;CHECK-SAME: kCustom - ;CHECK-SAME: "triton_gemm_config" + ;CHECK: ENTRY + ;CHECK: ROOT + ;CHECK-SAME: kCustom + ;CHECK-SAME: "triton_gemm_config" )"); } } else { @@ -215,7 +215,7 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, +INSTANTIATE_TEST_SUITE_P(DotBF16ForBf16Bf16F32Tests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(BF16), Values(BF16, F32), Values(CC(8, 0)), @@ -224,8 +224,15 @@ INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, +INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest, + Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32), + Values(F32), Values(CC(8, 0)), + Values(SemanticVersion{6, 0, 0}), + Values(BackendRestriction::kTritonOnly), + Values(Sizes{32, 32}, Sizes{16, 2})), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X3, PC::ALG_DOT_BF16_BF16_F32_X6), Values(F32), Values(F32), Values(CC(8, 0)), diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index bf14cd4dd3fb5..7018af701cc22 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -433,6 +433,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:tensor_float_32_utils", ], ) diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index e9974d3ce1584..4a4245c159067 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -2658,6 +2658,18 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth ? mt::InputPrecision::TF32 : mt::InputPrecision::IEEE; + + // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. + if (dot_instr->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_BF16_BF16_F32) { + if (dot_instr->operand(0)->shape().element_type() == F32) { + dot_input_lhs = Cast(b, dot_input_lhs, b.getBF16Type()); + } + if (dot_instr->operand(1)->shape().element_type() == F32) { + dot_input_rhs = Cast(b, dot_input_rhs, b.getBF16Type()); + } + } + // For fp8 matmuls, disable accumulator promotion, as it's what cublas // does. It may make sense to enable frequent accumulator promotion at // higher matmul precisions set in the config. diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 7ac574504294a..0ea4702d7445a 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -5134,6 +5134,46 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 /*arel=*/1e-5})); } +class TritonBF16BF16F32GemmTest : public TritonTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } + + protected: + void SetUp() override { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + } +}; + +TEST_F(TritonBF16BF16F32GemmTest, WorkWithF32InputAndAlgorithm_BF16_BF16_F32) { + const std::string kHloText = R"( + HloModule t + + ENTRY main { + lhs = f32[32,64]{1,0} parameter(0) + rhs = f32[64,16]{1,0} parameter(1) + ROOT dot = f32[32,16]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + // This test could be modified to allow TF32 once this bug is fixed. // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. TEST_F(TritonTest, NoTF32For8BitOrLessWithF32) { diff --git a/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/xla/service/gpu/fusions/triton/triton_support_legacy.cc index 8280accb99e10..802fed51f4d20 100644 --- a/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -274,8 +275,9 @@ CodegenDecision CanTritonHandleGEMM( } else { if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), gpu_version)) { - return CodegenDecision::Forbid( - "Unsupported algorithm on the current device(s)."); + return CodegenDecision::Forbid(absl::StrFormat( + "Unsupported algorithm on the current device(s): %s", + PrecisionConfig::Algorithm_Name(dot.precision_config().algorithm()))); } } diff --git a/xla/service/gpu/transforms/gemm_fusion.cc b/xla/service/gpu/transforms/gemm_fusion.cc index 1f6d41698aeaa..1934f59c48e24 100644 --- a/xla/service/gpu/transforms/gemm_fusion.cc +++ b/xla/service/gpu/transforms/gemm_fusion.cc @@ -740,6 +740,7 @@ absl::StatusOr CreateDotFusion( dot.precision_config().algorithm(); if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || + algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 || dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || dot.sparse_operands()) { return Decision::Allow(); @@ -757,9 +758,9 @@ absl::StatusOr CreateDotFusion( } return absl::OkStatus(); }); - if (is_pure_matmul) { - return Decision::NotProfitable("Pure Matmul"); - } + + if (is_pure_matmul) return Decision::NotProfitable("Pure Matmul"); + return Decision::Allow(); }