Skip to content

Commit

Permalink
Adding simplified HAL dispatch methods.
Browse files Browse the repository at this point in the history
These combine push constants and push descriptor sets into the dispatch
calls as in practice we have a near 1:1 relationship anyway. Pipeline
layouts are still used in HAL interfaces to allow the compiler to map
the information but are otherwise not used by the new ops.

The `--iree-hal-experimental-dispatch2` flag enables emitting the new ops
though no targets currently implement them. Since executables no longer
require pipeline layouts in this simplified model the
`--iree-hal-experimental-executable-create2` flag can be used to stop
passing them.

Progress on #18154.
  • Loading branch information
benvanik committed Aug 10, 2024
1 parent 5a48912 commit 6e58a51
Show file tree
Hide file tree
Showing 30 changed files with 1,639 additions and 123 deletions.
96 changes: 27 additions & 69 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,

// Check the usage of each binding at each dispatch site.
struct DescriptorInfo {
bool isIndirect = false;
DescriptorFlags flags = DescriptorFlags::None;
};
SmallVector<DescriptorInfo> descriptorInfos(bindingCount);
Expand All @@ -142,12 +141,18 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,
// Opt into indirect descriptors when dynamic values are used from
// execution regions that may be executed more than once.
if (!isRegionExecutedOnce) {
auto resource = dispatchOp.getResources()[i];
Value resource = dispatchOp.getResources()[i];
if (auto blockArg = dyn_cast<BlockArgument>(resource)) {
if (blockArg.getOwner()->getParentOp() == parentOp) {
resource = parentOp.getResourceOperands()[blockArg.getArgNumber()];
}
}
switch (categorizeValue(resource)) {
default:
case ValueOrigin::Unknown:
case ValueOrigin::MutableGlobal:
descriptorInfo.isIndirect |= true;
descriptorInfo.flags =
descriptorInfo.flags | IREE::HAL::DescriptorFlags::Indirect;
break;
case ValueOrigin::LocalConstant:
case ValueOrigin::ImmutableGlobal:
Expand All @@ -173,74 +178,27 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,
pipelineLayout.pushConstantCount = operandCount;
pipelineLayout.resourceMap.resize(bindingCount);

// Today we use one or two sets based on the composition of bindings we have:
// we try to put everything in a directly referenced set 0 and spill over any
// indirectly referenced values into the second set.
//
// HACK: the Vulkan HAL implementation currently cannot handle multiple
// descriptor sets. Ouch. To preserve existing behavior we only use a single
// set and mark the whole thing as indirect if any bindings are indirect.
const bool forceSingleSet = true;
if (forceSingleSet) {
DescriptorSetLayout setLayout;
setLayout.ordinal = 0;
setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
setLayout.bindings.reserve(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
if (descriptorInfo.isIndirect) {
setLayout.flags =
setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, setBinding.ordinal);
}
pipelineLayout.setLayouts.push_back(setLayout);
} else {
DescriptorSetLayout directSetLayout;
directSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
directSetLayout.bindings.reserve(bindingCount);
DescriptorSetLayout indirectSetLayout;
indirectSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::Indirect;
indirectSetLayout.bindings.reserve(bindingCount);

// Ordinals relative to the owning set.
SmallVector<unsigned> bindingSetOrdinals(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
auto &setLayout =
descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
bindingSetOrdinals[i] = setBinding.ordinal;
}
unsigned nextSetOrdinal = 0;
if (!directSetLayout.bindings.empty()) {
directSetLayout.ordinal = nextSetOrdinal++;
pipelineLayout.setLayouts.push_back(directSetLayout);
}
if (!indirectSetLayout.bindings.empty()) {
indirectSetLayout.ordinal = nextSetOrdinal++;
pipelineLayout.setLayouts.push_back(indirectSetLayout);
}

// Map each resource to its set/binding ordinals.
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
auto &setLayout =
descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, bindingSetOrdinals[i]);
// TODO(#18154): simplify binding setup.
DescriptorSetLayout setLayout;
setLayout.ordinal = 0;
setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
setLayout.bindings.reserve(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
if (allEnumBitsSet(descriptorInfo.flags,
IREE::HAL::DescriptorFlags::Indirect)) {
setLayout.flags =
setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, setBinding.ordinal);
}
pipelineLayout.setLayouts.push_back(setLayout);

LLVM_DEBUG({
auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
Expand Down
10 changes: 0 additions & 10 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@
namespace mlir::iree_compiler::IREE::HAL {

ValueOrigin categorizeValue(Value value) {
// If this is a captured argument of an execution region then look up to the
// SSA value that was captured.
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (auto closureOp = dyn_cast<IREE::Util::ClosureOpInterface>(
blockArg.getOwner()->getParentOp())) {
return categorizeValue(
closureOp.getClosureOperands()[blockArg.getArgNumber()]);
}
}

// If we wanted to pull in entire IR slices this would have to use a
// worklist (selects of globals based on globals, etc). For now this analysis
// only looks at the value provided.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,162 @@ class CommandBufferDispatchIndirectOpConversion
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferDispatch2OpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatch2Op> {
public:
CommandBufferDispatch2OpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(typeConverter, context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::CommandBufferDispatch2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();

auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
Value zeroI32 = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());

auto flags = adaptor.getFlagsAttr()
? rewriter
.create<IREE::VM::ConstI64Op>(
op.getLoc(), adaptor.getFlagsAttr().getInt())
.getResult()
: rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
.getResult();
SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
adaptor.getExecutable(),
castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
castToImportType(adaptor.getWorkgroupX(), i32Type, rewriter),
castToImportType(adaptor.getWorkgroupY(), i32Type, rewriter),
castToImportType(adaptor.getWorkgroupZ(), i32Type, rewriter),
flags,
};
SmallVector<int16_t, 5> segmentSizes = {
/*command_buffer=*/-1,
/*executable=*/-1,
/*entry_point=*/-1,
/*workgroup_x=*/-1,
/*workgroup_y=*/-1,
/*workgroup_z=*/-1,
/*flags=*/-1,
/*constants=*/static_cast<int16_t>(adaptor.getConstants().size()),
/*bindings=*/
static_cast<int16_t>(adaptor.getBindingBuffers().size()),
};
llvm::append_range(callOperands, adaptor.getConstants());
for (auto [bindingBufferOrSlot, bindingOffset, bindingLength] :
llvm::zip_equal(adaptor.getBindingBuffers(),
adaptor.getBindingOffsets(),
adaptor.getBindingLengths())) {
callOperands.push_back(zeroI32);
auto [bindingBufferSlot, bindingBuffer] =
splitBufferSlot(op.getLoc(), bindingBufferOrSlot, rewriter);
callOperands.push_back(bindingBufferSlot);
callOperands.push_back(bindingBuffer);
callOperands.push_back(
castToImportType(bindingOffset, i64Type, rewriter));
callOperands.push_back(
castToImportType(bindingLength, i64Type, rewriter));
}

auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
importType.getInputs(), callOperands);
copyImportAttrs(importOp, callOp);
return success();
}

private:
mutable IREE::VM::ImportOp importOp;
};

class CommandBufferDispatch2IndirectOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferDispatch2IndirectOp> {
public:
CommandBufferDispatch2IndirectOpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(typeConverter, context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::CommandBufferDispatch2IndirectOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto importType = importOp.getFunctionType();

auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
Value zeroI32 = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());

auto [workgroupsBufferSlot, workgroupsBuffer] =
splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
auto flags = adaptor.getFlagsAttr()
? rewriter
.create<IREE::VM::ConstI64Op>(
op.getLoc(), adaptor.getFlagsAttr().getInt())
.getResult()
: rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
.getResult();
SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
adaptor.getExecutable(),
castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
workgroupsBufferSlot,
workgroupsBuffer,
castToImportType(adaptor.getWorkgroupsOffset(), i64Type, rewriter),
flags,
};
SmallVector<int16_t, 5> segmentSizes = {
/*command_buffer=*/-1,
/*executable=*/-1,
/*entry_point=*/-1,
/*workgroups_buffer_slot=*/-1,
/*workgroups_buffer=*/-1,
/*workgroups_offset=*/-1,
/*flags=*/-1,
/*constants=*/static_cast<int16_t>(adaptor.getConstants().size()),
/*bindings=*/
static_cast<int16_t>(adaptor.getBindingBuffers().size()),
};
llvm::append_range(callOperands, adaptor.getConstants());
for (auto [bindingBufferOrSlot, bindingOffset, bindingLength] :
llvm::zip_equal(adaptor.getBindingBuffers(),
adaptor.getBindingOffsets(),
adaptor.getBindingLengths())) {
callOperands.push_back(zeroI32);
auto [bindingBufferSlot, bindingBuffer] =
splitBufferSlot(op.getLoc(), bindingBufferOrSlot, rewriter);
callOperands.push_back(bindingBufferSlot);
callOperands.push_back(bindingBuffer);
callOperands.push_back(
castToImportType(bindingOffset, i64Type, rewriter));
callOperands.push_back(
castToImportType(bindingLength, i64Type, rewriter));
}

auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
importType.getInputs(), callOperands);
copyImportAttrs(importOp, callOp);
return success();
}

private:
mutable IREE::VM::ImportOp importOp;
};

} // namespace

void populateHALCommandBufferToVMPatterns(MLIRContext *context,
Expand Down Expand Up @@ -468,6 +624,11 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context,
patterns.insert<CommandBufferDispatchIndirectOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.dispatch.indirect");
patterns.insert<CommandBufferDispatch2OpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch2");
patterns.insert<CommandBufferDispatch2IndirectOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.dispatch2.indirect");
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,62 @@ class ExecutableCreateOpConversion
mutable IREE::VM::ImportOp importOp;
};

class ExecutableCreate2OpConversion
: public OpConversionPattern<IREE::HAL::ExecutableCreate2Op> {
public:
ExecutableCreate2OpConversion(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
StringRef importName)
: OpConversionPattern(context) {
importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
assert(importOp);
}

LogicalResult
matchAndRewrite(IREE::HAL::ExecutableCreate2Op createOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Materialize vm.rodata for the binary.
auto executableBinaryOp =
SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableBinaryOp>(
createOp, createOp.getExecutableTarget());
auto executableOp = executableBinaryOp.getOperation()
->getParentOfType<IREE::HAL::ExecutableOp>();
std::string rodataName = sanitizeSymbolName(
(executableOp.getName() + "_" + executableBinaryOp.getName()).str());
auto rodataOp = rewriter.create<IREE::VM::RodataInlineOp>(
executableBinaryOp.getLoc(),
IREE::VM::RefType::get(rewriter.getType<IREE::VM::BufferType>()),
rewriter.getStringAttr(rodataName), executableBinaryOp.getData(),
rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr());

// Get format string as a rodata blob.
auto executableFormatStr = rewriter.create<IREE::VM::RodataInlineOp>(
createOp.getLoc(), executableBinaryOp.getFormatAttr());

// Pack constants, if any.
auto constantBuffer = createPackedConstantBuffer(
createOp.getLoc(), adaptor.getConstants(), rewriter);

SmallVector<Value, 8> callOperands = {
adaptor.getDevice(),
executableFormatStr,
rodataOp,
constantBuffer,
};
auto importType = importOp.getFunctionType();
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
createOp, SymbolRefAttr::get(importOp), importType.getResults(),
callOperands);
copyImportAttrs(importOp, callOp);

return success();
}

private:
mutable IREE::VM::ImportOp importOp;
};

} // namespace

void populateHALExecutableToVMPatterns(MLIRContext *context,
Expand All @@ -162,6 +218,8 @@ void populateHALExecutableToVMPatterns(MLIRContext *context,

patterns.insert<ExecutableCreateOpConversion>(
context, importSymbols, typeConverter, "hal.executable.create");
patterns.insert<ExecutableCreate2OpConversion>(
context, importSymbols, typeConverter, "hal.executable.create2");

patterns.insert<VMImportOpConversion<IREE::HAL::DescriptorSetLayoutCreateOp>>(
context, importSymbols, typeConverter,
Expand Down
Loading

0 comments on commit 6e58a51

Please sign in to comment.