Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into egor/gemm_trans
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor-Krivov committed Oct 8, 2024
2 parents eae3600 + 2202ca7 commit 8cdd681
Show file tree
Hide file tree
Showing 23 changed files with 569 additions and 209 deletions.
2 changes: 2 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool atomicNeedsSharedMemory(Value result);

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
Expand Down
76 changes: 73 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,75 @@ bool supportMMA(Value value, int version) {
(elemTy.isInteger(8) && version >= 2);
}

bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
return false;
auto parentLayout =
dyn_cast<BlockedEncodingAttr>(dotOperandLayout.getParent());
if (parentLayout == nullptr)
return false;
auto opShape = srcTy.getShape();
auto rank = opShape.size();

int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
auto ctaLayout = blockedLayout.getCTALayout();

// The following logic checks that a source blocked layout matches a
// destination dot operand layout. This means that given tensor in source
// layout could be converted into destination layout without any data movement
// between registers or threads.
//
// It is considered a match if
// 1) Each thread in source layout holds a whole copy of all elements along
// the K dimension of a tensor
// 2) Distribution of data along all other non-K dimensions(Batch/M/N)
// matches between source and destination parent layouts.
//
// First condition comes from the property of dot operand layout with Blocked
// parent: size per threads along K dimension equals size of the tensor along
// K. Second condition comes from other property: dot operand layout
// inherits non-K dimensions from it's parent layout.
//
// clang-format off
//
// For example, following conversion is a no op:
// tensor<128x32xf16, #blocked<{sizePerThread = [2, 32], threadsPerWarp = [32, 1]}>>
// ->
// tensor<128x32xf16, #dot_op<{opIdx=0, parent=#blocked<{sizePerThread = [2, 8], threadsPerWarp = [32, 1]}>>>
//
// clang-format on
bool ctaLayoutCompatible =
ctaLayout.getCTASplitNum()[kDim] == 1 &&
blockedLayout.getCTALayout() == parentLayout.getCTALayout();
bool threadHoldsWholeKDim =
blockedLayout.getSizePerThread()[kDim] == opShape[kDim];
bool nonKDimCompatible =
blockedLayout.getOrder() == parentLayout.getOrder() &&
blockedLayout.getSizePerThread()[nonKDim] ==
parentLayout.getSizePerThread()[nonKDim] &&
blockedLayout.getThreadsPerWarp()[nonKDim] ==
parentLayout.getThreadsPerWarp()[nonKDim] &&
blockedLayout.getWarpsPerCTA()[nonKDim] ==
parentLayout.getWarpsPerCTA()[nonKDim];
bool matrixDimsCompatible =
ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible;
if (rank == 2)
return matrixDimsCompatible;

// additional check for batch dimension if it is present
assert(rank == 3);
bool bDimCompatible =
blockedLayout.getSizePerThread()[0] ==
parentLayout.getSizePerThread()[0] &&
blockedLayout.getThreadsPerWarp()[0] ==
parentLayout.getThreadsPerWarp()[0] &&
blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0];
return matrixDimsCompatible && bDimCompatible;
}

bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
Expand Down Expand Up @@ -632,13 +701,14 @@ bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
}

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut` and
// `isMfmaToDotShortcut`) once they're fully subsumed by the linear-layout
// checks.
// TODO(jlebar): Remove these special cases (`isMmaToDotShortcut`,
// `isBlockedToDotShortcut` and `isMfmaToDotShortcut`) once they're fully
// subsumed by the linear-layout checks.
// TODO(Keren): We didn't check `cvtNeedsWarpShuffle` here because it's not
// supported yet in Triton's backend.
return !cvtReordersRegisters(srcTy, dstTy) &&
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}
Expand Down
32 changes: 32 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,36 @@ struct ConvertLayoutOpConversion
const TargetInfoBase &targetInfo;
};

struct ConvertLayoutOpBlockedToDotOpShortcutConversion
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
const TargetInfoBase &targetInfo;
explicit ConvertLayoutOpBlockedToDotOpShortcutConversion(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

LogicalResult
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = op.getContext();

const auto &shape = op.getType().getShape();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto dstDotEncoding = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (!dstDotEncoding)
return failure();
if (!isa<BlockedEncodingAttr>(srcTy.getEncoding()) ||
!isa<BlockedEncodingAttr>(dstDotEncoding.getParent()))
return failure();
if (cvtNeedsSharedMemory(srcTy, dstTy))
return failure();
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
};

struct ConvertLayoutOpUsingLinearLayoutsConversion
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
const TargetInfoBase &targetInfo;
Expand Down Expand Up @@ -657,5 +687,7 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
// one left.
mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
typeConverter, targetInfo, patterns, benefit.getBenefit() + 1);
patterns.add<ConvertLayoutOpBlockedToDotOpShortcutConversion>(
typeConverter, targetInfo, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
OpBuilder builder(cvtOp);
auto srcType = cast<RankedTensorType>(cvtOp.getSrc().getType());
auto dstType = cast<RankedTensorType>(cvtOp.getType());
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcBlocked =
dyn_cast<triton::gpu::BlockedEncodingAttr>(srcType.getEncoding());
auto dstDotOp =
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ class RewriteTensorPointerPass
auto newForOp = builder.create<scf::ForOp>(op.getLoc(), op.getLowerBound(),
op.getUpperBound(), op.getStep(),
newIterOperands);
newForOp->setAttrs(op->getAttrs());

// Create value mapping. Note that for tensor pointers, we use identity
// mapping. It may refer to a value in the old loop, but we will rewrite it
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <vector>

#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -827,6 +828,9 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
return dotOperandMfmaToLinearLayout(*this, shape);
}
if (auto dpasLayout = llvm::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
return dotOperandDpasToLinearLayout(*this, shape);
}

return std::nullopt;
}
Expand Down
18 changes: 2 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,8 @@ class TritonGPUReduceDataDuplicationPass
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (!dstDotOp)
return;
if (auto srcMmaEncoding =
dyn_cast<triton::gpu::NvidiaMmaEncodingAttr>(srcEncoding)) {

if (srcMmaEncoding.getVersionMajor() != 2 ||
(srcMmaEncoding.getWarpsPerCTA()[1] == 1 &&
dstDotOp.getParent() == srcMmaEncoding))
return;
}
if (auto srcMfmaEncoding =
dyn_cast<triton::gpu::AMDMfmaEncodingAttr>(srcEncoding)) {

if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 &&
srcMfmaEncoding.getIsTransposed() &&
dstDotOp.getParent() == srcMfmaEncoding)
return;
}
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
SmallVector<unsigned> sharedOrder;
Expand Down
3 changes: 2 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3375,7 +3375,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
if is_hip():
# hip does not support tf32 precision, so use ieee for all tests
input_precision = "ieee"
if "gfx11" in triton.runtime.driver.active.get_current_target().arch:
arch = triton.runtime.driver.active.get_current_target().arch
if "gfx11" in arch or "gfx12" in arch:
if in_dtype_str == "float32":
pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d")
if out_dtype_str == "float16":
Expand Down
104 changes: 88 additions & 16 deletions test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
@@ -1,33 +1,105 @@
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions=arch=gfx1130 | FileCheck %s
// RUN: triton-opt %s --split-input-file --decompose-unsupported-amd-conversions | FileCheck %s

// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK: wmma_to_wmma_dot_op
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK-LABEL: wmma_to_wmma_dot_op
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[WMMA]]> -> tensor<16x16xf16, #[[BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
}
}

// -----

// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK: wmma_to_wmma_dot3d_op
// CHECK: #[[$BLOCKED:.+]] = #triton_gpu.blocked<{{.*}}>
// CHECK: #[[$WMMA:.+]] = #triton_gpu.amd_wmma<{{.*}}>
// CHECK: #[[$SHARED:.+]] = #triton_gpu.shared<{{.*}}>
// CHECK-LABEL: wmma_to_wmma_dot3d_op
#mma = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) {
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[WMMA]]> -> tensor<2x16x16xf16, #[[BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[WMMA]], kWidth = 16}>>
// CHECK: %[[SRC_BLOCKED:.+]] = triton_gpu.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]>
// CHECK-NEXT: %[[INT_SHARED:.+]] = triton_gpu.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !tt.memdesc<2x16x16xf16, #[[$SHARED]], #triton_gpu.shared_memory>
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = triton_gpu.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
%0 = triton_gpu.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
tt.return
}
}

// -----

// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx1130
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx1130", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func @blocked_to_dot_op_shortcut_gfx1130(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.local_alloc
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.local_alloc
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
tt.return
}
}

// -----

// CHECK-LABEL: blocked_to_dot_op_shortcut_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @blocked_to_dot_op_shortcut_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.local_alloc
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: triton_gpu.local_alloc
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
tt.return
}
}

// -----

// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_elems_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_elems_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.local_load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>
tt.return
}
}

// -----

// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_threads_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 4], warpsPerCTA = [2, 2], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_threads_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.local_load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
tt.return
}
}

// -----

// CHECK-LABEL: neg_blocked_to_dot_op_incompatible_warp_gfx940
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx940", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
// CHECK-NOT: triton_gpu.convert_layout
// CHECK: triton_gpu.local_alloc
// CHECK: triton_gpu.local_load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
tt.return
}
}
Loading

0 comments on commit 8cdd681

Please sign in to comment.