Skip to content

Commit

Permalink
Replace TritonGEN Shuffle with GPU shuffle (intel#2128)
Browse files Browse the repository at this point in the history
The next step in replacing parts of TritonGEN with the `gpu-to-llvm-spv`
pass.
Shuffles are 1-2-1 conversion, but a bit of extra code is needed because
of the limit on the number types supported by the `gpu-to-llvm-spv`
pass, which follows the types supported by the OpenCL SPIR-V Environment
Specification.
  • Loading branch information
FMarno authored Sep 9, 2024
1 parent d017d06 commit 31bd963
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 213 deletions.
58 changes: 0 additions & 58 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -196,64 +196,6 @@ module attributes {

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorlj(i64, i32) -> i64 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorsj(i16, i32) -> i16 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorcj(i8, i32) -> i8 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z17sub_group_shuffleij(i32, i32) -> i32 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z22sub_group_shuffle_downij(i32, i32) -> i32 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z20sub_group_shuffle_upij(i32, i32) -> i32 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorij(i32, i32) -> i32 attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn"]}

llvm.func @triton_gen.sub_group_shuffle() {
// CHECK-LABEL: triton_gen.sub_group_shuffle
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorij([[ZERO]], [[ZERO]]) {{.*}} : (i32, i32) -> i32
// CHECK: llvm.call spir_funccc @_Z20sub_group_shuffle_upij([[ZERO]], [[ZERO]]) {{.*}} : (i32, i32) -> i32
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downij([[ZERO]], [[ZERO]]) {{.*}} : (i32, i32) -> i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleij([[ZERO]], [[ZERO]]) {{.*}} : (i32, i32) -> i32
%1 = triton_gen.sub_group_shuffle xor %0, %0 : i32
%2 = triton_gen.sub_group_shuffle up %0, %0 : i32
%3 = triton_gen.sub_group_shuffle down %0, %0 : i32
%4 = triton_gen.sub_group_shuffle idx %0, %0 : i32

// CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj([[ZERO1]], [[ZERO]]) {{.*}} : (i8, i32) -> i8
%5 = llvm.mlir.constant(0 : i8) : i8
%6 = triton_gen.sub_group_shuffle xor %5, %0 : i8

// CHECK: [[ZERO2:%.*]] = llvm.mlir.constant(0 : i16) : i16
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorsj([[ZERO2]], [[ZERO]]) {{.*}} : (i16, i32) -> i16
%7 = llvm.mlir.constant(0 : i16) : i16
%8 = triton_gen.sub_group_shuffle xor %7, %0 : i16

// CHECK: [[ZERO3:%.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorlj([[ZERO3]], [[ZERO]]) {{.*}} : (i64, i32) -> i64
%9 = llvm.mlir.constant(0 : i64) : i64
%10 = triton_gen.sub_group_shuffle xor %9, %0 : i64

// CHECK: [[ZERO4:%.*]] = llvm.mlir.constant(0.000000e+00 : f16) : f16
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorDhj([[ZERO4]], [[ZERO]]) {{.*}} : (f16, i32) -> f16
%11 = llvm.mlir.constant(0.0 : f16) : f16
%12 = triton_gen.sub_group_shuffle xor %11, %0 : f16

// CHECK: [[ZERO5:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xorfj([[ZERO5]], [[ZERO]]) {{.*}} : (f32, i32) -> f32
%13 = llvm.mlir.constant(0.0 : f32) : f32
%14 = triton_gen.sub_group_shuffle xor %13, %0 : f32

// CHECK: [[ZERO6:%.*]] = llvm.mlir.constant(0.000000e+00 : f64) : f64
// CHECK: llvm.call spir_funccc @_Z21sub_group_shuffle_xordj([[ZERO6]], [[ZERO]]) {{.*}} : (f64, i32) -> f64
%15 = llvm.mlir.constant(0.0 : f64) : f64
%16 = triton_gen.sub_group_shuffle xor %15, %0 : f64
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {passthrough = ["convergent", "nofree", "nounwind", "willreturn", ["memory", "0"]]}

llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
Expand Down
32 changes: 0 additions & 32 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -113,38 +113,6 @@ module attributes {

// -----

llvm.func @triton_gen.sub_group_shuffle() {
// CHECK-LABEL: triton_gen.sub_group_shuffle
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: %1 = triton_gen.sub_group_shuffle xor %0, %0 : i32
%1 = triton_gen.sub_group_shuffle xor %0, %0 : i32
// CHECK: %2 = triton_gen.sub_group_shuffle up %0, %0 : i32
%2 = triton_gen.sub_group_shuffle up %0, %0 : i32
// CHECK: %3 = triton_gen.sub_group_shuffle down %0, %0 : i32
%3 = triton_gen.sub_group_shuffle down %0, %0 : i32
// CHECK: %4 = triton_gen.sub_group_shuffle idx %0, %0 : i32
%4 = triton_gen.sub_group_shuffle idx %0, %0 : i32
%5 = llvm.mlir.constant(0 : i8) : i8
// CHECK: %6 = triton_gen.sub_group_shuffle xor %5, %0 : i8
%6 = triton_gen.sub_group_shuffle xor %5, %0 : i8
%7 = llvm.mlir.constant(0 : i16) : i16
// CHECK: %8 = triton_gen.sub_group_shuffle xor %7, %0 : i16
%8 = triton_gen.sub_group_shuffle xor %7, %0 : i16
%9 = llvm.mlir.constant(0 : i64) : i64
// CHECK: %10 = triton_gen.sub_group_shuffle xor %9, %0 : i64
%10 = triton_gen.sub_group_shuffle xor %9, %0 : i64
%11 = llvm.mlir.constant(0.0 : f16) : f16
// CHECK: %12 = triton_gen.sub_group_shuffle xor %11, %0 : f16
%12 = triton_gen.sub_group_shuffle xor %11, %0 : f16
%13 = llvm.mlir.constant(0.0 : f32) : f32
// CHECK: %14 = triton_gen.sub_group_shuffle xor %13, %0 : f32
%14 = triton_gen.sub_group_shuffle xor %13, %0 : f32
%15 = llvm.mlir.constant(0.0 : f64) : f64
// CHECK: %16 = triton_gen.sub_group_shuffle xor %15, %0 : f64
%16 = triton_gen.sub_group_shuffle xor %15, %0 : f64
llvm.return
}

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// CHECK: llvm.func @triton_gen.dpas(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) {
// CHECK-NEXT: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,6 @@ def TritonGEN_ScanKindAttr : I32EnumAttr<"ScanKind", "TritonGEN subgroup scan ki
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different subgroup shuffle kinds.
def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN subgroup shuffle kind",
[
I32EnumAttrCase<"XOR", 0, "xor">,
I32EnumAttrCase<"UP", 1, "up">,
I32EnumAttrCase<"DOWN", 2, "down">,
I32EnumAttrCase<"IDX", 3, "idx">
]> {
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different floating-point rounding modes.
def TritonGEN_RoundingModeAttr : I32EnumAttr<"RoundingMode",
"TritonGEN floating-point rounding mode",
Expand Down
25 changes: 0 additions & 25 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -203,31 +203,6 @@ def TritonGEN_SubGroupScanOp : TritonGEN_Op<"sub_group_scan", [
}];
}

def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [
AllTypesMatch<["res", "value"]>]>,
Results<(outs SignlessIntegerOrFloatLike:$res)>,
Arguments<(ins SignlessIntegerOrFloatLike:$value,
I32:$mask,
TritonGEN_ShflKindAttr:$kind)> {
let summary = "Subgroup shuffle";

let description = [{
The `triton_gen.sub_group_shuffle` operation is invoked by different work
items with different values, given by $value. Different work items have
different subgroup local IDs. The shuffle kind, $kind, is given to determine
how to calculate the associated subgroup local ID. It returns the associated
$value for the work item with subgroup local ID equal to:
- $kind == xor, the current invocation’s subgroup local ID xor’ed with $mask.
- $kind == up, the current invocation’s subgroup local ID - $mask.
- $kind == down, the current invocation’s subgroup local ID + $mask.
- $kind == idx, the subgroup local ID $mask.
}];

let assemblyFormat = [{
$kind $value `,` $mask attr-dict `:` type($value)
}];
}

//===----------------------------------------------------------------------===//
// Matrix operations
//===----------------------------------------------------------------------===//
Expand Down
92 changes: 11 additions & 81 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,10 +804,9 @@ struct TritonGENNamedBarrierWaitLowering

struct TritonSubGroupBase {
protected:
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SubGroupReduceOp, TritonGEN::SubGroupScanOp,
TritonGEN::SubGroupShuffleOp>::value>>
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SubGroupReduceOp,
TritonGEN::SubGroupScanOp>::value>>
static Value extend(OpType op, Value val, Type type,
ConversionPatternRewriter &rewriter) {
Location loc = op.getLoc();
Expand All @@ -817,23 +816,14 @@ struct TritonSubGroupBase {
TritonGEN::SubGroupScanOp>::value) {
if (type.isInteger() && bitWidth < 8)
val = zext(i8_ty, val);
} else if constexpr (std::is_same_v<OpType, TritonGEN::SubGroupShuffleOp>) {
if (bitWidth < 8) {
if (!type.isInteger())
val = bitcast(val, int_ty(bitWidth));
val = zext(i8_ty, val);
} else if (isa<BFloat16Type>(type)) {
val = bitcast(val, i16_ty);
}
}

return val;
}

template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SubGroupReduceOp, TritonGEN::SubGroupScanOp,
TritonGEN::SubGroupShuffleOp>::value>>
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SubGroupReduceOp,
TritonGEN::SubGroupScanOp>::value>>
static Value truncate(OpType op, Value val, Type type,
ConversionPatternRewriter &rewriter) {
Location loc = op.getLoc();
Expand All @@ -844,14 +834,6 @@ struct TritonSubGroupBase {
if (type.isInteger() && bitWidth < 8)
val = trunc(type, val);
return val;
} else if constexpr (std::is_same_v<OpType, TritonGEN::SubGroupShuffleOp>) {
if (bitWidth < 8) {
val = trunc(int_ty(bitWidth), val);
if (!type.isInteger())
val = bitcast(val, type);
} else if (isa<BFloat16Type>(type)) {
val = bitcast(val, type);
}
}

return val;
Expand Down Expand Up @@ -962,58 +944,6 @@ struct TritonSubGroupScanLowering
}
};

struct TritonSubGroupShuffleLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupShuffleOp>,
public TritonSubGroupBase {
using ConvertOpToLLVMPattern<
TritonGEN::SubGroupShuffleOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SubGroupShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value val = op.getValue();
auto origTy = val.getType();
val = TritonSubGroupBase::extend(op, op.getValue(), origTy, rewriter);
Value value = val;
Value mask = op.getMask();
TritonGEN::ShflKind kind = op.getKind();

StringRef func;
switch (kind) {
case TritonGEN::ShflKind::XOR:
func = "sub_group_shuffle_xor";
break;
case TritonGEN::ShflKind::UP:
func = "sub_group_shuffle_up";
break;
case TritonGEN::ShflKind::DOWN:
func = "sub_group_shuffle_down";
break;
case TritonGEN::ShflKind::IDX:
func = "sub_group_shuffle";
break;
}
std::string fnName = intel::mangle(func, {value.getType(), i32_ty},
/*isUnsigned=*/{false, true});

intel::AttributeList attrs =
createFunctionAttributes({{llvm::Attribute::NoUnwind, std::nullopt},
{llvm::Attribute::WillReturn, std::nullopt},
{llvm::Attribute::NoFree, std::nullopt},
{llvm::Attribute::Convergent, std::nullopt}},
rewriter.getContext());

Value result = createDeviceFunctionCall(rewriter, fnName, value.getType(),
{value.getType(), mask.getType()},
{value, mask}, attrs)
.getResult();

result = TritonSubGroupBase::truncate(op, result, origTy, rewriter);
rewriter.replaceOp(op, result);
return success();
}
};

//===----------------------------------------------------------------------===//
// Matrix operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1449,11 +1379,11 @@ void mlir::triton::populateTritonGENToLLVMConversionPatterns(
TritonGENBarrierLowering, TritonGENSplitBarrierSignalLowering,
TritonGENSplitBarrierWaitLowering, TritonGENNamedBarrierSignalLowering,
TritonGENNamedBarrierWaitLowering, TritonSubGroupReduceLowering,
TritonSubGroupScanLowering, TritonSubGroupShuffleLowering,
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering,
TritonSIMDBlockReadLowering, TritonSIMDBlockWriteLowering>(
converter, patternBenefitPreferTritonGENLowering);
TritonSubGroupScanLowering, TritonMatrixDPASLowering,
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
TritonMatrix2DBlockPrefetchLowering, TritonSIMDBlockReadLowering,
TritonSIMDBlockWriteLowering>(converter,
patternBenefitPreferTritonGENLowering);
}

void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {
Expand Down
59 changes: 53 additions & 6 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,61 @@
#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"

#include "mlir/Dialect/GPU/IR/GPUDialect.h"

using namespace mlir;
using namespace mlir::triton;

namespace mlir::LLVM::intel {

static Type findShuffleType(RewriterBase &rewriter, Type valType) {
if (valType.isBF16())
return rewriter.getI16Type();

unsigned bitWidth = valType.getIntOrFloatBitWidth();
if (bitWidth < 8)
return rewriter.getI8Type();

assert((valType.isInteger(8) || valType.isInteger(16) ||
valType.isInteger(32) || valType.isInteger(64) || valType.isF16() ||
valType.isF32() || valType.isF64()) &&
"Invalid Shuffle Type");
return valType;
}

static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val,
Value i, TritonGEN::ShflKind mode) {
Type type = val.getType();
return rewriter.create<TritonGEN::SubGroupShuffleOp>(loc, type, val, i, mode);
Value i, mlir::gpu::ShuffleMode mode) {
Type valType = val.getType();
Type shuffleType = findShuffleType(rewriter, valType);

const unsigned bitWidth = valType.getIntOrFloatBitWidth();
if (shuffleType != valType) {
assert(shuffleType.isInteger() &&
"expected to bitcast to an integer for unsupported shuffles");
if (!valType.isInteger()) {
val = bitcast(val, int_ty(bitWidth));
}
if (bitWidth < shuffleType.getIntOrFloatBitWidth()) {
val = zext(shuffleType, val);
}
}

int width = TritonGEN::getSubgroupSize(i.getDefiningOp());
Value widthConstant = i32_val(width);
Value result =
rewriter.create<mlir::gpu::ShuffleOp>(loc, val, i, widthConstant, mode)
.getShuffleResult();

if (shuffleType != valType) {
if (bitWidth < shuffleType.getIntOrFloatBitWidth()) {
result = trunc(int_ty(bitWidth), result);
}
if (!valType.isInteger()) {
result = bitcast(result, valType);
}
}

return result;
}

Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
Expand All @@ -36,19 +82,20 @@ Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,

Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) {
return shuffleCommon(loc, rewriter, val, i32_val(i),
TritonGEN::ShflKind::XOR);
mlir::gpu::ShuffleMode::XOR);
}

Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) {
return shuffleCommon(loc, rewriter, val, i32_val(i), TritonGEN::ShflKind::UP);
return shuffleCommon(loc, rewriter, val, i32_val(i),
mlir::gpu::ShuffleMode::UP);
}

Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) {
return shuffleIdx(loc, rewriter, val, i32_val(i));
}

Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) {
return shuffleCommon(loc, rewriter, val, i, TritonGEN::ShflKind::IDX);
return shuffleCommon(loc, rewriter, val, i, mlir::gpu::ShuffleMode::IDX);
}

Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
Expand Down

0 comments on commit 31bd963

Please sign in to comment.