Skip to content

Commit

Permalink
Adding simplified HAL dispatch methods.
Browse files Browse the repository at this point in the history
These combine push constants and push descriptor sets into the dispatch
calls as in practice we have a near 1:1 relationship anyway. Pipeline
layouts are still used in HAL interfaces to allow the compiler to map
the information but are otherwise not used by the new ops.

The `--iree-hal-experimental-dispatch2` flag enables emitting the new ops
though no targets currently implement them. Since executables no longer
require pipeline layouts in this simplified model the
`--iree-hal-experimental-executable-create2` flag can be used to stop
passing them.

Progress on #18154.

Signed-off-by: Ben Vanik <[email protected]>
  • Loading branch information
benvanik committed Aug 12, 2024
1 parent 552bb63 commit 59bf417
Show file tree
Hide file tree
Showing 68 changed files with 3,335 additions and 210 deletions.
20 changes: 14 additions & 6 deletions compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,21 @@ class LLVMCPUTargetBackend final : public TargetBackend {
llvmFunc->addParamAttr(i, align16);
}

// Optionally entry points may specify that they require workgroup local
LibraryBuilder::DispatchAttrs dispatchAttrs = {0};

// Entry points may optionally specify that they require workgroup local
// memory. We fetch that value here and plumb it through so the runtime
// knows how much memory to reserve and pass in.
int64_t localMemorySize = exportOp.getWorkgroupLocalMemory()
.value_or(APInt(64, 0))
.getSExtValue();
dispatchAttrs.localMemorySize = exportOp.getWorkgroupLocalMemory()
.value_or(APInt(64, 0))
.getSExtValue();

// Specify the constant and binding information used to validate
// dispatches.
// TODO(#18189): pack per-binding information bitfields.
dispatchAttrs.constantCount = exportOp.getLayout().getPushConstants();
dispatchAttrs.bindingCount =
exportOp.getLayout().getSetLayout(0).getBindings().size();

LibraryBuilder::SourceLocation sourceLocation;
if (options.debugLevel >= 1) {
Expand All @@ -417,8 +426,7 @@ class LLVMCPUTargetBackend final : public TargetBackend {
}
libraryBuilder.addExport(exportOp.getName(), std::move(sourceLocation),
std::move(stageLocations), /*tag=*/"",
LibraryBuilder::DispatchAttrs{localMemorySize},
llvmFunc);
dispatchAttrs, llvmFunc);
}

// Embed source files (if present).
Expand Down
15 changes: 10 additions & 5 deletions compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,22 @@ makeDispatchFunctionType(llvm::LLVMContext &context) {

// %struct.iree_hal_executable_dispatch_attrs_v0_t = type {
// i16,
// i16
// i8,
// i8
// }
static llvm::StructType *makeDispatchAttrsType(llvm::LLVMContext &context) {
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_dispatch_attrs_v0_t")) {
return existingType;
}
auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i16Type = llvm::IntegerType::getInt16Ty(context);
auto *type =
llvm::StructType::create(context,
{
i16Type,
i16Type,
i8Type,
i8Type,
},
"iree_hal_executable_dispatch_attrs_v0_t",
/*isPacked=*/false);
Expand Down Expand Up @@ -502,7 +505,7 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {
bool hasNonDefaultAttrs = llvm::any_of(exports, [](const auto &dispatch) {
return !dispatch.attrs.isDefault();
});
if (!hasNonDefaultAttrs) {
if (hasNonDefaultAttrs) {
SmallVector<llvm::Constant *> exportAttrValues;
for (auto dispatch : exports) {
exportAttrValues.push_back(llvm::ConstantStruct::get(
Expand All @@ -513,8 +516,10 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {
i16Type, roundUpToAlignment(dispatch.attrs.localMemorySize,
kWorkgroupLocalMemoryPageSize) /
kWorkgroupLocalMemoryPageSize),
// reserved=
llvm::ConstantInt::get(i16Type, 0),
// constant_count=
llvm::ConstantInt::get(i8Type, dispatch.attrs.constantCount),
// binding_count=
llvm::ConstantInt::get(i8Type, dispatch.attrs.bindingCount),
}));
}
exportAttrs = createArrayConstant(libraryName + "_attrs", dispatchAttrsType,
Expand Down
10 changes: 8 additions & 2 deletions compiler/plugins/target/LLVMCPU/LibraryBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,22 @@ class LibraryBuilder {
UNDEFINED = 4u,
};

// IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
// IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
static const int64_t kWorkgroupLocalMemoryPageSize = 4096;

// iree_hal_executable_dispatch_attrs_v0_t
struct DispatchAttrs {
// Required workgroup local memory size, in bytes.
int64_t localMemorySize = 0;
// Total number of 32-bit constants used by the dispatch.
uint8_t constantCount = 0;
// Total number of bindings used by the dispatch.
uint8_t bindingCount = 0;

// True if all values are default and the attributes may be omitted.
constexpr bool isDefault() const { return localMemorySize == 0; }
constexpr bool isDefault() const {
return localMemorySize == 0 && constantCount == 0 && bindingCount == 0;
}
};

// iree_hal_executable_source_location_v0_t
Expand Down
22 changes: 21 additions & 1 deletion compiler/plugins/target/VMVX/VMVXTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ class VMVXTargetBackend final : public TargetBackend {
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
// Add reflection information used at runtime specific to the HAL interface.
SymbolTable symbolTable(variantOp.getInnerModule());
auto vmModule =
*variantOp.getInnerModule().getOps<IREE::VM::ModuleOp>().begin();
SymbolTable symbolTable(vmModule);
for (auto exportOp : variantOp.getBlock().getOps<ExecutableExportOp>()) {
auto funcOp = symbolTable.lookup<IREE::VM::FuncOp>(exportOp.getName());

Expand All @@ -127,6 +129,24 @@ class VMVXTargetBackend final : public TargetBackend {
if (localMemorySizeAttr) {
funcOp.setReflectionAttr("local_memory", localMemorySizeAttr);
}

// Specify the constant and binding information used to validate
// dispatches.
// TODO(#18189): pack per-binding information bitfields.
if (auto layoutAttr = exportOp.getLayout()) {
int64_t constantCount = layoutAttr.getPushConstants();
if (constantCount > 0) {
funcOp.setReflectionAttr("constant_count",
executableBuilder.getI8IntegerAttr(
static_cast<uint8_t>(constantCount)));
}
size_t bindingCount = layoutAttr.getSetLayout(0).getBindings().size();
if (bindingCount > 0) {
funcOp.setReflectionAttr("binding_count",
executableBuilder.getI8IntegerAttr(
static_cast<uint8_t>(bindingCount)));
}
}
}

// Serialize the VM module to bytes and embed it directly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ addSet3IfNeeded(IREE::HAL::PipelineLayoutAttr originalAttr) {
SmallVector<IREE::HAL::DescriptorSetBindingAttr> bindingAttrs;
bindingAttrs.push_back(IREE::HAL::DescriptorSetBindingAttr::get(
originalAttr.getContext(), 0, IREE::HAL::DescriptorType::UniformBuffer,
std::nullopt));
IREE::HAL::DescriptorFlags::None));
setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get(
originalAttr.getContext(), 3, bindingAttrs, std::nullopt));
return IREE::HAL::PipelineLayoutAttr::get(originalAttr.getContext(),
Expand Down
99 changes: 28 additions & 71 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ assumeExportLayout(IREE::HAL::PipelineLayoutAttr layoutAttr) {
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = bindingAttr.getOrdinal();
setBinding.type = bindingAttr.getType();
setBinding.flags =
bindingAttr.getFlags().value_or(IREE::HAL::DescriptorFlags::None);
setBinding.flags = bindingAttr.getFlags();
setLayout.bindings[setBinding.ordinal] = setBinding;
pipelineLayout.resourceMap.emplace_back(setLayout.ordinal,
setBinding.ordinal);
Expand Down Expand Up @@ -123,7 +122,6 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,

// Check the usage of each binding at each dispatch site.
struct DescriptorInfo {
bool isIndirect = false;
DescriptorFlags flags = DescriptorFlags::None;
};
SmallVector<DescriptorInfo> descriptorInfos(bindingCount);
Expand All @@ -142,12 +140,18 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,
// Opt into indirect descriptors when dynamic values are used from
// execution regions that may be executed more than once.
if (!isRegionExecutedOnce) {
auto resource = dispatchOp.getResources()[i];
Value resource = dispatchOp.getResources()[i];
if (auto blockArg = dyn_cast<BlockArgument>(resource)) {
if (blockArg.getOwner()->getParentOp() == parentOp) {
resource = parentOp.getResourceOperands()[blockArg.getArgNumber()];
}
}
switch (categorizeValue(resource)) {
default:
case ValueOrigin::Unknown:
case ValueOrigin::MutableGlobal:
descriptorInfo.isIndirect |= true;
descriptorInfo.flags =
descriptorInfo.flags | IREE::HAL::DescriptorFlags::Indirect;
break;
case ValueOrigin::LocalConstant:
case ValueOrigin::ImmutableGlobal:
Expand All @@ -173,74 +177,27 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,
pipelineLayout.pushConstantCount = operandCount;
pipelineLayout.resourceMap.resize(bindingCount);

// Today we use one or two sets based on the composition of bindings we have:
// we try to put everything in a directly referenced set 0 and spill over any
// indirectly referenced values into the second set.
//
// HACK: the Vulkan HAL implementation currently cannot handle multiple
// descriptor sets. Ouch. To preserve existing behavior we only use a single
// set and mark the whole thing as indirect if any bindings are indirect.
const bool forceSingleSet = true;
if (forceSingleSet) {
DescriptorSetLayout setLayout;
setLayout.ordinal = 0;
setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
setLayout.bindings.reserve(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
if (descriptorInfo.isIndirect) {
setLayout.flags =
setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, setBinding.ordinal);
}
pipelineLayout.setLayouts.push_back(setLayout);
} else {
DescriptorSetLayout directSetLayout;
directSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
directSetLayout.bindings.reserve(bindingCount);
DescriptorSetLayout indirectSetLayout;
indirectSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::Indirect;
indirectSetLayout.bindings.reserve(bindingCount);

// Ordinals relative to the owning set.
SmallVector<unsigned> bindingSetOrdinals(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
auto &setLayout =
descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
bindingSetOrdinals[i] = setBinding.ordinal;
}
unsigned nextSetOrdinal = 0;
if (!directSetLayout.bindings.empty()) {
directSetLayout.ordinal = nextSetOrdinal++;
pipelineLayout.setLayouts.push_back(directSetLayout);
}
if (!indirectSetLayout.bindings.empty()) {
indirectSetLayout.ordinal = nextSetOrdinal++;
pipelineLayout.setLayouts.push_back(indirectSetLayout);
}

// Map each resource to its set/binding ordinals.
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
auto &setLayout =
descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, bindingSetOrdinals[i]);
// TODO(#18154): simplify binding setup.
DescriptorSetLayout setLayout;
setLayout.ordinal = 0;
setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
setLayout.bindings.reserve(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
if (allEnumBitsSet(descriptorInfo.flags,
IREE::HAL::DescriptorFlags::Indirect)) {
setLayout.flags =
setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, setBinding.ordinal);
}
pipelineLayout.setLayouts.push_back(setLayout);

LLVM_DEBUG({
auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
Expand Down
10 changes: 0 additions & 10 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@
namespace mlir::iree_compiler::IREE::HAL {

ValueOrigin categorizeValue(Value value) {
// If this is a captured argument of an execution region then look up to the
// SSA value that was captured.
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (auto closureOp = dyn_cast<IREE::Util::ClosureOpInterface>(
blockArg.getOwner()->getParentOp())) {
return categorizeValue(
closureOp.getClosureOperands()[blockArg.getArgNumber()]);
}
}

// If we wanted to pull in entire IR slices this would have to use a
// worklist (selects of globals based on globals, etc). For now this analysis
// only looks at the value provided.
Expand Down
Loading

0 comments on commit 59bf417

Please sign in to comment.