Skip to content

Commit

Permalink
Removing descriptor set layouts from HAL IR and simplifying bindings.
Browse files Browse the repository at this point in the history
* Renamed `push_constants` to `constants` (as there is no longer a
  `push_constants` API)
* Dropped `#hal.descriptor_set.layout`
* Removed ordinal from `#hal.descriptor_set.binding` (as ordinals are
  now implicit)
* Renamed `#hal.descriptor_set.binding` to `#hal.pipeline.binding`
* Removed `set` from `hal.interface.binding.subspan`
* Removed `#hal.interface.binding` and the spooky action at a distance
  `hal.interface.binding` attr now that ordinals are implicit

Progress on #18154.
  • Loading branch information
benvanik committed Aug 26, 2024
1 parent 1efeaca commit 1ec2e5c
Show file tree
Hide file tree
Showing 296 changed files with 6,919 additions and 9,287 deletions.
5 changes: 2 additions & 3 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,10 +672,9 @@ class CUDATargetBackend final : public TargetBackend {
}

auto layoutAttr = exportOp.getLayoutAttr();
uint32_t constantCount =
static_cast<uint32_t>(layoutAttr.getPushConstants());
uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants());
SmallVector<iree_hal_cuda_BindingBits_enum_t> bindingFlags;
for (auto bindingAttr : layoutAttr.getSetLayout(0).getBindings()) {
for (auto bindingAttr : layoutAttr.getBindings()) {
iree_hal_cuda_BindingBits_enum_t flags = 0;
if (allEnumBitsSet(bindingAttr.getFlags(),
IREE::HAL::DescriptorFlags::ReadOnly)) {
Expand Down
5 changes: 2 additions & 3 deletions compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,8 @@ class LLVMCPUTargetBackend final : public TargetBackend {
// Specify the constant and binding information used to validate
// dispatches.
if (auto layoutAttr = exportOp.getLayout()) {
dispatchAttrs.constantCount = layoutAttr.getPushConstants();
dispatchAttrs.bindingCount =
layoutAttr.getSetLayout(0).getBindings().size();
dispatchAttrs.constantCount = layoutAttr.getConstants();
dispatchAttrs.bindingCount = layoutAttr.getBindings().size();
}

LibraryBuilder::SourceLocation sourceLocation;
Expand Down
5 changes: 2 additions & 3 deletions compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,9 @@ class MetalSPIRVTargetBackend : public TargetBackend {
};

auto layoutAttr = exportOp.getLayoutAttr();
uint32_t constantCount =
static_cast<uint32_t>(layoutAttr.getPushConstants());
uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants());
SmallVector<iree_hal_metal_BindingBits_enum_t> bindingFlags;
for (auto bindingAttr : layoutAttr.getSetLayout(0).getBindings()) {
for (auto bindingAttr : layoutAttr.getBindings()) {
iree_hal_metal_BindingBits_enum_t flags = 0;
if (allEnumBitsSet(bindingAttr.getFlags(),
IREE::HAL::DescriptorFlags::ReadOnly)) {
Expand Down
5 changes: 2 additions & 3 deletions compiler/plugins/target/ROCM/ROCMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,9 @@ class ROCMTargetBackend final : public TargetBackend {
}

auto layoutAttr = exportOp.getLayoutAttr();
uint32_t constantCount =
static_cast<uint32_t>(layoutAttr.getPushConstants());
uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants());
SmallVector<iree_hal_hip_BindingBits_enum_t> bindingFlags;
for (auto bindingAttr : layoutAttr.getSetLayout(0).getBindings()) {
for (auto bindingAttr : layoutAttr.getBindings()) {
iree_hal_hip_BindingBits_enum_t flags = 0;
if (allEnumBitsSet(bindingAttr.getFlags(),
IREE::HAL::DescriptorFlags::ReadOnly)) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/plugins/target/VMVX/VMVXTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ class VMVXTargetBackend final : public TargetBackend {
// Specify the constant and binding information used to validate
// dispatches.
if (auto layoutAttr = exportOp.getLayout()) {
int64_t constantCount = layoutAttr.getPushConstants();
int64_t constantCount = layoutAttr.getConstants();
if (constantCount > 0) {
funcOp.setReflectionAttr("constant_count",
executableBuilder.getI8IntegerAttr(
static_cast<uint8_t>(constantCount)));
}
size_t bindingCount = layoutAttr.getSetLayout(0).getBindings().size();
size_t bindingCount = layoutAttr.getBindings().size();
if (bindingCount > 0) {
funcOp.setReflectionAttr("binding_count",
executableBuilder.getI8IntegerAttr(
Expand Down
9 changes: 4 additions & 5 deletions compiler/plugins/target/VMVX/test/smoketest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ stream.executable public @add_dispatch_0 {
// CHECK-LABEL: hal.executable public @add_dispatch_0
// CHECK-NEXT: hal.executable.variant public @vmvx_bytecode_fb target(<"vmvx", "vmvx-bytecode-fb">) {
// CHECK-NEXT: hal.executable.export public @add_dispatch_0 ordinal(0)
// CHECK-SAME: layout(#hal.pipeline.layout<push_constants = 0, sets = [
// CHECK-SAME: <0, bindings = [
// CHECK-SAME: <0, storage_buffer>,
// CHECK-SAME: <1, storage_buffer>,
// CHECK-SAME: <2, storage_buffer>
// CHECK-SAME: layout(#hal.pipeline.layout<bindings = [
// CHECK-SAME: #hal.pipeline.binding<storage_buffer>,
// CHECK-SAME: #hal.pipeline.binding<storage_buffer>,
// CHECK-SAME: #hal.pipeline.binding<storage_buffer>
// CHECK: module attributes {vm.toplevel} {
// CHECK-NEXT: vm.module public @module {
// CHECK-NEXT: vm.func private @add_dispatch_0(
Expand Down
28 changes: 16 additions & 12 deletions compiler/plugins/target/VulkanSPIRV/VulkanSPIRVTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ struct VulkanSPIRVTargetOptions {
};
} // namespace

using DescriptorSetLayout = std::pair<unsigned, ArrayRef<PipelineBindingAttr>>;

static std::tuple<iree_hal_vulkan_DescriptorSetLayoutDef_vec_ref_t,
iree_hal_vulkan_PipelineLayoutDef_vec_ref_t,
DenseMap<IREE::HAL::PipelineLayoutAttr, uint32_t>>
createPipelineLayoutDefs(ArrayRef<IREE::HAL::ExecutableExportOp> exportOps,
FlatbufferBuilder &fbb) {
DenseMap<IREE::HAL::DescriptorSetLayoutAttr, size_t> descriptorSetLayoutMap;
DenseMap<DescriptorSetLayout, size_t> descriptorSetLayoutMap;
DenseMap<IREE::HAL::PipelineLayoutAttr, uint32_t> pipelineLayoutMap;
SmallVector<iree_hal_vulkan_DescriptorSetLayoutDef_ref_t>
descriptorSetLayoutRefs;
Expand All @@ -77,18 +79,20 @@ createPipelineLayoutDefs(ArrayRef<IREE::HAL::ExecutableExportOp> exportOps,
continue; // already present
}

// Currently only one descriptor set on the compiler side. We could
// partition it by binding type (direct vs indirect, etc).
SmallVector<uint32_t> descriptorSetLayoutOrdinals;
for (auto descriptorSetLayoutAttr : pipelineLayoutAttr.getSetLayouts()) {
auto it = descriptorSetLayoutMap.find(descriptorSetLayoutAttr);
if (it != descriptorSetLayoutMap.end()) {
descriptorSetLayoutOrdinals.push_back(it->second);
continue;
}

auto descriptorSetLayout =
DescriptorSetLayout(0, pipelineLayoutAttr.getBindings());
auto it = descriptorSetLayoutMap.find(descriptorSetLayout);
if (it != descriptorSetLayoutMap.end()) {
descriptorSetLayoutOrdinals.push_back(it->second);
} else {
SmallVector<iree_hal_vulkan_DescriptorSetLayoutBindingDef_ref_t>
bindingRefs;
for (auto bindingAttr : descriptorSetLayoutAttr.getBindings()) {
uint32_t ordinal = static_cast<uint32_t>(bindingAttr.getOrdinal());
for (auto [i, bindingAttr] :
llvm::enumerate(pipelineLayoutAttr.getBindings())) {
uint32_t ordinal = static_cast<uint32_t>(i);
iree_hal_vulkan_VkDescriptorType_enum_t descriptorType = 0;
switch (bindingAttr.getType()) {
case IREE::HAL::DescriptorType::UniformBuffer:
Expand All @@ -107,7 +111,7 @@ createPipelineLayoutDefs(ArrayRef<IREE::HAL::ExecutableExportOp> exportOps,
auto bindingsRef = fbb.createOffsetVecDestructive(bindingRefs);

descriptorSetLayoutOrdinals.push_back(descriptorSetLayoutRefs.size());
descriptorSetLayoutMap[descriptorSetLayoutAttr] =
descriptorSetLayoutMap[descriptorSetLayout] =
descriptorSetLayoutRefs.size();
descriptorSetLayoutRefs.push_back(
iree_hal_vulkan_DescriptorSetLayoutDef_create(fbb, bindingsRef));
Expand All @@ -116,7 +120,7 @@ createPipelineLayoutDefs(ArrayRef<IREE::HAL::ExecutableExportOp> exportOps,
fbb.createInt32Vec(descriptorSetLayoutOrdinals);

iree_hal_vulkan_PushConstantRange_vec_ref_t pushConstantRangesRef = 0;
if (int64_t pushConstantCount = pipelineLayoutAttr.getPushConstants()) {
if (int64_t pushConstantCount = pipelineLayoutAttr.getConstants()) {
SmallVector<iree_hal_vulkan_PushConstantRange> pushConstantRanges;
iree_hal_vulkan_PushConstantRange range0;
range0.stage_flags = 0x00000020u; // VK_SHADER_STAGE_COMPUTE_BIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ static bool canSetsBeMerged(Value v1, Value v2, BufferizationPlan &plan) {
if (!v1InterfaceBinding || !v2InterfaceBinding) {
return true;
}
if (v1InterfaceBinding.getSet() != v2InterfaceBinding.getSet() ||
v1InterfaceBinding.getBinding() != v2InterfaceBinding.getBinding() ||
if (v1InterfaceBinding.getBinding() != v2InterfaceBinding.getBinding() ||
v1InterfaceBinding.getByteOffset() !=
v2InterfaceBinding.getByteOffset()) {
// If the set, binding or offsets are different, map these to different
Expand Down
Loading

0 comments on commit 1ec2e5c

Please sign in to comment.