From 2ff88c1368de06e7458cf7ebe99aa6d2c7be4ec6 Mon Sep 17 00:00:00 2001 From: Thomas Date: Tue, 29 Aug 2023 17:36:34 -0700 Subject: [PATCH] [BACKEND] Extend hoisting of convert op above ext ops (#2206) Handle more cases of hoisting convert above ext op. If there are multiple ext op in the slice but only one requires inserting a convert we can still apply the optimization. --- .../Transforms/RemoveLayoutConversions.cpp | 85 ++++++++++++------- test/TritonGPU/combine.mlir | 22 ++++- 2 files changed, 74 insertions(+), 33 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 2a38b03d2e14..894914e5f97f 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -748,6 +748,25 @@ static void rewriteSlice(SetVector &slice, rewriteSlice(slice, layout, convertOp, mapping); } +static LogicalResult getRematerializableSlice( + Value root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr) { + LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding, + layout, stopPropagation); + if (result.failed() || slice.empty()) + return failure(); + + // Check if all the operations in the slice can be rematerialized. + for (Value v : slice) { + if (Operation *op = v.getDefiningOp()) { + if (!canBeRemat(op)) + return failure(); + } + } + return success(); +} + static void backwardRematerialization(ConvertLayoutOp convertOp) { // we don't want to rematerialize any conversion to/from shared if (triton::gpu::isSharedEncoding(convertOp.getResult()) || @@ -759,22 +778,16 @@ static void backwardRematerialization(ConvertLayoutOp convertOp) { if (targetType.getEncoding().isa()) return; - // 1. Take a backward slice of all the tensor dependencies. + // 1. Take a backward slice of all the tensor dependencies that can be + // rematerialized. SetVector slice; DenseMap layout; - LogicalResult result = getConvertBackwardSlice( - convertOp.getOperand(), slice, targetType.getEncoding(), layout); - if (result.failed() || slice.empty()) + LogicalResult result = getRematerializableSlice( + convertOp.getOperand(), targetType.getEncoding(), slice, layout); + if (result.failed()) return; - // 2. Check if all the operations in the slice can be rematerialized. - for (Value v : slice) { - if (Operation *op = v.getDefiningOp()) { - if (!canBeRemat(op)) - return; - } - } - // 3. Rewrite the slice. + // 2. Rewrite the slice. rewriteSlice(slice, layout, convertOp); } @@ -791,32 +804,44 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { if (targetType.getEncoding().isa()) return; - // 1. Take a backward slice of all the tensor dependencies. - SetVector slice; - DenseMap layout; auto isExtOp = [](Operation *op) { return isa(op); }; - // Get a backward slice but don't go past ext ops - LogicalResult result = getConvertBackwardSlice( - convertOp.getOperand(), slice, targetType.getEncoding(), layout, isExtOp); - if (result.failed() || slice.empty()) + // 1. Take a backward slice of all the tensor dependencies. + SetVector slice; + DenseMap layout; + LogicalResult result = getRematerializableSlice( + convertOp.getOperand(), targetType.getEncoding(), slice, layout, isExtOp); + if (result.failed()) return; + Operation *extOp = nullptr; - // 2. Check if all the operations in the slice can be rematerialized. - for (Value v : slice) { - if (Operation *op = v.getDefiningOp()) { - if (!canBeRemat(op)) - return; - if (isExtOp(op)) { - // Only apply it if there is a single ext op otherwise we would have to - // duplicate the convert. - if (extOp != nullptr) - return; - extOp = op; + unsigned sliceSize = slice.size(); + for (unsigned i = 0; i < sliceSize; i++) { + Value v = slice[i]; + Operation *op = v.getDefiningOp(); + if (!op) + continue; + if (isExtOp(op)) { + SetVector tempSlice; + DenseMap tempLayout; + LogicalResult result = getRematerializableSlice( + op->getOperand(0), layout[v], tempSlice, tempLayout); + // If we can rematerialize the rest of the ext slice we can ignore this + // ext as it won't need a convert. + if (result.succeeded()) { + slice.insert(tempSlice.begin(), tempSlice.end()); + layout.insert(tempLayout.begin(), tempLayout.end()); + continue; } + // Only apply it if there is a single ext op otherwise we would have to + // duplicate the convert. + if (extOp != nullptr) + return; + extOp = op; } } + if (extOp == nullptr) return; // Move the convert before the ext op and rewrite the slice. diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 92f33e0c257f..ab670e375775 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -85,12 +85,28 @@ tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tenso // CHECK-NOT: triton_gpu.convert_layout // CHECK: tt.return %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> - %1 = tt.splat %arg1 : (f32) -> tensor<1024xf32, #layout1> - %2 = triton_gpu.convert_layout %0 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1> - %3 = arith.addf %1, %2 : tensor<1024xf32, #layout1> + %1 = tt.splat %arg1 : (f32) -> tensor<1024xf32, #layout0> + %2 = arith.addf %0, %1 : tensor<1024xf32, #layout0> + %3 = triton_gpu.convert_layout %2 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1> tt.return %3 : tensor<1024xf32, #layout1> } +// CHECK-LABEL: hoist_above_ext2 +tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tensor<1024xf32, #layout1> { +// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: arith.extf %[[CVT]] +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return + %0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %1 = tt.splat %arg1 : (f16) -> tensor<1024xf16, #layout0> + %2 = arith.extf %1 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0> + %3 = arith.addf %0, %2 : tensor<1024xf32, #layout0> + %4 = triton_gpu.convert_layout %3 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1> + tt.return %4 : tensor<1024xf32, #layout1> +} + + + // CHECK-LABEL: if tt.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout