-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AMD] Add basic instruction scheduling control (#4770)
LLVM AMDGPU backend supports special intrinsics (https://llvm.org/docs/AMDGPUUsage.html#llvm-ir-intrinsics) as hints to influence instruction scheduling. This PR adds basic scaffolding for utilizing those intrinsics to better control instructions generated from the backend. It is meant to only target `tt.dot` operations which are often the most intensive ones and may demand fine-tuning to achieve better performance. Facilities added here are experimental and we need to iterate on it until to a good state.
- Loading branch information
1 parent
16c5b26
commit 4348109
Showing
8 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
205 changes: 205 additions & 0 deletions
205
third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
#include "TritonAMDGPUToLLVM/Passes.h" | ||
|
||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" | ||
#include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
#include "triton/Dialect/Triton/IR/Dialect.h" | ||
|
||
namespace mlir::triton { | ||
#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS | ||
#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS | ||
#include "TritonAMDGPUToLLVM/Passes.h.inc" | ||
} // namespace mlir::triton | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
// The bitmask that encodes kinds of the instructions from AMD ISA. | ||
// The bitmask is used for providing instruction scheduling hints. | ||
enum InstructionKindMask { | ||
NONE = 0x0000000, | ||
ALL_ALU = 0x00000001, | ||
VALU = 0x00000002, | ||
SALU = 0x00000004, | ||
MFMA = 0x00000008, | ||
ALL_VMEM = 0x00000010, | ||
VMEM_READ = 0x00000020, | ||
VMEM_WRITE = 0x00000040, | ||
ALL_DS = 0x00000080, | ||
DS_READ = 0x00000100, | ||
DS_WRITE = 0x00000200 | ||
}; | ||
|
||
// Create an intrinsic to control how different instruction kinds should | ||
// interleave for better ILP. | ||
void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, | ||
InstructionKindMask maskValue, int sizeValue, | ||
int groupIdValue) { | ||
MLIRContext *ctx = rewriter.getContext(); | ||
auto intrinsicName = str_attr("llvm.amdgcn.sched.group.barrier"); | ||
|
||
Value mask = | ||
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue)); | ||
Value size = | ||
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(sizeValue)); | ||
Value groupId = LLVM::createConstantI32(loc, rewriter, | ||
static_cast<int32_t>(groupIdValue)); | ||
|
||
LLVM::FastmathFlagsAttr defaultFlags{}; | ||
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName, | ||
ValueRange{mask, size, groupId}, | ||
defaultFlags); | ||
} | ||
|
||
// Insert intrinsic that controls the types of instructions that may be | ||
// allowed to cross the intrinsic during instruction scheduling | ||
Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, | ||
int64_t maskValue) { | ||
MLIRContext *ctx = rewriter.getContext(); | ||
auto intrinsicName = str_attr("llvm.amdgcn.sched.barrier"); | ||
LLVM::FastmathFlagsAttr defaultFlags{}; | ||
|
||
Value mask = | ||
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue)); | ||
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName, | ||
ValueRange{mask}, defaultFlags); | ||
} | ||
|
||
// Insert an experimental intrinsic for instruction group level parallelism. | ||
// The intrinsic takes a value that specifies the strategy. | ||
Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { | ||
MLIRContext *ctx = rewriter.getContext(); | ||
auto intrinsicName = str_attr("llvm.amdgcn.iglp.opt"); | ||
LLVM::FastmathFlagsAttr defaultFlags{}; | ||
Value iglpValue = | ||
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(value)); | ||
return rewriter.create<LLVM::CallIntrinsicOp>( | ||
loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags); | ||
} | ||
|
||
struct InstructionSchedHintsRewriter | ||
: public OpRewritePattern<triton::amdgpu::InstructionSchedHint> { | ||
|
||
InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) | ||
: OpRewritePattern(ctx) { | ||
std::transform(variant.begin(), variant.end(), variant.begin(), | ||
[](unsigned char c) { return std::tolower(c); }); | ||
|
||
this->schedulingType = llvm::StringSwitch<SchedulingType>(variant) | ||
.Case("default", SchedulingType::NONE) | ||
.Case("iglp0", SchedulingType::IGLP0) | ||
.Case("iglp1", SchedulingType::IGLP1) | ||
.Default(SchedulingType::UNKNOWN); | ||
} | ||
|
||
enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; | ||
|
||
LogicalResult | ||
matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, | ||
PatternRewriter &rewriter) const override { | ||
|
||
if (this->schedulingType == SchedulingType::UNKNOWN) { | ||
llvm::dbgs() | ||
<< "[" << getDebugName() << "]: " | ||
<< "unknown instruction scheduling variant has been provided\n"; | ||
return mlir::failure(); | ||
} | ||
|
||
// The switch controls whether instructions are allowed to cross the basic | ||
// block boundaries at the very top and at the very bottom. Note, this is | ||
// not supposed to be used together with IGLP OPT according to the AMDGPU | ||
// backend documentation. | ||
const bool limitSchedulingRange = | ||
!(schedulingType == SchedulingType::IGLP0 || | ||
schedulingType == SchedulingType::IGLP1); | ||
Location loc = instructionSchedHint->getLoc(); | ||
Block *block = instructionSchedHint->getBlock(); | ||
if (limitSchedulingRange) { | ||
rewriter.setInsertionPointToStart(block); | ||
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); | ||
} | ||
|
||
rewriter.setInsertionPoint(block, std::prev(block->end())); | ||
|
||
switch (schedulingType) { | ||
case SchedulingType::IGLP0: | ||
[[fallthrough]]; | ||
case SchedulingType::IGLP1: { | ||
createIglpOpt(rewriter, loc, static_cast<int>(schedulingType) - 1); | ||
break; | ||
} | ||
case SchedulingType::NONE: | ||
[[fallthrough]]; | ||
default: { | ||
break; | ||
} | ||
} | ||
|
||
if (limitSchedulingRange) | ||
createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); | ||
|
||
rewriter.eraseOp(instructionSchedHint); | ||
return mlir::success(); | ||
} | ||
|
||
private: | ||
SchedulingType schedulingType; | ||
}; | ||
|
||
struct LowerInstructionSchedHints | ||
: public triton::impl::LowerInstructionSchedHintsBase< | ||
LowerInstructionSchedHints> { | ||
|
||
explicit LowerInstructionSchedHints(std::string variant) { | ||
this->variant = variant; | ||
} | ||
|
||
void runOnOperation() override { | ||
MLIRContext *ctx = &getContext(); | ||
ModuleOp mod = getOperation(); | ||
|
||
ConversionTarget target(*ctx); | ||
target.addLegalDialect<LLVM::LLVMDialect>(); | ||
target.addIllegalOp<triton::amdgpu::InstructionSchedHint>(); | ||
|
||
RewritePatternSet patterns(ctx); | ||
patterns.add<InstructionSchedHintsRewriter>(ctx, this->variant); | ||
|
||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
struct InsertInstructionSchedHints | ||
: public triton::impl::InsertInstructionSchedHintsBase< | ||
InsertInstructionSchedHints> { | ||
void runOnOperation() override { | ||
MLIRContext *ctx = &getContext(); | ||
ModuleOp mod = getOperation(); | ||
|
||
mod->walk([ctx](triton::DotOp dot) { | ||
if (dyn_cast<mlir::scf::ForOp>(dot->getParentOp())) { | ||
mlir::OpBuilder rewriter(ctx); | ||
rewriter.setInsertionPointAfter(dot); | ||
rewriter.create<triton::amdgpu::InstructionSchedHint>(dot->getLoc()); | ||
} | ||
}); | ||
} | ||
}; | ||
} // namespace | ||
|
||
namespace mlir::triton { | ||
std::unique_ptr<OperationPass<ModuleOp>> | ||
createLowerInstructionSchedHintsPass(std::string variant) { | ||
return std::make_unique<LowerInstructionSchedHints>(variant); | ||
} | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> | ||
createInsertInstructionSchedHintsPass() { | ||
return std::make_unique<InsertInstructionSchedHints>(); | ||
} | ||
} // namespace mlir::triton |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters