Skip to content

Commit

Permalink
Use block load attribute to remove duplicate logic from MaterializeBl…
Browse files Browse the repository at this point in the history
…ockPointer pass (#2420)

The logic in `shouldRemove` in the `RewriteTensorPointer` pass
duplicates the same logic in `MaterializeBlockPointer`:
https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp#L50

This duplication is necessary for the Matrix transpose multiplication
case because the block pointer is defined outside a `scf.for` loop, but
the load is inside the loop. The previous logic in
`RewriteTensorPointer` could not "see" into the `scf.for` loop block and
decided to remove the tensor pointer even though its result was used by
a block load. 

This commit changes the algorithm:

First, walk the tree and look for MakeTensorPtr ops. For each MakeTensorPtr op, we do a search to find load/store users of the op. If we have a store associated with DPAS layout, or a block load, then we do not mark the MakeTensorPtr op for removal. Otherwise, we mark it for removal.

Next, we make a pass through all the ops again and make sure we removal all MakeTensorPtr-related ops for each MakeTensorPtr marked for removal (tt.advance, rewrite the loads, etc).
 
Close #2380
  • Loading branch information
alexbaden authored Oct 10, 2024
1 parent a2a3100 commit 734e33c
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 81 deletions.
5 changes: 3 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4286,8 +4286,9 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con
actual = torch.zeros(expected.shape, dtype=torch.int32, device=device)

k = kernel[(1, )](input, actual, shape[0], shape[1])
assert k.asm['ttgir'].count(
'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization"
if not is_xpu():
assert k.asm['ttgir'].count(
'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization"

np.testing.assert_equal(to_numpy(expected), to_numpy(actual))

Expand Down
68 changes: 58 additions & 10 deletions test/TritonIntelGPU/rewrite-tensor-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
%23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args(%arg11 = %cst, %arg12 = %18, %arg13 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
Expand All @@ -59,8 +59,8 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
%25 = arith.extsi %arg9 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #[[DPAS]]>>
%26 = tt.make_tensor_ptr %arg3, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #dpas>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #dpas>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #dpas>>
%28 = arith.addf %23#0, %27 : tensor<256x256xf32, #dpas>
%29 = arith.truncf %28 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>

Expand Down Expand Up @@ -125,10 +125,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
Expand Down Expand Up @@ -335,3 +335,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
tt.return
}
}

// -----

// COM: Case 5:
// COM: Check that a make tensor ptr with no loads is properly removed
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
// CHECK: @matmul_kernel_with_block_pointers
%c4_i32 = arith.constant 4 : i32
%c256_i32 = arith.constant 256 : i32
%c1024_i64 = arith.constant 1024 : i64
%c5120_i64 = arith.constant 5120 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c4096_i64 = arith.constant 4096 : i64
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c5120_i32 = arith.constant 5120 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
%0 = tt.get_program_id x : i32
%1 = arith.divsi %0, %c64_i32 : i32
%2 = arith.muli %1, %c4_i32 : i32
%3 = arith.subi %c4_i32, %2 : i32
%4 = arith.minsi %3, %c4_i32 : i32
%5 = arith.remsi %0, %4 : i32
%6 = arith.addi %2, %5 : i32
%7 = arith.remsi %0, %c64_i32 : i32
%8 = arith.divsi %7, %4 : i32
%9 = arith.muli %6, %c256_i32 : i32
// CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
%11 = arith.muli %8, %c256_i32 : i32
// CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>) : i32 {
// CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%19 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
%20 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
scf.yield %arg4, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
}
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
%15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass
return;
}

if (fastChangeDim == rank - 2 &&
tensorType.getElementTypeBitWidth() == 8) {
// TODO: column major layout w/ fp8 has performance regression
return;
}

if (fastChangeDim >= (rank - 2)) {
// HW 2D block read instruction only supports contiguous access.
Value fastChangeStride = strides[fastChangeDim];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ namespace {
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
/// - its pitch is not divisible by Qword bitwidth
/// - it is not contiguous in memory
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByStoreOp,
const bool isUsedByBlockLoadOp) {
LDBG("Considering removal of: " << op);
if (!op->getParentOfType<ModuleOp>()->hasAttr(
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
Expand All @@ -45,61 +46,19 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
LDBG("Op ptr type: " << ptrType);
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
LDBG("Op tensor type: " << tensorType);

if (!ttgi::hasDotDpasEncoding(tensorType) &&
!(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used "
"by load or store op with DPAS layout");
return true;
}

TypedValue<triton::PointerType> base = op.getBase();
Operation::operand_range shape = op.getShape();
unsigned rank = shape.size();
assert(rank > 1 && "Expecting tensor with rank > 1");
Operation::operand_range strides = op.getStrides();
Operation::operand_range offsets = op.getOffsets();
ArrayRef<int32_t> order = op.getOrder();
ArrayRef<int64_t> tensorShape = tensorType.getShape();

int fastChangeDim = -1;
for (size_t i = 0; i < strides.size(); ++i) {
if (ttgi::isConstant(strides[i], 1)) {
fastChangeDim = i;
break;
}
}

LDBG("fastChangeDim: " << fastChangeDim);
if (fastChangeDim < 0) {
LDBG("Marked for removal: fast changing dimension not found");
return true;
}

LDBG("Tensor type element type bit width: "
<< tensorType.getElementTypeBitWidth());
if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) {
// TODO: column major layout w/ fp8 has performance regression
LDBG("Marked for removal: column major layout with fp8 element type");
return true;
}

// HW 2D block read instruction has restriction on pitch divisibility
if (fastChangeDim >= (rank - 2)) {
auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
LDBG("Pitch: " << pitch);
// Across Intel platforms, the strictest pitch restriction is to be a
// multiple of OWord(128 bits).
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) {
LDBG("Marked for removal: cannot use block read/write instructions");
return true;
}

LDBG("Used by store op? " << isUsedByStoreOp);
LDBG("Used by block load op? " << isUsedByBlockLoadOp);

LDBG("hasDpasEncoding: " << ttgi::hasDpasEncoding(tensorType));
if (isUsedByBlockLoadOp ||
(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType))) {
LDBG("Tensor has DPAS layout or is used by load/store op with DPAS layout, "
"skipping removal");
return false;
}

LDBG("Marked for removal: fall-trough");

LDBG("Marked for removal: make tensor ptr op is not used by block load op or "
"by store op with DPAS layout");
return true;
}

Expand Down Expand Up @@ -715,28 +674,73 @@ class TritonIntelGPURewriteTensorPointerPass
void runOnOperation() override {
ModuleOp mod = getOperation();

auto usedByLoadOrStoreOp = [](Value val) {
return llvm::any_of(val.getUsers(), [](Operation *user) {
return isa<tt::LoadOp, tt::StoreOp>(user);
});
};
DenseSet<Operation *> tensorPointersToRemove;
mod.walk([&](tt::MakeTensorPtrOp makeTensorPtrOp) {
tensorPointersToRemove.insert(makeTensorPtrOp);
DenseSet<Operation *> workingSet;

auto markTensorPointerForRemoval =
[this](Value val, bool isUsedByLoadOrStoreOp = false) {
if (tt::isTensorPointerType(val.getType())) {
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp))
valueToRemove.insert(val);
LDBG("Considering: " << makeTensorPtrOp);
Value result = makeTensorPtrOp.getResult();
for (auto user : result.getUsers()) {
workingSet.insert(user);
}
while (!workingSet.empty()) {
auto crtOpItr = workingSet.begin();
auto crtOp = *crtOpItr;
LDBG("Processing op: " << *crtOp);
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
if (!shouldRemove(
makeTensorPtrOp,
/*isUsedByStoreOp=*/isa<tt::StoreOp>(crtOp),
/*isBlockLoad=*/
isa<tt::LoadOp>(crtOp) &&
crtOp->hasAttr(
ttgi::TritonIntelGPUDialect::getBlockIOAttrName()))) {
tensorPointersToRemove.erase(makeTensorPtrOp);
return WalkResult::advance();
}
};
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
for (auto [arg, blockArg] :
llvm::zip(forOp.getInitArgs(),
forOp.getBody()->getArguments().drop_front(
forOp.getNumInductionVars()))) {
if (arg == makeTensorPtrOp) {
// add users of block arg
for (auto user : blockArg.getUsers()) {
workingSet.insert(user);
}
}
}
} else if (crtOp->getNumResults() > 0) {
// TODO: should we handle more than one result?
auto crtOpResult = crtOp->getResult(0);
LDBG("Not a load store and not a loop, adding users to working "
"set.");
for (auto user : crtOpResult.getUsers()) {
workingSet.insert(user);
}
}
workingSet.erase(crtOpItr);
}
return WalkResult::advance();
});

auto markTensorPointerForRemoval = [this,
&tensorPointersToRemove](Value val) {
if (tt::isTensorPointerType(val.getType())) {
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
if (tensorPointersToRemove.count(makeTensorPtrOp)) {
valueToRemove.insert(val);
}
}
};

mod.walk([&](Operation *op) {
if (isa<tt::MakeTensorPtrOp>(op)) {
Value result = op->getResult(0);
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
markTensorPointerForRemoval(result);
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
markTensorPointerForRemoval(op->getOperand(0),
isa<tt::LoadOp, tt::StoreOp>(op));
markTensorPointerForRemoval(op->getOperand(0));
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
for (auto arg : forOp.getInitArgs())
markTensorPointerForRemoval(arg);
Expand All @@ -752,7 +756,7 @@ class TritonIntelGPURewriteTensorPointerPass
else {
DBGS() << "Values to remove: ";
for (auto val : valueToRemove)
DBGS() << val;
DBGS() << val << "\n";
}
});

Expand Down

0 comments on commit 734e33c

Please sign in to comment.