Skip to content

Commit

Permalink
[BACKEND] Extend hoisting of convert op above ext ops (#2206)
Browse files Browse the repository at this point in the history
Handle more cases of hoisting convert above ext op. If there are
multiple ext op in the slice but only one requires inserting a convert
we can still apply the optimization.
  • Loading branch information
ThomasRaoux authored Aug 30, 2023
1 parent 9f3b631 commit 2ff88c1
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 33 deletions.
85 changes: 55 additions & 30 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,25 @@ static void rewriteSlice(SetVector<Value> &slice,
rewriteSlice(slice, layout, convertOp, mapping);
}

static LogicalResult getRematerializableSlice(
Value root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr) {
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
layout, stopPropagation);
if (result.failed() || slice.empty())
return failure();

// Check if all the operations in the slice can be rematerialized.
for (Value v : slice) {
if (Operation *op = v.getDefiningOp()) {
if (!canBeRemat(op))
return failure();
}
}
return success();
}

static void backwardRematerialization(ConvertLayoutOp convertOp) {
// we don't want to rematerialize any conversion to/from shared
if (triton::gpu::isSharedEncoding(convertOp.getResult()) ||
Expand All @@ -759,22 +778,16 @@ static void backwardRematerialization(ConvertLayoutOp convertOp) {
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return;

// 1. Take a backward slice of all the tensor dependencies.
// 1. Take a backward slice of all the tensor dependencies that can be
// rematerialized.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getConvertBackwardSlice(
convertOp.getOperand(), slice, targetType.getEncoding(), layout);
if (result.failed() || slice.empty())
LogicalResult result = getRematerializableSlice(
convertOp.getOperand(), targetType.getEncoding(), slice, layout);
if (result.failed())
return;

// 2. Check if all the operations in the slice can be rematerialized.
for (Value v : slice) {
if (Operation *op = v.getDefiningOp()) {
if (!canBeRemat(op))
return;
}
}
// 3. Rewrite the slice.
// 2. Rewrite the slice.
rewriteSlice(slice, layout, convertOp);
}

Expand All @@ -791,32 +804,44 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) {
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return;

// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
auto isExtOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op);
};
// Get a backward slice but don't go past ext ops
LogicalResult result = getConvertBackwardSlice(
convertOp.getOperand(), slice, targetType.getEncoding(), layout, isExtOp);
if (result.failed() || slice.empty())
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getOperand(), targetType.getEncoding(), slice, layout, isExtOp);
if (result.failed())
return;

Operation *extOp = nullptr;
// 2. Check if all the operations in the slice can be rematerialized.
for (Value v : slice) {
if (Operation *op = v.getDefiningOp()) {
if (!canBeRemat(op))
return;
if (isExtOp(op)) {
// Only apply it if there is a single ext op otherwise we would have to
// duplicate the convert.
if (extOp != nullptr)
return;
extOp = op;
unsigned sliceSize = slice.size();
for (unsigned i = 0; i < sliceSize; i++) {
Value v = slice[i];
Operation *op = v.getDefiningOp();
if (!op)
continue;
if (isExtOp(op)) {
SetVector<Value> tempSlice;
DenseMap<Value, Attribute> tempLayout;
LogicalResult result = getRematerializableSlice(
op->getOperand(0), layout[v], tempSlice, tempLayout);
// If we can rematerialize the rest of the ext slice we can ignore this
// ext as it won't need a convert.
if (result.succeeded()) {
slice.insert(tempSlice.begin(), tempSlice.end());
layout.insert(tempLayout.begin(), tempLayout.end());
continue;
}
// Only apply it if there is a single ext op otherwise we would have to
// duplicate the convert.
if (extOp != nullptr)
return;
extOp = op;
}
}

if (extOp == nullptr)
return;
// Move the convert before the ext op and rewrite the slice.
Expand Down
22 changes: 19 additions & 3 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,28 @@ tt.func @hoist_above_ext(%arg0: tensor<1024xf16, #layout0>, %arg1: f32) -> tenso
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.return
%0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0>
%1 = tt.splat %arg1 : (f32) -> tensor<1024xf32, #layout1>
%2 = triton_gpu.convert_layout %0 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1>
%3 = arith.addf %1, %2 : tensor<1024xf32, #layout1>
%1 = tt.splat %arg1 : (f32) -> tensor<1024xf32, #layout0>
%2 = arith.addf %0, %1 : tensor<1024xf32, #layout0>
%3 = triton_gpu.convert_layout %2 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1>
tt.return %3 : tensor<1024xf32, #layout1>
}

// CHECK-LABEL: hoist_above_ext2
tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tensor<1024xf32, #layout1> {
// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout
// CHECK: arith.extf %[[CVT]]
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: tt.return
%0 = arith.extf %arg0 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0>
%1 = tt.splat %arg1 : (f16) -> tensor<1024xf16, #layout0>
%2 = arith.extf %1 : tensor<1024xf16, #layout0> to tensor<1024xf32, #layout0>
%3 = arith.addf %0, %2 : tensor<1024xf32, #layout0>
%4 = triton_gpu.convert_layout %3 : (tensor<1024xf32, #layout0>) -> tensor<1024xf32, #layout1>
tt.return %4 : tensor<1024xf32, #layout1>
}



// CHECK-LABEL: if
tt.func @if(%arg0: i32, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
// CHECK-NOT: triton_gpu.convert_layout
Expand Down

0 comments on commit 2ff88c1

Please sign in to comment.