Skip to content

Commit

Permalink
[AMD] Add basic instruction scheduling control (#4770)
Browse files Browse the repository at this point in the history
LLVM AMDGPU backend supports special intrinsics
(https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics)
as hints to influence instruction scheduling. This PR
adds basic scaffolding for utilizing those intrinsics to better
control instructions generated from the backend. It is meant
to only target `tt.dot` operations which are often the
most intensive ones and may demand fine-tuning to achieve
better performance.

Facilities added here are experimental and we need to iterate
on it until to a good state.
  • Loading branch information
ravil-mobile authored Sep 25, 2024
1 parent 16c5b26 commit 4348109
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 0 deletions.
9 changes: 9 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ class HIPOptions:
max_num_imprecise_acc_default: int = 0
backend_name: str = 'hip'

# The following option provides hints to the AMDGPU backend regarding instruction scheduling
# for all `tt.dot` operations in a kernel. The "default" variant preserves the default
# instruction scheduling of the AMDGPU backend which aims at maximizing occupancy.
# The option is experimental and may change at any time regarding its semantics and/or may
# be gone entirely anytime.
instruction_sched_variant: str = 'default'

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
Expand Down Expand Up @@ -174,6 +181,7 @@ def make_ttgir(mod, metadata, options):
if options.num_stages == 0:
amd.passes.ttgpuir.add_stream_pipeline(pm)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
Expand Down Expand Up @@ -221,6 +229,7 @@ def make_llir(src, metadata, options):
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
Expand Down
20 changes: 20 additions & 0 deletions third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,24 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "TritonAMDGPUDialect.td"
include "TritonAMDGPUAttrDefs.td"

class TT_AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])> {
}

def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
let summary = "A placeholder op for instruction scheduling hints within a basic block";
let description = [{
A placeholder op for instruction scheduling hints applied to instructions within
a basic block where the placeholder op is located. This op is primarily intended
to be used to adjust instruction scheduling inside the resulting main loop
of a `tt.dot` operation. It's easier to identify dot ops at a high level and, thus,
to mark intended scheduling regions. The hint ops are eventually lowered
into LLVM AMDGPU instruction scheduling primitives, which are meant to control
how different kinds of instructions (valu/mfma, global/shared memory, etc.) should
interleave for better instruction level parallelism.
}];

let assemblyFormat = [{attr-dict}];
}

#endif
4 changes: 4 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz);
std::unique_ptr<OperationPass<ModuleOp>> createConvertBuiltinFuncToLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>>
createInsertInstructionSchedHintsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createLowerInstructionSchedHintsPass(std::string variant);

#define GEN_PASS_REGISTRATION
#include "TritonAMDGPUToLLVM/Passes.h.inc"
Expand Down
20 changes: 20 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,24 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul

}

def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];
}

def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Lower instruction scheduling hints to LLVM intrinsics";
let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")";

let dependentDialects = ["mlir::LLVM::LLVMDialect"];

let options = [
Option<"variant", "variant", "std::string", /*default*/"\"default\"",
"instruction scheduling variant">,
];
}


#endif
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ add_triton_library(TritonAMDGPUToLLVM
OptimizeLDSUsage.cpp
OptimizeLDSUtility.cpp
SPMDOpToLLVM.cpp
SchedInstructions.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen

LINK_LIBS PUBLIC
TritonGPUToLLVM
TritonAMDGPUIR
)
205 changes: 205 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#include "TritonAMDGPUToLLVM/Passes.h"

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir::triton {
#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS
#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS
#include "TritonAMDGPUToLLVM/Passes.h.inc"
} // namespace mlir::triton

using namespace mlir;

namespace {

// The bitmask that encodes kinds of the instructions from AMD ISA.
// The bitmask is used for providing instruction scheduling hints.
enum InstructionKindMask {
NONE = 0x0000000,
ALL_ALU = 0x00000001,
VALU = 0x00000002,
SALU = 0x00000004,
MFMA = 0x00000008,
ALL_VMEM = 0x00000010,
VMEM_READ = 0x00000020,
VMEM_WRITE = 0x00000040,
ALL_DS = 0x00000080,
DS_READ = 0x00000100,
DS_WRITE = 0x00000200
};

// Create an intrinsic to control how different instruction kinds should
// interleave for better ILP.
void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
InstructionKindMask maskValue, int sizeValue,
int groupIdValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.group.barrier");

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
Value size =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(sizeValue));
Value groupId = LLVM::createConstantI32(loc, rewriter,
static_cast<int32_t>(groupIdValue));

LLVM::FastmathFlagsAttr defaultFlags{};
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask, size, groupId},
defaultFlags);
}

// Insert intrinsic that controls the types of instructions that may be
// allowed to cross the intrinsic during instruction scheduling
Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc,
int64_t maskValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.barrier");
LLVM::FastmathFlagsAttr defaultFlags{};

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask}, defaultFlags);
}

// Insert an experimental intrinsic for instruction group level parallelism.
// The intrinsic takes a value that specifies the strategy.
Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.iglp.opt");
LLVM::FastmathFlagsAttr defaultFlags{};
Value iglpValue =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(value));
return rewriter.create<LLVM::CallIntrinsicOp>(
loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags);
}

struct InstructionSchedHintsRewriter
: public OpRewritePattern<triton::amdgpu::InstructionSchedHint> {

InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant)
: OpRewritePattern(ctx) {
std::transform(variant.begin(), variant.end(), variant.begin(),
[](unsigned char c) { return std::tolower(c); });

this->schedulingType = llvm::StringSwitch<SchedulingType>(variant)
.Case("default", SchedulingType::NONE)
.Case("iglp0", SchedulingType::IGLP0)
.Case("iglp1", SchedulingType::IGLP1)
.Default(SchedulingType::UNKNOWN);
}

enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN };

LogicalResult
matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint,
PatternRewriter &rewriter) const override {

if (this->schedulingType == SchedulingType::UNKNOWN) {
llvm::dbgs()
<< "[" << getDebugName() << "]: "
<< "unknown instruction scheduling variant has been provided\n";
return mlir::failure();
}

// The switch controls whether instructions are allowed to cross the basic
// block boundaries at the very top and at the very bottom. Note, this is
// not supposed to be used together with IGLP OPT according to the AMDGPU
// backend documentation.
const bool limitSchedulingRange =
!(schedulingType == SchedulingType::IGLP0 ||
schedulingType == SchedulingType::IGLP1);
Location loc = instructionSchedHint->getLoc();
Block *block = instructionSchedHint->getBlock();
if (limitSchedulingRange) {
rewriter.setInsertionPointToStart(block);
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE);
}

rewriter.setInsertionPoint(block, std::prev(block->end()));

switch (schedulingType) {
case SchedulingType::IGLP0:
[[fallthrough]];
case SchedulingType::IGLP1: {
createIglpOpt(rewriter, loc, static_cast<int>(schedulingType) - 1);
break;
}
case SchedulingType::NONE:
[[fallthrough]];
default: {
break;
}
}

if (limitSchedulingRange)
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE);

rewriter.eraseOp(instructionSchedHint);
return mlir::success();
}

private:
SchedulingType schedulingType;
};

struct LowerInstructionSchedHints
: public triton::impl::LowerInstructionSchedHintsBase<
LowerInstructionSchedHints> {

explicit LowerInstructionSchedHints(std::string variant) {
this->variant = variant;
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

ConversionTarget target(*ctx);
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalOp<triton::amdgpu::InstructionSchedHint>();

RewritePatternSet patterns(ctx);
patterns.add<InstructionSchedHintsRewriter>(ctx, this->variant);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
}
}
};

struct InsertInstructionSchedHints
: public triton::impl::InsertInstructionSchedHintsBase<
InsertInstructionSchedHints> {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
ModuleOp mod = getOperation();

mod->walk([ctx](triton::DotOp dot) {
if (dyn_cast<mlir::scf::ForOp>(dot->getParentOp())) {
mlir::OpBuilder rewriter(ctx);
rewriter.setInsertionPointAfter(dot);
rewriter.create<triton::amdgpu::InstructionSchedHint>(dot->getLoc());
}
});
}
};
} // namespace

namespace mlir::triton {
std::unique_ptr<OperationPass<ModuleOp>>
createLowerInstructionSchedHintsPass(std::string variant) {
return std::make_unique<LowerInstructionSchedHints>(variant);
}

std::unique_ptr<OperationPass<ModuleOp>>
createInsertInstructionSchedHintsPass() {
return std::make_unique<InsertInstructionSchedHints>();
}
} // namespace mlir::triton
2 changes: 2 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
Expand Down Expand Up @@ -57,6 +58,7 @@ class TritonLLVMConversionTarget : public ConversionTarget {
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
addLegalOp<triton::amdgpu::InstructionSchedHint>();
}
};

Expand Down
7 changes: 7 additions & 0 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) {
pm.addPass(createConvertBuiltinFuncToLLVMPass());
});
m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) {
pm.addPass(createInsertInstructionSchedHintsPass());
});
m.def("lower_instruction_sched_hints",
[](mlir::PassManager &pm, std::string variant) {
pm.addPass(createLowerInstructionSchedHintsPass(variant));
});
m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm,
const std::string &arch) {
pm.addPass(
Expand Down

0 comments on commit 4348109

Please sign in to comment.