diff --git a/xla/service/algorithm_util.cc b/xla/service/algorithm_util.cc index 4e03705673090..8eec061e84e13 100644 --- a/xla/service/algorithm_util.cc +++ b/xla/service/algorithm_util.cc @@ -174,9 +174,15 @@ bool IsSupportedDotAlgorithmOnGpu( return input_storage_type == F16 && (output_storage_type == F16 || output_storage_type == F32); case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && - input_storage_type == BF16 && - (output_storage_type == BF16 || output_storage_type == F32); + if (!is_cuda_ge_ampere && !is_rocm_mi100_and_above) return false; + switch (input_storage_type) { + case BF16: + return output_storage_type == BF16 || output_storage_type == F32; + case F32: + return output_storage_type == F32; + default: + return false; + } case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && diff --git a/xla/service/gpu/dot_algorithm_support_test.cc b/xla/service/gpu/dot_algorithm_support_test.cc index bb17f1a6ed58f..f731049a8f6f6 100644 --- a/xla/service/gpu/dot_algorithm_support_test.cc +++ b/xla/service/gpu/dot_algorithm_support_test.cc @@ -173,10 +173,10 @@ TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) { if (params.backend_restriction == BackendRestriction::kTritonOnly) { MatchOptimizedHlo(hlo_text, R"( - ;CHECK: ENTRY - ;CHECK: ROOT - ;CHECK-SAME: kCustom - ;CHECK-SAME: "triton_gemm_config" + ;CHECK: ENTRY + ;CHECK: ROOT + ;CHECK-SAME: kCustom + ;CHECK-SAME: "triton_gemm_config" )"); } } else { @@ -215,7 +215,7 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, +INSTANTIATE_TEST_SUITE_P(DotBF16ForBf16Bf16F32Tests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(BF16), Values(BF16, F32), Values(CC(8, 0)), @@ -224,8 +224,15 @@ INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, +INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest, + Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32), + Values(F32), Values(CC(8, 0)), + Values(SemanticVersion{6, 0, 0}), + Values(BackendRestriction::kTritonOnly), + Values(Sizes{32, 32}, Sizes{16, 2})), + TestParamsToString); +INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X3, PC::ALG_DOT_BF16_BF16_F32_X6), Values(F32), Values(F32), Values(CC(8, 0)), diff --git a/xla/service/gpu/fusions/triton/BUILD b/xla/service/gpu/fusions/triton/BUILD index bf14cd4dd3fb5..7018af701cc22 100644 --- a/xla/service/gpu/fusions/triton/BUILD +++ b/xla/service/gpu/fusions/triton/BUILD @@ -433,6 +433,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@tsl//tsl/platform:tensor_float_32_utils", ], ) diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index e9974d3ce1584..4a4245c159067 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -2658,6 +2658,18 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth ? mt::InputPrecision::TF32 : mt::InputPrecision::IEEE; + + // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. + if (dot_instr->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_BF16_BF16_F32) { + if (dot_instr->operand(0)->shape().element_type() == F32) { + dot_input_lhs = Cast(b, dot_input_lhs, b.getBF16Type()); + } + if (dot_instr->operand(1)->shape().element_type() == F32) { + dot_input_rhs = Cast(b, dot_input_rhs, b.getBF16Type()); + } + } + // For fp8 matmuls, disable accumulator promotion, as it's what cublas // does. It may make sense to enable frequent accumulator promotion at // higher matmul precisions set in the config. diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 7ac574504294a..0ea4702d7445a 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -5134,6 +5134,46 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 /*arel=*/1e-5})); } +class TritonBF16BF16F32GemmTest : public TritonTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } + + protected: + void SetUp() override { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + } +}; + +TEST_F(TritonBF16BF16F32GemmTest, WorkWithF32InputAndAlgorithm_BF16_BF16_F32) { + const std::string kHloText = R"( + HloModule t + + ENTRY main { + lhs = f32[32,64]{1,0} parameter(0) + rhs = f32[64,16]{1,0} parameter(1) + ROOT dot = f32[32,16]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + // This test could be modified to allow TF32 once this bug is fixed. // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. TEST_F(TritonTest, NoTF32For8BitOrLessWithF32) { diff --git a/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/xla/service/gpu/fusions/triton/triton_support_legacy.cc index 8280accb99e10..802fed51f4d20 100644 --- a/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -274,8 +275,9 @@ CodegenDecision CanTritonHandleGEMM( } else { if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), gpu_version)) { - return CodegenDecision::Forbid( - "Unsupported algorithm on the current device(s)."); + return CodegenDecision::Forbid(absl::StrFormat( + "Unsupported algorithm on the current device(s): %s", + PrecisionConfig::Algorithm_Name(dot.precision_config().algorithm()))); } } diff --git a/xla/service/gpu/transforms/gemm_fusion.cc b/xla/service/gpu/transforms/gemm_fusion.cc index 1f6d41698aeaa..1934f59c48e24 100644 --- a/xla/service/gpu/transforms/gemm_fusion.cc +++ b/xla/service/gpu/transforms/gemm_fusion.cc @@ -740,6 +740,7 @@ absl::StatusOr CreateDotFusion( dot.precision_config().algorithm(); if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || + algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 || dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || dot.sparse_operands()) { return Decision::Allow(); @@ -757,9 +758,9 @@ absl::StatusOr CreateDotFusion( } return absl::OkStatus(); }); - if (is_pure_matmul) { - return Decision::NotProfitable("Pure Matmul"); - } + + if (is_pure_matmul) return Decision::NotProfitable("Pure Matmul"); + return Decision::Allow(); }