diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c0b84f53c..98d7033fb 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4286,8 +4286,9 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) k = kernel[(1, )](input, actual, shape[0], shape[1]) - assert k.asm['ttgir'].count( - 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + if not is_xpu(): + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) diff --git a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir index 761c82717..596f29fea 100644 --- a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir +++ b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir @@ -44,10 +44,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > %23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args(%arg11 = %cst, %arg12 = %18, %arg13 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - %28 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg13 {boundaryCheck = array} : !tt.ptr> + // CHECK: tt.load {{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + // CHECK: tt.load {{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + %28 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %29 = tt.load %arg13 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]> // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> @@ -59,8 +59,8 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa %25 = arith.extsi %arg9 : i32 to i64 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > %26 = tt.make_tensor_ptr %arg3, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array} : > - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr> - %27 = tt.load %26 {boundaryCheck = array} : !tt.ptr> + // CHECK: tt.load {{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %27 = tt.load %26 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %28 = arith.addf %23#0, %27 : tensor<256x256xf32, #dpas> %29 = arith.truncf %28 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> @@ -125,10 +125,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + // CHECK: tt.load {{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + // CHECK: tt.load {{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]> // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> @@ -335,3 +335,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 tt.return } } + +// ----- + +// COM: Case 5: +// COM: Check that a make tensor ptr with no loads is properly removed +// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + // CHECK: @matmul_kernel_with_block_pointers + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : >> + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array} : >> + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> + // CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> + %19 = tt.advance %arg5, [%c0_i32, %c32_i32] : >> + %20 = tt.advance %arg6, [%c32_i32, %c0_i32] : >> + scf.yield %arg4, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr>>, !tt.ptr>> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 8361675b5..9a0b5e4f9 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -71,6 +71,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass return; } + if (fastChangeDim == rank - 2 && + tensorType.getElementTypeBitWidth() == 8) { + // TODO: column major layout w/ fp8 has performance regression + return; + } + if (fastChangeDim >= (rank - 2)) { // HW 2D block read instruction only supports contiguous access. Value fastChangeStride = strides[fastChangeDim]; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp index 801982320..0857ecba0 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp @@ -33,7 +33,8 @@ namespace { /// - it does not have Dpas layout or Dot layout (with Dpas layout as parent) /// - its pitch is not divisible by Qword bitwidth /// - it is not contiguous in memory -bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) { +bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByStoreOp, + const bool isUsedByBlockLoadOp) { LDBG("Considering removal of: " << op); if (!op->getParentOfType()->hasAttr( ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) { @@ -45,61 +46,19 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) { LDBG("Op ptr type: " << ptrType); auto tensorType = cast(ptrType.getPointeeType()); LDBG("Op tensor type: " << tensorType); - - if (!ttgi::hasDotDpasEncoding(tensorType) && - !(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) { - LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used " - "by load or store op with DPAS layout"); - return true; - } - - TypedValue base = op.getBase(); - Operation::operand_range shape = op.getShape(); - unsigned rank = shape.size(); - assert(rank > 1 && "Expecting tensor with rank > 1"); - Operation::operand_range strides = op.getStrides(); - Operation::operand_range offsets = op.getOffsets(); - ArrayRef order = op.getOrder(); - ArrayRef tensorShape = tensorType.getShape(); - - int fastChangeDim = -1; - for (size_t i = 0; i < strides.size(); ++i) { - if (ttgi::isConstant(strides[i], 1)) { - fastChangeDim = i; - break; - } - } - - LDBG("fastChangeDim: " << fastChangeDim); - if (fastChangeDim < 0) { - LDBG("Marked for removal: fast changing dimension not found"); - return true; - } - - LDBG("Tensor type element type bit width: " - << tensorType.getElementTypeBitWidth()); - if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) { - // TODO: column major layout w/ fp8 has performance regression - LDBG("Marked for removal: column major layout with fp8 element type"); - return true; - } - - // HW 2D block read instruction has restriction on pitch divisibility - if (fastChangeDim >= (rank - 2)) { - auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1]; - LDBG("Pitch: " << pitch); - // Across Intel platforms, the strictest pitch restriction is to be a - // multiple of OWord(128 bits). - if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) { - LDBG("Marked for removal: cannot use block read/write instructions"); - return true; - } - + LDBG("Used by store op? " << isUsedByStoreOp); + LDBG("Used by block load op? " << isUsedByBlockLoadOp); + + LDBG("hasDpasEncoding: " << ttgi::hasDpasEncoding(tensorType)); + if (isUsedByBlockLoadOp || + (isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType))) { + LDBG("Tensor has DPAS layout or is used by load/store op with DPAS layout, " + "skipping removal"); return false; } - LDBG("Marked for removal: fall-trough"); - + LDBG("Marked for removal: make tensor ptr op is not used by block load op or " + "by store op with DPAS layout"); return true; } @@ -715,28 +674,73 @@ class TritonIntelGPURewriteTensorPointerPass void runOnOperation() override { ModuleOp mod = getOperation(); - auto usedByLoadOrStoreOp = [](Value val) { - return llvm::any_of(val.getUsers(), [](Operation *user) { - return isa(user); - }); - }; + DenseSet tensorPointersToRemove; + mod.walk([&](tt::MakeTensorPtrOp makeTensorPtrOp) { + tensorPointersToRemove.insert(makeTensorPtrOp); + DenseSet workingSet; - auto markTensorPointerForRemoval = - [this](Value val, bool isUsedByLoadOrStoreOp = false) { - if (tt::isTensorPointerType(val.getType())) { - tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val); - if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp)) - valueToRemove.insert(val); + LDBG("Considering: " << makeTensorPtrOp); + Value result = makeTensorPtrOp.getResult(); + for (auto user : result.getUsers()) { + workingSet.insert(user); + } + while (!workingSet.empty()) { + auto crtOpItr = workingSet.begin(); + auto crtOp = *crtOpItr; + LDBG("Processing op: " << *crtOp); + if (isa(crtOp)) { + if (!shouldRemove( + makeTensorPtrOp, + /*isUsedByStoreOp=*/isa(crtOp), + /*isBlockLoad=*/ + isa(crtOp) && + crtOp->hasAttr( + ttgi::TritonIntelGPUDialect::getBlockIOAttrName()))) { + tensorPointersToRemove.erase(makeTensorPtrOp); + return WalkResult::advance(); } - }; + } else if (auto forOp = dyn_cast(crtOp)) { + for (auto [arg, blockArg] : + llvm::zip(forOp.getInitArgs(), + forOp.getBody()->getArguments().drop_front( + forOp.getNumInductionVars()))) { + if (arg == makeTensorPtrOp) { + // add users of block arg + for (auto user : blockArg.getUsers()) { + workingSet.insert(user); + } + } + } + } else if (crtOp->getNumResults() > 0) { + // TODO: should we handle more than one result? + auto crtOpResult = crtOp->getResult(0); + LDBG("Not a load store and not a loop, adding users to working " + "set."); + for (auto user : crtOpResult.getUsers()) { + workingSet.insert(user); + } + } + workingSet.erase(crtOpItr); + } + return WalkResult::advance(); + }); + + auto markTensorPointerForRemoval = [this, + &tensorPointersToRemove](Value val) { + if (tt::isTensorPointerType(val.getType())) { + tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val); + if (tensorPointersToRemove.count(makeTensorPtrOp)) { + valueToRemove.insert(val); + } + } + }; mod.walk([&](Operation *op) { if (isa(op)) { Value result = op->getResult(0); - markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result)); + markTensorPointerForRemoval(result); } else if (isa(op)) { - markTensorPointerForRemoval(op->getOperand(0), - isa(op)); + markTensorPointerForRemoval(op->getOperand(0)); } else if (auto forOp = dyn_cast(op)) { for (auto arg : forOp.getInitArgs()) markTensorPointerForRemoval(arg); @@ -752,7 +756,7 @@ class TritonIntelGPURewriteTensorPointerPass else { DBGS() << "Values to remove: "; for (auto val : valueToRemove) - DBGS() << val; + DBGS() << val << "\n"; } });