Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify pass pipeline to allow lowering tt.load to 2DBlockRead #1061

Merged
merged 11 commits into from
May 10, 2024
8 changes: 3 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Value ptr = getMemAccessPtr(curr);
if (!ptr)
return;
// We only convert `tensor<tt.ptr<>>` or `tt.ptr<tensor<>>` load/store
bool isPtrTensor = false, isTensorPointer = false;
// We only convert `tensor<tt.ptr<>>` load/store
etiotto marked this conversation as resolved.
Show resolved Hide resolved
bool isPtrTensor = false;
if (auto tensorType = dyn_cast<RankedTensorType>(ptr.getType()))
isPtrTensor = isa<PointerType>(tensorType.getElementType());
if (auto ptrType = dyn_cast<PointerType>(ptr.getType()))
isTensorPointer = isa<RankedTensorType>(ptrType.getPointeeType());
if (!isPtrTensor && !isTensorPointer)
if (!isPtrTensor)
return;
auto mod = curr->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
Expand Down
6 changes: 4 additions & 2 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device):
pytest.xfail("Flash attention bfloat16 not supported in interpreter mode")

if device == "xpu":
if D_HEAD != 16:
pytest.skip("FIXME: Enable larger problem sizes when tl.dot uses DPAS")
if D_HEAD > 32:
pytest.skip("FIXME: results precision issue")

# Pytorch does not support Half data type for matmul operation hence the skip
if device == 'cpu':
Expand Down Expand Up @@ -61,6 +61,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device):
tri_dq, q.grad = q.grad.clone(), None
# compare
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
if device == "xpu" and dtype != torch.bfloat16:
atol = 7e-2
torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_out), dim=0),
torch.nn.functional.normalize(torch.flatten(tri_out), dim=0), atol=atol, rtol=0)
torch.testing.assert_close(torch.nn.functional.normalize(torch.flatten(ref_dv), dim=0),
Expand Down
26 changes: 5 additions & 21 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,6 @@ class XPUBackend(BaseBackend):
# Experimental pass pipeline for kernels using block pointers.
class Experimental:

@staticmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Now the experimental pass has the same implementation as the default path for make_ttir.

def make_ttir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod

@staticmethod
def make_ttgir(mod, metadata, opt, device_arch):
pm = ir.pass_manager(mod.context)
Expand Down Expand Up @@ -150,13 +136,9 @@ def load_dialects(self, ctx):

@staticmethod
def make_ttir(mod, metadata, opt):
if XPUOptions.isBlockPtrEnabled:
return XPUBackend.Experimental.make_ttir(mod, metadata, opt)

pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
Expand All @@ -175,13 +157,15 @@ def make_ttgir(mod, metadata, opt, device_arch):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas)

# optimize TTGIR
passes.ttgpuir.add_coalesce(pm)
intel.passes.ttgpuir.add_accelerate_matmul(pm, device_arch)
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm, device_arch)

intel.passes.ttgpuir.add_accelerate_matmul(pm, device_arch)
passes.ttgpuir.add_coalesce(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.common.add_cse(pm)
passes.ttgpuir.add_prefetch(pm)
Expand Down
162 changes: 160 additions & 2 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
etiotto marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,82 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
}
}

static SmallVector<SmallVector<unsigned>>
emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout,
RankedTensorType type) {
auto dpasLayout = dotLayout.getParent().dyn_cast<DpasEncodingAttr>();
if (!dpasLayout) {
llvm::errs() << "dotLayout: " << dotLayout << "\n";
llvm_unreachable("unsupported parent layout in emitOffsetForDotOpLayout");
}

ArrayRef<int64_t> shape = type.getShape();
SmallVector<SmallVector<unsigned>> offsets;
SmallVector<int64_t> shapePerCTA = triton::gpu::getShapePerCTA(type);

unsigned opIdx = dotLayout.getOpIdx();
SmallVector<int64_t> numReps =
dpasLayout.getDPASRepetitions(shapePerCTA, opIdx);
SmallVector<unsigned> warpShape =
(opIdx == 0) ? dpasLayout.getShapeA() : dpasLayout.getShapeB();

unsigned warpSize = triton::gpu::getWarpSize(dpasLayout);
unsigned numElemPerInstPerThread = product<unsigned>(warpShape) / warpSize;

unsigned systolicDepth = dpasLayout.getSystolicDepth();
unsigned repeatCount = dpasLayout.getRepeatCount();
unsigned executionSize = dpasLayout.getExecutionSize();
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();

unsigned rowsPerWarp, numElemPerInstPerRowPerThread;
switch (opIdx) {
case 0: {
assert((opsPerChannel == 1 || opsPerChannel == 2 || opsPerChannel == 4) &&
"invalid opsPerChannel number.");
SmallVector<unsigned> shapeA = dpasLayout.getShapeA();
// Unlike the operand B, to pack the value to i16 for scalar bit width
// <=16.
unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
unsigned packedColNum = shapeA[1] / packedOpsPerLane;
if (warpSize < packedColNum)
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the threads required per row for A operand.");

rowsPerWarp = warpSize / packedColNum;
numElemPerInstPerRowPerThread = packedOpsPerLane;
} break;
case 1: {
if (warpSize < executionSize)
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the execution size for B operand.");

rowsPerWarp = warpSize / executionSize;
rowsPerWarp = rowsPerWarp * opsPerChannel;
numElemPerInstPerRowPerThread = 1;
} break;
}

SmallVector<unsigned> shapePerCTATile =
triton::gpu::getShapePerCTATile(dotLayout);
int64_t numRepOuter = numReps[opIdx];
int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0];
for (int dimOuter = 0; dimOuter < numRepOuter; ++dimOuter)
for (int k = 0; k < numRepK; ++k)
for (unsigned elemId = 0; elemId < numElemPerInstPerThread; ++elemId) {
uint32_t repRowIndex = shapePerCTATile[0] * (opIdx == 0 ? dimOuter : k);
uint32_t repColIndex = shapePerCTATile[1] * (opIdx == 0 ? k : dimOuter);
uint32_t elemRowIndex =
(elemId / numElemPerInstPerRowPerThread) * rowsPerWarp;
uint32_t elemColIndex = elemId % numElemPerInstPerRowPerThread;
offsets.push_back(
{repRowIndex + elemRowIndex, repColIndex + elemColIndex});
}

return offsets;
}

static SmallVector<SmallVector<unsigned>>
emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout,
RankedTensorType type) {
Expand All @@ -186,6 +262,85 @@ emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout,
// -----------------------------------------------------------------------
// Dpas layout indices
// -----------------------------------------------------------------------
static SmallVector<Value>
emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
const DotOperandEncodingAttr &dotLayout,
RankedTensorType type) {
auto dpasLayout = dotLayout.getParent().dyn_cast<DpasEncodingAttr>();
if (!dpasLayout) {
llvm::errs() << "dotLayout: " << dotLayout << "\n";
llvm_unreachable(
"unsupported parent layout in emitBaseIndexForDotOpLayout");
}

Value threadId = getThreadId(rewriter, loc);
unsigned warpSize = triton::gpu::getWarpSize(dpasLayout);
Value warpId = udiv(threadId, i32_val(warpSize));
Value laneId = urem(threadId, i32_val(warpSize));

const SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
SmallVector<unsigned> order = triton::gpu::getOrder(dpasLayout);
SmallVector<int64_t> shapePerCTA = triton::gpu::getShapePerCTA(type);

unsigned opIdx = dotLayout.getOpIdx();
SmallVector<unsigned> warpShape =
(opIdx == 0) ? dpasLayout.getShapeA() : dpasLayout.getShapeB();
SmallVector<int64_t> numReps =
dpasLayout.getDPASRepetitions(shapePerCTA, opIdx);
SmallVector<Value> multiDimWarpId =
mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order);

Value rowWarpId =
urem(multiDimWarpId[0],
i32_val(mlir::ceil<unsigned>(shapePerCTA[0], warpShape[0])));
Value colWarpId =
urem(multiDimWarpId[1],
i32_val(mlir::ceil<unsigned>(shapePerCTA[1], warpShape[1])));
Value rowWarpOffset = mul(rowWarpId, i32_val(warpShape[0]));
Value colWarpOffset = mul(colWarpId, i32_val(warpShape[1]));

// Compute the 2-dim coordinates of the first element in the warp operated
// own by this thread.
unsigned systolicDepth = dpasLayout.getSystolicDepth();
unsigned repeatCount = dpasLayout.getRepeatCount();
unsigned executionSize = dpasLayout.getExecutionSize();
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();

Value laneRowIndex, laneColIndex;
switch (opIdx) {
case 0: {
assert((opsPerChannel == 1 || opsPerChannel == 2 || opsPerChannel == 4) &&
"invalid opsPerChannel number.");
SmallVector<unsigned> shapeA = dpasLayout.getShapeA();
// Unlike the operand B, to pack the value to i16 for scalar bit width
// <=16.
unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
unsigned packedColNum = shapeA[1] / packedOpsPerLane;
if (warpSize < packedColNum)
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the threads required per row for A operand.");

laneRowIndex = udiv(laneId, i32_val(packedColNum));
laneColIndex = urem(laneId, i32_val(packedColNum));
laneColIndex = mul(laneColIndex, i32_val(packedOpsPerLane));
} break;
case 1: {
if (warpSize < executionSize)
llvm::report_fatal_error(
"DpasEncodingAttr sub-group size could not "
"be smaller than the execution size for B operand.");

laneRowIndex = udiv(laneId, i32_val(executionSize));
laneRowIndex = mul(laneRowIndex, i32_val(opsPerChannel));
laneColIndex = urem(laneId, i32_val(executionSize));
} break;
}

SmallVector<Value> multiDimBase = {add(laneRowIndex, rowWarpOffset),
add(laneColIndex, colWarpOffset)};
return multiDimBase;
}

static SmallVector<Value>
emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
Expand Down Expand Up @@ -279,6 +434,8 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
result.erase(result.begin() + sliceLayout.getDim());
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
return result;
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
result = emitBaseIndexForDotOpLayout(loc, rewriter, dotLayout, type);
} else {
return mlir::emitBaseIndexForLayoutImpl(loc, rewriter, target, layout, type,
withCTAOffset);
Expand Down Expand Up @@ -324,9 +481,10 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,

inline SmallVector<SmallVector<unsigned>>
emitOffsetForLayout(Attribute layout, RankedTensorType type) {
if (auto dpasLayout = layout.dyn_cast<DpasEncodingAttr>()) {
if (auto dpasLayout = layout.dyn_cast<DpasEncodingAttr>())
return emitOffsetForDpasLayout(dpasLayout, type);
}
if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return emitOffsetForDotOpLayout(dotLayout, type);
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return ::intel::emitOffsetForSliceLayout(sliceLayout, type);
return mlir::emitOffsetForLayout(layout, type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"

#include "llvm/Support/Debug.h"

#include <stack>

using namespace mlir;
Expand All @@ -16,6 +18,8 @@ namespace mlir::triton::gpu::intel {
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
} // namespace mlir::triton::gpu::intel

#define DEBUG_TYPE "tritonintelgpu-rewrite-tensor-pointer"

namespace {

/// Check if given value is divisible by the divisor.
Expand Down Expand Up @@ -346,6 +350,7 @@ class TritonIntelGPURewriteTensorPointerPass
std::stack<Operation *> &eraser) {
if (!valueToRemove.count(op.getResult()))
return nullptr;

// Save info for later use
auto ptrType = cast<tt::PointerType>(op.getType());
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
Expand Down Expand Up @@ -734,6 +739,16 @@ class TritonIntelGPURewriteTensorPointerPass
}
});

LLVM_DEBUG({
if (valueToRemove.empty())
llvm::dbgs() << "No tensor pointer to remove\n";
else {
llvm::dbgs() << "Values to remove: \n";
for (auto val : valueToRemove)
llvm::dbgs() << val << "\n";
}
});

// NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because
// MLIR does not support one-multiple value mapping. For example, if we use
// `ConversionPatternRewriter`, we can not make a type converter, which
Expand Down