Skip to content

Commit

Permalink
[XLA:GPU] Add support for the explicit algorithm=BF16_BF16_F32 in Tri…
Browse files Browse the repository at this point in the history
…ton when the input is F32.

It is the case that was not covered when BF16_BF16_F32_X3 was introduced.

We enable F32 input in algorithm_util.cc. But the default behavior led to F32_F32_F32 triton that was slower than the cuBLAS with ~21ms. I.e. it was not faster despite lower precision and at the same time the fusion was forbidden due to "Pure matmul".

With the explicit truncation the F32 input to BF16 in the triton emitter we could reach the latency ~4ms which is way better than F32_F32_F32 (~21ms), and BF16_BF16_F32_X3 (~13ms), and BF16_BF16_F32_X6 (~18ms), but it is still slower that the clear dot for BF16 arguments (1.53ms).

PiperOrigin-RevId: 678283878
  • Loading branch information
loislo authored and Google-ML-Automation committed Sep 24, 2024
1 parent 794d2a1 commit b831b17
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 14 deletions.
12 changes: 9 additions & 3 deletions xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,15 @@ bool IsSupportedDotAlgorithmOnGpu(
return input_storage_type == F16 &&
(output_storage_type == F16 || output_storage_type == F32);
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
input_storage_type == BF16 &&
(output_storage_type == BF16 || output_storage_type == F32);
if (!is_cuda_ge_ampere && !is_rocm_mi100_and_above) return false;
switch (input_storage_type) {
case BF16:
return output_storage_type == BF16 || output_storage_type == F32;
case F32:
return output_storage_type == F32;
default:
return false;
}
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
return (is_cuda_ge_ampere || is_rocm_mi100_and_above) &&
Expand Down
19 changes: 13 additions & 6 deletions xla/service/gpu/dot_algorithm_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) {

if (params.backend_restriction == BackendRestriction::kTritonOnly) {
MatchOptimizedHlo(hlo_text, R"(
;CHECK: ENTRY
;CHECK: ROOT
;CHECK-SAME: kCustom
;CHECK-SAME: "triton_gemm_config"
;CHECK: ENTRY
;CHECK: ROOT
;CHECK-SAME: kCustom
;CHECK-SAME: "triton_gemm_config"
)");
}
} else {
Expand Down Expand Up @@ -215,7 +215,7 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest,
Values(Sizes{32, 32}, Sizes{16, 2})),
TestParamsToString);

INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest,
INSTANTIATE_TEST_SUITE_P(DotBF16ForBf16Bf16F32Tests, DotAlgorithmSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32),
Values(BF16), Values(BF16, F32),
Values(CC(8, 0)),
Expand All @@ -224,8 +224,15 @@ INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest,
Values(Sizes{32, 32}, Sizes{16, 2})),
TestParamsToString);

INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest,
INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32),
Values(F32), Values(CC(8, 0)),
Values(SemanticVersion{6, 0, 0}),
Values(BackendRestriction::kTritonOnly),
Values(Sizes{32, 32}, Sizes{16, 2})),
TestParamsToString);

INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest,
Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X3,
PC::ALG_DOT_BF16_BF16_F32_X6),
Values(F32), Values(F32), Values(CC(8, 0)),
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ cc_library(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@tsl//tsl/platform:tensor_float_32_utils",
],
)
Expand Down
12 changes: 12 additions & 0 deletions xla/service/gpu/fusions/triton/triton_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2658,6 +2658,18 @@ absl::Status EmitMatMul(mlir::OpBuilder builder,
IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth
? mt::InputPrecision::TF32
: mt::InputPrecision::IEEE;

// Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32.
if (dot_instr->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_BF16_BF16_F32) {
if (dot_instr->operand(0)->shape().element_type() == F32) {
dot_input_lhs = Cast(b, dot_input_lhs, b.getBF16Type());
}
if (dot_instr->operand(1)->shape().element_type() == F32) {
dot_input_rhs = Cast(b, dot_input_rhs, b.getBF16Type());
}
}

// For fp8 matmuls, disable accumulator promotion, as it's what cublas
// does. It may make sense to enable frequent accumulator promotion at
// higher matmul precisions set in the config.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5134,6 +5134,46 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32
/*arel=*/1e-5}));
}

class TritonBF16BF16F32GemmTest : public TritonTest {
public:
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = TritonTest::GetDebugOptionsForTest();
// Do not fall back to cuBLAS, we are testing Triton.
debug_options.set_xla_gpu_cublas_fallback(false);
// Do not autotune split-k by default, since this prevents deterministically
// matching the optimized HLO.
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
return debug_options;
}

protected:
void SetUp() override {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
}
};

TEST_F(TritonBF16BF16F32GemmTest, WorkWithF32InputAndAlgorithm_BF16_BF16_F32) {
const std::string kHloText = R"(
HloModule t
ENTRY main {
lhs = f32[32,64]{1,0} parameter(0)
rhs = f32[64,16]{1,0} parameter(1)
ROOT dot = f32[32,16]{1,0} dot(lhs, rhs),
algorithm=dot_bf16_bf16_f32,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
const std::string pattern =
R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
EXPECT_TRUE(ok);
}

// This test could be modified to allow TF32 once this bug is fixed.
// TODO(b/320659359) Allow TF32 for 8-bit or less types with F32.
TEST_F(TritonTest, NoTF32For8BitOrLessWithF32) {
Expand Down
6 changes: 4 additions & 2 deletions xla/service/gpu/fusions/triton/triton_support_legacy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include "absl/algorithm/container.h"
#include "absl/log/check.h"
#include "absl/strings/str_format.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -274,8 +275,9 @@ CodegenDecision CanTritonHandleGEMM(
} else {
if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(),
gpu_version)) {
return CodegenDecision::Forbid(
"Unsupported algorithm on the current device(s).");
return CodegenDecision::Forbid(absl::StrFormat(
"Unsupported algorithm on the current device(s): %s",
PrecisionConfig::Algorithm_Name(dot.precision_config().algorithm())));
}
}

Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/transforms/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ absl::StatusOr<Decision> CreateDotFusion(
dot.precision_config().algorithm();
if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 ||
algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 ||
dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() ||
dot.sparse_operands()) {
return Decision::Allow();
Expand All @@ -757,9 +758,9 @@ absl::StatusOr<Decision> CreateDotFusion(
}
return absl::OkStatus();
});
if (is_pure_matmul) {
return Decision::NotProfitable("Pure Matmul");
}

if (is_pure_matmul) return Decision::NotProfitable("Pure Matmul");

return Decision::Allow();
}

Expand Down

0 comments on commit b831b17

Please sign in to comment.