From b7381623652bdcfd67ad2cf836f503c9d634c567 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 31 Mar 2021 21:16:02 -0700 Subject: [PATCH] Changing HAL dialect syntax to express all types. (#5239) The previous HAL ops inferred the types of the values they were working on (such as !hal.buffer or !hal.device); this prevented the specialization of those types required for buffer analysis and static device feature detection. The new syntax uses `op_name<%value : !hal.type>` on the op name indicating that the op is templated on the given `%value`. Parameters are now mostly encoded in named parens like linalg to remove a lot of the parsing ambiguity that existed when they were comma separated. Future changes for allocation will use a `!hal.buffer` and changes for device feature detection will use a `!hal.device<@id>`. Other types like `!hal.command_buffer` may also be specialized per-device. There's some partially-updated enum support in here that will be getting improved in the follow-ups; the enums will move into the type specifiers and many of the enums used on ops will go away as well. --- .../compiler/Bindings/SIP/Transforms/Passes.h | 5 +- .../Conversion/LinalgToLLVM/ConvertToLLVM.cpp | 12 +- .../test/hal_interface_bindings.mlir | 4 +- .../materialize_launch_configuration.mlir | 4 +- .../test/matmul_vectorization.mlir | 4 +- .../test/tile_and_distribute.mlir | 4 +- .../Conversion/LinalgToNVVM/ConvertToNVVM.cpp | 2 +- .../test/distribute_to_thread.mlir | 2 +- .../LinalgToNVVM/test/pipeline_test.mlir | 2 +- .../LinalgToSPIRV/ConvertToSPIRVPass.cpp | 12 +- .../SplitDispatchFunctionPass.cpp | 2 +- .../test/batch_matmul_vectorization.mlir | 4 +- .../concretize_tile_among_workgroups.mlir | 4 +- .../LinalgToSPIRV/test/convert_to_gpu.mlir | 18 +- .../LinalgToSPIRV/test/convert_to_spirv.mlir | 8 +- .../test/elementwise_vectorization.mlir | 6 +- .../test/fold-gpu-procid-uses.mlir | 12 +- .../test/linalg_tile_and_fuse.mlir | 12 +- .../materialize_launch_configuration.mlir | 2 +- .../materialize_launch_configuration2.mlir | 2 +- .../test/matmul_fused_vectorization.mlir | 2 +- .../test/matmul_vectorization.mlir | 2 +- .../test/memref_vecrotization.mlir | 6 +- .../LinalgToSPIRV/test/pipeline_test.mlir | 6 +- .../test/pipeline_test_cooperative_mat.mlir | 2 +- .../test/split_dispatch_function.mlir | 42 +- .../test/tile_and_vectorize_conv.mlir | 4 +- .../test/tile_and_vectorize_matmul.mlir | 2 +- .../test/workgroup_memory_promotion.mlir | 4 +- .../HAL/Conversion/ConversionTarget.cpp | 9 +- .../Dialect/HAL/Conversion/FlowToHAL/BUILD | 1 - .../HAL/Conversion/FlowToHAL/CMakeLists.txt | 1 - .../Conversion/FlowToHAL/ConvertFlowToHAL.cpp | 6 - .../Conversion/FlowToHAL/ConvertStreamOps.cpp | 47 +- .../Conversion/FlowToHAL/ConvertTensorOps.cpp | 54 +- .../Conversion/FlowToHAL/test/stream_ops.mlir | 149 ++- .../Conversion/FlowToHAL/test/tensor_ops.mlir | 86 +- .../HALToHAL/ConvertConstantOps.cpp | 3 +- .../HALToHAL/test/constant_ops.mlir | 8 +- .../HALToVM/test/allocator_ops.mlir | 8 +- .../Conversion/HALToVM/test/buffer_ops.mlir | 38 +- .../HALToVM/test/command_buffer_ops.mlir | 75 +- .../Conversion/HALToVM/test/constant_ops.mlir | 17 +- .../Conversion/HALToVM/test/device_ops.mlir | 2 +- .../HALToVM/test/executable_ops.mlir | 4 +- .../Conversion/IREEToHAL/ConvertIREEToHAL.cpp | 25 +- .../IREEToHAL/test/shape_constants.mlir | 11 +- .../HAL/Conversion/StandardToHAL/BUILD | 5 + .../Conversion/StandardToHAL/CMakeLists.txt | 5 + .../StandardToHAL/ConvertConstantOps.cpp | 92 ++ .../ConvertShapeOps.cpp} | 4 +- .../StandardToHAL/ConvertStandardToHAL.cpp | 17 + .../HAL/Conversion/StandardToHAL/test/BUILD | 5 +- .../StandardToHAL/test/CMakeLists.txt | 1 + .../StandardToHAL/test/constant_ops.mlir | 27 + .../Dialect/HAL/Conversion/TypeConverter.cpp | 15 +- .../Dialect/HAL/Conversion/TypeConverter.h | 2 +- iree/compiler/Dialect/HAL/IR/BUILD | 28 +- iree/compiler/Dialect/HAL/IR/CMakeLists.txt | 20 +- iree/compiler/Dialect/HAL/IR/HALBase.td | 87 +- iree/compiler/Dialect/HAL/IR/HALInterfaces.td | 82 ++ iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp | 109 +- iree/compiler/Dialect/HAL/IR/HALOps.cpp | 453 +++----- iree/compiler/Dialect/HAL/IR/HALOps.td | 985 +++++++++--------- iree/compiler/Dialect/HAL/IR/HALTypes.cpp | 294 +++++- iree/compiler/Dialect/HAL/IR/HALTypes.h | 43 +- iree/compiler/Dialect/HAL/IR/test/BUILD | 1 + .../Dialect/HAL/IR/test/CMakeLists.txt | 1 + .../HAL/IR/test/allocator_op_folding.mlir | 36 + .../Dialect/HAL/IR/test/allocator_ops.mlir | 152 ++- .../Dialect/HAL/IR/test/attributes.mlir | 6 +- .../Dialect/HAL/IR/test/buffer_folding.mlir | 32 +- .../Dialect/HAL/IR/test/buffer_ops.mlir | 105 +- .../HAL/IR/test/buffer_view_folding.mlir | 35 +- .../Dialect/HAL/IR/test/buffer_view_ops.mlir | 36 +- .../HAL/IR/test/command_buffer_folding.mlir | 52 +- .../HAL/IR/test/command_buffer_ops.mlir | 214 ++-- .../HAL/IR/test/descriptor_set_ops.mlir | 19 +- .../Dialect/HAL/IR/test/device_ops.mlir | 25 +- .../Dialect/HAL/IR/test/executable_ops.mlir | 48 +- .../Dialect/HAL/IR/test/experimental_ops.mlir | 6 +- .../Dialect/HAL/IR/test/semaphore_ops.mlir | 30 +- .../HAL/Target/SPIRVCommon/SPIRVTarget.cpp | 2 +- .../Dialect/HAL/Target/TargetBackend.cpp | 8 +- .../Dialect/HAL/Target/TargetBackend.h | 24 +- .../HAL/Target/VMLA/test/i1_types.mlir | 2 +- .../Dialect/HAL/Target/VMLA/test/linking.mlir | 72 +- .../HAL/Target/VMLA/test/smoketest.mlir | 8 +- .../Target/VulkanSPIRV/test/smoketest.mlir | 2 +- .../MaterializeConstantPoolBuffers.cpp | 58 +- .../HAL/Transforms/MaterializeInterfaces.cpp | 13 +- .../Transforms/MaterializeResourceCaches.cpp | 3 +- iree/compiler/Dialect/HAL/Transforms/Passes.h | 2 +- .../HAL/Transforms/PublicAbiGeneration.cpp | 58 +- .../test/benchmark_batch_dispatches.mlir | 36 +- .../Transforms/test/cse_variable_loads.mlir | 125 +-- .../test/identify_constant_pools.mlir | 8 +- .../test/inline_device_switches.mlir | 16 +- .../materialize_constant_pool_buffers.mlir | 35 +- .../test/materialize_interfaces.mlir | 14 +- .../test/materialize_resource_caches.mlir | 79 +- .../test/memoize_device_queries.mlir | 10 +- .../test/pack_constant_pool_storage.mlir | 16 +- .../propagate_constant_workgroup_info.mlir | 2 +- .../test/public_abi_generation.mlir | 28 +- .../test/resolve_entry_point_ordinals.mlir | 166 +-- iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp | 57 +- iree/compiler/Dialect/HAL/Utils/TypeUtils.h | 10 +- .../Conversion/ConversionPatterns.cpp | 6 +- .../Conversion/test/convert_hal_to_vm.mlir | 8 +- .../Conversion/test/convert_to_hal.mlir | 6 +- .../Conversion/HALToVMLA/ConvertHALToVMLA.cpp | 6 +- iree/modules/check/test/success.mlir | 8 +- iree/modules/tensorlist/tensorlist_test.mlir | 80 +- .../simple_embedding/simple_embedding_test.cc | 41 +- 115 files changed, 2750 insertions(+), 2066 deletions(-) create mode 100644 iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp rename iree/compiler/Dialect/HAL/Conversion/{FlowToHAL/ConvertShapeQueryOps.cpp => StandardToHAL/ConvertShapeOps.cpp} (96%) create mode 100644 iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/constant_ops.mlir create mode 100644 iree/compiler/Dialect/HAL/IR/HALInterfaces.td create mode 100644 iree/compiler/Dialect/HAL/IR/test/allocator_op_folding.mlir diff --git a/iree/compiler/Bindings/SIP/Transforms/Passes.h b/iree/compiler/Bindings/SIP/Transforms/Passes.h index 81994d4c4dc6..ce6e766e0c37 100644 --- a/iree/compiler/Bindings/SIP/Transforms/Passes.h +++ b/iree/compiler/Bindings/SIP/Transforms/Passes.h @@ -48,7 +48,10 @@ std::unique_ptr> createMaterializeReflectionAttrsPass(); // Register all Passes //===----------------------------------------------------------------------===// -inline void registerPasses() { createMaterializeReflectionAttrsPass(); } +inline void registerPasses() { + registerTransformPassPipeline(); + createMaterializeReflectionAttrsPass(); +} } // namespace SIP } // namespace IREE diff --git a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp index 8d2883fa0c0d..e1550dcb90d7 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp @@ -505,9 +505,9 @@ class ConvertHALInterfaceBindingSubspanOp : public ConvertToLLVMPattern { cast(op).queryBindingOp(); IREE::HAL::InterfaceBindingSubspanOpAdaptor newOperands(operands); MemRefType memRefType = op->getResult(0).getType().cast(); - auto memRefDesc = - abi.loadBinding(op->getLoc(), interfaceBindingOp.binding(), - newOperands.byte_offset(), memRefType, rewriter); + auto memRefDesc = abi.loadBinding( + op->getLoc(), interfaceBindingOp.binding().getZExtValue(), + newOperands.byte_offset(), memRefType, rewriter); rewriter.replaceOp(op, {memRefDesc}); return success(); } @@ -532,9 +532,9 @@ class ConvertLegacyPlaceholderOp : public ConvertToLLVMPattern { SymbolTable::lookupNearestSymbolFrom( op, op->getAttrOfType("binding"))); MemRefType memRefType = op->getResult(0).getType().cast(); - auto memRefDesc = - abi.loadBinding(op->getLoc(), interfaceBindingOp.binding(), - /*baseOffset=*/{}, memRefType, rewriter); + auto memRefDesc = abi.loadBinding( + op->getLoc(), interfaceBindingOp.binding().getZExtValue(), + /*baseOffset=*/{}, memRefType, rewriter); rewriter.replaceOp(op, {memRefDesc}); return success(); } diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir index 2f03d71b80ac..a6ceebf45369 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/hal_interface_bindings.mlir @@ -21,7 +21,7 @@ func @binding_ptrs() { "test.sink"(%memref) : (memref) -> () return } -hal.interface @io attributes {push_constants = 2 : i32, sym_visibility = "private"} { +hal.interface @io attributes {push_constants = 2 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write" } @@ -57,7 +57,7 @@ func @tie_shape() { "test.sink"(%tied_memref) : (memref) -> () return } -hal.interface @io attributes {push_constants = 2 : i32, sym_visibility = "private"} { +hal.interface @io attributes {push_constants = 2 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write" } diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir index 4be633bd6b41..acc5f2c9af1d 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/materialize_launch_configuration.mlir @@ -8,7 +8,7 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} { } hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_tensors attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -96,7 +96,7 @@ hal.executable @add attributes {sym_visibility = "private"} { } hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @add attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir index 98e9489e67f9..1ed1e3c3068b 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir @@ -9,7 +9,7 @@ hal.executable @dynamic_matmul attributes {sym_visibility = "private"} { } hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_128x128x128 attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -90,7 +90,7 @@ hal.executable @dynamic_matmul_i8_i8_i32 attributes {sym_visibility = "private"} } hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_i8_i8_i32_128x128x128 attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir index 3da07ad80ff3..4bf05388b827 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir +++ b/iree/compiler/Conversion/LinalgToLLVM/test/tile_and_distribute.mlir @@ -9,7 +9,7 @@ // } // hal.executable.target @llvm_aot, filter="dylib*" { // hal.executable.entry_point @dynamic_matmul attributes { -// interface = @legacy_io, ordinal = 0 : i32, +// interface = @legacy_io, ordinal = 0 : index, // signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, // !flow.dispatch.tensor) -> ()} // module { @@ -57,7 +57,7 @@ hal.executable @static_matmul attributes {sym_visibility = "private"} { } hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @static_matmul attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { diff --git a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp index 75708aef2887..12f7ccf51fc7 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp +++ b/iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp @@ -112,7 +112,7 @@ class ConvertIREEBindingOp : public ConvertToLLVMPattern { op, op->getAttrOfType("binding")); auto interfaceBindingOp = cast(symbol); Value llvmBufferBasePtr = - llvmFuncOp.getArgument(interfaceBindingOp.binding()); + llvmFuncOp.getArgument(interfaceBindingOp.binding().getZExtValue()); if (memrefType.hasStaticShape()) { auto desc = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), memrefType, llvmBufferBasePtr); diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/distribute_to_thread.mlir b/iree/compiler/Conversion/LinalgToNVVM/test/distribute_to_thread.mlir index 90a304a1cc33..dafd5953ee2c 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/distribute_to_thread.mlir +++ b/iree/compiler/Conversion/LinalgToNVVM/test/distribute_to_thread.mlir @@ -2,7 +2,7 @@ hal.executable @add_dispatch_0 attributes {sym_visibility = "private"} { hal.executable.target @cuda, filter="cuda" { - hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @add_dispatch_0() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir index 9d6a7fd3a645..864a7f2c594b 100644 --- a/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir +++ b/iree/compiler/Conversion/LinalgToNVVM/test/pipeline_test.mlir @@ -9,7 +9,7 @@ hal.executable @simpleMath_ex_dispatch_0 { hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" } hal.executable.target @cuda, filter="cuda" { - hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} + hal.executable.entry_point @add_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @add_dispatch_0() { %c0 = constant 0 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp index 7b55c8e8657f..a41d51b9eb33 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToSPIRVPass.cpp @@ -158,9 +158,9 @@ IREE::HAL::InterfaceBindingOp getBindingOp(Operation *op) { } /// Returns the (set, binding) pair for the given placeholder op. -std::pair getPlaceholderSetAndBinding(Operation *op) { +std::pair getPlaceholderSetAndBinding(Operation *op) { IREE::HAL::InterfaceBindingOp bindingOp = getBindingOp(op); - return {bindingOp.set(), bindingOp.binding()}; + return {bindingOp.set().getSExtValue(), bindingOp.binding().getSExtValue()}; } /// Returns the set of resources that should be marked as aliased in SPIR-V. @@ -259,8 +259,8 @@ struct InterfaceOpConverter final : public OpConversionPattern { // placeholder op's pointer address as the `id`. spirv::GlobalVariableOp varOp = insertResourceVariable( interfaceOp.getLoc(), convertedType, - reinterpret_cast(interfaceOp.getOperation()), bindingOp.set(), - bindingOp.binding(), + reinterpret_cast(interfaceOp.getOperation()), + bindingOp.set().getZExtValue(), bindingOp.binding().getZExtValue(), aliasedResources.contains(interfaceOp.getOperation()), *moduleOp.getBody(), rewriter); @@ -484,8 +484,10 @@ LogicalResult HALInterfaceLoadConstantConverter::matchAndRewrite( auto halInterfaceOps = llvm::to_vector<1>(moduleOp.getOps()); assert(halInterfaceOps.size() == 1); + assert(halInterfaceOps.front().push_constants().hasValue()); - unsigned elementCount = *halInterfaceOps.front().push_constants(); + uint64_t elementCount = + (*halInterfaceOps.front().push_constants()).getZExtValue(); unsigned offset = loadOp.offset().getZExtValue(); // The following function generates SPIR-V ops with i32 types. So it does type diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp index a5e5c109c475..eaf6fa10debd 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp +++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp @@ -283,7 +283,7 @@ LogicalResult SplitDispatchFunctionPass::splitDispatchFunction( builder.clone(*oldEntryPointOp.getOperation())); clonedEntryPointOp.sym_nameAttr(builder.getStringAttr(newFnName)); clonedEntryPointOp.ordinalAttr( - builder.getI32IntegerAttr(static_cast(entryPoints.size()))); + builder.getIndexAttr(static_cast(entryPoints.size()))); entryPoints.push_back(builder.getSymbolRefAttr(clonedEntryPointOp)); } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir index c44336f0ab90..f8b559dc5c16 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir @@ -8,7 +8,7 @@ hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private" } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @batch_matmul_static_shape attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -298,7 +298,7 @@ hal.executable @batch_matmul_fused_fillop attributes {sym_visibility = "private" } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @batch_matmul_fused_fillop attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir index 140dd0aae13d..9e174cb709f9 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/concretize_tile_among_workgroups.mlir @@ -8,7 +8,7 @@ hal.executable @conv2d_static_shape attributes {sym_visibility = "private"} { } hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @conv2d_static_shape attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @conv2d_static_shape() { @@ -119,7 +119,7 @@ hal.executable @matmul_dynamic_shape attributes {sym_visibility = "private"} { } hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @matmul_dynamic_shape attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @matmul_dynamic_shape() { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir index b57bec847e69..29b3dedd564b 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir @@ -10,7 +10,7 @@ // } // hal.executable.target @vulkan, filter="vulkan*" { // hal.executable.entry_point @parallel_4D attributes { -// interface = @legacy_io, ordinal = 0 : i32, +// interface = @legacy_io, ordinal = 0 : index, // signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, // !flow.dispatch.tensor) -> ()} // module attributes { @@ -89,7 +89,7 @@ hal.executable @parallel_4D_static attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @parallel_4D_static attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -168,7 +168,7 @@ hal.executable @scalar_add attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @scalar_add attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -222,7 +222,7 @@ hal.executable @reduce_sum attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @reduce_sum attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -295,7 +295,7 @@ hal.executable @matmul attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @matmul attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -367,7 +367,7 @@ hal.executable @conv_1d attributes {sym_visibility = "private"} { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vulkan_spirv, filter="vulkan*" { - hal.executable.entry_point @conv_1d attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>} + hal.executable.entry_point @conv_1d attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<3x8x1xf32>, tensor<3x1x1xf32>) -> tensor<3x6x1xf32>} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { func @conv_1d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} { %cst = constant 0.000000e+00 : f32 @@ -426,7 +426,7 @@ hal.executable @conv_no_padding attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @conv_no_padding attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -542,7 +542,7 @@ hal.executable @conv_3d attributes {sym_visibility = "private"} { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vulkan_spirv, filter="vulkan*" { - hal.executable.entry_point @conv_3d attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>} + hal.executable.entry_point @conv_3d attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<2x8x8x8x3xf32>, tensor<2x2x2x3x2xf32>) -> tensor<2x7x7x7x2xf32>} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { func @conv_3d() attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}} { %cst = constant 0.000000e+00 : f32 @@ -603,7 +603,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vulkan, filter="vulkan*" { - hal.executable.entry_point @pooling_nhwc_max attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} { + hal.executable.entry_point @pooling_nhwc_max attributes {interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} { ^bb0(%arg0: index, %arg1: index, %arg2: index): // no predecessors %c4 = constant 4 : index %c1 = constant 1 : index diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir index c6e2272d7178..4bb442417b8b 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_spirv.mlir @@ -14,7 +14,7 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} @@ -61,7 +61,7 @@ hal.executable @elementwise_static_shape attributes {sym_visibility = "private"} // Negative test as we currently don't support vectorization when there is a // transpose. // CHECK-LABEL: func @elementwise_transpose -// CHECK-NOT: vector.transfer_read +// CHECK-NOT: vector.transfer_read // CHECK: linalg.generic hal.executable @elementwise_transpose attributes {sym_visibility = "private"} { hal.interface @legacy_io { @@ -71,7 +71,7 @@ hal.executable @elementwise_transpose attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @elementwise_transpose attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir index 8adf600f5a96..4baaf841af65 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/fold-gpu-procid-uses.mlir @@ -5,7 +5,7 @@ hal.executable @fold_block_id attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @fold_block_id attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = () -> ()} { ^bb0(%arg0 : index, %arg1 : index, %arg2 : index): %x = constant 112: index @@ -39,7 +39,7 @@ hal.executable @fold_interface_workgroup_id attributes {sym_visibility = "privat } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @fold_interface_workgroup_id attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = () -> ()} { ^bb0(%arg0 : index, %arg1 : index, %arg2 : index): %x = constant 112: index @@ -73,7 +73,7 @@ hal.executable @fold_thread_id attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @fold_thread_id attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = () -> ()} module { func @fold_thread_id() -> (index, index, index) @@ -102,7 +102,7 @@ hal.executable @does_not_fold_mod attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @does_not_fold_mod attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = () -> ()} module { func @does_not_fold_mod() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} { @@ -123,7 +123,7 @@ hal.executable @does_not_fold_div attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @does_not_fold_div attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = () -> ()} module { func @does_not_fold_div() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} { @@ -144,7 +144,7 @@ hal.executable @does_not_fold_symbol_mul_symbol attributes {sym_visibility = "pr } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @does_not_fold_symbol_mul_symbol attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = () -> ()} module { func @does_not_fold_symbol_mul_symbol() -> index attributes {spv.entry_point_abi = {local_size = dense<[8, 2, 1]> : vector<3xi32>}} { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir index cf4502911f58..b6c1e0166c6a 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir @@ -9,7 +9,7 @@ hal.executable @conv_no_padding attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @conv_no_padding attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -76,7 +76,7 @@ hal.executable @matmul attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -143,7 +143,7 @@ hal.executable @pooling_nhwc_max attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @pooling_nhwc_max attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -207,7 +207,7 @@ hal.executable @matmul_fusion attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_fusion attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -276,7 +276,7 @@ hal.executable @conv_no_padding_fusion attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @conv_no_padding_fusion attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -346,7 +346,7 @@ hal.executable @three_op_fusion attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @three_op_fusion attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir index 07cec3eab73c..1617b52e172b 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration.mlir @@ -8,7 +8,7 @@ hal.executable @matmul_tensors attributes {sym_visibility = "private"} { } hal.executable.target @llvm_aot, filter="dylib*" { hal.executable.entry_point @matmul_tensors attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir index 91bbf2c3c394..c91ef9c3a7b4 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/materialize_launch_configuration2.mlir @@ -8,7 +8,7 @@ hal.executable @add attributes {sym_visibility = "private"} { } hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @add attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, SwiftShader:CPU, {cooperative_matrix_properties_nv = [], max_compute_shared_memory_size = 16384 : i32, max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>, subgroup_size = 4 : i32}>} { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir index 968a9bf6724d..5b03e914994f 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir @@ -8,7 +8,7 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir index 4becaee0e3df..b0803a0762d3 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir @@ -9,7 +9,7 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir index cbd18c2a0d3f..42b179a6c806 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/memref_vecrotization.mlir @@ -54,7 +54,7 @@ func @resource_copy() { return } -hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} { +hal.interface @legacy_io attributes {push_constants = 5 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write" } @@ -80,7 +80,7 @@ func @resource_copy_f16() { return } -hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} { +hal.interface @legacy_io attributes {push_constants = 5 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write" } @@ -106,7 +106,7 @@ func @resource_copy_8xf16() { return } -hal.interface @legacy_io attributes {push_constants = 5 : i32, sym_visibility = "private"} { +hal.interface @legacy_io attributes {push_constants = 5 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=1, binding=2, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=3, binding=4, type="StorageBuffer", access="Write" } diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir index 165a7f3b4f02..9e03a7e53a01 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test.mlir @@ -8,7 +8,7 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -67,7 +67,7 @@ hal.executable @matmul_fill_fused attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_fill_fused attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -129,7 +129,7 @@ hal.executable @matmul_add_fused attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_add_fused attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir index d928cc6e3828..3957e6867e3a 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/pipeline_test_cooperative_mat.mlir @@ -9,7 +9,7 @@ hal.executable @matmul_static_shape attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_static_shape attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir index 3b3cc38beb21..bd222a2a220b 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir @@ -8,7 +8,7 @@ hal.executable @kernel_fusable_fill_conv1d_ops attributes {sym_visiblity = "priv } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_conv1d_ops attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -36,7 +36,7 @@ hal.executable @kernel_fusable_fill_conv1d_ops attributes {sym_visiblity = "priv outs(%ts2 : memref) return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" @@ -54,7 +54,7 @@ hal.executable @kernel_fusable_fill_conv2d_ops attributes {sym_visiblity = "priv } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_conv2d_ops attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -82,7 +82,7 @@ hal.executable @kernel_fusable_fill_conv2d_ops attributes {sym_visiblity = "priv outs(%ts2 : memref) return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" @@ -101,7 +101,7 @@ hal.executable @kernel_fusable_fill_conv3d_ops attributes {sym_visiblity = "priv } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_conv3d_ops attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -129,7 +129,7 @@ hal.executable @kernel_fusable_fill_conv3d_ops attributes {sym_visiblity = "priv outs(%ts2 : memref) return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" @@ -148,7 +148,7 @@ hal.executable @kernel_fusable_fill_matmul_ops attributes {sym_visiblity = "priv } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_matmul_ops attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -176,7 +176,7 @@ hal.executable @kernel_fusable_fill_matmul_ops attributes {sym_visiblity = "priv outs(%ts3 : memref) return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" @@ -195,7 +195,7 @@ hal.executable @kernel_fusable_pooling attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_pooling attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -235,7 +235,7 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} // CHECK: hal.executable.entry_point @kernel_dispatch_0 @@ -283,7 +283,7 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { linalg.fill(%ts2, %cst) : memref, f32 return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" @@ -302,7 +302,7 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} // CHECK: hal.executable.entry_point @kernel_dispatch_0 @@ -363,7 +363,7 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { outs(%ts2 : memref) return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" @@ -384,7 +384,7 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} // CHECK-NOT: hal.entry_point_schedule @@ -426,7 +426,7 @@ hal.executable @kernel attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -464,7 +464,7 @@ hal.executable @subview_interleaved attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @subview_interleaved attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { func @subview_interleaved() { @@ -515,7 +515,7 @@ hal.executable @reshape_interleaved attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @reshape_interleaved attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -575,7 +575,7 @@ hal.executable @predict_ex_dispatch_0 attributes {sym_visiblity = "private"} { } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @predict_ex_dispatch_0 attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -595,7 +595,7 @@ hal.executable @predict_ex_dispatch_0 attributes {sym_visiblity = "private"} { } return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" @@ -636,7 +636,7 @@ hal.executable @kernel_fusable_fill_matmul_generic_ops attributes {sym_visiblity } hal.executable.target @vulkan, filter="vulkan*" { hal.executable.entry_point @kernel_fusable_fill_matmul_generic_ops attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module { @@ -681,7 +681,7 @@ hal.executable @kernel_fusable_fill_matmul_generic_ops attributes {sym_visiblity } return } - hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} { + hal.interface @legacy_io attributes {push_constants = 1 : index, sym_visibility = "private"} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @arg2, set=0, binding=1, type="StorageBuffer", access="Read" diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir index d37146c2d5da..c884f0610b2b 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_conv.mlir @@ -8,7 +8,7 @@ hal.executable @conv_static_shape_f32 attributes {sym_visibility = "private"} { } hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @conv_static_shape_f32 attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @conv_static_shape_f32() { @@ -96,7 +96,7 @@ hal.executable @depthwise_conv_static_shape_f32 attributes {sym_visibility = "pr } hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @depthwise_conv_static_shape_f32 attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @depthwise_conv_static_shape_f32() { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir index 4095948d9b4a..54cf7e29c9ee 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir @@ -8,7 +8,7 @@ hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} } hal.executable.target @vulkan_spirv, filter="vulkan*" { hal.executable.entry_point @matmul_static_shape_f16 attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes {spv.target_env = #spv.target_env<#spv.vce, ARM:IntegratedGPU, {}>} { func @matmul_static_shape_f16() { diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir index ceba3ff6d084..dab99eeb21d4 100644 --- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir +++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir @@ -9,7 +9,7 @@ hal.executable @matmul_tile attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @matmul_tile attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { @@ -67,7 +67,7 @@ hal.executable @conv_no_padding_tile attributes {sym_visibility = "private"} { } hal.executable.target @vulkan, filter="dylib*" { hal.executable.entry_point @conv_no_padding_tile attributes { - interface = @legacy_io, ordinal = 0 : i32, + interface = @legacy_io, ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, !flow.dispatch.tensor) -> ()} module attributes { diff --git a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp index 02f7411f2985..811df15ffa66 100644 --- a/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp @@ -75,7 +75,7 @@ LogicalResult HALConversionTarget::applyDefaultBufferRewrite( for (auto srcDstOperand : llvm::zip(srcOp->getOperands(), operands)) { auto srcOperand = std::get<0>(srcDstOperand); auto dstOperand = std::get<1>(srcDstOperand); - if (HALTypeConverter::shouldConvertToHalBuffer(srcOperand.getType())) { + if (HALTypeConverter::shouldConvertToBuffer(srcOperand.getType())) { // Create the buffer view that we'll pass to the function. // Note that we expect this to be CSE'd if there are multiple calls // using the same buffer. @@ -95,7 +95,7 @@ LogicalResult HALConversionTarget::applyDefaultBufferRewrite( } } for (auto resultType : srcOp->getResultTypes()) { - if (HALTypeConverter::shouldConvertToHalBuffer(resultType)) { + if (HALTypeConverter::shouldConvertToBuffer(resultType)) { state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext())); } else { // Normal pass-through result. @@ -114,9 +114,10 @@ LogicalResult HALConversionTarget::applyDefaultBufferRewrite( Type resultType; Value resultValue; std::tie(resultType, resultValue) = resultTypeValue; - if (HALTypeConverter::shouldConvertToHalBuffer(resultType)) { + if (HALTypeConverter::shouldConvertToBuffer(resultType)) { results.push_back(rewriter.createOrFold( - srcOp->getLoc(), resultValue)); + srcOp->getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + resultValue)); } else { results.push_back(resultValue); } diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD index 5fc818974ac9..51ca8976eb42 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/BUILD @@ -22,7 +22,6 @@ cc_library( name = "FlowToHAL", srcs = [ "ConvertFlowToHAL.cpp", - "ConvertShapeQueryOps.cpp", "ConvertStreamOps.cpp", "ConvertTensorOps.cpp", "ConvertVariableOps.cpp", diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt index 1159f459099f..449d404fcb5a 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/CMakeLists.txt @@ -17,7 +17,6 @@ iree_cc_library( "ConvertFlowToHAL.h" SRCS "ConvertFlowToHAL.cpp" - "ConvertShapeQueryOps.cpp" "ConvertStreamOps.cpp" "ConvertTensorOps.cpp" "ConvertVariableOps.cpp" diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp index cec7acefff1a..f9a65f4571b4 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertFlowToHAL.cpp @@ -36,11 +36,6 @@ void populateFlowVariableToHALPatterns(MLIRContext *context, OwningRewritePatternList &patterns, TypeConverter &converter); -// Populates only the std.dim and std.rank conversion patterns. -void populateHalBufferViewShapePatterns(MLIRContext *context, - OwningRewritePatternList &patterns, - TypeConverter &converter); - void setupFlowToHALLegality(MLIRContext *context, ConversionTarget &conversionTarget, TypeConverter &typeConverter) { @@ -54,7 +49,6 @@ void populateFlowToHALPatterns(MLIRContext *context, populateFlowStreamToHALPatterns(context, patterns, typeConverter); populateFlowTensorToHALPatterns(context, patterns, typeConverter); populateFlowVariableToHALPatterns(context, patterns, typeConverter); - populateHalBufferViewShapePatterns(context, patterns, typeConverter); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp index ad61d2843f6e..6eaaea3c74eb 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp @@ -110,11 +110,11 @@ static Value allocateOutputBuffer(Value streamValue, Value externalValue, loc, allocator, *shape, elementType.getValue()) .getResult(); - auto buffer = - rewriter - .create(loc, allocator, memoryTypes, - bufferUsage, allocationSize) - .getResult(); + auto buffer = rewriter + .create( + loc, IREE::HAL::BufferType::get(rewriter.getContext()), + allocator, memoryTypes, bufferUsage, allocationSize) + .getResult(); return buffer; } @@ -183,11 +183,11 @@ static Value allocateTransientBuffer(Value streamValue, Value allocator, loc, allocator, *shape, elementType.getValue()) .getResult(); - auto buffer = - rewriter - .create(loc, allocator, memoryTypes, - bufferUsage, allocationSize) - .getResult(); + auto buffer = rewriter + .create( + loc, IREE::HAL::BufferType::get(rewriter.getContext()), + allocator, memoryTypes, bufferUsage, allocationSize) + .getResult(); return buffer; } @@ -331,14 +331,17 @@ static void recordPushConstants(Value device, Value commandBuffer, return; } - uint64_t maxPushConstants = interfaceOp.push_constants().getValueOr(0); + uint64_t maxPushConstants = + interfaceOp.push_constants().hasValue() + ? interfaceOp.push_constants().getValue().getZExtValue() + : 0; (void)maxPushConstants; assert(pushConstantValues.size() <= maxPushConstants && "uniform buffer spilling not yet implemented"); rewriter.create( dispatchOp.getLoc(), commandBuffer, executableLayout, - rewriter.getI32IntegerAttr(0), pushConstantValues); + rewriter.getIndexAttr(0), pushConstantValues); } static LogicalResult recordPushBindings(Value device, Value commandBuffer, @@ -353,7 +356,9 @@ static LogicalResult recordPushBindings(Value device, Value commandBuffer, SmallVector bindings; auto zeroOffset = rewriter.createOrFold(dispatchOp.getLoc(), 0); - auto pushBinding = [&](Value tensorValue) -> LogicalResult { + auto pushBinding = + [&](Value tensorValue, + IREE::HAL::MemoryAccessBitfield accessType) -> LogicalResult { auto &bufferRange = bufferSet.rangeMap[tensorValue]; assert(bufferRange.buffer && "buffer not preallocated"); auto value = IREE::HAL::TensorRewriteAdaptor::getChecked( @@ -374,7 +379,8 @@ static LogicalResult recordPushBindings(Value device, Value commandBuffer, LLVM_DEBUG(llvm::dbgs() << " + OPERAND(" << it.index() << "): " << it.value() << "\n"); if (it.value().getType().isa()) { - if (failed(pushBinding(it.value()))) { + if (failed( + pushBinding(it.value(), IREE::HAL::MemoryAccessBitfield::Read))) { return failure(); } } @@ -385,7 +391,8 @@ static LogicalResult recordPushBindings(Value device, Value commandBuffer, if (dispatchOp.getTiedResultOperandIndex(it.index())) { LLVM_DEBUG(llvm::dbgs() << " TIED TO OPERAND; SKIP\n"); } else { - if (failed(pushBinding(it.value()))) { + if (failed(pushBinding(it.value(), + IREE::HAL::MemoryAccessBitfield::DiscardWrite))) { return failure(); } } @@ -449,8 +456,6 @@ static LogicalResult recordDispatch(Value device, Value commandBuffer, } // TODO(benvanik): support extended push constants. dispatchState.basePushConstantOffset = 0; - dispatchState.operands = operandAdaptors; - dispatchState.results = resultAdaptors; // Ask each target backend to record their dispatch logic. IREE::HAL::DeviceSwitchRewriter switchRewriter(dispatchOp.getLoc(), @@ -473,8 +478,8 @@ static LogicalResult recordDispatch(Value device, Value commandBuffer, rewriter.createOrFold( dispatchOp.getLoc(), IREE::HAL::ExecutableLayoutType::get(device.getContext()), device, - interfaceOp.getExecutableSetLayoutsAttr(), - interfaceOp.push_constantsAttr()); + interfaceOp.push_constantsAttr(), + interfaceOp.getExecutableSetLayoutsAttr()); // Setup push constants for any dynamic values we need to pass across at // runtime. @@ -719,7 +724,9 @@ class ExStreamFragmentOpConversion // information attached to the stream. auto commandBuffer = rewriter.createOrFold( - streamOp.getLoc(), device, mode, category); + streamOp.getLoc(), + IREE::HAL::CommandBufferType::get(rewriter.getContext()), device, + mode, category); rewriter.create(streamOp.getLoc(), commandBuffer); diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp index 591a1b19d8ce..e7f8872a604c 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertTensorOps.cpp @@ -32,55 +32,6 @@ namespace mlir { namespace iree_compiler { namespace { -class ConstantTensorOpConversion - : public OpConversionPattern { - public: - ConstantTensorOpConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(ctx) {} - - LogicalResult matchAndRewrite( - mlir::ConstantOp constantOp, llvm::ArrayRef newOperands, - ConversionPatternRewriter &rewriter) const override { - if (!constantOp.getType().isa()) return failure(); - - auto device = - rewriter.createOrFold(constantOp.getLoc()); - auto allocator = rewriter.createOrFold( - constantOp.getLoc(), device); - - // TODO(benvanik): compute from SSA use-def chain uses. - IREE::HAL::MemoryTypeBitfield memoryTypes = - IREE::HAL::MemoryTypeBitfield::DeviceLocal | - IREE::HAL::MemoryTypeBitfield::HostVisible; - IREE::HAL::BufferUsageBitfield bufferUsage = - IREE::HAL::BufferUsageBitfield::All | - IREE::HAL::BufferUsageBitfield::Constant; - - auto elementsAttr = constantOp.getValue().cast(); - auto elementsTy = elementsAttr.getType().cast(); - - // Expand boolean elements to the minimum bit widht supported by the HAL - // (8-bits). - // To improve memory bandwidth and increase computae we should prefer to - // pack 1-bit tensors into wider storage before this lossy conversion. For - // example bitwise ops on 8x32xi1 can be converted to ops on tensor<8xi32>. - if (elementsTy.getElementType().isInteger(1)) { - elementsAttr = - elementsAttr.mapValues(rewriter.getIntegerType(8), - llvm::function_ref( - [](const APInt &val) -> APInt { - return APInt(8, val.getBoolValue()); - })); - } - - auto buffer = rewriter.createOrFold( - constantOp.getLoc(), allocator, memoryTypes, bufferUsage, elementsAttr); - - rewriter.replaceOp(constantOp, {buffer}); - return success(); - } -}; - class TensorLoadOpConversion : public OpConversionPattern { public: @@ -162,9 +113,8 @@ class TensorTraceOpConversion void populateFlowTensorToHALPatterns(MLIRContext *context, OwningRewritePatternList &patterns, TypeConverter &converter) { - patterns.insert(context, - converter); + patterns.insert(context, converter); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir index e183db82c4cf..1d2e30528305 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir @@ -8,7 +8,7 @@ hal.executable @ex0 { hal.executable.target @vmla, filter="vmla" { hal.executable.entry_point @entry0 attributes { interface = @interface, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (tensor<128xf32>) -> tensor<128xf32> } module {} @@ -16,28 +16,43 @@ hal.executable @ex0 { } // CHECK-LABEL: func @multipleDispatches -func @multipleDispatches(%arg0: tensor<128xf32>) -> tensor<128xf32> { +// CHECK-SAME: %[[INPUT_BUF:.+]]: !hal.buffer +func @multipleDispatches(%input: tensor<128xf32>) -> tensor<128xf32> { // CHECK-DAG: %[[C0:.+]] = constant 0 // CHECK-DAG: %[[C128:.+]] = constant 128 %cst = constant 128 : index - // CHECK: %[[RET_BUF:.+]] = hal.allocator.allocate {{.+}}, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch" - // CHECK: %[[TMP_BUF:.+]] = hal.allocator.allocate {{.+}}, "DeviceVisible|DeviceLocal", "Transfer|Dispatch" - // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" - // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %0 = flow.ex.stream.fragment(%cst, %arg0) : (index, tensor<128xf32>) -> tensor<128xf32> = + // CHECK: %[[RET_BUF:.+]] = hal.allocator.allocate + // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Transfer|Mapping|Dispatch") + // CHECK-SAME: : !hal.buffer{%c512} + // CHECK: %[[TMP_BUF:.+]] = hal.allocator.allocate + // CHECK-SAME: type("DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Transfer|Dispatch") + // CHECK-SAME: : !hal.buffer{%c512} + // CHECK: %[[CMD:.+]] = hal.command_buffer.create + // CHECK-SAME: mode(OneShot) + // CHECK-SAME: categories("Transfer|Dispatch") + // CHECK-NEXT: hal.command_buffer.begin<%[[CMD]] + %0 = flow.ex.stream.fragment(%cst, %input) : (index, tensor<128xf32>) -> tensor<128xf32> = (%arg1: index, %arg2: tensor<128xf32>) -> tensor<128xf32> { // CHECK-DAG: %[[EXE_LAYOUT:.+]] = hal.executable_layout.lookup - // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set = %c0, bindings = [%c0 = (%arg0, %c0, %c512), %c1 = (%[[TMP_BUF]], %c0, %c512)] - // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz + // CHECK: hal.command_buffer.push_descriptor_set + // CHECK-SAME: layout(%[[EXE_LAYOUT]] : !hal.executable_layout)[%c0] + // CHECK-SAME: bindings([ + // CHECK-NEXT: %c0 = (%[[INPUT_BUF]] : !hal.buffer)[%c0, %c512], + // CHECK-NEXT: %c1 = (%[[TMP_BUF]] : !hal.buffer)[%c0, %c512] + // CHECK: hal.command_buffer.dispatch.symbol + // CHECK-SAME: target(@ex0::@vmla::@entry0) // CHECK: hal.command_buffer.execution_barrier %1 = flow.dispatch @ex0::@entry0[%arg1](%arg2) : (tensor<128xf32>) -> tensor<128xf32> // CHECK: hal.command_buffer.push_descriptor_set - // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz + // CHECK: hal.command_buffer.dispatch.symbol + // CHECK-SAME: target(@ex0::@vmla::@entry0) // CHECK: hal.command_buffer.execution_barrier %2 = flow.dispatch @ex0::@entry0[%arg1](%1) : (tensor<128xf32>) -> tensor<128xf32> flow.return %2 : tensor<128xf32> } - // CHECK: hal.command_buffer.end %[[CMD]] + // CHECK: hal.command_buffer.end<%[[CMD]] // CHECK-NEXT: hal.ex.submit_and_wait {{.+}}, %[[CMD]] // CHECK-NEXT: return %[[RET_BUF]] return %0 : tensor<128xf32> @@ -55,12 +70,15 @@ func @tensorSlice(%arg0 : tensor<5x24x48xf32>) -> tensor<3x24x48xf32> { %c48 = constant 48 : index // CHECK: %[[RET_BUF:.+]] = hal.allocator.allocate // CHECK: %[[CMD:.+]] = hal.command_buffer.create - // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] + // CHECK-NEXT: hal.command_buffer.begin<%[[CMD]] %2 = flow.ex.stream.fragment(%arg0, %c0, %c2, %c3, %c24, %c48) : (tensor<5x24x48xf32>, index, index, index, index, index) -> tensor<3x24x48xf32> = (%arg2 : tensor<5x24x48xf32>, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<3x24x48xf32> { - // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[SBUF]], %c9216, %[[RET_BUF]], %c0, %c13824 + // CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] + // CHECK-SAME: source(%[[SBUF]] : !hal.buffer)[%c9216] + // CHECK-SAME: target(%[[RET_BUF]] : !hal.buffer)[%c0] + // CHECK-SAME: length(%c13824) %slice = flow.tensor.slice %arg2[%arg4, %arg3, %arg3 for %arg5, %arg6, %arg7] : tensor<5x24x48xf32> -> tensor<3x24x48xf32> flow.return %slice : tensor<3x24x48xf32> @@ -77,17 +95,23 @@ func @tensorUpdate(%arg0 : tensor<1x1x10xf32>, %arg1 : tensor<5x1x10xf32>) -> te %c1 = constant 1 : index // CHECK: %[[RET_BUF:.+]] = hal.allocator.allocate // CHECK: %[[CMD:.+]] = hal.command_buffer.create - // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] + // CHECK-NEXT: hal.command_buffer.begin<%[[CMD]] %0 = flow.ex.stream.fragment(%arg0, %arg1, %c4, %c1) : (tensor<1x1x10xf32>, tensor<5x1x10xf32>, index, index) -> tensor<5x1x10xf32> = (%arg2: tensor<1x1x10xf32>, %arg3: tensor<5x1x10xf32>, %arg4: index, %arg5: index) -> tensor<5x1x10xf32> { - // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[TBUF]], %c0, %[[RET_BUF]], %c0, %c200 + // CHECK-NEXT: hal.command_buffer.copy_buffer + // CHECK-SAME: source(%[[TBUF]] : !hal.buffer)[%c0] + // CHECK-SAME: target(%[[RET_BUF]] : !hal.buffer)[%c0] + // CHECK-SAME: length(%c200) // CHECK: hal.command_buffer.execution_barrier %clone = flow.tensor.clone %arg3 : tensor<5x1x10xf32> - // CHECK-NEXT: hal.command_buffer.copy_buffer %[[CMD]], %[[UBUF]], %c0, %[[RET_BUF]], %c204, %c40 + // CHECK-NEXT: hal.command_buffer.copy_buffer + // CHECK-SAME: source(%[[UBUF]] : !hal.buffer)[%c0] + // CHECK-SAME: target(%[[RET_BUF]] : !hal.buffer)[%c204] + // CHECK-SAME: length(%c40) %1 = flow.tensor.update %arg2, %clone[%arg4, %arg5, %arg5] : tensor<1x1x10xf32> -> tensor<5x1x10xf32> flow.return %1 : tensor<5x1x10xf32> } - // CHECK: hal.command_buffer.end %[[CMD]] + // CHECK: hal.command_buffer.end<%[[CMD]] // CHECK: return %[[RET_BUF]] return %0 : tensor<5x1x10xf32> } @@ -95,14 +119,14 @@ func @tensorUpdate(%arg0 : tensor<1x1x10xf32>, %arg1 : tensor<5x1x10xf32>) -> te // ----- hal.executable @ex0 { - hal.interface @interface attributes {push_constants = 2 : i32} { + hal.interface @interface attributes {push_constants = 2 : index} { hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write" } hal.executable.target @vmla, filter="vmla" { hal.executable.entry_point @entry0 attributes { interface = @interface, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (tensor, index) -> tensor } module {} @@ -118,7 +142,10 @@ func @dispatchWithShapeTies(%arg0: tensor, %bs : index) -> tensor{%cst}, index) -> tensor{%cst} = @@ -141,7 +168,7 @@ hal.executable @ex attributes {sym_visibility = "private"} { hal.executable.target @tgt, filter="dylib-llvm-aot" { hal.executable.entry_point @entry attributes { interface = @legacy_io, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () } module {} @@ -149,33 +176,39 @@ hal.executable @ex attributes {sym_visibility = "private"} { } // CHECK-LABEL: func @static_tiled_dispatch -func @static_tiled_dispatch(%arg0: tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> { +// CHECK-SAME: %[[INPUT:.+]]: !hal.buffer +func @static_tiled_dispatch(%input: tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> { %c1024 = constant 1024 : index %c512 = constant 512 : index - // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" - // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] - %1 = flow.ex.stream.fragment(%arg0, %c1024, %c512) : (tensor<7x4x24xf32>, index, index) -> tensor<4x7x1024xf32> = + // CHECK: %[[CMD:.+]] = hal.command_buffer.create + // CHECK-NEXT: hal.command_buffer.begin<%[[CMD]] + %1 = flow.ex.stream.fragment(%input, %c1024, %c512) : (tensor<7x4x24xf32>, index, index) -> tensor<4x7x1024xf32> = (%arg3: tensor<7x4x24xf32>, %arg6: index, %arg7: index) -> tensor<4x7x1024xf32> { - // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %c2688), %c1 = (%buffer, %c0, %c114688)] - // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex::@tgt::@entry, workgroup_xyz + // CHECK: hal.command_buffer.push_descriptor_set + // CHECK-SAME: layout(%executable_layout : !hal.executable_layout)[%c0] + // CHECK-SAME: bindings([ + // CHECK-NEXT: %c0 = (%[[INPUT]] : !hal.buffer)[%c0, %c2688], + // CHECK-NEXT: %c1 = (%{{.+}} : !hal.buffer)[%c0, %c114688] + // CHECK: hal.command_buffer.dispatch.symbol + // CHECK-SAME: target(@ex::@tgt::@entry) %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> flow.return %0 : tensor<4x7x1024xf32> } - // CHECK: hal.command_buffer.end %[[CMD]] + // CHECK: hal.command_buffer.end<%[[CMD]] return %1 : tensor<4x7x1024xf32> } // ----- hal.executable @ex attributes {sym_visibility = "private"} { - hal.interface @legacy_io attributes {push_constants = 4 : i32} { + hal.interface @legacy_io attributes {push_constants = 4 : index} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" } hal.executable.target @tgt, filter="dylib-llvm-aot" { hal.executable.entry_point @entry attributes { interface = @legacy_io, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, index, index, index, index) -> () } module {} @@ -183,15 +216,23 @@ hal.executable @ex attributes {sym_visibility = "private"} { } // CHECK-LABEL: func @dynamic_tiled_dispatch +// CHECK-SAME: %[[INPUT:.+]]: !hal.buffer func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: index) -> tensor { %c1024 = constant 1024 : index %c512 = constant 512 : index - // CHECK: %[[CMD:.+]] = hal.command_buffer.create {{.+}}, OneShot, "Transfer|Dispatch" - // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] + // CHECK: %[[CMD:.+]] = hal.command_buffer.create + // CHECK-NEXT: hal.command_buffer.begin<%[[CMD]] %2 = flow.ex.stream.fragment(%arg0, %arg1, %arg2, %c1024, %c512) : (tensor<7x?x24x?xf32>{%arg1, %arg2}, index, index, index, index) -> tensor{%arg2, %arg1} = (%arg3: tensor<7x?x24x?xf32>, %arg4: index, %arg5: index, %arg6: index, %arg7: index) -> tensor { - // CHECK: hal.command_buffer.push_constants %[[CMD]], %executable_layout, offset = 0, values = [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] : i32 - // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %9), %c1 = (%buffer, %c0, %12)] + // CHECK: hal.command_buffer.push_constants<%[[CMD]] + // CHECK-SAME: layout(%executable_layout + // CHECK-SAME: offset(0) + // CHECK-SAME: values([%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}]) : i32, i32, i32, i32 + // CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]] + // CHECK-SAME: layout(%executable_layout : !hal.executable_layout)[%c0] + // CHECK-SAME: bindings([ + // CHECK-NEXT: %c0 = (%[[INPUT]] : !hal.buffer)[%c0, %9], + // CHECK-NEXT: %c1 = (%{{.+}} : !hal.buffer)[%c0, %12] // CHECK: #hal.device.match.id<"dylib*">( // CHECK-SAME: %[[CMD_INNER:.+]] = %cmd : !hal.command_buffer, @@ -200,6 +241,9 @@ func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: i // CHECK-SAME: %[[COUNT_Z_INNER:.+]] = %c512 : index // This makes me so sad. + // If you are improving folding/canonicalization of these ops and come + // across this feel free to remove it all. And let me know so I can buy you + // a drink :) // CHECK: %[[C1:.+]] = constant 1 // CHECK-NEXT: %[[COUNT_X_TMP:.+]] = addi %[[COUNT_X_INNER]], %[[C1]] // CHECK-NEXT: %[[COUNT_X:.+]] = subi %[[COUNT_X_TMP]], %[[C1]] @@ -208,12 +252,13 @@ func @dynamic_tiled_dispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: i // CHECK-NEXT: %[[COUNT_Z_TMP:.+]] = addi %[[COUNT_Z_INNER]], %[[C1]] // CHECK-NEXT: %[[COUNT_Z:.+]] = subi %[[COUNT_Z_TMP]], %[[C1]] - // CHECK: hal.command_buffer.dispatch.symbol %[[CMD_INNER]], @ex::@tgt::@entry, workgroup_xyz = - // CHECK-SAME: [%[[COUNT_X]], %[[COUNT_Y]], %[[COUNT_Z]]] + // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD_INNER]] + // CHECK-SAME: target(@ex::@tgt::@entry) + // CHECK-SAME: workgroups([%[[COUNT_X]], %[[COUNT_Y]], %[[COUNT_Z]]]) %6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3, %arg4, %arg5, %arg5, %arg4) : (tensor<7x?x24x?xf32>{%arg4, %arg5}, index, index, index, index) -> tensor{%arg5, %arg4} flow.return %6 : tensor } - // CHECK: hal.command_buffer.end %[[CMD]] + // CHECK: hal.command_buffer.end<%[[CMD]] return %2 : tensor } @@ -227,7 +272,7 @@ hal.executable @pad_dispatch_0 attributes {sym_visibility = "private"} { hal.executable.target @tgt, filter="dylib-llvm-aot" { hal.executable.entry_point @pad_dispatch_0 attributes { interface = @interface_io, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () } module {} @@ -242,7 +287,7 @@ hal.executable @pad_dispatch_1 attributes {sym_visibility = "private"} { hal.executable.target @tgt, filter="dylib-llvm-aot" { hal.executable.entry_point @pad_dispatch_1 attributes { interface = @interface_io, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () } module {} @@ -252,17 +297,33 @@ hal.executable @pad_dispatch_1 attributes {sym_visibility = "private"} { // CHECK-LABEL: func @dispatch_tied_buffer // CHECK-SAME: (%[[FILL:.+]]: !hal.buffer, %[[INPUT:.+]]: !hal.buffer) func @dispatch_tied_buffer(%fill: tensor, %input: tensor<2x3xi32>) -> tensor<3x9xi32> { - // CHECK: %[[OUTPUT:.+]] = hal.allocator.allocate %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch" + // CHECK: %[[OUTPUT:.+]] = hal.allocator.allocate + // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Transfer|Mapping|Dispatch") %0 = flow.ex.stream.fragment(%fill, %input) : (tensor, tensor<2x3xi32>) -> tensor<3x9xi32> = (%arg0: tensor, %arg1: tensor<2x3xi32>) -> tensor<3x9xi32> { %c9 = constant 9 : index %c3 = constant 3 : index %c1 = constant 1 : index - // CHECK: %[[LAYOUT0:.+]] = hal.executable_layout.lookup %{{.+}}, set_layouts = {{\[\[}}#hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write|Discard">]] - // CHECK: hal.command_buffer.push_descriptor_set %{{.+}}, %[[LAYOUT0]], set = %{{.+}}, bindings = [%c0 = (%[[FILL]], %c0, %c4), %c1 = (%[[OUTPUT]], %c0, %c108)] + // CHECK: %[[LAYOUT0:.+]] = hal.executable_layout.lookup + // CHECK-SAME: layouts([ + // CHECK-SAME: #hal.descriptor_set_layout_binding<0, "StorageBuffer", R>, + // CHECK-SAME: #hal.descriptor_set_layout_binding<1, "StorageBuffer", DW> + // CHECK: hal.command_buffer.push_descriptor_set + // CHECK-SAME: layout(%[[LAYOUT0]] : !hal.executable_layout)[%{{.+}}] + // CHECK-SAME: bindings([ + // CHECK-NEXT: %c0 = (%[[FILL]] : !hal.buffer)[%c0, %c4], + // CHECK-NEXT: %c1 = (%[[OUTPUT]] : !hal.buffer)[%c0, %c108] %3 = flow.dispatch @pad_dispatch_0::@pad_dispatch_0[%c9, %c3, %c1](%arg0) : (tensor) -> tensor<3x9xi32> - // CHECK: %[[LAYOUT1:.+]] = hal.executable_layout.lookup %{{.+}}, set_layouts = {{\[\[}}#hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Read|Write">]] - // CHECK: hal.command_buffer.push_descriptor_set %{{.+}}, %[[LAYOUT1]], set = %{{.+}}, bindings = [%c0 = (%[[INPUT]], %c0, %c24), %c1 = (%[[OUTPUT]], %c0, %c108)] + // CHECK: %[[LAYOUT1:.+]] = hal.executable_layout.lookup + // CHECK-SAME: layouts([ + // CHECK-SAME: #hal.descriptor_set_layout_binding<0, "StorageBuffer", R>, + // CHECK-SAME: #hal.descriptor_set_layout_binding<1, "StorageBuffer", RW> + // CHECK: hal.command_buffer.push_descriptor_set + // CHECK-SAME: layout(%[[LAYOUT1]] : !hal.executable_layout)[%{{.+}}] + // CHECK-SAME: bindings([ + // CHECK-NEXT: %c0 = (%[[INPUT]] : !hal.buffer)[%c0, %c24], + // CHECK-NEXT: %c1 = (%[[OUTPUT]] : !hal.buffer)[%c0, %c108] %4 = flow.dispatch @pad_dispatch_1::@pad_dispatch_1[%c9, %c3, %c1](%arg1, %3) : (tensor<2x3xi32>, tensor<3x9xi32>) -> %3 flow.return %4 : tensor<3x9xi32> } diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir index c327af0293e3..ced0c86c97c7 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/tensor_ops.mlir @@ -1,63 +1,42 @@ // RUN: iree-opt -split-input-file -iree-convert-to-hal %s | IreeFileCheck %s -// CHECK-LABEL: @constantTensor -func @constantTensor() { - // CHECK-NEXT: %dev = hal.ex.shared_device - // CHECK-NEXT: %allocator = hal.device.allocator %dev - // CHECK-NEXT: %cbuffer = hal.allocator.allocate.const %allocator, {{.+}} = dense<[1, 2]> : tensor<2xi32> - %0 = constant dense<[1, 2]> : tensor<2xi32> - return -} - -// ----- - -// CHECK-LABEL: @constantTensor1 -func @constantTensor1() { - // CHECK-NEXT: %dev = hal.ex.shared_device - // CHECK-NEXT: %allocator = hal.device.allocator %dev - // CHECK-NEXT: %cbuffer = hal.allocator.allocate.const %allocator, {{.+}} = dense<[1, 0]> : tensor<2xi8> - %0 = constant dense<[1, 0]> : tensor<2xi1> - return -} - -// ----- - // CHECK-LABEL: @tensorLoad -func @tensorLoad(%arg0 : tensor<2x3xi32>) { +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer +func @tensorLoad(%tensor : tensor<2x3xi32>) { // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[C2:.+]] = constant 2 : index // CHECK-DAG: %[[C3:.+]] = constant 3 : index %i0 = constant 0 : index %i1 = constant 1 : index - // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset %allocator, shape = [ - // CHECK-SAME: %[[C2]], %[[C3]] - // CHECK-SAME: ], element_type = %c16777248_i32, indices = [ - // CHECK-SAME: %[[C0]], %[[C1]] - // CHECK-SAME: ] - // CHECK-NEXT: = hal.buffer.load %arg0[ - // CHECK-SAME: %[[OFF]] - // CHECK-SAME: ] : i32 - %0 = flow.tensor.load %arg0[%i0, %i1] : tensor<2x3xi32> + // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset<%allocator : !hal.allocator> + // CHECK-SAME: indices([%[[C0]], %[[C1]]]) + // CHECK-SAME: shape([%[[C2]], %[[C3]]]) + // CHECK-SAME: type(%c16777248_i32) + // CHECK-NEXT: = hal.buffer.load<%[[BUFFER]] : !hal.buffer>[%[[OFF]]] : i32 + %0 = flow.tensor.load %tensor[%i0, %i1] : tensor<2x3xi32> return } // ----- // CHECK-LABEL: @tensorLoad1 -func @tensorLoad1(%arg0 : tensor) { - // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset %allocator, shape = [], element_type = %c16777217_i32, indices = [] - // CHECK-NEXT: = hal.buffer.load %arg0[ - // CHECK-SAME: %[[OFF]] - // CHECK-SAME: ] : i1 - %0 = flow.tensor.load %arg0 : tensor +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer +func @tensorLoad1(%tensor : tensor) { + // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset<%allocator : !hal.allocator> + // CHECK-SAME: indices([]) + // CHECK-SAME: shape([]) + // CHECK-SAME: type(%c16777217_i32) + // CHECK-NEXT: = hal.buffer.load<%[[BUFFER]] : !hal.buffer>[%[[OFF]]] : i1 + %0 = flow.tensor.load %tensor : tensor return } // ----- // CHECK-LABEL: @tensorStore -func @tensorStore(%arg0 : tensor<2x3xi32>) { +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer +func @tensorStore(%tensor : tensor<2x3xi32>) { // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[C9:.+]] = constant 9 : i32 @@ -66,28 +45,27 @@ func @tensorStore(%arg0 : tensor<2x3xi32>) { %i0 = constant 0 : index %i1 = constant 1 : index %c9 = constant 9 : i32 - // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset %allocator, shape = [ - // CHECK-SAME: %[[C2]], %[[C3]] - // CHECK-SAME: ], element_type = %c16777248_i32, indices = [ - // CHECK-SAME: %[[C0]], %[[C1]] - // CHECK-SAME: ] - // CHECK-NEXT: hal.buffer.store %[[C9]], %arg0[ - // CHECK-SAME: %[[OFF]] - // CHECK-SAME: ] : i32 - flow.tensor.store %c9, %arg0[%i0, %i1] : tensor<2x3xi32> + // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset<%allocator : !hal.allocator> + // CHECK-SAME: indices([%[[C0]], %[[C1]]]) + // CHECK-SAME: shape([%[[C2]], %[[C3]]]) + // CHECK-SAME: type(%c16777248_i32) + // CHECK-NEXT: hal.buffer.store<%[[BUFFER]] : !hal.buffer>[%[[OFF]]] value(%[[C9]] : i32) + flow.tensor.store %c9, %tensor[%i0, %i1] : tensor<2x3xi32> return } // ----- // CHECK-LABEL: @tensorStore1 -func @tensorStore1(%arg0 : tensor) { +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer +func @tensorStore1(%tensor : tensor) { // CHECK-DAG: %[[C1:.+]] = constant true %c1 = constant true - // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset %allocator, shape = [], element_type = %c16777217_i32, indices = [] - // CHECK-NEXT: hal.buffer.store %[[C1]], %arg0[ - // CHECK-SAME: %[[OFF]] - // CHECK-SAME: ] : i1 - flow.tensor.store %c1, %arg0 : tensor + // CHECK: %[[OFF:.+]] = hal.allocator.compute_offset<%allocator : !hal.allocator> + // CHECK-SAME: indices([]) + // CHECK-SAME: shape([]) + // CHECK-SAME: type(%c16777217_i32) + // CHECK-NEXT: hal.buffer.store<%[[BUFFER]] : !hal.buffer>[%[[OFF]]] value(%[[C1]] : i1) + flow.tensor.store %c1, %tensor : tensor return } diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertConstantOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertConstantOps.cpp index 68516a5438a1..b91d993a9390 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertConstantOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/ConvertConstantOps.cpp @@ -38,8 +38,7 @@ class ConstantSubspanConversion auto lengthValue = rewriter.createOrFold( op.getLoc(), op.runtime_range().lengthAttr()); rewriter.replaceOpWithNewOp( - op, IREE::HAL::BufferType::get(rewriter.getContext()), bufferValue, - offsetValue, lengthValue); + op, bufferValue.getType(), bufferValue, offsetValue, lengthValue); return success(); } }; diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/constant_ops.mlir index 1e430845f03b..944485ad6ce2 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/constant_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToHAL/test/constant_ops.mlir @@ -2,10 +2,10 @@ // CHECK-LABEL: func @constant_subspan func @constant_subspan() { - // CHECK-DAG: [[BUFFER:%.+]] = hal.variable.load @pool_buffer : !hal.buffer - // CHECK-DAG: [[OFFSET:%.+]] = constant 123 : index - // CHECK-DAG: [[LENGTH:%.+]] = constant 16 : index - // CHECK-NEXT: = hal.buffer.subspan [[BUFFER]], [[OFFSET]], [[LENGTH]] : !hal.buffer + // CHECK-DAG: %[[BUFFER:.+]] = hal.variable.load @pool_buffer : !hal.buffer + // CHECK-DAG: %[[OFFSET:.+]] = constant 123 : index + // CHECK-DAG: %[[LENGTH:.+]] = constant 16 : index + // CHECK-NEXT: = hal.buffer.subspan<%[[BUFFER]] : !hal.buffer>[%[[OFFSET]], %[[LENGTH]]] : !hal.buffer %cst0 = hal.constant.subspan @pool_buffer[#hal.byte_range<123, 16>] : tensor<4xf32> return } diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir index d663255a8cb2..f021df5c82b6 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/allocator_ops.mlir @@ -6,7 +6,7 @@ func @allocatorComputeSizeFoldsAway(%arg0 : !hal.allocator) -> index { // CHECK-NOT: hal.allocator.compute_size %c1024 = constant 1024 : index %c32_i32 = constant 32 : i32 - %0 = hal.allocator.compute_size %arg0, shape=[%c1024, %c1024], element_type=%c32_i32 + %0 = hal.allocator.compute_size<%arg0 : !hal.allocator> shape([%c1024, %c1024]) type(%c32_i32) : index return %0 : index } @@ -15,8 +15,8 @@ func @allocatorComputeSizeFoldsAway(%arg0 : !hal.allocator) -> index { // CHECK-LABEL: @allocatorAllocate func @allocatorAllocate(%arg0 : !hal.allocator) -> !hal.buffer { %c1024 = constant 1024 : index - // CHECK: %ref = vm.call @hal.allocator.allocate(%arg0, %c6, %c15, %c1024) : (!vm.ref, i32, i32, i32) -> !vm.ref - %0 = hal.allocator.allocate %arg0, "HostLocal", "All", %c1024 : !hal.buffer + // CHECK: %ref = vm.call @hal.allocator.allocate(%arg0, %c6, %c14, %c1024) : (!vm.ref, i32, i32, i32) -> !vm.ref + %0 = hal.allocator.allocate<%arg0 : !hal.allocator> type("HostLocal") usage("All") : !hal.buffer{%c1024} return %0 : !hal.buffer } @@ -27,6 +27,6 @@ func @allocatorMapByteBuffer(%arg0 : !hal.allocator, %arg1 : !iree.byte_buffer) %offset = constant 128 : index %length = constant 256 : index // CHECK: = vm.call @hal.allocator.wrap.byte_buffer(%arg0, %c6, %c2, %arg1, %c128, %c256) : (!vm.ref, i32, i32, !vm.ref, i32, i32) -> !vm.ref - %buffer = hal.allocator.map %arg0, "HostVisible|HostCoherent", Transfer, %arg1[%offset, %length] : !iree.byte_buffer -> !hal.buffer + %buffer = hal.allocator.map<%arg0 : !hal.allocator> source(%arg1 : !iree.byte_buffer)[%offset, %length] type("HostVisible|HostCoherent") usage(Transfer) : !hal.buffer return %buffer : !hal.buffer } diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir index 386774338886..246d0b136dd0 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/buffer_ops.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: @buffer_subspan func @buffer_subspan(%arg0 : !hal.buffer) -> !hal.buffer { %c42 = constant 42 : index - %c42_0 = constant 42 : index - // CHECK: %ref = vm.call @hal.buffer.subspan(%arg0, %c42, %c42_0) : (!vm.ref, i32, i32) -> !vm.ref - %buffer = hal.buffer.subspan %arg0, %c42, %c42_0 : !hal.buffer + %c43 = constant 43 : index + // CHECK: %ref = vm.call @hal.buffer.subspan(%arg0, %c42, %c43) : (!vm.ref, i32, i32) -> !vm.ref + %buffer = hal.buffer.subspan<%arg0 : !hal.buffer>[%c42, %c43] : !hal.buffer return %buffer : !hal.buffer } @@ -14,10 +14,10 @@ func @buffer_subspan(%arg0 : !hal.buffer) -> !hal.buffer { // CHECK-LABEL: @buffer_fill func @buffer_fill(%arg0 : !hal.buffer) { %c42 = constant 42 : index - %c42_0 = constant 42 : index - %c42_1 = constant 42 : i32 - // CHECK: vm.call @hal.buffer.fill(%arg0, %c42, %c42_0, %c42_1) : (!vm.ref, i32, i32, i32) -> () - hal.buffer.fill %arg0, %c42, %c42_0, %c42_1 + %c43 = constant 43 : index + %c123 = constant 123 : i32 + // CHECK: vm.call @hal.buffer.fill(%arg0, %c42, %c43, %c123) : (!vm.ref, i32, i32, i32) -> () + hal.buffer.fill<%arg0 : !hal.buffer>[%c42, %c43] pattern(%c123 : i32) return } @@ -26,12 +26,14 @@ func @buffer_fill(%arg0 : !hal.buffer) { // CHECK-LABEL: @buffer_load func @buffer_load(%arg0 : !hal.buffer) -> (i8, i16, i32) { %c42 = constant 42 : index + %c43 = constant 43 : index + %c44 = constant 44 : index // CHECK: %0 = vm.call @hal.buffer.load(%arg0, %c42, %c1) : (!vm.ref, i32, i32) -> i32 - %0 = hal.buffer.load %arg0[%c42] : i8 - // CHECK: %1 = vm.call @hal.buffer.load(%arg0, %c42, %c2) : (!vm.ref, i32, i32) -> i32 - %1 = hal.buffer.load %arg0[%c42] : i16 - // CHECK: %2 = vm.call @hal.buffer.load(%arg0, %c42, %c4) : (!vm.ref, i32, i32) -> i32 - %2 = hal.buffer.load %arg0[%c42] : i32 + %0 = hal.buffer.load<%arg0 : !hal.buffer>[%c42] : i8 + // CHECK: %1 = vm.call @hal.buffer.load(%arg0, %c43, %c2) : (!vm.ref, i32, i32) -> i32 + %1 = hal.buffer.load<%arg0 : !hal.buffer>[%c43] : i16 + // CHECK: %2 = vm.call @hal.buffer.load(%arg0, %c44, %c4) : (!vm.ref, i32, i32) -> i32 + %2 = hal.buffer.load<%arg0 : !hal.buffer>[%c44] : i32 return %0, %1, %2 : i8, i16, i32 } @@ -40,11 +42,13 @@ func @buffer_load(%arg0 : !hal.buffer) -> (i8, i16, i32) { // CHECK-LABEL: @buffer_store func @buffer_store(%arg0 : !hal.buffer, %arg1 : i8, %arg2 : i16, %arg3 : i32) { %c42 = constant 42 : index + %c43 = constant 43 : index + %c44 = constant 44 : index // CHECK: vm.call @hal.buffer.store(%arg1, %arg0, %c42, %c1) : (i32, !vm.ref, i32, i32) -> () - hal.buffer.store %arg1, %arg0[%c42] : i8 - // CHECK: vm.call @hal.buffer.store(%arg2, %arg0, %c42, %c2) : (i32, !vm.ref, i32, i32) -> () - hal.buffer.store %arg2, %arg0[%c42] : i16 - // CHECK: vm.call @hal.buffer.store(%arg3, %arg0, %c42, %c4) : (i32, !vm.ref, i32, i32) -> () - hal.buffer.store %arg3, %arg0[%c42] : i32 + hal.buffer.store<%arg0 : !hal.buffer>[%c42] value(%arg1 : i8) + // CHECK: vm.call @hal.buffer.store(%arg2, %arg0, %c43, %c2) : (i32, !vm.ref, i32, i32) -> () + hal.buffer.store<%arg0 : !hal.buffer>[%c43] value(%arg2 : i16) + // CHECK: vm.call @hal.buffer.store(%arg3, %arg0, %c44, %c4) : (i32, !vm.ref, i32, i32) -> () + hal.buffer.store<%arg0 : !hal.buffer>[%c44] value(%arg3 : i32) return } diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index f46de67749f5..c7c8174c8c8e 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir @@ -1,53 +1,70 @@ // RUN: iree-opt -split-input-file -iree-convert-hal-to-vm %s | IreeFileCheck %s // CHECK-LABEL: @command_buffer_create -func @command_buffer_create(%arg0 : !hal.device) { +func @command_buffer_create(%arg0: !hal.device) { // CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3) : (!vm.ref, i32, i32) -> !vm.ref - %cmd = hal.command_buffer.create %arg0, "OneShot", "Transfer|Dispatch" : !hal.command_buffer + %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer return } // ----- // CHECK-LABEL: @command_buffer_begin_end -func @command_buffer_begin_end(%arg0 : !hal.command_buffer) { +func @command_buffer_begin_end(%arg0: !hal.command_buffer) { // CHECK: vm.call @hal.command_buffer.begin(%arg0) : (!vm.ref) -> () - hal.command_buffer.begin %arg0 + hal.command_buffer.begin<%arg0 : !hal.command_buffer> // CHECK: vm.call @hal.command_buffer.end(%arg0) : (!vm.ref) -> () - hal.command_buffer.end %arg0 + hal.command_buffer.end<%arg0 : !hal.command_buffer> return } // ----- // CHECK-LABEL: @command_buffer_execution_barrier -func @command_buffer_execution_barrier(%arg0 : !hal.command_buffer, %arg1 : !hal.buffer) { +func @command_buffer_execution_barrier( + %arg0: !hal.command_buffer, + %arg1: !hal.buffer +) { // CHECK: vm.call @hal.command_buffer.execution_barrier(%arg0, %c1, %c2, %zero) : (!vm.ref, i32, i32, i32) - hal.command_buffer.execution_barrier %arg0, "CommandIssue", "CommandProcess", "None" + hal.command_buffer.execution_barrier<%arg0 : !hal.command_buffer> + source("CommandIssue") + target("CommandProcess") + flags("None") return } // ----- // CHECK-LABEL: @command_buffer_fill_buffer -func @command_buffer_fill_buffer(%arg0 : !hal.command_buffer, %arg1 : !hal.buffer) { +func @command_buffer_fill_buffer( + %arg0: !hal.command_buffer, + %arg1: !hal.buffer +) { %c100 = constant 100 : index %c200 = constant 200 : index %c300 = constant 300 : i32 // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %c300) : (!vm.ref, !vm.ref, i32, i32, i32) -> () - hal.command_buffer.fill_buffer %arg0, %arg1, %c100, %c200, %c300 + hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer> + target(%arg1 : !hal.buffer)[%c100, %c200] + pattern(%c300 : i32) return } // ----- // CHECK-LABEL: @command_buffer_copy_buffer -func @command_buffer_copy_buffer(%arg0 : !hal.command_buffer, %arg1 : !hal.buffer) { +func @command_buffer_copy_buffer( + %arg0: !hal.command_buffer, + %arg1: !hal.buffer +) { %c100 = constant 100 : index %c200 = constant 200 : index %c300 = constant 300 : index // CHECK: vm.call @hal.command_buffer.copy_buffer(%arg0, %arg1, %c100, %arg1, %c200, %c300) : (!vm.ref, !vm.ref, i32, !vm.ref, i32, i32) -> () - hal.command_buffer.copy_buffer %arg0, %arg1, %c100, %arg1, %c200, %c300 + hal.command_buffer.copy_buffer<%arg0 : !hal.command_buffer> + source(%arg1 : !hal.buffer)[%c100] + target(%arg1 : !hal.buffer)[%c200] + length(%c300) return } @@ -55,27 +72,38 @@ func @command_buffer_copy_buffer(%arg0 : !hal.command_buffer, %arg1 : !hal.buffe // CHECK-LABEL: @command_buffer_bind_descriptor_set func @command_buffer_bind_descriptor_set( - %arg0 : !hal.command_buffer, - %arg1 : !hal.executable_layout, - %arg2 : !hal.descriptor_set) { + %arg0: !hal.command_buffer, + %arg1: !hal.executable_layout, + %arg2: !hal.descriptor_set +) { %c0 = constant 0 : index %c100 = constant 100 : index // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero, %arg2, []) : (!vm.ref, !vm.ref, i32, !vm.ref, i32 ...) - hal.command_buffer.bind_descriptor_set %arg0, %arg1, set = %c0, %arg2 + hal.command_buffer.bind_descriptor_set<%arg0 : !hal.command_buffer> + layout(%arg1 : !hal.executable_layout)[%c0] + set(%arg2 : !hal.descriptor_set) // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero, %arg2, [%c100]) : (!vm.ref, !vm.ref, i32, !vm.ref, i32 ...) - hal.command_buffer.bind_descriptor_set %arg0, %arg1, set = %c0, %arg2, offsets = [%c100] + hal.command_buffer.bind_descriptor_set<%arg0 : !hal.command_buffer> + layout(%arg1 : !hal.executable_layout)[%c0] + set(%arg2 : !hal.descriptor_set) + offsets([%c100]) return } // ----- // CHECK-LABEL: @command_buffer_dispatch -func @command_buffer_dispatch(%arg0 : !hal.command_buffer, %arg1 : !hal.executable) { +func @command_buffer_dispatch( + %arg0: !hal.command_buffer, + %arg1: !hal.executable +) { %c100 = constant 100 : index %c200 = constant 200 : index %c300 = constant 300 : index // CHECK: vm.call @hal.command_buffer.dispatch(%arg0, %arg1, %zero, %c100, %c200, %c300) : (!vm.ref, !vm.ref, i32, i32, i32, i32) -> () - hal.command_buffer.dispatch %arg0, %arg1, entry_point=0, workgroup_xyz=[%c100, %c200, %c300] + hal.command_buffer.dispatch<%arg0 : !hal.command_buffer> + target(%arg1 : !hal.executable)[0] + workgroups([%c100, %c200, %c300]) return } @@ -83,11 +111,14 @@ func @command_buffer_dispatch(%arg0 : !hal.command_buffer, %arg1 : !hal.executab // CHECK-LABEL: @command_buffer_dispatch_indirect func @command_buffer_dispatch_indirect( - %arg0 : !hal.command_buffer, - %arg1 : !hal.executable, - %arg2 : !hal.buffer) { + %arg0: !hal.command_buffer, + %arg1: !hal.executable, + %arg2: !hal.buffer +) { %c100 = constant 100 : index // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%arg0, %arg1, %zero, %arg2, %c100) : (!vm.ref, !vm.ref, i32, !vm.ref, i32) -> () - hal.command_buffer.dispatch.indirect %arg0, %arg1, entry_point=0, workgroups=%arg2[%c100] + hal.command_buffer.dispatch.indirect<%arg0 : !hal.command_buffer> + target(%arg1 : !hal.executable)[0] + workgroups(%arg2 : !hal.buffer)[%c100] return } diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir index 74c1ecebc89d..8cd90a04aa23 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/constant_ops.mlir @@ -18,11 +18,14 @@ func private @pool_storage0_buffer_initializer() -> !hal.buffer { %c0 = constant 0 : index %c16 = constant 16 : index %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator + %allocator = hal.device.allocator<%dev : !hal.device> : !hal.allocator // CHECK: [[STORAGE_REF:%.+]] = vm.const.ref.rodata @pool_storage0 : !vm.ref %storage = hal.constant_storage.lookup @pool::@_storage0 : !iree.byte_buffer // CHECK: = vm.call @hal.allocator.wrap.byte_buffer({{.+}}, %c22, %c15, [[STORAGE_REF]], %zero, %c16) - %mapped = hal.allocator.map %allocator, "HostVisible|HostCoherent|DeviceVisible", "Constant|Transfer|Mapping|Dispatch", %storage[%c0, %c16] : !iree.byte_buffer -> !hal.buffer + %mapped = hal.allocator.map<%allocator : !hal.allocator> + source(%storage : !iree.byte_buffer)[%c0, %c16] + type("HostVisible|HostCoherent|DeviceVisible") + usage("Constant|Transfer|Mapping|Dispatch") : !hal.buffer return %mapped : !hal.buffer } @@ -41,12 +44,14 @@ func private @pool_splats_initializer() -> !hal.buffer { %c32 = constant 32 : index %c1234567890_i32 = constant 1234567890 : i32 %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator + %allocator = hal.device.allocator<%dev : !hal.device> : !hal.allocator // CHECK: [[BUFFER:%.+]] = vm.call @hal.allocator.allocate({{.+}}, %c50, %c15, %c64) - %buffer = hal.allocator.allocate %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch", %c64 : !hal.buffer + %buffer = hal.allocator.allocate<%allocator : !hal.allocator> + type("HostVisible|DeviceVisible|DeviceLocal") + usage("Constant|Transfer|Mapping|Dispatch") : !hal.buffer{%c64} // CHECK: vm.call @hal.buffer.fill([[BUFFER]], %zero, %c4, %c1065353216) - hal.buffer.fill %buffer, %c0, %c4, %c1065353216_i32 + hal.buffer.fill<%buffer : !hal.buffer>[%c0, %c4] pattern(%c1065353216_i32 : i32) // CHECK: vm.call @hal.buffer.fill([[BUFFER]], %c32, %c32, %c1234567890) - hal.buffer.fill %buffer, %c32, %c32, %c1234567890_i32 + hal.buffer.fill<%buffer : !hal.buffer>[%c32, %c32] pattern(%c1234567890_i32 : i32) return %buffer : !hal.buffer } diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir index 498917ab6989..4286064c49b6 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir @@ -3,6 +3,6 @@ // CHECK-LABEL: @device_allocator func @device_allocator(%arg0 : !hal.device) -> !hal.allocator { // CHECK: %ref = vm.call @hal.device.allocator(%arg0) : (!vm.ref) -> !vm.ref - %allocator = hal.device.allocator %arg0 : !hal.allocator + %allocator = hal.device.allocator<%arg0 : !hal.device> : !hal.allocator return %allocator : !hal.allocator } diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir index a18b2c95f04a..6f06a74fac22 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir @@ -30,12 +30,12 @@ func @executableCreate( // CHECK: %[[EX0:.+]] = vm.call.variadic @hal.executable.create( // CHECK-SAME: %[[DEV]], %c1230128453, %[[BIN0]], [%[[LAYOUT0]], %[[LAYOUT1]]] // CHECK-SAME: ) : (!vm.ref, i32, !vm.ref, !vm.ref ...) -> !vm.ref - %0 = hal.executable.create %device, @exe::@binary1, layouts = [%layout0, %layout1] : !hal.executable + %0 = hal.executable.create device(%device : !hal.device) target(@exe::@binary1) layouts([%layout0, %layout1]) : !hal.executable // CHECK: %[[BIN1:.+]] = vm.const.ref.rodata @_exe_binary2_binary_spirv : !vm.ref // CHECK: %[[EX1:.+]] = vm.call.variadic @hal.executable.create( // CHECK-SAME: %[[DEV]], %c1397773893, %[[BIN1]], [%[[LAYOUT1]], %[[LAYOUT0]]] // CHECK-SAME: ) : (!vm.ref, i32, !vm.ref, !vm.ref ...) -> !vm.ref - %1 = hal.executable.create %device, @exe::@binary2, layouts = [%layout1, %layout0] : !hal.executable + %1 = hal.executable.create device(%device : !hal.device) target(@exe::@binary2) layouts([%layout1, %layout0]) : !hal.executable // CHECK: vm.return %[[EX0]], %[[EX1]] return %0, %1 : !hal.executable, !hal.executable } diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.cpp b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.cpp index f3493e170c06..718071699198 100644 --- a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/ConvertIREEToHAL.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/IREE/IR/IREEOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -48,9 +49,27 @@ class DynamicShapeConstantOpConversion IREE::HAL::BufferUsageBitfield::All | IREE::HAL::BufferUsageBitfield::Constant; - auto view = rewriter.createOrFold( - constantOp.getLoc(), allocator, memoryTypes, bufferUsage, - constantOp.value()); + auto shapedType = constantOp.value().getType(); + auto elementType = + IREE::HAL::getElementTypeValue(shapedType.getElementType()); + if (!elementType.hasValue()) { + return rewriter.notifyMatchFailure(constantOp, "unhandled element type"); + } + + auto buffer = rewriter.createOrFold( + constantOp.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + allocator, memoryTypes, bufferUsage, constantOp.value()); + + SmallVector shape; + if (shapedType.getRank() >= 1) { + for (auto dim : shapedType.getShape()) { + shape.push_back(rewriter.createOrFold( + constantOp.getLoc(), dim)); + } + } + + auto view = rewriter.createOrFold( + constantOp.getLoc(), buffer, elementType.getValue(), shape); rewriter.replaceOpWithNewOp(constantOp, view); return success(); diff --git a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/shape_constants.mlir b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/shape_constants.mlir index a6af40636acf..96102a58de2e 100644 --- a/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/shape_constants.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/IREEToHAL/test/shape_constants.mlir @@ -2,10 +2,13 @@ // CHECK-LABEL: @dynamic_shape_constant func @dynamic_shape_constant() { - // CHECK: %dev = hal.ex.shared_device - // CHECK: %allocator = hal.device.allocator %dev - // CHECK: %view = hal.buffer_view.const %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch" : !hal.buffer_view = dense<2> : tensor<2xi32> - // CHECK: %[[RES:.+]] = iree.do_not_optimize(%view) : !hal.buffer_view + // CHECK: %[[ALLOCATOR:.+]] = hal.device.allocator + // CHECK-NEXT: %[[BUFFER:.+]] = hal.allocator.constant<%[[ALLOCATOR]] : !hal.allocator> + // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Constant|Transfer|Mapping|Dispatch") + // CHECK-SAME: : !hal.buffer = dense<2> : tensor<2xi32> + // CHECK: %[[VIEW:.+]] = hal.buffer_view.create %[[BUFFER]], element_type = %c16777248_i32, shape = [%c2] : !hal.buffer -> !hal.buffer_view + // CHECK-NEXT: %[[RET:.+]] = iree.do_not_optimize(%[[VIEW]]) : !hal.buffer_view %c = iree.dynamic_shape_constant dense<2> : tensor<2xi32> -> tensor return } diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD index 89b798013bd1..187cde51da5e 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/BUILD @@ -21,6 +21,8 @@ package( cc_library( name = "StandardToHAL", srcs = [ + "ConvertConstantOps.cpp", + "ConvertShapeOps.cpp", "ConvertStandardToHAL.cpp", "ConvertStructuralOps.cpp", ], @@ -33,9 +35,12 @@ cc_library( "//iree/compiler/Dialect/HAL/IR:HALDialect", "//iree/compiler/Dialect/HAL/Target", "//iree/compiler/Dialect/HAL/Utils", + "//iree/compiler/Dialect/Shape/IR", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt index a92e59bfddbd..0c36bb3572a9 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/CMakeLists.txt @@ -16,12 +16,16 @@ iree_cc_library( HDRS "ConvertStandardToHAL.h" SRCS + "ConvertConstantOps.cpp" + "ConvertShapeOps.cpp" "ConvertStandardToHAL.cpp" "ConvertStructuralOps.cpp" DEPS LLVMSupport MLIRIR + MLIRMemRef MLIRPass + MLIRShape MLIRStandard MLIRTransforms iree::compiler::Dialect::HAL::Conversion @@ -29,6 +33,7 @@ iree_cc_library( iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::HAL::Target iree::compiler::Dialect::HAL::Utils + iree::compiler::Dialect::Shape::IR PUBLIC ) diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp new file mode 100644 index 000000000000..43aa7576aee1 --- /dev/null +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertConstantOps.cpp @@ -0,0 +1,92 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertStandardToHAL.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { +namespace { + +class ConstantTensorOpConversion + : public OpConversionPattern { + public: + ConstantTensorOpConversion(MLIRContext *ctx, TypeConverter &converter) + : OpConversionPattern(ctx) {} + + LogicalResult matchAndRewrite( + mlir::ConstantOp constantOp, llvm::ArrayRef newOperands, + ConversionPatternRewriter &rewriter) const override { + if (!constantOp.getType().isa()) return failure(); + + auto device = + rewriter.createOrFold(constantOp.getLoc()); + auto allocator = rewriter.createOrFold( + constantOp.getLoc(), device); + + // TODO(benvanik): compute from SSA use-def chain uses. + IREE::HAL::MemoryTypeBitfield memoryTypes = + IREE::HAL::MemoryTypeBitfield::DeviceLocal | + IREE::HAL::MemoryTypeBitfield::HostVisible; + IREE::HAL::BufferUsageBitfield bufferUsage = + IREE::HAL::BufferUsageBitfield::All | + IREE::HAL::BufferUsageBitfield::Constant; + + auto elementsAttr = constantOp.getValue().cast(); + auto elementsTy = elementsAttr.getType().cast(); + + // Expand boolean elements to the minimum bit widht supported by the HAL + // (8-bits). + // To improve memory bandwidth and increase computae we should prefer to + // pack 1-bit tensors into wider storage before this lossy conversion. For + // example bitwise ops on 8x32xi1 can be converted to ops on tensor<8xi32>. + if (elementsTy.getElementType().isInteger(1)) { + elementsAttr = + elementsAttr.mapValues(rewriter.getIntegerType(8), + llvm::function_ref( + [](const APInt &val) -> APInt { + return APInt(8, val.getBoolValue()); + })); + } + + auto buffer = rewriter.createOrFold( + constantOp.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + allocator, memoryTypes, bufferUsage, elementsAttr); + + rewriter.replaceOp(constantOp, {buffer}); + return success(); + } +}; + +} // namespace + +void populateStandardConstantToHALPatterns(MLIRContext *context, + OwningRewritePatternList &patterns, + TypeConverter &converter) { + patterns.insert(context, converter); +} + +} // namespace iree_compiler +} // namespace mlir diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertShapeQueryOps.cpp b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp similarity index 96% rename from iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertShapeQueryOps.cpp rename to iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp index 7b09018b1190..6843d77c8c77 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertShapeQueryOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/ConvertShapeOps.cpp @@ -65,7 +65,7 @@ class BackingBufferBufferViewDimPattern Optional index = dimOp.getConstantIndex(); assert(index.hasValue() && "expect constant index in `std.dim` operation"); - auto dimIndex = rewriter.getI32IntegerAttr(index.getValue()); + auto dimIndex = rewriter.getIndexAttr(index.getValue()); rewriter.replaceOpWithNewOp( dimOp, dimOp.getResult().getType(), adaptor.getBufferView(), dimIndex); return success(); @@ -96,7 +96,7 @@ class BackingBufferBufferViewRankPattern : public OpConversionPattern { } // namespace -void populateHalBufferViewShapePatterns(MLIRContext *context, +void populateStandardShapeToHALPatterns(MLIRContext *context, OwningRewritePatternList &patterns, TypeConverter &converter) { patterns.insert(); + conversionTarget.addIllegalOp(); } void populateStandardToHALPatterns(MLIRContext *context, OwningRewritePatternList &patterns, TypeConverter &typeConverter) { + populateStandardConstantToHALPatterns(context, patterns, typeConverter); + populateStandardShapeToHALPatterns(context, patterns, typeConverter); populateStandardStructuralToHALPatterns(context, patterns, typeConverter); } diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD index 51722bc38906..fc87c97b6704 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/BUILD @@ -24,7 +24,10 @@ package( iree_lit_test_suite( name = "lit", srcs = enforce_glob( - ["structural_ops.mlir"], + [ + "constant_ops.mlir", + "structural_ops.mlir", + ], include = ["*.mlir"], ), data = [ diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt index aec20652b290..1609d044bd36 100644 --- a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "constant_ops.mlir" "structural_ops.mlir" DATA iree::tools::IreeFileCheck diff --git a/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/constant_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/constant_ops.mlir new file mode 100644 index 000000000000..cb56984a790e --- /dev/null +++ b/iree/compiler/Dialect/HAL/Conversion/StandardToHAL/test/constant_ops.mlir @@ -0,0 +1,27 @@ +// RUN: iree-opt -split-input-file -iree-convert-to-hal %s | IreeFileCheck %s + +// CHECK-LABEL: @constantTensor +func @constantTensor() { + // CHECK: %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + // CHECK: %cbuffer = hal.allocator.constant<%allocator : !hal.allocator> + // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Constant|Transfer|Mapping|Dispatch") + // CHECK-SAME: : !hal.buffer + // CHECK-SAME: = dense<[1, 2]> : tensor<2xi32> + %0 = constant dense<[1, 2]> : tensor<2xi32> + return +} + +// ----- + +// CHECK-LABEL: @constantTensor1 +func @constantTensor1() { + // CHECK: %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + // CHECK: %cbuffer = hal.allocator.constant<%allocator : !hal.allocator> + // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Constant|Transfer|Mapping|Dispatch") + // CHECK-SAME: : !hal.buffer + // CHECK-SAME: = dense<[1, 0]> : tensor<2xi8> + %0 = constant dense<[1, 0]> : tensor<2xi1> + return +} diff --git a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp index 0ccbaf092731..c2a25c24115d 100644 --- a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.cpp @@ -14,7 +14,9 @@ #include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" +#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h" #include "iree/compiler/Dialect/IREE/IR/IREETypes.h" namespace mlir { @@ -23,6 +25,7 @@ namespace iree_compiler { HALTypeConverter::HALTypeConverter( ArrayRef conversionInterfaces) : conversionInterfaces(conversionInterfaces.vec()) { + // Custom conversion interfaces for external dialects. addConversion([this](Type type, SmallVectorImpl &results) { for (auto *conversionInterface : this->conversionInterfaces) { if (succeeded(conversionInterface->convertType(type, results))) { @@ -32,18 +35,22 @@ HALTypeConverter::HALTypeConverter( results.push_back(type); return success(); }); + + // Tensors become buffers by default. + // TODO(benvanik): make them buffer views instead? then they carry shape but + // are memory type erased which is not good. addConversion([](TensorType type) -> Optional { // HAL only should be concerned with numeric values. - if (HALTypeConverter::shouldConvertToHalBuffer(type)) { + if (HALTypeConverter::shouldConvertToBuffer(type)) { // TODO(benvanik): composite-type conversion (buffer + dynamic dims). return IREE::HAL::BufferType::get(type.getContext()); } - return llvm::None; }); + + // Recursively handle pointer target types (we want to convert + // ptr> to ptr>, for example). addConversion([this](IREE::PtrType type) -> Type { - // Recursively handle pointer target types (we want to convert ptr to - // ptr, for example). auto targetType = convertType(type.getTargetType()); if (!targetType) { return Type(); diff --git a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h index edb356b9ad0a..3f0f225c24c9 100644 --- a/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h +++ b/iree/compiler/Dialect/HAL/Conversion/TypeConverter.h @@ -31,7 +31,7 @@ class HALTypeConverter : public TypeConverter { // TODO(benvanik): signature conversion for output buffers. - static bool shouldConvertToHalBuffer(Type type) { + static bool shouldConvertToBuffer(Type type) { if (TensorType tensor_type = type.template dyn_cast()) { return tensor_type.getElementType().isIntOrFloat(); } diff --git a/iree/compiler/Dialect/HAL/IR/BUILD b/iree/compiler/Dialect/HAL/IR/BUILD index 01ba698853b4..beadb9054bda 100644 --- a/iree/compiler/Dialect/HAL/IR/BUILD +++ b/iree/compiler/Dialect/HAL/IR/BUILD @@ -29,6 +29,7 @@ filegroup( srcs = enforce_glob( [ "HALBase.td", + "HALInterfaces.td", "HALOps.td", ], include = ["*.td"], @@ -50,14 +51,17 @@ cc_library( textual_hdrs = [ "HALEnums.cpp.inc", "HALEnums.h.inc", + "HALOpInterfaces.cpp.inc", + "HALOpInterfaces.h.inc", "HALOps.cpp.inc", "HALOps.h.inc", - "HALOpInterface.cpp.inc", - "HALOpInterface.h.inc", "HALStructs.cpp.inc", "HALStructs.h.inc", + "HALTypeInterfaces.cpp.inc", + "HALTypeInterfaces.h.inc", ], deps = [ + ":HALInterfacesGen", ":HALOpsGen", ":HALStructsGen", ":HALTypesGen", @@ -92,6 +96,24 @@ cc_library( ], ) +gentbl( + name = "HALInterfacesGen", + tbl_outs = [ + ("-gen-op-interface-decls", "HALOpInterfaces.h.inc"), + ("-gen-op-interface-defs", "HALOpInterfaces.cpp.inc"), + ("-gen-type-interface-decls", "HALTypeInterfaces.h.inc"), + ("-gen-type-interface-defs", "HALTypeInterfaces.cpp.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HALInterfaces.td", + td_srcs = [ + ":td_files", + "//iree/compiler/Dialect/IREE/IR:td_files", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:StdOpsTdFiles", + ], +) + gentbl( name = "HALOpsGen", tbl_outs = [ @@ -129,8 +151,6 @@ gentbl( tbl_outs = [ ("-gen-enum-decls", "HALEnums.h.inc"), ("-gen-enum-defs", "HALEnums.cpp.inc"), - ("-gen-op-interface-decls", "HALOpInterface.h.inc"), - ("-gen-op-interface-defs", "HALOpInterface.cpp.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "HALBase.td", diff --git a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt index 2627c126fd90..a72e03d5ec56 100644 --- a/iree/compiler/Dialect/HAL/IR/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/IR/CMakeLists.txt @@ -20,12 +20,14 @@ iree_cc_library( TEXTUAL_HDRS "HALEnums.cpp.inc" "HALEnums.h.inc" - "HALOpInterface.cpp.inc" - "HALOpInterface.h.inc" + "HALOpInterfaces.cpp.inc" + "HALOpInterfaces.h.inc" "HALOps.cpp.inc" "HALOps.h.inc" "HALStructs.cpp.inc" "HALStructs.h.inc" + "HALTypeInterfaces.cpp.inc" + "HALTypeInterfaces.h.inc" SRCS "HALOpFolders.cpp" "HALOps.cpp" @@ -65,6 +67,18 @@ iree_cc_library( PUBLIC ) +iree_tablegen_library( + NAME + HALInterfacesGen + TD_FILE + "HALInterfaces.td" + OUTS + -gen-op-interface-decls HALOpInterfaces.h.inc + -gen-op-interface-defs HALOpInterfaces.cpp.inc + -gen-type-interface-decls HALTypeInterfaces.h.inc + -gen-type-interface-defs HALTypeInterfaces.cpp.inc +) + iree_tablegen_library( NAME HALOpsGen @@ -95,8 +109,6 @@ iree_tablegen_library( OUTS -gen-enum-decls HALEnums.h.inc -gen-enum-defs HALEnums.cpp.inc - -gen-op-interface-decls HALOpInterface.h.inc - -gen-op-interface-defs HALOpInterface.cpp.inc ) iree_tablegen_doc( diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td index db0d962f39d8..97e82e87785a 100644 --- a/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/iree/compiler/Dialect/HAL/IR/HALBase.td @@ -60,14 +60,14 @@ def HAL_MemoryModelAttr : let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; } -def HAL_MemoryType_None : BitEnumAttrCase<"None", 0x0000>; -def HAL_MemoryType_Transient : BitEnumAttrCase<"Transient", 0x0001>; -def HAL_MemoryType_HostVisible : BitEnumAttrCase<"HostVisible", 0x0002>; -def HAL_MemoryType_HostCoherent : BitEnumAttrCase<"HostCoherent", 0x0004>; -def HAL_MemoryType_HostCached : BitEnumAttrCase<"HostCached", 0x0008>; -def HAL_MemoryType_HostLocal : BitEnumAttrCase<"HostLocal", 0x0006>; -def HAL_MemoryType_DeviceVisible : BitEnumAttrCase<"DeviceVisible", 0x0010>; -def HAL_MemoryType_DeviceLocal : BitEnumAttrCase<"DeviceLocal", 0x0030>; +def HAL_MemoryType_None : BitEnumAttrCase<"None", 0x0000>; // ? +def HAL_MemoryType_Transient : BitEnumAttrCase<"Transient", 0x0001>; // T +def HAL_MemoryType_HostVisible : BitEnumAttrCase<"HostVisible", 0x0002>; // h +def HAL_MemoryType_HostCoherent : BitEnumAttrCase<"HostCoherent", 0x0004>; // c +def HAL_MemoryType_HostCached : BitEnumAttrCase<"HostCached", 0x0008>; // C +def HAL_MemoryType_HostLocal : BitEnumAttrCase<"HostLocal", 0x0006>; // H +def HAL_MemoryType_DeviceVisible : BitEnumAttrCase<"DeviceVisible", 0x0010>; // d +def HAL_MemoryType_DeviceLocal : BitEnumAttrCase<"DeviceLocal", 0x0030>; // D def HAL_MemoryTypeBitfieldAttr : BitEnumAttr<"MemoryTypeBitfield", "valid MemoryType", [ HAL_MemoryType_None, @@ -82,13 +82,13 @@ def HAL_MemoryTypeBitfieldAttr : let cppNamespace = "mlir::iree_compiler::IREE::HAL"; } -def HAL_MemoryAccess_None : BitEnumAttrCase<"None", 0x0000>; -def HAL_MemoryAccess_Read : BitEnumAttrCase<"Read", 0x0001>; -def HAL_MemoryAccess_Write : BitEnumAttrCase<"Write", 0x0002>; -def HAL_MemoryAccess_Discard : BitEnumAttrCase<"Discard", 0x0004>; +def HAL_MemoryAccess_None : BitEnumAttrCase<"None", 0x0000>; // ? +def HAL_MemoryAccess_Read : BitEnumAttrCase<"Read", 0x0001>; // R +def HAL_MemoryAccess_Write : BitEnumAttrCase<"Write", 0x0002>; // W +def HAL_MemoryAccess_Discard : BitEnumAttrCase<"Discard", 0x0004>; // D def HAL_MemoryAccess_DiscardWrite : BitEnumAttrCase<"DiscardWrite", 0x0006>; -def HAL_MemoryAccess_MayAlias : BitEnumAttrCase<"MayAlias", 0x0008>; -def HAL_MemoryAccess_All : BitEnumAttrCase<"All", 0x0007>; +def HAL_MemoryAccess_MayAlias : BitEnumAttrCase<"MayAlias", 0x0008>; // A +def HAL_MemoryAccess_All : BitEnumAttrCase<"All", 0x0007>; def HAL_MemoryAccessBitfieldAttr : BitEnumAttr<"MemoryAccessBitfield", "valid MemoryAccess", [ HAL_MemoryAccess_None, @@ -102,12 +102,12 @@ def HAL_MemoryAccessBitfieldAttr : let cppNamespace = "mlir::iree_compiler::IREE::HAL"; } -def HAL_BufferUsage_None : BitEnumAttrCase<"None", 0x0000>; -def HAL_BufferUsage_Constant : BitEnumAttrCase<"Constant", 0x0001>; -def HAL_BufferUsage_Transfer : BitEnumAttrCase<"Transfer", 0x0002>; -def HAL_BufferUsage_Mapping : BitEnumAttrCase<"Mapping", 0x0004>; -def HAL_BufferUsage_Dispatch : BitEnumAttrCase<"Dispatch", 0x0008>; -def HAL_BufferUsage_All : BitEnumAttrCase<"All", 0x000F>; +def HAL_BufferUsage_None : BitEnumAttrCase<"None", 0x0000>; // ? +def HAL_BufferUsage_Constant : BitEnumAttrCase<"Constant", 0x0001>; // C +def HAL_BufferUsage_Transfer : BitEnumAttrCase<"Transfer", 0x0002>; // T +def HAL_BufferUsage_Mapping : BitEnumAttrCase<"Mapping", 0x0004>; // M +def HAL_BufferUsage_Dispatch : BitEnumAttrCase<"Dispatch", 0x0008>; // D +def HAL_BufferUsage_All : BitEnumAttrCase<"All", 0x000E>; def HAL_BufferUsageBitfieldAttr : BitEnumAttr<"BufferUsageBitfield", "valid BufferUsage", [ HAL_BufferUsage_None, @@ -413,8 +413,11 @@ def HAL_ObjectType : AnyTypeOf<[ HAL_Semaphore, ]>; -def HAL_OrdinalAttr : SignlessIntegerAttrBase< - I32, "32-bit integer ordinal attribute">; +def HAL_BufferType : AnyTypeOf<[ + HAL_Buffer, +]>; + +def HAL_OrdinalAttr : IREE_IndexAttrBase<"size_t">; def HAL_ExecutableDataAttr : SignlessIntElementsAttr<8>; @@ -424,6 +427,7 @@ def HAL_ElementTypeAttr : SignlessIntegerAttrBase< def HAL_DeviceSize : TypeAlias; def HAL_DeviceSizeAttr : IREE_IndexAttrBase<"iree_device_size_t">; +def HAL_DeviceSizes : Variadic; def HAL_HostSize : TypeAlias; def HAL_HostSizeAttr : IREE_IndexAttrBase<"size_t">; @@ -436,6 +440,12 @@ def HAL_VariableRefAttr : AliasedSymbolRefAttr; def HAL_VariableType : AnyTypeOf<[HAL_PrimitiveType, AnyVector, HAL_ObjectType]>; def HAL_VariablePtr : PtrOf; +def HAL_IndexAttr : IREE_IndexAttrBase<"index">; +def HAL_IndexArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getIndexArrayAttr($0)"; +} + def HAL_Dim : TypeAlias; def HAL_Dims : Variadic; def HAL_Shape : Variadic; @@ -503,31 +513,6 @@ def HAL_ByteRangeAttr : let cppNamespace = "mlir::iree_compiler::IREE::HAL"; } -def HAL_MemoryBarrier : NamedTupleOf<[ - NamedTupleElement<0, "source_scope", I32>, - NamedTupleElement<1, "target_scope", I32> - ], "MemoryBarrier"> { - let description = [{ - MemoryBarrier struct that can be passed to the command buffer barrier - operations. - }]; -} -def HAL_MemoryBarrierList : TupleOf<[HAL_MemoryBarrier]>; - -def HAL_BufferBarrier : NamedTupleOf<[ - NamedTupleElement<0, "source_scope", I32>, - NamedTupleElement<1, "target_scope", I32>, - NamedTupleElement<2, "buffer", HAL_Buffer>, - NamedTupleElement<3, "offset", HAL_DeviceSize>, - NamedTupleElement<4, "length", HAL_DeviceSize> - ], "BufferBarrier"> { - let description = [{ - BufferBarrier struct that can be passed to the command buffer barrier - operations. - }]; -} -def HAL_BufferBarrierList : TupleOf<[HAL_BufferBarrier]>; - def HAL_DescriptorSetLayoutBindingAttr : IREE_StructAttr<"descriptor_set_layout_binding", "DescriptorSetLayoutBindingAttr", @@ -584,14 +569,8 @@ def HAL_DeviceMatchMemoryModelAttr : IREE_StructAttr< // Base HAL op classes //===----------------------------------------------------------------------===// -def HAL_OpInterface : OpInterface<"HALOp"> { - let description = [{ - Interface for HAL ops. - }]; -} - class HAL_Op traits = []> : - Op { + Op { let parser = [{ return parse$cppClass(parser, &result); }]; let printer = [{ return print$cppClass(p, *this); }]; } diff --git a/iree/compiler/Dialect/HAL/IR/HALInterfaces.td b/iree/compiler/Dialect/HAL/IR/HALInterfaces.td new file mode 100644 index 000000000000..40f5cb11371b --- /dev/null +++ b/iree/compiler/Dialect/HAL/IR/HALInterfaces.td @@ -0,0 +1,82 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_DIALECT_HAL_INTERFACES +#define IREE_DIALECT_HAL_INTERFACES + +include "iree/compiler/Dialect/IREE/IR/IREEBase.td" + +//===----------------------------------------------------------------------===// +// IREE::HAL::SizeAwareOpInterface +//===----------------------------------------------------------------------===// + +def HAL_InferTypeSize : TypeInterface<"InferTypeSizeInterface"> { + let description = [{ + Allows types to be queried for their size by inserting the required logic + when required. + }]; + + let methods = [ + InterfaceMethod< + [{Builds an expression computing the size of the value.}], + "Value", "inferSizeFromValue", (ins "Location":$loc, + "Value":$value, + "OpBuilder &":$builder) + >, + ]; +} + +def HAL_SizeAwareType : TypeInterface<"SizeAwareTypeInterface"> { + let description = [{ + Denotes that a type is size-aware and must always have a size value + associated with it in the IR. See `SizeAwareOp` for more information. + }]; + + let methods = [ + InterfaceMethod< + [{Returns a size for the given sized value.}], + "Value", "getSize", (ins "Value":$value) + >, + ]; +} + +def HAL_SizeAwareOp : OpInterface<"SizeAwareOpInterface"> { + let description = [{ + An operation that is able to provide size values for all size-aware operands + and results. + }]; + + let methods = [ + InterfaceMethod< + [{Returns a size for the given sized operand index.}], + "Value", "getOperandSize", (ins "unsigned":$idx) + >, + InterfaceMethod< + [{Returns a size for the given sized result index.}], + "Value", "getResultSize", (ins "unsigned":$idx) + >, + InterfaceMethod< + [{Returns a size for the given sized result value.}], + "Value", "getResultSizeFromValue", (ins "Value":$value), + /*defaultImplementation=*/[{ + for (unsigned i = 0; i < $_self->getNumResults(); ++i) { + if ($_self->getResult(i) == value) return $_self.getResultSize(i); + } + return {}; + }] + >, + ]; +} + +#endif // IREE_DIALECT_HAL_INTERFACES diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 505d065a3047..6566a9a1fa15 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -293,31 +293,62 @@ void AllocatorComputeRangeOp::getCanonicalizationPatterns( namespace { -/// Expands hal.allocator.allocate.const to an allocation and data write. -struct ExpandAllocatorAllocateConstOp - : public OpRewritePattern { +/// Expands hal.allocator.allocate.constant to an allocation and data write. +struct ExpandAllocatorConstantOp + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AllocatorAllocateConstOp op, + LogicalResult matchAndRewrite(AllocatorConstantOp op, PatternRewriter &rewriter) const override { + auto shapedType = op.value().getType(); + auto elementType = + IREE::HAL::getElementTypeValue(shapedType.getElementType()); + if (!elementType.hasValue()) { + return rewriter.notifyMatchFailure(op, "unhandled element type"); + } + + // TODO(benvanik): compute from SSA use-def chain uses. + IREE::HAL::MemoryTypeBitfield memoryTypes = + IREE::HAL::MemoryTypeBitfield::DeviceLocal | + IREE::HAL::MemoryTypeBitfield::HostVisible; + IREE::HAL::BufferUsageBitfield bufferUsage = + IREE::HAL::BufferUsageBitfield::All | + IREE::HAL::BufferUsageBitfield::Constant; + Type bufferType = IREE::HAL::BufferType::get(rewriter.getContext()); + auto hostBuffer = rewriter.createOrFold( op.getLoc(), IREE::ByteBufferType::get(rewriter.getContext()), op.value()); auto zero = rewriter.createOrFold(op.getLoc(), 0); auto neg1 = rewriter.createOrFold(op.getLoc(), -1); auto deviceBuffer = rewriter.createOrFold( - op.getLoc(), op.allocator(), op.memory_types(), op.buffer_usage(), + op.getLoc(), bufferType, op.allocator(), memoryTypes, bufferUsage, hostBuffer, zero, neg1); - rewriter.replaceOp(op, {deviceBuffer}); + + if (op.result().getType().isa()) { + // Wrap in a buffer view. + SmallVector shape; + if (shapedType.getRank() >= 1) { + for (auto dim : shapedType.getShape()) { + shape.push_back( + rewriter.createOrFold(op.getLoc(), dim)); + } + } + auto bufferView = rewriter.createOrFold( + op.getLoc(), deviceBuffer, elementType.getValue(), shape); + rewriter.replaceOp(op, {bufferView}); + } else { + rewriter.replaceOp(op, {deviceBuffer}); + } return success(); } }; } // namespace -void AllocatorAllocateConstOp::getCanonicalizationPatterns( +void AllocatorConstantOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -337,13 +368,13 @@ struct SkipBufferAllocatorOp : public OpRewritePattern { op.buffer().getDefiningOp())) { rewriter.replaceOp(op, allocateOp.allocator()); return success(); - } else if (auto allocateOp = dyn_cast_or_null( + } else if (auto allocateOp = dyn_cast_or_null( op.buffer().getDefiningOp())) { rewriter.replaceOp(op, allocateOp.allocator()); return success(); } else if (auto subspanOp = dyn_cast_or_null( op.buffer().getDefiningOp())) { - rewriter.replaceOpWithNewOp(op, + rewriter.replaceOpWithNewOp(op, op.result().getType(), subspanOp.source_buffer()); return success(); } @@ -364,45 +395,6 @@ void BufferAllocatorOp::getCanonicalizationPatterns( namespace { -/// Expands hal.buffer_view.const to an allocation and buffer view wrapper. -struct ExpandBufferViewConstOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BufferViewConstOp op, - PatternRewriter &rewriter) const override { - auto shapedType = op.value().getType(); - auto elementType = getElementTypeValue(shapedType.getElementType()); - if (!elementType.hasValue()) { - return failure(); - } - - auto buffer = rewriter.createOrFold( - op.getLoc(), op.allocator(), op.memory_types(), op.buffer_usage(), - op.value()); - - SmallVector shape; - if (shapedType.getRank() >= 1) { - for (auto dim : shapedType.getShape()) { - shape.push_back( - rewriter.createOrFold(op.getLoc(), dim)); - } - } - - rewriter.replaceOpWithNewOp( - op, buffer, elementType.getValue(), shape); - return success(); - } -}; - -} // namespace - -void BufferViewConstOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -namespace { - /// Expands a hal.buffer_view.subview op into range computation and creation /// ops. This allows for greater opportunity to CSE/bypass/etc the buffer view /// operations. @@ -416,7 +408,8 @@ struct ExpandBufferViewSubviewOp op.getLoc(), op.buffer_view(), op.indices(), op.lengths()); auto bufferValue = rewriter.createOrFold( - op.getLoc(), op.buffer_view()); + op.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + op.buffer_view()); auto subspanValue = rewriter.createOrFold( op.getLoc(), bufferValue.getType(), bufferValue, computeRangeOp.offset(), computeRangeOp.length()); @@ -473,9 +466,10 @@ struct ExpandBufferViewComputeOffsetOp LogicalResult matchAndRewrite(BufferViewComputeOffsetOp op, PatternRewriter &rewriter) const override { auto bufferValue = rewriter.createOrFold( - op.getLoc(), op.buffer_view()); - auto allocatorValue = - rewriter.createOrFold(op.getLoc(), bufferValue); + op.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + op.buffer_view()); + auto allocatorValue = rewriter.createOrFold( + op.getLoc(), AllocatorType::get(rewriter.getContext()), bufferValue); int rank = op.indices().size(); SmallVector dimTypes(rank, rewriter.getIndexType()); auto dimsOp = rewriter.create(op.getLoc(), dimTypes, @@ -507,9 +501,10 @@ struct ExpandBufferViewComputeRangeOp LogicalResult matchAndRewrite(BufferViewComputeRangeOp op, PatternRewriter &rewriter) const override { auto bufferValue = rewriter.createOrFold( - op.getLoc(), op.buffer_view()); - auto allocatorValue = - rewriter.createOrFold(op.getLoc(), bufferValue); + op.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + op.buffer_view()); + auto allocatorValue = rewriter.createOrFold( + op.getLoc(), AllocatorType::get(rewriter.getContext()), bufferValue); int rank = op.indices().size(); SmallVector dimTypes(rank, rewriter.getIndexType()); auto dimsOp = rewriter.create(op.getLoc(), dimTypes, @@ -542,7 +537,7 @@ struct ExpandBufferViewDimsOp : public OpRewritePattern { for (unsigned i = 0; i < op.getNumResults(); ++i) { newDimValues.push_back(rewriter.createOrFold( op.getLoc(), rewriter.getIndexType(), op.buffer_view(), - rewriter.getI32IntegerAttr(i))); + rewriter.getIndexAttr(i))); } rewriter.replaceOp(op, {newDimValues}); return success(); diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 09287c112e9c..4aafca4927bd 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -57,52 +57,134 @@ static LogicalResult parseEnumAttr(OpAsmParser &parser, StringRef attrName, } //===----------------------------------------------------------------------===// -// hal.ex.shared_device +// custom //===----------------------------------------------------------------------===// +// type{%size} -void ExSharedDeviceOp::getAsmResultNames( - function_ref setNameFn) { - setNameFn(result(), "dev"); +static ParseResult parseSizeAwareType(OpAsmParser &parser, Type &type, + OpAsmParser::OperandType &size) { + if (failed(parser.parseType(type)) || failed(parser.parseLBrace()) || + failed(parser.parseOperand(size)) || failed(parser.parseRBrace())) { + return failure(); + } + return success(); +} + +static void printSizeAwareType(OpAsmPrinter &p, Operation *op, Type type, + Value size) { + p.printType(type); + p << "{"; + p.printOperand(size); + p << "}"; } //===----------------------------------------------------------------------===// -// hal.make_memory_barrier +// custom //===----------------------------------------------------------------------===// +// (type{%size0}, type, type{%size1}) -void MakeMemoryBarrierOp::build(OpBuilder &builder, OperationState &state, - IREE::HAL::AccessScopeBitfield sourceScope, - IREE::HAL::AccessScopeBitfield targetScope) { - state.addAttribute("source_scope", builder.getI32IntegerAttr( - static_cast(sourceScope))); - state.addAttribute("target_scope", builder.getI32IntegerAttr( - static_cast(targetScope))); - state.addTypes({MemoryBarrierType::get(builder.getContext())}); +static ParseResult parseSizeAwareTypeList( + OpAsmParser &parser, SmallVectorImpl &types, + SmallVectorImpl &sizes) { + do { + Type type; + if (failed(parser.parseType(type))) return failure(); + if (type.isa()) { + OpAsmParser::OperandType size; + if (failed(parser.parseLBrace()) || failed(parser.parseOperand(size)) || + failed(parser.parseRBrace())) { + return failure(); + } + sizes.push_back(size); + } + types.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + return success(); } -void MakeMemoryBarrierOp::getAsmResultNames( - function_ref setNameFn) { - setNameFn(result(), "memory_barrier"); +static void printSizeAwareTypeList(OpAsmPrinter &p, Operation *op, + TypeRange types, OperandRange sizes) { + int sizeIndex = 0; + llvm::interleaveComma(types, p, [&](Type type) { + p.printType(type); + if (type.isa()) { + p << "{"; + p.printOperand(sizes[sizeIndex++]); + p << "}"; + } + }); } //===----------------------------------------------------------------------===// -// hal.make_buffer_barrier +// custom($binding_ordinals, +// $binding_buffers, +// type($binding_buffers), +// $binding_offsets, +// $binding_lengths) //===----------------------------------------------------------------------===// -void MakeBufferBarrierOp::build(OpBuilder &builder, OperationState &state, - IREE::HAL::AccessScopeBitfield sourceScope, - IREE::HAL::AccessScopeBitfield targetScope, - Value buffer, Value offset, Value length) { - state.addAttribute("source_scope", builder.getI32IntegerAttr( - static_cast(sourceScope))); - state.addAttribute("target_scope", builder.getI32IntegerAttr( - static_cast(targetScope))); - state.addOperands({buffer, offset, length}); - state.addTypes({BufferBarrierType::get(builder.getContext())}); +static ParseResult parseDescriptorSetBindings( + OpAsmParser &parser, SmallVectorImpl &ordinals, + SmallVectorImpl &buffers, + SmallVectorImpl &bufferTypes, + SmallVectorImpl &bufferOffsets, + SmallVectorImpl &bufferLengths) { + do { + OpAsmParser::OperandType ordinal; + OpAsmParser::OperandType buffer; + Type bufferType; + OpAsmParser::OperandType bufferOffset; + OpAsmParser::OperandType bufferLength; + if (failed(parser.parseOperand(ordinal)) || failed(parser.parseEqual()) || + failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) || + failed(parser.parseColonType(bufferType)) || + failed(parser.parseRParen()) || failed(parser.parseLSquare()) || + failed(parser.parseOperand(bufferOffset)) || + failed(parser.parseComma()) || + failed(parser.parseOperand(bufferLength)) || + failed(parser.parseRSquare())) { + return failure(); + } + ordinals.push_back(ordinal); + buffers.push_back(buffer); + bufferTypes.push_back(bufferType); + bufferOffsets.push_back(bufferOffset); + bufferLengths.push_back(bufferLength); + } while (succeeded(parser.parseOptionalComma())); + return success(); +} + +static void printDescriptorSetBindings(OpAsmPrinter &p, Operation *op, + ValueRange ordinals, ValueRange buffers, + TypeRange bufferTypes, + ValueRange bufferOffsets, + ValueRange bufferLengths) { + llvm::interleaveComma( + llvm::zip(ordinals, buffers, bufferTypes, bufferOffsets, bufferLengths), + p, [&](std::tuple it) { + p.printNewline(); + p << " "; + p.printOperand(std::get<0>(it)); + p << " = ("; + p.printOperand(std::get<1>(it)); + p << " : "; + p.printType(std::get<2>(it)); + p << ")["; + p.printOperand(std::get<3>(it)); + p << ", "; + p.printOperand(std::get<4>(it)); + p << "]"; + }); + p.printNewline(); } -void MakeBufferBarrierOp::getAsmResultNames( +//===----------------------------------------------------------------------===// +// hal.ex.shared_device +//===----------------------------------------------------------------------===// + +void ExSharedDeviceOp::getAsmResultNames( function_ref setNameFn) { - setNameFn(result(), "buffer_barrier"); + setNameFn(result(), "device"); } //===----------------------------------------------------------------------===// @@ -440,43 +522,20 @@ void AllocatorComputeRangeOp::getAsmResultNames( // hal.allocator.allocate //===----------------------------------------------------------------------===// -void AllocatorAllocateOp::build(OpBuilder &builder, OperationState &state, - Value allocator, - IREE::HAL::MemoryTypeBitfield memoryTypes, - IREE::HAL::BufferUsageBitfield bufferUsage, - Value allocationSize) { - state.addOperands({allocator, allocationSize}); - state.addAttribute("memory_types", builder.getI32IntegerAttr( - static_cast(memoryTypes))); - state.addAttribute("buffer_usage", builder.getI32IntegerAttr( - static_cast(bufferUsage))); - state.addTypes({BufferType::get(builder.getContext())}); -} - void AllocatorAllocateOp::getAsmResultNames( function_ref setNameFn) { setNameFn(result(), "buffer"); } +Value AllocatorAllocateOp::getOperandSize(unsigned idx) { return {}; } + +Value AllocatorAllocateOp::getResultSize(unsigned idx) { return result_size(); } + //===----------------------------------------------------------------------===// -// hal.allocator.allocate.const +// hal.allocator.constant //===----------------------------------------------------------------------===// -void AllocatorAllocateConstOp::build(OpBuilder &builder, OperationState &state, - Value allocator, - IREE::HAL::MemoryTypeBitfield memoryTypes, - IREE::HAL::BufferUsageBitfield bufferUsage, - ElementsAttr value) { - state.addOperands({allocator}); - state.addAttribute("memory_types", builder.getI32IntegerAttr( - static_cast(memoryTypes))); - state.addAttribute("buffer_usage", builder.getI32IntegerAttr( - static_cast(bufferUsage))); - state.addAttribute("value", value); - state.addTypes({BufferType::get(builder.getContext())}); -} - -void AllocatorAllocateConstOp::getAsmResultNames( +void AllocatorConstantOp::getAsmResultNames( function_ref setNameFn) { setNameFn(result(), "cbuffer"); } @@ -485,33 +544,21 @@ void AllocatorAllocateConstOp::getAsmResultNames( // hal.allocator.map //===----------------------------------------------------------------------===// -void AllocatorMapOp::build(OpBuilder &builder, OperationState &state, - Value allocator, - IREE::HAL::MemoryTypeBitfield memoryTypes, - IREE::HAL::BufferUsageBitfield bufferUsage, - Value source, Value offset, Value length) { - state.addOperands({allocator, source, offset, length}); - state.addAttribute("memory_types", builder.getI32IntegerAttr( - static_cast(memoryTypes))); - state.addAttribute("buffer_usage", builder.getI32IntegerAttr( - static_cast(bufferUsage))); - state.addTypes({BufferType::get(builder.getContext())}); -} - void AllocatorMapOp::getAsmResultNames( function_ref setNameFn) { setNameFn(result(), "mapped"); } +Value AllocatorMapOp::getOperandSize(unsigned idx) { return {}; } + +Value AllocatorMapOp::getResultSize(unsigned idx) { return length(); } + //===----------------------------------------------------------------------===// -// hal.buffer.allocator //===----------------------------------------------------------------------===// -void BufferAllocatorOp::build(OpBuilder &builder, OperationState &state, - Value buffer) { - state.addOperands({buffer}); - state.addTypes({AllocatorType::get(builder.getContext())}); -} +//===----------------------------------------------------------------------===// +// hal.buffer.allocator +//===----------------------------------------------------------------------===// void BufferAllocatorOp::getAsmResultNames( function_ref setNameFn) { @@ -527,27 +574,17 @@ void BufferSubspanOp::getAsmResultNames( setNameFn(result(), "buffer"); } +Value BufferSubspanOp::getOperandSize(unsigned idx) { return length(); } + +Value BufferSubspanOp::getResultSize(unsigned idx) { return length(); } + //===----------------------------------------------------------------------===// -// hal.buffer_view.const +// hal.buffer.byte_length //===----------------------------------------------------------------------===// -void BufferViewConstOp::build(OpBuilder &builder, OperationState &state, - Value allocator, - IREE::HAL::MemoryTypeBitfield memoryTypes, - IREE::HAL::BufferUsageBitfield bufferUsage, - ElementsAttr value) { - state.addOperands({allocator}); - state.addAttribute("memory_types", builder.getI32IntegerAttr( - static_cast(memoryTypes))); - state.addAttribute("buffer_usage", builder.getI32IntegerAttr( - static_cast(bufferUsage))); - state.addAttribute("value", value); - state.addTypes({BufferViewType::get(builder.getContext())}); -} - -void BufferViewConstOp::getAsmResultNames( +void BufferLengthOp::getAsmResultNames( function_ref setNameFn) { - setNameFn(result(), "view"); + setNameFn(result(), "len"); } //===----------------------------------------------------------------------===// @@ -588,12 +625,6 @@ void BufferViewSubviewOp::getAsmResultNames( // hal.buffer_view.buffer //===----------------------------------------------------------------------===// -void BufferViewBufferOp::build(OpBuilder &builder, OperationState &state, - Value bufferView) { - state.addOperands({bufferView}); - state.addTypes({BufferType::get(builder.getContext())}); -} - void BufferViewBufferOp::getAsmResultNames( function_ref setNameFn) { setNameFn(result(), "buffer"); @@ -653,19 +684,6 @@ void BufferViewComputeRangeOp::getAsmResultNames( // hal.command_buffer.create //===----------------------------------------------------------------------===// -void CommandBufferCreateOp::build( - OpBuilder &builder, OperationState &state, Value device, - IREE::HAL::CommandBufferModeBitfield modes, - IREE::HAL::CommandCategoryBitfield commandCategories) { - state.addOperands({device}); - state.addAttribute("modes", - builder.getI32IntegerAttr(static_cast(modes))); - state.addAttribute( - "command_categories", - builder.getI32IntegerAttr(static_cast(commandCategories))); - state.addTypes({CommandBufferType::get(builder.getContext())}); -} - void CommandBufferCreateOp::getAsmResultNames( function_ref setNameFn) { setNameFn(result(), "cmd"); @@ -751,157 +769,6 @@ void CommandBufferPushDescriptorSetOp::build( state.addOperands(bindingLengths); } -static ParseResult parseDescriptorSetBindings(OpAsmParser &parser, - OperationState *result) { - auto indexType = parser.getBuilder().getIndexType(); - SmallVector bindingOrdinals; - SmallVector bindingBuffers; - SmallVector bindingOffsets; - SmallVector bindingLengths; - do { - NamedAttrList attrList; - OpAsmParser::OperandType ordinal; - OpAsmParser::OperandType buffer; - OpAsmParser::OperandType bufferOffset; - OpAsmParser::OperandType bufferLength; - if (failed(parser.parseOperand(ordinal)) || - failed(parser.resolveOperand(ordinal, indexType, bindingOrdinals)) || - failed(parser.parseEqual()) || failed(parser.parseLParen()) || - failed(parser.parseOperand(buffer)) || - failed(parser.resolveOperand( - buffer, BufferType::get(result->getContext()), bindingBuffers)) || - failed(parser.parseComma()) || - failed(parser.parseOperand(bufferOffset)) || - failed( - parser.resolveOperand(bufferOffset, indexType, bindingOffsets)) || - failed(parser.parseComma()) || - failed(parser.parseOperand(bufferLength)) || - failed( - parser.resolveOperand(bufferLength, indexType, bindingLengths)) || - failed(parser.parseRParen())) { - return failure(); - } - } while (succeeded(parser.parseOptionalComma())); - result->addOperands(bindingOrdinals); - result->addOperands(bindingBuffers); - result->addOperands(bindingOffsets); - result->addOperands(bindingLengths); - return success(); -} - -static ParseResult parseCommandBufferPushDescriptorSetOp( - OpAsmParser &parser, OperationState *result) { - OpAsmParser::OperandType commandBuffer; - OpAsmParser::OperandType executableLayout; - OpAsmParser::OperandType set; - auto operandsLoc = parser.getCurrentLocation(); - if (failed(parser.parseOperand(commandBuffer)) || - failed(parser.parseComma()) || - failed(parser.parseOperand(executableLayout)) || - failed(parser.parseComma()) || failed(parser.parseKeyword("set")) || - failed(parser.parseEqual()) || failed(parser.parseOperand(set)) || - failed(parser.parseComma()) || - failed(parser.resolveOperands( - ArrayRef{ - commandBuffer, - executableLayout, - set, - }, - ArrayRef{ - CommandBufferType::get(result->getContext()), - ExecutableLayoutType::get(result->getContext()), - IndexType::get(result->getContext()), - }, - operandsLoc, result->operands)) || - failed(parser.parseKeyword("bindings")) || failed(parser.parseEqual()) || - failed(parser.parseLSquare()) || - failed(parseDescriptorSetBindings(parser, result)) || - failed(parser.parseRSquare()) || - failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { - return failure(); - } - return success(); -} - -template -static void printDescriptorSetBindings(OpAsmPrinter &p, T op) { - for (int i = 0; i < op.binding_ordinals().size(); ++i) { - p.printOperand(op.binding_ordinals()[i]); - p << " = ("; - p.printOperand(op.binding_buffers()[i]); - p << ", "; - p.printOperand(op.binding_offsets()[i]); - p << ", "; - p.printOperand(op.binding_lengths()[i]); - p << ")"; - if (i < op.binding_ordinals().size() - 1) p << ", "; - } -} - -static void printCommandBufferPushDescriptorSetOp( - OpAsmPrinter &p, CommandBufferPushDescriptorSetOp op) { - p << op.getOperationName() << ' '; - p.printOperand(op.command_buffer()); - p << ", "; - p.printOperand(op.executable_layout()); - p << ", set = "; - p.printOperand(op.set()); - p << ", bindings = ["; - printDescriptorSetBindings(p, op); - p << "]"; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{ - "set", - "bindings", - }); -} - -//===----------------------------------------------------------------------===// -// hal.command_buffer.bind_descriptor_set -//===----------------------------------------------------------------------===// - -void CommandBufferBindDescriptorSetOp::build(OpBuilder &builder, - OperationState &state, - Value commandBuffer, - Value executableLayout, - int64_t set, Value descriptorSet, - ValueRange dynamicOffsets) { - build(builder, state, commandBuffer, executableLayout, - builder.createOrFold(state.location, set), - descriptorSet, dynamicOffsets); -} - -void CommandBufferBindDescriptorSetOp::build(OpBuilder &builder, - OperationState &state, - Value commandBuffer, - Value executableLayout, Value set, - Value descriptorSet, - ValueRange dynamicOffsets) { - state.addOperands({commandBuffer, executableLayout, set, descriptorSet}); - state.addOperands(dynamicOffsets); -} - -//===----------------------------------------------------------------------===// -// hal.command_buffer.dispatch.symbol -//===----------------------------------------------------------------------===// - -void CommandBufferDispatchSymbolOp::build( - OpBuilder &builder, OperationState &state, Value commandBuffer, - IREE::HAL::ExecutableEntryPointOp entryPoint, Value workgroupX, - Value workgroupY, Value workgroupZ) { - state.addOperands({commandBuffer, workgroupX, workgroupY, workgroupZ}); - // Construct Executable::Target::EntryPoint nested reference. - StringRef executableOpSymName = - entryPoint->getParentOp() - ->getParentOp() - ->getAttrOfType(SymbolTable::getSymbolAttrName()) - .getValue(); - state.addAttribute("entry_point", - builder.getSymbolRefAttr( - executableOpSymName, - {builder.getSymbolRefAttr(entryPoint->getParentOp()), - builder.getSymbolRefAttr(entryPoint)})); -} - //===----------------------------------------------------------------------===// // hal.constant_pool //===----------------------------------------------------------------------===// @@ -995,47 +862,6 @@ void DescriptorSetCreateOp::build( state.addOperands(bindingLengths); } -static ParseResult parseDescriptorSetCreateOp(OpAsmParser &parser, - OperationState *result) { - OpAsmParser::OperandType device; - OpAsmParser::OperandType setLayout; - auto operandsLoc = parser.getCurrentLocation(); - if (failed(parser.parseOperand(device)) || failed(parser.parseComma()) || - failed(parser.parseOperand(setLayout)) || failed(parser.parseComma()) || - failed(parser.resolveOperands( - ArrayRef{ - device, - setLayout, - }, - ArrayRef{ - DeviceType::get(result->getContext()), - DescriptorSetLayoutType::get(result->getContext()), - }, - operandsLoc, result->operands)) || - failed(parser.parseKeyword("bindings")) || failed(parser.parseEqual()) || - failed(parser.parseLSquare()) || - failed(parseDescriptorSetBindings(parser, result)) || - failed(parser.parseRSquare()) || - failed(parser.parseOptionalAttrDictWithKeyword(result->attributes))) { - return failure(); - } - return success(); -} - -static void printDescriptorSetCreateOp(OpAsmPrinter &p, - DescriptorSetCreateOp op) { - p << op.getOperationName() << ' '; - p.printOperand(op.device()); - p << ", "; - p.printOperand(op.set_layout()); - p << ", bindings = ["; - printDescriptorSetBindings(p, op); - p << "]"; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{ - "bindings", - }); -} - void DescriptorSetCreateOp::getAsmResultNames( function_ref setNameFn) { setNameFn(result(), "descriptor_set"); @@ -1091,10 +917,10 @@ static ParseResult parseDeviceSwitchOp(OpAsmParser &parser, OperationState *result) { OpAsmParser::OperandType device; Type deviceType; - if (failed(parser.parseLParen()) || failed(parser.parseOperand(device)) || + if (failed(parser.parseLess()) || failed(parser.parseOperand(device)) || failed(parser.parseColonType(deviceType)) || failed(parser.resolveOperand(device, deviceType, result->operands)) || - failed(parser.parseRParen()) || + failed(parser.parseGreater()) || failed(parser.parseOptionalArrowTypeList(result->types))) { return failure(); } @@ -1150,11 +976,11 @@ static ParseResult parseDeviceSwitchOp(OpAsmParser &parser, } static void printDeviceSwitchOp(OpAsmPrinter &p, DeviceSwitchOp op) { - p << op.getOperationName() << "("; + p << op.getOperationName() << "<"; p.printOperand(op.device()); p << " : "; p.printType(op.device().getType()); - p << ")"; + p << ">"; p.printOptionalArrowTypeList(op.getResultTypes()); p << "\n"; p.getStream().indent(4); @@ -1542,8 +1368,8 @@ ArrayAttr InterfaceOp::getExecutableSetLayoutsAttr() { Builder builder(getContext()); SmallVector, 4> setAttrs; for (auto bindingOp : getBlock().getOps()) { - int set = bindingOp.set(); - int binding = bindingOp.binding(); + unsigned set = bindingOp.set().getZExtValue(); + unsigned binding = bindingOp.binding().getZExtValue(); if (set >= setAttrs.size()) setAttrs.resize(set + 1); auto &bindingAttrs = setAttrs[set]; if (binding >= bindingAttrs.size()) bindingAttrs.resize(binding + 1); @@ -1582,13 +1408,12 @@ static ParseResult parseInterfaceBindingOp(OpAsmParser &parser, result->attributes)) || failed(parser.parseComma()) || failed(parser.parseKeyword("set")) || failed(parser.parseEqual()) || - failed(parser.parseAttribute(setAttr, - parser.getBuilder().getIntegerType(32), + failed(parser.parseAttribute(setAttr, parser.getBuilder().getIndexType(), "set", result->attributes)) || failed(parser.parseComma()) || failed(parser.parseKeyword("binding")) || failed(parser.parseEqual()) || failed(parser.parseAttribute(bindingAttr, - parser.getBuilder().getIntegerType(32), + parser.getBuilder().getIndexType(), "binding", result->attributes)) || failed(parser.parseComma()) || failed(parser.parseKeyword("type")) || failed(parser.parseEqual()) || diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td index 8ac0bfa507cf..3f1c7fa78db9 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -16,6 +16,7 @@ #define IREE_DIALECT_HAL_OPS include "iree/compiler/Dialect/HAL/IR/HALBase.td" +include "iree/compiler/Dialect/HAL/IR/HALInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -59,71 +60,6 @@ def HAL_ExSubmitAndWaitOp : HAL_Op<"ex.submit_and_wait", [YieldPoint]> { let assemblyFormat = "$device `,` $command_buffer attr-dict"; } -//===----------------------------------------------------------------------===// -// HAL struct definition ops -//===----------------------------------------------------------------------===// - -def HAL_MakeMemoryBarrierOp : HAL_MakeTupleOp<"make_memory_barrier", [ - DeclareOpInterfaceMethods, - ]> { - let summary = [{temporary memory barrier allocation operation}]; - let description = [{ - Allocates a temporary MemoryBarrier struct that can be passed to the - command buffer barrier operations. - }]; - - let arguments = (ins - HAL_AccessScopeBitfieldAttr:$source_scope, - HAL_AccessScopeBitfieldAttr:$target_scope - ); - let results = (outs - HAL_MemoryBarrier:$result - ); - - let assemblyFormat = [{ - $source_scope `,` $target_scope attr-dict-with-keyword `:` type($result) - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "IREE::HAL::AccessScopeBitfield":$sourceScope, - "IREE::HAL::AccessScopeBitfield":$targetScope)>, - ]; -} - -def HAL_MakeBufferBarrierOp : HAL_MakeTupleOp<"make_buffer_barrier", [ - DeclareOpInterfaceMethods, - ]> { - let summary = [{temporary buffer barrier allocation operation}]; - let description = [{ - Allocates a temporary BufferBarrier struct that can be passed to the - command buffer barrier operations. - }]; - - let arguments = (ins - HAL_AccessScopeBitfieldAttr:$source_scope, - HAL_AccessScopeBitfieldAttr:$target_scope, - HAL_Buffer:$buffer, - HAL_DeviceSize:$offset, - HAL_DeviceSize:$length - ); - let results = (outs - HAL_BufferBarrier:$result - ); - - let assemblyFormat = [{ - $source_scope `,` $target_scope `,` operands attr-dict-with-keyword `:` - type($result) - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "IREE::HAL::AccessScopeBitfield":$sourceScope, - "IREE::HAL::AccessScopeBitfield":$targetScope, "Value":$buffer, - "Value":$offset, "Value":$length)>, - ]; -} - //===----------------------------------------------------------------------===// // Global variables //===----------------------------------------------------------------------===// @@ -150,15 +86,33 @@ def HAL_VariableOp : HAL_Op<"variable", [ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type, - "Optional":$initializer, "Optional":$initialValue, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, - "mlir::FuncOp":$initializer, CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type, - "Attribute":$initialValue, CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "StringRef":$name, "bool":$isMutable, "Type":$type, - CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins + "StringRef":$name, + "bool":$isMutable, + "Type":$type, + "Optional":$initializer, + "Optional":$initialValue, + CArg<"ArrayRef", "{}">:$attrs + )>, + OpBuilder<(ins + "StringRef":$name, + "bool":$isMutable, + "mlir::FuncOp":$initializer, + CArg<"ArrayRef", "{}">:$attrs + )>, + OpBuilder<(ins + "StringRef":$name, + "bool":$isMutable, + "Type":$type, + "Attribute":$initialValue, + CArg<"ArrayRef", "{}">:$attrs + )>, + OpBuilder<(ins + "StringRef":$name, + "bool":$isMutable, + "Type":$type, + CArg<"ArrayRef", "{}">:$attrs + )>, ]; let verifier = [{ return verifyVariableOp(*this); }]; @@ -180,7 +134,9 @@ def HAL_VariableAddressOp : HAL_PureOp<"variable.address"> { HAL_VariablePtr:$result ); - let assemblyFormat = "$variable attr-dict `:` type($result)"; + let assemblyFormat = [{ + $variable attr-dict `:` type($result) + }]; } def HAL_VariableLoadOp : HAL_Op<"variable.load", [ @@ -199,7 +155,9 @@ def HAL_VariableLoadOp : HAL_Op<"variable.load", [ HAL_VariableType:$result ); - let assemblyFormat = "$variable attr-dict `:` type($result)"; + let assemblyFormat = [{ + $variable attr-dict `:` type($result) + }]; let verifier = [{ return verifyVariableLoadOp(*this); }]; } @@ -217,7 +175,9 @@ def FLOW_VariableLoadIndirectOp : HAL_Op<"variable.load.indirect"> { HAL_VariableType:$result ); - let assemblyFormat = "$variable attr-dict `:` type($variable) `->` type($result)"; + let assemblyFormat = [{ + $variable attr-dict `:` type($variable) `->` type($result) + }]; let verifier = [{ return verifyVariableLoadIndirectOp(*this); }]; @@ -235,7 +195,9 @@ def HAL_VariableStoreOp : HAL_Op<"variable.store"> { HAL_VariableRefAttr:$variable ); - let assemblyFormat = "$value `,` $variable attr-dict `:` type($value)"; + let assemblyFormat = [{ + $value `,` $variable attr-dict `:` type($value) + }]; let verifier = [{ return verifyVariableStoreOp(*this); }]; @@ -253,7 +215,9 @@ def HAL_VariableStoreIndirectOp : HAL_Op<"variable.store.indirect"> { HAL_VariablePtr:$variable ); - let assemblyFormat = "$value `,` $variable attr-dict `:` type($value) `->` type($variable)"; + let assemblyFormat = [{ + $value `,` $variable attr-dict `:` type($value) `->` type($variable) + }]; let verifier = [{ return verifyVariableStoreIndirectOp(*this); }]; @@ -299,7 +263,7 @@ def HAL_CheckSuccessOp : HAL_Op<"check_success"> { } //===----------------------------------------------------------------------===// -// iree::hal::Allocator +// !hal.allocator / iree_hal_allocator_t //===----------------------------------------------------------------------===// def HAL_AllocatorComputeSizeOp : HAL_PureOp<"allocator.compute_size", [ @@ -321,8 +285,11 @@ def HAL_AllocatorComputeSizeOp : HAL_PureOp<"allocator.compute_size", [ ); let assemblyFormat = [{ - $allocator `,` `shape` `=` `[` $shape `]` `,` `element_type` `=` - $element_type attr-dict + `<` $allocator `:` type($allocator) `>` + `shape` `(` `[` $shape `]` `)` + `type` `(` $element_type `)` + `:` type($result) + attr-dict-with-keyword }]; let builders = [ @@ -357,8 +324,12 @@ def HAL_AllocatorComputeOffsetOp : HAL_PureOp<"allocator.compute_offset", [ ); let assemblyFormat = [{ - $allocator `,` `shape` `=` `[` $shape `]` `,` `element_type` `=` - $element_type `,` `indices` `=` `[` $indices `]` attr-dict + `<` $allocator `:` type($allocator) `>` + `indices` `(` `[` $indices `]` `)` + `shape` `(` `[` $shape `]` `)` + `type` `(` $element_type `)` + `:` type($offset) + attr-dict-with-keyword }]; let builders = [ @@ -395,9 +366,13 @@ def HAL_AllocatorComputeRangeOp : HAL_PureOp<"allocator.compute_range", [ ); let assemblyFormat = [{ - $allocator `,` `shape` `=` `[` $shape `]` `,` `element_type` `=` - $element_type `,` `indices` `=` `[` $indices `]` `,` `lengths` `=` `[` - $lengths `]` attr-dict + `<` $allocator `:` type($allocator) `>` + `indices` `(` `[` $indices `]` `)` + `lengths` `(` `[` $lengths `]` `)` + `shape` `(` `[` $shape `]` `)` + `type` `(` $element_type `)` + `:` type($offset) `,` type($length) + attr-dict-with-keyword }]; let builders = [ @@ -412,6 +387,7 @@ def HAL_AllocatorComputeRangeOp : HAL_PureOp<"allocator.compute_range", [ def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{empty buffer allocation operation}]; let description = [{ @@ -424,26 +400,23 @@ def HAL_AllocatorAllocateOp : HAL_Op<"allocator.allocate", [ HAL_Allocator:$allocator, HAL_MemoryTypeBitfieldAttr:$memory_types, HAL_BufferUsageBitfieldAttr:$buffer_usage, - HAL_DeviceSize:$allocation_size + HAL_DeviceSize:$result_size ); let results = (outs HAL_Buffer:$result ); + // TODO(benvanik): change type/usage to ref params. let assemblyFormat = [{ - $allocator `,` $memory_types `,` $buffer_usage `,` $allocation_size - attr-dict-with-keyword `:` type($result) + `<` $allocator `:` type($allocator) `>` + `type` `(` $memory_types `)` + `usage` `(` $buffer_usage `)` + `:` custom(type($result), $result_size) + attr-dict-with-keyword }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$allocator, - "IREE::HAL::MemoryTypeBitfield":$memoryTypes, - "IREE::HAL::BufferUsageBitfield":$bufferUsage, "Value":$allocationSize)>, - ]; } -def HAL_AllocatorAllocateConstOp : HAL_Op<"allocator.allocate.const", [ +def HAL_AllocatorConstantOp : HAL_Op<"allocator.constant", [ DeclareOpInterfaceMethods, ]> { let summary = [{constant buffer allocation operation}]; @@ -461,26 +434,24 @@ def HAL_AllocatorAllocateConstOp : HAL_Op<"allocator.allocate.const", [ ElementsAttr:$value ); let results = (outs - HAL_Buffer:$result + AnyTypeOf<[HAL_Buffer, HAL_BufferView]>:$result ); + // TODO(benvanik): change type/usage to ref params. let assemblyFormat = [{ - $allocator `,` $memory_types `,` $buffer_usage attr-dict-with-keyword `:` - type($result) `=` $value + `<` $allocator `:` type($allocator) `>` + `type` `(` $memory_types `)` + `usage` `(` $buffer_usage `)` + `:` type($result) `=` $value + attr-dict-with-keyword }]; - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$allocator, - "IREE::HAL::MemoryTypeBitfield":$memoryTypes, - "IREE::HAL::BufferUsageBitfield":$bufferUsage, "ElementsAttr":$value)>, - ]; - let hasCanonicalizer = 1; } def HAL_AllocatorMapOp : HAL_Op<"allocator.map", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ]> { let summary = [{allocator-supported host buffer wrapping operation}]; let description = [{ @@ -502,23 +473,19 @@ def HAL_AllocatorMapOp : HAL_Op<"allocator.map", [ HAL_Buffer:$result ); + // TODO(benvanik): change type/usage to ref params. let assemblyFormat = [{ - $allocator `,` $memory_types `,` $buffer_usage `,` - $source `[` $offset `,` $length `]` attr-dict-with-keyword - `:` type($source) `->` type($result) + `<` $allocator `:` type($allocator) `>` + `source` `(` $source `:` type($source) `)` `` `[` $offset `,` $length `]` + `type` `(` $memory_types `)` + `usage` `(` $buffer_usage `)` + `:` type($result) + attr-dict-with-keyword }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$allocator, - "IREE::HAL::MemoryTypeBitfield":$memoryTypes, - "IREE::HAL::BufferUsageBitfield":$bufferUsage, "Value":$source, - "Value":$offset, "Value":$length)>, - ]; } //===----------------------------------------------------------------------===// -// iree::hal::Buffer +// !hal.buffer / iree_hal_buffer_t //===----------------------------------------------------------------------===// def HAL_BufferAllocatorOp : HAL_PureOp<"buffer.allocator", [ @@ -530,108 +497,114 @@ def HAL_BufferAllocatorOp : HAL_PureOp<"buffer.allocator", [ }]; let arguments = (ins - HAL_Buffer:$buffer + HAL_BufferType:$buffer ); let results = (outs HAL_Allocator:$result ); - let assemblyFormat = "$buffer `:` type($result) attr-dict"; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$buffer)>, - ]; + let assemblyFormat = [{ + `<` $buffer `:` type($buffer) `>` + `:` type($result) + attr-dict-with-keyword + }]; let hasCanonicalizer = 1; } -// TODO(benvanik): clone buffer op. - def HAL_BufferSubspanOp : HAL_PureOp<"buffer.subspan", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + ]> { let summary = [{buffer subspan operation}]; let description = [{ Returns a reference to a subspan of the buffer. }]; let arguments = (ins - HAL_Buffer:$source_buffer, + HAL_BufferType:$source_buffer, HAL_DeviceSize:$source_offset, HAL_DeviceSize:$length ); let results = (outs - HAL_Buffer:$result + HAL_BufferType:$result ); - let assemblyFormat = "operands attr-dict `:` type($result)"; -} - -def HAL_BufferFillOp : HAL_Op<"buffer.fill"> { - let summary = [{buffer fill operation}]; - let description = [{ - Fills the target buffer with the given repeating value. + let assemblyFormat = [{ + `<` $source_buffer `:` type($source_buffer) `>` + `` `[` $source_offset `,` $length `]` + `:` type($result) + attr-dict-with-keyword }]; - - let arguments = (ins - HAL_Buffer:$target_buffer, - HAL_DeviceSize:$target_offset, - HAL_DeviceSize:$length, - I32:$pattern - ); - - let assemblyFormat = "operands attr-dict"; } -def HAL_BufferReadDataOp : HAL_Op<"buffer.read_data"> { - let summary = [{buffer-to-heap read operation}]; +def HAL_BufferLengthOp : HAL_PureOp<"buffer.length", [ + DeclareOpInterfaceMethods, + ]> { + let summary = [{buffer byte length accessor}]; let description = [{ - Reads a block of byte data from the resource at the given offset. + Returns the allocated size of a buffer in bytes. + May be less than the underlying buffer allocation if this is a subspan or + view into another buffer. }]; let arguments = (ins - HAL_Buffer:$source_buffer, - HAL_DeviceSize:$source_offset, - MutableByteBufferType:$target_buffer, - HAL_DeviceSize:$target_offset, - HAL_DeviceSize:$length + HAL_BufferType:$buffer + ); + let results = (outs + HAL_DeviceSize:$result ); - let assemblyFormat = "operands attr-dict `:` type($target_buffer)"; + let assemblyFormat = [{ + `<` $buffer `:` type($buffer) `>` + `:` type($result) + attr-dict-with-keyword + }]; } -def HAL_BufferWriteDataOp : HAL_Op<"buffer.write_data"> { - let summary = [{heap-to-buffer write operation}]; +def HAL_BufferFillOp : HAL_Op<"buffer.fill"> { + let summary = [{buffer fill operation}]; let description = [{ - Writes a block of byte data into the resource at the given offset. + Fills the target buffer with the given repeating value. }]; let arguments = (ins - HAL_HostBuffer:$source_buffer, - HAL_DeviceSize:$source_offset, - HAL_Buffer:$target_buffer, + HAL_BufferType:$target_buffer, HAL_DeviceSize:$target_offset, - HAL_DeviceSize:$length + HAL_DeviceSize:$length, + I32:$pattern ); - let assemblyFormat = "operands attr-dict `:` type($source_buffer)"; + let assemblyFormat = [{ + `<` $target_buffer `:` type($target_buffer) `>` + `` `[` $target_offset `,` $length `]` + `pattern` `(` $pattern `:` type($pattern) `)` + attr-dict-with-keyword + }]; } -def HAL_BufferCopyDataOp : HAL_Op<"buffer.copy_data"> { +def HAL_BufferCopyOp : HAL_Op<"buffer.copy"> { let summary = [{buffer-to-buffer copy operation}]; let description = [{ Copies data from the provided source_buffer into the buffer. }]; let arguments = (ins - HAL_Buffer:$source_buffer, + HAL_BufferType:$source_buffer, HAL_DeviceSize:$source_offset, - HAL_Buffer:$target_buffer, + HAL_BufferType:$target_buffer, HAL_DeviceSize:$target_offset, HAL_DeviceSize:$length ); - let assemblyFormat = "operands attr-dict"; + let assemblyFormat = [{ + `source` `(` $source_buffer `:` type($source_buffer) `)` + `` `[` $source_offset `]` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + attr-dict-with-keyword + }]; } def HAL_BufferLoadOp : HAL_PureOp<"buffer.load"> { @@ -641,7 +614,7 @@ def HAL_BufferLoadOp : HAL_PureOp<"buffer.load"> { }]; let arguments = (ins - HAL_Buffer:$source_buffer, + HAL_BufferType:$source_buffer, HAL_DeviceSize:$source_offset ); let results = (outs @@ -649,7 +622,10 @@ def HAL_BufferLoadOp : HAL_PureOp<"buffer.load"> { ); let assemblyFormat = [{ - $source_buffer `[` $source_offset `]` `:` type($result) attr-dict + `<` $source_buffer `:` type($source_buffer) `>` + `` `[` $source_offset `]` + `:` type($result) + attr-dict-with-keyword }]; } @@ -661,53 +637,22 @@ def HAL_BufferStoreOp : HAL_Op<"buffer.store"> { let arguments = (ins AnyTypeOf<[HAL_PrimitiveType, AnyVector]>:$value, - HAL_Buffer:$target_buffer, + HAL_BufferType:$target_buffer, HAL_DeviceSize:$target_offset ); let assemblyFormat = [{ - $value `,` $target_buffer `[` $target_offset `]` `:` type($value) attr-dict + `<` $target_buffer `:` type($target_buffer) `>` + `` `[` $target_offset `]` + `value` `(` $value `:` type($value) `)` + attr-dict-with-keyword }]; } //===----------------------------------------------------------------------===// -// iree::hal::BufferView +// !hal.buffer_view / iree_hal_buffer_view_t //===----------------------------------------------------------------------===// -def HAL_BufferViewConstOp : HAL_PureOp<"buffer_view.const", [ - DeclareOpInterfaceMethods, - ]> { - let summary = [{buffer view constant initializer}]; - let description = [{ - Pseudo-op for allocating a constant buffer view. Expands to a buffer - allocation and a buffer view wrapper. - }]; - - let arguments = (ins - HAL_Allocator:$allocator, - HAL_MemoryTypeBitfieldAttr:$memory_types, - HAL_BufferUsageBitfieldAttr:$buffer_usage, - ElementsAttr:$value - ); - let results = (outs - HAL_BufferView:$result - ); - - let assemblyFormat = [{ - $allocator `,` $memory_types `,` $buffer_usage `:` type($result) - attr-dict-with-keyword `=` $value - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$allocator, - "IREE::HAL::MemoryTypeBitfield":$memoryTypes, - "IREE::HAL::BufferUsageBitfield":$bufferUsage, "ElementsAttr":$value)>, - ]; - - let hasCanonicalizer = 1; -} - def HAL_BufferViewCreateOp : HAL_PureOp<"buffer_view.create", [ DeclareOpInterfaceMethods, ]> { @@ -720,7 +665,7 @@ def HAL_BufferViewCreateOp : HAL_PureOp<"buffer_view.create", [ }]; let arguments = (ins - HAL_Buffer:$buffer, + HAL_BufferType:$buffer, HAL_ElementType:$element_type, HAL_Shape:$shape ); @@ -729,16 +674,24 @@ def HAL_BufferViewCreateOp : HAL_PureOp<"buffer_view.create", [ ); let assemblyFormat = [{ - $buffer `,` `element_type` `=` $element_type `,` `shape` `=` `[` $shape `]` - `:` type($result) attr-dict + $buffer `,` + `element_type` `=` $element_type `,` + `shape` `=` `[` $shape `]` + attr-dict `:` type($buffer) `->` type($result) }]; let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "Value":$buffer, "int32_t":$elementType, - "ValueRange":$shape)>, - OpBuilder<(ins "Value":$buffer, "Value":$elementType, - "ValueRange":$shape)>, + OpBuilder<(ins + "Value":$buffer, + "int32_t":$elementType, + "ValueRange":$shape + )>, + OpBuilder<(ins + "Value":$buffer, + "Value":$elementType, + "ValueRange":$shape + )>, ]; } @@ -762,8 +715,10 @@ def HAL_BufferViewSubviewOp : HAL_PureOp<"buffer_view.subview", [ ); let assemblyFormat = [{ - $buffer_view `,` `indices` `=` `[` $indices `]` `,` `lengths` `=` `[` - $lengths `]` `:` type($result) attr-dict + $buffer_view `,` + `indices` `=` `[` $indices `]` `,` + `lengths` `=` `[` $lengths `]` + attr-dict `:` type($result) }]; let hasCanonicalizer = 1; @@ -781,15 +736,12 @@ def HAL_BufferViewBufferOp : HAL_PureOp<"buffer_view.buffer", [ HAL_BufferView:$buffer_view ); let results = (outs - HAL_Buffer:$result + HAL_BufferType:$result ); - let assemblyFormat = "$buffer_view `:` type($result) attr-dict"; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$bufferView)>, - ]; + let assemblyFormat = [{ + $buffer_view attr-dict `:` type($result) + }]; let hasCanonicalizer = 1; } @@ -809,7 +761,9 @@ def HAL_BufferViewByteLengthOp : HAL_PureOp<"buffer_view.byte_length", [ HAL_DeviceSize:$result ); - let assemblyFormat = "$buffer_view `:` type($result) attr-dict"; + let assemblyFormat = [{ + $buffer_view attr-dict `:` type($result) + }]; let skipDefaultBuilders = 1; let builders = [ @@ -835,7 +789,9 @@ def HAL_BufferViewComputeOffsetOp : HAL_PureOp<"buffer_view.compute_offset", [ ); let assemblyFormat = [{ - $buffer_view `,` `indices` `=` `[` $indices `]` attr-dict + $buffer_view `,` + `indices` `=` `[` $indices `]` + attr-dict `:` type($offset) }]; let skipDefaultBuilders = 1; @@ -867,14 +823,19 @@ def HAL_BufferViewComputeRangeOp : HAL_PureOp<"buffer_view.compute_range", [ ); let assemblyFormat = [{ - $buffer_view `,` `indices` `=` `[` $indices `]` `,` `lengths` `=` `[` - $lengths `]` attr-dict + $buffer_view `,` + `indices` `=` `[` $indices `]` `,` + `lengths` `=` `[` $lengths `]` + attr-dict `:` type($offset) `,` type($length) }]; let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "Value":$bufferView, "ValueRange":$indices, - "ValueRange":$lengths)>, + OpBuilder<(ins + "Value":$bufferView, + "ValueRange":$indices, + "ValueRange":$lengths + )>, ]; let hasCanonicalizer = 1; @@ -893,7 +854,9 @@ def HAL_BufferViewElementTypeOp : HAL_PureOp<"buffer_view.element_type"> { HAL_ElementType:$result ); - let assemblyFormat = [{$buffer_view attr-dict `:` type($result)}]; + let assemblyFormat = [{ + $buffer_view attr-dict `:` type($result) + }]; } def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { @@ -906,10 +869,12 @@ def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> { HAL_BufferView:$buffer_view ); let results = (outs - Index:$result + HAL_Dim:$result ); - let assemblyFormat = [{$buffer_view attr-dict `:` type($result)}]; + let assemblyFormat = [{ + $buffer_view attr-dict `:` type($result) + }]; } def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> { @@ -920,13 +885,15 @@ def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> { let arguments = (ins HAL_BufferView:$buffer_view, - I32Attr:$index + IndexAttr:$index ); let results = (outs - Index:$result + HAL_Dim:$result ); - let assemblyFormat = [{$buffer_view `,` $index attr-dict `:` type($result)}]; + let assemblyFormat = [{ + $buffer_view `,` $index attr-dict `:` type($result) + }]; } def HAL_BufferViewDimsOp : HAL_PureOp<"buffer_view.dims"> { @@ -942,7 +909,9 @@ def HAL_BufferViewDimsOp : HAL_PureOp<"buffer_view.dims"> { Variadic:$result ); - let assemblyFormat = [{$buffer_view attr-dict `:` type($result)}]; + let assemblyFormat = [{ + $buffer_view attr-dict `:` type($result) + }]; let hasCanonicalizer = 1; } @@ -960,11 +929,13 @@ def HAL_BufferViewTraceOp : HAL_Op<"buffer_view.trace", []> { Variadic:$operands ); - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let assemblyFormat = [{ + attr-dict ($operands^ `:` type($operands))? + }]; } //===----------------------------------------------------------------------===// -// iree::hal::CommandBuffer +// !hal.command_buffer / iree_hal_command_buffer_t //===----------------------------------------------------------------------===// def HAL_CommandBufferCreateOp : HAL_Op<"command_buffer.create", [ @@ -985,16 +956,12 @@ def HAL_CommandBufferCreateOp : HAL_Op<"command_buffer.create", [ ); let assemblyFormat = [{ - $device `,` $modes `,` $command_categories attr-dict-with-keyword `:` - type($result) + `device` `(` $device `:` type($device) `)` + `mode` `(` $modes `)` + `categories` `(` $command_categories `)` + `:` type($result) + attr-dict-with-keyword }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$device, - "IREE::HAL::CommandBufferModeBitfield":$modes, - "IREE::HAL::CommandCategoryBitfield":$commandCategories)>, - ]; } def HAL_CommandBufferBeginOp : HAL_Op<"command_buffer.begin"> { @@ -1008,7 +975,10 @@ def HAL_CommandBufferBeginOp : HAL_Op<"command_buffer.begin"> { HAL_CommandBuffer:$command_buffer ); - let assemblyFormat = "$command_buffer attr-dict"; + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + attr-dict-with-keyword + }]; } def HAL_CommandBufferEndOp : HAL_Op<"command_buffer.end"> { @@ -1021,7 +991,10 @@ def HAL_CommandBufferEndOp : HAL_Op<"command_buffer.end"> { HAL_CommandBuffer:$command_buffer ); - let assemblyFormat = "$command_buffer attr-dict"; + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + attr-dict-with-keyword + }]; } def HAL_CommandBufferDeviceOp : HAL_PureOp<"command_buffer.device"> { @@ -1037,7 +1010,11 @@ def HAL_CommandBufferDeviceOp : HAL_PureOp<"command_buffer.device"> { HAL_Device:$device ); - let assemblyFormat = "$command_buffer attr-dict `:` type($device)"; + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `:` type($device) + attr-dict-with-keyword + }]; let hasCanonicalizer = 1; } @@ -1056,6 +1033,14 @@ def HAL_CommandBufferExecutionBarrierOp : HAL_Op<"command_buffer.execution_barri HAL_ExecutionStageBitfieldAttr:$target_stage_mask, HAL_ExecutionBarrierFlagBitfieldAttr:$flags ); + + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `source` `(` $source_stage_mask `)` + `target` `(` $target_stage_mask `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; } // TODO(benvanik): event ops. @@ -1068,13 +1053,19 @@ def HAL_CommandBufferFillBufferOp : HAL_Op<"command_buffer.fill_buffer"> { let arguments = (ins HAL_CommandBuffer:$command_buffer, - HAL_Buffer:$target_buffer, + HAL_BufferType:$target_buffer, HAL_DeviceSize:$target_offset, HAL_DeviceSize:$length, I32:$pattern ); - let assemblyFormat = "operands attr-dict"; + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `,` $length `]` + `pattern` `(` $pattern `:` type($pattern) `)` + attr-dict-with-keyword + }]; } // TODO(benvanik): update buffer op. @@ -1087,14 +1078,22 @@ def HAL_CommandBufferCopyBufferOp : HAL_Op<"command_buffer.copy_buffer"> { let arguments = (ins HAL_CommandBuffer:$command_buffer, - HAL_Buffer:$source_buffer, + HAL_BufferType:$source_buffer, HAL_DeviceSize:$source_offset, - HAL_Buffer:$target_buffer, + HAL_BufferType:$target_buffer, HAL_DeviceSize:$target_offset, HAL_DeviceSize:$length ); - let assemblyFormat = "operands attr-dict"; + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `source` `(` $source_buffer `:` type($source_buffer) `)` + `` `[` $source_offset `]` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + attr-dict-with-keyword + }]; } def HAL_CommandBufferPushConstantsOp : @@ -1106,27 +1105,22 @@ def HAL_CommandBufferPushConstantsOp : Push constants are always 4-byte values and treated as opaque, meaning that they may be bit-casted floats, bit-packed booleans, etc. - - ```mlir - hal.command_buffer.push_constants %cmd, %exe_layout, - offset = 0, - values = [%value0, %value1] : i32 - hal.command_buffer.push_constants %cmd, %exe_layout, - offset = 2, - values = [%value2, %value3] : i32 - ``` }]; let arguments = (ins HAL_CommandBuffer:$command_buffer, HAL_ExecutableLayout:$executable_layout, - I32Attr:$offset, + IndexAttr:$offset, Variadic:$values ); let assemblyFormat = [{ - $command_buffer `,` $executable_layout `,` `offset` `=` $offset `,` - `values` `=` `[` $values `]` `:` `i32` attr-dict-with-keyword + `<` $command_buffer `:` type($command_buffer) `>` + `layout` `(` $executable_layout `:` type($executable_layout) `)` + `offset` `(` $offset `)` + `values` `(` `[` $values `]` `)` + `:` type($values) + attr-dict-with-keyword }]; } @@ -1137,14 +1131,6 @@ def HAL_CommandBufferPushDescriptorSetOp : let summary = [{command buffer descriptor set push binding operation}]; let description = [{ Pushes an inline-defined descriptor set to the command buffer. - - ```mlir - hal.command_buffer.push_descriptor_set %cmd, %executable_layout, set = %c0, bindings = [ - %c0 = (%buffer_0, %buffer_offset_0, %buffer_length_0), - %c1 = (%buffer_1, %buffer_offset_1, %buffer_length_1), - %c2 = (%buffer_2, %buffer_offset_2, %buffer_length_2) - ] - ``` }]; let arguments = (ins @@ -1152,11 +1138,25 @@ def HAL_CommandBufferPushDescriptorSetOp : HAL_ExecutableLayout:$executable_layout, Index:$set, Variadic:$binding_ordinals, - Variadic:$binding_buffers, + Variadic:$binding_buffers, Variadic:$binding_offsets, Variadic:$binding_lengths ); + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `layout` `(` $executable_layout `:` type($executable_layout) `)` + `` `[` $set `]` + `bindings` `(` `[` + custom($binding_ordinals, + $binding_buffers, + type($binding_buffers), + $binding_offsets, + $binding_lengths) + `]` `)` + attr-dict-with-keyword + }]; + let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "Value":$commandBuffer, "Value":$executableLayout, @@ -1185,34 +1185,19 @@ def HAL_CommandBufferBindDescriptorSetOp : ); let assemblyFormat = [{ - $command_buffer `,` $executable_layout `,` `set` `=` $set `,` - $descriptor_set (`,` `offsets` `=` `[` $dynamic_offsets^ `]`)? + `<` $command_buffer `:` type($command_buffer) `>` + `layout` `(` $executable_layout `:` type($executable_layout) `)` + `` `[` $set `]` + `set` `(` $descriptor_set `:` type($descriptor_set) `)` + (`offsets` `(` `[` $dynamic_offsets^ `]` `)`)? attr-dict-with-keyword }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$commandBuffer, "Value":$executableLayout, - "int64_t":$set, "Value":$descriptorSet, - CArg<"ValueRange", "{}">:$dynamicOffsets)>, - OpBuilder<(ins "Value":$commandBuffer, "Value":$executableLayout, - "Value":$set, "Value":$descriptorSet, - CArg<"ValueRange", "{}">:$dynamicOffsets)>, - ]; } def HAL_CommandBufferDispatchSymbolOp : HAL_Op<"command_buffer.dispatch.symbol"> { let summary = [{command buffer dispatch recording operation, using symbolref}]; let description = [{ Dispatches an execution request, using a nested symbol reference to the entry point. - - ```mlir - %x = constant 128 : index - %y = constant 32 : index - %z = constant 1 : index - hal.command_buffer.dispatch.symbol %cmd, @executable::@target::@entry, - workgroup_xyz = [%x, %y, %z] - ``` }]; let arguments = (ins @@ -1224,32 +1209,21 @@ def HAL_CommandBufferDispatchSymbolOp : HAL_Op<"command_buffer.dispatch.symbol"> ); let assemblyFormat = [{ - $command_buffer `,` $entry_point `,` - `workgroup_xyz` `=` `[` $workgroup_x `,` $workgroup_y `,` $workgroup_z `]` - attr-dict + `<` $command_buffer `:` type($command_buffer) `>` + `target` `(` $entry_point `)` + `workgroups` `(` `[` + $workgroup_x `,` + $workgroup_y `,` + $workgroup_z + `]` `)` + attr-dict-with-keyword }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Value":$commandBuffer, - "IREE::HAL::ExecutableEntryPointOp":$entryPoint, "Value":$workgroupX, - "Value":$workgroupY, "Value":$workgroupZ)>, - ]; } def HAL_CommandBufferDispatchOp : HAL_Op<"command_buffer.dispatch"> { let summary = [{command buffer dispatch recording operation}]; let description = [{ Dispatches an execution request. - - ```mlir - %x = constant 128 : index - %y = constant 32 : index - %z = constant 1 : index - hal.command_buffer.dispatch %cmd, %executable, - entry_point = 0, - workgroup_xyz = [%x, %y, %z] - ``` }]; let arguments = (ins @@ -1262,9 +1236,15 @@ def HAL_CommandBufferDispatchOp : HAL_Op<"command_buffer.dispatch"> { ); let assemblyFormat = [{ - $command_buffer `,` $executable `,` `entry_point` `=` $entry_point `,` - `workgroup_xyz` `=` `[` $workgroup_x `,` $workgroup_y `,` $workgroup_z `]` - attr-dict + `<` $command_buffer `:` type($command_buffer) `>` + `target` `(` $executable `:` type($executable) `)` + `` `[` $entry_point `]` + `workgroups` `(` `[` + $workgroup_x `,` + $workgroup_y `,` + $workgroup_z + `]` `)` + attr-dict-with-keyword }]; } @@ -1283,13 +1263,16 @@ def HAL_CommandBufferDispatchIndirectSymbolOp : HAL_Op<"command_buffer.dispatch. let arguments = (ins HAL_CommandBuffer:$command_buffer, SymbolRefAttr:$entry_point, - HAL_Buffer:$workgroups_buffer, + HAL_BufferType:$workgroups_buffer, HAL_DeviceSize:$workgroups_offset ); let assemblyFormat = [{ - $command_buffer `,` $entry_point `,` - `workgroups` `=` $workgroups_buffer `[` $workgroups_offset `]` attr-dict + `<` $command_buffer `:` type($command_buffer) `>` + `target` `(` $entry_point `)` + `workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)` + `` `[` $workgroups_offset `]` + attr-dict-with-keyword }]; } @@ -1298,25 +1281,23 @@ def HAL_CommandBufferDispatchIndirectOp : HAL_Op<"command_buffer.dispatch.indire let description = [{ Dispatches an execution request with the dispatch parameters loaded from the given buffer. - - ```mlir - hal.command_buffer.dispatch.indirect %cmd, %executable, - entry_point = 0, - workgroups = %buffer[%offset] - ``` }]; let arguments = (ins HAL_CommandBuffer:$command_buffer, HAL_Executable:$executable, HAL_OrdinalAttr:$entry_point, - HAL_Buffer:$workgroups_buffer, + HAL_BufferType:$workgroups_buffer, HAL_DeviceSize:$workgroups_offset ); let assemblyFormat = [{ - $command_buffer `,` $executable `,` `entry_point` `=` $entry_point `,` - `workgroups` `=` $workgroups_buffer `[` $workgroups_offset `]` attr-dict + `<` $command_buffer `:` type($command_buffer) `>` + `target` `(` $executable `:` type($executable) `)` + `` `[` $entry_point `]` + `workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)` + `` `[` $workgroups_offset `]` + attr-dict-with-keyword }]; } @@ -1349,8 +1330,10 @@ def HAL_ConstantPoolOp : HAL_Op<"constant_pool", [ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, - "BufferConstraintsAttr":$bufferConstraints)>, + OpBuilder<(ins + "StringRef":$name, + "BufferConstraintsAttr":$bufferConstraints + )>, ]; let extraClassDeclaration = [{ @@ -1382,7 +1365,7 @@ def HAL_ConstantPoolValueOp : HAL_Op<"constant_pool.value", [ ); let assemblyFormat = [{ - $sym_name attr-dict `=` $value + $sym_name `=` $value attr-dict-with-keyword }]; } @@ -1407,9 +1390,10 @@ def HAL_ConstantPoolSpanOp : HAL_Op<"constant_pool.span", [ ); let assemblyFormat = [{ - $sym_name `:` $tensor_type attr-dict + $sym_name `:` $tensor_type `=` $storage_buffer `[` $storage_range `]` (`->` $runtime_buffer^ `[` $runtime_range `]`)? + attr-dict-with-keyword }]; } @@ -1431,8 +1415,9 @@ def HAL_ConstantPoolSplatOp : HAL_Op<"constant_pool.splat", [ ); let assemblyFormat = [{ - $sym_name attr-dict `=` $value + $sym_name `=` $value (`->` $runtime_buffer^ `[` $runtime_range `]`)? + attr-dict-with-keyword }]; } @@ -1453,7 +1438,9 @@ def HAL_ConstantPoolLoadOp : HAL_PureOp<"constant_pool.load", [ TypeAlias:$result ); - let assemblyFormat = "$constant attr-dict `:` type($result)"; + let assemblyFormat = [{ + $constant `:` type($result) attr-dict-with-keyword + }]; let skipDefaultBuilders = 1; let builders = [ @@ -1482,7 +1469,7 @@ def HAL_ConstantStorageOp : HAL_Op<"constant_storage", [ ); let assemblyFormat = [{ - $sym_name attr-dict-with-keyword `=` $value + $sym_name `=` $value attr-dict-with-keyword }]; } @@ -1503,7 +1490,7 @@ def HAL_ConstantStorageLookupOp : ); let assemblyFormat = [{ - $constant `:` type($result) attr-dict + $constant `:` type($result) attr-dict-with-keyword }]; } @@ -1525,23 +1512,13 @@ def HAL_ConstantSubspanOp : HAL_PureOp<"constant.subspan", [ ); let assemblyFormat = [{ - $runtime_buffer `[` $runtime_range `]` `:` type($result) attr-dict + $runtime_buffer `[` $runtime_range `]` `:` type($result) + attr-dict-with-keyword }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "Type":$resultType, "SymbolRefAttr":$runtimeBuffer, - "ByteRangeAttr":$runtimeRange), - [{ - $_state.addTypes({resultType}); - $_state.addAttribute("runtime_buffer", runtimeBuffer); - $_state.addAttribute("runtime_range", runtimeRange); - }]>, - ]; } //===----------------------------------------------------------------------===// -// iree::hal::DescriptorSet +// !hal.descriptor_set / iree_hal_descriptor_set_layout_t //===----------------------------------------------------------------------===// def HAL_DescriptorSetCreateOp : HAL_PureOp<"descriptor_set.create", [ @@ -1557,7 +1534,7 @@ def HAL_DescriptorSetCreateOp : HAL_PureOp<"descriptor_set.create", [ HAL_Device:$device, HAL_DescriptorSetLayout:$set_layout, Variadic:$binding_ordinals, - Variadic:$binding_buffers, + Variadic:$binding_buffers, Variadic:$binding_offsets, Variadic:$binding_lengths ); @@ -1565,15 +1542,30 @@ def HAL_DescriptorSetCreateOp : HAL_PureOp<"descriptor_set.create", [ HAL_DescriptorSet:$result ); - let skipDefaultBuilders = 1; + let assemblyFormat = [{ + `device` `(` $device `:` type($device) `)` + `layout` `(` $set_layout `:` type($set_layout) `)` + `bindings` `(` `[` + custom($binding_ordinals, + $binding_buffers, + type($binding_buffers), + $binding_offsets, + $binding_lengths) + `]` `)` + attr-dict-with-keyword + }]; + let builders = [ - OpBuilder<(ins "Value":$device, "Value":$setLayout, - "ArrayRef":$bindings)>, + OpBuilder<(ins + "Value":$device, + "Value":$setLayout, + "ArrayRef":$bindings + )>, ]; } //===----------------------------------------------------------------------===// -// iree::hal::DescriptorSetLayout +// !hal.descriptor_set_layout / iree_hal_descriptor_set_layout_t //===----------------------------------------------------------------------===// def HAL_DescriptorSetLayoutCreateOp : @@ -1586,13 +1578,6 @@ def HAL_DescriptorSetLayoutCreateOp : The same descriptor set layout may be shared with many different executable layouts and by doing so some runtime binding overhead when switching between executables that use the same set layouts can be reduced. - - ```mlir - %layout = hal.descriptor_set_layout.create %device, "PushOnly", bindings = [ - #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, - #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> - ] : !hal.descriptor_set_layout - ``` }]; let arguments = (ins @@ -1605,7 +1590,11 @@ def HAL_DescriptorSetLayoutCreateOp : ); let assemblyFormat = [{ - $device `,` $usage_type `,` `bindings` `=` $bindings attr-dict `:` type($result) + `device` `(` $device `:` type($device) `)` + `usage` `(` $usage_type `)` + `bindings` `(` $bindings `)` + `:` type($result) + attr-dict-with-keyword }]; } @@ -1616,13 +1605,6 @@ def HAL_DescriptorSetLayoutLookupOp : HAL_PureOp<"descriptor_set_layout.lookup", let description = [{ Used during conversion to provide a placeholder for a globally cached and possibly lazy-initialized descriptor set layout. - - ```mlir - %layout = hal.descriptor_set_layout.lookup %device, "PushOnly", bindings = [ - #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, - #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> - ] : !hal.descriptor_set_layout - ``` }]; let arguments = (ins @@ -1635,12 +1617,16 @@ def HAL_DescriptorSetLayoutLookupOp : HAL_PureOp<"descriptor_set_layout.lookup", ); let assemblyFormat = [{ - $device `,` $usage_type `,` `bindings` `=` $bindings attr-dict `:` type($result) + `device` `(` $device `:` type($device) `)` + `usage` `(` $usage_type `)` + `bindings` `(` $bindings `)` + `:` type($result) + attr-dict-with-keyword }]; } //===----------------------------------------------------------------------===// -// iree::hal::Device +// !hal.device / iree_hal_device_t //===----------------------------------------------------------------------===// def HAL_DeviceAllocatorOp : HAL_PureOp<"device.allocator", [ @@ -1659,7 +1645,9 @@ def HAL_DeviceAllocatorOp : HAL_PureOp<"device.allocator", [ HAL_Allocator:$result ); - let assemblyFormat = "$device attr-dict `:` type($result)"; + let assemblyFormat = [{ + `<` $device `:` type($device) `>` `:` type($result) attr-dict-with-keyword + }]; let skipDefaultBuilders = 1; let builders = [ @@ -1711,7 +1699,7 @@ def HAL_DeviceSwitchOp : HAL_Op<"device.switch", [IsolatedFromAbove]> { %c1 = constant 1 : i32 %c2 = constant 2 : i32 %device = ... : !hal.device - %0 = hal.device.switch(%device : !hal.device) -> i32 + %0 = hal.device.switch<%device : !hal.device> -> i32 #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) { hal.return %c1a : i32 }, @@ -1735,19 +1723,23 @@ def HAL_DeviceSwitchOp : HAL_Op<"device.switch", [IsolatedFromAbove]> { let regions = (region VariadicRegion:$condition_regions); - let extraClassDeclaration = [{ - /// Returns the index of the args() operand in the Operation operands list. - unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; } - }]; - let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "TypeRange":$resultTypes, "Value":$device, + OpBuilder<(ins + "TypeRange":$resultTypes, + "Value":$device, "ArrayRef":$conditions, "ArrayRef>":$conditionArgs, - CArg<"ArrayRef", "{}">:$attributes)>, + CArg<"ArrayRef", "{}">:$attributes + )>, ]; + let extraClassDeclaration = [{ + /// Returns the index of the args() operand in the Operation operands list. + unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; } + }]; + + let verifier = [{ return verifyDeviceSwitchOp(*this); }]; } @@ -1761,7 +1753,9 @@ def HAL_ReturnOp : HAL_Op<"return", [Terminator]> { Variadic:$operands ); - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let assemblyFormat = [{ + ($operands^ `:` type($operands))? attr-dict + }]; let builders = [ OpBuilder<(ins), @@ -1782,7 +1776,8 @@ def HAL_DeviceMatchIDOp : HAL_PureOp<"device.match.id"> { device is not known at compile-time. ```mlir - %is_match = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1 + %is_match = hal.device.match.id<%device : !hal.device> + pattern("vulkan-*") : i1 ``` }]; @@ -1795,8 +1790,10 @@ def HAL_DeviceMatchIDOp : HAL_PureOp<"device.match.id"> { ); let assemblyFormat = [{ - $device `,` `pattern` `=` `[` $pattern `]` attr-dict - `:` `(` type($device) `)` `->` type($result) + `<` $device `:` type($device) `>` + `pattern` `(` $pattern `)` + `:` type($result) + attr-dict-with-keyword }]; } @@ -1808,7 +1805,8 @@ def HAL_DeviceMatchMemoryModelOp : HAL_PureOp<"device.match.memory_model"> { device is not known at compile-time. ```mlir - %is_match = hal.device.match.memory_model %device, memory_model = "Unified" : (!hal.device) -> i1 + %is_match = hal.device.match.memory_model<%device : !hal.device> + value("Unified") : i1 ``` }]; @@ -1821,13 +1819,15 @@ def HAL_DeviceMatchMemoryModelOp : HAL_PureOp<"device.match.memory_model"> { ); let assemblyFormat = [{ - $device `,` `model` `=` `[` $model `]` attr-dict - `:` `(` type($device) `)` `->` type($result) + `<` $device `:` type($device) `>` + `value` `(` $model `)` + `:` type($result) + attr-dict-with-keyword }]; } //===----------------------------------------------------------------------===// -// iree::hal::Executable +// !hal.executable / iree_hal_executable_t //===----------------------------------------------------------------------===// def HAL_ExecutableOp : HAL_Op<"executable", [ @@ -1902,18 +1902,26 @@ def HAL_ExecutableEntryPointOp : HAL_Op<"executable.entry_point", [ let regions = (region VariadicRegion>:$workgroup_count_region); let builders = [ - OpBuilder< - (ins "::llvm::StringRef":$sym_name, "::llvm::APInt":$ordinal, - "::llvm::StringRef":$interface,"::mlir::Type":$signature, - "::mlir::ArrayAttr":$workgroup_size), - [{build($_builder, $_state, sym_name, ordinal, interface, signature, - workgroup_size, 0);}]>, - OpBuilder< - (ins "::mlir::StringAttr":$sym_name, "::mlir::IntegerAttr":$ordinal, - "::mlir::FlatSymbolRefAttr":$interface, "::mlir::TypeAttr":$signature, - "::mlir::ArrayAttr":$workgroup_size), - [{build($_builder, $_state, sym_name, ordinal, interface, signature, - workgroup_size, 0);}]> + OpBuilder<(ins + "::llvm::StringRef":$sym_name, + "::llvm::APInt":$ordinal, + "::llvm::StringRef":$interface, + "::mlir::Type":$signature, + "::mlir::ArrayAttr":$workgroup_size + ), [{ + build($_builder, $_state, sym_name, ordinal, interface, signature, + workgroup_size, 0); + }]>, + OpBuilder<(ins + "::mlir::StringAttr":$sym_name, + "::mlir::IntegerAttr":$ordinal, + "::mlir::FlatSymbolRefAttr":$interface, + "::mlir::TypeAttr":$signature, + "::mlir::ArrayAttr":$workgroup_size + ), [{ + build($_builder, $_state, sym_name, ordinal, interface, signature, + workgroup_size, 0); + }]> ]; let verifier = [{ return verifyExecutableEntryPointOp(*this); }]; @@ -2004,8 +2012,16 @@ def HAL_ExecutableBinaryOp : HAL_Op<"executable.binary", [ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, "uint32_t":$format, "std::vector":$data)>, - OpBuilder<(ins "StringRef":$name, "uint32_t":$format, "DenseIntElementsAttr":$data)>, + OpBuilder<(ins + "StringRef":$name, + "uint32_t":$format, + "std::vector":$data + )>, + OpBuilder<(ins + "StringRef":$name, + "uint32_t":$format, + "DenseIntElementsAttr":$data + )>, ]; let extraClassDeclaration = [{ @@ -2042,11 +2058,6 @@ def HAL_ExecutableCreateOp : HAL_PureOp<"executable.create", [ (such as when JITing/etc). As the cache is internally synchronized callers can issue preparation requests from multiple threads - even for the same executables - and calls will block until preparation completes. - - ```mlir - %exe = hal.executable.create %device, @executable::@target, - layouts = [%exe_layout_0] : !hal.executable - ``` }]; let arguments = (ins @@ -2059,9 +2070,11 @@ def HAL_ExecutableCreateOp : HAL_PureOp<"executable.create", [ ); let assemblyFormat = [{ - $device `,` $executable_target `,` - `layouts` `=` `[` $layouts `]` - attr-dict-with-keyword `:` type($result) + `device` `(` $device `:` type($device) `)` + `target` `(` $executable_target `)` + `layouts` `(` `[` $layouts `]` `)` + `:` type($result) + attr-dict-with-keyword }]; } @@ -2082,7 +2095,12 @@ def HAL_ExecutableLookupOp : HAL_PureOp<"executable.lookup", [ HAL_Executable:$result ); - let assemblyFormat = "$device `,` $executable attr-dict `:` type($result)"; + let assemblyFormat = [{ + `device` `(` $device `:` type($device) `)` + `executable` `(` $executable `)` + `:` type($result) + attr-dict-with-keyword + }]; let skipDefaultBuilders = 1; let builders = [ @@ -2096,7 +2114,7 @@ def HAL_ExecutableLookupOp : HAL_PureOp<"executable.lookup", [ } //===----------------------------------------------------------------------===// -// iree::hal::Executable Interfaces +// hal.interface //===----------------------------------------------------------------------===// def HAL_InterfaceOp : HAL_Op<"interface", [ @@ -2130,15 +2148,17 @@ def HAL_InterfaceOp : HAL_Op<"interface", [ let arguments = (ins StrAttr:$sym_name, - OptionalAttr:$push_constants + OptionalAttr:$push_constants ); let regions = (region SizedRegion<1>:$body); let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, - CArg<"IntegerAttr", "{}">:$pushConstants)>, + OpBuilder<(ins + "StringRef":$name, + CArg<"IntegerAttr", "{}">:$pushConstants + )>, ]; let extraClassDeclaration = [{ @@ -2185,8 +2205,8 @@ def HAL_InterfaceBindingOp : HAL_Op<"interface.binding", [ let arguments = (ins StrAttr:$sym_name, - I32Attr:$set, - I32Attr:$binding, + IndexAttr:$set, + IndexAttr:$binding, HAL_DescriptorTypeAttr:$type, HAL_MemoryAccessBitfieldAttr:$access ); @@ -2213,7 +2233,9 @@ def HAL_InterfaceWorkgroupIDOp : HAL_PureOp<"interface.workgroup.id", [ let arguments = (ins IndexAttr:$dimension); let results = (outs HAL_Dim:$result); - let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)"; + let assemblyFormat = [{ + `[` $dimension `]` attr-dict `:` type($result) + }]; } def HAL_InterfaceWorkgroupCountOp : HAL_PureOp<"interface.workgroup.count", [ @@ -2238,7 +2260,9 @@ def HAL_InterfaceWorkgroupCountOp : HAL_PureOp<"interface.workgroup.count", [ let arguments = (ins IndexAttr:$dimension); let results = (outs HAL_Dim:$result); - let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)"; + let assemblyFormat = [{ + `[` $dimension `]` attr-dict `:` type($result) + }]; } def HAL_InterfaceWorkgroupSizeOp : HAL_PureOp<"interface.workgroup.size", [ @@ -2263,7 +2287,9 @@ def HAL_InterfaceWorkgroupSizeOp : HAL_PureOp<"interface.workgroup.size", [ let arguments = (ins IndexAttr:$dimension); let results = (outs HAL_Dim:$result); - let assemblyFormat = "`[` $dimension `]` attr-dict `:` type($result)"; + let assemblyFormat = [{ + `[` $dimension `]` attr-dict `:` type($result) + }]; } def HAL_InterfaceLoadConstantOp : HAL_PureOp<"interface.load.constant"> { @@ -2378,7 +2404,8 @@ def HAL_InterfaceStoreTensorOp : HAL_Op<"interface.store.tensor"> { ); let assemblyFormat = [{ - $operand `,` $binding `,` `offset` `=` $offset attr-dict `:` type($operand) + $operand `,` $binding `,` `offset` `=` $offset + attr-dict `:` type($operand) }]; let extraClassDeclaration = [{ @@ -2388,9 +2415,10 @@ def HAL_InterfaceStoreTensorOp : HAL_Op<"interface.store.tensor"> { }]; } -def HAL_InterfaceLoadTensorTileOp : HAL_PureOp< - "interface.load.tensor.tile", - [AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> { +def HAL_InterfaceLoadTensorTileOp : HAL_PureOp<"interface.load.tensor.tile", [ + AttrSizedOperandSegments, + OffsetSizeAndStrideOpInterface, + ]> { let summary = [{loads a tensor tile from an executable IO binding}]; let description = [{ Loads a tensor tile value from an executable IO binding at the given @@ -2437,10 +2465,13 @@ def HAL_InterfaceLoadTensorTileOp : HAL_PureOp< ); let assemblyFormat = [{ - $binding `,` `base_offset` `=` $base_offset `,` `offsets` `=` - custom($offsets, $static_offsets) - `,` `sizes` `=` custom($sizes, $static_sizes) - `,` `strides` `=` + $binding `,` + `base_offset` `=` $base_offset `,` + `offsets` `=` + custom($offsets, $static_offsets) `,` + `sizes` `=` + custom($sizes, $static_sizes) `,` + `strides` `=` custom($strides, $static_strides) attr-dict `:` type($result) }]; @@ -2477,9 +2508,10 @@ def HAL_InterfaceLoadTensorTileOp : HAL_PureOp< }]; } -def HAL_InterfaceStoreTensorTileOp : HAL_Op< - "interface.store.tensor.tile", - [AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> { +def HAL_InterfaceStoreTensorTileOp : HAL_Op<"interface.store.tensor.tile", [ + AttrSizedOperandSegments, + OffsetSizeAndStrideOpInterface, + ]> { let summary = [{stores a tensor tile in an executable IO binding}]; let description = [{ Stores a tensor value into an executable IO binding. This is a pseudo op @@ -2524,10 +2556,13 @@ def HAL_InterfaceStoreTensorTileOp : HAL_Op< ); let assemblyFormat = [{ - $operand `,` $binding `,` `base_offset` `=` $base_offset `,` `offsets` `=` - custom($offsets, $static_offsets) - `,` `sizes` `=` custom($sizes, $static_sizes) - `,` `strides` `=` + $operand `,` $binding `,` + `base_offset` `=` $base_offset `,` + `offsets` `=` + custom($offsets, $static_offsets) `,` + `sizes` `=` + custom($sizes, $static_sizes) `,` + `strides` `=` custom($strides, $static_strides) attr-dict `:` type($operand) }]; @@ -2565,7 +2600,7 @@ def HAL_InterfaceStoreTensorTileOp : HAL_Op< } //===----------------------------------------------------------------------===// -// iree::hal::ExecutableLayout +// !hal.executable_layout / iree_hal_executable_layout_t //===----------------------------------------------------------------------===// def HAL_ExecutableLayoutCreateOp : HAL_PureOp<"executable_layout.create", [ @@ -2580,30 +2615,24 @@ def HAL_ExecutableLayoutCreateOp : HAL_PureOp<"executable_layout.create", [ is often worth the cost to allow a small number of unused bindings in one executable such that it can share layouts with others that will be scheduled adjacent to it. - - ```mlir - %set0 = hal.descriptor_set_layout.create ... - %set1 = hal.descriptor_set_layout.create ... - %layout = hal.executable_layout.create %device, - push_constants = 3, - set_layouts = [%set0, %set1] : !hal.executable_layout - ``` }]; let arguments = (ins HAL_Device:$device, - I32Attr:$push_constants, + IndexAttr:$push_constants, Variadic:$set_layouts ); let results = (outs HAL_ExecutableLayout:$result ); + // TODO(benvanik): include descriptor set layout types. let assemblyFormat = [{ - $device `,` - `push_constants` `=` $push_constants `,` - `set_layouts` `=` `[` $set_layouts `]` - attr-dict-with-keyword `:` type($result) + `device` `(` $device `:` type($device) `)` + `push_constants` `(` $push_constants `)` + `layouts` `(` `[` $set_layouts `]` `)` + `:` type($result) + attr-dict-with-keyword }]; } @@ -2614,43 +2643,30 @@ def HAL_ExecutableLayoutLookupOp : HAL_PureOp<"executable_layout.lookup", [ let description = [{ Used during conversion to provide a placeholder for a globally cached and possibly lazy-initialized executable layout. - - ```mlir - %layout = hal.executable_layout.lookup %device, set_layouts = [ - [ - #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, - #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> - ] - ] : !hal.executable_layout - ``` }]; let arguments = (ins HAL_Device:$device, + OptionalAttr:$push_constants, // TODO(benvanik): replace with a nested typed attr that works. // Array of HAL_DescriptorSetLayoutBindingArrayAttr. - ArrayAttr:$set_layouts, - OptionalAttr:$push_constants + ArrayAttr:$set_layouts ); let results = (outs HAL_ExecutableLayout:$result ); let assemblyFormat = [{ - $device `,` `set_layouts` `=` $set_layouts - (`,` `push_constants` `=` $push_constants^)? - attr-dict-with-keyword `:` type($result) + `device` `(` $device `:` type($device) `)` + (`push_constants` `(` $push_constants^ `)`)? + `layouts` `(` $set_layouts `)` + `:` type($result) + attr-dict-with-keyword }]; } //===----------------------------------------------------------------------===// -// iree::hal::RingBuffer -//===----------------------------------------------------------------------===// - -// TODO(benvanik): ring buffer. - -//===----------------------------------------------------------------------===// -// iree::hal::Semaphore +// !hal.semaphore / iree_hal_semaphore_t //===----------------------------------------------------------------------===// def HAL_SemaphoreCreateOp : HAL_Op<"semaphore.create", [ @@ -2670,8 +2686,10 @@ def HAL_SemaphoreCreateOp : HAL_Op<"semaphore.create", [ ); let assemblyFormat = [{ - $device `,` `initial_value` `=` $initial_value - attr-dict-with-keyword `:` type($result) + `device` `(` $device `:` type($device) `)` + `initial` `(` $initial_value `)` + `:` type($result) + attr-dict-with-keyword }]; } @@ -2694,7 +2712,9 @@ def HAL_SemaphoreQueryOp : HAL_Op<"semaphore.query"> { ); let assemblyFormat = [{ - $semaphore attr-dict-with-keyword `:` type($status) `,` type($value) + `<` $semaphore `:` type($semaphore) `>` + `:` type($status) `,` type($value) + attr-dict-with-keyword }]; } @@ -2711,7 +2731,9 @@ def HAL_SemaphoreSignalOp : HAL_Op<"semaphore.signal"> { ); let assemblyFormat = [{ - $semaphore `,` `value` `=` $new_value attr-dict-with-keyword + `<` $semaphore `:` type($semaphore) `>` + `value` `(` $new_value `)` + attr-dict-with-keyword }]; } @@ -2729,7 +2751,9 @@ def HAL_SemaphoreFailOp : HAL_Op<"semaphore.fail"> { ); let assemblyFormat = [{ - $semaphore `,` `status` `=` $status attr-dict-with-keyword + `<` $semaphore `:` type($semaphore) `>` + `status` `(` $status `)` + attr-dict-with-keyword }]; } @@ -2752,7 +2776,10 @@ def HAL_SemaphoreAwaitOp : HAL_Op<"semaphore.await", [YieldPoint]> { ); let assemblyFormat = [{ - $semaphore `,` `min_value` `=` $min_value attr-dict-with-keyword `:` type($status) + `<` $semaphore `:` type($semaphore) `>` + `until` `(` $min_value `)` + `:` type($status) + attr-dict-with-keyword }]; } diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index 5451d23ae577..24d1967f03f0 100644 --- a/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -15,6 +15,8 @@ #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/IREE/IR/IREEOps.h" #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" @@ -58,6 +60,243 @@ static LogicalResult parseEnumAttr(DialectAsmParser &parser, StringRef attrName, return success(); } +template +static LogicalResult parseOptionalEnumAttr(DialectAsmParser &parser, + StringRef attrName, AttrType &attr) { + if (succeeded(parser.parseOptionalQuestion())) { + // Special case `?` to indicate any/none/undefined/etc. + attr = AttrType::get(parser.getBuilder().getContext(), 0); + return success(); + } + return parseEnumAttr(parser, attrName, attr); +} + +static LogicalResult parseMemoryType(DialectAsmParser &parser, + Attribute &attr) { + if (succeeded(parser.parseOptionalQuestion())) { + attr = parser.getBuilder().getI32IntegerAttr(0); + return success(); + } + + StringRef fullString; + if (succeeded(parser.parseOptionalString(&fullString))) { + auto symbolized = symbolizeEnum(fullString); + if (!symbolized.hasValue()) { + return parser.emitError(parser.getCurrentLocation()) + << "failed to parse memory type enum value"; + } + attr = parser.getBuilder().getI32IntegerAttr( + static_cast(symbolized.getValue())); + return success(); + } + + StringRef shortString; + if (failed(parser.parseKeyword(&shortString))) { + return parser.emitError(parser.getCurrentLocation()) + << "failed to find memory type short string"; + } + MemoryTypeBitfield memoryType = MemoryTypeBitfield::None; + for (char c : shortString) { + switch (c) { + case 'T': + memoryType = memoryType | MemoryTypeBitfield::Transient; + break; + case 'h': + memoryType = memoryType | MemoryTypeBitfield::HostVisible; + break; + case 'H': + memoryType = memoryType | MemoryTypeBitfield::HostLocal; + break; + case 'c': + memoryType = memoryType | MemoryTypeBitfield::HostCoherent; + break; + case 'C': + memoryType = memoryType | MemoryTypeBitfield::HostCached; + break; + case 'd': + memoryType = memoryType | MemoryTypeBitfield::DeviceVisible; + break; + case 'D': + memoryType = memoryType | MemoryTypeBitfield::DeviceLocal; + break; + default: + return parser.emitError(parser.getCurrentLocation()) + << "unknown memory type short-form char: " << c; + } + } + attr = + parser.getBuilder().getI32IntegerAttr(static_cast(memoryType)); + return success(); +} + +static void printMemoryType(DialectAsmPrinter &printer, + MemoryTypeBitfield memoryType) { + if (memoryType == MemoryTypeBitfield::None) { + printer << '?'; + return; + } + if (allEnumBitsSet(memoryType, MemoryTypeBitfield::Transient)) { + printer << 't'; + } + if (allEnumBitsSet(memoryType, MemoryTypeBitfield::HostLocal)) { + printer << 'H'; + } else if (allEnumBitsSet(memoryType, MemoryTypeBitfield::HostVisible)) { + printer << 'h'; + } + if (allEnumBitsSet(memoryType, MemoryTypeBitfield::HostCoherent)) { + printer << 'c'; + } + if (allEnumBitsSet(memoryType, MemoryTypeBitfield::HostCached)) { + printer << 'C'; + } + if (allEnumBitsSet(memoryType, MemoryTypeBitfield::DeviceLocal)) { + printer << 'D'; + } else if (allEnumBitsSet(memoryType, MemoryTypeBitfield::DeviceVisible)) { + printer << 'd'; + } +} + +static LogicalResult parseMemoryAccess(DialectAsmParser &parser, + Attribute &attr) { + if (succeeded(parser.parseOptionalQuestion())) { + attr = parser.getBuilder().getI32IntegerAttr(0); + return success(); + } + + StringRef fullString; + if (succeeded(parser.parseOptionalString(&fullString))) { + auto symbolized = symbolizeEnum(fullString); + if (!symbolized.hasValue()) { + return parser.emitError(parser.getCurrentLocation()) + << "failed to parse memory access enum value"; + } + attr = parser.getBuilder().getI32IntegerAttr( + static_cast(symbolized.getValue())); + return success(); + } + + StringRef shortString; + if (failed(parser.parseKeyword(&shortString))) { + return parser.emitError(parser.getCurrentLocation()) + << "failed to find memory access short string"; + } + MemoryAccessBitfield memoryAccess = MemoryAccessBitfield::None; + for (char c : shortString) { + switch (c) { + case 'R': + memoryAccess = memoryAccess | MemoryAccessBitfield::Read; + break; + case 'W': + memoryAccess = memoryAccess | MemoryAccessBitfield::Write; + break; + case 'D': + memoryAccess = memoryAccess | MemoryAccessBitfield::Discard; + break; + case 'A': + memoryAccess = memoryAccess | MemoryAccessBitfield::MayAlias; + break; + default: + return parser.emitError(parser.getCurrentLocation()) + << "unknown memory access short-form char: " << c; + } + } + attr = + parser.getBuilder().getI32IntegerAttr(static_cast(memoryAccess)); + return success(); +} + +static void printMemoryAccess(DialectAsmPrinter &printer, + MemoryAccessBitfield memoryAccess) { + if (memoryAccess == MemoryAccessBitfield::None) { + printer << '?'; + return; + } + if (allEnumBitsSet(memoryAccess, MemoryAccessBitfield::Read)) { + printer << 'R'; + } + if (allEnumBitsSet(memoryAccess, MemoryAccessBitfield::Discard)) { + printer << 'D'; + } + if (allEnumBitsSet(memoryAccess, MemoryAccessBitfield::Write)) { + printer << 'W'; + } + if (allEnumBitsSet(memoryAccess, MemoryAccessBitfield::MayAlias)) { + printer << 'A'; + } +} + +static LogicalResult parseBufferUsage(DialectAsmParser &parser, + Attribute &attr) { + if (succeeded(parser.parseOptionalQuestion())) { + attr = parser.getBuilder().getI32IntegerAttr(0); + return success(); + } + + StringRef fullString; + if (succeeded(parser.parseOptionalString(&fullString))) { + auto symbolized = symbolizeEnum(fullString); + if (!symbolized.hasValue()) { + return parser.emitError(parser.getCurrentLocation()) + << "failed to parse buffer usage enum value"; + } + attr = parser.getBuilder().getI32IntegerAttr( + static_cast(symbolized.getValue())); + return success(); + } + + StringRef shortString; + if (failed(parser.parseKeyword(&shortString))) { + return parser.emitError(parser.getCurrentLocation()) + << "failed to find buffer usage short string"; + } + BufferUsageBitfield usage = BufferUsageBitfield::None; + for (char c : shortString) { + switch (c) { + case 'C': + usage = usage | BufferUsageBitfield::Constant; + break; + case 'T': + usage = usage | BufferUsageBitfield::Transfer; + break; + case 'M': + usage = usage | BufferUsageBitfield::Mapping; + break; + case 'D': + usage = usage | BufferUsageBitfield::Dispatch; + break; + default: + return parser.emitError(parser.getCurrentLocation()) + << "unknown buffer usage short-form char: " << c; + } + } + attr = parser.getBuilder().getI32IntegerAttr(static_cast(usage)); + return success(); +} + +static void printBufferUsage(DialectAsmPrinter &printer, + BufferUsageBitfield usage) { + if (usage == BufferUsageBitfield::None) { + printer << '?'; + return; + } + if (allEnumBitsSet(usage, BufferUsageBitfield::Constant)) { + printer << 'C'; + } + if (allEnumBitsSet(usage, BufferUsageBitfield::Transfer)) { + printer << 'T'; + } + if (allEnumBitsSet(usage, BufferUsageBitfield::Mapping)) { + printer << 'M'; + } + if (allEnumBitsSet(usage, BufferUsageBitfield::Dispatch)) { + printer << 'D'; + } +} + +//===----------------------------------------------------------------------===// +// Element types +//===----------------------------------------------------------------------===// + // Keep these in sync with iree/hal/api.h namespace { enum class NumericalType : uint32_t { @@ -126,6 +365,53 @@ Value getElementByteCount(Location loc, Value elementType, OpBuilder &builder) { c8); } +//===----------------------------------------------------------------------===// +// Size-aware type utils +//===----------------------------------------------------------------------===// + +// Returns the SSA value containing the size of the given |value|. +static Value lookupValueSize(Value value) { + assert(value.getType().isa()); + + auto definingOp = value.getDefiningOp(); + if (!definingOp) { + return {}; // Not yet implemented. + } + + // Skip do-not-optimize ops. + if (auto dnoOp = dyn_cast(definingOp)) { + return lookupValueSize(dnoOp.getOperand(0)); + } + + // Query size from the size-aware op that defined the value, as it knows how + // to get/build the right value. + unsigned resultIndex = -1; + for (unsigned i = 0; i < definingOp->getNumResults(); ++i) { + if (definingOp->getResult(i) == value) { + resultIndex = i; + break; + } + } + assert(resultIndex != -1 && "result not in results"); + auto sizeAwareOp = dyn_cast(definingOp); + if (!sizeAwareOp) return {}; + return sizeAwareOp.getResultSize(resultIndex); +} + +//===----------------------------------------------------------------------===// +// Object types +//===----------------------------------------------------------------------===// + +Value BufferType::inferSizeFromValue(Location loc, Value value, + OpBuilder &builder) const { + return builder.createOrFold(loc, builder.getIndexType(), + value); +} +Value BufferViewType::inferSizeFromValue(Location loc, Value value, + OpBuilder &builder) const { + return builder.createOrFold(loc, value); +} + //===----------------------------------------------------------------------===// // Struct types //===----------------------------------------------------------------------===// @@ -310,8 +596,7 @@ Attribute DescriptorSetLayoutBindingAttr::parse(DialectAsmParser &p) { if (failed(p.parseLess()) || failed(p.parseAttribute(bindingAttr, b.getIntegerType(32))) || failed(p.parseComma()) || failed(parseEnumAttr(p, "type", typeAttr)) || - failed(p.parseComma()) || - failed(parseEnumAttr(p, "access", accessAttr)) || + failed(p.parseComma()) || failed(parseMemoryAccess(p, accessAttr)) || failed(p.parseGreater())) { return {}; } @@ -323,7 +608,7 @@ void DescriptorSetLayoutBindingAttr::print(DialectAsmPrinter &p) const { os << getKindName() << "<"; os << binding() << ", "; os << "\"" << stringifyDescriptorType(type()) << "\", "; - os << "\"" << stringifyMemoryAccessBitfield(access()) << "\""; + printMemoryAccess(p, access()); os << ">"; } @@ -420,7 +705,8 @@ void DeviceMatchMemoryModelAttr::print(DialectAsmPrinter &p) const { os << "\">"; } -#include "iree/compiler/Dialect/HAL/IR/HALOpInterface.cpp.inc" +#include "iree/compiler/Dialect/HAL/IR/HALOpInterfaces.cpp.inc" +#include "iree/compiler/Dialect/HAL/IR/HALTypeInterfaces.cpp.inc" void HALDialect::registerAttributes() { addAttributes +inline bool allEnumBitsSet(T value, T required) { + return (static_cast(value) & static_cast(required)) == + static_cast(required); +} + //===----------------------------------------------------------------------===// // Object types //===----------------------------------------------------------------------===// @@ -69,15 +76,20 @@ class AllocatorType : public Type::TypeBase { using Base::Base; }; -class BufferType : public Type::TypeBase { +class BufferType : public Type::TypeBase { public: using Base::Base; + + Value inferSizeFromValue(Location loc, Value value, OpBuilder &builder) const; }; -class BufferViewType - : public Type::TypeBase { +class BufferViewType : public Type::TypeBase { public: using Base::Base; + + Value inferSizeFromValue(Location loc, Value value, OpBuilder &builder) const; }; class CommandBufferType @@ -157,29 +169,6 @@ class BufferConstraintsAdaptor { BufferConstraintsAttr bufferConstraints_; }; -class BufferBarrierType { - public: - static TupleType get(MLIRContext *context) { - return TupleType::get(context, { - IntegerType::get(context, 32), - IntegerType::get(context, 32), - BufferType::get(context), - IndexType::get(context), - IndexType::get(context), - }); - } -}; - -class MemoryBarrierType { - public: - static TupleType get(MLIRContext *context) { - return TupleType::get(context, { - IntegerType::get(context, 32), - IntegerType::get(context, 32), - }); - } -}; - // A tuple containing runtime values for a descriptor set binding: // using DescriptorSetBindingValue = std::tuple; diff --git a/iree/compiler/Dialect/HAL/IR/test/BUILD b/iree/compiler/Dialect/HAL/IR/test/BUILD index a4e7c0ba2d58..cbc1232093b6 100644 --- a/iree/compiler/Dialect/HAL/IR/test/BUILD +++ b/iree/compiler/Dialect/HAL/IR/test/BUILD @@ -25,6 +25,7 @@ iree_lit_test_suite( name = "lit", srcs = enforce_glob( [ + "allocator_op_folding.mlir", "allocator_ops.mlir", "attributes.mlir", "buffer_folding.mlir", diff --git a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt index 172247e7f333..6878008f9a15 100644 --- a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "allocator_op_folding.mlir" "allocator_ops.mlir" "attributes.mlir" "buffer_folding.mlir" diff --git a/iree/compiler/Dialect/HAL/IR/test/allocator_op_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/allocator_op_folding.mlir new file mode 100644 index 000000000000..159841a263f4 --- /dev/null +++ b/iree/compiler/Dialect/HAL/IR/test/allocator_op_folding.mlir @@ -0,0 +1,36 @@ +// RUN: iree-opt -split-input-file -canonicalize -cse %s | iree-opt -split-input-file | IreeFileCheck %s + +// CHECK-LABEL: @allocator_constant_buffer +// CHECK-SAME: %[[ALLOCATOR:.+]]: !hal.allocator +func @allocator_constant_buffer(%allocator: !hal.allocator) -> !hal.buffer { + // CHECK: %[[RODATA:.+]] = iree.byte_buffer.constant : !iree.byte_buffer = dense<123> : tensor<4x4xi32> + // CHECK-NEXT: %[[BUFFER:.+]] = hal.allocator.map<%[[ALLOCATOR]] : !hal.allocator> + // CHECK-SAME: source(%[[RODATA]] : !iree.byte_buffer)[%c0, %c-1] + // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Constant|Transfer|Mapping|Dispatch") + // CHECK-SAME: : !hal.buffer + %ref = hal.allocator.constant<%allocator : !hal.allocator> + type(DeviceLocal) usage(Transfer) : !hal.buffer = + dense<123> : tensor<4x4xi32> + // CHECK-NEXT: return %[[BUFFER]] + return %ref : !hal.buffer +} + +// ----- + +// CHECK-LABEL: @allocator_constant_buffer_view +// CHECK-SAME: %[[ALLOCATOR:.+]]: !hal.allocator +func @allocator_constant_buffer_view(%allocator: !hal.allocator) -> !hal.buffer_view { + // CHECK: %[[RODATA:.+]] = iree.byte_buffer.constant : !iree.byte_buffer = dense<123> : tensor<4x4xi32> + // CHECK-NEXT: %[[BUFFER:.+]] = hal.allocator.map<%[[ALLOCATOR]] : !hal.allocator> + // CHECK-SAME: source(%[[RODATA]] : !iree.byte_buffer)[%c0, %c-1] + // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("Constant|Transfer|Mapping|Dispatch") + // CHECK-SAME: : !hal.buffer + // CHECK-NEXT: %[[VIEW:.+]] = hal.buffer_view.create %[[BUFFER]], element_type = %c16777248_i32, shape = [%c4, %c4] : !hal.buffer -> !hal.buffer_view + %ref = hal.allocator.constant<%allocator : !hal.allocator> + type(DeviceLocal) usage(Transfer) : !hal.buffer_view = + dense<123> : tensor<4x4xi32> + // CHECK-NEXT: return %[[VIEW]] + return %ref : !hal.buffer_view +} diff --git a/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir index c7e7c1e50e87..567c1d2ca100 100644 --- a/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/allocator_ops.mlir @@ -1,86 +1,126 @@ -// Tests printing and parsing of hal.allocator ops. - -// RUN: iree-opt -allow-unregistered-dialect -split-input-file %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s +// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @allocator_compute_size -func @allocator_compute_size() -> index { - %0 = "test_hal.allocator"() : () -> !hal.allocator - %1:2 = "test_hal.shape"() : () -> (index, index) - %c32_i32 = constant 32 : i32 - // CHECK: %[[SZ:.+]] = hal.allocator.compute_size %0, shape = [%1#0, %1#1], element_type = %c32_i32 - %sz = hal.allocator.compute_size %0, shape = [%1#0, %1#1], element_type = %c32_i32 - // CHECK-NEXT: return %[[SZ]] +func @allocator_compute_size(%arg0: !hal.allocator) -> index { + // CHECK-DAG: %[[DIM0:.+]] = constant 100 + %dim0 = constant 100 : index + // CHECK-DAG: %[[DIM1:.+]] = constant 200 + %dim1 = constant 200 : index + // CHECK-DAG: %[[TYPE:.+]] = constant 32 + %type = constant 32 : i32 + // CHECK: %[[SIZE:.+]] = hal.allocator.compute_size<%arg0 : !hal.allocator> + // CHECK-SAME: shape([%[[DIM0]], %[[DIM1]]]) + // CHECK-SAME: type(%[[TYPE]]) : index + %sz = hal.allocator.compute_size<%arg0 : !hal.allocator> + shape([%dim0, %dim1]) + type(%type) : index + // CHECK-NEXT: return %[[SIZE]] return %sz : index } // ----- // CHECK-LABEL: @allocator_compute_offset -func @allocator_compute_offset() -> index { - %0 = "test_hal.allocator"() : () -> !hal.allocator - %1:2 = "test_hal.shape"() : () -> (index, index) - %2:2 = "test_hal.indices"() : () -> (index, index) - %c32_i32 = constant 32 : i32 - // CHECK: %off = hal.allocator.compute_offset %0, shape = [%1#0, %1#1], element_type = %c32_i32, indices = [%2#0, %2#1] - %off = hal.allocator.compute_offset %0, shape = [%1#0, %1#1], element_type = %c32_i32, indices = [%2#0, %2#1] +func @allocator_compute_offset(%arg0: !hal.allocator) -> index { + // CHECK-DAG: %[[IDX0:.+]] = constant 10 + %idx0 = constant 10 : index + // CHECK-DAG: %[[IDX1:.+]] = constant 20 + %idx1 = constant 20 : index + // CHECK-DAG: %[[DIM0:.+]] = constant 100 + %dim0 = constant 100 : index + // CHECK-DAG: %[[DIM1:.+]] = constant 200 + %dim1 = constant 200 : index + // CHECK-DAG: %[[TYPE:.+]] = constant 32 + %type = constant 32 : i32 + // CHECK: %[[OFFSET:.+]] = hal.allocator.compute_offset<%arg0 : !hal.allocator> + // CHECK-SAME: indices([%[[IDX0]], %[[IDX1]]]) + // CHECK-SAME: shape([%[[DIM0]], %[[DIM1]]]) + // CHECK-SAME: type(%[[TYPE]]) : index + %off = hal.allocator.compute_offset<%arg0 : !hal.allocator> + indices([%idx0, %idx1]) + shape([%dim0, %dim1]) + type(%type) : index + // CHECK-NEXT: return %[[OFFSET]] return %off : index } // ----- // CHECK-LABEL: @allocator_compute_range -func @allocator_compute_range() -> (index, index) { - %0 = "test_hal.allocator"() : () -> !hal.allocator - %1:2 = "test_hal.shape"() : () -> (index, index) - %2:2 = "test_hal.indices"() : () -> (index, index) - %3:2 = "test_hal.lengths"() : () -> (index, index) - %c32_i32 = constant 32 : i32 - // CHECK: %off, %len = hal.allocator.compute_range %0, shape = [%1#0, %1#1], element_type = %c32_i32, indices = [%2#0, %2#1], lengths = [%3#0, %3#1] - %off, %len = hal.allocator.compute_range %0, shape = [%1#0, %1#1], element_type = %c32_i32, indices = [%2#0, %2#1], lengths=[%3#0, %3#1] +func @allocator_compute_range(%arg0: !hal.allocator) -> (index, index) { + // CHECK-DAG: %[[IDX0:.+]] = constant 10 + %idx0 = constant 10 : index + // CHECK-DAG: %[[IDX1:.+]] = constant 20 + %idx1 = constant 20 : index + // CHECK-DAG: %[[LEN0:.+]] = constant 11 + %len0 = constant 11 : index + // CHECK-DAG: %[[LEN1:.+]] = constant 21 + %len1 = constant 21 : index + // CHECK-DAG: %[[DIM0:.+]] = constant 100 + %dim0 = constant 100 : index + // CHECK-DAG: %[[DIM1:.+]] = constant 200 + %dim1 = constant 200 : index + // CHECK-DAG: %[[TYPE:.+]] = constant 32 + %type = constant 32 : i32 + // CHECK: = hal.allocator.compute_range<%arg0 : !hal.allocator> + // CHECK-SAME: indices([%[[IDX0]], %[[IDX1]]]) + // CHECK-SAME: lengths([%[[LEN0]], %[[LEN1]]]) + // CHECK-SAME: shape([%[[DIM0]], %[[DIM1]]]) + // CHECK-SAME: type(%[[TYPE]]) : index, index + %off, %len = hal.allocator.compute_range<%arg0 : !hal.allocator> + indices([%idx0, %idx1]) + lengths([%len0, %len1]) + shape([%dim0, %dim1]) + type(%type) : index, index return %off, %len : index, index } // ----- // CHECK-LABEL: @allocator_allocate -func @allocator_allocate() -> !hal.buffer { - // CHECK-DAG: %[[C123:.+]] = constant 123 - %0 = constant 123 : index - // CHECK-DAG: %[[AL:.+]] = "test_hal.allocator" - %1 = "test_hal.allocator"() : () -> !hal.allocator - // CHECK: %[[CB:.+]] = hal.allocator.allocate %[[AL]], "HostVisible|HostCoherent", Transfer, %[[C123]] : !hal.buffer - %buffer = hal.allocator.allocate %1, "HostVisible|HostCoherent", Transfer, %0 : !hal.buffer - // CHECK-NEXT: return %[[CB]] - return %buffer : !hal.buffer +// CHECK-SAME: (%[[ALLOCATOR:.+]]: !hal.allocator) +func @allocator_allocate(%allocator: !hal.allocator) { + // CHECK-DAG: %[[SIZE:.+]] = constant 123 + %size = constant 123 : index + // CHECK: %[[REF:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator> + // CHECK-SAME: type("HostVisible|HostCoherent") + // CHECK-SAME: usage(Transfer) + // CHECK-SAME: : !hal.buffer{%[[SIZE]]} + %ref = hal.allocator.allocate<%allocator : !hal.allocator> + type(HostLocal) usage(Transfer) : !hal.buffer{%size} + return } // ----- -// CHECK-LABEL: @allocator_allocate_const -func @allocator_allocate_const() -> !hal.buffer { - // CHECK-DAG: %[[AL:.+]] = "test_hal.allocator" - %allocator = "test_hal.allocator"() : () -> !hal.allocator - // CHECK: %[[CB:.+]] = hal.allocator.allocate.const %[[AL]], "HostVisible|HostCoherent", Transfer : !hal.buffer = dense<123> : tensor<4x4xi32> - %buffer = hal.allocator.allocate.const %allocator, "HostVisible|HostCoherent", Transfer : !hal.buffer = dense<123> : tensor<4x4xi32> - // CHECK-NEXT: return %[[CB]] - return %buffer : !hal.buffer +// CHECK-LABEL: @allocator_constant_buffer +// CHECK-SAME: %[[ALLOCATOR:.+]]: !hal.allocator +func @allocator_constant_buffer(%allocator: !hal.allocator) { + // CHECK: %[[REF:.+]] = hal.allocator.constant<%[[ALLOCATOR]] : !hal.allocator> + // CHECK-SAME: type("DeviceVisible|DeviceLocal") + // CHECK-SAME: usage(Transfer) + // CHECK-SAME: : !hal.buffer = dense<123> : tensor<4x4xi32> + %ref = hal.allocator.constant<%allocator : !hal.allocator> + type(DeviceLocal) usage(Transfer) : !hal.buffer = + dense<123> : tensor<4x4xi32> + return } // ----- // CHECK-LABEL: @allocator_map_byte_buffer -func @allocator_map_byte_buffer() -> !hal.buffer { - // CHECK-DAG: [[SOURCE:%.+]] = "test_hal.immutable_data" - %source = "test_hal.immutable_data"() : () -> !iree.byte_buffer - // CHECK-DAG: [[OFFSET:%.+]] = "test_hal.offset" - %offset = "test_hal.offset"() : () -> index - // CHECK-DAG: [[LENGTH:%.+]] = "test_hal.length" - %length = "test_hal.length"() : () -> index - // CHECK-DAG: [[AL:%.+]] = "test_hal.allocator" - %allocator = "test_hal.allocator"() : () -> !hal.allocator - // CHECK: = hal.allocator.map [[AL]], "HostVisible|HostCoherent", Transfer, [[SOURCE]][ - // CHECK-SAME: [[OFFSET]], [[LENGTH]] - // CHECK-SAME: ] : !iree.byte_buffer -> !hal.buffer - %buffer = hal.allocator.map %allocator, "HostVisible|HostCoherent", Transfer, %source[%offset, %length] : !iree.byte_buffer -> !hal.buffer - return %buffer : !hal.buffer +func @allocator_map_byte_buffer(%arg0: !hal.allocator, %arg1: !iree.byte_buffer) { + // CHECK-DAG: %[[OFFSET:.+]] = constant 100 + %offset = constant 100 : index + // CHECK-DAG: %[[LENGTH:.+]] = constant 200 + %length = constant 200 : index + // CHECK: = hal.allocator.map<%arg0 : !hal.allocator> + // CHECK-SAME: source(%arg1 : !iree.byte_buffer)[%[[OFFSET]], %[[LENGTH]]] + // CHECK-SAME: type("DeviceVisible|DeviceLocal") + // CHECK-SAME: usage(Transfer) + // CHECK-SAME: : !hal.buffer + %ref = hal.allocator.map<%arg0 : !hal.allocator> + source(%arg1 : !iree.byte_buffer)[%offset, %length] + type(DeviceLocal) usage(Transfer) : !hal.buffer + return } diff --git a/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/iree/compiler/Dialect/HAL/IR/test/attributes.mlir index 8cc05d0114b5..43b259edf6cb 100644 --- a/iree/compiler/Dialect/HAL/IR/test/attributes.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/attributes.mlir @@ -10,6 +10,8 @@ // CHECK-LABEL: descriptor_set_layout_binding.basic "descriptor_set_layout_binding.basic"() { - // CHECK: dslb = #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read|MayAlias"> - dslb = #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read|MayAlias"> + // CHECK: dslb0 = #hal.descriptor_set_layout_binding<0, "StorageBuffer", RA> + dslb0 = #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read|MayAlias">, + // CHECK: dslb1 = #hal.descriptor_set_layout_binding<0, "StorageBuffer", RA> + dslb1 = #hal.descriptor_set_layout_binding<0, "StorageBuffer", RA> } : () -> () diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir index 229ee0e78b4e..f6a749df2ad3 100644 --- a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir @@ -1,30 +1,30 @@ -// Tests folding and canonicalization of HAL buffer ops. - -// RUN: iree-opt -allow-unregistered-dialect -split-input-file -canonicalize %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s +// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @skip_buffer_allocator -func @skip_buffer_allocator() -> !hal.allocator { - // CHECK-DAG: %[[AL:.+]] = "test_hal.allocator" - %0 = "test_hal.allocator"() : () -> !hal.allocator +// CHECK-SAME: (%[[ALLOCATOR:.+]]: !hal.allocator) +func @skip_buffer_allocator(%allocator: !hal.allocator) -> !hal.allocator { %sz = constant 4 : index - %buffer = hal.allocator.allocate %0, "HostVisible|HostCoherent", Transfer, %sz : !hal.buffer - %1 = hal.buffer.allocator %buffer : !hal.allocator - // CHECK: return %[[AL]] + %buffer = hal.allocator.allocate<%allocator : !hal.allocator> + type("HostVisible|HostCoherent") + usage(Transfer) : !hal.buffer{%sz} + %1 = hal.buffer.allocator<%buffer : !hal.buffer> : !hal.allocator + // CHECK: return %[[ALLOCATOR]] return %1 : !hal.allocator } // ----- // CHECK-LABEL: @skip_subspan_buffer_allocator -func @skip_subspan_buffer_allocator() -> !hal.allocator { +// CHECK-SAME: (%[[ALLOCATOR:.+]]: !hal.allocator) +func @skip_subspan_buffer_allocator(%allocator: !hal.allocator) -> !hal.allocator { %c0 = constant 0 : index %c184 = constant 184 : index %c384 = constant 384 : index - // CHECK-DAG: %[[AL:.+]] = "test_hal.allocator" - %allocator = "test_hal.allocator"() : () -> !hal.allocator - %source_buffer = hal.allocator.allocate %allocator, "HostVisible|HostCoherent", Transfer, %c384 : !hal.buffer - %span_buffer = hal.buffer.subspan %source_buffer, %c0, %c184 : !hal.buffer - %1 = hal.buffer.allocator %span_buffer : !hal.allocator - // CHECK: return %[[AL]] + %source_buffer = hal.allocator.allocate<%allocator : !hal.allocator> + type("HostVisible|HostCoherent") + usage(Transfer) : !hal.buffer{%c384} + %span_buffer = hal.buffer.subspan<%source_buffer : !hal.buffer>[%c0, %c184] : !hal.buffer + %1 = hal.buffer.allocator<%span_buffer : !hal.buffer> : !hal.allocator + // CHECK: return %[[ALLOCATOR]] return %1 : !hal.allocator } diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir index 8407f7b5c94c..18a04fc4f234 100644 --- a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir @@ -1,84 +1,78 @@ // Tests printing and parsing of hal.buffer ops. -// RUN: iree-opt -allow-unregistered-dialect -split-input-file %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s +// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @buffer_allocator -func @buffer_allocator() -> !hal.allocator { - %0 = "test_hal.buffer"() : () -> !hal.buffer - // CHECK: %allocator = hal.buffer.allocator %0 : !hal.allocator - %allocator = hal.buffer.allocator %0 : !hal.allocator +func @buffer_allocator(%arg0: !hal.buffer) -> !hal.allocator { + // CHECK: %allocator = hal.buffer.allocator<%arg0 : !hal.buffer> : !hal.allocator + %allocator = hal.buffer.allocator<%arg0 : !hal.buffer> : !hal.allocator return %allocator : !hal.allocator } // ----- // CHECK-LABEL: @buffer_subspan -func @buffer_subspan() -> !hal.buffer { - %0 = "test_hal.buffer"() : () -> !hal.buffer - %1 = "test_hal.device_size"() : () -> index - %2 = "test_hal.device_size"() : () -> index - // CHECK: %buffer = hal.buffer.subspan %0, %1, %2 : !hal.buffer - %buffer = hal.buffer.subspan %0, %1, %2 : !hal.buffer +func @buffer_subspan(%arg0: !hal.buffer) -> !hal.buffer { + // CHECK-DAG: %[[OFFSET:.+]] = constant 100 + %offset = constant 100 : index + // CHECK-DAG: %[[LENGTH:.+]] = constant 200 + %length = constant 200 : index + // CHECK: %buffer = hal.buffer.subspan<%arg0 : !hal.buffer>[%[[OFFSET]], %[[LENGTH]]] : !hal.buffer + %buffer = hal.buffer.subspan<%arg0 : !hal.buffer>[%offset, %length] : !hal.buffer return %buffer : !hal.buffer } // ----- -// CHECK-LABEL: @buffer_fill -func @buffer_fill(%arg0 : !hal.buffer) { - %0 = "test_hal.device_size"() : () -> index - %1 = "test_hal.device_size"() : () -> index - %2 = "test_hal.pattern"() : () -> i32 - // CHECK: hal.buffer.fill %arg0, %0, %1, %2 - hal.buffer.fill %arg0, %0, %1, %2 - return -} - -// ----- - -// CHECK-LABEL: @buffer_read_data -func @buffer_read_data(%arg0 : !hal.buffer) { - %0 = "test_hal.device_size"() : () -> index - %1 = "test_hal.mutable_data"() : () -> !iree.mutable_byte_buffer - %2 = "test_hal.device_size"() : () -> index - %3 = "test_hal.device_size"() : () -> index - // CHECK: hal.buffer.read_data %arg0, %0, %1, %2, %3 : !iree.mutable_byte_buffer - hal.buffer.read_data %arg0, %0, %1, %2, %3 : !iree.mutable_byte_buffer - return +// CHECK-LABEL: @buffer_length +func @buffer_length(%arg0: !hal.buffer) -> index { + // CHECK: hal.buffer.length<%arg0 : !hal.buffer> : index + %length = hal.buffer.length<%arg0 : !hal.buffer> : index + return %length : index } // ----- -// CHECK-LABEL: @buffer_write_data -func @buffer_write_data(%arg0 : !hal.buffer) { - %0 = "test_hal.mutable_data"() : () -> !iree.mutable_byte_buffer - %1 = "test_hal.device_size"() : () -> index - %2 = "test_hal.device_size"() : () -> index - %3 = "test_hal.device_size"() : () -> index - // CHECK: hal.buffer.write_data %0, %1, %arg0, %2, %3 : !iree.mutable_byte_buffer - hal.buffer.write_data %0, %1, %arg0, %2, %3 : !iree.mutable_byte_buffer +// CHECK-LABEL: @buffer_fill +func @buffer_fill(%arg0: !hal.buffer) { + // CHECK-DAG: %[[OFFSET:.+]] = constant 100 + %offset = constant 100 : index + // CHECK-DAG: %[[LENGTH:.+]] = constant 200 + %length = constant 200 : index + // CHECK-DAG: %[[PATTERN:.+]] = constant 42 + %pattern = constant 42 : i32 + // CHECK: hal.buffer.fill<%arg0 : !hal.buffer>[%[[OFFSET]], %[[LENGTH]]] pattern(%[[PATTERN]] : i32) + hal.buffer.fill<%arg0 : !hal.buffer>[%offset, %length] pattern(%pattern : i32) return } // ----- -// CHECK-LABEL: @buffer_copy_data -func @buffer_copy_data(%arg0 : !hal.buffer, %arg1 : !hal.buffer) { - %0 = "test_hal.device_size"() : () -> index - %1 = "test_hal.device_size"() : () -> index - %2 = "test_hal.device_size"() : () -> index - // CHECK: hal.buffer.copy_data %arg0, %0, %arg1, %1, %2 - hal.buffer.copy_data %arg0, %0, %arg1, %1, %2 +// CHECK-LABEL: @buffer_copy +func @buffer_copy(%arg0: !hal.buffer, %arg1: !hal.buffer) { + // CHECK-DAG: %[[SRC_OFFSET:.+]] = constant 100 + %src_offset = constant 100 : index + // CHECK-DAG: %[[DST_OFFSET:.+]] = constant 200 + %dst_offset = constant 200 : index + // CHECK-DAG: %[[LENGTH:.+]] = constant 300 + %length = constant 300 : index + // CHECK: hal.buffer.copy source(%arg0 : !hal.buffer)[%[[SRC_OFFSET]]] + // CHECK-SAME: target(%arg1 : !hal.buffer)[%[[DST_OFFSET]]] + // CHECK-SAME: length(%[[LENGTH]]) + hal.buffer.copy source(%arg0 : !hal.buffer)[%src_offset] + target(%arg1 : !hal.buffer)[%dst_offset] + length(%length) return } // ----- // CHECK-LABEL: @buffer_load -func @buffer_load(%arg0 : !hal.buffer) -> i32 { - %0 = "test_hal.device_size"() : () -> index - // CHECK: %[[VAL:.+]] = hal.buffer.load %arg0[%0] : i32 - %1 = hal.buffer.load %arg0[%0] : i32 +func @buffer_load(%arg0: !hal.buffer) -> i32 { + // CHECK-DAG: %[[SRC_OFFSET:.+]] = constant 100 + %src_offset = constant 100 : index + // CHECK: %[[VAL:.+]] = hal.buffer.load<%arg0 : !hal.buffer>[%[[SRC_OFFSET]]] : i32 + %1 = hal.buffer.load<%arg0 : !hal.buffer>[%src_offset] : i32 // CHECK-NEXT: return %[[VAL]] return %1 : i32 } @@ -86,9 +80,10 @@ func @buffer_load(%arg0 : !hal.buffer) -> i32 { // ----- // CHECK-LABEL: @buffer_store -func @buffer_store(%arg0 : i32, %arg1 : !hal.buffer) { - %0 = "test_hal.device_size"() : () -> index - // CHECK: hal.buffer.store %arg0, %arg1[%0] : i32 - hal.buffer.store %arg0, %arg1[%0] : i32 +func @buffer_store(%arg0: !hal.buffer, %arg1: i32) { + // CHECK-DAG: %[[DST_OFFSET:.+]] = constant 100 + %dst_offset = constant 100 : index + // CHECK: hal.buffer.store<%arg0 : !hal.buffer>[%[[DST_OFFSET]]] value(%arg1 : i32) + hal.buffer.store<%arg0 : !hal.buffer>[%dst_offset] value(%arg1 : i32) return } diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir index 634fecd87016..dc40968f7c32 100644 --- a/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/buffer_view_folding.mlir @@ -1,18 +1,5 @@ // RUN: iree-opt -allow-unregistered-dialect -split-input-file -canonicalize -cse %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s -// CHECK-LABEL: func @expand_buffer_view_const -func @expand_buffer_view_const() -> !hal.buffer_view { - %0 = "test_hal.allocator"() : () -> !hal.allocator - // CHECK: [[CONST:%.+]] = iree.byte_buffer.constant : !iree.byte_buffer = dense<[4, 1, 2]> : tensor<3xi32> - // CHECK: [[BUFFER:%.+]] = hal.allocator.map {{.+}}, "HostVisible|HostCoherent", Transfer, [[CONST]][%c0, %c-1] : !iree.byte_buffer -> !hal.buffer - // CHECK: [[VIEW:%.+]] = hal.buffer_view.create [[BUFFER]], element_type = %c16777248_i32, shape = [%c3] : !hal.buffer_view - %view = hal.buffer_view.const %0, "HostVisible|HostCoherent", Transfer : !hal.buffer_view = dense<[4, 1, 2]> : tensor<3xi32> - // CHECK-NEXT: return [[VIEW]] - return %view : !hal.buffer_view -} - -// ----- - // CHECK-LABEL: func @expand_buffer_view_subview func @expand_buffer_view_subview( // CHECK-SAME: %[[VIEW:.+]]: !hal.buffer_view, @@ -24,11 +11,11 @@ func @expand_buffer_view_subview( // CHECK: %[[ELEMENT_TYPE:.+]] = hal.buffer_view.element_type %[[VIEW]] : i32 // << A BUNCH OF MATH >> // CHECK: %[[BUFFER:.+]] = hal.buffer_view.buffer %[[VIEW]] : !hal.buffer - // CHECK-NEXT: %[[SUBSPAN:.+]] = hal.buffer.subspan %[[BUFFER]], %{{.+}}, %{{.+}} : !hal.buffer + // CHECK-NEXT: %[[SUBSPAN:.+]] = hal.buffer.subspan<%[[BUFFER]] : !hal.buffer>[%{{.+}}, %{{.+}}] : !hal.buffer // CHECK: %[[SUBVIEW:.+]] = hal.buffer_view.create // CHECK-SAME: %[[SUBSPAN]], // CHECK-SAME: element_type = %[[ELEMENT_TYPE]], - // CHECK-SAME: shape = [%[[LENGTH0]], %[[LENGTH1]]] : !hal.buffer_view + // CHECK-SAME: shape = [%[[LENGTH0]], %[[LENGTH1]]] : !hal.buffer -> !hal.buffer_view %subview = hal.buffer_view.subview %view, indices = [%index0, %index1], lengths = [%length0, %length1] : !hal.buffer_view @@ -39,15 +26,15 @@ func @expand_buffer_view_subview( // ----- // CHECK-LABEL: func @skip_buffer_view_buffer -func @skip_buffer_view_buffer() -> !hal.buffer { - // CHECK: %[[BUFFER:.+]] = "test_hal.buffer" - %0 = "test_hal.buffer"() : () -> !hal.buffer - %1:2 = "test_hal.shape"() : () -> (index, index) +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer +func @skip_buffer_view_buffer(%buffer : !hal.buffer) -> !hal.buffer { + %c10 = constant 10 : index + %c11 = constant 11 : index %c32 = constant 32 : i32 - %2 = hal.buffer_view.create %0, element_type = %c32, shape = [%1#0, %1#1] : !hal.buffer_view - %3 = hal.buffer_view.buffer %2 : !hal.buffer + %view = hal.buffer_view.create %buffer, element_type = %c32, shape = [%c10, %c11] : !hal.buffer -> !hal.buffer_view + %view_buffer = hal.buffer_view.buffer %view : !hal.buffer // CHECK: return %[[BUFFER]] - return %3 : !hal.buffer + return %view_buffer : !hal.buffer } // ----- @@ -67,7 +54,7 @@ func @buffer_view_compute_offset(%arg0 : !hal.buffer_view) -> index { // CHECK: %[[T5:.+]] = subi %[[T4]], %c1 : index // CHECK: %[[T6:.+]] = divi_unsigned %[[T5]], %c8 : index // CHECK: %[[T7:.+]] = muli %[[T1]], %[[T6]] : index - %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] + %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] : index // CHECK: return %[[T7]] return %off : index } @@ -86,7 +73,7 @@ func @buffer_view_compute_range(%arg0 : !hal.buffer_view) -> (index, index) { // CHECK: = hal.buffer_view.dim %[[VIEW]], 1 : index // CHECK: = hal.buffer_view.element_type %[[VIEW]] : i32 // << A BUNCH OF MATH >> - %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] + %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] : index, index // CHECK: return return %off, %len : index, index } diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir index 138babc888dd..65b8fc0c87db 100644 --- a/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/buffer_view_ops.mlir @@ -1,32 +1,18 @@ -// Tests printing and parsing of hal.buffer_view ops. - // RUN: iree-opt -allow-unregistered-dialect -split-input-file %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s -// ----- - -// CHECK-LABEL: @buffer_view_const -func @buffer_view_const() -> !hal.buffer_view { - %0 = "test_hal.allocator"() : () -> !hal.allocator - // CHECK: %view = hal.buffer_view.const %0, "HostVisible|HostCoherent", Transfer : !hal.buffer_view = dense<[4, 1, 2]> : tensor<3xi32> - %view = hal.buffer_view.const %0, "HostVisible|HostCoherent", Transfer : !hal.buffer_view = dense<[4, 1, 2]> : tensor<3xi32> - return %view : !hal.buffer_view -} - -// ----- - // CHECK-LABEL: @buffer_view_create -func @buffer_view_create(%arg0 : !hal.buffer) -> !hal.buffer_view { +func @buffer_view_create(%arg0: !hal.buffer) -> !hal.buffer_view { %c32 = constant 32 : i32 %0:2 = "test_hal.shape"() : () -> (index, index) - // CHECK: %view = hal.buffer_view.create %arg0, element_type = %c32_i32, shape = [%0#0, %0#1] : !hal.buffer_view - %view = hal.buffer_view.create %arg0, element_type = %c32, shape = [%0#0, %0#1] : !hal.buffer_view + // CHECK: %view = hal.buffer_view.create %arg0, element_type = %c32_i32, shape = [%0#0, %0#1] : !hal.buffer -> !hal.buffer_view + %view = hal.buffer_view.create %arg0, element_type = %c32, shape = [%0#0, %0#1] : !hal.buffer -> !hal.buffer_view return %view : !hal.buffer_view } // ----- // CHECK-LABEL: @buffer_view_subview -func @buffer_view_subview(%arg0 : !hal.buffer_view) -> !hal.buffer_view { +func @buffer_view_subview(%arg0: !hal.buffer_view) -> !hal.buffer_view { %0:2 = "test_hal.indices"() : () -> (index, index) %1:2 = "test_hal.lengths"() : () -> (index, index) // CHECK: %view = hal.buffer_view.subview %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] : !hal.buffer_view @@ -37,7 +23,7 @@ func @buffer_view_subview(%arg0 : !hal.buffer_view) -> !hal.buffer_view { // ----- // CHECK-LABEL: @buffer_view_buffer -func @buffer_view_buffer(%arg0 : !hal.buffer_view) -> !hal.buffer { +func @buffer_view_buffer(%arg0: !hal.buffer_view) -> !hal.buffer { // CHECK: %buffer = hal.buffer_view.buffer %arg0 : !hal.buffer %buffer = hal.buffer_view.buffer %arg0 : !hal.buffer return %buffer : !hal.buffer @@ -46,7 +32,7 @@ func @buffer_view_buffer(%arg0 : !hal.buffer_view) -> !hal.buffer { // ----- // CHECK-LABEL: @buffer_view_byte_length -func @buffer_view_byte_length(%arg0 : !hal.buffer_view) -> index { +func @buffer_view_byte_length(%arg0: !hal.buffer_view) -> index { // CHECK: %len = hal.buffer_view.byte_length %arg0 : index %len = hal.buffer_view.byte_length %arg0 : index return %len : index @@ -55,28 +41,28 @@ func @buffer_view_byte_length(%arg0 : !hal.buffer_view) -> index { // ----- // CHECK-LABEL: @buffer_view_compute_offset -func @buffer_view_compute_offset(%arg0 : !hal.buffer_view) -> index { +func @buffer_view_compute_offset(%arg0: !hal.buffer_view) -> index { %0:2 = "test_hal.indices"() : () -> (index, index) // CHECK: %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] - %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] + %off = hal.buffer_view.compute_offset %arg0, indices = [%0#0, %0#1] : index return %off : index } // ----- // CHECK-LABEL: @buffer_view_compute_range -func @buffer_view_compute_range(%arg0 : !hal.buffer_view) -> (index, index) { +func @buffer_view_compute_range(%arg0: !hal.buffer_view) -> (index, index) { %0:2 = "test_hal.indices"() : () -> (index, index) %1:2 = "test_hal.lengths"() : () -> (index, index) // CHECK: %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] - %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] + %off, %len = hal.buffer_view.compute_range %arg0, indices = [%0#0, %0#1], lengths = [%1#0, %1#1] : index, index return %off, %len : index, index } // ----- // CHECK-LABEL: @buffer_view_shape_queries -func @buffer_view_shape_queries(%arg0 : !hal.buffer_view) -> (index, index, index, index) { +func @buffer_view_shape_queries(%arg0: !hal.buffer_view) -> (index, index, index, index) { // CHECK: %{{.+}} = hal.buffer_view.rank %arg0 : index %0 = hal.buffer_view.rank %arg0 : index // CHECK: %{{.+}} = hal.buffer_view.dim %arg0, 0 : index diff --git a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir index 84f8ed045456..54b7c922548e 100644 --- a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir @@ -2,13 +2,18 @@ // CHECK-LABEL: @skip_command_buffer_device func @skip_command_buffer_device() -> !hal.executable { + // CHECK: %[[DEVICE:.+]] = hal.ex.shared_device %dev = hal.ex.shared_device : !hal.device - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer + %cmd = hal.command_buffer.create device(%dev : !hal.device) + mode(OneShot) + categories("Transfer|Dispatch") : !hal.command_buffer // CHECK-NOT: hal.command_buffer.device - // CHECK: %[[EXECUTABLE:.+]] = hal.executable.lookup %dev, @executable_name : !hal.executable - %0 = hal.command_buffer.device %cmd : !hal.device - %exe = hal.executable.lookup %0, @executable_name : !hal.executable + // CHECK: = hal.executable.lookup device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: executable(@executable_name) : !hal.executable + %0 = hal.command_buffer.device<%cmd : !hal.command_buffer> : !hal.device + %exe = hal.executable.lookup device(%dev : !hal.device) + executable(@executable_name) : !hal.executable return %exe : !hal.executable } @@ -16,11 +21,13 @@ func @skip_command_buffer_device() -> !hal.executable { // ----- // CHECK-LABEL: @fold_buffer_subspan_into_push_descriptor_set -// CHECK-SAME: [[BASE_BUFFER:%[a-z0-9]+]]: !hal.buffer +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[LAYOUT:.+]]: !hal.executable_layout, +// CHECK-SAME: %[[BASE_BUFFER:.+]]: !hal.buffer func @fold_buffer_subspan_into_push_descriptor_set( - %cmd : !hal.command_buffer, - %layout : !hal.executable_layout, - %buffer : !hal.buffer + %cmd: !hal.command_buffer, + %layout: !hal.executable_layout, + %buffer: !hal.buffer ) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -30,18 +37,21 @@ func @fold_buffer_subspan_into_push_descriptor_set( %c8000 = constant 8000 : index %c262140 = constant 262140 : index %c262144 = constant 262144 : index - %subspan = hal.buffer.subspan %buffer, %c4096, %c262144 : !hal.buffer - // CHECK: hal.command_buffer.push_descriptor_set {{.+}}, set = %c0, bindings = [ - hal.command_buffer.push_descriptor_set %cmd, %layout, set = %c0, bindings = [ - // 0 + 4096: - // CHECK-SAME: %c0 = ([[BASE_BUFFER]], %c4096, %c8000) - %c0 = (%subspan, %c0, %c8000), - // 4096 + 4: - // CHECK-SAME: %c1 = ([[BASE_BUFFER]], %c4100, %c262140) - %c1 = (%subspan, %c4, %c262140), - // No change: - // CHECK-SAME: %c2 = ([[BASE_BUFFER]], %c4096, %c262144) - %c2 = (%buffer, %c4096, %c262144) - ] + %subspan = hal.buffer.subspan<%buffer : !hal.buffer>[%c4096, %c262144] : !hal.buffer + // CHECK: hal.command_buffer.push_descriptor_set + // CHECK-SAME: bindings([ + hal.command_buffer.push_descriptor_set<%cmd : !hal.command_buffer> + layout(%layout : !hal.executable_layout)[%c0] + bindings([ + // 0 + 4096: + // CHECK-NEXT: %c0 = (%[[BASE_BUFFER]] : !hal.buffer)[%c4096, %c8000] + %c0 = (%subspan : !hal.buffer)[%c0, %c8000], + // 4096 + 4: + // CHECK-NEXT: %c1 = (%[[BASE_BUFFER]] : !hal.buffer)[%c4100, %c262140] + %c1 = (%subspan : !hal.buffer)[%c4, %c262140], + // No change: + // CHECK-NEXT: %c2 = (%[[BASE_BUFFER]] : !hal.buffer)[%c4096, %c262144] + %c2 = (%buffer : !hal.buffer)[%c4096, %c262144] + ]) return } diff --git a/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index 9730bfec10fa..2acd67eeff4f 100644 --- a/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir @@ -1,137 +1,187 @@ -// Tests printing and parsing of hal.command_buffer ops. - -// RUN: iree-opt -allow-unregistered-dialect -split-input-file %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s - -// CHECK-LABEL: @make_memory_barrier -func @make_memory_barrier() -> tuple { - // CHECK: %memory_barrier = hal.make_memory_barrier "HostRead|HostWrite", "MemoryRead|MemoryWrite" : tuple - %memory_barrier = hal.make_memory_barrier "HostRead|HostWrite", "MemoryRead|MemoryWrite" : tuple - return %memory_barrier : tuple -} - -// ----- - -// CHECK-LABEL: @make_buffer_barrier -func @make_buffer_barrier(%arg0 : !hal.buffer) -> tuple { - %0 = "test_hal.offset"() : () -> index - %1 = "test_hal.length"() : () -> index - // CHECK: %buffer_barrier = hal.make_buffer_barrier "HostRead|HostWrite", "MemoryRead|MemoryWrite", %arg0, %0, %1 : tuple - %buffer_barrier = hal.make_buffer_barrier "HostRead|HostWrite", "MemoryRead|MemoryWrite", %arg0, %0, %1 : tuple - return %buffer_barrier : tuple -} - -// ----- +// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @command_buffer_create -func @command_buffer_create(%arg0 : !hal.device) { - // CHECK: %cmd = hal.command_buffer.create %arg0, OneShot, "Transfer|Dispatch" : !hal.command_buffer - %cmd = hal.command_buffer.create %arg0, OneShot, "Transfer|Dispatch" : !hal.command_buffer +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) +func @command_buffer_create(%device: !hal.device) { + // CHECK: %cmd = hal.command_buffer.create + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: mode(OneShot) + // CHECK-SAME: categories("Transfer|Dispatch") : !hal.command_buffer + %cmd = hal.command_buffer.create device(%device : !hal.device) + mode(OneShot) + categories("Transfer|Dispatch") : !hal.command_buffer return } // ----- // CHECK-LABEL: @command_buffer_begin_end -func @command_buffer_begin_end(%arg0 : !hal.command_buffer) { - // CHECK: hal.command_buffer.begin %arg0 - hal.command_buffer.begin %arg0 - // CHECK: hal.command_buffer.end %arg0 - hal.command_buffer.end %arg0 +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer) +func @command_buffer_begin_end(%cmd: !hal.command_buffer) { + // CHECK: hal.command_buffer.begin<%[[CMD]] : !hal.command_buffer> + hal.command_buffer.begin<%cmd : !hal.command_buffer> + // CHECK: hal.command_buffer.end<%[[CMD]] : !hal.command_buffer> + hal.command_buffer.end<%cmd : !hal.command_buffer> return } // ----- // CHECK-LABEL: @command_buffer_device -func @command_buffer_device(%arg0 : !hal.command_buffer) { - // CHECK: %0 = hal.command_buffer.device %arg0 : !hal.device - %0 = hal.command_buffer.device %arg0 : !hal.device +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer) +func @command_buffer_device(%cmd: !hal.command_buffer) { + // CHECK: %0 = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer> : !hal.device + %0 = hal.command_buffer.device<%cmd : !hal.command_buffer> : !hal.device return } // ----- // CHECK-LABEL: @command_buffer_execution_barrier -func @command_buffer_execution_barrier(%arg0 : !hal.command_buffer) { - // CHECK: hal.command_buffer.execution_barrier %arg0, "CommandIssue", "CommandProcess", "None" - hal.command_buffer.execution_barrier %arg0, "CommandIssue", "CommandProcess", "None" +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer) +func @command_buffer_execution_barrier(%cmd: !hal.command_buffer) { + // CHECK: hal.command_buffer.execution_barrier<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: source(CommandIssue) + // CHECK-SAME: target(CommandProcess) + // CHECK-SAME: flags("None") + hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> + source(CommandIssue) + target(CommandProcess) + flags(None) return } // ----- // CHECK-LABEL: @command_buffer_fill_buffer -func @command_buffer_fill_buffer(%arg0 : !hal.command_buffer) { - %0 = "test_hal.buffer"() : () -> !hal.buffer - %1 = "test_hal.offset"() : () -> index - %2 = "test_hal.length"() : () -> index - %3 = "test_hal.pattern"() : () -> i32 - // CHECK: hal.command_buffer.fill_buffer %arg0, %0, %1, %2, %3 - hal.command_buffer.fill_buffer %arg0, %0, %1, %2, %3 +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, +// CHECK-SAME: %[[OFFSET:.+]]: index, %[[LENGTH:.+]]: index, +// CHECK-SAME: %[[PATTERN:.+]]: i32) +func @command_buffer_fill_buffer( + %cmd: !hal.command_buffer, + %buffer: !hal.buffer, + %offset: index, + %length: index, + %pattern: i32 + ) { + // CHECK: hal.command_buffer.fill_buffer<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(%[[BUFFER]] : !hal.buffer)[%[[OFFSET]], %[[LENGTH]]] + // CHECK-SAME: pattern(%[[PATTERN]] : i32) + hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> + target(%buffer : !hal.buffer)[%offset, %length] + pattern(%pattern : i32) return } // ----- // CHECK-LABEL: @command_buffer_copy_buffer -func @command_buffer_copy_buffer(%arg0 : !hal.command_buffer) { - %0 = "test_hal.buffer"() : () -> !hal.buffer - %1 = "test_hal.source_offset"() : () -> index - %2 = "test_hal.target_offset"() : () -> index - %3 = "test_hal.length"() : () -> index - // CHECK: hal.command_buffer.copy_buffer %arg0, %0, %1, %0, %2, %3 - hal.command_buffer.copy_buffer %arg0, %0, %1, %0, %2, %3 +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, +// CHECK-SAME: %[[SRC_OFFSET:.+]]: index, %[[DST_OFFSET:.+]]: index, +// CHECK-SAME: %[[LENGTH:.+]]: index) +func @command_buffer_copy_buffer( + %cmd: !hal.command_buffer, + %buffer: !hal.buffer, + %src_offset: index, + %dst_offset: index, + %length: index + ) { + // CHECK: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: source(%[[BUFFER]] : !hal.buffer)[%[[SRC_OFFSET]]] + // CHECK-SAME: target(%[[BUFFER]] : !hal.buffer)[%[[DST_OFFSET]]] + // CHECK-SAME: length(%[[LENGTH]]) + hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer> + source(%buffer : !hal.buffer)[%src_offset] + target(%buffer : !hal.buffer)[%dst_offset] + length(%length) return } // ----- // CHECK-LABEL: @command_buffer_bind_descriptor_set -func @command_buffer_bind_descriptor_set(%arg0 : !hal.command_buffer) { - %0 = "test_hal.executable_layout"() : () -> !hal.executable_layout - %1 = "test_hal.descriptor_set"() : () -> !hal.descriptor_set - %2 = "test_hal.offset"() : () -> index +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[LAYOUT:.+]]: !hal.executable_layout, +// CHECK-SAME: %[[SET:.+]]: !hal.descriptor_set, +// CHECK-SAME: %[[OFFSET:.+]]: index) +func @command_buffer_bind_descriptor_set( + %cmd: !hal.command_buffer, + %layout: !hal.executable_layout, + %set: !hal.descriptor_set, + %offset: index + ) { + // CHECK: %[[SET_IDX:.+]] = constant 0 %c0 = constant 0 : index - // CHECK: hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1 - hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1 - // CHECK-NEXT: hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1, offsets = [%2] - hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1, offsets = [%2] + // CHECK: hal.command_buffer.bind_descriptor_set<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: layout(%[[LAYOUT]] : !hal.executable_layout)[%[[SET_IDX]]] + // CHECK-SAME: set(%[[SET]] : !hal.descriptor_set) + hal.command_buffer.bind_descriptor_set<%cmd : !hal.command_buffer> + layout(%layout : !hal.executable_layout)[%c0] + set(%set : !hal.descriptor_set) + // CHECK: hal.command_buffer.bind_descriptor_set<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: layout(%[[LAYOUT]] : !hal.executable_layout)[%[[SET_IDX]]] + // CHECK-SAME: set(%[[SET]] : !hal.descriptor_set) + // CHECK-SAME: offsets([%[[OFFSET]]]) + hal.command_buffer.bind_descriptor_set<%cmd : !hal.command_buffer> + layout(%layout : !hal.executable_layout)[%c0] + set(%set : !hal.descriptor_set) + offsets([%offset]) return } // ----- -// CHECK-LABEL: @command_buffer_dispatch -func @command_buffer_dispatch(%arg0 : !hal.command_buffer) { - hal.executable @ex { - hal.executable.target @backend, filter="backend" { - hal.executable.entry_point @entry0 attributes { - interface = @interface, ordinal = 0 : i32, signature = (tensor) -> tensor - } +hal.executable @ex { + hal.executable.target @backend, filter="backend" { + hal.executable.entry_point @entry0 attributes { + interface = @interface, ordinal = 0 : index, signature = (tensor) -> tensor } } - %0 = "test_hal.workgroup_x"() : () -> index - %1 = "test_hal.workgroup_y"() : () -> index - %2 = "test_hal.workgroup_z"() : () -> index - // CHECK: hal.command_buffer.dispatch.symbol %arg0, @ex::@backend::@entry0, workgroup_xyz = [%0, %1, %2] - hal.command_buffer.dispatch.symbol %arg0, @ex::@backend::@entry0, workgroup_xyz = [%0, %1, %2] +} + +// CHECK-LABEL: @command_buffer_dispatch +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index) +func @command_buffer_dispatch( + %cmd: !hal.command_buffer, + %x: index, + %y: index, + %z: index + ) { + // CHECK: hal.command_buffer.dispatch.symbol<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(@ex::@backend::@entry0) + // CHECK-SAME: workgroups([%[[X]], %[[Y]], %[[Z]]]) + hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> + target(@ex::@backend::@entry0) + workgroups([%x, %y, %z]) return } // ----- -// CHECK-LABEL: @command_buffer_dispatch_indirect -func @command_buffer_dispatch_indirect(%arg0 : !hal.command_buffer) { - hal.executable @ex { - hal.executable.target @backend, filter="backend" { - hal.executable.entry_point @entry0 attributes { - interface = @interface, ordinal = 0 : i32, signature = (tensor) -> tensor - } +hal.executable @ex { + hal.executable.target @backend, filter="backend" { + hal.executable.entry_point @entry0 attributes { + interface = @interface, ordinal = 0 : index, signature = (tensor) -> tensor } } - %0 = "test_hal.buffer"() : () -> !hal.buffer - %1 = "test_hal.offset"() : () -> index - // CHECK: hal.command_buffer.dispatch.indirect.symbol %arg0, @ex::@backend::@entry0, workgroups = %0[%1] - hal.command_buffer.dispatch.indirect.symbol %arg0, @ex::@backend::@entry0, workgroups = %0[%1] +} + +// CHECK-LABEL: @command_buffer_dispatch_indirect +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, +// CHECK-SAME: %[[OFFSET:.+]]: index) +func @command_buffer_dispatch_indirect( + %cmd: !hal.command_buffer, + %buffer: !hal.buffer, + %offset: index) { + // CHECK: hal.command_buffer.dispatch.indirect.symbol<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(@ex::@backend::@entry0) + // CHECK-SAME: workgroups(%[[BUFFER]] : !hal.buffer)[%[[OFFSET]]] + hal.command_buffer.dispatch.indirect.symbol<%cmd : !hal.command_buffer> + target(@ex::@backend::@entry0) + workgroups(%buffer : !hal.buffer)[%offset] return } diff --git a/iree/compiler/Dialect/HAL/IR/test/descriptor_set_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/descriptor_set_ops.mlir index 99250107a46b..519626dce9f9 100644 --- a/iree/compiler/Dialect/HAL/IR/test/descriptor_set_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/descriptor_set_ops.mlir @@ -1,13 +1,20 @@ -// Tests printing and parsing of hal.descriptor_set ops. - // RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @descriptor_set_layout_create -func @descriptor_set_layout_create(%arg0 : !hal.device) { - // CHECK: hal.descriptor_set_layout.create %arg0, PushOnly, bindings = [#hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write">] : !hal.descriptor_set_layout - %descriptor_set_layout = hal.descriptor_set_layout.create %arg0, PushOnly, bindings = [ +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) +func @descriptor_set_layout_create(%device: !hal.device) { + // CHECK: = hal.descriptor_set_layout.create + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: usage(PushOnly) + // CHECK-SAME: bindings([ + // CHECK-SAME: #hal.descriptor_set_layout_binding<0, "StorageBuffer", R>, + // CHECK-SAME: #hal.descriptor_set_layout_binding<1, "StorageBuffer", W> + // CHECK-SAME: ]) : !hal.descriptor_set_layout + %0 = hal.descriptor_set_layout.create device(%device : !hal.device) + usage(PushOnly) + bindings([ #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> - ] : !hal.descriptor_set_layout + ]) : !hal.descriptor_set_layout return } diff --git a/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir index 4eebdf0fcd9a..a66e0ea2df11 100644 --- a/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir @@ -1,27 +1,26 @@ -// RUN: iree-opt -allow-unregistered-dialect -split-input-file %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s +// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @device_allocator -func @device_allocator() -> !hal.allocator { - %0 = "test_hal.device"() : () -> !hal.device - // CHECK: %allocator = hal.device.allocator %0 : !hal.allocator - %allocator = hal.device.allocator %0 : !hal.allocator +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) +func @device_allocator(%device: !hal.device) -> !hal.allocator { + // CHECK: %allocator = hal.device.allocator<%[[DEVICE]] : !hal.device> : !hal.allocator + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator return %allocator : !hal.allocator } // ----- // CHECK-LABEL: @device_switch -func @device_switch() -> i32 { +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) +func @device_switch(%device: !hal.device) -> i32 { // CHECK-DAG: %[[C0:.+]] = constant 0 %c0 = constant 0 : i32 // CHECK-DAG: %[[C1:.+]] = constant 1 %c1 = constant 1 : i32 // CHECK-DAG: %[[C2:.+]] = constant 2 %c2 = constant 2 : i32 - // CHECK-DAG: %[[DEVICE:.+]] = "test_hal.device" - %device = "test_hal.device"() : () -> !hal.device - // CHECK: = hal.device.switch(%[[DEVICE]] : !hal.device) -> i32 - %0 = hal.device.switch(%device : !hal.device) -> i32 + // CHECK: = hal.device.switch<%[[DEVICE]] : !hal.device> -> i32 + %0 = hal.device.switch<%device : !hal.device> -> i32 // CHECK-NEXT: #hal.device.match.id<"vulkan-v1.?-*">(%[[C1A:.+]] = %[[C1]] : i32) { #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) { // CHECK-NEXT: hal.return %[[C1A]] : i32 @@ -46,9 +45,9 @@ func @device_switch() -> i32 { // ----- // CHECK-LABEL: @device_matchers -// CHECK-SAME: %[[DEVICE:.+]]: !hal.device +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) func @device_matchers(%device : !hal.device) -> i1 { - // CHECK: = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-*"] : (!hal.device) -> i1 - %0 = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1 + // CHECK: = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vulkan-*") : i1 + %0 = hal.device.match.id<%device : !hal.device> pattern("vulkan-*") : i1 return %0 : i1 } diff --git a/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir index c0ade489ef57..f1b13f985034 100644 --- a/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir @@ -1,6 +1,4 @@ -// Tests printing and parsing of executable/structural ops. - -// RUN: iree-opt -allow-unregistered-dialect -split-input-file %s | iree-opt -allow-unregistered-dialect -split-input-file | IreeFileCheck %s +// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s // CHECK-LABEL: @ex hal.executable @ex { @@ -8,12 +6,12 @@ hal.executable @ex { hal.executable.target @backend, filter="backend" { // CHECK-DAG: hal.executable.entry_point @entry0 attributes { // CHECK-SAME: interface = @interface - // CHECK-SAME: ordinal = 0 : i32 + // CHECK-SAME: ordinal = 0 : index // CHECK-SAME: signature = (tensor<4xf32>) -> tensor<4xf32> // CHECK-SAME: workgroup_size = [4 : index, 1 : index, 1 : index] hal.executable.entry_point @entry0 attributes { interface = @interface, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (tensor<4xf32>) -> tensor<4xf32>, workgroup_size = [4 : index, 1 : index, 1 : index] } @@ -42,12 +40,12 @@ hal.executable @ex_with_workgroup_count_region { hal.executable.target @backend, filter="backend" { // CHECK-DAG: hal.executable.entry_point @entry0 attributes { // CHECK-SAME: interface = @interface - // CHECK-SAME: ordinal = 0 : i32 + // CHECK-SAME: ordinal = 0 : index // CHECK-SAME: signature = (tensor<4xf32>) -> tensor<4xf32> // CHECK-SAME: workgroup_size = [4 : index, 1 : index, 1 : index] hal.executable.entry_point @entry0 attributes { interface = @interface, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (tensor<4xf32>) -> tensor<4xf32>, workgroup_size = [4 : index, 1 : index, 1 : index] } { @@ -87,12 +85,7 @@ hal.executable @ex_with_source { // CHECK-NEXT: func @dispatch0 func @dispatch0(%arg0: memref<4xf32>, %arg1: memref<4xf32>) attributes { iree.executable.export, - iree.ordinal = 0 : i32} { - %0 = "iree_ll_interp.alloc_heap"() : () -> memref<4xf32> - "iree_ll_interp.add_f"(%arg0, %arg0, %0) : (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () - %1 = "iree_ll_interp.constant"() {value = dense<0> : tensor<1xi64>} : () -> memref<1xi64> - %2 = "iree_ll_interp.constant"() {value = dense<4> : tensor<1xi64>} : () -> memref<1xi64> - "iree_ll_interp.dynamic_copy"(%0, %1, %arg1, %1, %2) : (memref<4xf32>, memref<1xi64>, memref<4xf32>, memref<1xi64>, memref<1xi64>) -> () + iree.ordinal = 0 : index} { return } } @@ -105,17 +98,34 @@ hal.executable @ex_with_source { // CHECK-SAME: %[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[LAYOUT0:.+]]: !hal.executable_layout, // CHECK-SAME: %[[LAYOUT1:.+]]: !hal.executable_layout -func @executable_create(%device : !hal.device, %layout0 : !hal.executable_layout, %layout1 : !hal.executable_layout) { - // CHECK: = hal.executable.create %[[DEVICE]], @exe::@binary1, layouts = [%[[LAYOUT0]], %[[LAYOUT1]]] : !hal.executable - %0 = hal.executable.create %device, @exe::@binary1, layouts = [%layout0, %layout1] : !hal.executable +func @executable_create(%device: !hal.device, + %layout0: !hal.executable_layout, + %layout1: !hal.executable_layout) { + // CHECK: = hal.executable.create + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: target(@exe::@binary1) + // CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT1]]]) : !hal.executable + %0 = hal.executable.create device(%device : !hal.device) + target(@exe::@binary1) + layouts([%layout0, %layout1]) : !hal.executable return } // ----- // CHECK-LABEL: @executable_layout_create -func @executable_layout_create(%arg0 : !hal.device, %arg1 : !hal.descriptor_set_layout) { - // CHECK: hal.executable_layout.create %arg0, push_constants = 1, set_layouts = [%arg1] : !hal.executable_layout - %executable_layout = hal.executable_layout.create %arg0, push_constants = 1, set_layouts = [%arg1] : !hal.executable_layout +// CHECK-SAME: %[[DEVICE:.+]]: !hal.device, +// CHECK-SAME: %[[LAYOUT0:.+]]: !hal.descriptor_set_layout, +// CHECK-SAME: %[[LAYOUT1:.+]]: !hal.descriptor_set_layout +func @executable_layout_create(%device: !hal.device, + %layout0: !hal.descriptor_set_layout, + %layout1: !hal.descriptor_set_layout) { + // CHECK: hal.executable_layout.create + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: push_constants(1) + // CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT1]]]) : !hal.executable_layout + %0 = hal.executable_layout.create device(%device : !hal.device) + push_constants(1) + layouts([%layout0, %layout1]) : !hal.executable_layout return } diff --git a/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir index 6ffa2037c75d..835651e6f925 100644 --- a/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/experimental_ops.mlir @@ -4,9 +4,9 @@ // CHECK-LABEL: @shared_device func @shared_device() -> !hal.device { - // CHECK: %dev = hal.ex.shared_device : !hal.device - %dev = hal.ex.shared_device : !hal.device - return %dev : !hal.device + // CHECK: %device = hal.ex.shared_device : !hal.device + %device = hal.ex.shared_device : !hal.device + return %device : !hal.device } // ----- diff --git a/iree/compiler/Dialect/HAL/IR/test/semaphore_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/semaphore_ops.mlir index ede4d14d811b..a43fc3e97bb8 100644 --- a/iree/compiler/Dialect/HAL/IR/test/semaphore_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/semaphore_ops.mlir @@ -4,9 +4,9 @@ // CHECK-LABEL: @semaphore_create func @semaphore_create(%arg0 : !hal.device) -> !hal.semaphore { - %c0 = std.constant 0 : index - // CHECK: %semaphore = hal.semaphore.create %arg0, initial_value = %c0 : !hal.semaphore - %semaphore = hal.semaphore.create %arg0, initial_value = %c0 : !hal.semaphore + %c0 = constant 0 : index + // CHECK: %semaphore = hal.semaphore.create device(%arg0 : !hal.device) initial(%c0) : !hal.semaphore + %semaphore = hal.semaphore.create device(%arg0 : !hal.device) initial(%c0) : !hal.semaphore return %semaphore : !hal.semaphore } @@ -14,8 +14,8 @@ func @semaphore_create(%arg0 : !hal.device) -> !hal.semaphore { // CHECK-LABEL: @semaphore_query func @semaphore_query(%arg0 : !hal.semaphore) { - // CHECK: = hal.semaphore.query %arg0 : i32, index - %status, %value = hal.semaphore.query %arg0 : i32, index + // CHECK: = hal.semaphore.query<%arg0 : !hal.semaphore> : i32, index + %status, %value = hal.semaphore.query<%arg0 : !hal.semaphore> : i32, index return } @@ -23,9 +23,9 @@ func @semaphore_query(%arg0 : !hal.semaphore) { // CHECK-LABEL: @semaphore_signal func @semaphore_signal(%arg0 : !hal.semaphore) { - %c0 = std.constant 0 : index - // CHECK: hal.semaphore.signal %arg0, value = %c0 - hal.semaphore.signal %arg0, value = %c0 + %c0 = constant 0 : index + // CHECK: hal.semaphore.signal<%arg0 : !hal.semaphore> value(%c0) + hal.semaphore.signal<%arg0 : !hal.semaphore> value(%c0) return } @@ -33,9 +33,10 @@ func @semaphore_signal(%arg0 : !hal.semaphore) { // CHECK-LABEL: @semaphore_fail func @semaphore_fail(%arg0 : !hal.semaphore) { - %c0 = std.constant 0 : i32 - // CHECK: hal.semaphore.fail %arg0, status = %c0 - hal.semaphore.fail %arg0, status = %c0 + // CHECK: %[[C0:.+]] = constant 0 + %c0 = constant 0 : i32 + // CHECK: hal.semaphore.fail<%arg0 : !hal.semaphore> status(%[[C0]]) + hal.semaphore.fail<%arg0 : !hal.semaphore> status(%c0) return } @@ -43,8 +44,9 @@ func @semaphore_fail(%arg0 : !hal.semaphore) { // CHECK-LABEL: @semaphore_await func @semaphore_await(%arg0 : !hal.semaphore) { - %c0 = std.constant 0 : index - // CHECK: = hal.semaphore.await %arg0, min_value = %c0 : i32 - %0 = hal.semaphore.await %arg0, min_value = %c0 : i32 + // CHECK: %[[C0:.+]] = constant 0 + %c0 = constant 0 : index + // CHECK: = hal.semaphore.await<%arg0 : !hal.semaphore> until(%[[C0]]) : i32 + %0 = hal.semaphore.await<%arg0 : !hal.semaphore> until(%c0) : i32 return } diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp index 580ccc6544dc..65c035d4ab7e 100644 --- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.cpp @@ -154,7 +154,7 @@ LogicalResult SPIRVTargetBackend::recordDispatch( int32_t entryPointOrdinal = entryPoint.index(); rewriter.create( loc, commandBuffer, executable, - rewriter.getI32IntegerAttr(entryPointOrdinal), workgroupCount[0], + rewriter.getIndexAttr(entryPointOrdinal), workgroupCount[0], workgroupCount[1], workgroupCount[2]); if (entryPoint.index() + 1 != entryPoints.size()) { recordFullExecutionBarrier(commandBuffer, loc, rewriter); diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp index 1995001e2b88..25a8ece9e794 100644 --- a/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp +++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.cpp @@ -192,7 +192,7 @@ LogicalResult TargetBackend::linkExecutablesInto( auto newEntryPointOp = linkedTargetBuilder.create( entryPointOp.getLoc(), entryPointOp.sym_nameAttr(), - builder.getI32IntegerAttr(nextEntryPointOrdinal++), + builder.getIndexAttr(nextEntryPointOrdinal++), builder.getSymbolRefAttr(linkedInterfaceOp.getName()), entryPointOp.signatureAttr(), ArrayAttr{}); @@ -343,11 +343,15 @@ LogicalResult TargetBackend::recordDispatch( } auto builder = OpBuilder::atBlockBegin(&entryBlock); + auto entryPointSymRef = builder.getSymbolRefAttr( + dispatchState.executableOp.getName(), + {builder.getSymbolRefAttr(dispatchState.entryPointOp->getParentOp()), + builder.getSymbolRefAttr(dispatchState.entryPointOp)}); auto remappedWorkgroupCount = calculateDispatchWorkgroupCount( loc, dispatchState.executableOp, dispatchState.entryPointOp, originalWorkgroupCount, builder); builder.create( - loc, commandBuffer, dispatchState.entryPointOp, remappedWorkgroupCount[0], + loc, commandBuffer, entryPointSymRef, remappedWorkgroupCount[0], remappedWorkgroupCount[1], remappedWorkgroupCount[2]); builder.create(loc); diff --git a/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/iree/compiler/Dialect/HAL/Target/TargetBackend.h index ad2e01978156..75e423b5d1d6 100644 --- a/iree/compiler/Dialect/HAL/Target/TargetBackend.h +++ b/iree/compiler/Dialect/HAL/Target/TargetBackend.h @@ -203,28 +203,6 @@ class TargetBackend { // must follow. Note that backend-specific push constants must have been // allocated during `extractInterface`. int basePushConstantOffset = 0; - - // Dispatch operands in a form accessible as hal.buffer/hal.buffer_view. - // Note that any introduced scheduling dependency (such as a write of an - // operand/result prior to the dispatch) must be handled appropriately, such - // as by inserting a `hal.command_buffer.barrier`. - // - // Operands are 1:1 the flow.dispatch operands, meaning that if there were - // operands that were not tensor/buffer types they will be None. - // - // NOTE: some operands/results may alias (as indicated by the interface). - ArrayRef> operands; - - // Dispatch results with allocated buffers. - // Note that any introduced scheduling dependency (such as a write of an - // operand/result prior to the dispatch) must be handled appropriately, such - // as by inserting a `hal.command_buffer.barrier`. - // - // Results are 1:1 the flow.dispatch results, meaning that if there were - // results that were not tensor/buffer types they will be None. - // - // NOTE: some operands/results may alias (as indicated by the interface). - ArrayRef> results; }; // Records a dispatch to a command buffer given the dispatch state. @@ -267,7 +245,7 @@ class TargetBackend { // hal.executable.target @target, filter="target-backend" { // hal.executable.entry_point @main attributes { // interface = @main_io, - // ordinal = 0 : i32, + // ordinal = 0 : index, // signature = (tensor<4xf32>) -> tensor<4xf32> // } // module { ... } diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir index 16e8a4583354..b12931045663 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/i1_types.mlir @@ -17,7 +17,7 @@ func @i1_op_usage(%arg0: tensor<4xi1>) -> tensor<4xi1> { // CHECK: hal.executable.target @vmla // CHECK: hal.executable.entry_point @i1_op_usage_ex_dispatch_0 attributes { // CHECK-SAME: interface = @legacy_io -// CHECK-SAME: ordinal = 0 : i32 +// CHECK-SAME: ordinal = 0 : index // CHECK-SAME: signature = (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> flow.executable @i1_op_usage_ex_dispatch_0 attributes {sym_visibility = "private"} { flow.dispatch.entry @i1_op_usage_ex_dispatch_0 diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir index 2a8da68128c5..cc7c4e12378e 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/linking.mlir @@ -8,7 +8,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module { vm.func @dispatch_0(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { @@ -26,7 +26,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module { vm.func @dispatch_1(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { @@ -45,7 +45,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module { vm.func @dispatch_2(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) { @@ -57,12 +57,12 @@ module { } } func @main() -> () { - %dev = hal.ex.shared_device : !hal.device - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer + %device = hal.ex.shared_device : !hal.device + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer %c1 = constant 1 : index - hal.command_buffer.dispatch.symbol %cmd, @dispatch_0::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch.symbol %cmd, @dispatch_1::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch.symbol %cmd, @dispatch_2::@vmla::@dispatch_2, workgroup_xyz = [%c1, %c1, %c1] + hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_0::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@dispatch_2::@vmla::@dispatch_2) workgroups([%c1, %c1, %c1]) return } } @@ -84,9 +84,9 @@ module { // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} -// CHECK-NEXT: hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io_0, ordinal = 1 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} -// CHECK-NEXT: hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io_1, ordinal = 2 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} +// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} +// CHECK-NEXT: hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io_0, ordinal = 1 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} +// CHECK-NEXT: hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io_1, ordinal = 2 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} // CHECK-NEXT: module { // CHECK-NEXT: vm.module @linked_module { // CHECK-NEXT: vm.func @dispatch_0(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { @@ -107,9 +107,9 @@ module { // CHECK-NEXT: } // // CHECK: func @main() { -// CHECK: hal.command_buffer.dispatch.symbol %cmd, @vmla_linked_1::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK-NEXT: hal.command_buffer.dispatch.symbol %cmd, @vmla_linked_1::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1] -// CHECK-NEXT: hal.command_buffer.dispatch.symbol %cmd, @vmla_linked_1::@vmla::@dispatch_2, workgroup_xyz = [%c1, %c1, %c1] +// CHECK: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_2) workgroups([%c1, %c1, %c1]) // CHECK-NEXT: return // CHECK-NEXT: } @@ -123,7 +123,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module { vm.func @dispatch_0(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { @@ -144,7 +144,7 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module { vm.func @dispatch_1(%arg0: !vm.ref, %arg1: i32, %arg2: i32) { @@ -160,19 +160,19 @@ module { } } func @main() -> () { - %dev = hal.ex.shared_device : !hal.device - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer - hal.device.switch(%dev : !hal.device) + %device = hal.ex.shared_device : !hal.device + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer + hal.device.switch<%device : !hal.device> #hal.device.match.id<"vmla">(%arg1 = %cmd : !hal.command_buffer) { %c1 = constant 1 : index - hal.command_buffer.dispatch.symbol %arg1, @dispatch_0::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch.symbol %arg1, @dispatch_1::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1] + hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_0::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1]) hal.return }, #hal.device.match.id<"othertarget">(%arg1 = %cmd : !hal.command_buffer) { %c1 = constant 1 : index - hal.command_buffer.dispatch.symbol %arg1, @dispatch_0::@otherdispatch::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch.symbol %arg1, @dispatch_1::@otherdispatch::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1] + hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_0::@otherdispatch::@dispatch_0) workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch.symbol<%arg1 : !hal.command_buffer> target(@dispatch_1::@otherdispatch::@dispatch_1) workgroups([%c1, %c1, %c1]) hal.return } return @@ -191,8 +191,8 @@ module { // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} -// CHECK-NEXT: hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io_1, ordinal = 1 : i32, signature = (tensor<1x1xf32>) -> tensor<1x1xf32>} +// CHECK-NEXT: hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} +// CHECK-NEXT: hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io_1, ordinal = 1 : index, signature = (tensor<1x1xf32>) -> tensor<1x1xf32>} // CHECK-NEXT: module { // CHECK-NEXT: vm.module @linked_module { // CHECK-NEXT: vm.func @dispatch_0(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { @@ -217,17 +217,17 @@ module { // CHECK: hal.executable.target @othertarget, filter="othertarget" // // CHECK: func @main() { -// CHECK: hal.device.switch(%dev : !hal.device) +// CHECK: hal.device.switch<%device : !hal.device> // CHECK-NEXT: #hal.device.match.id<"vmla">(%arg0 = %cmd : !hal.command_buffer) { // CHECK-NEXT: %c1 = constant 1 : index -// CHECK-NEXT: hal.command_buffer.dispatch.symbol %arg0, @vmla_linked_1::@vmla::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK-NEXT: hal.command_buffer.dispatch.symbol %arg0, @vmla_linked_1::@vmla::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1] +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_0) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@vmla_linked_1::@vmla::@dispatch_1) workgroups([%c1, %c1, %c1]) // CHECK-NEXT: hal.return // CHECK-NEXT: }, // CHECK-NEXT: #hal.device.match.id<"othertarget">(%arg0 = %cmd : !hal.command_buffer) { // CHECK-NEXT: %c1 = constant 1 : index -// CHECK-NEXT: hal.command_buffer.dispatch.symbol %arg0, @dispatch_0::@otherdispatch::@dispatch_0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK-NEXT: hal.command_buffer.dispatch.symbol %arg0, @dispatch_1::@otherdispatch::@dispatch_1, workgroup_xyz = [%c1, %c1, %c1] +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@dispatch_0::@otherdispatch::@dispatch_0) workgroups([%c1, %c1, %c1]) +// CHECK-NEXT: hal.command_buffer.dispatch.symbol<%arg0 : !hal.command_buffer> target(@dispatch_1::@otherdispatch::@dispatch_1) workgroups([%c1, %c1, %c1]) // CHECK-NEXT: hal.return // CHECK-NEXT: } // CHECK-NEXT: return @@ -243,33 +243,33 @@ module { hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module {} } } } hal.executable @dispatch_1 attributes {sym_visibility = "private"} { - hal.interface @legacy_io attributes {push_constants = 2 : i32} { + hal.interface @legacy_io attributes {push_constants = 2 : index} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_1 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module {} } } } hal.executable @dispatch_2 attributes {sym_visibility = "private"} { - hal.interface @legacy_io attributes {push_constants = 2 : i32} { + hal.interface @legacy_io attributes {push_constants = 2 : index} { hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" } hal.executable.target @vmla, filter="vmla" { - hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} + hal.executable.entry_point @dispatch_2 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>} module { vm.module @module {} } @@ -287,7 +287,7 @@ module { // CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } -// CHECK-NEXT: hal.interface @legacy_io_1 attributes {push_constants = 2 : i32} { +// CHECK-NEXT: hal.interface @legacy_io_1 attributes {push_constants = 2 : index} { // CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard" diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir index a1211b5e5cec..a51425284052 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir @@ -18,7 +18,7 @@ flow.executable @simpleMath_ex_dispatch_0 { // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4xf32>) -> tensor<4xf32>} +// CHECK-NEXT: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<4xf32>) -> tensor<4xf32>} // CHECK-NEXT: module { // CHECK-NEXT: vm.module @module { // CHECK-NEXT: vm.func @simpleMath_rgn_dispatch_0(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { @@ -56,12 +56,12 @@ flow.executable @shaped_dispatch { } // CHECK-LABEL: hal.executable @shaped_dispatch -// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 1 : i32} { +// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 1 : index} { // CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>} +// CHECK-NEXT: hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>} // CHECK-NEXT: module { // CHECK-NEXT: vm.module @module { // CHECK-NEXT: vm.func @entry(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { @@ -103,7 +103,7 @@ flow.executable @reduction_ex_dispatch_0 { // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>} +// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : index, signature = (tensor<4x8xf32>) -> tensor<4xf32>} // CHECK-NEXT: module { // CHECK-NEXT: vm.module @module { // CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_const dense<0.000000e+00> : tensor<1xf32> diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir index 74d0c4a466c3..d4f4566dfe41 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/test/smoketest.mlir @@ -2,7 +2,7 @@ flow.executable @simpleMath_ex_dispatch_0 { flow.dispatch.entry @simpleMath_rgn_dispatch_0 attributes { - workload = 4 : index + workload = 4 : index } module { func @simpleMath_rgn_dispatch_0(%arg0: tensor<4xf32>) -> tensor<4xf32> { diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp index f654edf9ba31..999ab05b3c36 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeConstantPoolBuffers.cpp @@ -86,9 +86,9 @@ class MaterializeConstantPoolBuffersPass auto *context = poolOp.getContext(); auto variableName = (poolOp.getName() + storageOp.getName() + "_buffer").str(); + auto variableType = IREE::HAL::BufferType::get(context); auto variableOp = OpBuilder(context).create( - storageOp.getLoc(), variableName, /*isMutable=*/false, - IREE::HAL::BufferType::get(context)); + storageOp.getLoc(), variableName, /*isMutable=*/false, variableType); moduleSymbolTable.insert(variableOp, insertionPoint); variableOp.setPrivate(); @@ -136,10 +136,6 @@ class MaterializeConstantPoolBuffersPass // Today we always map the buffer directly. We should be using a device // switch to schedule the upload if needed. // TODO(benvanik): allocate based on usage tracking. - auto memoryTypes = IREE::HAL::MemoryTypeBitfield::HostLocal | - IREE::HAL::MemoryTypeBitfield::DeviceVisible; - auto bufferUsage = IREE::HAL::BufferUsageBitfield::Constant | - IREE::HAL::BufferUsageBitfield::All; auto sourceValue = funcBuilder.createOrFold( storageOp.getLoc(), IREE::ByteBufferType::get(context), @@ -153,9 +149,13 @@ class MaterializeConstantPoolBuffersPass bufferConstraints.min_buffer_range_alignment()); auto lengthValue = funcBuilder.createOrFold( storageOp.getLoc(), runtimeLength); + auto memoryType = IREE::HAL::MemoryTypeBitfield::DeviceLocal | + IREE::HAL::MemoryTypeBitfield::HostVisible; + auto bufferUsage = IREE::HAL::BufferUsageBitfield::Constant | + IREE::HAL::BufferUsageBitfield::All; auto bufferValue = funcBuilder.createOrFold( - storageOp.getLoc(), allocatorValue, memoryTypes, bufferUsage, - sourceValue, offsetValue, lengthValue); + storageOp.getLoc(), IREE::HAL::BufferType::get(context), allocatorValue, + memoryType, bufferUsage, sourceValue, offsetValue, lengthValue); funcBuilder.create(storageOp.getLoc(), bufferValue); return initializerFunc; @@ -168,6 +168,14 @@ class MaterializeConstantPoolBuffersPass SymbolTable &moduleSymbolTable, Block::iterator insertionPoint) { auto *context = poolOp.getContext(); + + // TODO(benvanik): we don't need host-visible here as we could require that + // all reads go through staging. When we want to support devices with + // unmappable memory we'll need to adjust this. Usage analysis on whether + // the buffer is ever read back or only used on device will help determine + // things. + auto variableType = IREE::HAL::BufferType::get(context); + auto variableLoc = FusedLoc::get(context, llvm::to_vector<8>(llvm::map_range( splatOps, [](ConstantPoolSplatOp splatOp) { @@ -175,8 +183,7 @@ class MaterializeConstantPoolBuffersPass }))); auto variableName = (poolOp.getName() + "_splats").str(); auto variableOp = OpBuilder(context).create( - variableLoc, variableName, /*isMutable=*/false, - IREE::HAL::BufferType::get(context)); + variableLoc, variableName, /*isMutable=*/false, variableType); moduleSymbolTable.insert(variableOp, insertionPoint); variableOp.setPrivate(); @@ -241,19 +248,26 @@ class MaterializeConstantPoolBuffersPass deviceValue); // Allocate buffer with empty contents. - // TODO(benvanik): allocate based on usage tracking. - auto memoryTypes = IREE::HAL::MemoryTypeBitfield::DeviceLocal | - IREE::HAL::MemoryTypeBitfield::HostVisible; + auto memoryType = IREE::HAL::MemoryTypeBitfield::DeviceLocal | + IREE::HAL::MemoryTypeBitfield::HostVisible; auto bufferUsage = IREE::HAL::BufferUsageBitfield::Constant | IREE::HAL::BufferUsageBitfield::All; auto allocationSizeValue = funcBuilder.createOrFold( variableLoc, bufferLength); auto bufferValue = funcBuilder.createOrFold( - variableLoc, allocatorValue, memoryTypes, bufferUsage, - allocationSizeValue); + variableLoc, IREE::HAL::BufferType::get(context), allocatorValue, + memoryType, bufferUsage, allocationSizeValue); - // Fill the buffer (memset). - // TODO(benvanik): do this via a command buffer/DMA to keep host moving. + // Fill the buffers (memset). + // We do this with a command buffer so that we can allow the device to + // fill them in asynchronously and without memory mapping. + auto commandBufferValue = + funcBuilder.createOrFold( + variableLoc, IREE::HAL::CommandBufferType::get(context), + deviceValue, IREE::HAL::CommandBufferModeBitfield::OneShot, + IREE::HAL::CommandCategoryBitfield::Transfer); + funcBuilder.create(variableLoc, + commandBufferValue); for (auto splatOp : splatOps) { auto offsetValue = funcBuilder.createOrFold( splatOp.getLoc(), splatOp.runtime_rangeAttr().offsetAttr()); @@ -263,10 +277,14 @@ class MaterializeConstantPoolBuffersPass splatOp.value().cast().getSplatValue()); auto patternValue = funcBuilder.createOrFold( variableLoc, static_cast(pattern), 32); - funcBuilder.create(splatOp.getLoc(), bufferValue, - offsetValue, lengthValue, - patternValue); + funcBuilder.create( + splatOp.getLoc(), commandBufferValue, bufferValue, offsetValue, + lengthValue, patternValue); } + funcBuilder.create(variableLoc, + commandBufferValue); + funcBuilder.create(variableLoc, deviceValue, + commandBufferValue); funcBuilder.create(variableLoc, bufferValue); diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 75f20a079366..e519377b13e7 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -78,7 +78,8 @@ static llvm::Optional declareInterfaceIO( int bindingOrdinal = nextBindingOrdinal++; auto bindingName = "arg" + std::to_string(inputType.index()); interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, + interfaceLoc, bindingName, /*set=*/APInt(64, 0), + /*binding=*/APInt(64, bindingOrdinal), IREE::HAL::DescriptorType::StorageBuffer, IREE::HAL::MemoryAccessBitfield::Read); } else if (auto tensorType = @@ -106,7 +107,8 @@ static llvm::Optional declareInterfaceIO( std::string bindingName = std::string(prefix) + std::to_string(bindingOrdinal); interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, + interfaceLoc, bindingName, /*set=*/APInt(64, 0), + /*binding=*/APInt(64, bindingOrdinal), IREE::HAL::DescriptorType::StorageBuffer, memoryAccess); } else if (auto indexType = inputType.value().dyn_cast()) { ++pushConstantCount; @@ -132,7 +134,8 @@ static llvm::Optional declareInterfaceIO( auto bindingName = "ret" + std::to_string(outputType.index()); if (outputType.value().isa()) { interfaceBuilder.create( - interfaceLoc, bindingName, /*set=*/0, /*binding=*/bindingOrdinal, + interfaceLoc, bindingName, /*set=*/APInt(64, 0), + /*binding=*/APInt(64, bindingOrdinal), IREE::HAL::DescriptorType::StorageBuffer, IREE::HAL::MemoryAccessBitfield::DiscardWrite); } else { @@ -145,7 +148,7 @@ static llvm::Optional declareInterfaceIO( if (pushConstantCount > 0) { interfaceOp->setAttr("push_constants", - interfaceBuilder.getI32IntegerAttr(pushConstantCount)); + interfaceBuilder.getIndexAttr(pushConstantCount)); } return interfaceOp; @@ -398,7 +401,7 @@ static LogicalResult declareEntryPointOps( builder.create( dispatchEntryOp.getLoc(), builder.getStringAttr(dispatchEntryOp.function_ref()), - builder.getI32IntegerAttr(nextOrdinal++), + builder.getIndexAttr(nextOrdinal++), builder.getSymbolRefAttr(interfaceOp), TypeAttr::get(sourceFuncOp.getType()), ArrayAttr{}); } diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index c45754fc6c4b..66043eee1ce7 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -125,8 +125,7 @@ class MaterializeResourceCachesPass IntegerAttr pushConstantsAttr) { // Push constants are optional but we always provide the value. if (!pushConstantsAttr) { - pushConstantsAttr = - IntegerAttr::get(IntegerType::get(loc.getContext(), 32), 0); + pushConstantsAttr = IntegerAttr::get(IndexType::get(loc.getContext()), 0); } // We key the layout cache on all attributes that compose an executable diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h index 99c10f314995..5b28fc05928f 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.h +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h @@ -52,7 +52,7 @@ void registerHALTransformPassPipeline(); // Conversion //===----------------------------------------------------------------------===// -// Convert input flow/std/etc dialects to the IREE HAL dialect. +// Converts input flow/std/etc dialects to the IREE HAL dialect. std::unique_ptr> createConvertToHALPass(); //===----------------------------------------------------------------------===// diff --git a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp b/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp index 4f582c4dbc12..b9c68080e9a4 100644 --- a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp @@ -35,28 +35,28 @@ namespace { using mlir::iree_compiler::IREE::SIP::RawSignatureParser; using mlir::iree_compiler::IREE::SIP::AbiConstants::ScalarType; -Type mapScalarType(MLIRContext *ctx, ScalarType scalarType) { +Type mapScalarType(MLIRContext *context, ScalarType scalarType) { switch (scalarType) { case ScalarType::kIeeeFloat32: - return FloatType::getF32(ctx); + return FloatType::getF32(context); case ScalarType::kIeeeFloat64: - return FloatType::getF64(ctx); + return FloatType::getF64(context); case ScalarType::kIeeeFloat16: - return FloatType::getF16(ctx); + return FloatType::getF16(context); case ScalarType::kGoogleBfloat16: - return FloatType::getBF16(ctx); + return FloatType::getBF16(context); case ScalarType::kSint32: case ScalarType::kUint32: - return IntegerType::get(ctx, 32); + return IntegerType::get(context, 32); case ScalarType::kSint64: case ScalarType::kUint64: - return IntegerType::get(ctx, 64); + return IntegerType::get(context, 64); case ScalarType::kSint16: case ScalarType::kUint16: - return IntegerType::get(ctx, 16); + return IntegerType::get(context, 16); case ScalarType::kSint8: case ScalarType::kUint8: - return IntegerType::get(ctx, 8); + return IntegerType::get(context, 8); default: return nullptr; } @@ -65,7 +65,7 @@ Type mapScalarType(MLIRContext *ctx, ScalarType scalarType) { LogicalResult mapRawAbiTypes( Location loc, SmallVectorImpl &descs, SmallVectorImpl &types) { - auto *ctx = loc.getContext(); + auto *context = loc.getContext(); auto bufferViewType = HAL::BufferViewType::get(loc.getContext()); for (auto &d : descs) { switch (d.type) { @@ -80,7 +80,7 @@ LogicalResult mapRawAbiTypes( return emitError(loc) << "unsupported ABI type: " << dstr; } case RawSignatureParser::Type::kScalar: { - auto t = mapScalarType(ctx, d.scalar.type); + auto t = mapScalarType(context, d.scalar.type); if (!t) { std::string dstr; d.ToString(dstr); @@ -101,7 +101,7 @@ LogicalResult generateAsynchronousBody( SmallVectorImpl &inputDescs, SmallVectorImpl &resultTypes, SmallVectorImpl &resultDescs) { - auto *ctx = funcOp.getContext(); + auto *context = funcOp.getContext(); auto loc = funcOp.getLoc(); Block *entryBlock = funcOp.addEntryBlock(); OpBuilder builder = OpBuilder::atBlockEnd(entryBlock); @@ -124,8 +124,8 @@ LogicalResult generateAsynchronousBody( case RawSignatureParser::Type::kBuffer: { // Pass the backing buffer. // TODO(laurenzo): Validate shape. - callOperands.push_back( - builder.create(loc, blockArg)); + callOperands.push_back(builder.create( + loc, IREE::HAL::BufferType::get(context), blockArg)); // Now, each dynamic dim is passed individually. for (auto dim : llvm::enumerate(input.value().dims)) { @@ -139,7 +139,7 @@ LogicalResult generateAsynchronousBody( // at a time as needed. auto dimValue = builder.create( loc, builder.getIndexType(), blockArg, - builder.getI32IntegerAttr(dim.index())); + builder.getIndexAttr(dim.index())); callOperands.push_back(dimValue); } break; @@ -195,7 +195,8 @@ LogicalResult generateAsynchronousBody( } // Determine element type. - Type mappedScalarType = mapScalarType(ctx, output.value().scalar.type); + Type mappedScalarType = + mapScalarType(context, output.value().scalar.type); auto elementType = getElementTypeValue(mappedScalarType); if (!elementType) { return emitError(loc) @@ -276,7 +277,7 @@ LogicalResult generateRawAbiFunctions(OpBuilder &moduleBuilder, StringRef exportName, DictionaryAttr reflection, StringRef signatureSr) { - auto ctx = rawCalleeFuncOp.getContext(); + auto context = rawCalleeFuncOp.getContext(); auto loc = rawCalleeFuncOp.getLoc(); StringRef signature(signatureSr.data(), signatureSr.size()); @@ -314,13 +315,13 @@ LogicalResult generateRawAbiFunctions(OpBuilder &moduleBuilder, // Prefix with wait semaphore and its value. // TODO(scotttodd): SemaphoreValue wrapper for single {semaphore, value} // TODO(scotttodd): SemaphoreList wrapper for list of SemaphoreValues - asyncInputTypes.push_back(HAL::SemaphoreType::get(ctx)); + asyncInputTypes.push_back(HAL::SemaphoreType::get(context)); asyncInputTypes.push_back(moduleBuilder.getIndexType()); for (const auto &inputType : inputTypes) { asyncInputTypes.push_back(inputType); } // Postfix with signal semaphore and its value. - asyncInputTypes.push_back(HAL::SemaphoreType::get(ctx)); + asyncInputTypes.push_back(HAL::SemaphoreType::get(context)); asyncInputTypes.push_back(moduleBuilder.getIndexType()); // TODO(scotttodd): populate async export attributes @@ -329,9 +330,9 @@ LogicalResult generateRawAbiFunctions(OpBuilder &moduleBuilder, SmallVector asyncExportAttrs; asyncExportAttrs.push_back(moduleBuilder.getNamedAttr( "iree.module.export", - StringAttr::get(ctx, (exportName + "$async").str()))); + StringAttr::get(context, (exportName + "$async").str()))); - auto asyncType = FunctionType::get(ctx, asyncInputTypes, resultTypes); + auto asyncType = FunctionType::get(context, asyncInputTypes, resultTypes); auto asyncName = (rawCalleeFuncOp.getName() + "$async").str(); auto asyncFuncOp = moduleBuilder.create(loc, asyncName, asyncType, asyncExportAttrs); @@ -349,9 +350,9 @@ LogicalResult generateRawAbiFunctions(OpBuilder &moduleBuilder, syncExportAttrs.push_back( moduleBuilder.getNamedAttr("iree.reflection", reflection)); syncExportAttrs.push_back( - moduleBuilder.getNamedAttr("iree.abi.stub", UnitAttr::get(ctx))); + moduleBuilder.getNamedAttr("iree.abi.stub", UnitAttr::get(context))); - auto syncType = FunctionType::get(ctx, inputTypes, resultTypes); + auto syncType = FunctionType::get(context, inputTypes, resultTypes); auto syncName = (rawCalleeFuncOp.getName() + "$sync").str(); auto syncFuncOp = moduleBuilder.create(loc, syncName, syncType, syncExportAttrs); @@ -397,14 +398,14 @@ class PublicABIGenerationPass : public PassWrapper> { public: void runOnOperation() override { - auto *ctx = &getContext(); + auto *context = &getContext(); for (auto &op : getOperation().getBody()->getOperations()) { if (auto funcOp = dyn_cast(op)) { // Skip functions we generate. if (funcOp->getAttr("iree.abi.stub")) continue; - // Any function marked for export, we redirect to export with a - // '$raw' suffix and then generate ABI wrappers with the original name. + // Any function marked for export we make private and expose via + // generated ABI wrappers with the original name. Optional exportName = getFuncOpExportName(funcOp); if (!exportName) continue; auto reflection = funcOp->getAttr("iree.reflection") @@ -412,10 +413,9 @@ class PublicABIGenerationPass if (!reflection) continue; // Rename and remove reflection (it will go on the ABI entry point). - funcOp->setAttr("iree.module.export", - StringAttr::get(ctx, (*exportName + "$raw").str())); + funcOp->removeAttr("iree.module.export"); funcOp->removeAttr("iree.reflection"); - funcOp->setAttr("noinline", UnitAttr::get(ctx)); + funcOp->setAttr("noinline", UnitAttr::get(context)); if (reflection) { if (failed(generateAbiFunctions(funcOp, *exportName, reflection))) { diff --git a/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir index 45b479e11b36..e13e70a7bb7d 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/benchmark_batch_dispatches.mlir @@ -1,24 +1,24 @@ // RUN: iree-opt -split-input-file -test-iree-hal-benchmark-batch-dispatches-2-times %s | IreeFileCheck %s -hal.variable @_executable_0 : !hal.executable -func @multiple_reads_no_writes() { - %0 = hal.variable.load @_executable_0 : !hal.executable - %1 = hal.variable.load @_executable_0 : !hal.executable - %2 = hal.variable.load @_executable_0 : !hal.executable +hal.variable @_executable : !hal.executable + +// CHECK-LABEL: @multiple_reads_no_writes +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer +func @multiple_reads_no_writes(%cmd : !hal.command_buffer) { + // CHECK: %[[EXE:.+]] = hal.variable.load @_executable + %exe = hal.variable.load @_executable : !hal.executable %c1 = constant 1 : index - %dev = hal.ex.shared_device : !hal.device - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer - hal.command_buffer.begin %cmd - hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.end %cmd + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[0] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[1] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[2] workgroups([%c1, %c1, %c1]) + return } -// CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK: hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] -// CHECK: hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] + +// CHECK: hal.command_buffer.dispatch<%[[CMD:.+]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[0] workgroups([%c1, %c1, %c1]) +// CHECK: hal.command_buffer.dispatch<%[[CMD:.+]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[0] workgroups([%c1, %c1, %c1]) +// CHECK: hal.command_buffer.dispatch<%[[CMD:.+]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[1] workgroups([%c1, %c1, %c1]) +// CHECK: hal.command_buffer.dispatch<%[[CMD:.+]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[1] workgroups([%c1, %c1, %c1]) +// CHECK: hal.command_buffer.dispatch<%[[CMD:.+]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[2] workgroups([%c1, %c1, %c1]) +// CHECK: hal.command_buffer.dispatch<%[[CMD:.+]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[2] workgroups([%c1, %c1, %c1]) diff --git a/iree/compiler/Dialect/HAL/Transforms/test/cse_variable_loads.mlir b/iree/compiler/Dialect/HAL/Transforms/test/cse_variable_loads.mlir index 71d1d0b6e619..c6b49911a475 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/cse_variable_loads.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/cse_variable_loads.mlir @@ -1,27 +1,26 @@ // RUN: iree-opt -split-input-file -iree-hal-cse-variable-loads %s | IreeFileCheck %s -// CHECK: hal.variable @_executable_0 : !hal.executable -hal.variable @_executable_0 : !hal.executable +// CHECK: hal.variable @_executable : !hal.executable +hal.variable @_executable : !hal.executable + // CHECK-LABEL: @multiple_reads_no_writes -func @multiple_reads_no_writes() { - // CHECK-NEXT: %0 = hal.variable.load @_executable_0 : !hal.executable - %0 = hal.variable.load @_executable_0 : !hal.executable +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer +func @multiple_reads_no_writes(%cmd : !hal.command_buffer) { + // CHECK-NEXT: %[[EXE:.+]] = hal.variable.load @_executable : !hal.executable + %exe0 = hal.variable.load @_executable : !hal.executable // CHECK-NOT: hal.variable.load - %1 = hal.variable.load @_executable_0 : !hal.executable + %exe1 = hal.variable.load @_executable : !hal.executable // CHECK-NOT: hal.variable.load - %2 = hal.variable.load @_executable_0 : !hal.executable + %exe2 = hal.variable.load @_executable : !hal.executable %c1 = constant 1 : index - %dev = hal.ex.shared_device : !hal.device - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer - hal.command_buffer.begin %cmd - // CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - // CHECK-NEXT: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - // CHECK-NEXT: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.end %cmd + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[0] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe0 : !hal.executable)[0] workgroups([%c1, %c1, %c1]) + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[1] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe1 : !hal.executable)[1] workgroups([%c1, %c1, %c1]) + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[2] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe2 : !hal.executable)[2] workgroups([%c1, %c1, %c1]) + return } @@ -31,28 +30,21 @@ func @multiple_reads_no_writes() { hal.variable @_executable_0 : !hal.executable // CHECK: hal.variable @_executable_1 : !hal.executable hal.variable @_executable_1 : !hal.executable -// CHECK: hal.variable @_executable_2 : !hal.executable -hal.variable @_executable_2 : !hal.executable + // CHECK-LABEL: @different_variables_are_not_eliminated -func @different_variables_are_not_eliminated() { - // CHECK-NEXT: %0 = hal.variable.load @_executable_0 : !hal.executable - %0 = hal.variable.load @_executable_0 : !hal.executable - // CHECK-NEXT: %1 = hal.variable.load @_executable_1 : !hal.executable - %1 = hal.variable.load @_executable_1 : !hal.executable - // CHECK-NEXT: %2 = hal.variable.load @_executable_2 : !hal.executable - %2 = hal.variable.load @_executable_2 : !hal.executable +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer +func @different_variables_are_not_eliminated(%cmd : !hal.command_buffer) { + // CHECK-NEXT: %[[EXE0:.+]] = hal.variable.load @_executable_0 : !hal.executable + %exe0 = hal.variable.load @_executable_0 : !hal.executable + // CHECK-NEXT: %[[EXE1:.+]] = hal.variable.load @_executable_1 : !hal.executable + %exe1 = hal.variable.load @_executable_1 : !hal.executable %c1 = constant 1 : index - %dev = hal.ex.shared_device : !hal.device - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer - hal.command_buffer.begin %cmd - // CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - // CHECK-NEXT: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - // CHECK-NEXT: hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %2, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.end %cmd + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE0]] : !hal.executable)[0] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe0 : !hal.executable)[0] workgroups([%c1, %c1, %c1]) + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE1]] : !hal.executable)[1] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe1 : !hal.executable)[1] workgroups([%c1, %c1, %c1]) + return } @@ -60,60 +52,57 @@ func @different_variables_are_not_eliminated() { hal.executable @exe {} -// CHECK: hal.variable @_executable_0 mutable : !hal.executable -hal.variable @_executable_0 mutable : !hal.executable +// CHECK: hal.variable @_executable mutable : !hal.executable +hal.variable @_executable mutable : !hal.executable + // CHECK-LABEL: @writes_prevent_cse -func @writes_prevent_cse() { - // CHECK-NEXT: %0 = hal.variable.load @_executable_0 : !hal.executable - %0 = hal.variable.load @_executable_0 : !hal.executable - // CHECK-NEXT: %1 = hal.variable.load @_executable_0 : !hal.executable - %1 = hal.variable.load @_executable_0 : !hal.executable +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer +func @writes_prevent_cse(%cmd : !hal.command_buffer) { + // CHECK-NEXT: %[[EXE0:.+]] = hal.variable.load @_executable : !hal.executable + %exe0 = hal.variable.load @_executable : !hal.executable + // CHECK-NEXT: %[[EXE1:.+]] = hal.variable.load @_executable : !hal.executable + %exe1 = hal.variable.load @_executable : !hal.executable %dev = hal.ex.shared_device : !hal.device - %exe = hal.executable.lookup %dev, @exe : !hal.executable - // CHECK: hal.variable.store %exe, @_executable_0 : !hal.executable - hal.variable.store %exe, @_executable_0 : !hal.executable + %exe = hal.executable.lookup device(%dev : !hal.device) executable(@exe) : !hal.executable + // CHECK: hal.variable.store %exe, @_executable : !hal.executable + hal.variable.store %exe, @_executable : !hal.executable %c1 = constant 1 : index - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer - hal.command_buffer.begin %cmd - // CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - // CHECK-NEXT: hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.end %cmd + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE0]] : !hal.executable)[0] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe0 : !hal.executable)[0] workgroups([%c1, %c1, %c1]) + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE1]] : !hal.executable)[1] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe1 : !hal.executable)[1] workgroups([%c1, %c1, %c1]) + return } // ----- -// CHECK: hal.variable @_executable_0 : !hal.executable -hal.variable @_executable_0 : !hal.executable +// CHECK: hal.variable @_executable : !hal.executable +hal.variable @_executable : !hal.executable + // CHECK-LABEL: @reads_in_blocks -func @reads_in_blocks() { +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer +func @reads_in_blocks(%cmd : !hal.command_buffer) { // load should be hoisted to the entry block - // CHECK-NEXT: %0 = hal.variable.load @_executable_0 : !hal.executable + // CHECK-NEXT: %[[EXE:.+]] = hal.variable.load @_executable : !hal.executable %c1 = constant 1 : index - %dev = hal.ex.shared_device : !hal.device - %cmd = hal.command_buffer.create %dev, "OneShot", "Transfer|Dispatch" : !hal.command_buffer - hal.command_buffer.begin %cmd - %i1 = constant 1 : i1 cond_br %i1, ^bb1, ^bb2 ^bb1: // CHECK-NOT: hal.variable.load - %0 = hal.variable.load @_executable_0 : !hal.executable - // CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] + %exe0 = hal.variable.load @_executable : !hal.executable + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[0] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe0 : !hal.executable)[0] workgroups([%c1, %c1, %c1]) br ^bb3 ^bb2: // CHECK-NOT: hal.variable.load - %1 = hal.variable.load @_executable_0 : !hal.executable - // CHECK: hal.command_buffer.dispatch %cmd, %0, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] - hal.command_buffer.dispatch %cmd, %1, entry_point = 0, workgroup_xyz = [%c1, %c1, %c1] + %exe1 = hal.variable.load @_executable : !hal.executable + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> target(%[[EXE]] : !hal.executable)[1] + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe1 : !hal.executable)[1] workgroups([%c1, %c1, %c1]) br ^bb3 ^bb3: - hal.command_buffer.end %cmd return } diff --git a/iree/compiler/Dialect/HAL/Transforms/test/identify_constant_pools.mlir b/iree/compiler/Dialect/HAL/Transforms/test/identify_constant_pools.mlir index 56a34744d0a8..31453bb7cfb8 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/identify_constant_pools.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/identify_constant_pools.mlir @@ -2,11 +2,11 @@ // CHECK: hal.constant_pool @_const_pool attributes // CHECK-SAME: buffer_constraints = #hal.buffer_constraints -// CHECK-NEXT: hal.constant_pool.value @cst0 {{.+}} = dense<1.000000e+00> : tensor<1xf32> +// CHECK-NEXT: hal.constant_pool.value @cst0 = dense<1.000000e+00> : tensor<1xf32> flow.variable @cst0 dense<1.000000e+00> : tensor<1xf32> -// CHECK-NEXT: hal.constant_pool.value @cst1 {{.+}} = dense<[2.100000e+00, 3.200000e+00, 4.300000e+00, 5.400000e+00]> : tensor<4xf32> +// CHECK-NEXT: hal.constant_pool.value @cst1 = dense<[2.100000e+00, 3.200000e+00, 4.300000e+00, 5.400000e+00]> : tensor<4xf32> flow.variable @cst1 dense<[2.1, 3.2, 4.3, 5.4]> : tensor<4xf32> -// CHECK-NEXT: hal.constant_pool.value @cst2 {{.+}} = dense<[6, 7, 8]> : tensor<3xi8> +// CHECK-NEXT: hal.constant_pool.value @cst2 = dense<[6, 7, 8]> : tensor<3xi8> flow.variable @cst2 dense<[6, 7, 8]> : tensor<3xi8> // CHECK-LABEL: func @immutable_variables @@ -23,7 +23,7 @@ func @immutable_variables() -> (tensor<1xf32>, tensor<4xf32>, tensor<3xi8>) { // ----- // CHECK: hal.constant_pool @_const_pool_init -// CHECK-NEXT: hal.constant_pool.value @variable_0 {{.+}} = dense<3.000000e+00> : tensor<128xf32> +// CHECK-NEXT: hal.constant_pool.value @variable_0 = dense<3.000000e+00> : tensor<128xf32> // CHECK: flow.variable @variable_0 mutable init(@variable_0_initializer) flow.variable @variable_0 mutable dense<3.0> : tensor<128xf32> diff --git a/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir index 6f713dcf0eb1..c41a37f415af 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/inline_device_switches.mlir @@ -12,15 +12,15 @@ func @simple_constants(%device : !hal.device, %arg : i32) -> i32 { %c2 = constant 2 : i32 // CHECK-DAG: %[[C3:.+]] = constant 3 // CHECK-DAG: %[[C4:.+]] = constant 4 - %0 = hal.device.switch(%device : !hal.device) -> i32 - // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + %0 = hal.device.switch<%device : !hal.device> -> i32 + // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vulkan-v1.?-*") : i1 // CHECK-NEXT: cond_br %[[IS0]], ^bb3(%[[C1]] : i32), ^bb1 #hal.device.match.id<"vulkan-v1.?-*">(%c1a = %c1 : i32) { hal.return %c1a : i32 }, // CHECK-NEXT: ^bb1: - // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id %arg0, pattern = ["vmla"] : (!hal.device) -> i1 - // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id %arg0, pattern = ["vulkan-*"] : (!hal.device) -> i1 + // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vmla") : i1 + // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vulkan-*") : i1 // CHECK-NEXT: %[[IS1:.+]] = or %[[IS1L]], %[[IS1R]] : i1 // CHECK-NEXT: cond_br %[[IS1]], ^bb2, ^bb3(%[[C0]] : i32) // CHECK-NEXT: ^bb2: @@ -49,8 +49,8 @@ func @simple_constants(%device : !hal.device, %arg : i32) -> i32 { // CHECK-LABEL: @no_results // CHECK-SAME: %[[DEVICE:.+]]: !hal.device func @no_results(%device : !hal.device) { - hal.device.switch(%device : !hal.device) - // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + hal.device.switch<%device : !hal.device> + // CHECK-NEXT: %[[IS0:.+]] = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vulkan-v1.?-*") : i1 // CHECK-NEXT: cond_br %[[IS0]], ^bb1, ^bb2 // CHECK-NEXT: ^bb1: // CHECK-NEXT: "some.op_a"() @@ -60,8 +60,8 @@ func @no_results(%device : !hal.device) { hal.return }, // CHECK-NEXT: ^bb2: - // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id %arg0, pattern = ["vmla"] : (!hal.device) -> i1 - // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id %arg0, pattern = ["vulkan-*"] : (!hal.device) -> i1 + // CHECK-NEXT: %[[IS1L:.+]] = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vmla") : i1 + // CHECK-NEXT: %[[IS1R:.+]] = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vulkan-*") : i1 // CHECK-NEXT: %[[IS1:.+]] = or %[[IS1L]], %[[IS1R]] : i1 // CHECK-NEXT: cond_br %[[IS1]], ^bb3, ^bb4 // CHECK-NEXT: ^bb3: diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_constant_pool_buffers.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_constant_pool_buffers.mlir index 3232083f1c03..66052c1969aa 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_constant_pool_buffers.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_constant_pool_buffers.mlir @@ -11,8 +11,10 @@ hal.constant_pool @dense_variable_init attributes {buffer_constraints = #hal.buf // CHECK: hal.variable @dense_variable_init_storage_buffer init(@dense_variable_init_storage_buffer_initializer) : !hal.buffer // CHECK-NEXT: func private @dense_variable_init_storage_buffer_initializer() -> !hal.buffer -// CHECK: [[STORAGE:%.+]] = hal.constant_storage.lookup @dense_variable_init::@_storage : !iree.byte_buffer -// CHECK: = hal.allocator.map {{.+}} [[STORAGE]][%c0, %c768] : !iree.byte_buffer -> !hal.buffer +// CHECK: %[[STORAGE:.+]] = hal.constant_storage.lookup @dense_variable_init::@_storage : !iree.byte_buffer +// CHECK: = hal.allocator.map<%allocator : !hal.allocator> +// CHECK-SAME: source(%[[STORAGE]] : !iree.byte_buffer)[%c0, %c768] +// CHECK-SAME: : !hal.buffer // ----- @@ -26,9 +28,15 @@ hal.constant_pool @splat_variable_init attributes {buffer_constraints = #hal.buf // CHECK: hal.variable @splat_variable_init_splats init(@splat_variable_init_splats_initializer) : !hal.buffer // CHECK-NEXT: func private @splat_variable_init_splats_initializer() -> !hal.buffer -// CHECK: [[BUFFER:%.+]] = hal.allocator.allocate {{.+}} %c64 : !hal.buffer -// CHECK: hal.buffer.fill [[BUFFER]], %c0, %c4, %c1065353216_i32 -// CHECK: hal.buffer.fill [[BUFFER]], %c32, %c32_0, %c1234567890_i32 +// CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%allocator : !hal.allocator> +// CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") +// CHECK-SAME: usage("Constant|Transfer|Mapping|Dispatch") : !hal.buffer{%c64} +// CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> +// CHECK-SAME: target(%[[BUFFER]] : !hal.buffer)[%c0, %c4] +// CHECK-SAME: pattern(%c1065353216_i32 : i32) +// CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> +// CHECK-SAME: target(%[[BUFFER]] : !hal.buffer)[%c32, %c32_0] +// CHECK-SAME: pattern(%c1234567890_i32 : i32) // ----- @@ -48,14 +56,21 @@ hal.constant_pool @pool attributes {buffer_constraints = #hal.buffer_constraints // CHECK: hal.variable @pool_storage0_buffer init(@pool_storage0_buffer_initializer) : !hal.buffer // CHECK-NEXT: func private @pool_storage0_buffer_initializer() -> !hal.buffer -// CHECK: [[STORAGE:%.+]] = hal.constant_storage.lookup @pool::@_storage0 : !iree.byte_buffer -// CHECK: = hal.allocator.map {{.+}} [[STORAGE]][%c0, %c16] : !iree.byte_buffer -> !hal.buffer +// CHECK: %[[STORAGE:.+]] = hal.constant_storage.lookup @pool::@_storage0 : !iree.byte_buffer +// CHECK: = hal.allocator.map<%allocator : !hal.allocator> +// CHECK-SAME: source(%[[STORAGE]] : !iree.byte_buffer)[%c0, %c16] +// CHECK-SAME: : !hal.buffer // CHECK: hal.variable @pool_storage1_buffer init(@pool_storage1_buffer_initializer) : !hal.buffer // CHECK-NEXT: func private @pool_storage1_buffer_initializer() -> !hal.buffer // CHECK: hal.variable @pool_splats init(@pool_splats_initializer) : !hal.buffer // CHECK-NEXT: func private @pool_splats_initializer() -> !hal.buffer -// CHECK: [[BUFFER:%.+]] = hal.allocator.allocate %allocator, "HostVisible|DeviceVisible|DeviceLocal", "Constant|Transfer|Mapping|Dispatch", %c64 : !hal.buffer -// CHECK: hal.buffer.fill [[BUFFER]], %c0, %c4, %c1065353216_i32 -// CHECK: hal.buffer.fill [[BUFFER]], %c32, %c32_0, %c1234567890_i32 +// CHECK: %[[BUFFER:.+]] = hal.allocator.allocate<%allocator : !hal.allocator> +// CHECK-SAME: : !hal.buffer{%c64} +// CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> +// CHECK-SAME: target(%[[BUFFER]] : !hal.buffer)[%c0, %c4] +// CHECK-SAME: pattern(%c1065353216_i32 : i32) +// CHECK: hal.command_buffer.fill_buffer<%cmd : !hal.command_buffer> +// CHECK-SAME: target(%[[BUFFER]] : !hal.buffer)[%c32, %c32_0] +// CHECK-SAME: pattern(%c1234567890_i32 : i32) diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir index b90afde7bafc..53a693a1c82a 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir @@ -8,7 +8,7 @@ // CHECK-DAG: hal.executable.target @vmla, filter="vmla" { // CHECK-DAG: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes { // CHECK-SAME: interface = @legacy_io, -// CHECK-SAME: ordinal = 0 : i32, +// CHECK-SAME: ordinal = 0 : index, // CHECK-SAME: signature = (tensor<4xf32>) -> tensor<4xf32> // CHECK-SAME: } flow.executable @simpleMath_ex_dispatch_0 { @@ -44,7 +44,7 @@ flow.executable @simpleMath_ex_dispatch_0 { // CHECK-DAG: hal.executable.target @vmla, filter="vmla" { // CHECK-DAG: hal.executable.entry_point @bools_rgn_dispatch_0 attributes { // CHECK-SAME: interface = @legacy_io, -// CHECK-SAME: ordinal = 0 : i32, +// CHECK-SAME: ordinal = 0 : index, // CHECK-SAME: signature = (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> // CHECK-SAME: } flow.executable @bools_ex_dispatch_0 { @@ -78,7 +78,7 @@ flow.executable @bools_ex_dispatch_0 { // ----- // CHECK-LABEL: hal.executable @shaped_dispatch -// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 2 : i32} { +// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 2 : index} { // CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } @@ -119,7 +119,7 @@ flow.executable @static_tiled_dispatch { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, - // CHECK-SAME: ordinal = 0 : i32, + // CHECK-SAME: ordinal = 0 : index, // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { @@ -148,7 +148,7 @@ flow.executable @static_tiled_dispatch { // ----- // CHECK-LABEL: hal.executable @dynamic_tiled_dispatch -// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 4 : i32} { +// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 4 : index} { // CHECK-NEXT: hal.interface.binding @ro0, set=0, binding=0, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @wo1, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } @@ -156,7 +156,7 @@ flow.executable @dynamic_tiled_dispatch { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, - // CHECK-SAME: ordinal = 0 : i32, + // CHECK-SAME: ordinal = 0 : index, // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor, index, index, index, index) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { @@ -211,7 +211,7 @@ flow.executable @workgroup_infos { // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { // CHECK-NEXT: hal.executable.entry_point @entry attributes { // CHECK-SAME: interface = @legacy_io, - // CHECK-SAME: ordinal = 0 : i32, + // CHECK-SAME: ordinal = 0 : index, // CHECK-SAME: signature = (!flow.dispatch.tensor, !flow.dispatch.tensor) -> () // CHECK-SAME: } flow.dispatch.entry @entry attributes { diff --git a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir index 7b8b11c91b2e..e8b381559a02 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir @@ -2,18 +2,26 @@ // CHECK: hal.variable @_descriptor_set_layout_0 init(@_descriptor_set_layout_0_initializer) : !hal.descriptor_set_layout // CHECK-NEXT: func private @_descriptor_set_layout_0_initializer() -> !hal.descriptor_set_layout { -// CHECK-NEXT: %dev = hal.ex.shared_device : !hal.device -// CHECK-NEXT: %descriptor_set_layout = hal.descriptor_set_layout.create %dev, PushOnly, bindings = [#hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write">] : !hal.descriptor_set_layout +// CHECK-NEXT: %device = hal.ex.shared_device : !hal.device +// CHECK-NEXT: %descriptor_set_layout = hal.descriptor_set_layout.create +// CHECK-SAME: device(%device : !hal.device) +// CHECK-SAME: usage(PushOnly) +// CHECK-SAME: bindings([ +// CHECK-SAME: #hal.descriptor_set_layout_binding<0, "StorageBuffer", R>, +// CHECK-SAME: #hal.descriptor_set_layout_binding<1, "StorageBuffer", W> +// CHECK-SAME: ]) : !hal.descriptor_set_layout // CHECK-NEXT: return %descriptor_set_layout : !hal.descriptor_set_layout // CHECK-NEXT: } // CHECK-LABEL: @descriptorSetLayoutLookup -func @descriptorSetLayoutLookup(%arg0 : !hal.device) -> !hal.descriptor_set_layout { +func @descriptorSetLayoutLookup(%device : !hal.device) -> !hal.descriptor_set_layout { // CHECK-NEXT: %[[LAYOUT:.+]] = hal.variable.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout - %0 = hal.descriptor_set_layout.lookup %arg0, PushOnly, bindings = [ + %0 = hal.descriptor_set_layout.lookup device(%device : !hal.device) + usage(PushOnly) + bindings([ #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> - ] : !hal.descriptor_set_layout + ]) : !hal.descriptor_set_layout // CHECK-NEXT: return %[[LAYOUT]] return %0 : !hal.descriptor_set_layout } @@ -24,21 +32,25 @@ func @descriptorSetLayoutLookup(%arg0 : !hal.device) -> !hal.descriptor_set_layo // CHECK: hal.variable @_executable_layout_0 init(@_executable_layout_0_initializer) : !hal.executable_layout // CHECK-NEXT: func private @_executable_layout_0_initializer() -> !hal.executable_layout { -// CHECK-NEXT: %0 = hal.variable.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout -// CHECK-NEXT: %dev = hal.ex.shared_device : !hal.device -// CHECK-NEXT: %executable_layout = hal.executable_layout.create %dev, push_constants = 0, set_layouts = [%0] : !hal.executable_layout +// CHECK-NEXT: %[[SET0:.+]] = hal.variable.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout +// CHECK-NEXT: %device = hal.ex.shared_device : !hal.device +// CHECK-NEXT: %executable_layout = hal.executable_layout.create +// CHECK-SAME: device(%device : !hal.device) +// CHECK-SAME: push_constants(0) +// CHECK-SAME: layouts([%[[SET0]]]) : !hal.executable_layout // CHECK-NEXT: return %executable_layout : !hal.executable_layout // CHECK-NEXT: } // CHECK-LABEL: @exeLayoutLookup -func @exeLayoutLookup(%arg0 : !hal.device) -> !hal.executable_layout { +func @exeLayoutLookup(%device : !hal.device) -> !hal.executable_layout { // CHECK: %[[LAYOUT:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout - %0 = hal.executable_layout.lookup %arg0, set_layouts = [ + %0 = hal.executable_layout.lookup device(%device : !hal.device) + layouts([ [ #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> ] - ] : !hal.executable_layout + ]) : !hal.executable_layout // CHECK-NEXT: return %[[LAYOUT]] return %0 : !hal.executable_layout } @@ -50,17 +62,21 @@ func @exeLayoutLookup(%arg0 : !hal.device) -> !hal.executable_layout { // CHECK: hal.variable @_executable_layout_0 init(@_executable_layout_0_initializer) : !hal.executable_layout // CHECK-NEXT: func private @_executable_layout_0_initializer() -> !hal.executable_layout { -// CHECK-NEXT: %0 = hal.variable.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout -// CHECK-NEXT: %1 = hal.variable.load @_descriptor_set_layout_1 : !hal.descriptor_set_layout -// CHECK-NEXT: %dev = hal.ex.shared_device : !hal.device -// CHECK-NEXT: %executable_layout = hal.executable_layout.create %dev, push_constants = 0, set_layouts = [%0, %1] : !hal.executable_layout +// CHECK-NEXT: %[[SET0:.+]] = hal.variable.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout +// CHECK-NEXT: %[[SET1:.+]] = hal.variable.load @_descriptor_set_layout_1 : !hal.descriptor_set_layout +// CHECK-NEXT: %device = hal.ex.shared_device : !hal.device +// CHECK-NEXT: %executable_layout = hal.executable_layout.create +// CHECK-SAME: device(%device : !hal.device) +// CHECK-SAME: push_constants(0) +// CHECK-SAME: layouts([%[[SET0]], %[[SET1]]]) : !hal.executable_layout // CHECK-NEXT: return %executable_layout : !hal.executable_layout // CHECK-NEXT: } // CHECK-LABEL: @sharedLayoutLookup -func @sharedLayoutLookup(%arg0 : !hal.device) -> !hal.executable_layout { +func @sharedLayoutLookup(%device : !hal.device) -> !hal.executable_layout { // CHECK: %[[LAYOUT:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout - %0 = hal.executable_layout.lookup %arg0, set_layouts = [ + %0 = hal.executable_layout.lookup device(%device : !hal.device) + layouts([ [ #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> @@ -69,18 +85,20 @@ func @sharedLayoutLookup(%arg0 : !hal.device) -> !hal.executable_layout { #hal.descriptor_set_layout_binding<0, "UniformBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "UniformBuffer", "Write"> ] - ] : !hal.executable_layout + ]) : !hal.executable_layout // CHECK-NEXT: return %[[LAYOUT]] return %0 : !hal.executable_layout } // CHECK: @otherDescriptorSetLayoutLookup -func @otherDescriptorSetLayoutLookup(%arg0 : !hal.device) -> !hal.descriptor_set_layout { +func @otherDescriptorSetLayoutLookup(%device : !hal.device) -> !hal.descriptor_set_layout { // CHECK: %[[LAYOUT:.+]] = hal.variable.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout - %0 = hal.descriptor_set_layout.lookup %arg0, PushOnly, bindings = [ + %0 = hal.descriptor_set_layout.lookup device(%device : !hal.device) + usage(PushOnly) + bindings([ #hal.descriptor_set_layout_binding<0, "StorageBuffer", "Read">, #hal.descriptor_set_layout_binding<1, "StorageBuffer", "Write"> - ] : !hal.descriptor_set_layout + ]) : !hal.descriptor_set_layout // CHECK-NEXT: return %[[LAYOUT]] return %0 : !hal.descriptor_set_layout } @@ -102,19 +120,19 @@ hal.executable @exe { hal.executable.target @vmla, filter="vmla" { hal.executable.entry_point @entry0 attributes { interface = @interface0, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (tensor<4xf32>) -> tensor<4xf32>, workgroup_size = [32 : index, 1 : index, 1 : index] } hal.executable.entry_point @entry0_alias attributes { interface = @interface0, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (tensor<4xf32>) -> tensor<4xf32>, workgroup_size = [32 : index, 1 : index, 1 : index] } hal.executable.entry_point @entry1 attributes { interface = @interface1, - ordinal = 1 : i32, + ordinal = 1 : index, signature = (tensor<4xf32>, tensor<8xf32>) -> tensor<4xf32>, workgroup_size = [32 : index, 1 : index, 1 : index] } @@ -129,12 +147,16 @@ hal.executable @exe { // CHECK: hal.variable @_executable_exe init(@_executable_exe_initializer) : !hal.executable // CHECK: func private @_executable_exe_initializer() -> !hal.executable { // CHECK: %[[IN_DEV:.+]] = hal.ex.shared_device : !hal.device -// CHECK: %[[RET:.+]] = hal.device.switch(%[[IN_DEV]] : !hal.device) -> !hal.executable +// CHECK: %[[RET:.+]] = hal.device.switch<%[[IN_DEV]] : !hal.device> -> !hal.executable // CHECK: #hal.device.match.id<"vmla">(%[[DEV:.+]] = %[[IN_DEV]] : !hal.device) { // CHECK: %[[LAYOUT0:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout // CHECK: %[[LAYOUT0_2:.+]] = hal.variable.load @_executable_layout_0 : !hal.executable_layout // CHECK: %[[LAYOUT1:.+]] = hal.variable.load @_executable_layout_1 : !hal.executable_layout -// CHECK: %[[EXE:.+]] = hal.executable.create %[[DEV]], @exe::@vmla, layouts = [%[[LAYOUT0]], %[[LAYOUT0_2]], %[[LAYOUT1]]] : !hal.executable +// CHECK: %[[EXE:.+]] = hal.executable.create +// CHECK-SAME: device(%[[DEV]] : !hal.device) +// CHECK-SAME: target(@exe::@vmla) +// CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT0_2]], %[[LAYOUT1]]]) +// CHECK-SAME: : !hal.executable // CHECK: hal.return %[[EXE]] : !hal.executable // CHECK: }, // CHECK: #hal.match.always() { @@ -145,9 +167,10 @@ hal.executable @exe { // CHECK: } // CHECK-LABEL: @exeLookup -func @exeLookup(%arg0 : !hal.device) -> !hal.executable { +func @exeLookup(%device : !hal.device) -> !hal.executable { // CHECK: %[[EXE:.+]] = hal.variable.load @_executable_exe : !hal.executable - %0 = hal.executable.lookup %arg0, @exe : !hal.executable + %0 = hal.executable.lookup device(%device : !hal.device) + executable(@exe) : !hal.executable // CHECK-NEXT: return %[[EXE]] return %0 : !hal.executable } diff --git a/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir index b6b733925112..95a3efab950a 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/memoize_device_queries.mlir @@ -3,7 +3,7 @@ // CHECK: hal.variable @_device_match_id_0 init(@_device_match_id_0_initializer) : i1 // CHECK: func private @_device_match_id_0_initializer() -> i1 // CHECK-NEXT: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device -// CHECK-NEXT: %[[IS_MATCH:.+]] = hal.device.match.id %[[DEVICE]], pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 +// CHECK-NEXT: %[[IS_MATCH:.+]] = hal.device.match.id<%[[DEVICE]] : !hal.device> pattern("vulkan-v1.?-*") : i1 // CHECK-NEXT: return %[[IS_MATCH]] : i1 // CHECK: hal.variable @_device_match_id_1 @@ -12,12 +12,12 @@ // CHECK-LABEL: func @device_matchers func @device_matchers(%device : !hal.device) { // CHECK-NEXT: = hal.variable.load @_device_match_id_0 : i1 - %0 = hal.device.match.id %device, pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + %0 = hal.device.match.id<%device : !hal.device> pattern("vulkan-v1.?-*") : i1 // CHECK-NEXT: = hal.variable.load @_device_match_id_0 : i1 - %1 = hal.device.match.id %device, pattern = ["vulkan-v1.?-*"] : (!hal.device) -> i1 + %1 = hal.device.match.id<%device : !hal.device> pattern("vulkan-v1.?-*") : i1 // CHECK-NEXT: = hal.variable.load @_device_match_id_1 : i1 - %2 = hal.device.match.id %device, pattern = ["vulkan-v2.?-*"] : (!hal.device) -> i1 + %2 = hal.device.match.id<%device : !hal.device> pattern("vulkan-v2.?-*") : i1 // CHECK-NEXT: = hal.variable.load @_device_match_id_2 : i1 - %3 = hal.device.match.id %device, pattern = ["vulkan-*"] : (!hal.device) -> i1 + %3 = hal.device.match.id<%device : !hal.device> pattern("vulkan-*") : i1 return } diff --git a/iree/compiler/Dialect/HAL/Transforms/test/pack_constant_pool_storage.mlir b/iree/compiler/Dialect/HAL/Transforms/test/pack_constant_pool_storage.mlir index 469a94cb2b67..60d2db3cb774 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/pack_constant_pool_storage.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/pack_constant_pool_storage.mlir @@ -7,14 +7,14 @@ hal.constant_pool @pool attributes { max_buffer_range = 134217728, min_buffer_range_alignment = 4> } { - // CHECK-DAG: hal.constant_pool.splat @cst0 {{.+}} = dense<1.000000e+00> : tensor<1xf32> + // CHECK-DAG: hal.constant_pool.splat @cst0 = dense<1.000000e+00> : tensor<1xf32> hal.constant_pool.value @cst0 = dense<1.000000e+00> : tensor<1xf32> - // CHECK-DAG: hal.constant_pool.span @cst1 : tensor<4xf32> {{.+}} = @_storage[#hal.byte_range<0, 16>] + // CHECK-DAG: hal.constant_pool.span @cst1 : tensor<4xf32> = @_storage[#hal.byte_range<0, 16>] hal.constant_pool.value @cst1 = dense<[2.1, 3.2, 4.3, 5.4]> : tensor<4xf32> - // CHECK-DAG: hal.constant_pool.span @cst2 : tensor<3xi8> {{.+}} = @_storage[#hal.byte_range<32, 3>] + // CHECK-DAG: hal.constant_pool.span @cst2 : tensor<3xi8> = @_storage[#hal.byte_range<32, 3>] hal.constant_pool.value @cst2 = dense<[6, 7, 8]> : tensor<3xi8> - // CHECK: hal.constant_storage @_storage {{.+}} = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 7, 8, 0]> : vector<36xi8> + // CHECK: hal.constant_storage @_storage = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 7, 8, 0]> : vector<36xi8> } // ----- @@ -26,11 +26,11 @@ hal.constant_pool @multi_storage attributes { max_buffer_range = 134217728, min_buffer_range_alignment = 1> } { - // CHECK-DAG: hal.constant_pool.span @cst0 : tensor<4xf32> {{.+}} = @_storage[#hal.byte_range<0, 16>] + // CHECK-DAG: hal.constant_pool.span @cst0 : tensor<4xf32> = @_storage[#hal.byte_range<0, 16>] hal.constant_pool.value @cst0 = dense<[2.1, 3.2, 4.3, 5.4]> : tensor<4xf32> - // CHECK-DAG: hal.constant_pool.span @cst1 : tensor<3xi8> {{.+}} = @_storage_0[#hal.byte_range<0, 3>] + // CHECK-DAG: hal.constant_pool.span @cst1 : tensor<3xi8> = @_storage_0[#hal.byte_range<0, 3>] hal.constant_pool.value @cst1 = dense<[6, 7, 8]> : tensor<3xi8> - // CHECK-NEXT: hal.constant_storage @_storage {{.+}} = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8> - // CHECK-NEXT: hal.constant_storage @_storage_0 {{.+}} = dense<[6, 7, 8]> : vector<3xi8> + // CHECK-NEXT: hal.constant_storage @_storage = dense<[102, 102, 6, 64, -51, -52, 76, 64, -102, -103, -119, 64, -51, -52, -84, 64]> : vector<16xi8> + // CHECK-NEXT: hal.constant_storage @_storage_0 = dense<[6, 7, 8]> : vector<3xi8> } diff --git a/iree/compiler/Dialect/HAL/Transforms/test/propagate_constant_workgroup_info.mlir b/iree/compiler/Dialect/HAL/Transforms/test/propagate_constant_workgroup_info.mlir index 6efe13257322..71f9c65f6b81 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/propagate_constant_workgroup_info.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/propagate_constant_workgroup_info.mlir @@ -8,7 +8,7 @@ hal.executable @exe { hal.executable.target @target, filter="target" { hal.executable.entry_point @entry attributes { interface = @interface, - ordinal = 0 : i32, + ordinal = 0 : index, signature = (tensor<4xf32>) -> tensor<4xf32>, workgroup_size = [32 : index, 4 : index, 8 : index] } diff --git a/iree/compiler/Dialect/HAL/Transforms/test/public_abi_generation.mlir b/iree/compiler/Dialect/HAL/Transforms/test/public_abi_generation.mlir index c9c7cc898836..802d6072a6bd 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/public_abi_generation.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/public_abi_generation.mlir @@ -12,27 +12,24 @@ func @noReflectionExport(%arg0 : tensor<4xf32>) -> tensor<4xf32> // CHECK-LABEL: @staticTwoArg // Note: reflection matches signature: // (%arg0 : tensor<4x4xi64>, %arg1 : tensor<5x6xi64>) -> tensor<5x6xi64> -// Original func should be rewritten to export with $raw suffix with no -// reflection metadata. -// CHECK-SAME: iree.module.export = "staticTwoArg$raw" // A new function with $async suffix based on buffer_view with wait and signal // semaphore arguments should be generated. // CHECK: func @staticTwoArg$async(%[[ARG0:.+]]: !hal.semaphore, %[[ARG1:.+]]: index, %[[ARG2:.+]]: !hal.buffer_view, %[[ARG3:.+]]: !hal.buffer_view, %[[ARG4:.+]]: !hal.semaphore, %[[ARG5:.+]]: index) // CHECK-SAME: attributes // CHECK-SAME: iree.module.export = "staticTwoArg$async" -func @staticTwoArg(%arg0 : !hal.buffer, %arg1 : !hal.buffer) -> !hal.buffer +func @staticTwoArg(%arg0: !hal.buffer, %arg1: !hal.buffer) -> !hal.buffer attributes {iree.module.export, iree.reflection = {f = "I19!B7!t7d4d4B7!t7d5d6R10!B7!t7d5d6", fv = "1"}} { - // CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await %[[ARG0]], min_value = %[[ARG1]] : i32 + // CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[ARG0]] : !hal.semaphore> until(%[[ARG1]]) : i32 // CHECK-DAG: hal.check_success %[[WAITRESULT]] // CHECK-DAG: %[[BUFFER0:.+]] = hal.buffer_view.buffer %[[ARG2]] : !hal.buffer // CHECK-DAG: %[[BUFFER1:.+]] = hal.buffer_view.buffer %[[ARG3]] : !hal.buffer // CHECK-DAG: %[[R0:.+]] = call @staticTwoArg(%[[BUFFER0]], %[[BUFFER1]]) // CHECK-DAG: %[[C5:.+]] = constant 5 : index // CHECK-DAG: %[[C6:.+]] = constant 6 : index - // CHECK-DAG: %[[VIEW:.+]] = hal.buffer_view.create %[[R0]], element_type = %c16777280_i32, shape = [%[[C5]], %[[C6]]] : !hal.buffer_view - // CHECK-DAG: hal.semaphore.signal %[[ARG4]], value = %[[ARG5]] + // CHECK-DAG: %[[VIEW:.+]] = hal.buffer_view.create %[[R0]], element_type = %c16777280_i32, shape = [%[[C5]], %[[C6]]] : !hal.buffer -> !hal.buffer_view + // CHECK-DAG: hal.semaphore.signal<%[[ARG4]] : !hal.semaphore> value(%[[ARG5]]) // CHECK: return %[[VIEW]] return %arg1 : !hal.buffer } @@ -46,9 +43,9 @@ func @staticTwoArg(%arg0 : !hal.buffer, %arg1 : !hal.buffer) -> !hal.buffer // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device -// CHECK-DAG: %[[SEMAPHORE:.+]] = hal.semaphore.create %[[DEVICE]], initial_value = %[[C0]] : !hal.semaphore +// CHECK-DAG: %[[SEMAPHORE:.+]] = hal.semaphore.create device(%[[DEVICE]] : !hal.device) initial(%[[C0]]) : !hal.semaphore // CHECK-DAG: %[[RESULT:.+]] = call @staticTwoArg$async(%[[SEMAPHORE]], %[[C0]], %[[ARG0]], %[[ARG1]], %[[SEMAPHORE]], %[[C1]]) : (!hal.semaphore, index, !hal.buffer_view, !hal.buffer_view, !hal.semaphore, index) -> !hal.buffer_view -// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await %[[SEMAPHORE]], min_value = %[[C1]] : i32 +// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[SEMAPHORE]] : !hal.semaphore> until(%[[C1]]) : i32 // CHECK-DAG: hal.check_success %[[WAITRESULT]] // CHECK: return %[[RESULT]] : !hal.buffer_view @@ -57,22 +54,19 @@ func @staticTwoArg(%arg0 : !hal.buffer, %arg1 : !hal.buffer) -> !hal.buffer // CHECK-LABEL: @dynamicTwoDims // Note: reflection matches signature: // (%arg0 : tensor) -> tensor -// Original func should be rewritten to export with $raw suffix with no -// reflection metadata. -// CHECK-SAME: iree.module.export = "dynamicTwoDims$raw" // A new function with $async suffix based on buffer_view with wait and signal // semaphore arguments should be generated. // CHECK: func @dynamicTwoDims$async(%[[ARG0:.+]]: !hal.semaphore, %[[ARG1:.+]]: index, %[[ARG2:.+]]: !hal.buffer_view, %[[ARG3:.+]]: !hal.semaphore, %[[ARG4:.+]]: index) // CHECK-SAME: attributes // CHECK-SAME: iree.module.export = "dynamicTwoDims$async" -// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await %[[ARG0]], min_value = %[[ARG1]] : i32 +// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[ARG0]] : !hal.semaphore> until(%[[ARG1]]) : i32 // CHECK-DAG: hal.check_success %[[WAITRESULT]] // CHECK-DAG: %[[BUFFER:.+]] = hal.buffer_view.buffer %[[ARG2]] : !hal.buffer // CHECK-DAG: %[[DIM0:.+]] = hal.buffer_view.dim %[[ARG2]], 0 : index // CHECK-DAG: %[[DIM1:.+]] = hal.buffer_view.dim %[[ARG2]], 1 : index // CHECK-DAG: %[[RESULT:.+]]:3 = call @dynamicTwoDims(%[[BUFFER]], %[[DIM0]], %[[DIM1]]) -// CHECK-DAG: %[[RESULT_VIEW:.+]] = hal.buffer_view.create %[[RESULT]]#0, element_type = %c50331680_i32, shape = [%[[RESULT]]#1, %[[RESULT]]#2] : !hal.buffer_view -// CHECK-DAG: hal.semaphore.signal %[[ARG3]], value = %[[ARG4]] +// CHECK-DAG: %[[RESULT_VIEW:.+]] = hal.buffer_view.create %[[RESULT]]#0, element_type = %c50331680_i32, shape = [%[[RESULT]]#1, %[[RESULT]]#2] : !hal.buffer -> !hal.buffer_view +// CHECK-DAG: hal.semaphore.signal<%[[ARG3]] : !hal.semaphore> value(%[[ARG4]]) // CHECK: return %[[RESULT_VIEW]] // A new function with $sync suffix based on buffer_view should be generated. // It should wrap the $async function. @@ -84,9 +78,9 @@ func @staticTwoArg(%arg0 : !hal.buffer, %arg1 : !hal.buffer) -> !hal.buffer // CHECK-DAG: %[[C0:.+]] = constant 0 : index // CHECK-DAG: %[[C1:.+]] = constant 1 : index // CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device -// CHECK-DAG: %[[SEMAPHORE:.+]] = hal.semaphore.create %[[DEVICE]], initial_value = %[[C0]] : !hal.semaphore +// CHECK-DAG: %[[SEMAPHORE:.+]] = hal.semaphore.create device(%[[DEVICE]] : !hal.device) initial(%[[C0]]) : !hal.semaphore // CHECK-DAG: %[[RESULT:.+]] = call @dynamicTwoDims$async(%[[SEMAPHORE]], %[[C0]], %[[ARG0]], %[[SEMAPHORE]], %[[C1]]) : (!hal.semaphore, index, !hal.buffer_view, !hal.semaphore, index) -> !hal.buffer_view -// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await %[[SEMAPHORE]], min_value = %[[C1]] : i32 +// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[SEMAPHORE]] : !hal.semaphore> until(%[[C1]]) : i32 // CHECK-DAG: hal.check_success %[[WAITRESULT]] // CHECK: return %[[RESULT]] : !hal.buffer_view func @dynamicTwoDims(%arg0 : !hal.buffer, %arg1 : index, %arg2 : index) -> (!hal.buffer, index, index) diff --git a/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir b/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir index 68861a413c02..0b76d529b8f1 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir +++ b/iree/compiler/Dialect/HAL/Transforms/test/resolve_entry_point_ordinals.mlir @@ -1,93 +1,113 @@ -// RUN: iree-opt -allow-unregistered-dialect -split-input-file -iree-hal-resolve-entry-point-ordinals %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -iree-hal-resolve-entry-point-ordinals %s | IreeFileCheck %s -// CHECK: module { -module { - hal.executable @exe { - hal.interface @interface { - hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write" - } - hal.executable.target @target, filter="target" { - hal.executable.entry_point @entry attributes { - interface = @interface, - ordinal = 0 : i32, - signature = (tensor<4xf32>) -> tensor<4xf32>, - workgroup_size = [32 : index, 1 : index, 1 : index] - } +hal.executable @exe { + hal.interface @interface { + hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write" + } + hal.executable.target @target, filter="target" { + hal.executable.entry_point @entry attributes { + interface = @interface, + ordinal = 0 : index, + signature = (tensor<4xf32>) -> tensor<4xf32>, + workgroup_size = [32 : index, 1 : index, 1 : index] } } +} - func @dispatch_with_nested_references() { - %cmd = "test_hal.command_buffer"() : () -> !hal.command_buffer - %x = "test_hal.workgroup_x"() : () -> index - %y = "test_hal.workgroup_y"() : () -> index - %z = "test_hal.workgroup_z"() : () -> index - // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device %0 - // CHECK: %[[EXE:.+]] = hal.executable.lookup %[[DEVICE]], @exe - // CHECK: hal.command_buffer.dispatch %0, %[[EXE]], entry_point = 0, workgroup_xyz = [%1, %2, %3] - hal.command_buffer.dispatch.symbol %cmd, @exe::@target::@entry, workgroup_xyz = [%x, %y, %z] - return - } +// CHECK-LABEL: @dispatch_with_nested_references +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer +func @dispatch_with_nested_references(%cmd : !hal.command_buffer) { + %c10 = constant 10 : index + %c11 = constant 11 : index + %c12 = constant 12 : index + // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] + // CHECK: %[[EXE:.+]] = hal.executable.lookup + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: executable(@exe) : !hal.executable + // CHECK: hal.command_buffer.dispatch<%[[CMD]] + // CHECK-SAME: target(%[[EXE]] : !hal.executable)[0] + // CHECK-SAME: workgroups([%c10, %c11, %c12]) + hal.command_buffer.dispatch.symbol<%cmd : !hal.command_buffer> + target(@exe::@target::@entry) + workgroups([%c10, %c11, %c12]) + return } // ----- -// CHECK: module { -module { - func @dispatch_already_using_ordinals() { - %cmd = "test_hal.command_buffer"() : () -> !hal.command_buffer - %exe = "test_hal.executable"() : () -> !hal.executable - %x = "test_hal.workgroup_x"() : () -> index - %y = "test_hal.workgroup_y"() : () -> index - %z = "test_hal.workgroup_z"() : () -> index - // CHECK: hal.command_buffer.dispatch %0, %1, entry_point = 2, workgroup_xyz = [%2, %3, %4] - hal.command_buffer.dispatch %cmd, %exe, entry_point = 2, workgroup_xyz = [%x, %y, %z] - return - } +// CHECK-LABEL: @dispatch_already_using_ordinals +func @dispatch_already_using_ordinals( + // CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer + %cmd: !hal.command_buffer, + // CHECK-SAME: %[[EXE:.+]]: !hal.executable + %exe: !hal.executable +) { + %c10 = constant 10 : index + %c11 = constant 11 : index + %c12 = constant 12 : index + // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(%[[EXE]] : !hal.executable)[2] + // CHECK-SAME: workgroups([%c10, %c11, %c12]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> + target(%exe : !hal.executable)[2] + workgroups([%c10, %c11, %c12]) + return } // ----- -// CHECK: module { -module { - hal.executable @exe { - hal.interface @interface { - hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read" - hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write" - } - hal.executable.target @target, filter="target" { - hal.executable.entry_point @entry attributes { - interface = @interface, - ordinal = 0 : i32, - signature = (tensor<4xf32>) -> tensor<4xf32>, - workgroup_size = [32 : index, 1 : index, 1 : index] - } +hal.executable @exe { + hal.interface @interface { + hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read" + hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write" + } + hal.executable.target @target, filter="target" { + hal.executable.entry_point @entry attributes { + interface = @interface, + ordinal = 0 : index, + signature = (tensor<4xf32>) -> tensor<4xf32>, + workgroup_size = [32 : index, 1 : index, 1 : index] } } +} - func @dispatch_indirect_with_nested_references() { - %cmd = "test_hal.command_buffer"() : () -> !hal.command_buffer - %buffer = "test_hal.buffer"() : () -> !hal.buffer - %offset = "test_hal.offset"() : () -> index - // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device %0 - // CHECK: %[[EXE:.+]] = hal.executable.lookup %[[DEVICE]], @exe - // CHECK: hal.command_buffer.dispatch.indirect %0, %[[EXE]], entry_point = 0, workgroups = %1[%2] - hal.command_buffer.dispatch.indirect.symbol %cmd, @exe::@target::@entry, workgroups = %buffer[%offset] - return - } +// CHECK-LABEL: @dispatch_indirect_with_nested_references +func @dispatch_indirect_with_nested_references( + // CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer + %cmd: !hal.command_buffer, + // CHECK-SAME: %[[BUF:.+]]: !hal.buffer + %buf: !hal.buffer +) { + %c10 = constant 10 : index + // CHECK: %[[DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] + // CHECK: %[[EXE:.+]] = hal.executable.lookup device(%[[DEVICE]] : !hal.device) executable(@exe) + // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(%[[EXE]] : !hal.executable)[0] + // CHECK-SAME: workgroups(%[[BUF]] : !hal.buffer)[%c10] + hal.command_buffer.dispatch.indirect.symbol<%cmd : !hal.command_buffer> + target(@exe::@target::@entry) + workgroups(%buf : !hal.buffer)[%c10] + return } // ----- -// CHECK: module { -module { - func @dispatch_indirect_already_using_ordinals() { - %cmd = "test_hal.command_buffer"() : () -> !hal.command_buffer - %exe = "test_hal.executable"() : () -> !hal.executable - %buffer = "test_hal.buffer"() : () -> !hal.buffer - %offset = "test_hal.offset"() : () -> index - // CHECK: hal.command_buffer.dispatch.indirect %0, %1, entry_point = 0, workgroups = %2[%3] - hal.command_buffer.dispatch.indirect %cmd, %exe, entry_point = 0, workgroups = %buffer[%offset] - return - } +// CHECK-LABEL: @dispatch_indirect_already_using_ordinals +func @dispatch_indirect_already_using_ordinals( + // CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer + %cmd: !hal.command_buffer, + // CHECK-SAME: %[[EXE:.+]]: !hal.executable + %exe: !hal.executable, + // CHECK-SAME: %[[BUF:.+]]: !hal.buffer + %buf: !hal.buffer +) { + %c10 = constant 10 : index + // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(%[[EXE]] : !hal.executable)[0] + // CHECK-SAME: workgroups(%[[BUF]] : !hal.buffer)[%c10] + hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> + target(%exe : !hal.executable)[0] + workgroups(%buf : !hal.buffer)[%c10] + return } diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp index 3e9edda05cd0..55ae9c640b51 100644 --- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp +++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.cpp @@ -60,15 +60,16 @@ SmallVector getStaticShapeDims(Location loc, ShapedType shapedType, return shape; } -llvm::Optional> getShapeDims( - Location loc, Value shapedValue, ConversionPatternRewriter &rewriter) { +llvm::Optional> getShapeDims(Location loc, + Value shapedValue, + OpBuilder &builder) { ShapedType shapedType = shapedValue.getType().cast(); if (shapedType.hasStaticShape()) { - return getStaticShapeDims(loc, shapedType, rewriter); + return getStaticShapeDims(loc, shapedType, builder); } else { // Dynamic shape lookup. Value rsValue = Shape::buildOrFindRankedShapeForValue( - loc, shapedValue, rewriter.getIndexType(), rewriter); + loc, shapedValue, builder.getIndexType(), builder); if (!rsValue) { return llvm::None; } @@ -76,17 +77,50 @@ llvm::Optional> getShapeDims( // Note that in the following, we require that the dims resolve // to discrete SSA values, which in a stream, will be block args. if (failed(Shape::getRankedDimsFromRankedShape( - loc, rsValue, /*createIntermediateOps=*/true, dims, rewriter))) { + loc, rsValue, /*createIntermediateOps=*/true, dims, builder))) { return llvm::None; } return dims; } } +Value getValueSize(Location loc, Value value, OpBuilder &builder) { + // Function arguments are special as we always have to query. + auto definingOp = value.getDefiningOp(); + if (!definingOp) { + return builder.createOrFold( + loc, builder.getIndexType(), value); + } + + if (auto awareOp = dyn_cast_or_null(definingOp)) { + return awareOp.getResultSizeFromValue(value); + } + + auto type = value.getType(); + if (auto awareType = type.dyn_cast()) { + auto sizeValue = awareType.getSize(value); + if (sizeValue) return sizeValue; + } + if (auto inferType = type.dyn_cast()) { + return inferType.inferSizeFromValue(loc, value, builder); + } + + auto elementType = IREE::HAL::getElementTypeValue( + value.getType().cast().getElementType()); + if (!elementType) return {}; + auto shape = IREE::HAL::getShapeDims(loc, value, builder); + if (!shape) return {}; + auto deviceValue = builder.createOrFold(loc); + auto allocatorValue = + builder.createOrFold(loc, deviceValue); + return builder.createOrFold( + loc, allocatorValue, *shape, elementType.getValue()); +} + // static bool TensorRewriteAdaptor::isValidNewType(Type newType) { - return newType.isa() || - newType.isa(); + return newType.isa() || + newType.isa(); } // static @@ -126,8 +160,8 @@ llvm::Optional TensorRewriteAdaptor::getChecked( } Value TensorRewriteAdaptor::getAllocator() { - return rewriter_.createOrFold(loc_, - getBuffer()); + return rewriter_.createOrFold( + loc_, AllocatorType::get(rewriter_.getContext()), getBuffer()); } bool TensorRewriteAdaptor::isBufferView() { @@ -136,8 +170,8 @@ bool TensorRewriteAdaptor::isBufferView() { Value TensorRewriteAdaptor::getBuffer() { if (isBufferView()) { - return rewriter_.createOrFold(loc_, - newValue_); + return rewriter_.createOrFold( + loc_, IREE::HAL::BufferType::get(rewriter_.getContext()), newValue_); } else { return newValue_; } @@ -174,6 +208,7 @@ IntegerAttr TensorRewriteAdaptor::getElementTypeAttr() { llvm::Optional> TensorRewriteAdaptor::getShapeDims() { return IREE::HAL::getShapeDims(loc_, oldValue_, rewriter_); } + llvm::Optional> TensorRewriteAdaptor::getShapeDims( ConversionPatternRewriter &rewriter) { return IREE::HAL::getShapeDims(loc_, oldValue_, rewriter); diff --git a/iree/compiler/Dialect/HAL/Utils/TypeUtils.h b/iree/compiler/Dialect/HAL/Utils/TypeUtils.h index 6c396a4c0077..d17444227ec9 100644 --- a/iree/compiler/Dialect/HAL/Utils/TypeUtils.h +++ b/iree/compiler/Dialect/HAL/Utils/TypeUtils.h @@ -47,8 +47,14 @@ SmallVector getStaticShapeDims(Location loc, ShapedType shapedType, OpBuilder &builder); // Returns an array of i32 values representing the shape of the |shapedValue|. -llvm::Optional> getShapeDims( - Location loc, Value shapedValue, ConversionPatternRewriter &rewriter); +llvm::Optional> getShapeDims(Location loc, + Value shapedValue, + OpBuilder &builder); + +// Returns the size of |value| as an index type. +// The returned value may either be produced at the current insertion site or +// pulled from a dominating block/block argument. +Value getValueSize(Location loc, Value value, OpBuilder &builder); // An adaptor used for tensor->buffer rewrites. // This abstracts the source and destination types to allow for implicit diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/ConversionPatterns.cpp b/iree/compiler/Dialect/Modules/TensorList/Conversion/ConversionPatterns.cpp index ade41bfc8b1b..d26bf435cd4a 100644 --- a/iree/compiler/Dialect/Modules/TensorList/Conversion/ConversionPatterns.cpp +++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/ConversionPatterns.cpp @@ -97,7 +97,8 @@ class ConcatOpConversion newOperands[0]); auto bufferOp = rewriter.createOrFold( - newConcatOp.getLoc(), newConcatOp); + newConcatOp.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + newConcatOp); rewriter.replaceOp(concatOp, bufferOp); return success(); @@ -128,7 +129,8 @@ class StackOpConversion allocator, newOperands[0], operand1); auto bufferOp = rewriter.createOrFold( - stackOp.getLoc(), newStackOp); + stackOp.getLoc(), IREE::HAL::BufferType::get(rewriter.getContext()), + newStackOp); rewriter.replaceOp(stackOp, bufferOp); return success(); diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir index 5974d4508905..3f011a473f14 100644 --- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir +++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_hal_to_vm.mlir @@ -32,8 +32,8 @@ func @SetItem(%list: !tensorlist.list, %index: !hal.buffer_view, %item: !hal.buf // CHECK-LABEL: @Stack func @Stack(%list: !tensorlist.list, %num_elements: !hal.buffer_view) -> !hal.buffer_view { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator // CHECK: vm.call @tensorlist.stack %0 = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view return %0 : !hal.buffer_view @@ -43,8 +43,8 @@ func @Stack(%list: !tensorlist.list, %num_elements: !hal.buffer_view) -> !hal.bu // CHECK-LABEL: @Concat func @Concat(%list: !tensorlist.list) -> !hal.buffer_view { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator // CHECK: vm.call @tensorlist.concat %0 = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view return %0 : !hal.buffer_view diff --git a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_to_hal.mlir b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_to_hal.mlir index a75a32ad3648..b4159ba2e27e 100644 --- a/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_to_hal.mlir +++ b/iree/compiler/Dialect/Modules/TensorList/Conversion/test/convert_to_hal.mlir @@ -56,8 +56,7 @@ func @GetItem(%arg0: !tensorlist.list, %arg1: tensor) -> tensor { // CHECK: @Stack func @Stack(%arg0: !tensorlist.list, %arg1: tensor) -> tensor<1xf32> { - // CHECK-DAG: [[DEV:%.+]] = hal.ex.shared_device - // CHECK-DAG: [[ALL:%.+]] = hal.device.allocator [[DEV]] + // CHECK-DAG: [[ALL:%.+]] = hal.device.allocator // CHECK-DAG: [[VIEW:%.+]] = hal.buffer_view.create %arg1 // CHECK-DAG: [[RES:%.+]] = "tensorlist.Stack"([[ALL]], %arg0, [[VIEW]]) // CHECK-DAG: [[BUF:%.+]] = hal.buffer_view.buffer [[RES]] @@ -71,8 +70,7 @@ func @Stack(%arg0: !tensorlist.list, %arg1: tensor) -> tensor<1xf32> { // CHECK: @Concat func @Concat(%arg0: !tensorlist.list) -> tensor<1xf32> { - // CHECK: [[DEV:%.+]] = hal.ex.shared_device : !hal.device - // CHECK: [[ALL:%.+]] = hal.device.allocator [[DEV]] + // CHECK: [[ALL:%.+]] = hal.device.allocator // CHECK: [[RES:%.+]] = "tensorlist.Concat"([[ALL]], %arg0) // CHECK: [[BUF:%.+]] = hal.buffer_view.buffer [[RES]] %0 = "tensorlist.Concat.Tensor"(%arg0) : (!tensorlist.list) -> tensor<1xf32> diff --git a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp index 17b65f65de66..6a1d27f7e57a 100644 --- a/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp +++ b/iree/compiler/Dialect/VMLA/Conversion/HALToVMLA/ConvertHALToVMLA.cpp @@ -91,7 +91,8 @@ struct InterfaceLoadTensorOpConversion IREE::HAL::InterfaceLoadTensorOp::Adaptor newOperands(operands); auto bufferOp = rewriter.create( loadOp.getLoc(), IREE::VMLA::BufferType::get(loadOp.getContext()), - interfaceArg, bindingOp.set(), bindingOp.binding()); + interfaceArg, bindingOp.set().getZExtValue(), + bindingOp.binding().getZExtValue()); auto byteLengthValue = VMLAConversionTarget::getBufferLength( loadOp.getLoc(), loadOp.result(), typeConverter, rewriter); if (!byteLengthValue) return failure(); @@ -123,7 +124,8 @@ struct InterfaceStoreTensorOpConversion IREE::HAL::InterfaceStoreTensorOp::Adaptor newOperands(operands); auto bufferOp = rewriter.create( storeOp.getLoc(), IREE::VMLA::BufferType::get(storeOp.getContext()), - interfaceArg, bindingOp.set(), bindingOp.binding()); + interfaceArg, bindingOp.set().getZExtValue(), + bindingOp.binding().getZExtValue()); auto zeroValue = rewriter.createOrFold(storeOp.getLoc(), 0); diff --git a/iree/modules/check/test/success.mlir b/iree/modules/check/test/success.mlir index bf17af4897fd..84cd8392d55e 100644 --- a/iree/modules/check/test/success.mlir +++ b/iree/modules/check/test/success.mlir @@ -14,9 +14,11 @@ func @expect_false() attributes { iree.module.export } { } func @expect_all_true() attributes {iree.module.export} { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator - %all_true = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<1> : tensor<2x2xi32> + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %all_true = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<1> : tensor<2x2xi32> check.expect_all_true(%all_true) : !hal.buffer_view return } diff --git a/iree/modules/tensorlist/tensorlist_test.mlir b/iree/modules/tensorlist/tensorlist_test.mlir index 6425b1b0fa96..ebf2fce8adcf 100644 --- a/iree/modules/tensorlist/tensorlist_test.mlir +++ b/iree/modules/tensorlist/tensorlist_test.mlir @@ -1,9 +1,15 @@ func @identity_through_set_item_get_item(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator - %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<1> : tensor - %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32> - %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %0 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<1> : tensor + %1 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<[]> : tensor<0xi32> + %2 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<0> : tensor %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %5 = "tensorlist.GetItem"(%4, %2) : (!tensorlist.list, !hal.buffer_view) -> !hal.buffer_view @@ -11,11 +17,17 @@ func @identity_through_set_item_get_item(%arg0: !hal.buffer_view) -> !hal.buffer } func @identity_through_set_item_get_item_2D(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator - %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<1> : tensor - %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[1, 1]> : tensor<2xi32> - %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %0 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<1> : tensor + %1 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<[1, 1]> : tensor<2xi32> + %2 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<0> : tensor %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %stacked = "tensorlist.Stack"(%allocator, %4, %0) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view @@ -23,20 +35,28 @@ func @identity_through_set_item_get_item_2D(%arg0: !hal.buffer_view) -> !hal.buf } func @identity_through_concat(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator - %element_shape = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32> + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %element_shape = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<[]> : tensor<0xi32> %list = "tensorlist.FromTensor"(%arg0) : (!hal.buffer_view) -> !tensorlist.list %concat = "tensorlist.Concat"(%allocator, %list) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view return %concat : !hal.buffer_view } func @concat_appends_empty(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator - %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor - %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[1]> : tensor<1xi32> - %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %0 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<2> : tensor + %1 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<[1]> : tensor<1xi32> + %2 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<0> : tensor %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %concat = "tensorlist.Concat"(%allocator, %4) : (!hal.allocator, !tensorlist.list) -> !hal.buffer_view @@ -44,20 +64,28 @@ func @concat_appends_empty(%arg0: !hal.buffer_view) -> !hal.buffer_view attribut } func @identity_through_stack(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator - %num_elements = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %num_elements = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<2> : tensor %list = "tensorlist.FromTensor"(%arg0) : (!hal.buffer_view) -> !tensorlist.list %stacked = "tensorlist.Stack"(%allocator, %list, %num_elements) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view return %stacked : !hal.buffer_view } func @stack_appends_empty(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.module.export, iree.abi.none} { - %dev = hal.ex.shared_device : !hal.device - %allocator = hal.device.allocator %dev : !hal.allocator - %0 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<2> : tensor - %1 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<[]> : tensor<0xi32> - %2 = hal.buffer_view.const %allocator, "HostLocal|DeviceVisible", "All" : !hal.buffer_view = dense<0> : tensor + %device = hal.ex.shared_device : !hal.device + %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator + %0 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<2> : tensor + %1 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<[]> : tensor<0xi32> + %2 = hal.allocator.constant<%allocator : !hal.allocator> + type("HostLocal|DeviceVisible") usage("All") : !hal.buffer_view = + dense<0> : tensor %3 = "tensorlist.Reserve"(%1, %0) { element_type = 50331680 : i32} : (!hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %4 = "tensorlist.SetItem"(%3, %2, %arg0) : (!tensorlist.list, !hal.buffer_view, !hal.buffer_view) -> !tensorlist.list %stacked = "tensorlist.Stack"(%allocator, %4, %0) : (!hal.allocator, !tensorlist.list, !hal.buffer_view) -> !hal.buffer_view diff --git a/iree/samples/simple_embedding/simple_embedding_test.cc b/iree/samples/simple_embedding/simple_embedding_test.cc index b82ae6458df2..9d165550755e 100644 --- a/iree/samples/simple_embedding/simple_embedding_test.cc +++ b/iree/samples/simple_embedding/simple_embedding_test.cc @@ -109,9 +109,9 @@ TEST_P(SimpleEmbeddingTest, RunOnce) { iree_vm_module_release(bytecode_module); // Lookup the entry point function. - // Note that we use the "raw" variant which operates on pure type/shape + // Note that we use the synchronous variant which operates on pure type/shape // erased buffers. - const char kMainFunctionName[] = "module.simple_mul$raw"; + const char kMainFunctionName[] = "module.simple_mul"; iree_vm_function_t main_function; IREE_ASSERT_OK(iree_vm_context_resolve_function( context, iree_make_cstring_view(kMainFunctionName), &main_function)) @@ -142,15 +142,30 @@ TEST_P(SimpleEmbeddingTest, RunOnce) { IREE_ASSERT_OK(iree_hal_buffer_fill(arg1_buffer, 0, IREE_WHOLE_BUFFER, &kFloat2, sizeof(float))); + // Wrap buffers in shaped buffer views. + iree_hal_dim_t shape[1] = {kElementCount}; + iree_hal_buffer_view_t* arg0_buffer_view = nullptr; + iree_hal_buffer_view_t* arg1_buffer_view = nullptr; + IREE_ASSERT_OK(iree_hal_buffer_view_create( + arg0_buffer, IREE_HAL_ELEMENT_TYPE_FLOAT_32, shape, IREE_ARRAYSIZE(shape), + &arg0_buffer_view)); + IREE_ASSERT_OK(iree_hal_buffer_view_create( + arg1_buffer, IREE_HAL_ELEMENT_TYPE_FLOAT_32, shape, IREE_ARRAYSIZE(shape), + &arg1_buffer_view)); + iree_hal_buffer_release(arg0_buffer); + iree_hal_buffer_release(arg1_buffer); + // Setup call inputs with our buffers. // TODO(benvanik): make a macro/magic. vm::ref inputs; IREE_ASSERT_OK(iree_vm_list_create(/*element_type=*/nullptr, 2, iree_allocator_system(), &inputs)); - auto arg0_buffer_ref = iree_hal_buffer_move_ref(arg0_buffer); - auto arg1_buffer_ref = iree_hal_buffer_move_ref(arg1_buffer); - IREE_ASSERT_OK(iree_vm_list_push_ref_move(inputs.get(), &arg0_buffer_ref)); - IREE_ASSERT_OK(iree_vm_list_push_ref_move(inputs.get(), &arg1_buffer_ref)); + auto arg0_buffer_view_ref = iree_hal_buffer_view_move_ref(arg0_buffer_view); + auto arg1_buffer_view_ref = iree_hal_buffer_view_move_ref(arg1_buffer_view); + IREE_ASSERT_OK( + iree_vm_list_push_ref_move(inputs.get(), &arg0_buffer_view_ref)); + IREE_ASSERT_OK( + iree_vm_list_push_ref_move(inputs.get(), &arg1_buffer_view_ref)); // Prepare outputs list to accept the results from the invocation. vm::ref outputs; @@ -165,17 +180,17 @@ TEST_P(SimpleEmbeddingTest, RunOnce) { // Get the result buffers from the invocation. IREE_LOG(INFO) << "Retrieving results..."; - auto* ret_buffer = - reinterpret_cast(iree_vm_list_get_ref_deref( - outputs.get(), 0, iree_hal_buffer_get_descriptor())); - ASSERT_NE(nullptr, ret_buffer); + auto* ret_buffer_view = + reinterpret_cast(iree_vm_list_get_ref_deref( + outputs.get(), 0, iree_hal_buffer_view_get_descriptor())); + ASSERT_NE(nullptr, ret_buffer_view); // Read back the results and ensure we got the right values. IREE_LOG(INFO) << "Reading back results..."; iree_hal_buffer_mapping_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map_range(ret_buffer, - IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); + IREE_ASSERT_OK(iree_hal_buffer_map_range( + iree_hal_buffer_view_buffer(ret_buffer_view), IREE_HAL_MEMORY_ACCESS_READ, + 0, IREE_WHOLE_BUFFER, &mapped_memory)); ASSERT_THAT(absl::Span( reinterpret_cast(mapped_memory.contents.data), mapped_memory.contents.data_length / sizeof(float)),