Skip to content

Commit

Permalink
[Codegen] Query #iree_gpu.target for shared memory limit (iree-org#…
Browse files Browse the repository at this point in the history
…18184)

Signed-off-by: nithinsubbiah <[email protected]>
  • Loading branch information
nithinsubbiah authored Aug 10, 2024
1 parent e36aa78 commit b06bf6a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static int shapedTypeStaticSize(
}

/// Returns success if the total shared memory allocation size is less than the
/// limit set by limit.
/// limit.
static LogicalResult checkGPUAllocationSize(
mlir::FunctionOpInterface funcOp, unsigned limit,
std::function<unsigned(mlir::FunctionOpInterface)> getIndexBitwidth) {
Expand Down Expand Up @@ -93,15 +93,14 @@ class GPUCheckResourceUsagePass final
: public impl::GPUCheckResourceUsagePassBase<GPUCheckResourceUsagePass> {
public:
explicit GPUCheckResourceUsagePass(
std::function<unsigned(mlir::FunctionOpInterface)> getSharedMemoryLimit,
std::function<unsigned(mlir::FunctionOpInterface)> getIndexBitwidth)
: getSharedMemoryLimit(getSharedMemoryLimit),
getIndexBitwidth(getIndexBitwidth) {}
: getIndexBitwidth(getIndexBitwidth) {}

void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
unsigned limit =
getSharedMemoryLimit ? getSharedMemoryLimit(funcOp) : 64 * 1024;
target ? target.getWgp().getMaxWorkgroupMemoryBytes() : 64 * 1024;
if (failed(checkGPUAllocationSize(funcOp, limit,
getIndexBitwidth
? getIndexBitwidth
Expand All @@ -111,18 +110,15 @@ class GPUCheckResourceUsagePass final
}

private:
std::function<unsigned(mlir::FunctionOpInterface)> getSharedMemoryLimit;
std::function<unsigned(mlir::FunctionOpInterface)> getIndexBitwidth;
};

} // namespace

std::unique_ptr<InterfacePass<FunctionOpInterface>>
createGPUCheckResourceUsagePass(
std::function<unsigned(mlir::FunctionOpInterface)> getSharedMemoryLimit,
std::function<unsigned(mlir::FunctionOpInterface)> getIndexBitwidth) {
return std::make_unique<GPUCheckResourceUsagePass>(getSharedMemoryLimit,
getIndexBitwidth);
return std::make_unique<GPUCheckResourceUsagePass>(getIndexBitwidth);
}

} // namespace mlir::iree_compiler
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ LogicalResult gpuDistributeSharedMemoryCopy(mlir::FunctionOpInterface funcOp);
// get the index size.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createGPUCheckResourceUsagePass(
std::function<unsigned(mlir::FunctionOpInterface)> getSharedMemoryLimit =
nullptr,
std::function<unsigned(mlir::FunctionOpInterface)> getIndexBitwidth =
nullptr);

Expand Down
7 changes: 1 addition & 6 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,13 +992,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
// Run checks on shared memory usage.
funcPassManager
.addPass([&]() {
auto getSharedMemoryLimit = [](mlir::FunctionOpInterface entryPoint) {
IREE::GPU::TargetAttr target = getGPUTargetAttr(entryPoint);
return target.getWgp().getMaxWorkgroupMemoryBytes();
};
auto getIndexBitwidth = [](mlir::FunctionOpInterface) { return 64; };
return createGPUCheckResourceUsagePass(getSharedMemoryLimit,
getIndexBitwidth);
return createGPUCheckResourceUsagePass(getIndexBitwidth);
})
// SCF -> CF
.addPass(createConvertSCFToCFPass)
Expand Down
11 changes: 2 additions & 9 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,11 @@ static void addMemRefLoweringPasses(OpPassManager &modulePassManager) {

.addPass(createPadDynamicAlloc);

// Check to make sure we are not exceeding shared memory usage limit.
auto getSharedMemoryLimit = [](mlir::FunctionOpInterface fn) {
IREE::GPU::TargetAttr target = getGPUTargetAttr(fn);
return target.getWgp().getMaxWorkgroupMemoryBytes();
};
// TODO: query this from the target.
auto getIndexBitwidth = [](mlir::FunctionOpInterface) { return 32; };
funcPassManager
.addPass([&]() {
return createGPUCheckResourceUsagePass(getSharedMemoryLimit,
getIndexBitwidth);
})
.addPass(
[&]() { return createGPUCheckResourceUsagePass(getIndexBitwidth); })

// Fold load/store from/to subview ops into the original memref when
// possible. In SPIR-V we don't use memref descriptor so it's not possible
Expand Down

0 comments on commit b06bf6a

Please sign in to comment.