diff --git a/CMakeLists.txt b/CMakeLists.txt index bc487287485cc..7bffd0ba1417f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,7 +115,6 @@ endif() # List of all HAL drivers to be built by default: set(IREE_ALL_HAL_DRIVERS DyLib - Metal VMLA Vulkan ) @@ -126,9 +125,6 @@ if(IREE_HAL_DRIVERS_TO_BUILD STREQUAL "all") # For Apple platforms we need to use Metal instead of Vulkan. if(APPLE) list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Vulkan) - else() - # And Metal isn't available on non-Apple platforms for sure. - list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Metal) endif() endif() message(STATUS "Building HAL drivers: ${IREE_HAL_DRIVERS_TO_BUILD}") diff --git a/bindings/java/com/google/iree/native/context_wrapper.cc b/bindings/java/com/google/iree/native/context_wrapper.cc index a277b3fda0e05..74f2ee50e8ed1 100644 --- a/bindings/java/com/google/iree/native/context_wrapper.cc +++ b/bindings/java/com/google/iree/native/context_wrapper.cc @@ -106,7 +106,7 @@ Status ContextWrapper::InvokeFunction(const FunctionWrapper& function_wrapper, IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create( input_buffer, /*shape=*/&input_element_count, /*shape_rank=*/1, IREE_HAL_ELEMENT_TYPE_FLOAT_32, - iree_allocator_system(), &input_buffer_view)); + &input_buffer_view)); iree_hal_buffer_release(input_buffer); // Marshal the input buffer views through the input VM variant list. @@ -132,13 +132,9 @@ Status ContextWrapper::InvokeFunction(const FunctionWrapper& function_wrapper, reinterpret_cast(iree_vm_list_get_ref_deref( outputs.get(), 0, iree_hal_buffer_view_get_descriptor())); auto* output_buffer = iree_hal_buffer_view_buffer(output_buffer_view); - iree_hal_mapped_memory_t mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map(output_buffer, - IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); - memcpy(output, mapped_memory.contents.data, - mapped_memory.contents.data_length); - iree_hal_buffer_unmap(output_buffer, &mapped_memory); + // TODO(jennik): this is unsafe - we don't know the size of output ptr here! + IREE_RETURN_IF_ERROR(iree_hal_buffer_read_data( + output_buffer, 0, output, iree_hal_buffer_byte_length(output_buffer))); return OkStatus(); } diff --git a/bindings/python/pyiree/rt/function_abi.cc b/bindings/python/pyiree/rt/function_abi.cc index 7ce19de37b308..2621501806242 100644 --- a/bindings/python/pyiree/rt/function_abi.cc +++ b/bindings/python/pyiree/rt/function_abi.cc @@ -590,7 +590,8 @@ void FunctionAbi::RawUnpack(absl::Span descs, throw RaiseValueError( "Could not deref result buffer view (wrong type?)"); } - iree_hal_buffer* raw_buffer = iree_hal_buffer_view_buffer(buffer_view); + iree_hal_buffer_t* raw_buffer = + iree_hal_buffer_view_buffer(buffer_view); if (!raw_buffer) { throw RaiseValueError("Could not deref result buffer (wrong type?)"); } @@ -675,10 +676,10 @@ void FunctionAbi::AllocateResults(absl::Span descs, kScalarTypeToHalElementType[static_cast( desc.scalar.type)]); iree_hal_buffer_view_t* buffer_view; - CheckApiStatus(iree_hal_buffer_view_create( - raw_buffer, dims.data(), dims.size(), element_type, - iree_allocator_system(), &buffer_view), - "Error allocating buffer_view"); + CheckApiStatus( + iree_hal_buffer_view_create(raw_buffer, dims.data(), dims.size(), + element_type, &buffer_view), + "Error allocating buffer_view"); iree_hal_buffer_release(raw_buffer); iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view); @@ -763,10 +764,10 @@ void FunctionAbi::PackBuffer(const RawSignatureParser::Description& desc, absl::InlinedVector dims(py_view.ndim); std::copy(py_view.shape, py_view.shape + py_view.ndim, dims.begin()); iree_hal_buffer_view_t* buffer_view; - CheckApiStatus(iree_hal_buffer_view_create( - raw_buffer, dims.data(), dims.size(), element_type, - iree_allocator_system(), &buffer_view), - "Error allocating buffer_view"); + CheckApiStatus( + iree_hal_buffer_view_create(raw_buffer, dims.data(), dims.size(), + element_type, &buffer_view), + "Error allocating buffer_view"); iree_hal_buffer_release(raw_buffer); iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view); CheckApiStatus(iree_vm_list_push_ref_move(f_args.raw_ptr(), &buffer_view_ref), diff --git a/bindings/python/pyiree/rt/hal.cc b/bindings/python/pyiree/rt/hal.cc index 858e29c4b37ba..a1ba69e0ddd5d 100644 --- a/bindings/python/pyiree/rt/hal.cc +++ b/bindings/python/pyiree/rt/hal.cc @@ -24,7 +24,7 @@ namespace { class HalMappedMemory { public: - HalMappedMemory(iree_hal_mapped_memory_t mapped_memory, + HalMappedMemory(iree_hal_buffer_mapping_t mapped_memory, iree_hal_buffer_view_t* bv) : mapped_memory_(mapped_memory), bv_(bv) { iree_hal_buffer_view_retain(bv_); @@ -32,7 +32,7 @@ class HalMappedMemory { ~HalMappedMemory() { if (bv_) { iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_); - IREE_CHECK_OK(iree_hal_buffer_unmap(buffer, &mapped_memory_)); + iree_hal_buffer_unmap_range(&mapped_memory_); iree_hal_buffer_view_release(bv_); } } @@ -44,10 +44,10 @@ class HalMappedMemory { static HalMappedMemory Create(HalBufferView& bv) { iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr()); iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); - iree_hal_mapped_memory_t mapped_memory; - CheckApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, - 0 /* element_offset */, byte_length, - &mapped_memory), + iree_hal_buffer_mapping_t mapped_memory; + CheckApiStatus(iree_hal_buffer_map_range( + buffer, IREE_HAL_MEMORY_ACCESS_READ, + 0 /* element_offset */, byte_length, &mapped_memory), "Could not map memory"); return HalMappedMemory(mapped_memory, bv.raw_ptr()); } @@ -81,7 +81,7 @@ class HalMappedMemory { } private: - iree_hal_mapped_memory_t mapped_memory_; + iree_hal_buffer_mapping_t mapped_memory_; iree_hal_buffer_view_t* bv_; }; @@ -167,9 +167,6 @@ void SetupHalBindings(pybind11::module m) { py::class_(m, "MappedMemory", py::buffer_protocol()) .def_buffer(&HalMappedMemory::ToBufferInfo); py::class_(m, "HalBuffer") - .def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer, - py::arg("memory_type"), py::arg("usage"), - py::arg("allocation_size")) .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), py::arg("byte_length")) .def("create_view", &HalBuffer::CreateView, py::arg("shape"), diff --git a/bindings/python/pyiree/rt/hal.h b/bindings/python/pyiree/rt/hal.h index beee3b1627a0e..8411b39de0c7e 100644 --- a/bindings/python/pyiree/rt/hal.h +++ b/bindings/python/pyiree/rt/hal.h @@ -92,18 +92,6 @@ class HalBufferView class HalBuffer : public ApiRefCounted { public: - static HalBuffer AllocateHeapBuffer(int32_t memory_type, int32_t usage, - iree_host_size_t allocation_size) { - iree_hal_buffer_t* buffer = nullptr; - CheckApiStatus( - iree_hal_heap_buffer_allocate( - static_cast(memory_type), - static_cast(usage), allocation_size, - iree_allocator_system(), iree_allocator_system(), &buffer), - "Error allocating heap buffer"); - return HalBuffer::CreateRetained(buffer); - } - iree_device_size_t byte_length() const { return iree_hal_buffer_byte_length(raw_ptr()); } @@ -121,7 +109,7 @@ class HalBuffer : public ApiRefCounted { IREE_HAL_ELEMENT_TYPE_NONE, element_size * 8); CheckApiStatus( iree_hal_buffer_view_create(raw_ptr(), shape.s.data(), shape.s.size(), - element_type, iree_allocator_system(), &bv), + element_type, &bv), "Error creating buffer view"); return HalBufferView::CreateRetained(bv); } diff --git a/bindings/python/pyiree/rt/hal_test.py b/bindings/python/pyiree/rt/hal_test.py index 12c046081aba9..cd830268dba2f 100644 --- a/bindings/python/pyiree/rt/hal_test.py +++ b/bindings/python/pyiree/rt/hal_test.py @@ -25,27 +25,6 @@ def testEnums(self): logging.info("MemoryType: %s", rt.MemoryType) logging.info("HOST_VISIBLE: %s", int(rt.MemoryType.HOST_VISIBLE)) - def testAllocateHeap(self): - b = rt.HalBuffer.allocate_heap(memory_type=int(rt.MemoryType.HOST_LOCAL), - usage=int(rt.BufferUsage.ALL), - allocation_size=4096) - self.assertIsNot(b, None) - b.fill_zero(0, 4096) - shape = rt.Shape([1, 1024]) - unused_bv = b.create_view(shape, 4) - - def testStrideCalculation(self): - b = rt.HalBuffer.allocate_heap(memory_type=int(rt.MemoryType.HOST_LOCAL), - usage=int(rt.BufferUsage.ALL), - allocation_size=4096) - self.assertIsNot(b, None) - b.fill_zero(0, 4096) - shape = rt.Shape([16, 1, 8, 4, 2]) - bv = b.create_view(shape, 4) - self.assertEqual( - np.array(bv.map()).strides, - (1 * 8 * 4 * 2 * 4, 8 * 4 * 2 * 4, 4 * 2 * 4, 2 * 4, 4)) - if __name__ == "__main__": absltest.main() diff --git a/bindings/python/pyiree/rt/host_types.cc b/bindings/python/pyiree/rt/host_types.cc index de456eb518aa4..997585aa16c63 100644 --- a/bindings/python/pyiree/rt/host_types.cc +++ b/bindings/python/pyiree/rt/host_types.cc @@ -119,15 +119,14 @@ class PyMappedMemory { } }; - PyMappedMemory(Description desc, iree_hal_mapped_memory_t mapped_memory, + PyMappedMemory(Description desc, iree_hal_buffer_mapping_t mapped_memory, HalBuffer buffer) : desc_(std::move(desc)), mapped_memory_(mapped_memory), buf_(std::move(buffer)) {} ~PyMappedMemory() { if (buf_) { - CheckApiStatus(iree_hal_buffer_unmap(buf_.raw_ptr(), &mapped_memory_), - "Error unmapping memory"); + iree_hal_buffer_unmap_range(&mapped_memory_); } } PyMappedMemory(PyMappedMemory&& other) @@ -139,8 +138,8 @@ class PyMappedMemory { HalBuffer buffer) { iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer.raw_ptr()); - iree_hal_mapped_memory_t mapped_memory; - CheckApiStatus(iree_hal_buffer_map( + iree_hal_buffer_mapping_t mapped_memory; + CheckApiStatus(iree_hal_buffer_map_range( buffer.raw_ptr(), IREE_HAL_MEMORY_ACCESS_READ, 0 /* element_offset */, byte_length, &mapped_memory), "Could not map memory"); @@ -160,7 +159,7 @@ class PyMappedMemory { private: Description desc_; - iree_hal_mapped_memory_t mapped_memory_; + iree_hal_buffer_mapping_t mapped_memory_; HalBuffer buf_; }; diff --git a/build_tools/bazel/iree.bazelrc b/build_tools/bazel/iree.bazelrc index 618a1f989db0b..368e26204c363 100644 --- a/build_tools/bazel/iree.bazelrc +++ b/build_tools/bazel/iree.bazelrc @@ -159,26 +159,45 @@ build:macos_clang --per_file_copt=tensorflow,iree_tf_compiler@-Wno-range-loop-an # https://github.com/google/sanitizers/wiki/AddressSanitizer ############################################################################### -# Turn on asan. Some toolchains make use of the asan feature and we'll directly -# set the appropriate opts. +# ASAN (address sanitizer) +# https://clang.llvm.org/docs/AddressSanitizer.html +build:asan --config=sanitizer build:asan --features=asan build:asan --copt=-fsanitize=address +build:asan --copt=-fsanitize-address-use-after-scope build:asan --linkopt=-fsanitize=address +build:asan --cc_output_directory_tag=asan +build:asan --copt=-DADDRESS_SANITIZER + +# MSAN (memory sanitizer) +# https://clang.llvm.org/docs/MemorySanitizer.html +build:msan --config=sanitizer +build:msan --features=msan +build:msan --copt=-fsanitize=memory +build:msan --copt=-fsanitize-memory-track-origins +build:msan --linkopt=-fsanitize=memory +build:msan --cc_output_directory_tag=msan +build:msan --copt=-DMEMORY_SANITIZER + +# TSAN (thread sanitizer) +# https://clang.llvm.org/docs/ThreadSanitizer.html +build:tsan --config=sanitizer +build:tsan --features=tsan +build:tsan --copt=-fsanitize=thread +build:tsan --linkopt=-fsanitize=thread +build:tsan --cc_output_directory_tag=tsan +build:tsan --copt=-DTHREAD_SANITIZER # Don't strip debug info -build:asan --strip=never +build:sanitizer --strip=never # Ignore settings of `linkopts = ["-static"]` which can screw up the sanitizer. # We don't use this in IREE (that's what linkstatic is for), but it could show # up in dependencies. -build:asan --force_ignore_dash_static -# asan tests tend to take longer, so increase the timeouts -build:asan --test_timeout=120,600,1800,-1 -# Make the outputs easy to find -build:asan --cc_output_directory_tag=asan +build:sanitizer --force_ignore_dash_static +# sanitizer tests tend to take longer, so increase the timeouts +build:sanitizer --test_timeout=120,600,1800,-1 # Get better stack traces -build:asan --copt=-fno-omit-frame-pointer -# This macro define is used by absl -build:asan --copt=-DADDRESS_SANITIZER +build:sanitizer --copt=-fno-omit-frame-pointer ############################################################################### # Architecture specific options diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index e192a3d741835..465f215ba79fb 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -61,7 +61,7 @@ "@renderdoc_api//:renderdoc_app": ["renderdoc_api::renderdoc_app"], "@pffft": ["pffft"], "@cpuinfo//:cpuinfo": ["cpuinfo"], - "@half//:half": ["half"], + "@half//:includes": [], "@vulkan_memory_allocator//:impl_header_only": ["vulkan_memory_allocator"], } diff --git a/build_tools/cmake/iree_macros.cmake b/build_tools/cmake/iree_macros.cmake index 22d4ce94c3a32..7f1c85a44faaf 100644 --- a/build_tools/cmake/iree_macros.cmake +++ b/build_tools/cmake/iree_macros.cmake @@ -306,10 +306,10 @@ function(iree_add_test_environment_properties TEST_NAME) # # Tests which only depend on a compiler target backend or a runtime HAL # driver, but not both, should generally use a different method of filtering. - if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} OR NOT ${IREE_HAL_DRIVER_VULKAN}) + if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}" OR NOT "${IREE_HAL_DRIVER_VULKAN}") set_property(TEST ${TEST_NAME} APPEND PROPERTY ENVIRONMENT "IREE_VULKAN_DISABLE=1") endif() - if(NOT ${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT} OR NOT ${IREE_HAL_DRIVER_DYLIB}) + if(NOT "${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}" OR NOT "${IREE_HAL_DRIVER_DYLIB}") set_property(TEST ${TEST_NAME} APPEND PROPERTY ENVIRONMENT "IREE_LLVMAOT_DISABLE=1") endif() endfunction() diff --git a/docs/design_docs/metal_hal_driver.md b/docs/design_docs/metal_hal_driver.md deleted file mode 100644 index 064e2bb3e1c85..0000000000000 --- a/docs/design_docs/metal_hal_driver.md +++ /dev/null @@ -1,371 +0,0 @@ -# Metal HAL Driver - -This document lists technical details regarding the Metal HAL driver. Note that -the Metal HAL driver is working in progress; this document is expected to be -updated along the way. - -IREE provides a [Hardware Abstraction Layer (HAL)][iree-hal] as a common -interface to different compute accelerators. IREE HAL's design draws inspiration -from modern GPU architecture and APIs; so implementing a HAL driver using modern -GPU APIs is generally straightforward. This applies to the Metal HAL driver. - -## Overall Design Choices - -### Metal Versions - -The Metal HAL driver expects Metal 2+. Metal 2 introduces useful features like -argument buffer, performance shaders, and others, that can improve performance -and make IREE HAL implementation simpler. Metal 2 was released late 2017 and are -supported since macOS High Sierra and iOS 11. It is already dominant -([macOS][macos-version-share], [iOS][ios-version-share]) right now. - -### Programming Languages and Libraries - -The Metal HAL driver lives under the [`iree/hal/metal/`][iree-metal] directory. -Header (`.h`) and implementation (`.mm`) files are put adjacent to each other. - -The Metal framework only exposes Objective-C or Swift programming language APIs. -Metal HAL driver needs to inherit from common HAL abstraction classes, which are -C++. So we use [Objective-C++][objcxx] for implementing the Metal HAL driver. -The headers try to stay with pure C/C++ syntax when possible, except for -`#import ` and using Metal `id` types. - -### Object Lifetime Management - -Objective-C uses refcount for tracking object lifetime and managing memory. This -is traditionally done manually by sending `retain` and `release` messages to -Objective-C objects. Modern Objective-C allows developers to opt in to use -[Automatic Reference Counting][objc-arc] to let the compiler to automatically -deduce and insert `retain`/`release` where possible to simplify the burdern of -manual management. - -We don't use ARC in the Metal HAL driver given that IREE has its own object -[refcount][iree-refptr] and lifetime management mechanism. Metal HAL GPU objects -are tracked with that to be consistent with others. Each Metal HAL GPU object -`retain`s the underlying Metal `id` object on construction and `release`s -on destruction. - -## GPU Objects - -Metal is one of the main modern GPU APIs that provide more explicit control over -the hardware. The mapping between IREE HAL classes and Metal protocols are -relatively straightforward: - -IREE HAL Class | Metal Protocol -:----------------------------------------: | :------------: -[`hal::Driver`][hal-driver] | N/A -[`hal::Device`][hal-device] | [`MTLDevice`][mtl-device] -[`hal::CommandQueue`][hal-command-queue] | [`MTLCommandQueue`][mtl-command-queue] -[`hal::CommandBuffer`][hal-command-buffer] | [`MTLCommandBuffer`][mtl-command-buffer] -[`hal::Semaphore`][hal-semaphore] | [`MTLSharedEvent`][mtl-shared-event] -[`hal::Allocator`][hal-allocator] | N/A -[`hal::Buffer`][hal-buffer] | [`MTLBuffer`][mtl-buffer] -[`hal::Executable`][hal-executable] | [`MTLLibrary`][mtl-library] -[`hal::ExecutableCache`][hal-executable-cache] | N/A -[`hal::DescriptorSetLayout`][hal-descriptor-set-layout] | N/A -[`hal::DescriptorSet`][hal-descriptor-set] | N/A -[`hal::ExecutableLayout`][hal-executable-layout] | N/A - -In the following subsections, we go over each pair to provide more details. - -### Driver - -There is no native driver abstraction in Metal. IREE's Metal HAL driver still -provides a [`hal::metal::MetalDriver`][metal-driver] subclass inheriting from -common [`hal::Driver`][hal-driver] class. `hal::metal::MetalDriver` just -`retain`s all available Metal devices in the system during its lifetime to -provide similar interface as other HAL drivers. - -### Device - -[`hal::metal::MetalDevice`][metal-device] inherits [`hal::Device`][hal-device] -to provide the interface to Metal GPU device by wrapping a `id`. Upon -construction, `hal::metal::MetalDevice` creates and retains one queue for both -dispatch and transfer during its lifetime. - -Metal requres command buffers to be created from a `MTLCommandQueue`. In IREE -HAL, command buffers are directly created from the `hal::Device`. -`hal::metal::MetalDevice` chooses the proper queue to create the command buffer -under the hood. - -### Command queue - -IREE HAL command queue follows Vulkan for modelling submission. Specifically, -`hal::CommandQueue::Submit()` takes a `SubmissionBatch`, which contains a list -of waiting `hal::Semaphore`s, a list of command buffers, and a list signaling -`hal::Semaphore`s. There is no direct mapping in Metal; so -[`hal::metal::MetalCommandQueue`][metal-command-queue] performs the submission -in three steps: - -1. Create a new `MTLCommandBuffer` to `encodeWaitForEvent:value` for all - waiting `hal::Semaphore`s and commit this command buffer. -1. Commit all command buffers in the `SubmissionBatch`. -1. Create a new `MTLCommandBuffer` to `encodeSignalEvent:value` for all - signaling `hal::Semaphore`s and commit this command buffer. - -There is also no direct `WaitIdle()` for -[`MTLCommandQueue`][mtl-command-queue]s. `hal::metal::MetalCommandQueue` -implements `WaitIdle()` by committing an empty `MTLCommandBuffer` and -registering a complete handler for it to signal a semaphore to wake the current -thread, which is put into sleep by waiting on the semaphore. - -### Command buffer - -In Metal, commands are recorded into a command buffer with three different kinds -of [command encoders][mtl-command-encoder]: `MTLRenderCommandEncoder`, -`MTLComputeCommandEncoder`, `MTLBlitCommandEncoder`, and -`MTLParallelRenderCommandEncoder`. Each encoder has its own create/end call. -There is no overall begin/end call for the whold command buffer. So even -[`hal::metal::MetalCommandBuffer`][metal-command-buffer] implements an overall -`Begin()`/`End()` call, under the hood it may create a new command encoder for a -specific API call. - -### Timeline semaphore - -[`hal::Semaphore`][hal-semaphore] allows host->device, device->host, host->host, -and device->device synchronization. It maps to Vulkan timeline semaphore. In -Metal world, the counterpart would be [`MTLSharedEvent`][mtl-shared-event]. Most -of the `hal::Semaphore` APIs are simple to implement in -[`MetalSharedEvent`][metal-shared-event], with `Wait()` as an exception. A -listener is registered on the `MTLSharedEvent` with -`notifyListener:atValue:block:` to singal a semaphore to wake the current -thread, which is put into sleep by waiting on the semaphore. - -### Allocator - -At the moment the Metal HAL driver just has a very simple -[`hal::Allocator`][hal-allocator] implementation. It just wraps a `MTLDevice` -and redirects all allocation requests to the `MTLDevice`. No page/pool/slab or -whatever. This is only meant to get started. In the future we should have a -better memory allocation library, probably by layering the -[Vulkan Memory Allocator][vma] on top of [`MTLHeap`][mtl-heap]. - -### Buffer - -IREE [`hal::Buffer`][hal-buffer] maps Metal `MTLBuffer`. See -[Memory Management](#memory-management) for more details. - -### Executable - -IREE [`hal::Executable`][hal-executable] represents a GPU program archive with -a driver-defined format. It maps naturally to Metal [`MTLLibrary`][mtl-library]. -An entry point in a `MTLLibrary` is a [`MTLFunction`][mtl-function]. We define -[`hal::metal::MetalKernelLibrary`][metal-kernel-library] to wrap around a -`MTLLibrary`, its `MTLFunction`s, and also `MTLComputePipelineState` objects -constructed from `MTLFunction`s. - -### Executable cache - -IREE [`hal::ExecutableCache`][hal-executable-cache] is modelling a cache of -preprared GPU executables for a particular device. At the moment the Metal -HAL driver does not peforming any cache on GPU programs; it simply reads the -program from the FlatBuffer and hands it over to Metal driver. - -### DescriptorSetLayout, DescriptorSet, ExecutableLayout - -See [Resource descriptors](#resource-descriptors) for more details. - -## Compute Pipeline - -### Shader/kernel compilation - -Metal has [Metal Shading Language (MSL)][msl-spec] for authoring graphics -shaders and compute kernels. MSL source code can be directly consumed by the -Metal framework at run-time; it can also be compiled first into an opaque -library using [command-line tools][msl-cl-library] at build-time. - -IREE uses compilers to compile ML models expressed with high-level op semantics -down to GPU native source format. This is also the case for the Metal HAL -driver. Metal does not provide an open intermediate language; we reuse the -[SPIR-V code generation pipeline][spirv-codegen] and then cross compile the -generated SPIR-V into MSL source with [SPIRV-Cross][spirv-cross]. This is -actually a fair common practice for targeting multiple GPU APIs in graphics -programming world. For example, the Vulkan implmenation in macOS/iOs, -[MoltenVK][moltenvk], is also doing the same for shaders/kernels. The path -is actually quite robust, as demonstrated by various games on top of MoltenVK. - -Therefore, in IREE, we have a [`MetalSPIRVTargetBackend`][metal-spirv-target], -which pulls in the normal MHLO to Linalg and Linalg to SPIR-V passes to form -the compilation pipeline. The difference would be to provide a suitable -SPIR-V target environment to drive the compilation, which one can derive from -the Metal GPU families to target. (Not implemented yet; TODO for the future.) -The serialization step differs from -[`VulkanSPIRVTargetBackend`][vulkan-spirv-target] too: following the normal -SPIR-V serialization step, we additionally need to invoke SPRIV-Cross to -cross compile the generated SPIR-V into MSL, and then compile and/or serialize -the MSL source/library. - -IREE uses [FlatBuffer][flatbuffer] to encode the whole workload module, -including both GPU shader/kernel (called executable in IREE terminology) and -CPU scheduling logic. The GPU executables are embedded as part of the module's -FlatBuffer, which are [`mmap`][mmap]ped when IREE runs. - -For the Metal HAL driver, it means we need to embed the MSL kernels inside the -module FlatBuffer. Right now we just encode the MSL source strings and compile -them at Metal run-time. In the future this should be changed to allow encoding -the library instead. - -### Workgroup/threadgroup size - -When dispatching a compute kernel in Metal, we need to specify the number of -thread groups in grid and the number of threads in thread group. Both are -3-D vectors. IREE HAL, which follows Vulkan, calls them workgroup count and -workgroup size, respectively. - -In Vulkan programming model, workgroup count and workgroup size are specified at -different places: the former is given when invoking -[`vkCmdDispatch()`][vulkan-cmd-dispatch], while the later is encoded in the -dispatched SPIR-V code. This split does not match the Metal model, where we -specify both in the API with `dispatchThreads:threadsPerThreadgroup:`. - -As said in [shader/kernel compilation](#shader-kernel-compilation), MSL kernels -are cross compiled from SPIR-V code and then embeded in the module FlatBuffer. -The module FlatBuffer provides us a way to convey the threadgroup/workgroup size -information extracted from the SPIR-V code. We encode an additional 3-D vector -for each entry point and use it as the threadgroup size when later dispatching -the `MTLFunction` corresponding to the entry point. - -### Resource descriptors - -A descriptor is an opaque handle pointing to a resource that is accessed in -the compute kernel. IREE's HAL is inspired by the Vulkan API; it models several -concepts related to GPU resource management explicitly: - -* [`hal::DescriptorSetLayout`][hal-descriptor-set-layout]: a schema for - describing an array of descriptor bindings. Each descriptor binding specifies - the resource type, access mode and other information. -* [`hal::DescriptorSet`][hal-descriptor-set]: a concrete set of resources that - gets bound to a compute pipeline in a batch. It must match the - `DescriptorSetLayout` describing its layout. `DescriptorSet` can be thought as - the "object" from the `DescriptorSetLayout` "class". -* [`hal::ExecutableLayout`][hal-executable-layout]: a schema for describing all - the resources accessed by a compute pipeline. It includes zero or more - `DescriptorSetLayout`s and (optional) push constants. - -One can create `DescriptorSetLayout`, `DescriptorSet`, and `ExecutableLayout` -objects beforehand to avoid incurring overhead during tight computing loops -and also amortize costs by sharing these objects. However, this isn't totally -matching Metal's paradigm. - -In the Metal framework, the closest concept to `DescriptorSet` would be [argument -buffer][mtl-argument-buffer]. There is no direct correspondence to -`DescriptorSetLayout` and `ExecutableLayout`. Rather, the layout is implicitly -encoded in Metal shaders as MSL structs. The APIs for creating argument buffers -do not encourage early creation without pipelines: one typically creates them -for each `MTLFunction`. Besides, unlike Vulkan where different descriptor sets -can have the same binding number, in Metal even if we have multiple argument -buffers, the indices for resources are in the same namespace and are typically -assigned sequentially. That means we need to remap `DescriptorSet`s with a set -number greater than zero by applying an offset to each of its bindings. - -All of this means it's better to defer the creation of the argument buffer -until the point of compute pipeline creation and dispatch. Therefore, although -the Metal HAL driver provides the implementation for `DescriptorSet` -(i.e., `hal::metal::MetalArgumentBuffer`), `DescriptorSetLayout` (i.e., -`hal::metal::MetalArgumentBufferLayout`), and `ExecutableLayout` (i.e., -`hal::metal::MetalPipelineArgumentBufferLayout`), they are just containers -holding the information up until the [command buffer -dispatch](#command-buffer-dispatch) time. - -With the above said, the overall idea is still to map one descriptor set to one -argument buffer. It just means we need to condense and remap the bindings. - -### Command buffer dispatch - -`MetalCommandBuffer::Dispatch()` performs the following steps with the current -active `MTLComputeCommandEncoder`: - -1. Bind the `MTLComputePipelineState` for the current entry function queried - from `MetalKernelLibrary`. -1. For each bound descriptor set at set #`S`: - 1. Create a [`MTLArgumentEncoder`][mtl-argument-encoder] for encoding an - associated argument `MTLBuffer`. - 1. For each bound resource buffer at binding #`B` in this descriptor set, - encode it to the argument buffer index #`B` with - `setBuffer::offset::atIndex:` and inform the `MTLComputeCommandEncoder` - that the dispatch will use this resource with `useResource:usage:`. - 1. Set the argument `MTLBuffer` to buffer index #`S`. -1. Dispatch with `dispatchThreadgroups:threadsPerThreadgroup:`. - -(TODO: condense and remap bindings) - -## Memory Management - -### Storage type - -Metal provides four [`MTLStorageMode`][mtl-storage-mode] options: - -* `MTLStorageModeShared`: The resource is stored in system memory and is - accessible to both the CPU and the GPU. -* `MTLStorageModeManaged`: The CPU and GPU may maintain separate copies of the - resource, and any changes must be explicitly synchronized. -* `MTLStorageModePrivate`: The resource can be accessed only by the GPU. -* `MTLStorageMemoryless`: The resource’s contents can be accessed only by the - GPU and only exist temporarily during a render pass. - -Among them, `MTLStorageModeManaged` is only available on macOS. - -IREE HAL defines serveral [`MemoryType`][hal-buffer]. They need to map to the -above storage modes: - -* If `kDeviceLocal` but not `kHostVisible`, `MTLStorageModePrivate` is chosen. -* If `kDeviceLocal` and `kHostVisible`: - * If macOS, `MTLStorageModeManaged` can be chosen. - * Otherwise, `MTLStorageModeShared` is chosen. -* If not `DeviceLocal` but `kDeviceVisible`, `MTLStorageModeShared` is chosen. -* If not `kDeviceLocal` and not `kDeviceVisible`, `MTLStorageModeShared` is - chosen. (TODO: We should probably use host buffer here.) - -IREE HAL also allows to create buffers with `kHostCoherent` bit. This may still -be backed by `MTLStorageModeManaged` `MTLBuffer`s in macOS. To respect the -`kHostCoherent` protocol, the Metal HAL driver will perform necessary -`InValidate`/`Flush` operations automatically under the hood. - -[macos-version-share]: https://gs.statcounter.com/macos-version-market-share/desktop/worldwide -[ios-version-share]: https://developer.apple.com/support/app-store/ -[iree-hal]: https://github.com/google/iree/tree/main/iree/hal -[iree-metal]: https://github.com/google/iree/tree/main/iree/hal/metal -[iree-refptr]: https://github.com/google/iree/blob/main/iree/base/ref_ptr.h -[hal-allocator]: https://github.com/google/iree/blob/main/iree/hal/allocator.h -[hal-buffer]: https://github.com/google/iree/blob/main/iree/hal/buffer.h -[hal-command-queue]: https://github.com/google/iree/blob/main/iree/hal/command_queue.h -[hal-command-buffer]: https://github.com/google/iree/blob/main/iree/hal/command_buffer.h -[hal-descriptor-set]: https://github.com/google/iree/blob/main/iree/hal/descriptor_set.h -[hal-descriptor-set-layout]: https://github.com/google/iree/blob/main/iree/hal/descriptor_set_layout.h -[hal-executable-layout]: https://github.com/google/iree/blob/main/iree/hal/executable_layout.h -[hal-device]: https://github.com/google/iree/blob/main/iree/hal/device.h -[hal-driver]: https://github.com/google/iree/blob/main/iree/hal/driver.h -[hal-executable]: https://github.com/google/iree/blob/main/iree/hal/executable.h -[hal-executable-cache]: https://github.com/google/iree/blob/main/iree/hal/executable_cache.h -[hal-semaphore]: https://github.com/google/iree/blob/main/iree/hal/semaphore.h -[metal-command-queue]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_command_queue.h -[metal-command-buffer]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_command_buffer.h -[metal-device]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_device.h -[metal-driver]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_driver.h -[metal-kernel-library]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_kernel_library.h -[metal-shared-event]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_shared_event.h -[metal-spirv-target]: https://github.com/google/iree/tree/hal-metal/iree/compiler/Dialect/HAL/Target/MetalSPIRV -[mtl-argument-buffer]: https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc -[mtl-argument-encoder]: https://developer.apple.com/documentation/metal/mtlargumentencoder?language=objc -[mtl-buffer]: https://developer.apple.com/documentation/metal/mtlbuffer?language=objc -[mtl-command-buffer]: https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc -[mtl-command-encoder]: https://developer.apple.com/documentation/metal/mtlcommandencoder?language=objc -[mtl-command-queue]: https://developer.apple.com/documentation/metal/mtlcommandqueue?language=objc -[mtl-device]: https://developer.apple.com/documentation/metal/mtldevice?language=objc -[mtl-function]: https://developer.apple.com/documentation/metal/mtlfunction?language=objc -[mtl-heap]: https://developer.apple.com/documentation/metal/mtlheap?language=objc -[mtl-library]: https://developer.apple.com/documentation/metal/mtllibrary?language=objc -[mtl-shared-event]: https://developer.apple.com/documentation/metal/mtlsharedevent?language=objc -[mtl-storage-mode]: https://developer.apple.com/documentation/metal/mtlstoragemode?language=objc -[msl-spec]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf -[msl-cl-library]: https://developer.apple.com/documentation/metal/libraries/building_a_library_with_metal_s_command-line_tools?language=objc -[objc-arc]: https://en.wikipedia.org/wiki/Automatic_Reference_Counting -[objcxx]: https://en.wikipedia.org/wiki/Objective-C#Objective-C++ -[flatbuffer]: https://google.github.io/flatbuffers/ -[mmap]: https://en.wikipedia.org/wiki/Mmap -[moltenvk]: https://github.com/KhronosGroup/MoltenVK -[spirv-codegen]: https://google.github.io/iree/design-docs/codegen-passes -[spirv-cross]: https://github.com/KhronosGroup/SPIRV-Cross -[vma]: https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator -[vulkan-spirv-target]: https://github.com/google/iree/tree/hal-metal/iree/compiler/Dialect/HAL/Target/VulkanSPIRV -[vulkan-cmd-dispatch]: https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/vkCmdDispatch.html diff --git a/docs/design_docs/simple_ir_walkthrough.md b/docs/design_docs/simple_ir_walkthrough.md index 68f51a895292a..5d7bb75f81368 100644 --- a/docs/design_docs/simple_ir_walkthrough.md +++ b/docs/design_docs/simple_ir_walkthrough.md @@ -414,10 +414,8 @@ class SimpleMulModule : public iree::vm::Module { // Matches IR: // %0 = "iree_ll_seq.alloc_heap"() : () -> memref<4xf32> ASSIGN_OR_RETURN(auto result, device->allocator()->Allocate( - iree::hal::MemoryType::kHostLocal | - iree::hal::MemoryType::kDeviceVisible, - iree::hal::BufferUsage::kDispatch | - iree::hal::BufferUsage::kMapping)); + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_MAPPING)); auto result_view = iree::hal::BufferView( std::move(result), {4}, sizeof(float)); @@ -468,8 +466,8 @@ class SimpleMulModule : public iree::vm::Module { // Matches IR: // iree_ll_seq.static_dispatch ... ASSIGN_OR_RETURN(auto cmd, device->CreateCommandBuffer( - iree::hal::CommandBufferMode::kOneShot, - iree::hal::CommandCategory::kDispatch)); + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH)); RETURN_IF_ERROR(cmd->Begin()); iree::hal::DispatchRequest dispatch_request; dispatch_request.executable = device_executable(device, 0); diff --git a/iree/base/BUILD b/iree/base/BUILD index d4362934fad76..bedf00bc2c430 100644 --- a/iree/base/BUILD +++ b/iree/base/BUILD @@ -47,21 +47,25 @@ cc_library( hdrs = [ "alignment.h", "atomics.h", - "bitfield.h", "debugging.h", "math.h", "memory.h", "target_platform.h", ], deps = [ + "//iree/base/internal:atomics", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/types:span", # bitfield.h ], ) +cc_library( + name = "target_platform", + hdrs = ["target_platform.h"], +) + cc_test( - name = "bitfield_test", - srcs = ["bitfield_test.cc"], + name = "atomics_test", + srcs = ["atomics_test.cc"], deps = [ ":core_headers", "//iree/testing:gtest", @@ -407,24 +411,6 @@ cc_test( ], ) -cc_library( - name = "time", - hdrs = ["time.h"], - deps = [ - ":api", - ], -) - -cc_test( - name = "time_test", - srcs = ["time_test.cc"], - deps = [ - ":time", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - ], -) - cc_library( name = "threading", srcs = [ diff --git a/iree/base/CMakeLists.txt b/iree/base/CMakeLists.txt index 6620197119d4b..cad0757d53994 100644 --- a/iree/base/CMakeLists.txt +++ b/iree/base/CMakeLists.txt @@ -35,26 +35,22 @@ iree_cc_library( HDRS "alignment.h" "atomics.h" - "bitfield.h" "debugging.h" "math.h" "memory.h" "target_platform.h" DEPS absl::core_headers - absl::span + iree::base::internal::atomics PUBLIC ) -iree_cc_test( +iree_cc_library( NAME - bitfield_test - SRCS - "bitfield_test.cc" - DEPS - ::core_headers - iree::testing::gtest - iree::testing::gtest_main + target_platform + HDRS + "target_platform.h" + PUBLIC ) iree_cc_library( @@ -106,6 +102,17 @@ iree_cc_test( iree::testing::gtest_main ) +iree_cc_test( + NAME + atomics_test + SRCS + "atomics_test.cc" + DEPS + ::core_headers + iree::testing::gtest + iree::testing::gtest_main +) + iree_cc_library( NAME dynamic_library @@ -438,27 +445,6 @@ iree_cc_test( iree::testing::gtest_main ) -iree_cc_library( - NAME - time - HDRS - "time.h" - DEPS - ::api - PUBLIC -) - -iree_cc_test( - NAME - time_test - SRCS - "time_test.cc" - DEPS - ::time - iree::testing::gtest - iree::testing::gtest_main -) - iree_cc_library( NAME threading diff --git a/iree/base/api.h b/iree/base/api.h index cf5909cd862ef..ada2dec13a16b 100644 --- a/iree/base/api.h +++ b/iree/base/api.h @@ -206,7 +206,7 @@ typedef size_t iree_host_size_t; // Size, in bytes, of a buffer on devices. typedef uint64_t iree_device_size_t; // Whole length of the underlying buffer. -#define IREE_WHOLE_BUFFER (iree_device_size_t(-1)) +#define IREE_WHOLE_BUFFER ((iree_device_size_t)(-1)) // TODO(benvanik): switch to static_cast/reinterpret_cast when in C++. // TODO(benvanik): see if we can shove in static_asserts somehow? @@ -232,6 +232,11 @@ static inline iree_host_size_t iree_math_align(iree_host_size_t value, #define iree_min(lhs, rhs) ((lhs) <= (rhs) ? (lhs) : (rhs)) #define iree_max(lhs, rhs) ((lhs) <= (rhs) ? (rhs) : (lhs)) +// Returns true if any bit from |rhs| is set in |lhs|. +#define iree_any_bit_set(lhs, rhs) (((lhs) & (rhs)) != 0) +// Returns true iff all bits from |rhs| are set in |lhs|. +#define iree_all_bits_set(lhs, rhs) (((lhs) & (rhs)) == (rhs)) + //===----------------------------------------------------------------------===// // Byte buffers and memory utilities //===----------------------------------------------------------------------===// @@ -377,7 +382,7 @@ iree_string_view_append_to_buffer(iree_string_view_t source_value, #if !defined(IREE_STATUS_MODE) #ifdef NDEBUG // Release mode: just source location. -#define IREE_STATUS_MODE 1 +#define IREE_STATUS_MODE 2 #else // Debug mode: annotations and stack traces. #define IREE_STATUS_MODE 3 @@ -446,7 +451,7 @@ typedef enum { // meaning `return iree_status_from_code(IREE_STATUS_INTERNAL);` (etc) is valid, // though not as useful as constructing via iree_make_status (which captures // additional info). -typedef void* iree_status_t; +typedef struct iree_status_handle_t* iree_status_t; // Returns an iree_status_t from the an iree_status_code_t. #define iree_status_from_code(code) \ diff --git a/iree/base/atomic_slist.c b/iree/base/atomic_slist.c index c495c7ecde089..025eebaaa5c8e 100644 --- a/iree/base/atomic_slist.c +++ b/iree/base/atomic_slist.c @@ -59,7 +59,10 @@ void iree_atomic_slist_push_unsafe(iree_atomic_slist_t* list, iree_atomic_slist_entry_t* iree_atomic_slist_pop(iree_atomic_slist_t* list) { iree_slim_mutex_lock(&list->mutex); iree_atomic_slist_entry_t* entry = list->head; - list->head = entry ? entry->next : NULL; + if (entry != NULL) { + list->head = entry->next; + entry->next = NULL; + } iree_slim_mutex_unlock(&list->mutex); return entry; } diff --git a/iree/base/atomics.h b/iree/base/atomics.h index 2c49d00a522cf..cd5008c8a7f41 100644 --- a/iree/base/atomics.h +++ b/iree/base/atomics.h @@ -57,229 +57,30 @@ extern "C" { #define iree_hardware_constructive_interference_size 64 //============================================================================== -// Atomics using the Win32 Interlocked* APIs +// C11-compatible atomic operations //============================================================================== -#if defined(IREE_COMPILER_MSVC) - -typedef enum iree_memory_order_e { - iree_memory_order_relaxed, - iree_memory_order_consume, - iree_memory_order_acquire, - iree_memory_order_release, - iree_memory_order_acq_rel, - iree_memory_order_seq_cst, -} iree_memory_order_t; - -#define IREE_ATOMIC_VAR_INIT(value) \ - { (value) } - -typedef struct { - int32_t __val; -} iree_atomic_int32_t; -typedef struct { - int64_t __val; -} iree_atomic_int64_t; -// typedef __declspec(align(16)) struct { -// uint64_t __val[2]; -// } iree_atomic_int128_t; - -#define iree_atomic_load_int32(object, order) \ - InterlockedExchangeAdd((volatile LONG*)object, 0) -#define iree_atomic_store_int32(object, desired, order) \ - InterlockedExchange((volatile LONG*)object, desired) -#define iree_atomic_fetch_add_int32(object, operand, order) \ - InterlockedExchangeAdd((volatile LONG*)object, operand) -#define iree_atomic_fetch_sub_int32(object, operand, order) \ - InterlockedExchangeAdd((volatile LONG*)object, -((int32_t)(operand))) -#define iree_atomic_fetch_and_int32(object, operand, order) \ - InterlockedAnd((volatile LONG*)object, operand) -#define iree_atomic_fetch_or_int32(object, operand, order) \ - InterlockedOr((volatile LONG*)object, operand) -#define iree_atomic_fetch_xor_int32(object, operand, order) \ - InterlockedXor((volatile LONG*)object, operand) -#define iree_atomic_exchange_int32(object, desired, order) \ - InterlockedExchange((volatile LONG*)object, desired) -#define iree_atomic_compare_exchange_strong_int32(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int32_impl( \ - (volatile iree_atomic_int32_t*)(object), (int32_t*)(expected), \ - (int32_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_int32 \ - iree_atomic_compare_exchange_strong_int32 - -#define iree_atomic_load_int64(object, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, 0) -#define iree_atomic_store_int64(object, desired, order) \ - InterlockedExchange64((volatile LONG64*)object, desired) -#define iree_atomic_fetch_add_int64(object, operand, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, operand) -#define iree_atomic_fetch_sub_int64(object, operand, order) \ - InterlockedExchangeAdd64((volatile LONG64*)object, -(operand)) -#define iree_atomic_fetch_and_int64(object, operand, order) \ - InterlockedAnd64((volatile LONG64*)object, operand) -#define iree_atomic_fetch_or_int64(object, operand, order) \ - InterlockedOr64((volatile LONG64*)object, operand) -#define iree_atomic_fetch_xor_int64(object, operand, order) \ - InterlockedXor64((volatile LONG64*)object, operand) -#define iree_atomic_exchange_int64(object, desired, order) \ - InterlockedExchange64((volatile LONG64*)object, desired) -#define iree_atomic_compare_exchange_strong_int64(object, expected, desired, \ - order_succ, order_fail) \ - iree_atomic_compare_exchange_strong_int64_impl( \ - (volatile iree_atomic_int64_t*)(object), (int64_t*)(expected), \ - (int64_t)(desired), (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_int64 \ - iree_atomic_compare_exchange_strong_int64 - -#define iree_atomic_thread_fence(order) MemoryBarrier() - -static inline bool iree_atomic_compare_exchange_strong_int32_impl( - volatile iree_atomic_int32_t* object, int32_t* expected, int32_t desired, - iree_memory_order_t order_succ, iree_memory_order_t order_fail) { - int32_t expected_value = *expected; - int32_t old_value = InterlockedCompareExchange((volatile LONG*)object, - desired, expected_value); - if (old_value == expected_value) { - return true; - } else { - *expected = old_value; - return false; - } -} +// We expose support for int32_t, int64_t, and intptr_t (which aliases one of +// int32_t or int64_t). This limits what we need to port and it's really all +// that's needed anyway. -static inline bool iree_atomic_compare_exchange_strong_int64_impl( - volatile iree_atomic_int64_t* object, int64_t* expected, int64_t desired, - iree_memory_order_t order_succ, iree_memory_order_t order_fail) { - int64_t expected_value = *expected; - int64_t old_value = InterlockedCompareExchange64((volatile LONG64*)object, - desired, expected_value); - if (old_value == expected_value) { - return true; - } else { - *expected = old_value; - return false; - } -} +#if defined(IREE_COMPILER_MSVC) -#define iree_atomic_thread_fence(order) MemoryBarrier() +// Atomics using the Win32 Interlocked* APIs. +#include "iree/base/internal/atomics_msvc.h" -//============================================================================== -// C11 atomics using Clang builtins -//============================================================================== #elif defined(IREE_COMPILER_CLANG) -typedef enum iree_memory_order_e { - iree_memory_order_relaxed = __ATOMIC_RELAXED, - iree_memory_order_consume = __ATOMIC_CONSUME, - iree_memory_order_acquire = __ATOMIC_ACQUIRE, - iree_memory_order_release = __ATOMIC_RELEASE, - iree_memory_order_acq_rel = __ATOMIC_ACQ_REL, - iree_memory_order_seq_cst = __ATOMIC_SEQ_CST, -} iree_memory_order_t; - -#define IREE_ATOMIC_VAR_INIT(value) (value) - -typedef _Atomic int32_t iree_atomic_int32_t; -typedef _Atomic int64_t iree_atomic_int64_t; -// TODO(#3453): check for __int128 support before using -// typedef _Atomic __int128 iree_atomic_int128_t; +// C11 atomics using Clang builtins. +#include "iree/base/internal/atomics_clang.h" -#define iree_atomic_load_auto(object, order) \ - __c11_atomic_load((object), (order)) -#define iree_atomic_store_auto(object, desired, order) \ - __c11_atomic_store((object), (desired), (order)) -#define iree_atomic_fetch_add_auto(object, operand, order) \ - __c11_atomic_fetch_add((object), (operand), (order)) -#define iree_atomic_fetch_sub_auto(object, operand, order) \ - __c11_atomic_fetch_sub((object), (operand), (order)) -#define iree_atomic_fetch_and_auto(object, operand, order) \ - __c11_atomic_fetch_and((object), (operand), (order)) -#define iree_atomic_fetch_or_auto(object, operand, order) \ - __c11_atomic_fetch_or((object), (operand), (order)) -#define iree_atomic_fetch_xor_auto(object, operand, order) \ - __c11_atomic_fetch_xor((object), (operand), (order)) -#define iree_atomic_exchange_auto(object, operand, order) \ - __c11_atomic_exchange((object), (operand), (order)) -#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ - order_succ, order_fail) \ - __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ - (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ - order_succ, order_fail) \ - __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ - (order_succ), (order_fail)) - -#define iree_atomic_thread_fence(order) __c11_atomic_thread_fence(order) - -//============================================================================== -// Atomics for GCC (compatible with both C and C++) -//============================================================================== #elif defined(IREE_COMPILER_GCC) -typedef enum iree_memory_order_e { - iree_memory_order_relaxed = __ATOMIC_RELAXED, - iree_memory_order_consume = __ATOMIC_CONSUME, - iree_memory_order_acquire = __ATOMIC_ACQUIRE, - iree_memory_order_release = __ATOMIC_RELEASE, - iree_memory_order_acq_rel = __ATOMIC_ACQ_REL, - iree_memory_order_seq_cst = __ATOMIC_SEQ_CST, -} iree_memory_order_t; +// Atomics for GCC (compatible with both C and C++). +#include "iree/base/internal/atomics_gcc.h" -#define IREE_ATOMIC_VAR_INIT(value) (value) - -typedef int32_t iree_atomic_int32_t; -typedef int64_t iree_atomic_int64_t; -// typedef __int128 iree_atomic_int128_t; - -#ifdef __cplusplus -// Equiv to C++ auto keyword in C++ mode. -#define __iree_auto_type auto -#else -// Only defined in C mode. -#define __iree_auto_type __auto_type -#endif - -#define iree_atomic_load_auto(object, order) \ - __extension__({ \ - __iree_auto_type __atomic_load_ptr = (object); \ - __typeof__(*__atomic_load_ptr) __atomic_load_tmp; \ - __atomic_load(__atomic_load_ptr, &__atomic_load_tmp, (order)); \ - __atomic_load_tmp; \ - }) -#define iree_atomic_store_auto(object, desired, order) \ - __extension__({ \ - __iree_auto_type __atomic_store_ptr = (object); \ - __typeof__(*__atomic_store_ptr) __atomic_store_tmp = (desired); \ - __atomic_store(__atomic_store_ptr, &__atomic_store_tmp, (order)); \ - }) -#define iree_atomic_fetch_add_auto(object, operand, order) \ - __atomic_fetch_add((object), (operand), (order)) -#define iree_atomic_fetch_sub_auto(object, operand, order) \ - __atomic_fetch_sub((object), (operand), (order)) -#define iree_atomic_fetch_and_auto(object, operand, order) \ - __atomic_fetch_and((object), (operand), (order)) -#define iree_atomic_fetch_or_auto(object, operand, order) \ - __atomic_fetch_or((object), (operand), (order)) -#define iree_atomic_fetch_xor_auto(object, operand, order) \ - __atomic_fetch_xor((object), (operand), (order)) -#define iree_atomic_exchange_auto(object, operand, order) \ - __atomic_exchange_n((object), (operand), (order)) -#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ - order_succ, order_fail) \ - __atomic_compare_exchange_n(object, expected, desired, /*weak=*/false, \ - (order_succ), (order_fail)) -#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ - order_succ, order_fail) \ - __atomic_compare_exchange_n(object, expected, desired, /*weak=*/true, \ - (order_succ), (order_fail)) - -#define iree_atomic_thread_fence(order) __atomic_thread_fence(order) - -//============================================================================== -// Unsupported architecture -//============================================================================== #else +// Unsupported architecture. #error Compiler does not have supported C11-style atomics #endif // IREE_COMPILER_* @@ -313,35 +114,17 @@ typedef int64_t iree_atomic_int64_t; #define iree_atomic_compare_exchange_weak_int64 \ iree_atomic_compare_exchange_weak_auto -#endif // iree_atomic_load_auto - -//============================================================================== -// Pointer-width atomics -//============================================================================== +#define iree_atomic_load_intptr iree_atomic_load_auto +#define iree_atomic_store_intptr iree_atomic_store_auto +#define iree_atomic_fetch_add_intptr iree_atomic_fetch_add_auto +#define iree_atomic_fetch_sub_intptr iree_atomic_fetch_sub_auto +#define iree_atomic_exchange_intptr iree_atomic_exchange_auto +#define iree_atomic_compare_exchange_strong_intptr \ + iree_atomic_compare_exchange_strong_auto +#define iree_atomic_compare_exchange_weak_intptr \ + iree_atomic_compare_exchange_weak_auto -#if defined(IREE_PTR_SIZE_32) -typedef iree_atomic_int32_t iree_atomic_ptr_t; -#define iree_atomic_load_ptr iree_atomic_load_int32 -#define iree_atomic_store_ptr iree_atomic_store_int32 -#define iree_atomic_fetch_add_ptr iree_atomic_fetch_add_int32 -#define iree_atomic_fetch_sub_ptr iree_atomic_fetch_sub_int32 -#define iree_atomic_exchange_ptr iree_atomic_exchange_int32 -#define iree_atomic_compare_exchange_strong_ptr \ - iree_atomic_compare_exchange_strong_int32 -#define iree_atomic_compare_exchange_weak_ptr \ - iree_atomic_compare_exchange_weak_int32 -#else -typedef iree_atomic_int64_t iree_atomic_ptr_t; -#define iree_atomic_load_ptr iree_atomic_load_int64 -#define iree_atomic_store_ptr iree_atomic_store_int64 -#define iree_atomic_fetch_add_ptr iree_atomic_fetch_add_int64 -#define iree_atomic_fetch_sub_ptr iree_atomic_fetch_sub_int64 -#define iree_atomic_exchange_ptr iree_atomic_exchange_int64 -#define iree_atomic_compare_exchange_strong_ptr \ - iree_atomic_compare_exchange_strong_int64 -#define iree_atomic_compare_exchange_weak_ptr \ - iree_atomic_compare_exchange_weak_int64 -#endif // IREE_PTR_SIZE_32 +#endif // iree_atomic_load_auto //============================================================================== // Reference count atomics diff --git a/iree/base/atomics_test.cc b/iree/base/atomics_test.cc new file mode 100644 index 0000000000000..b57d0ddb22602 --- /dev/null +++ b/iree/base/atomics_test.cc @@ -0,0 +1,107 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/base/atomics.h" + +#include "iree/testing/gtest.h" + +namespace { + +// NOTE: these tests are just to ensure we correctly compile the macros across +// our supported toolchains: they don't verify that the memory semantics are +// correct (as that would be difficult and is really the toolchain's job). + +TEST(AtomicPtr, LoadStore) { + intptr_t ptr_0 = 0x0; + intptr_t ptr_1 = 0x1; + iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); + EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + iree_atomic_store_intptr(&value, ptr_1, iree_memory_order_seq_cst); + EXPECT_EQ(ptr_1, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); +} + +TEST(AtomicPtr, AddSub) { + intptr_t ptr_0 = 0x0; + intptr_t ptr_1 = 0x1; + intptr_t ptr_2 = 0x2; + iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); + EXPECT_EQ(ptr_0, iree_atomic_fetch_add_intptr(&value, ptr_1, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, iree_atomic_fetch_add_intptr(&value, ptr_1, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, iree_atomic_fetch_sub_intptr(&value, ptr_1, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, iree_atomic_fetch_sub_intptr(&value, ptr_1, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); +} + +TEST(AtomicPtr, Exchange) { + intptr_t ptr_0 = 0x0; + intptr_t ptr_1 = 0x1; + intptr_t ptr_2 = 0x2; + iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); + EXPECT_EQ(ptr_0, iree_atomic_exchange_intptr(&value, ptr_1, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, iree_atomic_exchange_intptr(&value, ptr_2, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_2, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); +} + +TEST(AtomicPtr, CompareExchange) { + intptr_t ptr_0 = 0x0; + intptr_t ptr_1 = 0x1; + intptr_t ptr_2 = 0x2; + iree_atomic_intptr_t value = IREE_ATOMIC_VAR_INIT(ptr_0); + intptr_t ptr_expected = NULL; + + // OK: value == ptr_0, CAS(ptr_0 -> ptr_1) + iree_atomic_store_intptr(&value, ptr_0, iree_memory_order_seq_cst); + ptr_expected = ptr_0; + EXPECT_TRUE(iree_atomic_compare_exchange_strong_intptr( + &value, &ptr_expected, ptr_1, iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, ptr_expected); + EXPECT_EQ(ptr_1, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + + // OK: value == ptr_1, CAS(ptr_1 -> ptr_2) + iree_atomic_store_intptr(&value, ptr_1, iree_memory_order_seq_cst); + ptr_expected = ptr_1; + EXPECT_TRUE(iree_atomic_compare_exchange_strong_intptr( + &value, &ptr_expected, ptr_2, iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_1, ptr_expected); + EXPECT_EQ(ptr_2, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); + + // FAIL: value == ptr_0, CAS(ptr_1 -> ptr_2) + iree_atomic_store_intptr(&value, ptr_0, iree_memory_order_seq_cst); + ptr_expected = ptr_1; + EXPECT_FALSE(iree_atomic_compare_exchange_strong_intptr( + &value, &ptr_expected, ptr_2, iree_memory_order_seq_cst, + iree_memory_order_seq_cst)); + EXPECT_EQ(ptr_0, ptr_expected); + EXPECT_EQ(ptr_0, iree_atomic_load_intptr(&value, iree_memory_order_seq_cst)); +} + +TEST(AtomicRefCount, IncDec) { + iree_atomic_ref_count_t count; + iree_atomic_ref_count_init(&count); + EXPECT_EQ(1, iree_atomic_ref_count_inc(&count)); + EXPECT_EQ(2, iree_atomic_ref_count_inc(&count)); + EXPECT_EQ(3, iree_atomic_ref_count_dec(&count)); + EXPECT_EQ(2, iree_atomic_ref_count_dec(&count)); + EXPECT_EQ(1, iree_atomic_ref_count_dec(&count)); +} + +} // namespace diff --git a/iree/base/bitfield.h b/iree/base/bitfield.h deleted file mode 100644 index bd6eb832637b1..0000000000000 --- a/iree/base/bitfield.h +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility to enable bit operators on enum classes treated as bitfields. -// -// To use define an enum class with valid bitmask values and an underlying type -// then use the macro to enable support: -// enum class MyBitfield : uint32_t { -// kFoo = 1 << 0, -// kBar = 1 << 1, -// }; -// IREE_BITFIELD(MyBitfield); -// MyBitfield value = ~(MyBitfield::kFoo | MyBitfield::kBar); -// -// AnyBitSet is provided as a way to quickly test if any of the given bits are -// set: -// if (AnyBitSet(value)) { /* one or more bits are set */ } -// -// If testing for equality it's recommended that AllBitsSet is used to ensure -// that combined values are handled properly: -// if (AllBitsSet(value, MyBitfield::kSomeSetOfFlags)) { /* all bits set */ } - -#ifndef IREE_BASE_BITFIELD_H_ -#define IREE_BASE_BITFIELD_H_ - -#include -#include -#include -#include - -#include "absl/types/span.h" - -namespace iree { - -#define IREE_BITFIELD(enum_class) \ - inline enum_class operator|(enum_class lhs, enum_class rhs) { \ - using enum_type = typename std::underlying_type::type; \ - return static_cast(static_cast(lhs) | \ - static_cast(rhs)); \ - } \ - inline enum_class& operator|=(enum_class& lhs, enum_class rhs) { \ - using enum_type = typename std::underlying_type::type; \ - lhs = static_cast(static_cast(lhs) | \ - static_cast(rhs)); \ - return lhs; \ - } \ - inline enum_class operator&(enum_class lhs, enum_class rhs) { \ - using enum_type = typename std::underlying_type::type; \ - return static_cast(static_cast(lhs) & \ - static_cast(rhs)); \ - } \ - inline enum_class& operator&=(enum_class& lhs, enum_class rhs) { \ - using enum_type = typename std::underlying_type::type; \ - lhs = static_cast(static_cast(lhs) & \ - static_cast(rhs)); \ - return lhs; \ - } \ - inline enum_class operator^(enum_class lhs, enum_class rhs) { \ - using enum_type = typename std::underlying_type::type; \ - return static_cast(static_cast(lhs) ^ \ - static_cast(rhs)); \ - } \ - inline enum_class& operator^=(enum_class& lhs, enum_class rhs) { \ - using enum_type = typename std::underlying_type::type; \ - lhs = static_cast(static_cast(lhs) ^ \ - static_cast(rhs)); \ - return lhs; \ - } \ - inline enum_class operator~(enum_class lhs) { \ - using enum_type = typename std::underlying_type::type; \ - return static_cast(~static_cast(lhs)); \ - } \ - inline bool AnyBitSet(enum_class lhs) { \ - using enum_type = typename std::underlying_type::type; \ - return static_cast(lhs) != 0; \ - } \ - inline bool AllBitsSet(enum_class lhs, enum_class rhs) { \ - return (lhs & rhs) == rhs; \ - } - -// Appends the formatted contents of the given bitfield value to a stream. -// -// Processes values in the order of the mapping table provided and will only -// use each bit once. Use this to prioritize combined flags over split ones. -template -void FormatBitfieldValue( - std::ostringstream* stream, T value, - const absl::Span> mappings) { - T remaining_bits = value; - int i = 0; - for (const auto& mapping : mappings) { - if ((remaining_bits & mapping.first) == mapping.first) { - if (i > 0) { - *stream << "|"; - } - *stream << mapping.second; - remaining_bits &= ~mapping.first; - ++i; - } - } - using enum_type = typename std::underlying_type::type; - if (remaining_bits != static_cast(0)) { - if (i > 0) { - *stream << "|"; - } - *stream << std::hex << static_cast(remaining_bits) << "h"; - } -} - -// Returns a string with the formatted contents of the given bitfield value. -// -// Usage: -// MyValue my_value = MyValue::kA | MyValue::kB; -// std::string string_value = FormatBitfieldValue(my_value, { -// {MyValue::kA, "kA"}, -// {MyValue::kB, "kB"}, -// }); -// // string_value contains 'kA|kB' -template -std::string FormatBitfieldValue( - T value, absl::Span> mappings) { - std::ostringstream stream; - FormatBitfieldValue(&stream, value, mappings); - return stream.str(); -} - -} // namespace iree - -#endif // IREE_BASE_BITFIELD_H_ diff --git a/iree/base/bitfield_test.cc b/iree/base/bitfield_test.cc deleted file mode 100644 index 1da2c49b1a9cd..0000000000000 --- a/iree/base/bitfield_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/bitfield.h" - -#include -#include - -#include "iree/testing/gtest.h" - -namespace iree { - -// NOTE: define here so that we don't get internal linkage warnings. -enum class MyValue : uint32_t { - kNone = 0, - kA = 1 << 0, - kB = 1 << 1, - kAll = kA | kB, -}; -IREE_BITFIELD(MyValue); - -namespace { - -// Tests general usage. -TEST(BitfieldTest, FormatBitfieldValue) { - std::vector> mappings = { - {MyValue::kA, "kA"}, - {MyValue::kB, "kB"}, - }; - EXPECT_EQ("", - FormatBitfieldValue(MyValue::kNone, absl::MakeConstSpan(mappings))); - EXPECT_EQ("kA", - FormatBitfieldValue(MyValue::kA, absl::MakeConstSpan(mappings))); - EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB, - absl::MakeConstSpan(mappings))); -} - -// Tests that empty mapping tables are fine. -TEST(BitfieldTest, FormatBitfieldValueEmpty) { - EXPECT_EQ("", FormatBitfieldValue(MyValue::kNone, {})); -} - -// Tests that values not found in the mappings are still displayed. -TEST(BitfieldTest, FormatBitfieldValueUnhandledValues) { - EXPECT_EQ("kA|2h", FormatBitfieldValue(MyValue::kA | MyValue::kB, - { - {MyValue::kA, "kA"}, - })); -} - -// Tests priority order in the mapping table. -TEST(BitfieldTest, FormatBitfieldValuePriority) { - // No priority, will do separate. - EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB, - { - {MyValue::kA, "kA"}, - {MyValue::kB, "kB"}, - {MyValue::kAll, "kAll"}, - })); - - // Priority on the combined flag, use that instead. - EXPECT_EQ("kAll", FormatBitfieldValue(MyValue::kA | MyValue::kB, - { - {MyValue::kAll, "kAll"}, - {MyValue::kA, "kA"}, - {MyValue::kB, "kB"}, - })); -} - -} // namespace -} // namespace iree diff --git a/iree/base/debugging.h b/iree/base/debugging.h index 538ba08804ad6..4f47d134a4083 100644 --- a/iree/base/debugging.h +++ b/iree/base/debugging.h @@ -29,6 +29,13 @@ extern "C" { #define IREE_ATTRIBUTE_ALWAYS_INLINE #endif // IREE_COMPILER_* +//===----------------------------------------------------------------------===// +// Debugger interaction +//===----------------------------------------------------------------------===// +// NOTE: in general it's not a good idea to change program behavior when running +// under a debugger as that then makes it harder to reproduce and successfully +// debug issues that happen without the debugger. + // Forces a break into an attached debugger. // May be ignored if no debugger is attached or raise a signal that gives the // option to attach a debugger. @@ -56,6 +63,14 @@ IREE_ATTRIBUTE_ALWAYS_INLINE static inline void iree_debug_break() { #endif // IREE_PLATFORM_WINDOWS } +//===----------------------------------------------------------------------===// +// IREE_ASSERT macros +//===----------------------------------------------------------------------===// +// These are no-oped in builds with NDEBUG defined (by default anything but +// `-c dbg`/`-DCMAKE_BUILD_TYPE=Debug`). As with normal assert() ensure that +// side-effecting behavior is avoided as the expression will not be evaluated +// when the asserts are removed! + #if !defined(NDEBUG) #define IREE_ASSERT(expr, ...) \ { \ @@ -78,6 +93,54 @@ IREE_ATTRIBUTE_ALWAYS_INLINE static inline void iree_debug_break() { #define IREE_ASSERT_GT(lhs, rhs, ...) IREE_ASSERT_CMP(lhs, >=, rhs, __VA_ARGS__) #define IREE_ASSERT_GE(lhs, rhs, ...) IREE_ASSERT_CMP(lhs, >, rhs, __VA_ARGS__) +//===----------------------------------------------------------------------===// +// Sanitizer interfaces +//===----------------------------------------------------------------------===// +// These provide hints to the various -fsanitize= features that help us indicate +// what our code is doing to prevent false positives and gain additional +// coverage. By default the sanitizers try to hook platform features like +// mutexes and threads and our own implementations of those aren't automatically +// picked up. In addition, specific uses of memory like arenas can thwart tools +// like ASAN that try to detect accesses to freed memory because we are never +// actually malloc()'ing and free()'ing and need to tell ASAN when blocks of +// memory come into/outof the pool. +// +// The documentation on these interfaces is pretty sparse but it's possible to +// find usage examples of the hooks in the compiler-provided hooks themselves. +// +// The headers can be viewed here: +// https://github.com/llvm/llvm-project/tree/main/compiler-rt/include/sanitizer +// And common interceptors here: +// https://github.com/llvm/llvm-project/blob/main/compiler-rt/lib/tsan/rtl/tsan_interceptors_posix.cpp +// +// NOTE: don't assume the presence of a sanitizer implies clang+llvm+x86! GCC +// supports all of the sanitizers and MSVC supports ASAN and almost all of them +// can be used on non-x86 platforms. + +#if defined(IREE_SANITIZER_ADDRESS) +#include +#endif // IREE_SANITIZER_ADDRESS + +#if defined(IREE_SANITIZER_MEMORY) +// #include +#endif // IREE_SANITIZER_MEMORY + +#if defined(IREE_SANITIZER_THREAD) +// #include +#endif // IREE_SANITIZER_THREAD + +// Suppresses leak detection false-positives in a region. May be nested. +// Do not use this for any IREE-owned code: fix your leaks! This is useful when +// third-party libraries or system calls may create false positives or just be +// leaky such as GPU drivers and shader compilers (which are notoriously bad). +#if defined(IREE_SANITIZER_ADDRESS) +#define IREE_LEAK_CHECK_DISABLE_PUSH() __lsan_disable() +#define IREE_LEAK_CHECK_DISABLE_POP() __lsan_enable() +#else +#define IREE_LEAK_CHECK_DISABLE_PUSH() +#define IREE_LEAK_CHECK_DISABLE_POP() +#endif // IREE_SANITIZER_ADDRESS + #ifdef __cplusplus } // extern "C" #endif diff --git a/iree/base/internal/BUILD b/iree/base/internal/BUILD index 2fe99ee2fbbd4..d1fdac49462f1 100644 --- a/iree/base/internal/BUILD +++ b/iree/base/internal/BUILD @@ -20,6 +20,18 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "atomics", + hdrs = [ + "atomics_clang.h", + "atomics_gcc.h", + "atomics_msvc.h", + ], + deps = [ + "//iree/base:target_platform", + ], +) + cc_library( name = "file_handle_win32", srcs = ["file_handle_win32.cc"], diff --git a/iree/base/internal/CMakeLists.txt b/iree/base/internal/CMakeLists.txt index f9975d704d963..e0853bfe190e5 100644 --- a/iree/base/internal/CMakeLists.txt +++ b/iree/base/internal/CMakeLists.txt @@ -14,6 +14,18 @@ iree_add_all_subdirs() +iree_cc_library( + NAME + atomics + HDRS + "atomics_clang.h" + "atomics_gcc.h" + "atomics_msvc.h" + DEPS + iree::base::target_platform + PUBLIC +) + iree_cc_library( NAME file_handle_win32 diff --git a/iree/base/internal/atomics_clang.h b/iree/base/internal/atomics_clang.h new file mode 100644 index 0000000000000..e263995eededb --- /dev/null +++ b/iree/base/internal/atomics_clang.h @@ -0,0 +1,81 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_ATOMICS_CLANG_H_ +#define IREE_BASE_INTERNAL_ATOMICS_CLANG_H_ + +#include +#include +#include +#include + +#include "iree/base/target_platform.h" + +#if defined(IREE_COMPILER_CLANG) + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum iree_memory_order_e { + iree_memory_order_relaxed = __ATOMIC_RELAXED, + iree_memory_order_consume = __ATOMIC_CONSUME, + iree_memory_order_acquire = __ATOMIC_ACQUIRE, + iree_memory_order_release = __ATOMIC_RELEASE, + iree_memory_order_acq_rel = __ATOMIC_ACQ_REL, + iree_memory_order_seq_cst = __ATOMIC_SEQ_CST, +} iree_memory_order_t; + +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef _Atomic int32_t iree_atomic_int32_t; +typedef _Atomic int64_t iree_atomic_int64_t; +// TODO(#3453): check for __int128 support before using +// typedef _Atomic __int128 iree_atomic_int128_t; +typedef _Atomic intptr_t iree_atomic_intptr_t; + +#define iree_atomic_load_auto(object, order) \ + __c11_atomic_load((object), (order)) +#define iree_atomic_store_auto(object, desired, order) \ + __c11_atomic_store((object), (desired), (order)) +#define iree_atomic_fetch_add_auto(object, operand, order) \ + __c11_atomic_fetch_add((object), (operand), (order)) +#define iree_atomic_fetch_sub_auto(object, operand, order) \ + __c11_atomic_fetch_sub((object), (operand), (order)) +#define iree_atomic_fetch_and_auto(object, operand, order) \ + __c11_atomic_fetch_and((object), (operand), (order)) +#define iree_atomic_fetch_or_auto(object, operand, order) \ + __c11_atomic_fetch_or((object), (operand), (order)) +#define iree_atomic_fetch_xor_auto(object, operand, order) \ + __c11_atomic_fetch_xor((object), (operand), (order)) +#define iree_atomic_exchange_auto(object, operand, order) \ + __c11_atomic_exchange((object), (operand), (order)) +#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_strong((object), (expected), (desired), \ + (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ + order_succ, order_fail) \ + __c11_atomic_compare_exchange_weak((object), (expected), (desired), \ + (order_succ), (order_fail)) + +#define iree_atomic_thread_fence(order) __c11_atomic_thread_fence(order) + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_COMPILER_CLANG + +#endif // IREE_BASE_INTERNAL_ATOMICS_CLANG_H_ diff --git a/iree/base/internal/atomics_gcc.h b/iree/base/internal/atomics_gcc.h new file mode 100644 index 0000000000000..d60047fc660af --- /dev/null +++ b/iree/base/internal/atomics_gcc.h @@ -0,0 +1,97 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_ATOMICS_GCC_H_ +#define IREE_BASE_INTERNAL_ATOMICS_GCC_H_ + +#include +#include +#include +#include + +#include "iree/base/target_platform.h" + +#if defined(IREE_COMPILER_GCC) + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum iree_memory_order_e { + iree_memory_order_relaxed = __ATOMIC_RELAXED, + iree_memory_order_consume = __ATOMIC_CONSUME, + iree_memory_order_acquire = __ATOMIC_ACQUIRE, + iree_memory_order_release = __ATOMIC_RELEASE, + iree_memory_order_acq_rel = __ATOMIC_ACQ_REL, + iree_memory_order_seq_cst = __ATOMIC_SEQ_CST, +} iree_memory_order_t; + +#define IREE_ATOMIC_VAR_INIT(value) (value) + +typedef int32_t iree_atomic_int32_t; +typedef int64_t iree_atomic_int64_t; +// typedef __int128 iree_atomic_int128_t; +typedef intptr_t iree_atomic_intptr_t; + +#ifdef __cplusplus +// Equiv to C++ auto keyword in C++ mode. +#define __iree_auto_type auto +#else +// Only defined in C mode. +#define __iree_auto_type __auto_type +#endif + +#define iree_atomic_load_auto(object, order) \ + __extension__({ \ + __iree_auto_type __atomic_load_ptr = (object); \ + __typeof__(*__atomic_load_ptr) __atomic_load_tmp; \ + __atomic_load(__atomic_load_ptr, &__atomic_load_tmp, (order)); \ + __atomic_load_tmp; \ + }) +#define iree_atomic_store_auto(object, desired, order) \ + __extension__({ \ + __iree_auto_type __atomic_store_ptr = (object); \ + __typeof__(*__atomic_store_ptr) __atomic_store_tmp = (desired); \ + __atomic_store(__atomic_store_ptr, &__atomic_store_tmp, (order)); \ + }) +#define iree_atomic_fetch_add_auto(object, operand, order) \ + __atomic_fetch_add((object), (operand), (order)) +#define iree_atomic_fetch_sub_auto(object, operand, order) \ + __atomic_fetch_sub((object), (operand), (order)) +#define iree_atomic_fetch_and_auto(object, operand, order) \ + __atomic_fetch_and((object), (operand), (order)) +#define iree_atomic_fetch_or_auto(object, operand, order) \ + __atomic_fetch_or((object), (operand), (order)) +#define iree_atomic_fetch_xor_auto(object, operand, order) \ + __atomic_fetch_xor((object), (operand), (order)) +#define iree_atomic_exchange_auto(object, operand, order) \ + __atomic_exchange_n((object), (operand), (order)) +#define iree_atomic_compare_exchange_strong_auto(object, expected, desired, \ + order_succ, order_fail) \ + __atomic_compare_exchange_n(object, expected, desired, /*weak=*/false, \ + (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak_auto(object, expected, desired, \ + order_succ, order_fail) \ + __atomic_compare_exchange_n(object, expected, desired, /*weak=*/true, \ + (order_succ), (order_fail)) + +#define iree_atomic_thread_fence(order) __atomic_thread_fence(order) + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_COMPILER_GCC + +#endif // IREE_BASE_INTERNAL_ATOMICS_GCC_H_ diff --git a/iree/base/internal/atomics_msvc.h b/iree/base/internal/atomics_msvc.h new file mode 100644 index 0000000000000..865690f816361 --- /dev/null +++ b/iree/base/internal/atomics_msvc.h @@ -0,0 +1,190 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_ATOMICS_MSVC_H_ +#define IREE_BASE_INTERNAL_ATOMICS_MSVC_H_ + +#include +#include +#include +#include + +#include "iree/base/target_platform.h" + +#if defined(IREE_COMPILER_MSVC) + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum iree_memory_order_e { + iree_memory_order_relaxed, + iree_memory_order_consume, + iree_memory_order_acquire, + iree_memory_order_release, + iree_memory_order_acq_rel, + iree_memory_order_seq_cst, +} iree_memory_order_t; + +#define IREE_ATOMIC_VAR_INIT(value) \ + { (value) } + +typedef struct { + int32_t __val; +} iree_atomic_int32_t; +typedef struct { + int64_t __val; +} iree_atomic_int64_t; +// typedef __declspec(align(16)) struct { +// uint64_t __val[2]; +// } iree_atomic_int128_t; +typedef struct { + intptr_t __val; +} iree_atomic_intptr_t; + +#define iree_atomic_load_int32(object, order) \ + InterlockedExchangeAdd((volatile LONG*)object, 0) +#define iree_atomic_store_int32(object, desired, order) \ + InterlockedExchange((volatile LONG*)object, desired) +#define iree_atomic_fetch_add_int32(object, operand, order) \ + InterlockedExchangeAdd((volatile LONG*)object, operand) +#define iree_atomic_fetch_sub_int32(object, operand, order) \ + InterlockedExchangeAdd((volatile LONG*)object, -((int32_t)(operand))) +#define iree_atomic_fetch_and_int32(object, operand, order) \ + InterlockedAnd((volatile LONG*)object, operand) +#define iree_atomic_fetch_or_int32(object, operand, order) \ + InterlockedOr((volatile LONG*)object, operand) +#define iree_atomic_fetch_xor_int32(object, operand, order) \ + InterlockedXor((volatile LONG*)object, operand) +#define iree_atomic_exchange_int32(object, desired, order) \ + InterlockedExchange((volatile LONG*)object, desired) +#define iree_atomic_compare_exchange_strong_int32(object, expected, desired, \ + order_succ, order_fail) \ + iree_atomic_compare_exchange_strong_int32_impl( \ + (volatile iree_atomic_int32_t*)(object), (int32_t*)(expected), \ + (int32_t)(desired), (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak_int32 \ + iree_atomic_compare_exchange_strong_int32 + +#define iree_atomic_load_int64(object, order) \ + InterlockedExchangeAdd64((volatile LONG64*)object, 0) +#define iree_atomic_store_int64(object, desired, order) \ + InterlockedExchange64((volatile LONG64*)object, (LONG64)desired) +#define iree_atomic_fetch_add_int64(object, operand, order) \ + InterlockedExchangeAdd64((volatile LONG64*)object, (LONG64)operand) +#define iree_atomic_fetch_sub_int64(object, operand, order) \ + InterlockedExchangeAdd64((volatile LONG64*)object, -(operand)) +#define iree_atomic_fetch_and_int64(object, operand, order) \ + InterlockedAnd64((volatile LONG64*)object, operand) +#define iree_atomic_fetch_or_int64(object, operand, order) \ + InterlockedOr64((volatile LONG64*)object, operand) +#define iree_atomic_fetch_xor_int64(object, operand, order) \ + InterlockedXor64((volatile LONG64*)object, operand) +#define iree_atomic_exchange_int64(object, desired, order) \ + InterlockedExchange64((volatile LONG64*)object, desired) +#define iree_atomic_compare_exchange_strong_int64(object, expected, desired, \ + order_succ, order_fail) \ + iree_atomic_compare_exchange_strong_int64_impl( \ + (volatile iree_atomic_int64_t*)(object), (int64_t*)(expected), \ + (int64_t)(desired), (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak_int64 \ + iree_atomic_compare_exchange_strong_int64 + +#define iree_atomic_thread_fence(order) MemoryBarrier() + +static inline bool iree_atomic_compare_exchange_strong_int32_impl( + volatile iree_atomic_int32_t* object, int32_t* expected, int32_t desired, + iree_memory_order_t order_succ, iree_memory_order_t order_fail) { + int32_t expected_value = *expected; + int32_t old_value = InterlockedCompareExchange((volatile LONG*)object, + desired, expected_value); + if (old_value == expected_value) { + return true; + } else { + *expected = old_value; + return false; + } +} + +static inline bool iree_atomic_compare_exchange_strong_int64_impl( + volatile iree_atomic_int64_t* object, int64_t* expected, int64_t desired, + iree_memory_order_t order_succ, iree_memory_order_t order_fail) { + int64_t expected_value = *expected; + int64_t old_value = InterlockedCompareExchange64((volatile LONG64*)object, + desired, expected_value); + if (old_value == expected_value) { + return true; + } else { + *expected = old_value; + return false; + } +} + +#define iree_atomic_thread_fence(order) MemoryBarrier() + +// There are no pointer-width atomic ops in MSVC so we need to specialize based +// on the pointer size. +#if defined(IREE_PTR_SIZE_32) +#define iree_atomic_load_intptr(object, order) \ + (intptr_t) iree_atomic_load_int32((iree_atomic_int32_t*)(object), (order)) +#define iree_atomic_store_intptr(object, desired, order) \ + (intptr_t) iree_atomic_store_int32((iree_atomic_int32_t*)(object), \ + (int32_t)(desired), (order)) +#define iree_atomic_fetch_add_intptr(object, operand, order) \ + (intptr_t) iree_atomic_fetch_add_int32((iree_atomic_int32_t*)(object), \ + (int32_t)(operand), (order)) +#define iree_atomic_fetch_sub_intptr(object, operand, order) \ + (intptr_t) iree_atomic_fetch_sub_int32((iree_atomic_int32_t*)(object), \ + (int32_t)(operand), (order)) +#define iree_atomic_exchange_intptr(object, desired, order) \ + (intptr_t) iree_atomic_exchange_int32((iree_atomic_int32_t*)(object), \ + (int32_t)(desired), (order)) +#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ + order_succ, order_fail) \ + iree_atomic_compare_exchange_strong_int32( \ + (iree_atomic_int32_t*)(object), (int32_t*)(expected), \ + (int32_t)(desired), (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak_intptr \ + iree_atomic_compare_exchange_strong_intptr +#else +#define iree_atomic_load_intptr(object, order) \ + (intptr_t) iree_atomic_load_int64((iree_atomic_int64_t*)(object), (order)) +#define iree_atomic_store_intptr(object, desired, order) \ + (intptr_t) iree_atomic_store_int64((iree_atomic_int64_t*)(object), \ + (int64_t)(desired), (order)) +#define iree_atomic_fetch_add_intptr(object, operand, order) \ + (intptr_t) iree_atomic_fetch_add_int64((iree_atomic_int64_t*)(object), \ + (int64_t)(operand), (order)) +#define iree_atomic_fetch_sub_intptr(object, operand, order) \ + (intptr_t) iree_atomic_fetch_sub_int64((iree_atomic_int64_t*)(object), \ + (int64_t)(operand), (order)) +#define iree_atomic_exchange_intptr(object, desired, order) \ + (intptr_t) iree_atomic_exchange_int64((iree_atomic_int64_t*)(object), \ + (int64_t)(desired), (order)) +#define iree_atomic_compare_exchange_strong_intptr(object, expected, desired, \ + order_succ, order_fail) \ + iree_atomic_compare_exchange_strong_int64( \ + (iree_atomic_int64_t*)(object), (int64_t*)(expected), \ + (int64_t)(desired), (order_succ), (order_fail)) +#define iree_atomic_compare_exchange_weak_intptr \ + iree_atomic_compare_exchange_strong_intptr +#endif // IREE_PTR_SIZE_32 + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // IREE_COMPILER_MSVC + +#endif // IREE_BASE_INTERNAL_ATOMICS_MSVC_H_ diff --git a/iree/base/internal/file_io_win32.cc b/iree/base/internal/file_io_win32.cc index 3bf4b0e17ee2a..03183ee42eddf 100644 --- a/iree/base/internal/file_io_win32.cc +++ b/iree/base/internal/file_io_win32.cc @@ -18,6 +18,8 @@ #include +#include + #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "iree/base/file_io.h" @@ -120,6 +122,8 @@ StatusOr GetTempFile(absl::string_view base_name) { file_path::JoinPaths(temp_path, base_name) + "XXXXXX"; if (::_mktemp(&template_path[0]) != nullptr) { + static std::atomic next_id{0}; + template_path += std::to_string(next_id++); return template_path; // Should have been modified by _mktemp. } else { return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) diff --git a/iree/base/math.h b/iree/base/math.h index c268f6ed002ba..a1d9aceb62164 100644 --- a/iree/base/math.h +++ b/iree/base/math.h @@ -408,7 +408,7 @@ static inline uint64_t iree_prng_xoroshiro128starstar_next_uint64( // capitalize on it! typedef iree_alignas(iree_max_align_t) struct { uint8_t value[16]; // first to ensure alignment - uint8_t remaining; // number of remaining valid values in the state + int8_t remaining; // number of remaining valid values in the state } iree_prng_minilcg128_state_t; #define IREE_PRNG_MINILCG_INIT_MUL_CONSTANT 13 @@ -431,7 +431,7 @@ static inline void iree_prng_minilcg128_initialize( static inline uint8_t iree_prng_minilcg128_next_uint8( iree_prng_minilcg128_state_t* state) { - if (IREE_UNLIKELY(--state->remaining == 0)) { + if (IREE_UNLIKELY(--state->remaining < 0)) { #if defined(IREE_ARCH_ARM_64) uint8x16_t kmul = vdupq_n_u8(IREE_PRNG_MINILCG_NEXT_MUL_CONSTANT); uint8x16_t kadd = vdupq_n_u8(IREE_PRNG_MINILCG_NEXT_ADD_CONSTANT); @@ -442,9 +442,9 @@ static inline uint8_t iree_prng_minilcg128_next_uint8( IREE_PRNG_MINILCG_NEXT_ADD_CONSTANT; } #endif // IREE_ARCH_ARM_64 - state->remaining = 16; + state->remaining = 15; } - return state->value[16 - state->remaining + 1]; + return state->value[16 - state->remaining - 1]; } #endif // IREE_BASE_MATH_H_ diff --git a/iree/base/math_test.cc b/iree/base/math_test.cc index 3ed3b50ad7605..ba6a95fba8b5f 100644 --- a/iree/base/math_test.cc +++ b/iree/base/math_test.cc @@ -227,25 +227,25 @@ TEST(PRNG, MiniLcg128) { iree_prng_minilcg128_state_t state; iree_prng_minilcg128_initialize(/*seed=*/0ull, &state); - EXPECT_EQ(111u, iree_prng_minilcg128_next_uint8(&state)); + EXPECT_EQ(21u, iree_prng_minilcg128_next_uint8(&state)); for (int i = 0; i < 100; ++i) { iree_prng_minilcg128_next_uint8(&state); } - EXPECT_EQ(212u, iree_prng_minilcg128_next_uint8(&state)); + EXPECT_EQ(18u, iree_prng_minilcg128_next_uint8(&state)); iree_prng_minilcg128_initialize(/*seed=*/1ull, &state); - EXPECT_EQ(198u, iree_prng_minilcg128_next_uint8(&state)); + EXPECT_EQ(20u, iree_prng_minilcg128_next_uint8(&state)); for (int i = 0; i < 100; ++i) { iree_prng_minilcg128_next_uint8(&state); } - EXPECT_EQ(135u, iree_prng_minilcg128_next_uint8(&state)); + EXPECT_EQ(13u, iree_prng_minilcg128_next_uint8(&state)); iree_prng_minilcg128_initialize(/*seed=*/UINT64_MAX, &state); - EXPECT_EQ(12u, iree_prng_minilcg128_next_uint8(&state)); + EXPECT_EQ(234u, iree_prng_minilcg128_next_uint8(&state)); for (int i = 0; i < 100; ++i) { iree_prng_minilcg128_next_uint8(&state); } - EXPECT_EQ(229u, iree_prng_minilcg128_next_uint8(&state)); + EXPECT_EQ(59u, iree_prng_minilcg128_next_uint8(&state)); } } // namespace diff --git a/iree/base/memory.h b/iree/base/memory.h index dcd4757e6ed9e..1cfa389612513 100644 --- a/iree/base/memory.h +++ b/iree/base/memory.h @@ -113,21 +113,4 @@ ABSL_MUST_USE_RESULT Cleanup MakeCleanup(F&& f) { } // namespace iree -#if defined(__has_feature) -#if __has_feature(address_sanitizer) -#define IREE_CONFIG_ASAN 1 -#endif // __has_feature(address_sanitizer) -#endif // __has_feature - -// If you see these macros being used it means that the code between is not -// really under our control and not a leak we would be able to prevent. -#if defined(IREE_CONFIG_ASAN) -#include -#define IREE_DISABLE_LEAK_CHECKS() __lsan_disable() -#define IREE_ENABLE_LEAK_CHECKS() __lsan_enable() -#else -#define IREE_DISABLE_LEAK_CHECKS() -#define IREE_ENABLE_LEAK_CHECKS() -#endif // IREE_CONFIG_ASAN - #endif // IREE_BASE_MEMORY_H_ diff --git a/iree/base/synchronization.h b/iree/base/synchronization.h index 304a634f5fa02..d02d9b052a3e0 100644 --- a/iree/base/synchronization.h +++ b/iree/base/synchronization.h @@ -55,9 +55,14 @@ #define IREE_PTR_GUARDED_BY(x) #endif // __cplusplus +// NOTE: we only support futex when not using tsan as we need to add annotations +// for tsan to understand what we are doing. +// https://github.com/llvm-mirror/compiler-rt/blob/master/include/sanitizer/tsan_interface.h #if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_EMSCRIPTEN) || \ defined(IREE_PLATFORM_LINUX) || defined(IREE_PLATFORM_WINDOWS) +#if !defined(IREE_SANITIZER_THREAD) #define IREE_PLATFORM_HAS_FUTEX 1 +#endif // !IREE_SANITIZER_THREAD #endif // IREE_PLATFORM_* #if defined(IREE_PLATFORM_APPLE) diff --git a/iree/base/target_platform.h b/iree/base/target_platform.h index 7700ee1d43041..666f3a53842f6 100644 --- a/iree/base/target_platform.h +++ b/iree/base/target_platform.h @@ -41,6 +41,10 @@ // IREE_COMPILER_GCC_COMPAT // IREE_COMPILER_MSVC // +// IREE_SANITIZER_ADDRESS +// IREE_SANITIZER_MEMORY +// IREE_SANITIZER_THREAD +// // IREE_PLATFORM_ANDROID // IREE_PLATFORM_ANDROID_EMULATOR // IREE_PLATFORM_APPLE (IOS | MACOS) @@ -140,6 +144,18 @@ #error Unrecognized compiler. #endif // compiler versions +#if defined(__has_feature) +#if __has_feature(address_sanitizer) +#define IREE_SANITIZER_ADDRESS 1 +#endif // __has_feature(address_sanitizer) +#if __has_feature(memory_sanitizer) +#define IREE_SANITIZER_MEMORY 1 +#endif // __has_feature(memory_sanitizer) +#if __has_feature(thread_sanitizer) +#define IREE_SANITIZER_THREAD 1 +#endif // __has_feature(thread_sanitizer) +#endif // defined(__has_feature) + //============================================================================== // IREE_PLATFORM_ANDROID //============================================================================== diff --git a/iree/base/threading_darwin.c b/iree/base/threading_darwin.c index a9c0d972c7bd6..902e743172d83 100644 --- a/iree/base/threading_darwin.c +++ b/iree/base/threading_darwin.c @@ -107,7 +107,7 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, pthread_attr_t thread_attr; pthread_attr_init(&thread_attr); - pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_DETACHED); + pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_JOINABLE); if (params.stack_size) { pthread_attr_setstacksize(&thread_attr, params.stack_size); } @@ -117,6 +117,11 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, iree_thread_qos_class_for_priority_class(params.priority_class); pthread_attr_set_qos_class_np(&thread_attr, qos_class, 0); + // Retain the thread for the thread itself; this way if the caller immediately + // releases the iree_thread_t handle the thread won't explode. + iree_thread_retain(thread); + *out_thread = thread; + // Create the thread either suspended or running as the user requested. int rc; if (params.create_suspended) { @@ -132,7 +137,9 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, } pthread_attr_destroy(&thread_attr); if (rc != 0) { - iree_allocator_free(allocator, thread); + iree_thread_release(thread); // for self + iree_thread_release(thread); // for caller + *out_thread = NULL; IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_INTERNAL, "thread creation failed with %d", rc); @@ -143,15 +150,21 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, iree_thread_request_affinity(thread, params.initial_affinity); } - // Retain the thread for the thread itself; this way if the caller immediately - // releases the iree_thread_t handle the thread won't explode. - iree_thread_retain(thread); - IREE_TRACE_ZONE_END(z0); - *out_thread = thread; return iree_ok_status(); } +static void iree_thread_delete(iree_thread_t* thread) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_thread_resume(thread); + pthread_join(thread->handle, NULL); + + iree_allocator_free(thread->allocator, thread); + + IREE_TRACE_ZONE_END(z0); +} + void iree_thread_retain(iree_thread_t* thread) { if (thread) { iree_atomic_ref_count_inc(&thread->ref_count); @@ -160,7 +173,7 @@ void iree_thread_retain(iree_thread_t* thread) { void iree_thread_release(iree_thread_t* thread) { if (thread && iree_atomic_ref_count_dec(&thread->ref_count) == 1) { - iree_allocator_free(thread->allocator, thread); + iree_thread_delete(thread); } } diff --git a/iree/base/threading_pthreads.c b/iree/base/threading_pthreads.c index 9527d41c4ce24..cea28d692d190 100644 --- a/iree/base/threading_pthreads.c +++ b/iree/base/threading_pthreads.c @@ -67,16 +67,19 @@ static bool iree_thread_resumed_predicate(void* arg) { typedef int (*pthread_setname_np_fn_t)(pthread_t thread, const char* name); +static pthread_setname_np_fn_t iree_pthread_setname_np_fn = NULL; +static void iree_thread_try_query_setname_fn(void) { + iree_pthread_setname_np_fn = + (pthread_setname_np_fn_t)dlsym(RTLD_DEFAULT, "pthread_setname_np"); +} + static int iree_thread_set_name(pthread_t handle, const char* name) { IREE_TRACE_ZONE_BEGIN(z0); - static pthread_setname_np_fn_t pthread_setname_np_fn = NULL; - if (!pthread_setname_np_fn) { - pthread_setname_np_fn = - (pthread_setname_np_fn_t)dlsym(RTLD_DEFAULT, "pthread_setname_np"); - } + static iree_once_flag fn_query_flag = IREE_ONCE_FLAG_INIT; + iree_call_once(&fn_query_flag, iree_thread_try_query_setname_fn); int rc; - if (pthread_setname_np_fn) { - rc = pthread_setname_np_fn(handle, name); + if (iree_pthread_setname_np_fn) { + rc = iree_pthread_setname_np_fn(handle, name); } else { rc = EINVAL; } @@ -147,7 +150,7 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, pthread_attr_t thread_attr; pthread_attr_init(&thread_attr); - pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_DETACHED); + pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_JOINABLE); if (params.stack_size) { pthread_attr_setstacksize(&thread_attr, params.stack_size); } @@ -155,6 +158,7 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, // Retain the thread for the thread itself; this way if the caller immediately // releases the iree_thread_t handle the thread won't explode. iree_thread_retain(thread); + *out_thread = thread; // Unfortunately we can't create the thread suspended (no API). This means // that we are likely to incur some thrashing here as the thread gets spun up @@ -170,7 +174,9 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, } pthread_attr_destroy(&thread_attr); if (rc != 0) { - iree_allocator_free(allocator, thread); + iree_thread_release(thread); // for self + iree_thread_release(thread); // for caller + *out_thread = NULL; IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_INTERNAL, "thread creation failed with %d", rc); @@ -184,7 +190,6 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, } IREE_TRACE_ZONE_END(z0); - *out_thread = thread; return iree_ok_status(); } @@ -192,6 +197,7 @@ static void iree_thread_delete(iree_thread_t* thread) { IREE_TRACE_ZONE_BEGIN(z0); iree_thread_resume(thread); + pthread_join(thread->handle, NULL); iree_notification_deinitialize(&thread->suspend_barrier); iree_thread_override_list_deinitialize(&thread->qos_override_list); diff --git a/iree/base/threading_win32.c b/iree/base/threading_win32.c index dbd9cd310019e..835e6f915ca5d 100644 --- a/iree/base/threading_win32.c +++ b/iree/base/threading_win32.c @@ -160,6 +160,11 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, params.priority_class, thread->allocator, &thread->qos_override_list); + // Retain the thread for the thread itself; this way if the caller immediately + // releases the iree_thread_t handle the thread won't explode. + iree_thread_retain(thread); + *out_thread = thread; + // Create the thread either suspended or running as the user requested. { IREE_TRACE_ZONE_BEGIN_NAMED(z1, "CreateThread"); @@ -169,7 +174,9 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, IREE_TRACE_ZONE_END(z1); } if (thread->handle == INVALID_HANDLE_VALUE) { - iree_allocator_free(allocator, thread); + iree_thread_release(thread); // for self + iree_thread_release(thread); // for caller + *out_thread = NULL; IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_INTERNAL, "thread creation failed with %lu", GetLastError()); @@ -187,12 +194,7 @@ iree_status_t iree_thread_create(iree_thread_entry_t entry, void* entry_arg, iree_thread_request_affinity(thread, params.initial_affinity); } - // Retain the thread for the thread itself; this way if the caller immediately - // releases the iree_thread_t handle the thread won't explode. - iree_thread_retain(thread); - IREE_TRACE_ZONE_END(z0); - *out_thread = thread; return iree_ok_status(); } @@ -201,6 +203,7 @@ static void iree_thread_delete(iree_thread_t* thread) { iree_thread_resume(thread); + WaitForSingleObject(thread->handle, INFINITE); CloseHandle(thread->handle); iree_thread_override_list_deinitialize(&thread->qos_override_list); iree_allocator_free(thread->allocator, thread); diff --git a/iree/base/time.h b/iree/base/time.h deleted file mode 100644 index e51a5b12e8525..0000000000000 --- a/iree/base/time.h +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_TIME_H_ -#define IREE_BASE_TIME_H_ - -#include -#include - -#include "iree/base/api.h" - -namespace iree { -namespace impl { -template -class ChronoType { - public: - ChronoType() : value_() {} - explicit ChronoType(const T& value) : value_(value) {} - explicit ChronoType(T&& value) noexcept( - std::is_nothrow_move_constructible::value) - : value_(std::move(value)) {} - - explicit operator T&() noexcept { return value_; } - explicit operator const T&() const noexcept { return value_; } - - friend void swap(ChronoType& a, ChronoType& b) noexcept { - using std::swap; - swap(static_cast(a), static_cast(b)); - } - - friend inline bool operator==(const ChronoType& lhs, const ChronoType& rhs) { - return lhs.value_ == rhs.value_; - } - friend inline bool operator!=(const ChronoType& lhs, const ChronoType& rhs) { - return !(lhs == rhs); - } - friend inline bool operator<(const ChronoType& lhs, const ChronoType& rhs) { - return lhs.value_ < rhs.value_; - } - friend inline bool operator>(const ChronoType& lhs, const ChronoType& rhs) { - return rhs < lhs; - } - friend inline bool operator<=(const ChronoType& lhs, const ChronoType& rhs) { - return !(lhs > rhs); - } - friend inline bool operator>=(const ChronoType& lhs, const ChronoType& rhs) { - return !(lhs < rhs); - } - - friend ChronoType& operator+=(ChronoType& lhs, const ChronoType& rhs) { - static_cast(lhs) += static_cast(rhs); - return lhs; - } - friend ChronoType operator+(const ChronoType& lhs, const ChronoType& rhs) { - return ChronoType(static_cast(lhs) + static_cast(rhs)); - } - - friend ChronoType& operator-=(ChronoType& lhs, const ChronoType& rhs) { - static_cast(lhs) -= static_cast(rhs); - return lhs; - } - friend ChronoType operator-(const ChronoType& lhs, const ChronoType& rhs) { - return ChronoType(static_cast(lhs) - static_cast(rhs)); - } - - private: - T value_; -}; -} // namespace impl - -struct Duration : public impl::ChronoType { - using ChronoType::ChronoType; - explicit operator uint64_t() const noexcept { - if (static_cast(*this) == IREE_DURATION_INFINITE) { - return UINT64_MAX; - } - int64_t relative_ns = static_cast(*this); - return relative_ns <= 0 ? 0 : static_cast(relative_ns); - } -}; - -static inline Duration InfiniteDuration() { - return Duration(IREE_DURATION_INFINITE); -} -static inline Duration ZeroDuration() { return Duration(IREE_DURATION_ZERO); } - -struct Time : public impl::ChronoType { - using ChronoType::ChronoType; - friend Duration operator+(const Time& lhs, const Time& rhs) { - if (static_cast(lhs) == IREE_TIME_INFINITE_FUTURE || - static_cast(rhs) == IREE_TIME_INFINITE_FUTURE) { - return InfiniteDuration(); - } else if (static_cast(lhs) == IREE_TIME_INFINITE_PAST || - static_cast(rhs) == IREE_TIME_INFINITE_PAST) { - return ZeroDuration(); - } - return Duration(static_cast(lhs) + - static_cast(rhs)); - } - friend Duration operator-(const Time& lhs, const Time& rhs) { - if (static_cast(lhs) == IREE_TIME_INFINITE_FUTURE || - static_cast(rhs) == IREE_TIME_INFINITE_FUTURE) { - return InfiniteDuration(); - } else if (static_cast(lhs) == IREE_TIME_INFINITE_PAST || - static_cast(rhs) == IREE_TIME_INFINITE_PAST) { - return ZeroDuration(); - } - return Duration(static_cast(lhs) - - static_cast(rhs)); - } -}; - -static inline Time InfinitePast() { return Time(IREE_TIME_INFINITE_PAST); } -static inline Time InfiniteFuture() { return Time(IREE_TIME_INFINITE_FUTURE); } - -static inline Duration Milliseconds(int64_t millis) { - return Duration(millis * 1000000ull); -} - -// Returns the current system time in unix nanoseconds. -// Depending on the system architecture and power mode this time may have a -// very coarse granularity (on the order of microseconds to milliseconds). -// -// The system timer may not be monotonic; users should ensure when comparing -// times they check for negative values in case the time moves backwards. -static inline Time Now() { return Time(iree_time_now()); } - -// Converts a relative timeout duration to an absolute deadline time. -// This handles the special cases of IREE_DURATION_ZERO and -// IREE_DURATION_INFINITE to avoid extraneous time queries. -static inline Time RelativeTimeoutToDeadlineNanos(Duration timeout_ns) { - return Time(iree_relative_timeout_to_deadline_ns( - static_cast(timeout_ns))); -} - -static inline Duration DeadlineToRelativeTimeoutNanos(Time deadline_ns) { - if (deadline_ns == InfiniteFuture()) { - return InfiniteDuration(); - } else if (deadline_ns == InfinitePast()) { - return ZeroDuration(); - } else { - return Duration(static_cast(deadline_ns - Now())); - } -} - -} // namespace iree - -#endif // IREE_BASE_TIME_H_ diff --git a/iree/base/time_test.cc b/iree/base/time_test.cc deleted file mode 100644 index 114cd4efa16e5..0000000000000 --- a/iree/base/time_test.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/time.h" - -#include "iree/testing/gtest.h" - -namespace iree { -namespace { - -TEST(Time, DurationComparisons) { - EXPECT_TRUE(Milliseconds(123) == Milliseconds(123)); - EXPECT_FALSE(Milliseconds(123) == Milliseconds(456)); - EXPECT_FALSE(Milliseconds(123) != Milliseconds(123)); - EXPECT_TRUE(Milliseconds(123) != Milliseconds(456)); - - EXPECT_TRUE(Milliseconds(123) < Milliseconds(456)); - EXPECT_FALSE(Milliseconds(123) > Milliseconds(456)); - EXPECT_FALSE(Milliseconds(123) > Milliseconds(123)); - EXPECT_FALSE(Milliseconds(123) < Milliseconds(123)); - - EXPECT_TRUE(Milliseconds(123) <= Milliseconds(123)); - EXPECT_TRUE(Milliseconds(123) >= Milliseconds(123)); - EXPECT_TRUE(Milliseconds(123) <= Milliseconds(456)); - EXPECT_FALSE(Milliseconds(123) >= Milliseconds(456)); -} - -TEST(Time, DurationArithmetic) { - EXPECT_EQ(Milliseconds(150), Milliseconds(100) + Milliseconds(50)); - EXPECT_EQ(Milliseconds(50), Milliseconds(100) - Milliseconds(50)); -} - -} // namespace -} // namespace iree diff --git a/iree/base/tracing.h b/iree/base/tracing.h index 24340df28a1a6..b591c298c575a 100644 --- a/iree/base/tracing.h +++ b/iree/base/tracing.h @@ -203,6 +203,14 @@ void iree_tracing_set_thread_name_impl(const char* name); typedef struct ___tracy_source_location_data iree_tracing_location_t; +#ifdef __cplusplus +#define iree_tracing_make_zone_ctx(zone_id) \ + TracyCZoneCtx { zone_id, 1 } +#else +#define iree_tracing_make_zone_ctx(zone_id) \ + (TracyCZoneCtx) { zone_id, 1 } +#endif // __cplusplus + ABSL_MUST_USE_RESULT iree_zone_id_t iree_tracing_zone_begin_impl(const iree_tracing_location_t* src_loc, const char* name, size_t name_length); @@ -309,13 +317,12 @@ enum { name, name_length) // Sets the dynamic color of the zone to an XXBBGGRR value. -#define IREE_TRACE_ZONE_SET_COLOR(zone_id, color_xbgr) \ - ___tracy_emit_zone_color((struct ___tracy_c_zone_context){zone_id, 1}, \ - color_xbgr); +#define IREE_TRACE_ZONE_SET_COLOR(zone_id, color_xbgr) \ + ___tracy_emit_zone_color(iree_tracing_make_zone_ctx(zone_id), color_xbgr); // Appends an integer value to the parent zone. May be called multiple times. #define IREE_TRACE_ZONE_APPEND_VALUE(zone_id, value) \ - ___tracy_emit_zone_value((struct ___tracy_c_zone_context){zone_id, 1}, value); + ___tracy_emit_zone_value(iree_tracing_make_zone_ctx(zone_id), value); // Appends a string value to the parent zone. May be called multiple times. // The |value| string will be copied into the trace buffer. @@ -326,13 +333,13 @@ enum { (__VA_ARGS__) #define IREE_TRACE_ZONE_APPEND_TEXT_CSTRING(zone_id, value) \ IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(zone_id, value, strlen(value)) -#define IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(zone_id, value, value_length) \ - ___tracy_emit_zone_text((struct ___tracy_c_zone_context){zone_id, 1}, value, \ +#define IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(zone_id, value, value_length) \ + ___tracy_emit_zone_text(iree_tracing_make_zone_ctx(zone_id), value, \ value_length) // Ends the current zone. Must be passed the |zone_id| from the _BEGIN. #define IREE_TRACE_ZONE_END(zone_id) \ - ___tracy_emit_zone_end((struct ___tracy_c_zone_context){zone_id, 1}) + ___tracy_emit_zone_end(iree_tracing_make_zone_ctx(zone_id)) // Ends the current zone before returning on a failure. // Sugar for IREE_TRACE_ZONE_END+IREE_RETURN_IF_ERROR. diff --git a/iree/base/wait_handle_posix.c b/iree/base/wait_handle_posix.c index 088f6cc2a2ca8..32cf8a6d490ec 100644 --- a/iree/base/wait_handle_posix.c +++ b/iree/base/wait_handle_posix.c @@ -95,6 +95,7 @@ static iree_status_t iree_wait_primitive_create_pipe( iree_status_t iree_wait_primitive_create_native( bool initial_state, iree_wait_handle_t* out_handle) { + memset(out_handle, 0, sizeof(*out_handle)); #if defined(IREE_HAVE_WAIT_TYPE_EVENTFD) // Always prefer eventfd when present; they rock. return iree_wait_primitive_create_eventfd(initial_state, out_handle); diff --git a/iree/base/wait_handle_win32.c b/iree/base/wait_handle_win32.c index cf91401ef8ca5..8b34c85a2dbb3 100644 --- a/iree/base/wait_handle_win32.c +++ b/iree/base/wait_handle_win32.c @@ -422,6 +422,7 @@ iree_status_t iree_wait_one(iree_wait_handle_t* handle, iree_status_t iree_event_initialize(bool initial_state, iree_event_t* out_event) { + memset(out_event, 0, sizeof(*out_event)); iree_wait_primitive_value_t value; memset(&value, 0, sizeof(value)); value.win32.handle = diff --git a/iree/compiler/Dialect/HAL/IR/HALBase.td b/iree/compiler/Dialect/HAL/IR/HALBase.td index addf6ddff048d..ce858408d6ee0 100644 --- a/iree/compiler/Dialect/HAL/IR/HALBase.td +++ b/iree/compiler/Dialect/HAL/IR/HALBase.td @@ -445,7 +445,7 @@ def HAL_ElementTypeAttr : SignlessIntegerAttrBase< I32, "element type attribute">; def HAL_DeviceSize : TypeAlias; -def HAL_DeviceSizeAttr : IREE_IndexAttrBase<"device_size_t">; +def HAL_DeviceSizeAttr : IREE_IndexAttrBase<"iree_device_size_t">; def HAL_HostSize : TypeAlias; def HAL_HostSizeAttr : IREE_IndexAttrBase<"size_t">; diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD index 8a4365ec0bb7b..3224513f971ec 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/BUILD +++ b/iree/compiler/Dialect/HAL/Target/LLVM/BUILD @@ -22,7 +22,7 @@ package( iree_cmake_extra_content( content = """ -if(NOT ${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}) +if(NOT "${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}") return() endif() """, diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt index 4b55865dc82b4..ed3cdd843110a 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/LLVM/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}) +if(NOT "${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}") return() endif() diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir index ed951d7e686a2..9e07f043a3679 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir +++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/binary_op.mlir @@ -11,7 +11,7 @@ flow.executable @simpleMath_ex_dispatch_0 { } } -// CHECK-LABEL: hal.executable @binary_op_linked_llvm_aot +// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0 // CHECK-DAG: hal.executable.binary attributes { // CHECK-SAME: data = dense // CHECK-SAME: format = 1145850178 : i32} { diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir index 8d00275e913e7..47c89d4743fd8 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir +++ b/iree/compiler/Dialect/HAL/Target/LLVM/test/matmul_op.mlir @@ -11,7 +11,7 @@ flow.executable @simpleMath_ex_dispatch_0 { } } -// CHECK-LABEL: hal.executable @matmul_op_linked_llvm_aot +// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0 // CHECK-DAG: hal.executable.binary attributes { // CHECK-SAME: data = dense // CHECK-SAME: format = 1145850178 : i32} { diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/BUILD index b911d084c689e..bef612323a3fd 100644 --- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/BUILD +++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/BUILD @@ -12,12 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") + package( default_visibility = ["//visibility:public"], features = ["layering_check"], licenses = ["notice"], # Apache 2.0 ) +iree_cmake_extra_content( + content = """ +if(NOT "${IREE_TARGET_BACKEND_METAL-SPIRV}") + return() +endif() +""", +) + cc_library( name = "MetalSPIRV", srcs = ["MetalSPIRVTarget.cpp"], diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt index 2934ce9acb12a..e78415cad7957 100644 --- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_METAL-SPIRV}) +if(NOT "${IREE_TARGET_BACKEND_METAL-SPIRV}") return() endif() diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD index 49f5ccb9dbbe4..370a7dd231012 100644 --- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD +++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD @@ -22,7 +22,8 @@ package( iree_cmake_extra_content( content = """ -if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} AND NOT ${IREE_TARGET_BACKEND_METAL-SPIRV}) +if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}" AND + NOT "${IREE_TARGET_BACKEND_METAL-SPIRV}") return() endif() """, diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt index 3d005759f16db..62f1631501ba3 100644 --- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} AND NOT ${IREE_TARGET_BACKEND_METAL-SPIRV}) +if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}" AND + NOT "${IREE_TARGET_BACKEND_METAL-SPIRV}") return() endif() diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD index 7276df9034645..c3fde9d39563e 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/BUILD +++ b/iree/compiler/Dialect/HAL/Target/VMLA/BUILD @@ -22,7 +22,7 @@ package( iree_cmake_extra_content( content = """ -if(NOT ${IREE_TARGET_BACKEND_VMLA}) +if(NOT "${IREE_TARGET_BACKEND_VMLA}") return() endif() """, diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt index 8f6194d5f192b..b666d40076dd3 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/VMLA/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_VMLA}) +if(NOT "${IREE_TARGET_BACKEND_VMLA}") return() endif() diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir index 771b568398b8b..a1211b5e5ceca 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir +++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir @@ -12,15 +12,15 @@ flow.executable @simpleMath_ex_dispatch_0 { } } -// CHECK-LABEL: hal.executable @linked_vmla -// CHECK-NEXT: hal.interface @legacy_io_0 { +// CHECK-LABEL: hal.executable @simpleMath_ex_dispatch_0 +// CHECK-NEXT: hal.interface @legacy_io { // CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4xf32>) -> tensor<4xf32>} +// CHECK-NEXT: hal.executable.entry_point @simpleMath_rgn_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4xf32>) -> tensor<4xf32>} // CHECK-NEXT: module { -// CHECK-NEXT: vm.module @linked_module { +// CHECK-NEXT: vm.module @module { // CHECK-NEXT: vm.func @simpleMath_rgn_dispatch_0(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { // CHECK-DAG: %zero = vm.const.i32.zero : i32 // CHECK-DAG: %c16 = vm.const.i32 16 : i32 @@ -55,15 +55,15 @@ flow.executable @shaped_dispatch { } } -// CHECK-LABEL: hal.executable @linked_vmla -// CHECK-NEXT: hal.interface @legacy_io_0 attributes {push_constants = 1 : i32} { +// CHECK-LABEL: hal.executable @shaped_dispatch +// CHECK-NEXT: hal.interface @legacy_io attributes {push_constants = 1 : i32} { // CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @entry attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>} +// CHECK-NEXT: hal.executable.entry_point @entry attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x?xf32>, index) -> tensor<4x?xf32>} // CHECK-NEXT: module { -// CHECK-NEXT: vm.module @linked_module { +// CHECK-NEXT: vm.module @module { // CHECK-NEXT: vm.func @entry(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { // CHECK-DAG: %zero = vm.const.i32.zero : i32 // CHECK-DAG: %c16 = vm.const.i32 16 : i32 @@ -97,15 +97,15 @@ flow.executable @reduction_ex_dispatch_0 { } } -// CHECK-LABEL: hal.executable @linked_vmla -// CHECK-NEXT: hal.interface @legacy_io_0 { +// CHECK-LABEL: hal.executable @reduction_ex_dispatch_0 +// CHECK-NEXT: hal.interface @legacy_io { // CHECK-NEXT: hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read" // CHECK-NEXT: hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard" // CHECK-NEXT: } // CHECK-NEXT: hal.executable.target @vmla, filter="vmla" { -// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io_0, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>} +// CHECK-NEXT: hal.executable.entry_point @reduction_ex_dispatch_0 attributes {interface = @legacy_io, ordinal = 0 : i32, signature = (tensor<4x8xf32>) -> tensor<4xf32>} // CHECK-NEXT: module { -// CHECK-NEXT: vm.module @linked_module { +// CHECK-NEXT: vm.module @module { // CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_const dense<0.000000e+00> : tensor<1xf32> // CHECK-NEXT: vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref, %arg1: i32, %arg2: i32, %arg3: i32) { // CHECK-DAG: %zero = vm.const.i32.zero : i32 diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD index 4062aa6722b1b..92792ebb87f43 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD @@ -22,7 +22,7 @@ package( iree_cmake_extra_content( content = """ -if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}) +if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}") return() endif() """, diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt index c230f04f25938..7188ea1443054 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}) +if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}") return() endif() diff --git a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index 521807dbb0057..2bf7223eb9752 100644 --- a/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -140,7 +140,18 @@ class MaterializeResourceCachesPass VariableOp defineExecutableLayoutOp(Location loc, ArrayAttr setLayoutsArrayAttr, IntegerAttr pushConstantsAttr) { - auto existingIt = executableLayoutCache_.find(setLayoutsArrayAttr); + // Push constants are optional but we always provide the value. + if (!pushConstantsAttr) { + pushConstantsAttr = + IntegerAttr::get(IntegerType::get(loc.getContext(), 32), 0); + } + + // We key the layout cache on all attributes that compose an executable + // layout. + auto cacheKey = ArrayAttr::get({setLayoutsArrayAttr, pushConstantsAttr}, + loc.getContext()); + + auto existingIt = executableLayoutCache_.find(cacheKey); if (existingIt != executableLayoutCache_.end()) { return existingIt->second; } @@ -163,7 +174,7 @@ class MaterializeResourceCachesPass loc, symbolName, /*isMutable=*/false, layoutType, StringRef(initializerName), llvm::None); variableOp.setPrivate(); - executableLayoutCache_.try_emplace(setLayoutsArrayAttr, variableOp); + executableLayoutCache_.try_emplace(cacheKey, variableOp); auto initializerOp = moduleBuilder.create( loc, initializerName, moduleBuilder.getFunctionType({}, {layoutType})); @@ -178,10 +189,6 @@ class MaterializeResourceCachesPass setLayoutValues.push_back(setLayoutValue); } auto deviceValue = blockBuilder.createOrFold(loc); - // Push constants are optional but we always provide the value. - if (!pushConstantsAttr) { - pushConstantsAttr = blockBuilder.getI32IntegerAttr(0); - } auto layoutValue = blockBuilder.createOrFold( loc, layoutType, deviceValue, setLayoutValues, pushConstantsAttr); blockBuilder.create(loc, layoutValue); diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 27a87629f542a..5d522a21f887f 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -38,7 +38,7 @@ struct TransformOptions : public PassPipelineOptions { Option linkExecutables{ *this, "link-executables", llvm::cl::desc("Whether to link hal.executable ops together."), - llvm::cl::init(true)}; + llvm::cl::init(false)}; }; } // namespace diff --git a/iree/compiler/Dialect/VMLA/README.md b/iree/compiler/Dialect/VMLA/README.md index c5832ada7eb36..cf59c545186d2 100644 --- a/iree/compiler/Dialect/VMLA/README.md +++ b/iree/compiler/Dialect/VMLA/README.md @@ -33,9 +33,9 @@ TLDR: [VMLAToVM](/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/). 5. Add the runtime C++ kernel thunk to [vmla_module.cc](/iree/hal/vmla/vmla_module.cc). -6. Declare the kernel in [op_kernels.h](/iree/hal/vmla/op_kernels.h) and add a +6. Declare the kernel in [op_kernels.h](/iree/modules/vmla/op_kernels.h) and add a reference implementation in - [op_kernels_generic.h](/iree/hal/vmla/op_kernels_generic.h). + [op_kernels_generic.h](/iree/modules/vmla/op_kernels_generic.h). ### Declaring the Op @@ -139,7 +139,7 @@ There are some helpers such as `IREE_VMLA_BINARY_OP` that match the equivalents in the tablegen file such that if your op can usually be just a single line. The thunks in this file just call one of the kernels defined in the -[op_kernels.h](/iree/hal/vmla/op_kernels.h) file. These kernels are designed to +[op_kernels.h](/iree/modules/vmla/op_kernels.h) file. These kernels are designed to be standalone from the VM code and take effectively just pointers and lists of values. The job of the `vmla_module.cc` thunk is to unwrap the VM arguments and pass them to these functions. @@ -153,7 +153,7 @@ conversion. This ensures that we can optimize things on the compiler-side instead of forcing the runtime to deal with things. Finally, implement the kernel in -[op_kernels_generic.h](/iree/hal/vmla/op_kernels_generic.h). Try to keep it +[op_kernels_generic.h](/iree/modules/vmla/op_kernels_generic.h). Try to keep it simple and readable. These are reference kernels and don't need to be fast, however all of our tests use them and as such they shouldn't be so slow as to prevent tests from running in a reasonable time. Use your judgement or be @@ -161,7 +161,7 @@ willing to have someone file a bug telling you to make them faster if they are terribly slow :) Tests for the kernels can be added to -[op_kernels_test.cc](/iree/hal/vmla/op_kernels_test.cc). The thunks in +[op_kernels_test.cc](/iree/modules/vmla/op_kernels_test.cc). The thunks in `vmla_module.cc` are best tested via end-to-end tests using `iree-run-mlir` as what you really want to ensure is that the compiler is emitting calls that match the runtime side and the only way to do this is to actually compile and run. diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/BUILD b/iree/compiler/Dialect/Vulkan/Utils/test/BUILD index e83bb2fca7ebd..29aa3a5b0bf83 100644 --- a/iree/compiler/Dialect/Vulkan/Utils/test/BUILD +++ b/iree/compiler/Dialect/Vulkan/Utils/test/BUILD @@ -23,7 +23,7 @@ package( iree_cmake_extra_content( content = """ -if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}) +if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}") return() endif() """, diff --git a/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt b/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt index 50453b19443f1..a4f0d04767c37 100644 --- a/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt +++ b/iree/compiler/Dialect/Vulkan/Utils/test/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV}) +if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}") return() endif() diff --git a/iree/hal/BUILD b/iree/hal/BUILD index 367ed3ef94e57..93579136f8db7 100644 --- a/iree/hal/BUILD +++ b/iree/hal/BUILD @@ -30,171 +30,71 @@ package( cc_library( name = "api", srcs = [ - "api.c", - "api.cc", - ], - hdrs = [ - "api.h", - "api_detail.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":hal", - ":heap_buffer", - "//iree/base:api", - "//iree/base:core_headers", - "//iree/base:ref_ptr", - "//iree/base:synchronization", - "//iree/base:threading", - "//iree/base:tracing", - "//iree/hal/host:host_local_allocator", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@half//:includes", - ], -) - -cc_test( - name = "api_string_util_test", - srcs = ["api_string_util_test.cc"], - deps = [ - ":api", - "//iree/base:core_headers", - "//iree/base:status", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - ], -) - -#===------------------------------------------------------------------------===# -# Implementation -#===------------------------------------------------------------------------===# -# TODO(benvanik): rename to :cc and expose via an api_cc.h. - -cc_library( - name = "hal", - srcs = [ - "allocator.cc", - "buffer.cc", - "command_buffer.cc", - "deferred_buffer.cc", - "executable_cache.cc", - ], - hdrs = [ + "allocator.c", "allocator.h", + "allocator_heap.c", + "buffer.c", "buffer.h", + "buffer_heap.c", + "buffer_view.c", + "buffer_view.cc", + "buffer_view.h", + "command_buffer.c", "command_buffer.h", - "command_queue.h", - "debug_capture_manager.h", - "deferred_buffer.h", + "command_buffer_validation.c", + "descriptor_set.c", "descriptor_set.h", + "descriptor_set_layout.c", "descriptor_set_layout.h", + "detail.h", + "device.c", "device.h", - "device_info.h", - "device_placement.h", + "driver.c", "driver.h", + "driver_registry.c", + "driver_registry.h", + "event.c", "event.h", + "executable.c", "executable.h", + "executable_cache.c", "executable_cache.h", - "executable_format.h", + "executable_layout.c", "executable_layout.h", - "executable_spec.h", "resource.h", + "semaphore.c", "semaphore.h", - "stack_trace.h", + "string_util.cc", + "string_util.h", + ], + hdrs = [ + "api.h", ], + visibility = ["//visibility:public"], deps = [ + "//iree/base:api", "//iree/base:core_headers", - "//iree/base:logging", "//iree/base:ref_ptr", - "//iree/base:status", - "//iree/base:time", + "//iree/base:synchronization", + "//iree/base:threading", "//iree/base:tracing", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@half//:includes", ], ) cc_test( - name = "buffer_test", - srcs = [ - "buffer_mapping_test.cc", - "buffer_test.cc", - ], + name = "string_util_test", + srcs = ["string_util_test.cc"], deps = [ - ":hal", - ":heap_buffer", + ":api", + "//iree/base:core_headers", "//iree/base:status", "//iree/testing:gtest", "//iree/testing:gtest_main", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "deferred_buffer_test", - srcs = ["deferred_buffer_test.cc"], - deps = [ - ":hal", - ":heap_buffer", - "//iree/hal/testing:mock_allocator", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - "@com_google_absl//absl/memory", - ], -) - -#===------------------------------------------------------------------------===# -# Debugging utilities and tools -#===------------------------------------------------------------------------===# - -cc_library( - name = "command_buffer_validation", - srcs = ["command_buffer_validation.cc"], - hdrs = ["command_buffer_validation.h"], - deps = [ - ":hal", - "//iree/base:logging", - "//iree/base:status", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", ], ) - -#===------------------------------------------------------------------------===# -# Internal device management and driver registry -#===------------------------------------------------------------------------===# -# TODO(benvanik): port these to C and merge into main API. - -cc_library( - name = "device_manager", - srcs = ["device_manager.cc"], - hdrs = ["device_manager.h"], - deps = [ - ":hal", - ":heap_buffer", - "//iree/base:status", - "//iree/base:time", - "//iree/base:tracing", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - ], -) - -#===------------------------------------------------------------------------===# -# Internal implementation details -#===------------------------------------------------------------------------===# - -cc_library( - name = "heap_buffer", - srcs = ["heap_buffer.cc"], - hdrs = ["heap_buffer.h"], - deps = [ - ":hal", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal/host:host_buffer", - ], -) diff --git a/iree/hal/CMakeLists.txt b/iree/hal/CMakeLists.txt index 438ce0fa6f22c..b042643448ac1 100644 --- a/iree/hal/CMakeLists.txt +++ b/iree/hal/CMakeLists.txt @@ -19,157 +19,67 @@ iree_cc_library( api HDRS "api.h" - "api_detail.h" SRCS - "api.c" - "api.cc" - DEPS - ::hal - ::heap_buffer - absl::inlined_vector - absl::span - absl::strings - iree::base::api - iree::base::core_headers - iree::base::ref_ptr - iree::base::synchronization - iree::base::threading - iree::base::tracing - iree::hal::host::host_local_allocator - PUBLIC -) - -iree_cc_test( - NAME - api_string_util_test - SRCS - "api_string_util_test.cc" - DEPS - ::api - absl::inlined_vector - absl::strings - iree::base::core_headers - iree::base::status - iree::testing::gtest - iree::testing::gtest_main -) - -iree_cc_library( - NAME - hal - HDRS + "allocator.c" "allocator.h" + "allocator_heap.c" + "buffer.c" "buffer.h" + "buffer_heap.c" + "buffer_view.c" + "buffer_view.cc" + "buffer_view.h" + "command_buffer.c" "command_buffer.h" - "command_queue.h" - "debug_capture_manager.h" - "deferred_buffer.h" + "command_buffer_validation.c" + "descriptor_set.c" "descriptor_set.h" + "descriptor_set_layout.c" "descriptor_set_layout.h" + "detail.h" + "device.c" "device.h" - "device_info.h" - "device_placement.h" + "driver.c" "driver.h" + "driver_registry.c" + "driver_registry.h" + "event.c" "event.h" + "executable.c" "executable.h" + "executable_cache.c" "executable_cache.h" - "executable_format.h" + "executable_layout.c" "executable_layout.h" - "executable_spec.h" "resource.h" + "semaphore.c" "semaphore.h" - "stack_trace.h" - SRCS - "allocator.cc" - "buffer.cc" - "command_buffer.cc" - "deferred_buffer.cc" - "executable_cache.cc" + "string_util.cc" + "string_util.h" DEPS + absl::inlined_vector absl::span absl::strings + iree::base::api iree::base::core_headers - iree::base::logging iree::base::ref_ptr - iree::base::status - iree::base::time + iree::base::synchronization + iree::base::threading iree::base::tracing PUBLIC ) iree_cc_test( NAME - buffer_test + string_util_test SRCS - "buffer_mapping_test.cc" - "buffer_test.cc" + "string_util_test.cc" DEPS - ::hal - ::heap_buffer - absl::span + ::api + absl::inlined_vector + absl::strings + iree::base::core_headers iree::base::status iree::testing::gtest iree::testing::gtest_main ) - -iree_cc_test( - NAME - deferred_buffer_test - SRCS - "deferred_buffer_test.cc" - DEPS - ::hal - ::heap_buffer - absl::memory - iree::hal::testing::mock_allocator - iree::testing::gtest - iree::testing::gtest_main -) - -iree_cc_library( - NAME - command_buffer_validation - HDRS - "command_buffer_validation.h" - SRCS - "command_buffer_validation.cc" - DEPS - ::hal - absl::strings - iree::base::logging - iree::base::status - PUBLIC -) - -iree_cc_library( - NAME - device_manager - HDRS - "device_manager.h" - SRCS - "device_manager.cc" - DEPS - ::hal - ::heap_buffer - absl::span - absl::synchronization - iree::base::status - iree::base::time - iree::base::tracing - PUBLIC -) - -iree_cc_library( - NAME - heap_buffer - HDRS - "heap_buffer.h" - SRCS - "heap_buffer.cc" - DEPS - ::hal - iree::base::status - iree::base::tracing - iree::hal::host::host_buffer - PUBLIC -) diff --git a/iree/hal/allocator.c b/iree/hal/allocator.c new file mode 100644 index 0000000000000..f555a30390ee4 --- /dev/null +++ b/iree/hal/allocator.c @@ -0,0 +1,70 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/allocator.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" + +#define _VTABLE_DISPATCH(allocator, method_name) \ + IREE_HAL_VTABLE_DISPATCH(allocator, iree_hal_allocator, method_name) + +IREE_HAL_API_RETAIN_RELEASE(allocator); + +IREE_API_EXPORT iree_allocator_t IREE_API_CALL +iree_hal_allocator_host_allocator(const iree_hal_allocator_t* allocator) { + IREE_ASSERT_ARGUMENT(allocator); + return _VTABLE_DISPATCH(allocator, host_allocator)(allocator); +} + +IREE_API_EXPORT iree_hal_buffer_compatibility_t +iree_hal_allocator_query_buffer_compatibility( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t intended_usage, + iree_device_size_t allocation_size) { + IREE_ASSERT_ARGUMENT(allocator); + return _VTABLE_DISPATCH(allocator, query_buffer_compatibility)( + allocator, memory_type, allowed_usage, intended_usage, allocation_size); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_allocate_buffer( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size, + iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(allocator); + IREE_ASSERT_ARGUMENT(out_buffer); + *out_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(allocator, allocate_buffer)( + allocator, memory_type, allowed_usage, allocation_size, out_buffer); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_wrap_buffer( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data, + iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(allocator); + IREE_ASSERT_ARGUMENT(out_buffer); + *out_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(allocator, wrap_buffer)( + allocator, memory_type, allowed_access, allowed_usage, data, + data_allocator, out_buffer); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/allocator.cc b/iree/hal/allocator.cc deleted file mode 100644 index af219b3d32360..0000000000000 --- a/iree/hal/allocator.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/allocator.h" - -#include -#include -#include -#include - -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -bool Allocator::CanUseBuffer(Buffer* buffer, - BufferUsageBitfield intended_usage) const { - return CanUseBufferLike(buffer->allocator(), buffer->memory_type(), - buffer->usage(), intended_usage); -} - -StatusOr> Allocator::AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr source_buffer) { - if (AnyBitSet(source_buffer->usage() & BufferUsage::kConstant) && - CanUseBuffer(source_buffer.get(), buffer_usage)) { - // Buffer can be used directly by the device. - return source_buffer; - } - - IREE_TRACE_SCOPE0("Allocator::AllocateConstant"); - - // We need to map so we can copy into it. - buffer_usage |= BufferUsage::kMapping; - // It will be constant after we write it. - buffer_usage |= BufferUsage::kConstant; - - MemoryTypeBitfield memory_type = - MemoryType::kDeviceLocal | MemoryType::kHostVisible; - IREE_ASSIGN_OR_RETURN( - auto device_buffer, - Allocate(memory_type, buffer_usage, source_buffer->byte_length())); - IREE_ASSIGN_OR_RETURN(auto source_mapping, - source_buffer->MapMemory(MemoryAccess::kRead)); - IREE_RETURN_IF_ERROR(device_buffer->WriteData(0, source_mapping.data(), - source_mapping.byte_length())); - return device_buffer; -} - -StatusOr> Allocator::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - const void* data, - size_t data_length) { - return WrapMutable(memory_type, MemoryAccess::kRead, buffer_usage, - const_cast(data), data_length); -} - -StatusOr> Allocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Allocator does not support wrapping host memory"; -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/allocator.h b/iree/hal/allocator.h index 11f65ec53c216..446e2a3e6b112 100644 --- a/iree/hal/allocator.h +++ b/iree/hal/allocator.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,125 +15,162 @@ #ifndef IREE_HAL_ALLOCATOR_H_ #define IREE_HAL_ALLOCATOR_H_ -#include -#include +#include +#include -#include "absl/types/span.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" +#include "iree/base/api.h" #include "iree/hal/buffer.h" +#include "iree/hal/resource.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +// A bitfield indicating compatible behavior for buffers in an allocator. +enum iree_hal_buffer_compatibility_e { + // Indicates (in the absence of other bits) the buffer is not compatible with + // the allocator or device at all. Any attempts to use the buffer for any + // usage will fail. This will happen if the buffer is device-local to another + // device without peering and not visible to the host. + IREE_HAL_BUFFER_COMPATIBILITY_NONE = 0u, + + // Indicates that the allocator could allocate new buffers of this type and + // usage natively. Allocations with the queried parameters may still fail due + // to runtime conditions (out of memory, fragmentation, etc) but are otherwise + // valid. + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE = 1u << 0, + + // Indicates that the buffer can be used as a transfer source or target on the + // a device queue (such as being the source or target of a DMA operation, + // etc). If not set then the buffer may still be usable for + // iree_hal_buffer_copy_data but not with queued operations. + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER = 1u << 1, + + // Indicates that the buffer can be used as an input/output to a dispatch. + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH = 1u << 2, +}; +typedef uint32_t iree_hal_buffer_compatibility_t; + +//===----------------------------------------------------------------------===// +// iree_hal_allocator_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_allocator_s iree_hal_allocator_t; + +// Retains the given |allocator| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_allocator_retain(iree_hal_allocator_t* allocator); + +// Releases the given |allocator| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_allocator_release(iree_hal_allocator_t* allocator); -namespace iree { -namespace hal { +// Returns the host allocator used for allocating host objects. +IREE_API_EXPORT iree_allocator_t IREE_API_CALL +iree_hal_allocator_host_allocator(const iree_hal_allocator_t* allocator); -// Allocates buffers for a particular device memory space. +// Returns a bitmask indicating what operations with buffers of the given type +// are available on the allocator. // -// Buffers allocated are only guaranteed to work with the driver that the -// allocator services. Any attempt to use buffers on drivers they were not -// allocated from must first be checked with CanUseBuffer. +// For buffers allocated from the given allocator it's expected that the result +// will always be non-NONE. For buffers that originate from another allocator +// there may be limited support for cross-device usage. // -// Thread-safe. -class Allocator : public RefObject { - public: - virtual ~Allocator() = default; - - // Returns true if the device can use the given buffer for the provided usage. - // For buffers allocated from this allocator it's expected that the result - // will always be true. For buffers that originate from another allocator - // there may be limited support for cross-device usage. - // - // Returning false indicates that the buffer must be transferred externally - // into a buffer compatible with the device this allocator services. - bool CanUseBuffer(Buffer* buffer, BufferUsageBitfield intended_usage) const; - virtual bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const = 0; - - // Returns true if the allocator can allocate a buffer with the given - // attributes. - virtual bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const = 0; - - // Adjusts allocation parameters to be compatible with the allocator. - // Certain allocators may require particular memory types to function. By - // adjusting the parameters prior to allocation callers can be sure they are - // able to successfully Allocate a buffer later on with the same parameters. - virtual Status MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const { - return OkStatus(); - } - - // Allocates a buffer from the allocator. - // Fails if the memory type requested for the given usage cannot be serviced. - // Callers can use CanAllocate to decide their memory use strategy. - // - // The memory type of the buffer returned may differ from the requested value - // if the device can provide more functionality; for example, if requesting - // MemoryType::kHostVisible but the memory is really host cached you may get - // a buffer back with MemoryType::kHostVisible | MemoryType::kHostCached. The - // only requirement is that the buffer satisfy the required bits. - virtual StatusOr> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) = 0; - - // Allocates a buffer from the allocator for use as a constant value. - // The provided |source_buffer| may be returned if the device can use it - // directly and otherwise will be copied. - virtual StatusOr> AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr source_buffer); - - // Wraps an existing host heap allocation in a buffer. - // Ownership of the host allocation remains with the caller and the memory - // must remain valid for so long as the Buffer may be in use. - // Will have MemoryType::kHostLocal in most cases and may not be usable - // by the device. - // - // The inference optimizer makes assumptions about buffer aliasing based on - // Buffer instances and because of this wrapping the same host buffer in - // multiple Buffers will create potential memory aliasing issues that can be - // difficult to track down. There's no checking as to whether a host buffer - // has already been wrapped so it's best for callers to ensure this is never - // possible (the simplest way being to never use Wrap and always just allocate - // new Buffers). - // - // Fails if the allocator cannot access host memory in this way. - StatusOr> Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - const void* data, size_t data_length); - virtual StatusOr> WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length); - template - StatusOr> Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - absl::Span data); - template - StatusOr> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - absl::Span data); -}; - -// Inline functions and template definitions follow: - -template -StatusOr> Allocator::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - absl::Span data) { - return Wrap(memory_type, buffer_usage, data.data(), data.size() * sizeof(T)); -} - -template -StatusOr> Allocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, absl::Span data) { - return WrapMutable(memory_type, allowed_access, buffer_usage, data.data(), - data.size() * sizeof(T)); -} - -} // namespace hal -} // namespace iree +// Returning IREE_HAL_BUFFER_COMPATIBILITY_NONE indicates that the buffer must +// be transferred externally into a buffer compatible with the device the +// allocator services. +IREE_API_EXPORT iree_hal_buffer_compatibility_t +iree_hal_allocator_query_buffer_compatibility( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t intended_usage, iree_device_size_t allocation_size); + +// Allocates a buffer from the allocator. +// Fails if the memory type requested for the given usage cannot be serviced. +// Callers can use iree_hal_allocator_can_allocate to decide their memory use +// strategy. +// +// The memory type of the buffer returned may differ from the requested value +// if the device can provide more functionality; for example, if requesting +// IREE_HAL_MEMORY_TYPE_HOST_VISIBLE but the memory is really host cached you +// may get a buffer back with IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | +// IREE_HAL_MEMORY_TYPE_HOST_CACHED. The only requirement is that the buffer +// satisfy the required bits. +// +// Fails if it is not possible to allocate and satisfy all placements for the +// requested |allowed_usage|. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_allocate_buffer( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size, + iree_hal_buffer_t** out_buffer); + +// Wraps an existing host allocation in a buffer. +// |data_allocator| will be used to free the memory when the buffer is +// destroyed. iree_allocator_null() can be passed to indicate the buffer does +// not own the data. +// +// Fails if the allocator cannot access host memory in this way. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_wrap_buffer( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data, + iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer); + +//===----------------------------------------------------------------------===// +// iree_hal_heap_allocator_t +//===----------------------------------------------------------------------===// + +// Creates a host-local heap allocator that can be used when buffers are +// required that will not interact with a real hardware device (such as those +// used in file IO or tests). Buffers allocated with this will not be compatible +// with real device allocators and will likely incur a copy (or failure) if +// used. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_create_heap( + iree_string_view_t identifier, iree_allocator_t host_allocator, + iree_hal_allocator_t** out_allocator); + +//===----------------------------------------------------------------------===// +// iree_hal_allocator_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_allocator_t* allocator); + + iree_allocator_t(IREE_API_PTR* host_allocator)( + const iree_hal_allocator_t* allocator); + + iree_hal_buffer_compatibility_t(IREE_API_PTR* query_buffer_compatibility)( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t intended_usage, + iree_device_size_t allocation_size); + + iree_status_t(IREE_API_PTR* allocate_buffer)( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size, + iree_hal_buffer_t** out_buffer); + + iree_status_t(IREE_API_PTR* wrap_buffer)( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data, + iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer); +} iree_hal_allocator_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_allocator_destroy(iree_hal_allocator_t* allocator); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_ALLOCATOR_H_ diff --git a/iree/hal/allocator_heap.c b/iree/hal/allocator_heap.c new file mode 100644 index 0000000000000..5a223351f08b8 --- /dev/null +++ b/iree/hal/allocator_heap.c @@ -0,0 +1,167 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/base/tracing.h" +#include "iree/hal/allocator.h" +#include "iree/hal/detail.h" + +typedef struct iree_hal_heap_allocator_s { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + iree_string_view_t identifier; +} iree_hal_heap_allocator_t; + +static const iree_hal_allocator_vtable_t iree_hal_heap_allocator_vtable; + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_create_heap( + iree_string_view_t identifier, iree_allocator_t host_allocator, + iree_hal_allocator_t** out_allocator) { + IREE_ASSERT_ARGUMENT(out_allocator); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_heap_allocator_t* allocator = NULL; + iree_host_size_t total_size = sizeof(*allocator) + identifier.size; + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&allocator); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_heap_allocator_vtable, + &allocator->resource); + allocator->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &allocator->identifier, + (char*)allocator + total_size - identifier.size); + *out_allocator = (iree_hal_allocator_t*)allocator; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_heap_allocator_destroy( + iree_hal_allocator_t* base_allocator) { + iree_hal_heap_allocator_t* allocator = + (iree_hal_heap_allocator_t*)base_allocator; + iree_allocator_t host_allocator = allocator->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, allocator); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_allocator_t iree_hal_heap_allocator_host_allocator( + const iree_hal_allocator_t* base_allocator) { + iree_hal_heap_allocator_t* allocator = + (iree_hal_heap_allocator_t*)base_allocator; + return allocator->host_allocator; +} + +static iree_hal_buffer_compatibility_t +iree_hal_heap_allocator_query_buffer_compatibility( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t intended_usage, + iree_device_size_t allocation_size) { + // Disallow usage not permitted by the buffer itself. Since we then use this + // to determine compatibility below we'll naturally set the right compat flags + // based on what's both allowed and intended. + intended_usage &= allowed_usage; + + // All buffers can be allocated on the heap. + iree_hal_buffer_compatibility_t compatibility = + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; + + // Buffers can only be used on the queue if they are device visible. + // This is not a strict requirement of heap buffers but matches devices that + // have discrete memory spaces (remoting/sandboxed, GPUs, etc) and makes it + // much easier to find issues of buffer definition with local devices that + // will cause issues when used with real devices. + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } + if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; + } + } + + return compatibility; +} + +static iree_status_t iree_hal_heap_allocator_make_compatible( + iree_hal_memory_type_t* memory_type, + iree_hal_memory_access_t* allowed_access, + iree_hal_buffer_usage_t* allowed_usage) { + // Always ensure we are host-visible. + *memory_type |= IREE_HAL_MEMORY_TYPE_HOST_VISIBLE; + + // Host currently uses mapping to copy buffers, which is done a lot. + // We could probably remove this mutation by preventing copies in those cases. + *allowed_usage |= IREE_HAL_BUFFER_USAGE_MAPPING; + + // TODO(benvanik): check if transfer is still required for DMA copy source. + *allowed_usage |= IREE_HAL_BUFFER_USAGE_TRANSFER; + + return iree_ok_status(); +} + +static iree_status_t iree_hal_heap_allocator_allocate_buffer( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size, + iree_hal_buffer_t** out_buffer) { + iree_hal_heap_allocator_t* allocator = + (iree_hal_heap_allocator_t*)base_allocator; + + // Coerce options into those required for use by heap-based devices. + iree_hal_memory_access_t allowed_access = IREE_HAL_MEMORY_ACCESS_ALL; + IREE_RETURN_IF_ERROR(iree_hal_heap_allocator_make_compatible( + &memory_type, &allowed_access, &allowed_usage)); + + iree_byte_span_t data = iree_make_byte_span(NULL, allocation_size); + if (allocation_size > 0) { + // Zero-length buffers are valid but we don't want to try to malloc them. + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + allocator->host_allocator, allocation_size, (void**)&data.data)); + } + iree_status_t status = iree_hal_heap_buffer_wrap( + base_allocator, memory_type, allowed_access, allowed_usage, + allocation_size, data, allocator->host_allocator, out_buffer); + if (!iree_status_is_ok(status)) { + iree_allocator_free(allocator->host_allocator, data.data); + } + return status; +} + +static iree_status_t iree_hal_heap_allocator_wrap_buffer( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data, + iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) { + // Coerce options into those required for use by heap-based devices. + IREE_RETURN_IF_ERROR(iree_hal_heap_allocator_make_compatible( + &memory_type, &allowed_access, &allowed_usage)); + + return iree_hal_heap_buffer_wrap(base_allocator, memory_type, allowed_access, + allowed_usage, data.data_length, data, + data_allocator, out_buffer); +} + +static const iree_hal_allocator_vtable_t iree_hal_heap_allocator_vtable = { + .destroy = iree_hal_heap_allocator_destroy, + .host_allocator = iree_hal_heap_allocator_host_allocator, + .query_buffer_compatibility = + iree_hal_heap_allocator_query_buffer_compatibility, + .allocate_buffer = iree_hal_heap_allocator_allocate_buffer, + .wrap_buffer = iree_hal_heap_allocator_wrap_buffer, +}; diff --git a/iree/hal/api.cc b/iree/hal/api.cc deleted file mode 100644 index 65b310f8a1033..0000000000000 --- a/iree/hal/api.cc +++ /dev/null @@ -1,2042 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/api.h" - -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/strings/ascii.h" -#include "absl/strings/match.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "absl/strings/strip.h" -#include "absl/types/span.h" -#include "iree/base/api.h" -#include "iree/base/memory.h" -#include "iree/base/tracing.h" -#include "iree/hal/api_detail.h" -#include "iree/hal/buffer.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/device.h" -#include "iree/hal/driver.h" -#include "iree/hal/heap_buffer.h" -#include "iree/hal/host/host_local_allocator.h" -#include "iree/hal/semaphore.h" -#include "third_party/half/half.hpp" - -namespace iree { -namespace hal { - -// Defines the iree_hal__retain/_release methods. -#define IREE_HAL_API_RETAIN_RELEASE(type_name, cc_type) \ - IREE_API_EXPORT void iree_hal_##type_name##_retain( \ - iree_hal_##type_name##_t* type_name) { \ - auto* handle = reinterpret_cast(type_name); \ - if (handle) handle->AddReference(); \ - } \ - IREE_API_EXPORT void iree_hal_##type_name##_release( \ - iree_hal_##type_name##_t* type_name) { \ - auto* handle = reinterpret_cast(type_name); \ - if (handle) handle->ReleaseReference(); \ - } - -//===----------------------------------------------------------------------===// -// Utilities -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_shape( - iree_string_view_t value, iree_host_size_t shape_capacity, - iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) { - IREE_ASSERT_ARGUMENT(out_shape_rank); - *out_shape_rank = 0; - - auto str_value = absl::string_view(value.data, value.size); - if (str_value.empty()) { - return iree_ok_status(); // empty shape - } - - absl::InlinedVector dims; - for (auto dim_str : absl::StrSplit(str_value, 'x')) { - int dim_value = 0; - if (!absl::SimpleAtoi(dim_str, &dim_value)) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "shape[%zu] invalid value '%.*s' of '%.*s'", - dims.size(), (int)dim_str.size(), dim_str.data(), - (int)value.size, value.data); - } - if (dim_value < 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "shape[%zu] unsupported value %d of '%.*s'", - dims.size(), dim_value, (int)value.size, - value.data); - } - dims.push_back(dim_value); - } - if (out_shape_rank) { - *out_shape_rank = dims.size(); - } - if (dims.size() > shape_capacity) { - return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); - } - if (out_shape) { - std::memcpy(out_shape, dims.data(), dims.size() * sizeof(*out_shape)); - } - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_format_shape(const iree_hal_dim_t* shape, iree_host_size_t shape_rank, - iree_host_size_t buffer_capacity, char* buffer, - iree_host_size_t* out_buffer_length) { - if (out_buffer_length) { - *out_buffer_length = 0; - } - iree_host_size_t buffer_length = 0; - for (iree_host_size_t i = 0; i < shape_rank; ++i) { - int n = std::snprintf(buffer ? buffer + buffer_length : nullptr, - buffer ? buffer_capacity - buffer_length : 0, - (i < shape_rank - 1) ? "%dx" : "%d", shape[i]); - if (n < 0) { - return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, - "snprintf failed to write dimension %zu", i); - } else if (buffer && n >= buffer_capacity - buffer_length) { - buffer = nullptr; - } - buffer_length += n; - } - if (out_buffer_length) { - *out_buffer_length = buffer_length; - } - return buffer ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element_type( - iree_string_view_t value, iree_hal_element_type_t* out_element_type) { - IREE_ASSERT_ARGUMENT(out_element_type); - *out_element_type = IREE_HAL_ELEMENT_TYPE_NONE; - - auto str_value = absl::string_view(value.data, value.size); - - iree_hal_numerical_type_t numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN; - if (absl::StartsWith(str_value, "i")) { - numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED; - str_value.remove_prefix(1); - } else if (absl::StartsWith(str_value, "u")) { - numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED; - str_value.remove_prefix(1); - } else if (absl::StartsWith(str_value, "f")) { - numerical_type = IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE; - str_value.remove_prefix(1); - } else if (absl::StartsWith(str_value, "x") || - absl::StartsWith(str_value, "*")) { - numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN; - str_value.remove_prefix(1); - } else { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "unhandled element type prefix in '%.*s'", - (int)value.size, value.data); - } - - uint32_t bit_count = 0; - if (!absl::SimpleAtoi(str_value, &bit_count) || bit_count > 0xFFu) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "out of range bit count in '%.*s'", (int)value.size, - value.data); - } - - *out_element_type = iree_hal_make_element_type(numerical_type, bit_count); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element_type( - iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length) { - if (out_buffer_length) { - *out_buffer_length = 0; - } - const char* prefix; - switch (iree_hal_element_numerical_type(element_type)) { - case IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED: - prefix = "i"; - break; - case IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED: - prefix = "u"; - break; - case IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE: - prefix = "f"; - break; - default: - prefix = "*"; - break; - } - int n = std::snprintf( - buffer, buffer_capacity, "%s%d", prefix, - static_cast(iree_hal_element_bit_count(element_type))); - if (n < 0) { - return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "snprintf failed"); - } - if (out_buffer_length) { - *out_buffer_length = n; - } - return n >= buffer_capacity ? iree_status_from_code(IREE_STATUS_OUT_OF_RANGE) - : iree_ok_status(); -} - -// Parses a string of two character pairs representing hex numbers into bytes. -static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to, - ptrdiff_t num) { - /* clang-format off */ - static constexpr char kHexValue[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9' - 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F' - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f' - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - }; - /* clang-format on */ - for (int i = 0; i < num; i++) { - to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) + - (kHexValue[from[i * 2 + 1] & 0xFF]); - } -} - -// Parses a signal element string, assuming that the caller has validated that -// |out_data| has enough storage space for the parsed element data. -static iree_status_t iree_hal_parse_element_unsafe( - iree_string_view_t data_str, iree_hal_element_type_t element_type, - uint8_t* out_data) { - switch (element_type) { - case IREE_HAL_ELEMENT_TYPE_SINT_8: { - int32_t temp = 0; - if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - &temp) || - temp > INT8_MAX) { - return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - } - *reinterpret_cast(out_data) = static_cast(temp); - return iree_ok_status(); - } - case IREE_HAL_ELEMENT_TYPE_UINT_8: { - uint32_t temp = 0; - if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - &temp) || - temp > UINT8_MAX) { - return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - } - *reinterpret_cast(out_data) = static_cast(temp); - return iree_ok_status(); - } - case IREE_HAL_ELEMENT_TYPE_SINT_16: { - int32_t temp = 0; - if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - &temp) || - temp > INT16_MAX) { - return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - } - *reinterpret_cast(out_data) = static_cast(temp); - return iree_ok_status(); - } - case IREE_HAL_ELEMENT_TYPE_UINT_16: { - uint32_t temp = 0; - if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - &temp) || - temp > UINT16_MAX) { - return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - } - *reinterpret_cast(out_data) = static_cast(temp); - return iree_ok_status(); - } - case IREE_HAL_ELEMENT_TYPE_SINT_32: - return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - reinterpret_cast(out_data)) - ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - case IREE_HAL_ELEMENT_TYPE_UINT_32: - return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - reinterpret_cast(out_data)) - ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - case IREE_HAL_ELEMENT_TYPE_SINT_64: - return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - reinterpret_cast(out_data)) - ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - case IREE_HAL_ELEMENT_TYPE_UINT_64: - return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), - reinterpret_cast(out_data)) - ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - case IREE_HAL_ELEMENT_TYPE_FLOAT_16: { - float temp = 0; - if (!absl::SimpleAtof(absl::string_view(data_str.data, data_str.size), - &temp)) { - return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - } - *reinterpret_cast(out_data) = - half_float::detail::float2half(temp); - return iree_ok_status(); - } - case IREE_HAL_ELEMENT_TYPE_FLOAT_32: - return absl::SimpleAtof(absl::string_view(data_str.data, data_str.size), - reinterpret_cast(out_data)) - ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - case IREE_HAL_ELEMENT_TYPE_FLOAT_64: - return absl::SimpleAtod(absl::string_view(data_str.data, data_str.size), - reinterpret_cast(out_data)) - ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); - default: { - // Treat any unknown format as binary. - iree_host_size_t element_size = iree_hal_element_byte_count(element_type); - if (data_str.size != element_size * 2) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "binary hex element count mismatch: buffer " - "length=%zu < expected=%zu", - data_str.size, element_size * 2); - } - iree_hal_hex_string_to_bytes(data_str.data, out_data, element_size); - return iree_ok_status(); - } - } -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element( - iree_string_view_t data_str, iree_hal_element_type_t element_type, - iree_byte_span_t data_ptr) { - iree_host_size_t element_size = iree_hal_element_byte_count(element_type); - if (data_ptr.data_length < element_size) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "output data buffer overflow: data_length=%zu < element_size=%zu", - data_ptr.data_length, element_size); - } - return iree_hal_parse_element_unsafe(data_str, element_type, data_ptr.data); -} - -// Converts a sequence of bytes into hex number strings. -static void iree_hal_bytes_to_hex_string(const uint8_t* src, char* dest, - ptrdiff_t num) { - static constexpr char kHexTable[513] = - "000102030405060708090A0B0C0D0E0F" - "101112131415161718191A1B1C1D1E1F" - "202122232425262728292A2B2C2D2E2F" - "303132333435363738393A3B3C3D3E3F" - "404142434445464748494A4B4C4D4E4F" - "505152535455565758595A5B5C5D5E5F" - "606162636465666768696A6B6C6D6E6F" - "707172737475767778797A7B7C7D7E7F" - "808182838485868788898A8B8C8D8E8F" - "909192939495969798999A9B9C9D9E9F" - "A0A1A2A3A4A5A6A7A8A9AAABACADAEAF" - "B0B1B2B3B4B5B6B7B8B9BABBBCBDBEBF" - "C0C1C2C3C4C5C6C7C8C9CACBCCCDCECF" - "D0D1D2D3D4D5D6D7D8D9DADBDCDDDEDF" - "E0E1E2E3E4E5E6E7E8E9EAEBECEDEEEF" - "F0F1F2F3F4F5F6F7F8F9FAFBFCFDFEFF"; - for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest += 2) { - const char* hex_p = &kHexTable[*src_ptr * 2]; - std::copy(hex_p, hex_p + 2, dest); - } -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element( - iree_const_byte_span_t data, iree_hal_element_type_t element_type, - iree_host_size_t buffer_capacity, char* buffer, - iree_host_size_t* out_buffer_length) { - iree_host_size_t element_size = iree_hal_element_byte_count(element_type); - if (data.data_length < element_size) { - return iree_make_status( - IREE_STATUS_OUT_OF_RANGE, - "data buffer underflow: data_length=%zu < element_size=%zu", - data.data_length, element_size); - } - int n = 0; - switch (element_type) { - case IREE_HAL_ELEMENT_TYPE_SINT_8: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi8, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_UINT_8: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu8, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_SINT_16: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi16, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_UINT_16: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu16, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_SINT_32: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi32, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_UINT_32: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu32, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_SINT_64: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi64, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_UINT_64: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu64, - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_FLOAT_16: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", - half_float::detail::half2float( - *reinterpret_cast(data.data))); - break; - case IREE_HAL_ELEMENT_TYPE_FLOAT_32: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", - *reinterpret_cast(data.data)); - break; - case IREE_HAL_ELEMENT_TYPE_FLOAT_64: - n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", - *reinterpret_cast(data.data)); - break; - default: { - // Treat any unknown format as binary. - n = 2 * (int)element_size; - if (buffer && buffer_capacity > n) { - iree_hal_bytes_to_hex_string(data.data, buffer, element_size); - buffer[n] = 0; - } - } - } - if (n < 0) { - return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "snprintf failed"); - } else if (buffer && n >= buffer_capacity) { - buffer = nullptr; - } - if (out_buffer_length) { - *out_buffer_length = n; - } - return buffer ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_buffer_elements( - iree_string_view_t data_str, iree_hal_element_type_t element_type, - iree_byte_span_t data_ptr) { - IREE_TRACE_SCOPE0("iree_hal_parse_buffer_elements"); - iree_host_size_t element_size = iree_hal_element_byte_count(element_type); - iree_host_size_t element_capacity = data_ptr.data_length / element_size; - if (iree_string_view_is_empty(data_str)) { - memset(data_ptr.data, 0, data_ptr.data_length); - return iree_ok_status(); - } - size_t src_i = 0; - size_t dst_i = 0; - size_t token_start = std::string::npos; - while (src_i < data_str.size) { - char c = data_str.data[src_i++]; - bool is_separator = - absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']'; - if (token_start == std::string::npos) { - if (!is_separator) { - token_start = src_i - 1; - } - continue; - } else if (token_start != std::string::npos && !is_separator) { - continue; - } - if (dst_i >= element_capacity) { - return iree_make_status( - IREE_STATUS_OUT_OF_RANGE, - "output data buffer overflow: element_capacity=%zu < dst_i=%zu+", - element_capacity, dst_i); - } - IREE_RETURN_IF_ERROR(iree_hal_parse_element_unsafe( - iree_string_view_t{data_str.data + token_start, - src_i - 2 - token_start + 1}, - element_type, data_ptr.data + dst_i * element_size)); - ++dst_i; - token_start = std::string::npos; - } - if (token_start != std::string::npos) { - if (dst_i >= element_capacity) { - return iree_make_status( - IREE_STATUS_OUT_OF_RANGE, - "output data overflow: element_capacity=%zu < dst_i=%zu", - element_capacity, dst_i); - } - IREE_RETURN_IF_ERROR(iree_hal_parse_element_unsafe( - iree_string_view_t{data_str.data + token_start, - data_str.size - token_start}, - element_type, data_ptr.data + dst_i * element_size)); - ++dst_i; - } - if (dst_i == 1 && element_capacity > 1) { - // Splat the single value we got to the entire buffer. - uint8_t* p = data_ptr.data + element_size; - for (int i = 1; i < element_capacity; ++i, p += element_size) { - memcpy(p, data_ptr.data, element_size); - } - } else if (dst_i < element_capacity) { - return iree_make_status( - IREE_STATUS_OUT_OF_RANGE, - "input data string underflow: dst_i=%zu < element_capacity=%zu", dst_i, - element_capacity); - } - return iree_ok_status(); -} - -static iree_status_t iree_hal_format_buffer_elements_recursive( - iree_const_byte_span_t data, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - iree_host_size_t* max_element_count, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length) { - iree_host_size_t buffer_length = 0; - auto append_char = [&](char c) { - if (buffer) { - if (buffer_length < buffer_capacity - 1) { - buffer[buffer_length] = c; - buffer[buffer_length + 1] = '\0'; - } else { - buffer = nullptr; - } - } - ++buffer_length; - }; - - if (shape_rank == 0) { - // Scalar value; recurse to get on to the leaf dimension path. - const iree_hal_dim_t one = 1; - return iree_hal_format_buffer_elements_recursive( - data, &one, 1, element_type, max_element_count, buffer_capacity, buffer, - out_buffer_length); - } else if (shape_rank > 1) { - // Nested dimension; recurse into the next innermost dimension. - iree_hal_dim_t dim_length = 1; - for (iree_host_size_t i = 1; i < shape_rank; ++i) { - dim_length *= shape[i]; - } - iree_device_size_t dim_stride = - dim_length * iree_hal_element_byte_count(element_type); - if (data.data_length < dim_stride * shape[0]) { - return iree_make_status( - IREE_STATUS_OUT_OF_RANGE, - "input data underflow: data_length=%zu < expected=%zu", - data.data_length, - static_cast(dim_stride * shape[0])); - } - iree_const_byte_span_t subdata; - subdata.data = data.data; - subdata.data_length = dim_stride; - for (iree_hal_dim_t i = 0; i < shape[0]; ++i) { - append_char('['); - iree_host_size_t actual_length = 0; - iree_status_t status = iree_hal_format_buffer_elements_recursive( - subdata, shape + 1, shape_rank - 1, element_type, max_element_count, - buffer ? buffer_capacity - buffer_length : 0, - buffer ? buffer + buffer_length : nullptr, &actual_length); - buffer_length += actual_length; - if (iree_status_is_out_of_range(status)) { - buffer = nullptr; - } else if (!iree_status_is_ok(status)) { - return status; - } - subdata.data += dim_stride; - append_char(']'); - } - } else { - // Leaf dimension; output data. - iree_host_size_t max_count = - std::min(*max_element_count, static_cast(shape[0])); - iree_device_size_t element_stride = - iree_hal_element_byte_count(element_type); - if (data.data_length < max_count * element_stride) { - return iree_make_status( - IREE_STATUS_OUT_OF_RANGE, - "input data underflow; data_length=%zu < expected=%zu", - data.data_length, - static_cast(max_count * element_stride)); - } - *max_element_count -= max_count; - iree_const_byte_span_t subdata; - subdata.data = data.data; - subdata.data_length = element_stride; - for (iree_hal_dim_t i = 0; i < max_count; ++i) { - if (i > 0) append_char(' '); - iree_host_size_t actual_length = 0; - iree_status_t status = iree_hal_format_element( - subdata, element_type, buffer ? buffer_capacity - buffer_length : 0, - buffer ? buffer + buffer_length : nullptr, &actual_length); - subdata.data += element_stride; - buffer_length += actual_length; - if (iree_status_is_out_of_range(status)) { - buffer = nullptr; - } else if (!iree_status_is_ok(status)) { - return status; - } - } - if (max_count < shape[0]) { - append_char('.'); - append_char('.'); - append_char('.'); - } - } - if (out_buffer_length) { - *out_buffer_length = buffer_length; - } - return buffer ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_buffer_elements( - iree_const_byte_span_t data, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length) { - IREE_TRACE_SCOPE0("iree_hal_format_buffer_elements"); - if (out_buffer_length) { - *out_buffer_length = 0; - } - if (buffer && buffer_capacity) { - buffer[0] = '\0'; - } - return iree_hal_format_buffer_elements_recursive( - data, shape, shape_rank, element_type, &max_element_count, - buffer_capacity, buffer, out_buffer_length); -} - -//===----------------------------------------------------------------------===// -// iree::hal::Allocator -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(allocator, Allocator); - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_allocator_create_host_local(iree_allocator_t allocator, - iree_hal_allocator** out_allocator) { - IREE_TRACE_SCOPE0("iree_hal_allocator_create_host_local"); - IREE_ASSERT_ARGUMENT(out_allocator); - *out_allocator = - reinterpret_cast(new host::HostLocalAllocator()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_size( - const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - iree_device_size_t* out_allocation_size) { - IREE_ASSERT_ARGUMENT(allocator); - IREE_ASSERT_ARGUMENT(shape); - IREE_ASSERT_ARGUMENT(out_allocation_size); - *out_allocation_size = 0; - - // TODO(benvanik): layout/padding. - iree_device_size_t byte_length = iree_hal_element_byte_count(element_type); - for (int i = 0; i < shape_rank; ++i) { - byte_length *= shape[i]; - } - *out_allocation_size = byte_length; - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_offset( - const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - const iree_hal_dim_t* indices, iree_host_size_t indices_count, - iree_device_size_t* out_offset) { - IREE_ASSERT_ARGUMENT(allocator); - IREE_ASSERT_ARGUMENT(shape); - IREE_ASSERT_ARGUMENT(indices); - IREE_ASSERT_ARGUMENT(out_offset); - *out_offset = 0; - if (shape_rank != indices_count) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "shape rank/indices mismatch: %zu != %zu", - shape_rank, indices_count); - } - - // TODO(benvanik): layout/padding. - iree_device_size_t offset = 0; - for (iree_host_size_t i = 0; i < indices_count; ++i) { - if (indices[i] >= shape[i]) { - return iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "index[%zu] out of bounds: %d >= %d", i, - indices[i], shape[i]); - } - iree_device_size_t axis_offset = indices[i]; - for (iree_host_size_t j = i + 1; j < shape_rank; ++j) { - axis_offset *= shape[j]; - } - offset += axis_offset; - } - offset *= iree_hal_element_byte_count(element_type); - - *out_offset = offset; - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_range( - const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, - const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, - iree_device_size_t* out_start_offset, iree_device_size_t* out_length) { - IREE_ASSERT_ARGUMENT(allocator); - IREE_ASSERT_ARGUMENT(shape); - IREE_ASSERT_ARGUMENT(start_indices); - IREE_ASSERT_ARGUMENT(lengths); - IREE_ASSERT_ARGUMENT(out_start_offset); - IREE_ASSERT_ARGUMENT(out_length); - *out_start_offset = 0; - *out_length = 0; - if (indices_count != lengths_count) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "indices/lengths mismatch: %zu != %zu", - indices_count, lengths_count); - } - if (shape_rank != indices_count) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "shape rank/indices mismatch: %zu != %zu", - shape_rank, indices_count); - } - - // TODO(benvanik): layout/padding. - absl::InlinedVector end_indices(shape_rank); - iree_device_size_t element_size = iree_hal_element_byte_count(element_type); - iree_device_size_t subspan_length = element_size; - for (int i = 0; i < lengths_count; ++i) { - subspan_length *= lengths[i]; - end_indices[i] = start_indices[i] + lengths[i] - 1; - } - - iree_device_size_t start_byte_offset = 0; - IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_offset( - allocator, shape, shape_rank, element_type, start_indices, indices_count, - &start_byte_offset)); - iree_device_size_t end_byte_offset = 0; - IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_offset( - allocator, shape, shape_rank, element_type, end_indices.data(), - end_indices.size(), &end_byte_offset)); - - // Non-contiguous regions not yet implemented. Will be easier to detect when - // we have strides. - auto offset_length = end_byte_offset - start_byte_offset + element_size; - if (subspan_length != offset_length) { - return iree_make_status( - IREE_STATUS_UNIMPLEMENTED, - "non-contiguous range region computation not implemented"); - } - - *out_start_offset = start_byte_offset; - *out_length = subspan_length; - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_allocate_buffer( - iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, iree_host_size_t allocation_size, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_allocator_allocate_buffer"); - IREE_ASSERT_ARGUMENT(allocator); - IREE_ASSERT_ARGUMENT(out_buffer); - *out_buffer = nullptr; - - auto* handle = reinterpret_cast(allocator); - IREE_ASSIGN_OR_RETURN( - auto buffer, - handle->Allocate(static_cast(memory_type), - static_cast(buffer_usage), - allocation_size)); - - *out_buffer = reinterpret_cast(buffer.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_wrap_buffer( - iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, iree_byte_span_t data, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_allocator_wrap_buffer"); - IREE_ASSERT_ARGUMENT(allocator); - IREE_ASSERT_ARGUMENT(out_buffer); - *out_buffer = nullptr; - - auto* handle = reinterpret_cast(allocator); - IREE_ASSIGN_OR_RETURN( - auto buffer, - handle->WrapMutable(static_cast(memory_type), - static_cast(allowed_access), - static_cast(buffer_usage), - data.data, data.data_length)); - - *out_buffer = reinterpret_cast(buffer.release()); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::Buffer -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(buffer, Buffer); - -IREE_API_EXPORT iree_status_t iree_hal_buffer_subspan( - iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_buffer_subspan"); - IREE_ASSERT_ARGUMENT(buffer); - IREE_ASSERT_ARGUMENT(out_buffer); - *out_buffer = nullptr; - - auto handle = add_ref(reinterpret_cast(buffer)); - IREE_ASSIGN_OR_RETURN(auto new_handle, - Buffer::Subspan(handle, byte_offset, byte_length)); - - *out_buffer = reinterpret_cast(new_handle.release()); - - return iree_ok_status(); -} - -IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL -iree_hal_buffer_allocator(const iree_hal_buffer_t* buffer) { - IREE_ASSERT_ARGUMENT(buffer); - const auto* handle = reinterpret_cast(buffer); - return reinterpret_cast(handle->allocator()); -} - -IREE_API_EXPORT iree_device_size_t -iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer) { - IREE_ASSERT_ARGUMENT(buffer); - const auto* handle = reinterpret_cast(buffer); - return handle->byte_length(); -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_zero"); - IREE_ASSERT_ARGUMENT(buffer); - auto* handle = reinterpret_cast(buffer); - return handle->Fill8(byte_offset, byte_length, 0); -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_host_size_t pattern_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_fill"); - IREE_ASSERT_ARGUMENT(buffer); - IREE_ASSERT_ARGUMENT(pattern); - auto* handle = reinterpret_cast(buffer); - return handle->Fill(byte_offset, byte_length, pattern, pattern_length); -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_read_data( - iree_hal_buffer_t* buffer, iree_device_size_t source_offset, - void* target_buffer, iree_device_size_t data_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_read_data"); - IREE_ASSERT_ARGUMENT(buffer); - IREE_ASSERT_ARGUMENT(target_buffer); - auto* handle = reinterpret_cast(buffer); - return handle->ReadData(source_offset, target_buffer, data_length); -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_write_data( - iree_hal_buffer_t* buffer, iree_device_size_t target_offset, - const void* source_buffer, iree_device_size_t data_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_write_data"); - IREE_ASSERT_ARGUMENT(buffer); - IREE_ASSERT_ARGUMENT(source_buffer); - auto* handle = reinterpret_cast(buffer); - return handle->WriteData(target_offset, source_buffer, data_length); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_copy_data( - iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, - iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, - iree_device_size_t data_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_copy_data"); - IREE_ASSERT_ARGUMENT(source_buffer); - IREE_ASSERT_ARGUMENT(target_buffer); - auto* handle = reinterpret_cast(target_buffer); - return handle->CopyData(target_offset, - reinterpret_cast(source_buffer), - source_offset, data_length); -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_map( - iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, - iree_device_size_t byte_offset, iree_device_size_t byte_length, - iree_hal_mapped_memory_t* out_mapped_memory) { - IREE_TRACE_SCOPE0("iree_hal_buffer_map"); - IREE_ASSERT_ARGUMENT(buffer); - IREE_ASSERT_ARGUMENT(out_mapped_memory); - - std::memset(out_mapped_memory, 0, sizeof(*out_mapped_memory)); - - auto* buffer_handle = reinterpret_cast(buffer); - IREE_ASSIGN_OR_RETURN(auto mapping, - buffer_handle->MapMemory( - static_cast(memory_access), - byte_offset, byte_length)); - - static_assert(sizeof(iree_hal_mapped_memory_t::reserved) >= - sizeof(MappedMemory), - "C mapped memory struct must have large enough storage for the " - "matching C++ struct"); - auto* mapping_storage = - reinterpret_cast*>(out_mapped_memory->reserved); - *mapping_storage = std::move(mapping); - - out_mapped_memory->contents = {mapping_storage->unsafe_data(), - mapping_storage->size()}; - - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_unmap( - iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory) { - IREE_TRACE_SCOPE0("iree_hal_buffer_map"); - IREE_ASSERT_ARGUMENT(buffer); - IREE_ASSERT_ARGUMENT(mapped_memory); - auto* mapping = - reinterpret_cast*>(mapped_memory->reserved); - mapping->reset(); - std::memset(mapped_memory, 0, sizeof(*mapped_memory)); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::HeapBuffer -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_host_size_t allocation_size, iree_allocator_t contents_allocator, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate"); - IREE_ASSERT_ARGUMENT(out_buffer); - *out_buffer = nullptr; - - if (!allocation_size) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "allocation size must be >= 0"); - } - - auto handle = HeapBuffer::Allocate( - static_cast(memory_type), - static_cast(usage), allocation_size); - - *out_buffer = reinterpret_cast( - static_cast(handle.release())); - - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_hal_memory_access_t allowed_access, iree_byte_span_t contents, - iree_allocator_t contents_allocator, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate_copy"); - IREE_ASSERT_ARGUMENT(out_buffer); - - *out_buffer = nullptr; - - if (!contents.data || !contents.data_length) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "no contents specified (0 length)"); - } - - auto handle = HeapBuffer::AllocateCopy( - static_cast(usage), - static_cast(allowed_access), contents.data, - contents.data_length); - - *out_buffer = reinterpret_cast(handle.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( - iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t usage, iree_byte_span_t contents, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_heap_buffer_wrap"); - IREE_ASSERT_ARGUMENT(out_buffer); - - *out_buffer = nullptr; - - if (!contents.data || !contents.data_length) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "no contents specified (0 length)"); - } - - auto handle = - HeapBuffer::WrapMutable(static_cast(memory_type), - static_cast(allowed_access), - static_cast(usage), - contents.data, contents.data_length); - - *out_buffer = reinterpret_cast(handle.release()); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::BufferView -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(buffer_view, iree_hal_buffer_view); - -IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create( - iree_hal_buffer_t* buffer, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_create"); - IREE_ASSERT_ARGUMENT(buffer); - IREE_ASSERT_ARGUMENT(out_buffer_view); - - *out_buffer_view = nullptr; - if (shape_rank > 0 && !shape) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "no shape dimensions specified"); - } - - // Allocate and initialize the iree_hal_buffer_view struct. - // Note that we have the dynamically-sized shape dimensions on the end. - iree_hal_buffer_view* buffer_view = nullptr; - IREE_RETURN_IF_ERROR(iree_allocator_malloc( - allocator, sizeof(*buffer_view) + sizeof(iree_hal_dim_t) * shape_rank, - reinterpret_cast(&buffer_view))); - new (buffer_view) iree_hal_buffer_view(); - buffer_view->allocator = allocator; - buffer_view->buffer = buffer; - iree_hal_buffer_retain(buffer_view->buffer); - buffer_view->element_type = element_type; - buffer_view->byte_length = - iree_hal_element_byte_count(buffer_view->element_type); - buffer_view->shape_rank = shape_rank; - for (iree_host_size_t i = 0; i < shape_rank; ++i) { - buffer_view->shape[i] = shape[i]; - buffer_view->byte_length *= shape[i]; - } - - *out_buffer_view = buffer_view; - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_subview( - const iree_hal_buffer_view_t* buffer_view, - const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, - const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) { - IREE_ASSERT_ARGUMENT(out_buffer_view); - - // NOTE: we rely on the compute range call to do parameter validation. - iree_device_size_t start_offset = 0; - iree_device_size_t subview_length = 0; - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_compute_range( - buffer_view, start_indices, indices_count, lengths, lengths_count, - &start_offset, &subview_length)); - - iree_hal_buffer_t* subview_buffer = nullptr; - IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan(buffer_view->buffer, - start_offset, subview_length, - allocator, &subview_buffer)); - - iree_status_t result = iree_hal_buffer_view_create( - subview_buffer, lengths, lengths_count, buffer_view->element_type, - allocator, out_buffer_view); - iree_hal_buffer_release(subview_buffer); - return result; -} - -IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer( - const iree_hal_buffer_view_t* buffer_view) { - IREE_ASSERT_ARGUMENT(buffer_view); - return buffer_view->buffer; -} - -IREE_API_EXPORT iree_host_size_t IREE_API_CALL -iree_hal_buffer_view_shape_rank(const iree_hal_buffer_view_t* buffer_view) { - IREE_ASSERT_ARGUMENT(buffer_view); - return buffer_view->shape_rank; -} - -IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_hal_buffer_view_shape_dim( - const iree_hal_buffer_view_t* buffer_view, iree_host_size_t index) { - IREE_ASSERT_ARGUMENT(buffer_view); - if (index > buffer_view->shape_rank) { - return 0; - } - return buffer_view->shape[index]; -} - -IREE_API_EXPORT iree_host_size_t -iree_hal_buffer_view_element_count(const iree_hal_buffer_view_t* buffer_view) { - IREE_ASSERT_ARGUMENT(buffer_view); - iree_host_size_t element_count = 1; - for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) { - element_count *= buffer_view->shape[i]; - } - return element_count; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape( - const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity, - iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) { - IREE_ASSERT_ARGUMENT(buffer_view); - IREE_ASSERT_ARGUMENT(out_shape); - if (out_shape_rank) { - *out_shape_rank = 0; - } - - if (out_shape_rank) { - *out_shape_rank = buffer_view->shape_rank; - } - if (rank_capacity < buffer_view->shape_rank) { - // Not an error; just a size query. - return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); - } - - for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) { - out_shape[i] = buffer_view->shape[i]; - } - - return iree_ok_status(); -} - -IREE_API_EXPORT iree_hal_element_type_t IREE_API_CALL -iree_hal_buffer_view_element_type(const iree_hal_buffer_view_t* buffer_view) { - IREE_ASSERT_ARGUMENT(buffer_view); - return buffer_view->element_type; -} - -IREE_API_EXPORT iree_host_size_t -iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) { - IREE_ASSERT_ARGUMENT(buffer_view); - return iree_hal_element_byte_count(buffer_view->element_type); -} - -IREE_API_EXPORT iree_device_size_t IREE_API_CALL -iree_hal_buffer_view_byte_length(const iree_hal_buffer_view_t* buffer_view) { - IREE_ASSERT_ARGUMENT(buffer_view); - return buffer_view->byte_length; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_offset( - const iree_hal_buffer_view_t* buffer_view, const iree_hal_dim_t* indices, - iree_host_size_t indices_count, iree_device_size_t* out_offset) { - IREE_ASSERT_ARGUMENT(buffer_view); - return iree_hal_allocator_compute_offset( - iree_hal_buffer_allocator(buffer_view->buffer), buffer_view->shape, - buffer_view->shape_rank, buffer_view->element_type, indices, - indices_count, out_offset); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_range( - const iree_hal_buffer_view_t* buffer_view, - const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, - const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, - iree_device_size_t* out_start_offset, iree_device_size_t* out_length) { - IREE_ASSERT_ARGUMENT(buffer_view); - return iree_hal_allocator_compute_range( - iree_hal_buffer_allocator(buffer_view->buffer), buffer_view->shape, - buffer_view->shape_rank, buffer_view->element_type, start_indices, - indices_count, lengths, lengths_count, out_start_offset, out_length); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_parse( - iree_string_view_t value, iree_hal_allocator_t* buffer_allocator, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_parse"); - IREE_ASSERT_ARGUMENT(buffer_allocator); - - // Strip whitespace that may come along (linefeeds/etc). - auto string_view = - absl::StripAsciiWhitespace(absl::string_view(value.data, value.size)); - string_view = absl::StripPrefix(string_view, "\""); - string_view = absl::StripSuffix(string_view, "\""); - if (string_view.empty()) { - // Empty lines are invalid; need at least the shape/type information. - *out_buffer_view = nullptr; - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "empty string input"); - } - - // The part of the string corresponding to the shape, e.g. 1x2x3. - absl::string_view shape_str; - // The part of the string corresponding to the type, e.g. f32 - absl::string_view type_str; - // The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6 - absl::string_view data_str; - - absl::string_view shape_and_type_str; - auto equal_index = string_view.find('='); - if (equal_index == std::string::npos) { - // Treat a lack of = as defaulting the data to zeros. - shape_and_type_str = string_view; - } else { - shape_and_type_str = string_view.substr(0, equal_index); - data_str = string_view.substr(equal_index + 1); - } - auto last_x_index = shape_and_type_str.rfind('x'); - if (last_x_index == std::string::npos) { - // Scalar. - type_str = shape_and_type_str; - } else { - // Has a shape. - shape_str = shape_and_type_str.substr(0, last_x_index); - type_str = shape_and_type_str.substr(last_x_index + 1); - } - - // AxBxC... - absl::InlinedVector shape(6); - iree_host_size_t shape_rank = 0; - iree_status_t shape_result = - iree_hal_parse_shape({shape_str.data(), shape_str.length()}, shape.size(), - shape.data(), &shape_rank); - if (iree_status_is_ok(shape_result)) { - shape.resize(shape_rank); - } else if (iree_status_is_out_of_range(shape_result)) { - shape.resize(shape_rank); - IREE_RETURN_IF_ERROR( - iree_hal_parse_shape({shape_str.data(), shape_str.length()}, - shape.size(), shape.data(), &shape_rank)); - } else { - return shape_result; - } - - // f32, i32, etc - iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; - IREE_RETURN_IF_ERROR(iree_hal_parse_element_type( - {type_str.data(), type_str.length()}, &element_type)); - - // Allocate the buffer we will parse into from the provided allocator. - iree_device_size_t buffer_length = 0; - IREE_RETURN_IF_ERROR(iree_hal_allocator_compute_size( - buffer_allocator, shape.data(), shape.size(), element_type, - &buffer_length)); - iree_hal_buffer_t* buffer = nullptr; - IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( - buffer_allocator, - IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - buffer_length, &buffer)); - - iree_status_t status; - - // Parse the elements directly into the buffer. - iree_hal_mapped_memory_t mapped_buffer; - status = iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0, - buffer_length, &mapped_buffer); - if (!iree_status_is_ok(status)) { - iree_hal_buffer_release(buffer); - return status; - } - status = iree_hal_parse_buffer_elements({data_str.data(), data_str.length()}, - element_type, mapped_buffer.contents); - iree_hal_buffer_unmap(buffer, &mapped_buffer); - if (!iree_status_is_ok(status)) { - iree_hal_buffer_release(buffer); - return status; - } - - // Wrap and pass ownership of the buffer to the buffer view. - status = - iree_hal_buffer_view_create(buffer, shape.data(), shape.size(), - element_type, allocator, out_buffer_view); - iree_hal_buffer_release(buffer); - return status; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_format( - const iree_hal_buffer_view_t* buffer_view, - iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_format"); - IREE_ASSERT_ARGUMENT(buffer_view); - if (out_buffer_length) { - *out_buffer_length = 0; - } - if (buffer && buffer_capacity) { - buffer[0] = 0; - } - - iree_status_t status; - iree_host_size_t buffer_length = 0; - auto append_char = [&](char c) { - if (buffer) { - if (buffer_length < buffer_capacity - 1) { - buffer[buffer_length] = c; - buffer[buffer_length + 1] = '\0'; - } else { - buffer = nullptr; - } - } - ++buffer_length; - }; - - if (buffer_view->shape_rank > 0) { - // Shape: 1x2x3 - iree_host_size_t shape_length = 0; - status = iree_hal_format_shape(buffer_view->shape, buffer_view->shape_rank, - buffer ? buffer_capacity - buffer_length : 0, - buffer ? buffer + buffer_length : nullptr, - &shape_length); - buffer_length += shape_length; - if (iree_status_is_out_of_range(status)) { - status = iree_status_ignore(status); - buffer = nullptr; - } else if (!iree_status_is_ok(status)) { - return status; - } - - // Separator: x - append_char('x'); - } - - // Element type: f32 - iree_host_size_t element_type_length = 0; - status = iree_hal_format_element_type( - buffer_view->element_type, buffer ? buffer_capacity - buffer_length : 0, - buffer ? buffer + buffer_length : nullptr, &element_type_length); - buffer_length += element_type_length; - if (iree_status_is_out_of_range(status)) { - status = iree_status_ignore(status); - buffer = nullptr; - } else if (!iree_status_is_ok(status)) { - return status; - } - - // Separator: = - append_char('='); - - // Buffer contents: 0 1 2 3 ... - iree_hal_mapped_memory_t mapped_buffer; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map(buffer_view->buffer, - IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_buffer)); - iree_host_size_t elements_length = 0; - status = iree_hal_format_buffer_elements( - iree_const_byte_span_t{mapped_buffer.contents.data, - mapped_buffer.contents.data_length}, - buffer_view->shape, buffer_view->shape_rank, buffer_view->element_type, - max_element_count, buffer ? buffer_capacity - buffer_length : 0, - buffer ? buffer + buffer_length : nullptr, &elements_length); - buffer_length += elements_length; - iree_hal_buffer_unmap(buffer_view->buffer, &mapped_buffer); - if (iree_status_is_out_of_range(status)) { - status = iree_status_ignore(status); - buffer = nullptr; - } else if (!iree_status_is_ok(status)) { - return status; - } - - if (out_buffer_length) { - *out_buffer_length = buffer_length; - } - return buffer ? iree_ok_status() - : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); -} - -//===----------------------------------------------------------------------===// -// iree::hal::CommandBuffer -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(command_buffer, CommandBuffer); - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_create( - iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, iree_allocator_t allocator, - iree_hal_command_buffer_t** out_command_buffer) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_create"); - IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(out_command_buffer); - *out_command_buffer = nullptr; - auto* handle = reinterpret_cast(device); - - IREE_ASSIGN_OR_RETURN( - auto command_buffer, - handle->CreateCommandBuffer( - static_cast(mode), - static_cast(command_categories))); - - *out_command_buffer = - reinterpret_cast(command_buffer.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t -iree_hal_command_buffer_begin(iree_hal_command_buffer_t* command_buffer) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_begin"); - IREE_ASSERT_ARGUMENT(command_buffer); - auto* handle = reinterpret_cast(command_buffer); - return handle->Begin(); -} - -IREE_API_EXPORT iree_status_t -iree_hal_command_buffer_end(iree_hal_command_buffer_t* command_buffer) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_end"); - IREE_ASSERT_ARGUMENT(command_buffer); - auto* handle = reinterpret_cast(command_buffer); - return handle->End(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_execution_barrier( - iree_hal_command_buffer_t* command_buffer, - iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - iree_host_size_t memory_barrier_count, - const iree_hal_memory_barrier_t* memory_barriers, - iree_host_size_t buffer_barrier_count, - const iree_hal_buffer_barrier_t* buffer_barriers) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_execution_barrier"); - IREE_ASSERT_ARGUMENT(command_buffer); - auto* handle = reinterpret_cast(command_buffer); - // TODO(benvanik): refactor the C++ types to use the C types for storage so - // that we can safely map between the two. For now assume size equality - // is layout equality (as compilers aren't allowed to reorder structs). - static_assert(sizeof(MemoryBarrier) == sizeof(iree_hal_memory_barrier_t), - "Expecting identical layout"); - static_assert(sizeof(BufferBarrier) == sizeof(iree_hal_buffer_barrier_t), - "Expecting identical layout"); - return handle->ExecutionBarrier( - static_cast(source_stage_mask), - static_cast(target_stage_mask), - absl::MakeConstSpan( - reinterpret_cast(memory_barriers), - memory_barrier_count), - absl::MakeConstSpan( - reinterpret_cast(buffer_barriers), - buffer_barrier_count)); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_fill_buffer( - iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* target_buffer, - iree_device_size_t target_offset, iree_device_size_t length, - const void* pattern, iree_host_size_t pattern_length) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_fill_buffer"); - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(target_buffer); - auto* handle = reinterpret_cast(command_buffer); - return handle->FillBuffer(reinterpret_cast(target_buffer), - target_offset, length, pattern, pattern_length); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_update_buffer(iree_hal_command_buffer_t* command_buffer, - const void* source_buffer, - iree_host_size_t source_offset, - iree_hal_buffer_t* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_update_buffer"); - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(source_buffer); - IREE_ASSERT_ARGUMENT(target_buffer); - auto* handle = reinterpret_cast(command_buffer); - return handle->UpdateBuffer(source_buffer, source_offset, - reinterpret_cast(target_buffer), - target_offset, length); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_copy_buffer( - iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, - iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, - iree_device_size_t target_offset, iree_device_size_t length) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_copy_buffer"); - IREE_ASSERT_ARGUMENT(command_buffer); - auto* handle = reinterpret_cast(command_buffer); - return handle->CopyBuffer( - reinterpret_cast(source_buffer), source_offset, - reinterpret_cast(target_buffer), target_offset, length); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_push_constants( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, - const void* values, iree_host_size_t values_length) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_push_constants"); - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(executable_layout); - IREE_ASSERT_ARGUMENT(values); - if (values_length == 0) { - return iree_ok_status(); - } - if ((values_length % 4) != 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "invalid alignment %zu, must be 4-byte aligned", - values_length); - } - auto* handle = reinterpret_cast(command_buffer); - return handle->PushConstants( - reinterpret_cast(executable_layout), offset, - absl::MakeConstSpan(reinterpret_cast(values), - values_length / sizeof(uint32_t))); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_push_descriptor_set( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_layout_t* executable_layout, int32_t set, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_binding_t* bindings) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_push_descriptor_set"); - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(executable_layout); - auto* handle = reinterpret_cast(command_buffer); - if (binding_count && !bindings) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "bindings/binding_count mismatch"); - } - - // TODO(benvanik): refactor the C++ types to use the C types for storage so - // that we can safely map between the two. For now assume size equality - // is layout equality (as compilers aren't allowed to reorder structs). - static_assert(sizeof(DescriptorSet::Binding) == - sizeof(iree_hal_descriptor_set_binding_t), - "Expecting identical layout"); - return handle->PushDescriptorSet( - reinterpret_cast(executable_layout), set, - absl::MakeConstSpan( - reinterpret_cast(bindings), - binding_count)); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_bind_descriptor_set( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_layout_t* executable_layout, int32_t set, - iree_hal_descriptor_set_t* descriptor_set, - iree_host_size_t dynamic_offset_count, - const iree_device_size_t* dynamic_offsets) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_bind_descriptor_set"); - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(executable_layout); - IREE_ASSERT_ARGUMENT(descriptor_set); - auto* handle = reinterpret_cast(command_buffer); - if (dynamic_offset_count && !dynamic_offsets) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "dynamic_offset_count/dynamic_offsets mismatch"); - } - static_assert(sizeof(iree_device_size_t) == sizeof(device_size_t), - "Device sizes must match"); - return handle->BindDescriptorSet( - reinterpret_cast(executable_layout), set, - reinterpret_cast(descriptor_set), - absl::MakeConstSpan(dynamic_offsets, dynamic_offset_count)); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_dispatch( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_dispatch"); - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(executable); - auto* handle = reinterpret_cast(command_buffer); - return handle->Dispatch(reinterpret_cast(executable), - entry_point, {workgroup_x, workgroup_y, workgroup_z}); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_dispatch_indirect( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_t* workgroups_buffer, - iree_device_size_t workgroups_offset) { - IREE_TRACE_SCOPE0("iree_hal_command_buffer_dispatch_indirect"); - IREE_ASSERT_ARGUMENT(command_buffer); - IREE_ASSERT_ARGUMENT(executable); - IREE_ASSERT_ARGUMENT(workgroups_buffer); - auto* handle = reinterpret_cast(command_buffer); - return handle->DispatchIndirect( - reinterpret_cast(executable), entry_point, - reinterpret_cast(workgroups_buffer), workgroups_offset); -} - -//===----------------------------------------------------------------------===// -// iree::hal::DescriptorSet -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(descriptor_set, DescriptorSet); - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_descriptor_set_create( - iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_binding_t* bindings, - iree_allocator_t allocator, - iree_hal_descriptor_set_t** out_descriptor_set) { - IREE_TRACE_SCOPE0("iree_hal_descriptor_set_create"); - IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(set_layout); - IREE_ASSERT_ARGUMENT(out_descriptor_set); - *out_descriptor_set = nullptr; - if (binding_count && !bindings) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "binding_count/bindings mismatch"); - } - auto* handle = reinterpret_cast(device); - - // TODO(benvanik): refactor the C++ types to use the C types for storage so - // that we can safely map between the two. For now assume size equality - // is layout equality (as compilers aren't allowed to reorder structs). - static_assert(sizeof(DescriptorSet::Binding) == - sizeof(iree_hal_descriptor_set_binding_t), - "Expecting identical layout"); - IREE_ASSIGN_OR_RETURN( - auto descriptor_set, - handle->CreateDescriptorSet( - reinterpret_cast(set_layout), - absl::MakeConstSpan( - reinterpret_cast(bindings), - binding_count))); - - *out_descriptor_set = - reinterpret_cast(descriptor_set.release()); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::DescriptorSetLayout -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(descriptor_set_layout, DescriptorSetLayout); - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_descriptor_set_layout_create( - iree_hal_device_t* device, - iree_hal_descriptor_set_layout_usage_type_t usage_type, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t* bindings, - iree_allocator_t allocator, - iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { - IREE_TRACE_SCOPE0("iree_hal_descriptor_set_layout_create"); - IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); - *out_descriptor_set_layout = nullptr; - if (binding_count && !bindings) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "binding_count/bindings mismatch"); - } - auto* handle = reinterpret_cast(device); - - // TODO(benvanik): refactor the C++ types to use the C types for storage so - // that we can safely map between the two. For now assume size equality - // is layout equality (as compilers aren't allowed to reorder structs). - static_assert(sizeof(DescriptorSetLayout::Binding) == - sizeof(iree_hal_descriptor_set_layout_binding_t), - "Expecting identical layout"); - IREE_ASSIGN_OR_RETURN( - auto descriptor_set_layout, - handle->CreateDescriptorSetLayout( - static_cast(usage_type), - absl::MakeConstSpan( - reinterpret_cast(bindings), - binding_count))); - - *out_descriptor_set_layout = - reinterpret_cast( - descriptor_set_layout.release()); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::Device -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(device, Device); - -IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL -iree_hal_device_allocator(iree_hal_device_t* device) { - IREE_ASSERT_ARGUMENT(device); - auto* handle = reinterpret_cast(device); - return reinterpret_cast(handle->allocator()); -} - -IREE_API_EXPORT iree_string_view_t IREE_API_CALL -iree_hal_device_id(iree_hal_device_t* device) { - IREE_ASSERT_ARGUMENT(device); - auto* handle = reinterpret_cast(device); - const auto& id = handle->info().id(); - return iree_string_view_t{id.data(), id.size()}; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_device_queue_submit( - iree_hal_device_t* device, iree_hal_command_category_t command_categories, - uint64_t queue_affinity, iree_host_size_t batch_count, - const iree_hal_submission_batch_t* batches) { - IREE_TRACE_SCOPE0("iree_hal_device_queue_submit"); - IREE_ASSERT_ARGUMENT(device); - auto* handle = reinterpret_cast(device); - if (batch_count > 0 && !batches) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "batch_count/batches mismatch"); - } - - // We need to allocate storage to marshal in the semaphores. Ideally we'd - // change the C++ API to make this 1:1 with a reinterpret_cast, however that - // makes the C API more difficult. Bleh. - iree_host_size_t total_semaphore_count = 0; - for (iree_host_size_t i = 0; i < batch_count; ++i) { - total_semaphore_count += batches[i].wait_semaphores.count; - total_semaphore_count += batches[i].signal_semaphores.count; - } - absl::InlinedVector semaphore_values( - total_semaphore_count); - absl::InlinedVector dst_batches(batch_count); - iree_host_size_t base_semaphore_index = 0; - for (iree_host_size_t i = 0; i < batch_count; ++i) { - const auto& src_batch = batches[i]; - auto& dst_batch = dst_batches[i]; - for (iree_host_size_t j = 0; j < src_batch.wait_semaphores.count; ++j) { - semaphore_values[base_semaphore_index + j] = { - reinterpret_cast(src_batch.wait_semaphores.semaphores[j]), - src_batch.wait_semaphores.payload_values[j]}; - } - dst_batch.wait_semaphores = - absl::MakeConstSpan(&semaphore_values[base_semaphore_index], - src_batch.wait_semaphores.count); - base_semaphore_index += src_batch.wait_semaphores.count; - dst_batch.command_buffers = - iree::ReinterpretSpan(absl::MakeConstSpan( - src_batch.command_buffers, src_batch.command_buffer_count)); - for (iree_host_size_t j = 0; j < src_batch.signal_semaphores.count; ++j) { - semaphore_values[base_semaphore_index + j] = { - reinterpret_cast( - src_batch.signal_semaphores.semaphores[j]), - src_batch.signal_semaphores.payload_values[j]}; - } - dst_batch.signal_semaphores = - absl::MakeConstSpan(&semaphore_values[base_semaphore_index], - src_batch.signal_semaphores.count); - base_semaphore_index += src_batch.signal_semaphores.count; - } - - // For now we always go to the first compute queue. TBD cleanup pending the - // device modeling in the IR as to how we really want to handle this. We'll - // want to use queue_affinity in a way that ensures we have some control over - // things on the compiler side and may require that devices are declared by - // the number and types of queues they support. - uint64_t queue_index = queue_affinity % handle->dispatch_queues().size(); - auto* command_queue = handle->dispatch_queues()[queue_index]; - return command_queue->Submit(dst_batches); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_device_wait_semaphores_with_deadline( - iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, - const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) { - IREE_TRACE_SCOPE0("iree_hal_device_wait_semaphores_with_deadline"); - IREE_ASSERT_ARGUMENT(device); - if (!semaphore_list || semaphore_list->count == 0) return iree_ok_status(); - auto* handle = reinterpret_cast(device); - - absl::InlinedVector semaphore_values( - semaphore_list->count); - for (int i = 0; i < semaphore_list->count; ++i) { - semaphore_values[i] = { - reinterpret_cast(semaphore_list->semaphores[i]), - semaphore_list->payload_values[i]}; - } - - switch (wait_mode) { - case IREE_HAL_WAIT_MODE_ALL: { - return handle->WaitAllSemaphores(semaphore_values, Time(deadline_ns)); - } - case IREE_HAL_WAIT_MODE_ANY: { - IREE_ASSIGN_OR_RETURN( - int wake_index, - handle->WaitAnySemaphore(semaphore_values, Time(deadline_ns))); - (void)wake_index; - return iree_ok_status(); - } - default: { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "unhandled wait_mode"); - } - } -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_device_wait_semaphores_with_timeout( - iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, - const iree_hal_semaphore_list_t* semaphore_list, - iree_duration_t timeout_ns) { - iree_time_t deadline_ns = iree_relative_timeout_to_deadline_ns(timeout_ns); - return iree_hal_device_wait_semaphores_with_deadline( - device, wait_mode, semaphore_list, deadline_ns); -} - -//===----------------------------------------------------------------------===// -// iree::hal::Driver -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(driver, Driver); - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_driver_query_available_devices( - iree_hal_driver_t* driver, iree_allocator_t allocator, - iree_hal_device_info_t** out_device_infos, - iree_host_size_t* out_device_info_count) { - IREE_TRACE_SCOPE0("iree_hal_driver_query_available_devices"); - IREE_ASSERT_ARGUMENT(driver); - IREE_ASSERT_ARGUMENT(out_device_infos); - IREE_ASSERT_ARGUMENT(out_device_info_count); - *out_device_info_count = 0; - auto* handle = reinterpret_cast(driver); - - IREE_ASSIGN_OR_RETURN(auto device_infos, handle->EnumerateAvailableDevices()); - size_t total_string_size = 0; - for (const auto& device_info : device_infos) { - total_string_size += device_info.name().size(); - } - - *out_device_info_count = device_infos.size(); - iree_hal_device_info_t* device_info_storage = nullptr; - IREE_RETURN_IF_ERROR(iree_allocator_malloc( - allocator, - device_infos.size() * sizeof(*device_info_storage) + total_string_size, - (void**)&device_info_storage)); - - char* p = reinterpret_cast(device_info_storage) + - device_infos.size() * sizeof(*device_info_storage); - for (int i = 0; i < device_infos.size(); ++i) { - const auto& device_info = device_infos[i]; - device_info_storage[i].device_id = device_info.device_id(); - - size_t name_size = device_info.name().size(); - std::memcpy(p, device_info.name().c_str(), name_size); - device_info_storage[i].name = iree_string_view_t{p, name_size}; - p += name_size; - } - - *out_device_infos = device_info_storage; - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_create_device( - iree_hal_driver_t* driver, iree_hal_device_id_t device_id, - iree_allocator_t allocator, iree_hal_device_t** out_device) { - IREE_TRACE_SCOPE0("iree_hal_driver_create_device"); - IREE_ASSERT_ARGUMENT(driver); - IREE_ASSERT_ARGUMENT(out_device); - *out_device = nullptr; - auto* handle = reinterpret_cast(driver); - - IREE_ASSIGN_OR_RETURN(auto device, handle->CreateDevice(device_id)); - - *out_device = reinterpret_cast(device.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_driver_create_default_device(iree_hal_driver_t* driver, - iree_allocator_t allocator, - iree_hal_device_t** out_device) { - IREE_TRACE_SCOPE0("iree_hal_driver_create_default_device"); - IREE_ASSERT_ARGUMENT(driver); - IREE_ASSERT_ARGUMENT(out_device); - *out_device = nullptr; - auto* handle = reinterpret_cast(driver); - IREE_ASSIGN_OR_RETURN(auto device, handle->CreateDefaultDevice()); - *out_device = reinterpret_cast(device.release()); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::Executable -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(executable, Executable); - -//===----------------------------------------------------------------------===// -// iree::hal::ExecutableCache -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(executable_cache, ExecutableCache); - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_cache_create( - iree_hal_device_t* device, iree_string_view_t identifier, - iree_allocator_t allocator, - iree_hal_executable_cache_t** out_executable_cache) { - IREE_TRACE_SCOPE0("iree_hal_executable_cache_create"); - IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(out_executable_cache); - *out_executable_cache = nullptr; - - auto* handle = reinterpret_cast(device); - auto executable_cache = handle->CreateExecutableCache(); - *out_executable_cache = reinterpret_cast( - executable_cache.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT bool IREE_API_CALL iree_hal_executable_cache_can_prepare_format( - iree_hal_executable_cache_t* executable_cache, - iree_hal_executable_format_t format) { - IREE_ASSERT_ARGUMENT(executable_cache); - auto* handle = reinterpret_cast(executable_cache); - return handle->CanPrepareFormat(format); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_executable_cache_prepare_executable( - iree_hal_executable_cache_t* executable_cache, - iree_hal_executable_layout_t* executable_layout, - iree_hal_executable_caching_mode_t caching_mode, - iree_const_byte_span_t executable_data, iree_allocator_t allocator, - iree_hal_executable_t** out_executable) { - IREE_TRACE_SCOPE0("iree_hal_executable_cache_prepare_executable"); - IREE_ASSERT_ARGUMENT(executable_cache); - IREE_ASSERT_ARGUMENT(executable_layout); - IREE_ASSERT_ARGUMENT(out_executable); - *out_executable = nullptr; - auto* handle = reinterpret_cast(executable_cache); - - ExecutableSpec spec; - spec.executable_data = {executable_data.data, executable_data.data_length}; - IREE_ASSIGN_OR_RETURN( - auto executable, - handle->PrepareExecutable( - reinterpret_cast(executable_layout), - static_cast(caching_mode), spec)); - - *out_executable = - reinterpret_cast(executable.release()); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::ExecutableLayout -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(executable_layout, ExecutableLayout); - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_layout_create( - iree_hal_device_t* device, iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t** set_layouts, - iree_host_size_t push_constants, iree_allocator_t allocator, - iree_hal_executable_layout_t** out_executable_layout) { - IREE_TRACE_SCOPE0("iree_hal_executable_layout_create"); - IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(out_executable_layout); - *out_executable_layout = nullptr; - if (set_layout_count && !set_layouts) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "set_layout_count/set_layouts mismatch"); - } - - auto* handle = reinterpret_cast(device); - IREE_ASSIGN_OR_RETURN( - auto executable_layout, - handle->CreateExecutableLayout( - absl::MakeConstSpan( - reinterpret_cast(set_layouts), - set_layout_count), - push_constants)); - - *out_executable_layout = reinterpret_cast( - executable_layout.release()); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::Semaphore -//===----------------------------------------------------------------------===// - -IREE_HAL_API_RETAIN_RELEASE(semaphore, Semaphore); - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_semaphore_create( - iree_hal_device_t* device, uint64_t initial_value, - iree_allocator_t allocator, iree_hal_semaphore_t** out_semaphore) { - IREE_TRACE_SCOPE0("iree_hal_semaphore_create"); - IREE_ASSERT_ARGUMENT(device); - IREE_ASSERT_ARGUMENT(out_semaphore); - *out_semaphore = nullptr; - - auto* handle = reinterpret_cast(device); - IREE_ASSIGN_OR_RETURN(auto semaphore, handle->CreateSemaphore(initial_value)); - - *out_semaphore = reinterpret_cast(semaphore.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_query(iree_hal_semaphore_t* semaphore, uint64_t* out_value) { - IREE_ASSERT_ARGUMENT(semaphore); - IREE_ASSERT_ARGUMENT(out_value); - *out_value = 0; - - auto* handle = reinterpret_cast(semaphore); - IREE_ASSIGN_OR_RETURN(uint64_t value, handle->Query()); - *out_value = value; - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_signal(iree_hal_semaphore_t* semaphore, uint64_t new_value) { - IREE_TRACE_SCOPE0("iree_hal_semaphore_signal"); - IREE_ASSERT_ARGUMENT(semaphore); - auto* handle = reinterpret_cast(semaphore); - return handle->Signal(new_value); -} - -IREE_API_EXPORT void IREE_API_CALL -iree_hal_semaphore_fail(iree_hal_semaphore_t* semaphore, iree_status_t status) { - IREE_TRACE_SCOPE0("iree_hal_semaphore_fail"); - IREE_ASSERT_ARGUMENT(semaphore); - auto* handle = reinterpret_cast(semaphore); - handle->Fail(std::move(status)); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_wait_with_deadline(iree_hal_semaphore_t* semaphore, - uint64_t value, iree_time_t deadline_ns) { - IREE_TRACE_SCOPE0("iree_hal_semaphore_wait_with_deadline"); - IREE_ASSERT_ARGUMENT(semaphore); - auto* handle = reinterpret_cast(semaphore); - return handle->Wait(value, Time(deadline_ns)); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_wait_with_timeout(iree_hal_semaphore_t* semaphore, - uint64_t value, - iree_duration_t timeout_ns) { - IREE_TRACE_SCOPE0("iree_hal_semaphore_wait_with_timeout"); - IREE_ASSERT_ARGUMENT(semaphore); - auto* handle = reinterpret_cast(semaphore); - return handle->Wait(value, Duration(timeout_ns)); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/api.h b/iree/hal/api.h index a94310014603c..19ecd4b0f4260 100644 --- a/iree/hal/api.h +++ b/iree/hal/api.h @@ -17,1484 +17,20 @@ #ifndef IREE_HAL_API_H_ #define IREE_HAL_API_H_ -#include -#include - -#include "iree/base/api.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -//===----------------------------------------------------------------------===// -// Types and Enums -//===----------------------------------------------------------------------===// - -typedef struct iree_hal_allocator iree_hal_allocator_t; -typedef struct iree_hal_buffer iree_hal_buffer_t; -typedef struct iree_hal_buffer_view iree_hal_buffer_view_t; -typedef struct iree_hal_command_buffer iree_hal_command_buffer_t; -typedef struct iree_hal_descriptor_set iree_hal_descriptor_set_t; -typedef struct iree_hal_descriptor_set_layout iree_hal_descriptor_set_layout_t; -typedef struct iree_hal_device iree_hal_device_t; -typedef struct iree_hal_driver iree_hal_driver_t; -typedef struct iree_hal_driver_registry_s iree_hal_driver_registry_t; -typedef struct iree_hal_executable iree_hal_executable_t; -typedef struct iree_hal_executable_cache iree_hal_executable_cache_t; -typedef struct iree_hal_executable_layout iree_hal_executable_layout_t; -typedef struct iree_hal_semaphore iree_hal_semaphore_t; - -// Reference to a buffer's mapped memory. -typedef struct { - // Contents of the buffer. Behavior is undefined if an access is performed - // whose type was not specified during mapping. - iree_byte_span_t contents; - - // Used internally - do not modify. - uint64_t reserved[8]; -} iree_hal_mapped_memory_t; - -// A bitfield specifying properties for a memory type. -enum iree_hal_memory_type_e { - IREE_HAL_MEMORY_TYPE_NONE = 0u, - - // Memory is lazily allocated by the device and only exists transiently. - // This is the optimal mode for memory used only within a single command - // buffer. Transient buffers, even if they have - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE set, should be treated as device-local - // and opaque as they may have no memory attached to them outside of the time - // they are being evaluated on devices. - // - // This flag can be treated as a hint in most cases; allocating a buffer with - // it set _may_ return the same as if it had not be set. Certain allocation - // routines may use the hint to more tightly control reuse or defer wiring the - // memory. - IREE_HAL_MEMORY_TYPE_TRANSIENT = 1u << 0, - - // Memory allocated with this type can be mapped for host access using - // iree_hal_buffer_map. - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE = 1u << 1, - - // The host cache management commands MappedMemory::Flush and - // MappedMemory::Invalidate are not needed to flush host writes - // to the device or make device writes visible to the host, respectively. - IREE_HAL_MEMORY_TYPE_HOST_COHERENT = 1u << 2, - - // Memory allocated with this type is cached on the host. Host memory - // accesses to uncached memory are slower than to cached memory, however - // uncached memory is always host coherent. MappedMemory::Flush must be used - // to ensure the device has visibility into any changes made on the host and - // Invalidate must be used to ensure the host has visibility into any changes - // made on the device. - IREE_HAL_MEMORY_TYPE_HOST_CACHED = 1u << 3, - - // Memory is accessible as normal host allocated memory. - IREE_HAL_MEMORY_TYPE_HOST_LOCAL = - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_COHERENT, - - // Memory allocated with this type is visible to the device for execution. - // Being device visible does not mean the same thing as - // IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL. Though an allocation may be visible to - // the device and therefore useable for execution it may require expensive - // mapping or implicit transfers. - IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE = 1u << 4, - - // Memory allocated with this type is the most efficient for device access. - // Devices may support using memory that is not device local via - // IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE but doing so can incur non-trivial - // performance penalties. Device local memory, on the other hand, is - // guaranteed to be fast for all operations. - IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL = - IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | (1u << 5), -}; -typedef uint32_t iree_hal_memory_type_t; - -// A bitfield specifying how memory will be accessed in a mapped memory region. -enum iree_hal_memory_access_e { - // Memory is not mapped. - IREE_HAL_MEMORY_ACCESS_NONE = 0u, - // Memory will be read. - // If a buffer is only mapped for reading it may still be possible to write to - // it but the results will be undefined (as it may present coherency issues). - IREE_HAL_MEMORY_ACCESS_READ = 1u << 0, - // Memory will be written. - // If a buffer is only mapped for writing it may still be possible to read - // from it but the results will be undefined or incredibly slow (as it may - // be mapped by the driver as uncached). - IREE_HAL_MEMORY_ACCESS_WRITE = 1u << 1, - // Memory will be discarded prior to mapping. - // The existing contents will be undefined after mapping and must be written - // to ensure validity. - IREE_HAL_MEMORY_ACCESS_DISCARD = 1u << 2, - // Memory will be discarded and completely overwritten in a single operation. - IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE = - IREE_HAL_MEMORY_ACCESS_WRITE | IREE_HAL_MEMORY_ACCESS_DISCARD, - // A flag that can be applied to any access type to indicate that the buffer - // storage being accessed may alias with other accesses occurring concurrently - // within or across operations. The lack of the flag indicates that the access - // is guaranteed not to alias (ala C's `restrict` keyword). - IREE_HAL_MEMORY_ACCESS_MAY_ALIAS = 1u << 3, - // Memory may have any operation performed on it. - IREE_HAL_MEMORY_ACCESS_ALL = IREE_HAL_MEMORY_ACCESS_READ | - IREE_HAL_MEMORY_ACCESS_WRITE | - IREE_HAL_MEMORY_ACCESS_DISCARD, -}; -typedef uint32_t iree_hal_memory_access_t; - -// Bitfield that defines how a buffer is intended to be used. -// Usage allows the driver to appropriately place the buffer for more -// efficient operations of the specified types. -enum iree_hal_buffer_usage_e { - IREE_HAL_BUFFER_USAGE_NONE = 0u, - - // The buffer, once defined, will not be mapped or updated again. - // This should be used for uniform parameter values such as runtime - // constants for executables. Doing so may allow drivers to inline values or - // represent them in command buffers more efficiently (avoiding memory reads - // or swapping, etc). - IREE_HAL_BUFFER_USAGE_CONSTANT = 1u << 0, - - // The buffer can be used as the source or target of a transfer command - // (CopyBuffer, UpdateBuffer, etc). - // - // If |IREE_HAL_BUFFER_USAGE_MAPPING| is not specified drivers may safely - // assume that the host may never need visibility of this buffer as all - // accesses will happen via command buffers. - IREE_HAL_BUFFER_USAGE_TRANSFER = 1u << 1, - - // The buffer can be mapped by the host application for reading and writing. - // - // As mapping may require placement in special address ranges or system - // calls to enable visibility the driver can use the presence (or lack of) - // this flag to perform allocation-type setup and avoid initial mapping - // overhead. - IREE_HAL_BUFFER_USAGE_MAPPING = 1u << 2, - - // The buffer can be provided as an input or output to an executable. - // Buffers of this type may be directly used by drivers during dispatch. - IREE_HAL_BUFFER_USAGE_DISPATCH = 1u << 3, - - // Buffer may be used for any operation. - IREE_HAL_BUFFER_USAGE_ALL = IREE_HAL_BUFFER_USAGE_TRANSFER | - IREE_HAL_BUFFER_USAGE_MAPPING | - IREE_HAL_BUFFER_USAGE_DISPATCH, -}; -typedef uint32_t iree_hal_buffer_usage_t; - -// An opaque driver-specific handle to identify different devices. -typedef uintptr_t iree_hal_device_id_t; - -// Describes an enumerated HAL device. -typedef struct { - // Opaque handle used by drivers. Not valid across driver instances. - iree_hal_device_id_t device_id; - // Name of the device as returned by the API. - iree_string_view_t name; -} iree_hal_device_info_t; - -// An opaque factory-specific handle to identify different drivers. -typedef uint64_t iree_hal_driver_id_t; - -#define IREE_HAL_DRIVER_ID_INVALID 0ull - -// Describes a driver providing device enumeration and creation. -// The lifetime of memory referenced by this structure (such as strings) is -// dependent on where it originated. -// -// * When using iree_hal_driver_registry_enumerate the driver info is copied -// into memory owned by the caller. -// * When queried from a live driver with iree_hal_driver_info the memory is -// only guaranteed to live for as long as the driver is. -// * When enumerating via factories the information may be valid only while the -// driver registry lock is held. -typedef struct { - IREE_API_UNSTABLE - - // Opaque handle used by factories. Unique across all factories. - iree_hal_driver_id_t driver_id; - - // Canonical name of the driver as used in command lines, documentation, etc. - // Examples: 'metal', 'vulkan' - iree_string_view_t driver_name; - - // Full human-readable name of the driver for display. - // Examples: 'Vulkan 1.2 (NVIDIA)'. - iree_string_view_t full_name; - - // TODO(benvanik): version information; useful if wanting to expose multiple - // versions that may have completely different implementations (like vulkan - // 1.0, 1.1, and 1.2) but allow a nice sort/selection process. - // TODO(benvanik): triple, feature flags, etc. -} iree_hal_driver_info_t; - -// A bitfield specifying the mode of operation for a command buffer. -enum iree_hal_command_buffer_mode_e { - // Command buffer will be submitted once and never used again. - // This may enable in-place patching of command buffers that reduce overhead - // when it's known that command buffers will not be reused. - IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT = 1u << 0, -}; -typedef uint32_t iree_hal_command_buffer_mode_t; - -// A bitfield specifying the category of commands in a command queue. -enum iree_hal_command_category_e { - // Command is considered a transfer operation (memcpy, etc). - IREE_HAL_COMMAND_CATEGORY_TRANSFER = 1u << 0, - // Command is considered a dispatch operation (dispatch/execute). - IREE_HAL_COMMAND_CATEGORY_DISPATCH = 1u << 1, - // Commands may be of any type. - // Using this value may prevent optimizations and if possible callers should - // always specify the strictest set possible (for example, only transfer - // commands to ensure they get placed on a DMA queue). - IREE_HAL_COMMAND_CATEGORY_ANY = - IREE_HAL_COMMAND_CATEGORY_TRANSFER | IREE_HAL_COMMAND_CATEGORY_DISPATCH, -}; -typedef uint32_t iree_hal_command_category_t; - -// Specifies the type of a descriptor in a descriptor set. -enum iree_hal_descriptor_type_e { - IREE_HAL_DESCRIPTOR_TYPE_UNIFORM_BUFFER = 6u, - IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER = 7u, - IREE_HAL_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC = 8u, - IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC = 9u, -}; -typedef uint32_t iree_hal_descriptor_type_t; - -// Specifies a descriptor set binding. -// The range specified by [offset, length) will be made available to executables -// on the given binding. If the descriptor type is dynamic then the range will -// be [offset + dynamic_offset, length). -// -// The IREE HAL buffer type may internally be offset; such offset is applied -// here as if it were the base address of the buffer. Note that the offset will -// be applied at the time the binding is recording into the command buffer. -// -// Maps to VkDescriptorSetBinding. -typedef struct { - // The binding number of this entry and corresponds to a resource of the - // same binding number in the executable interface. - int32_t binding; - // Buffer bound to the binding number. - // May be nullptr if the binding is not used by the executable. - iree_hal_buffer_t* buffer; - // Offset, in bytes, into the buffer that the binding starts at. - // If the descriptor type is dynamic this will be added to the dynamic - // offset provided during binding. - iree_device_size_t offset; - // Length, in bytes, of the buffer that is available to the executable. - // This can be IREE_WHOLE_BUFFER, however note that if the entire buffer - // contents are larger than supported by the device (~128MiB, usually) this - // will fail. If the descriptor type is dynamic this will be used for all - // ranges regardless of offset. - iree_device_size_t length; -} iree_hal_descriptor_set_binding_t; - -// Specifies the usage type of the descriptor set. -enum iree_hal_descriptor_set_layout_usage_type_e { - // Descriptor set will be initialized once and never changed. - IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE = 0u, - // Descriptor set is never created and instead used with push descriptors. - IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY = 1u, -}; -typedef uint32_t iree_hal_descriptor_set_layout_usage_type_t; - -// Specifies a descriptor set layout binding. -// -// Maps to VkDescriptorSetLayoutBinding. -typedef struct { - // The binding number of this entry and corresponds to a resource of the - // same binding number in the executable interface. - int32_t binding; - // Specifies which type of resource descriptors are used for this binding. - iree_hal_descriptor_type_t type; - // Specifies the memory access performed by the executables. - iree_hal_memory_access_t access; -} iree_hal_descriptor_set_layout_binding_t; - -// An identifier for executable formats used to query support. -typedef uint32_t iree_hal_executable_format_t; - -// Defines how the executable cache performs preparation. -enum iree_hal_executable_caching_mode_e { - // Allows the cache to reference the provided executable_data after it has - // prepared the executable. Callers must ensure the data remains valid for the - // lifetime of the cache. If memory mapping constant executable data from - // disk this can be used to avoid copies. - IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA = 1u << 0, - // Allows the prepared executable to be cached persistently (on disk/etc). - // Enable for any executable that is likely to be used in future runs. - // Note that not all caches support persistent serialization and this is just - // a hint. - IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_PERSISTENT_CACHING = 1u << 1, - // Allows the cache to optimize the executable as much as it can. - // This may cause preparation to take significantly longer while (hopefully) - // improving runtime performance. Avoid for one-shot executables. - IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION = 1u << 2, - // Enables Executable debugging methods if supported by the device and - // executable. This may disable certain optimizations or retain additional - // data to allow disassembly, stepping, etc. - // - // Device must support the DeviceFeature::kDebugging feature and executables - // must support the ExecutableFeature::kDebugging feature. - IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_DEBUGGING = 1u << 3, - // Enables Executable coverage if supported by the device and executable. - // Depending on the optimization mode this may produce partial coverage - // results (for example, when certain source operations were optimized away). - // - // Device must support the DeviceFeature::kCoverage feature and executables - // must support the ExecutableFeature::kCoverage feature. - IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_COVERAGE = 1u << 4, - // Enables Executable profiling if supported by the device and executable. - // Depending on the optimization mode this may produce partial profiling - // results. Profiling attribution (whether to the entire executable or - // specific operations) depends on the implementation. - // - // Device must support the DeviceFeature::kProfiling feature and executables - // must support the ExecutableFeature::kProfiling feature. - IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_PROFILING = 1u << 5, - // Default caching mode. - IREE_HAL_EXECUTABLE_CACHING_MODE_DEFAULT = - IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_PERSISTENT_CACHING | - IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION, -}; -typedef uint32_t iree_hal_executable_caching_mode_t; - -// Bitfield specifying which execution stage a barrier should start/end at. -// -// Maps to VkPipelineStageFlagBits. -enum iree_hal_execution_stage_e { - // Top of the pipeline when commands are initially issued by the device. - IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE = 1u << 0, - // Stage of the pipeline when dispatch parameter data is consumed. - IREE_HAL_EXECUTION_STAGE_COMMAND_PROCESS = 1u << 1, - // Stage where dispatch commands execute. - IREE_HAL_EXECUTION_STAGE_DISPATCH = 1u << 2, - // Stage where transfer (copy/clear/fill/etc) commands execute. - IREE_HAL_EXECUTION_STAGE_TRANSFER = 1u << 3, - // Final stage in the pipeline when commands are retired on the device. - IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE = 1u << 4, - // Pseudo-stage for read/writes by the host. Not executed on device. - IREE_HAL_EXECUTION_STAGE_HOST = 1u << 5, -}; -typedef uint32_t iree_hal_execution_stage_t; - -// Bitfield specifying which scopes will access memory and how. -// -// Maps to VkAccessFlagBits. -enum iree_hal_access_scope_e { - // Read access to indirect command data as part of an indirect dispatch. - IREE_HAL_ACCESS_SCOPE_INDIRECT_COMMAND_READ = 1u << 0, - // Constant uniform buffer reads by the device. - IREE_HAL_ACCESS_SCOPE_CONSTANT_READ = 1u << 1, - // Storage buffer reads by dispatch commands. - IREE_HAL_ACCESS_SCOPE_DISPATCH_READ = 1u << 2, - // Storage buffer writes by dispatch commands. - IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE = 1u << 3, - // Source of a transfer operation. - IREE_HAL_ACCESS_SCOPE_TRANSFER_READ = 1u << 4, - // Target of a transfer operation. - IREE_HAL_ACCESS_SCOPE_TRANSFER_WRITE = 1u << 5, - // Read operation by the host through mapped memory. - IREE_HAL_ACCESS_SCOPE_HOST_READ = 1u << 6, - // Write operation by the host through mapped memory. - IREE_HAL_ACCESS_SCOPE_HOST_WRITE = 1u << 7, - // External/non-specific read. - IREE_HAL_ACCESS_SCOPE_MEMORY_READ = 1u << 8, - // External/non-specific write. - IREE_HAL_ACCESS_SCOPE_MEMORY_WRITE = 1u << 9, -}; -typedef uint32_t iree_hal_access_scope_t; - -// Defines a global memory barrier. -// These are cheaper to encode than buffer-specific barriers but may cause -// stalls and bubbles in device pipelines if applied too broadly. Prefer them -// over equivalently large sets of buffer-specific barriers (such as when -// completely changing execution contexts). -// -// Maps to VkMemoryBarrier. -typedef struct { - // All access scopes prior-to the barrier (inclusive). - iree_hal_access_scope_t source_scope; - // All access scopes following the barrier (inclusive). - iree_hal_access_scope_t target_scope; -} iree_hal_memory_barrier_t; - -// Defines a memory barrier that applies to a range of a specific buffer. -// Use of these (vs. global memory barriers) provides fine-grained execution -// ordering to device command processors and allows for more aggressive -// reordering. -// -// Maps to VkBufferMemoryBarrier. -typedef struct { - // All access scopes prior-to the barrier (inclusive). - iree_hal_access_scope_t source_scope; - // All access scopes following the barrier (inclusive). - iree_hal_access_scope_t target_scope; - // Buffer the barrier is restricted to. - // The barrier will apply to the entire physical device allocation. - iree_hal_buffer_t* buffer; - // Relative offset/length within |buffer| (which may itself be mapped into the - // device allocation at an offset). - iree_device_size_t offset; - iree_device_size_t length; -} iree_hal_buffer_barrier_t; - -// A list of semaphores and their corresponding payloads. -// When signaling each semaphore will be set to the new payload value provided. -// When waiting each semaphore must reach or exceed the payload value. -typedef struct { - iree_host_size_t count; - iree_hal_semaphore_t** semaphores; - uint64_t* payload_values; -} iree_hal_semaphore_list_t; - -// A single batch of command buffers submitted to a device queue. -// All of the wait semaphores must reach or exceed the given payload value prior -// to the batch beginning execution. Each command buffer begins execution in the -// order it is present in the list, though note that the command buffers -// execute concurrently and require internal synchronization via events if there -// are any dependencies between them. Only after all command buffers have -// completed will the signal semaphores be updated to the provided payload -// values. -// -// Matches Vulkan's VkSubmitInfo: -// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkSubmitInfo.html -// Note that as the HAL only models timeline semaphores we take the payload -// values directly in this struct; see: -// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkTimelineSemaphoreSubmitInfo.html -typedef struct { - // Semaphores to wait on prior to executing any command buffer. - iree_hal_semaphore_list_t wait_semaphores; - - // Command buffers to execute, in order. - iree_host_size_t command_buffer_count; - iree_hal_command_buffer_t** command_buffers; - - // Semaphores to signal once all command buffers have completed execution. - iree_hal_semaphore_list_t signal_semaphores; -} iree_hal_submission_batch_t; - -// Defines how a multi-wait operation treats the results of multiple semaphores. -enum iree_hal_wait_mode_e { - // Waits for all semaphores to reach or exceed their specified values. - IREE_HAL_WAIT_MODE_ALL = 0, - // Waits for one or more semaphores to reach or exceed their specified values. - IREE_HAL_WAIT_MODE_ANY = 1, -}; -typedef uint8_t iree_hal_wait_mode_t; - -// Keep these in sync with iree/compiler/Dialect/HAL/IR/HALTypes.cpp - -enum iree_hal_numerical_type_e { - IREE_HAL_NUMERICAL_TYPE_UNKNOWN = 0x00u, - IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED = 0x01u, - IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED = 0x02u, - // TODO(benvanik): specialize with semantics from APFloat. - IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE = 0x03u, -}; -typedef uint8_t iree_hal_numerical_type_t; - -#define IREE_HAL_ELEMENT_TYPE_VALUE(numerical_type, bit_count) \ - (((uint32_t)(numerical_type) << 24) | (uint32_t)(bit_count)) - -#define iree_hal_make_element_type(numerical_type, bit_count) \ - (iree_hal_element_type_t)( \ - IREE_HAL_ELEMENT_TYPE_VALUE(numerical_type, bit_count)) -#define iree_hal_element_numerical_type(element_type) \ - (iree_hal_numerical_type_t)((uint32_t)(element_type) >> 24) -#define iree_hal_element_bit_count(element_type) (size_t)((element_type)&0xFF) -#define iree_hal_element_byte_count(element_type) \ - ((iree_hal_element_bit_count(element_type) + 8 - 1) / 8) - -// Defines the element type of a buffer in a standard format. -// -// Composed as a 32-bit bitfield to allow for opaque data types. Use -// iree_hal_make_element_type to make a bitfield with the appropriate ordering. -// -// MSB ----------------------------------------------- LSB -// [numerical type] [reserved] [reserved] [number of bits] -// -// clang-format off -enum iree_hal_element_type_e { - IREE_HAL_ELEMENT_TYPE_NONE = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 0), // NOLINT - IREE_HAL_ELEMENT_TYPE_OPAQUE_8 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 8), // NOLINT - IREE_HAL_ELEMENT_TYPE_OPAQUE_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 16), // NOLINT - IREE_HAL_ELEMENT_TYPE_OPAQUE_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 32), // NOLINT - IREE_HAL_ELEMENT_TYPE_OPAQUE_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 64), // NOLINT - IREE_HAL_ELEMENT_TYPE_SINT_8 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 8), // NOLINT - IREE_HAL_ELEMENT_TYPE_UINT_8 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 8), // NOLINT - IREE_HAL_ELEMENT_TYPE_SINT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 16), // NOLINT - IREE_HAL_ELEMENT_TYPE_UINT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 16), // NOLINT - IREE_HAL_ELEMENT_TYPE_SINT_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 32), // NOLINT - IREE_HAL_ELEMENT_TYPE_UINT_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 32), // NOLINT - IREE_HAL_ELEMENT_TYPE_SINT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 64), // NOLINT - IREE_HAL_ELEMENT_TYPE_UINT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 64), // NOLINT - IREE_HAL_ELEMENT_TYPE_FLOAT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, 16), // NOLINT - IREE_HAL_ELEMENT_TYPE_FLOAT_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, 32), // NOLINT - IREE_HAL_ELEMENT_TYPE_FLOAT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, 64), // NOLINT -}; -typedef uint32_t iree_hal_element_type_t; -// clang-format on - -// A dimension within a shape. -typedef int32_t iree_hal_dim_t; - -//===----------------------------------------------------------------------===// -// Utilities -//===----------------------------------------------------------------------===// - -// Parses a serialized set of shape dimensions using the canonical shape format -// (the same as produced by iree_hal_format_shape). -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_shape( - iree_string_view_t value, iree_host_size_t shape_capacity, - iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank); - -// Converts shape dimensions into a `4x5x6` format. -// -// Follows the standard API string formatting rules. See iree/base/api.h. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_format_shape(const iree_hal_dim_t* shape, iree_host_size_t shape_rank, - iree_host_size_t buffer_capacity, char* buffer, - iree_host_size_t* out_buffer_length); - -// Parses a serialized iree_hal_element_type_t and sets |out_element_type| if -// it is valid. The format is the same as produced by -// iree_hal_format_element_type. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element_type( - iree_string_view_t value, iree_hal_element_type_t* out_element_type); - -// Converts an iree_hal_element_type_t enum value to a canonical string -// representation, like `IREE_HAL_ELEMENT_TYPE_FLOAT_16` to `f16`. -// |buffer_capacity| defines the size of |buffer| in bytes and -// |out_buffer_length| will return the string length in characters. -// -// Follows the standard API string formatting rules. See iree/base/api.h. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element_type( - iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length); - -// Parses a serialized element of |element_type| to its in-memory form. -// |data_ptr| must be at least large enough to contain the bytes of the element. -// For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4 -// byte float value of 1.2 to |data_ptr|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element( - iree_string_view_t data_str, iree_hal_element_type_t element_type, - iree_byte_span_t data_ptr); - -// Converts a single element of |element_type| to a string. -// -// |buffer_capacity| defines the size of |buffer| in bytes and -// |out_buffer_length| will return the string length in characters. Returns -// IREE_STATUS_OUT_OF_RANGE if the buffer capacity is insufficient to hold the -// formatted elements and |out_buffer_length| will contain the required size. -// -// Follows the standard API string formatting rules. See iree/base/api.h. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element( - iree_const_byte_span_t data, iree_hal_element_type_t element_type, - iree_host_size_t buffer_capacity, char* buffer, - iree_host_size_t* out_buffer_length); - -// Parses a serialized set of elements of the given |element_type|. -// The resulting parsed data is written to |data_ptr|, which must be at least -// large enough to contain the parsed elements. The format is the same as -// produced by iree_hal_format_buffer_elements. Supports additional inputs of -// empty to denote a 0 fill and a single element to denote a splat. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_buffer_elements( - iree_string_view_t data_str, iree_hal_element_type_t element_type, - iree_byte_span_t data_ptr); - -// Converts a shaped buffer of |element_type| elements to a string. -// This will include []'s to denote each dimension, for example for a shape of -// 2x3 the elements will be formatted as `[1 2 3][4 5 6]`. -// -// |max_element_count| can be used to limit the total number of elements printed -// when the count may be large. Elided elements will be replaced with `...`. -// -// |buffer_capacity| defines the size of |buffer| in bytes and -// |out_buffer_length| will return the string length in characters. Returns -// IREE_STATUS_OUT_OF_RANGE if the buffer capacity is insufficient to hold the -// formatted elements and |out_buffer_length| will contain the required size. -// -// Follows the standard API string formatting rules. See iree/base/api.h. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_buffer_elements( - iree_const_byte_span_t data, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length); - -//===----------------------------------------------------------------------===// -// iree::hal::Allocator -//===----------------------------------------------------------------------===// - -// Creates a host-local heap allocator that can be used when buffers are -// required that will not interact with a real hardware device (such as those -// used in file IO or tests). Buffers allocated with this will not be compatible -// with real device allocators and will likely incur a copy if used. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_allocator_create_host_local(iree_allocator_t allocator, - iree_hal_allocator_t** out_allocator); - -// Retains the given |allocator| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_allocator_retain(iree_hal_allocator_t* allocator); - -// Releases the given |allocator| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_allocator_release(iree_hal_allocator_t* allocator); - -// Calculates the allocation size of a buffer. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_size( - const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - iree_device_size_t* out_allocation_size); - -// Calculates a byte offset into a buffer at the given indices. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_offset( - const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - const iree_hal_dim_t* indices, size_t indices_count, - iree_device_size_t* out_offset); - -// Calculates a byte range into a buffer of the given contiguous range. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_compute_range( - const iree_hal_allocator_t* allocator, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, - const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, - iree_device_size_t* out_start_offset, iree_device_size_t* out_length); - -// Allocates a buffer from the allocator. -// Fails if the memory type requested for the given usage cannot be serviced. -// Callers can use iree_hal_allocator_can_allocate to decide their memory use -// strategy. -// -// The memory type of the buffer returned may differ from the requested value -// if the device can provide more functionality; for example, if requesting -// MemoryType::kHostVisible but the memory is really host cached you may get -// a buffer back with MemoryType::kHostVisible | MemoryType::kHostCached. The -// only requirement is that the buffer satisfy the required bits. -// -// Fails if it is not possible to allocate and satisfy all placements for the -// requested |buffer_usage|. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_allocate_buffer( - iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, iree_host_size_t allocation_size, - iree_hal_buffer_t** out_buffer); - -// Wraps an existing host allocation in a buffer. -// Ownership of the allocation remains with the caller and the memory must -// remain valid for so long as the buffer may be in use. -// -// Fails if the allocator cannot access host memory in this way. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_allocator_wrap_buffer( - iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, iree_byte_span_t data, - iree_hal_buffer_t** out_buffer); - -//===----------------------------------------------------------------------===// -// iree::hal::Buffer -//===----------------------------------------------------------------------===// - -// Returns a reference to a subspan of the |buffer|. -// If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer after -// |byte_offset| (possibly 0) will be selected. -// -// The parent buffer will remain alive for the lifetime of the subspan -// returned. If the subspan is a small portion this may cause additional -// memory to remain allocated longer than required. -// -// Returns the given |buffer| if the requested span covers the entire range. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_subspan( - iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer); - -// Retains the given |buffer| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_buffer_retain(iree_hal_buffer_t* buffer); - -// Releases the given |buffer| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_buffer_release(iree_hal_buffer_t* buffer); - -// Returns the allocator this buffer was allocated from. -IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL -iree_hal_buffer_allocator(const iree_hal_buffer_t* buffer); - -// Returns the size in bytes of the buffer. -IREE_API_EXPORT iree_device_size_t IREE_API_CALL -iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer); - -// Sets a range of the buffer to binary zero. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length); - -// Sets a range of the buffer to the given value. -// Only |pattern_length| values with 1, 2, or 4 bytes are supported. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_host_size_t pattern_length); - -// Reads a block of data from the buffer at the given offset. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_read_data( - iree_hal_buffer_t* buffer, iree_device_size_t source_offset, - void* target_buffer, iree_device_size_t data_length); - -// Writes a block of byte data into the buffer at the given offset. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_write_data( - iree_hal_buffer_t* buffer, iree_device_size_t target_offset, - const void* source_buffer, iree_device_size_t data_length); - -// Copies data from the provided |source_buffer| into the |target_buffer|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_copy_data( - iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, - iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, - iree_device_size_t data_length); - -// Maps the buffer to be accessed as a host pointer into |out_mapped_memory|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map( - iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, - iree_device_size_t byte_offset, iree_device_size_t byte_length, - iree_hal_mapped_memory_t* out_mapped_memory); - -// Unmaps the buffer as was previously mapped to |mapped_memory|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_unmap( - iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory); - -//===----------------------------------------------------------------------===// -// iree::hal::HeapBuffer -//===----------------------------------------------------------------------===// - -// Allocates a zeroed host heap buffer of the given size. -// The buffer contents will be allocated with |contents_allocator| while -// |allocator| is used for the iree_hal_buffer_t. -// -// Returns a buffer allocated with malloc that may not be usable by devices -// without copies. |memory_type| should be set to -// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_host_size_t allocation_size, iree_allocator_t contents_allocator, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer); - -// Allocates a host heap buffer with a copy of the given data. -// The buffer contents will be allocated with |contents_allocator| while -// |allocator| is used for the iree_hal_buffer_t. -// -// Returns a buffer allocated with malloc that may not be usable by devices -// without copies. |memory_type| should be set to -// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_hal_memory_access_t allowed_access, iree_byte_span_t contents, - iree_allocator_t contents_allocator, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer); - -// Wraps an existing host heap allocation in a buffer. -// Ownership of the host allocation remains with the caller and the memory -// must remain valid for so long as the iree_hal_buffer_t may be in use. -// -// Returns a buffer allocated with malloc that may not be usable by devices -// without copies. |memory_type| should be set to -// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( - iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t usage, iree_byte_span_t contents, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer); - -// TODO(benvanik): add a wrap that takes an allocator just for the buffer. - -//===----------------------------------------------------------------------===// -// iree::hal::BufferView -//===----------------------------------------------------------------------===// - -// Creates a buffer view with the given |buffer|. -// |out_buffer_view| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_create( - iree_hal_buffer_t* buffer, const iree_hal_dim_t* shape, - iree_host_size_t shape_rank, iree_hal_element_type_t element_type, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view); - -// Creates a buffer view referencing a subview of the given |buffer_view|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_subview( - const iree_hal_buffer_view_t* buffer_view, - const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, - const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view); - -// Retains the given |buffer_view| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view); - -// Releases the given |buffer_view| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view); - -// Returns the buffer underlying the buffer view. -// The caller must retain the returned buffer if they want to continue using it. -IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL -iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view); - -// Returns the rank of the shape associated with the buffer view. -IREE_API_EXPORT iree_host_size_t IREE_API_CALL -iree_hal_buffer_view_shape_rank(const iree_hal_buffer_view_t* buffer_view); - -// Returns the value of the given dimension. -IREE_API_EXPORT iree_host_size_t IREE_API_CALL iree_hal_buffer_view_shape_dim( - const iree_hal_buffer_view_t* buffer_view, iree_host_size_t index); - -// Returns the dimensions of the shape in |out_shape| and its rank in -// |out_shape_rank|. |rank_capacity| indicates the number of dimensions -// available in the |out_shape| buffer. If there is not enough capacity to store -// all of the dimensions IREE_STATUS_OUT_OF_RANGE is returned. -// |out_shape_rank| can be omitted if the rank is already known. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape( - const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity, - iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank); - -// Returns the total number of elements stored in the view. -IREE_API_EXPORT iree_host_size_t -iree_hal_buffer_view_element_count(const iree_hal_buffer_view_t* buffer_view); - -// Returns the element type of the buffer. -IREE_API_EXPORT iree_hal_element_type_t IREE_API_CALL -iree_hal_buffer_view_element_type(const iree_hal_buffer_view_t* buffer_view); - -// Returns the size of each element in the buffer view in bytes. -// Note that not all buffers are contiguous or densely packed. -IREE_API_EXPORT iree_host_size_t IREE_API_CALL -iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view); - -// Returns the total size of the specified view in bytes. -// Note that not all buffers are contiguous or densely packed. -IREE_API_EXPORT iree_device_size_t IREE_API_CALL -iree_hal_buffer_view_byte_length(const iree_hal_buffer_view_t* buffer_view); - -// Calculates a byte offset into the |buffer_view| at the given indices. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_offset( - const iree_hal_buffer_view_t* buffer_view, const iree_hal_dim_t* indices, - iree_host_size_t indices_count, iree_device_size_t* out_offset); - -// Calculates a byte range into the |buffer_view| of the given contiguous range. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_range( - const iree_hal_buffer_view_t* buffer_view, - const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, - const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, - iree_device_size_t* out_start_offset, iree_device_size_t* out_length); - -// Parses a serialized set of buffer elements in the canonical tensor format -// (the same as produced by iree_hal_buffer_view_format). The underlying buffer -// will be allocated with |buffer_allocator| as a host-local/device-visible -// buffer. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_parse( - iree_string_view_t value, iree_hal_allocator_t* buffer_allocator, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view); - -// Converts buffer view elements into a fully-specified string-form format like -// `2x4xi16=[[1 2][3 4]]`. -// -// |max_element_count| can be used to limit the total number of elements printed -// when the count may be large. Elided elements will be replaced with `...`. -// -// |buffer_capacity| defines the size of |buffer| in bytes and -// |out_buffer_length| will return the string length in characters. Returns -// IREE_STATUS_OUT_OF_RANGE if the buffer capacity is insufficient to hold the -// formatted elements and |out_buffer_length| will contain the required size. -// -// Follows the standard API string formatting rules. See iree/base/api.h. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_format( - const iree_hal_buffer_view_t* buffer_view, - iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, - char* buffer, iree_host_size_t* out_buffer_length); - -//===----------------------------------------------------------------------===// -// iree::hal::CommandBuffer -//===----------------------------------------------------------------------===// - -// Creates a command buffer ready to begin recording, possibly reusing an -// existing one from the |device| pool. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_create( - iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, iree_allocator_t allocator, - iree_hal_command_buffer_t** out_command_buffer); - -// Retains the given |command_buffer| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_command_buffer_retain(iree_hal_command_buffer_t* command_buffer); - -// Releases the given |command_buffer| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_command_buffer_release(iree_hal_command_buffer_t* command_buffer); - -// Resets and begins recording into the command buffer, clearing all -// previously recorded contents. -// The command buffer must not be in-flight. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_begin(iree_hal_command_buffer_t* command_buffer); - -// Ends recording into the command buffer. -// This must be called prior to submitting the command buffer for execution. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_end(iree_hal_command_buffer_t* command_buffer); - -// Defines a memory dependency between commands recorded before and after the -// barrier. One or more memory or buffer barriers can be specified to indicate -// between which stages or buffers the dependencies exist. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_execution_barrier( - iree_hal_command_buffer_t* command_buffer, - iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - iree_host_size_t memory_barrier_count, - const iree_hal_memory_barrier_t* memory_barriers, - iree_host_size_t buffer_barrier_count, - const iree_hal_buffer_barrier_t* buffer_barriers); - -// Fills the target buffer with the given repeating value. -// Expects that |pattern_length| is one of 1, 2, or 4 and that the offset and -// length are aligned to the natural alignment of the value. -// The target buffer must be compatible with the devices owned by this -// device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_fill_buffer( - iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* target_buffer, - iree_device_size_t target_offset, iree_device_size_t length, - const void* pattern, iree_host_size_t pattern_length); - -// Updates a range of the given target buffer from the source host memory. -// The source host memory is copied immediately into the command buffer and -// occupies command buffer space. It is strongly recommended that large buffer -// updates are performed via iree_hal_command_buffer_copy_buffer where there is -// the possibility of a zero-copy path. -// The |source_buffer| may be releaed by the caller immediately after this -// call returns. -// The |target_buffer| must be compatible with the devices owned by this -// device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_update_buffer(iree_hal_command_buffer_t* command_buffer, - const void* source_buffer, - iree_host_size_t source_offset, - iree_hal_buffer_t* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length); - -// Copies a range of one buffer to another. -// Both buffers must be compatible with the devices owned by this device -// queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. Though the source -// and target buffer may be the same the ranges must not overlap (as with -// memcpy). -// -// This can be used to perform device->host, host->device, and device->device -// copies. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_copy_buffer( - iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, - iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, - iree_device_size_t target_offset, iree_device_size_t length); - -// Pushes an inline set of constants that can be accessed by subsequent -// dispatches using a compatible executable layout. -// -// Push constants are always 4-byte values and treated as opaque, meaning that -// they may be bit-casted floats, bit-packed booleans, etc. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_push_constants( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, - const void* values, iree_host_size_t values_length); - -// Pushes a descriptor set and associates it with |set|. -// This uses an internal ringbuffer inside of the command buffer to avoid the -// need for creating and binding descriptor sets and managing their lifetime. -// -// The descriptor set will remain bound and valid so long as the executable -// layouts used by dispatches are compatible (same descriptor layouts and push -// constant sizes). -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_push_descriptor_set( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_layout_t* executable_layout, int32_t set, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_binding_t* bindings); - -// Binds a descriptor set to the given |set| matching that used in the -// executable layout interface. -// -// The descriptor set will remain bound and valid so long as the executable -// layouts used by dispatches are compatible (same descriptor layouts and push -// constant sizes). -// -// If any dynamic descriptor types are defined in the descriptor set layout then -// the dynamic offsets must be provided. These offsets will be added to the base -// offset of the descriptor layout binding. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_bind_descriptor_set( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_layout_t* executable_layout, int32_t set, - iree_hal_descriptor_set_t* descriptor_set, - iree_host_size_t dynamic_offset_count, - const iree_device_size_t* dynamic_offsets); - -// Dispatches an execution request. -// The request may execute overlapped with any other transfer operation or -// dispatch made within the same barrier-defined sequence. -// -// The executable specified must be registered for use with the device driver -// owning this queue. It must not be unregistered until all requests that use -// it have completed. -// -// Fails if the queue does not support dispatch operations (as indicated by -// can_dispatch). -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_dispatch( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); - -// Dispatches an execution request with deferred workgroup counts. -// This is the same as iree_hal_command_buffer_dispatch but the workgroup counts -// are read from the given |workgroups_buffer| at offset |workgroups_offset| as -// 3 uint32_t XYZ values before performing the dispatch. This allows prior -// dispatches within the command sequence to populate the workgroup counts. -// -// The buffer must have been allocated with IREE_HAL_BUFFER_USAGE_DISPATCH and -// be of IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_command_buffer_dispatch_indirect( - iree_hal_command_buffer_t* command_buffer, - iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset); - -//===----------------------------------------------------------------------===// -// iree::hal::DescriptorSet -//===----------------------------------------------------------------------===// - -// Creates a descriptor set of the given layout and bindings. -// Descriptor sets are immutable and retain their bindings. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_descriptor_set_create( - iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_binding_t* bindings, - iree_allocator_t allocator, iree_hal_descriptor_set_t** out_descriptor_set); - -// Retains the given |set| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_descriptor_set_retain(iree_hal_descriptor_set_t* descriptor_set); - -// Releases the given |set| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_descriptor_set_release(iree_hal_descriptor_set_t* descriptor_set); - -//===----------------------------------------------------------------------===// -// iree::hal::DescriptorSetLayout -//===----------------------------------------------------------------------===// - -// Creates a descriptor set layout with the given bindings. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_descriptor_set_layout_create( - iree_hal_device_t* device, - iree_hal_descriptor_set_layout_usage_type_t usage_type, - iree_host_size_t binding_count, - const iree_hal_descriptor_set_layout_binding_t* bindings, - iree_allocator_t allocator, - iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); - -// Retains the given |descriptor_set_layout| for the caller. -IREE_API_EXPORT void IREE_API_CALL iree_hal_descriptor_set_layout_retain( - iree_hal_descriptor_set_layout_t* descriptor_set_layout); - -// Releases the given |descriptor_set_layout| from the caller. -IREE_API_EXPORT void IREE_API_CALL iree_hal_descriptor_set_layout_release( - iree_hal_descriptor_set_layout_t* descriptor_set_layout); - -//===----------------------------------------------------------------------===// -// iree::hal::Device -//===----------------------------------------------------------------------===// - -// Retains the given |device| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_device_retain(iree_hal_device_t* device); - -// Releases the given |device| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_device_release(iree_hal_device_t* device); - -// Returns a reference to the allocator of the device that can be used for -// allocating buffers. -IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL -iree_hal_device_allocator(iree_hal_device_t* device); - -// Returns the device identifier. -// This identifier may vary based on the runtime device type; for example, a -// Vulkan device may return `vulkan-v1.1` or `vulkan-v1.2-spec1`. -IREE_API_EXPORT iree_string_view_t IREE_API_CALL -iree_hal_device_id(iree_hal_device_t* device); - -// Submits one or more batches of work to a device queue. -// -// The queue is selected based on the flags set in |command_categories| and the -// |queue_affinity|. As the number of available queues can vary the -// |queue_affinity| is used to hash into the available queues for the required -// categories. For example if 2 queues support transfer commands and the -// affinity is 5 the resulting queue could be index hash(5)=1. The affinity can -// thus be treated as just a way to indicate whether two submissions must be -// placed on to the same queue. Note that the exact hashing function is -// implementation dependent. -// -// The submission behavior matches Vulkan's vkQueueSubmit, with each batch -// executing its command buffers in the order they are defined but allowing the -// command buffers to complete out-of-order. See: -// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/vkQueueSubmit.html -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_device_queue_submit( - iree_hal_device_t* device, iree_hal_command_category_t command_categories, - uint64_t queue_affinity, iree_host_size_t batch_count, - const iree_hal_submission_batch_t* batches); - -// Blocks the caller until the semaphores reach or exceed the specified payload -// values or the |deadline_ns| elapses. All semaphores in |semaphore_list| must -// be created from this device (or be imported into it). -// -// |wait_mode| can be used to decide when the wait will proceed; whether *all* -// semaphores in |semaphore_list| must be signaled or whether *any* (one or -// more) can be signaled before an early return. -// -// Returns success if the wait is successful and semaphores have been signaled -// satisfying the |wait_mode|. -// -// Returns DEADLINE_EXCEEDED if the |deadline_ns| elapses without the -// |wait_mode| being satisfied. Note that even on success only a subset of the -// semaphores may have been signaled and each can be queried to see which ones. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_device_wait_semaphores_with_deadline( - iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, - const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns); - -// Blocks the caller until the semaphores reach or exceed the specified payload -// values or the |timeout_ns| elapses. -// A relative-time version of iree_hal_device_wait_semaphores_with_deadline -// using the relative nanoseconds from the time the call is made. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_device_wait_semaphores_with_timeout( - iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, - const iree_hal_semaphore_list_t* semaphore_list, - iree_duration_t timeout_ns); - -//===----------------------------------------------------------------------===// -// iree::hal::Driver -//===----------------------------------------------------------------------===// - -// Retains the given |driver| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_driver_retain(iree_hal_driver_t* driver); - -// Releases the given |driver| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_driver_release(iree_hal_driver_t* driver); - -// Queries available devices and returns them as a list. -// The provided |allocator| will be used to allocate the returned list and after -// the caller is done with it |out_device_infos| must be freed with that same -// allocator by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_driver_query_available_devices( - iree_hal_driver_t* driver, iree_allocator_t allocator, - iree_hal_device_info_t** out_device_infos, - iree_host_size_t* out_device_info_count); - -// Creates a device as queried with iree_hal_driver_query_available_devices. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_create_device( - iree_hal_driver_t* driver, iree_hal_device_id_t device_id, - iree_allocator_t allocator, iree_hal_device_t** out_device); - -// Creates the driver-defined "default" device. This may simply be the first -// device enumerated. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_driver_create_default_device(iree_hal_driver_t* driver, - iree_allocator_t allocator, - iree_hal_device_t** out_device); - -//===----------------------------------------------------------------------===// -// iree_hal_driver_registry_t -//===----------------------------------------------------------------------===// - -// Factory interface used for driver enumeration and creation. -// The factory is designed to in many cases live in rodata by not requiring any -// real code or processing when the driver is statically known to be available. -// When drivers may be dynamically available based on system configuration a -// factory can discover them and provide them during enumeration. -// -// Delay-loaded drivers that may require non-trivial setup time (such as those -// implemented in dynamic libraries or over RPC) can be speculatively enumerated -// by a factory and then rely on the try_create to actually perform the slow -// work once the user has explicitly signaled that they are willing to pay the -// cost (and deal with the consequences). -// -// WARNING: this API is unstable until the HAL is fully ported. Do not use. -typedef struct { - // TODO(benvanik): version field. - IREE_API_UNSTABLE - - // User-defined pointer passed to all functions. - void* self; - - // Queries the list of available drivers provided by the factory, if any. - // |out_driver_infos| will be populated with a *reference* to factory data - // structures (such as the driver name) that callers may choose to clone if - // needed. - // - // Implementers must make their factory enumeration results immutable for the - // duration they are registered, though the behavior of try_create is allowed - // to change call-to-call. If a factory needs to mutate its set of enumerated - // devices then it must do so by first unregistering itself and re-registering - // only after the changes have been made. - // - // Called with the driver registry lock held; may be called from any thread. - iree_status_t(IREE_API_PTR* enumerate)( - void* self, const iree_hal_driver_info_t** out_driver_infos, - iree_host_size_t* out_driver_info_count); - - // Tries to create a driver as previously queried with enumerate. - // |driver_id| is the opaque ID returned from enumeration; note that there may - // be a significant amount of time between enumeration and creation and the - // driver registry lock may have been release between then. - // - // Delay-loaded drivers may still fail here if - for example - required system - // resources are unavailable or permission is denied. - // - // Called with the driver registry lock held; may be called from any thread. - iree_status_t(IREE_API_PTR* try_create)(void* self, - iree_hal_driver_id_t driver_id, - iree_allocator_t allocator, - iree_hal_driver_t** out_driver); -} iree_hal_driver_factory_t; - -// Returns the default per-process driver registry. -// In simple applications this is usually where you want to go to register and -// create drivers. More sophisticated applications that want tighter control -// over the visibility of drivers to certain callers such as when dealing with -// requests from multiple users may choose to allocate their own registries and -// manage their lifetime as desired. -// -// TODO(benvanik): remove global registry and make callers manage always. We can -// provide helpers to make that easier to do, but there's really no benefit to -// having this be global like it is. Alternatively, this can be opt-in thanks to -// LTO: if a user doesn't call this then the default registry is never -// allocated. -IREE_API_EXPORT iree_hal_driver_registry_t* IREE_API_CALL -iree_hal_driver_registry_default(); - -// Registers a driver factory to serve future queries/requests for drivers. -// See iree_hal_driver_registry_t for more information. -// -// Thread-safe. The factory is not retained and must be kept alive by the caller -// until it is unregistered (or the application terminates). -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_driver_registry_register_factory( - iree_hal_driver_registry_t* registry, - const iree_hal_driver_factory_t* factory); - -// Unregisters a driver factory. -// Unregistering a factory only prevents new drivers from being created; -// existing drivers may remain live even after unregistering. Factories can -// expect that no new drivers will be created via the factory after the call -// returns. -// -// Thread-safe. As the factory is not retained by the registry the caller must -// release its memory (if needed) after this call returns. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_driver_registry_unregister_factory( - iree_hal_driver_registry_t* registry, - const iree_hal_driver_factory_t* factory); - -// Enumerates all drivers from registered factories and returns them as a list. -// The provided |allocator| will be used to allocate the returned list and after -// the caller is done with it |out_driver_infos| must be freed with that same -// allocator by the caller. -// -// The set of drivers returned should be considered the superset of those that -// may be available for successful creation as it's possible that delay-loaded -// drivers may fail even if they appear in this list. -// -// Thread-safe. Note that the factory may be unregistered between the query -// completing and any attempt to instantiate the driver. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_registry_enumerate( - iree_hal_driver_registry_t* registry, iree_allocator_t allocator, - iree_hal_driver_info_t** out_driver_infos, - iree_host_size_t* out_driver_info_count); - -// Attempts to create a driver registered with the driver registry by a specific -// ID as returned during enumeration in iree_hal_driver_info_t::driver_id. -// This can be used to specify the exact driver to create in cases where there -// may be multiple factories providing drivers with the same name. -// -// Thread-safe. May block the caller if the driver is delay-loaded and needs to -// perform additional loading/verification/etc before returning. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_registry_try_create( - iree_hal_driver_registry_t* registry, iree_hal_driver_id_t driver_id, - iree_allocator_t allocator, iree_hal_driver_t** out_driver); - -// Attempts to create a driver registered with the given canonical driver name. -// Effectively enumerate + find by name + try_create if found. Factories are -// searched in most-recently-added order such that it's possible to override -// drivers with newer registrations when multiple factories provide the same -// driver name. -// -// Thread-safe. May block the caller if the driver is delay-loaded and needs to -// perform additional loading/verification/etc before returning. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_driver_registry_try_create_by_name( - iree_hal_driver_registry_t* registry, iree_string_view_t driver_name, - iree_allocator_t allocator, iree_hal_driver_t** out_driver); - -//===----------------------------------------------------------------------===// -// iree::hal::Executable -//===----------------------------------------------------------------------===// - -// Retains the given |executable| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_executable_retain(iree_hal_executable_t* executable); - -// Releases the given |executable| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_executable_release(iree_hal_executable_t* executable); - -//===----------------------------------------------------------------------===// -// iree::hal::ExecutableCache -//===----------------------------------------------------------------------===// - -// Creates an executable cache using the given identifier. -// The identifier is provided to the backing cache API as way to partition -// caches between different groups of executables (from different modules, etc). -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_cache_create( - iree_hal_device_t* device, iree_string_view_t identifier, - iree_allocator_t allocator, - iree_hal_executable_cache_t** out_executable_cache); - -// Retains the given |executable_cache| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_executable_cache_retain(iree_hal_executable_cache_t* executable_cache); - -// Releases the given |executable_cache| from the caller. -IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_cache_release( - iree_hal_executable_cache_t* executable_cache); - -// Returns true if the executable cache can prepare the given executable input -// format. Preparation may still fail if the particular version or features -// required by the executable are not supported. -IREE_API_EXPORT bool IREE_API_CALL iree_hal_executable_cache_can_prepare_format( - iree_hal_executable_cache_t* executable_cache, - iree_hal_executable_format_t format); - -// Prepares an executable for use. -// The provided |executable_data| will be used to either lookup a previously -// prepared executable in the cache or prepare a new one. -// -// Depending on the driver preparation may take a non-trivial amount of time -// (such as when JITing/etc). As the cache is internally synchronized callers -// can issue preparation requests from multiple threads - even for the same -// executables - and calls will block until preparation completes. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_executable_cache_prepare_executable( - iree_hal_executable_cache_t* executable_cache, - iree_hal_executable_layout_t* executable_layout, - iree_hal_executable_caching_mode_t caching_mode, - iree_const_byte_span_t executable_data, iree_allocator_t allocator, - iree_hal_executable_t** out_executable); - -//===----------------------------------------------------------------------===// -// iree::hal::ExecutableLayout -//===----------------------------------------------------------------------===// - -// Creates an executable layout composed of the given descriptor set layouts. -// The returned executable layout can be used by multiple executables with the -// same compatible resource binding layouts. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_layout_create( - iree_hal_device_t* device, iree_host_size_t set_layout_count, - iree_hal_descriptor_set_layout_t** set_layouts, - iree_host_size_t push_constants, iree_allocator_t allocator, - iree_hal_executable_layout_t** out_executable_layout); - -// Retains the given |executable_layout| for the caller. -IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_layout_retain( - iree_hal_executable_layout_t* executable_layout); - -// Releases the given |executable_layout| from the caller. -IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_layout_release( - iree_hal_executable_layout_t* executable_layout); - -//===----------------------------------------------------------------------===// -// iree::hal::Semaphore -//===----------------------------------------------------------------------===// - -// Creates a semaphore that can be used with command queues owned by this -// device. To use the semaphores with other devices or instances they must -// first be exported. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_semaphore_create( - iree_hal_device_t* device, uint64_t initial_value, - iree_allocator_t allocator, iree_hal_semaphore_t** out_semaphore); - -// Retains the given |semaphore| for the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore); - -// Releases the given |semaphore| from the caller. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore); - -// Queries the current payload of the semaphore and stores the result in -// |out_value|. As the payload is monotonically increasing it is guaranteed that -// the value is at least equal to the previous result of a -// iree_hal_semaphore_query call and coherent with any waits for a -// specified value via iree_device_wait_all_semaphores. -// -// Returns the status at the time the method is called without blocking and as -// such is only valid after a semaphore has been signaled. The same failure -// status will be returned regardless of when in the timeline the error -// occurred. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_query(iree_hal_semaphore_t* semaphore, uint64_t* out_value); - -// Signals the |semaphore| to the given payload value. -// The call is ignored if the current payload value exceeds |new_value|. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_signal(iree_hal_semaphore_t* semaphore, uint64_t new_value); - -// Signals the |semaphore| with a failure. The |status| will be returned from -// iree_hal_semaphore_query and iree_hal_semaphore_signal for the lifetime -// of the semaphore. -IREE_API_EXPORT void IREE_API_CALL -iree_hal_semaphore_fail(iree_hal_semaphore_t* semaphore, iree_status_t status); - -// Blocks the caller until the semaphore reaches or exceedes the specified -// payload value or the |deadline_ns| elapses. -// -// Returns success if the wait is successful and the semaphore has met or -// exceeded the required payload value. -// -// Returns DEADLINE_EXCEEDED if the |deadline_ns| elapses without the semaphore -// reaching the required value. If an asynchronous failure occured this will -// return the failure status that was set immediately. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_wait_with_deadline(iree_hal_semaphore_t* semaphore, - uint64_t value, iree_time_t deadline_ns); - -// Blocks the caller until the semaphore reaches or exceedes the specified -// payload value or the |timeout_ns| elapses. -// A relative-time version of iree_hal_semaphore_wait_with_deadline using the -// relative nanoseconds from the time the call is made. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_wait_with_timeout(iree_hal_semaphore_t* semaphore, - uint64_t value, - iree_duration_t timeout_ns); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +#include "iree/hal/allocator.h" +#include "iree/hal/buffer.h" +#include "iree/hal/buffer_view.h" +#include "iree/hal/command_buffer.h" +#include "iree/hal/descriptor_set.h" +#include "iree/hal/descriptor_set_layout.h" +#include "iree/hal/device.h" +#include "iree/hal/driver.h" +#include "iree/hal/driver_registry.h" +#include "iree/hal/event.h" +#include "iree/hal/executable.h" +#include "iree/hal/executable_cache.h" +#include "iree/hal/executable_layout.h" +#include "iree/hal/semaphore.h" +#include "iree/hal/string_util.h" #endif // IREE_HAL_API_H_ diff --git a/iree/hal/api_detail.h b/iree/hal/api_detail.h deleted file mode 100644 index 480ad13d9de94..0000000000000 --- a/iree/hal/api_detail.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Additional definitions for internal users of the api. This should only -// be included from internal implementation files. - -#ifndef IREE_HAL_API_DETAIL_H_ -#define IREE_HAL_API_DETAIL_H_ - -#include "iree/base/ref_ptr.h" -#include "iree/hal/api.h" - -// In the API, buffer views are ref objects, and this allows parts of the -// API outside of the HAL to work with them (such as the HAL module). -struct iree_hal_buffer_view final - : public iree::RefObject { - iree_allocator_t allocator; - iree_hal_buffer_t* buffer = nullptr; - iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; - iree_device_size_t byte_length = 0; - iree_host_size_t shape_rank = 0; - iree_hal_dim_t shape[]; - - static void Delete(iree_hal_buffer_view* ptr) { - iree_hal_buffer_release(ptr->buffer); - iree_allocator_free(ptr->allocator, ptr); - } -}; - -#endif diff --git a/iree/hal/buffer.c b/iree/hal/buffer.c new file mode 100644 index 0000000000000..8219c125cf8b6 --- /dev/null +++ b/iree/hal/buffer.c @@ -0,0 +1,721 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/buffer.h" + +#include "iree/base/tracing.h" +#include "iree/hal/allocator.h" +#include "iree/hal/detail.h" + +#define _VTABLE_DISPATCH(buffer, method_name) \ + IREE_HAL_VTABLE_DISPATCH(buffer, iree_hal_buffer, method_name) + +//===----------------------------------------------------------------------===// +// Subspan indirection buffer +//===----------------------------------------------------------------------===// + +static const iree_hal_buffer_vtable_t iree_hal_subspan_buffer_vtable; + +static iree_status_t iree_hal_subspan_buffer_create( + iree_hal_buffer_t* allocated_buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(allocated_buffer); + IREE_ASSERT_ARGUMENT(out_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_buffer_t* buffer = NULL; + iree_status_t status = iree_allocator_malloc( + iree_hal_allocator_host_allocator(allocated_buffer->allocator), + sizeof(*buffer), (void**)&buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_subspan_buffer_vtable, + &buffer->resource); + buffer->allocator = allocated_buffer->allocator; + buffer->allocated_buffer = allocated_buffer; + iree_hal_buffer_retain(buffer->allocated_buffer); + buffer->allocation_size = allocated_buffer->allocation_size; + buffer->byte_offset = byte_offset; + buffer->byte_length = byte_length; + buffer->memory_type = allocated_buffer->memory_type; + buffer->allowed_access = allocated_buffer->allowed_access; + buffer->allowed_usage = allocated_buffer->allowed_usage; + *out_buffer = buffer; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_subspan_buffer_destroy(iree_hal_buffer_t* base_buffer) { + iree_allocator_t host_allocator = + iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer)); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_buffer_release(base_buffer->allocated_buffer); + iree_allocator_free(host_allocator, base_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_subspan_buffer_map_range( + iree_hal_buffer_t* buffer, iree_hal_mapping_mode_t mapping_mode, + iree_hal_memory_access_t memory_access, + iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, + void** out_data_ptr) { + return _VTABLE_DISPATCH(buffer->allocated_buffer, map_range)( + buffer->allocated_buffer, mapping_mode, memory_access, local_byte_offset, + local_byte_length, out_data_ptr); +} + +static void iree_hal_subspan_buffer_unmap_range( + iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, void* data_ptr) { + _VTABLE_DISPATCH(buffer->allocated_buffer, unmap_range) + (buffer->allocated_buffer, local_byte_offset, local_byte_length, data_ptr); +} + +static iree_status_t iree_hal_subspan_buffer_invalidate_range( + iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + return _VTABLE_DISPATCH(buffer->allocated_buffer, invalidate_range)( + buffer->allocated_buffer, local_byte_offset, local_byte_length); +} + +static iree_status_t iree_hal_subspan_buffer_flush_range( + iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + return _VTABLE_DISPATCH(buffer->allocated_buffer, flush_range)( + buffer->allocated_buffer, local_byte_offset, local_byte_length); +} + +static const iree_hal_buffer_vtable_t iree_hal_subspan_buffer_vtable = { + .destroy = iree_hal_subspan_buffer_destroy, + .map_range = iree_hal_subspan_buffer_map_range, + .unmap_range = iree_hal_subspan_buffer_unmap_range, + .invalidate_range = iree_hal_subspan_buffer_invalidate_range, + .flush_range = iree_hal_subspan_buffer_flush_range, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_t +//===----------------------------------------------------------------------===// + +IREE_HAL_API_RETAIN_RELEASE(buffer); + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_validate_memory_type( + iree_hal_memory_type_t actual_memory_type, + iree_hal_memory_type_t expected_memory_type) { + if (IREE_UNLIKELY( + !iree_all_bits_set(actual_memory_type, expected_memory_type))) { + // Missing one or more bits. + return iree_make_status( + IREE_STATUS_PERMISSION_DENIED, + "buffer memory type is not compatible with the requested operation; " + "buffer has %s, operation requires %s", + iree_hal_memory_type_string(actual_memory_type), + iree_hal_memory_type_string(expected_memory_type)); + } + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_validate_access( + iree_hal_memory_access_t allowed_memory_access, + iree_hal_memory_access_t required_memory_access) { + if (IREE_UNLIKELY(!iree_any_bit_set( + required_memory_access, + IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE))) { + // No actual access bits defined. + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "memory access must specify one or more of _READ or _WRITE"); + } else if (IREE_UNLIKELY(!iree_all_bits_set(allowed_memory_access, + required_memory_access))) { + // Bits must match exactly. + return iree_make_status( + IREE_STATUS_PERMISSION_DENIED, + "buffer does not support the requested access " + "type; buffer allows %s, operation requires %s", + iree_hal_memory_access_string(allowed_memory_access), + iree_hal_memory_access_string(required_memory_access)); + } + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_validate_usage(iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t required_usage) { + if (IREE_UNLIKELY(!iree_all_bits_set(allowed_usage, required_usage))) { + // Missing one or more bits. + return iree_make_status( + IREE_STATUS_PERMISSION_DENIED, + "requested usage was not specified when the buffer was allocated; " + "buffer allows %s, operation requires %s", + iree_hal_buffer_usage_string(allowed_usage), + iree_hal_buffer_usage_string(required_usage)); + } + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_validate_range( + iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length) { + // Check if the start of the range runs off the end of the buffer. + if (IREE_UNLIKELY(byte_offset > iree_hal_buffer_byte_length(buffer))) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "attempted to access an address off the end of the valid buffer range " + "(offset=%zu, length=%zu, buffer byte_length=%zu)", + byte_offset, byte_length, iree_hal_buffer_byte_length(buffer)); + } + + if (byte_length == 0) { + // Fine to have a zero length. + return iree_ok_status(); + } + + // Check if the end runs over the allocation. + iree_device_size_t end = byte_offset + byte_length; + if (IREE_UNLIKELY(end > iree_hal_buffer_byte_length(buffer))) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "attempted to access an address outside of the valid buffer range " + "(offset=%zu, length=%zu, end(inc)=%zu, buffer byte_length=%zu)", + byte_offset, byte_length, end - 1, iree_hal_buffer_byte_length(buffer)); + } + + return iree_ok_status(); +} + +static iree_status_t iree_hal_buffer_calculate_range( + iree_device_size_t base_offset, iree_device_size_t max_length, + iree_device_size_t offset, iree_device_size_t length, + iree_device_size_t* out_adjusted_offset, + iree_device_size_t* out_adjusted_length) { + // Check if the start of the range runs off the end of the buffer. + if (IREE_UNLIKELY(offset > max_length)) { + *out_adjusted_offset = 0; + if (out_adjusted_length) *out_adjusted_length = 0; + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "attempted to access an address off the end of the valid buffer " + "range (offset=%zu, length=%zu, buffer byte_length=%zu)", + offset, length, max_length); + } + + // Handle length as IREE_WHOLE_BUFFER by adjusting it (if allowed). + if (IREE_UNLIKELY(length == IREE_WHOLE_BUFFER) && + IREE_UNLIKELY(!out_adjusted_length)) { + *out_adjusted_offset = 0; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "IREE_WHOLE_BUFFER may only be used with buffer " + "ranges, not external pointer ranges"); + } + + // Calculate the real ranges adjusted for our region within the allocation. + iree_device_size_t adjusted_offset = base_offset + offset; + iree_device_size_t adjusted_length = + length == IREE_WHOLE_BUFFER ? max_length - offset : length; + if (adjusted_length == 0) { + // Fine to have a zero length. + *out_adjusted_offset = adjusted_offset; + if (out_adjusted_length) *out_adjusted_length = adjusted_length; + return iree_ok_status(); + } + + // Check if the end runs over the allocation. + iree_device_size_t end = offset + adjusted_length - 1; + if (IREE_UNLIKELY(end >= max_length)) { + *out_adjusted_offset = 0; + if (out_adjusted_length) *out_adjusted_length = 0; + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "attempted to access an address outside of the valid buffer " + "range (offset=%zu, adjusted_length=%zu, end=%zu, buffer " + "byte_length=%zu)", + offset, adjusted_length, end, max_length); + } + + *out_adjusted_offset = adjusted_offset; + if (out_adjusted_length) *out_adjusted_length = adjusted_length; + return iree_ok_status(); +} + +IREE_API_EXPORT iree_hal_buffer_overlap_t IREE_API_CALL +iree_hal_buffer_test_overlap(iree_hal_buffer_t* lhs_buffer, + iree_device_size_t lhs_offset, + iree_device_size_t lhs_length, + iree_hal_buffer_t* rhs_buffer, + iree_device_size_t rhs_offset, + iree_device_size_t rhs_length) { + if (iree_hal_buffer_allocated_buffer(lhs_buffer) != + iree_hal_buffer_allocated_buffer(rhs_buffer)) { + // Not even the same buffers. + return IREE_HAL_BUFFER_OVERLAP_DISJOINT; + } + // Resolve offsets into the underlying allocation. + iree_device_size_t lhs_alloc_offset = + iree_hal_buffer_byte_offset(lhs_buffer) + lhs_offset; + iree_device_size_t rhs_alloc_offset = + iree_hal_buffer_byte_offset(rhs_buffer) + rhs_offset; + iree_device_size_t lhs_alloc_length = + lhs_length == IREE_WHOLE_BUFFER + ? iree_hal_buffer_byte_length(lhs_buffer) - lhs_offset + : lhs_length; + iree_device_size_t rhs_alloc_length = + rhs_length == IREE_WHOLE_BUFFER + ? iree_hal_buffer_byte_length(rhs_buffer) - rhs_offset + : rhs_length; + if (!lhs_alloc_length || !rhs_alloc_length) { + return IREE_HAL_BUFFER_OVERLAP_DISJOINT; + } + if (lhs_alloc_offset == rhs_alloc_offset && + lhs_alloc_length == rhs_alloc_length) { + return IREE_HAL_BUFFER_OVERLAP_COMPLETE; + } + return lhs_alloc_offset + lhs_alloc_length > rhs_alloc_offset && + rhs_alloc_offset + rhs_alloc_length > lhs_alloc_offset + ? IREE_HAL_BUFFER_OVERLAP_PARTIAL + : IREE_HAL_BUFFER_OVERLAP_DISJOINT; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_subspan( + iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(buffer); + IREE_ASSERT_ARGUMENT(out_buffer); + *out_buffer = NULL; + + // Fast path: if we are requesting the whole buffer (usually via + // IREE_WHOLE_BUFFER) then we can just return the buffer itself. + IREE_RETURN_IF_ERROR(iree_hal_buffer_calculate_range( + iree_hal_buffer_byte_offset(buffer), iree_hal_buffer_byte_length(buffer), + byte_offset, byte_length, &byte_offset, &byte_length)); + if (byte_offset == 0 && byte_length == iree_hal_buffer_byte_length(buffer)) { + iree_hal_buffer_retain(buffer); + *out_buffer = buffer; + return iree_ok_status(); + } + + // To avoid heavy nesting of subspans that just add indirection we go to the + // parent buffer directly. If we wanted better accounting (to track where + // buffers came from) we'd want to avoid this but I'm not sure that's worth + // the super deep indirection that could arise. + iree_hal_buffer_t* allocated_buffer = + iree_hal_buffer_allocated_buffer(buffer); + if (allocated_buffer != buffer) { + return iree_hal_buffer_subspan(allocated_buffer, byte_offset, byte_length, + out_buffer); + } + + return iree_hal_subspan_buffer_create(buffer, byte_offset, byte_length, + out_buffer); +} + +IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL +iree_hal_buffer_allocator(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->allocator; +} + +IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL +iree_hal_buffer_allocated_buffer(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->allocated_buffer; +} + +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_allocation_size(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->allocation_size; +} + +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_byte_offset(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->byte_offset; +} + +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->byte_length; +} + +IREE_API_EXPORT +iree_hal_memory_type_t IREE_API_CALL +iree_hal_buffer_memory_type(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->memory_type; +} + +IREE_API_EXPORT +iree_hal_memory_access_t IREE_API_CALL +iree_hal_buffer_allowed_access(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->allowed_access; +} + +IREE_API_EXPORT +iree_hal_buffer_usage_t IREE_API_CALL +iree_hal_buffer_allowed_usage(const iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(buffer); + return buffer->allowed_usage; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length) { + const uint8_t zero = 0; + return iree_hal_buffer_fill(buffer, byte_offset, byte_length, &zero, 1); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, const void* pattern, + iree_host_size_t pattern_length) { + IREE_ASSERT_ARGUMENT(buffer); + IREE_ASSERT_ARGUMENT(pattern); + + if (IREE_UNLIKELY(pattern_length != 1 && pattern_length != 2 && + pattern_length != 4)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "fill patterns must be 1, 2, or 4 bytes (got %zu)", + pattern_length); + } + + if (byte_length == 0) { + return iree_ok_status(); // No-op. + } + + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_buffer_mapping_t target_mapping; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_hal_buffer_map_range(buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, + byte_offset, byte_length, &target_mapping)); + if (byte_length == IREE_WHOLE_BUFFER) { + byte_length = target_mapping.contents.data_length; + } + + if (IREE_UNLIKELY((byte_offset % pattern_length) != 0) || + IREE_UNLIKELY((byte_length % pattern_length) != 0)) { + iree_hal_buffer_unmap_range(&target_mapping); + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "attempting to fill a range with %zu byte values " + "that is not aligned (offset=%zu, length=%zu)", + pattern_length, byte_offset, byte_length); + } + + const uint32_t zero_32 = 0; + if (memcmp(pattern, &zero_32, pattern_length) == 0) { + // We can turn all-zero values into single-byte fills as that can be much + // faster on devices (doing a fill8 vs fill32). + pattern_length = 1; + } + + iree_status_t status = iree_ok_status(); + void* data_ptr = target_mapping.contents.data; + switch (pattern_length) { + case 1: { + uint8_t* data = (uint8_t*)data_ptr; + uint8_t value_bits = *(const uint8_t*)(pattern); + memset(data, value_bits, byte_length); + break; + } + case 2: { + uint16_t* data = (uint16_t*)data_ptr; + uint16_t value_bits = *(const uint16_t*)(pattern); + for (iree_device_size_t i = 0; i < byte_length / sizeof(uint16_t); ++i) { + data[i] = value_bits; + } + break; + } + case 4: { + uint32_t* data = (uint32_t*)data_ptr; + uint32_t value_bits = *(const uint32_t*)(pattern); + for (iree_device_size_t i = 0; i < byte_length / sizeof(uint32_t); ++i) { + data[i] = value_bits; + } + break; + } + default: + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unsupported fill pattern length: %zu", + pattern_length); + break; + } + + if (iree_status_is_ok(status) && + !iree_all_bits_set(iree_hal_buffer_memory_type(buffer), + IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) { + status = iree_hal_buffer_flush_range(&target_mapping, 0, IREE_WHOLE_BUFFER); + } + + iree_hal_buffer_unmap_range(&target_mapping); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_read_data( + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + void* target_buffer, iree_device_size_t data_length) { + if (data_length == 0) { + return iree_ok_status(); // No-op. + } + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_buffer); + + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_buffer_mapping_t source_mapping; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_hal_buffer_map_range(source_buffer, IREE_HAL_MEMORY_ACCESS_READ, + source_offset, data_length, &source_mapping)); + + memcpy(target_buffer, source_mapping.contents.data, data_length); + + iree_hal_buffer_unmap_range(&source_mapping); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_write_data( + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + const void* source_buffer, iree_device_size_t data_length) { + if (data_length == 0) { + return iree_ok_status(); // No-op. + } + IREE_ASSERT_ARGUMENT(target_buffer); + IREE_ASSERT_ARGUMENT(source_buffer); + + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_buffer_mapping_t target_mapping; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_buffer_map_range( + target_buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, + target_offset, data_length, &target_mapping)); + + memcpy(target_mapping.contents.data, source_buffer, data_length); + + iree_status_t status = iree_ok_status(); + if (!iree_all_bits_set(iree_hal_buffer_memory_type(target_buffer), + IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) { + status = iree_hal_buffer_flush_range(&target_mapping, 0, IREE_WHOLE_BUFFER); + } + + iree_hal_buffer_unmap_range(&target_mapping); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_copy_data( + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t data_length) { + if (data_length == 0) { + return iree_ok_status(); // No-op. + } + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_buffer); + + // Check for overlap - like memcpy we require that the two ranges don't have + // any overlap - because we use memcpy below! + if (iree_hal_buffer_test_overlap(source_buffer, source_offset, data_length, + target_buffer, target_offset, data_length) != + IREE_HAL_BUFFER_OVERLAP_DISJOINT) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "source and target ranges must not overlap within the same buffer"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + // Map source, which may have IREE_WHOLE_BUFFER length. + iree_hal_buffer_mapping_t source_mapping; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_hal_buffer_map_range(source_buffer, IREE_HAL_MEMORY_ACCESS_READ, + source_offset, data_length, &source_mapping)); + + // Map target, which may also have IREE_WHOLE_BUFFER length. + iree_hal_buffer_mapping_t target_mapping; + iree_status_t status = iree_hal_buffer_map_range( + target_buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, target_offset, + data_length, &target_mapping); + if (!iree_status_is_ok(status)) { + iree_hal_buffer_unmap_range(&source_mapping); + IREE_TRACE_ZONE_END(z0); + return status; + } + + // Adjust the data length based on the min we have. + iree_device_size_t adjusted_data_length = 0; + if (data_length == IREE_WHOLE_BUFFER) { + // Whole buffer copy requested - that could mean either, so take the min. + adjusted_data_length = iree_min(source_mapping.contents.data_length, + target_mapping.contents.data_length); + } else { + // Specific length requested - validate that we have matching lengths. + IREE_ASSERT_EQ(source_mapping.contents.data_length, + target_mapping.contents.data_length); + adjusted_data_length = target_mapping.contents.data_length; + } + + // Elide zero length copies. + if (adjusted_data_length == 0) { + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + memcpy(target_mapping.contents.data, source_mapping.contents.data, + adjusted_data_length); + + if (!iree_all_bits_set(iree_hal_buffer_memory_type(target_buffer), + IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) { + status = + iree_hal_buffer_flush_range(&target_mapping, 0, adjusted_data_length); + } + + iree_hal_buffer_unmap_range(&source_mapping); + iree_hal_buffer_unmap_range(&target_mapping); + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// Mapping / iree_hal_buffer_mapping_impl_t +//===----------------------------------------------------------------------===// + +typedef struct { + // Must be first (as in iree_hal_buffer_mapping_t). + // Stores both the offset data pointer and the byte_length of the mapping. + iree_byte_span_t contents; + // Retained buffer providing the backing storage for the mapping. + iree_hal_buffer_t* backing_buffer; + // Byte offset within the buffer where the mapped data begins. + iree_device_size_t byte_offset; + // Used for validation only. + iree_hal_memory_access_t allowed_access; + uint32_t reserved0; // unused + uint64_t reserved1; // unused +} iree_hal_buffer_mapping_impl_t; + +// We overlay the impl onto the external iree_hal_buffer_mapping_t struct; +// ensure we match the fields that are exposed. +static_assert(sizeof(iree_hal_buffer_mapping_impl_t) <= + sizeof(iree_hal_buffer_mapping_t), + "buffer mapping impl must fit inside the external struct"); +static_assert(offsetof(iree_hal_buffer_mapping_impl_t, contents) == + offsetof(iree_hal_buffer_mapping_t, contents), + "contents byte span must match the external struct offset"); + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map_range( + iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + iree_hal_buffer_mapping_t* out_buffer_mapping) { + IREE_ASSERT_ARGUMENT(buffer); + IREE_ASSERT_ARGUMENT(out_buffer_mapping); + memset(out_buffer_mapping, 0, sizeof(*out_buffer_mapping)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( + iree_hal_buffer_memory_type(buffer), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + iree_hal_buffer_allowed_access(buffer), memory_access)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(buffer), IREE_HAL_BUFFER_USAGE_MAPPING)); + + iree_hal_buffer_mapping_impl_t* buffer_mapping = + (iree_hal_buffer_mapping_impl_t*)out_buffer_mapping; + buffer_mapping->backing_buffer = buffer; + buffer_mapping->allowed_access = memory_access; + IREE_RETURN_IF_ERROR(iree_hal_buffer_calculate_range( + iree_hal_buffer_byte_offset(buffer), iree_hal_buffer_byte_length(buffer), + byte_offset, byte_length, &buffer_mapping->byte_offset, + &buffer_mapping->contents.data_length)); + + // TODO(benvanik): add mode arg to the HAL API. + iree_hal_mapping_mode_t mapping_mode = IREE_HAL_MAPPING_MODE_SCOPED; + + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(buffer, map_range)( + buffer, mapping_mode, buffer_mapping->allowed_access, + buffer_mapping->byte_offset, buffer_mapping->contents.data_length, + (void**)&buffer_mapping->contents.data); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_unmap_range(iree_hal_buffer_mapping_t* base_buffer_mapping) { + IREE_ASSERT_ARGUMENT(base_buffer_mapping); + iree_hal_buffer_mapping_impl_t* buffer_mapping = + (iree_hal_buffer_mapping_impl_t*)base_buffer_mapping; + iree_hal_buffer_t* buffer = buffer_mapping->backing_buffer; + IREE_TRACE_ZONE_BEGIN(z0); + _VTABLE_DISPATCH(buffer, unmap_range) + (buffer, buffer_mapping->byte_offset, buffer_mapping->contents.data_length, + buffer_mapping->contents.data); + IREE_TRACE_ZONE_END(z0); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_invalidate_range( + iree_hal_buffer_mapping_t* base_buffer_mapping, + iree_device_size_t byte_offset, iree_device_size_t byte_length) { + IREE_ASSERT_ARGUMENT(base_buffer_mapping); + iree_hal_buffer_mapping_impl_t* buffer_mapping = + (iree_hal_buffer_mapping_impl_t*)base_buffer_mapping; + iree_hal_buffer_t* buffer = buffer_mapping->backing_buffer; + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + buffer_mapping->allowed_access, IREE_HAL_MEMORY_ACCESS_READ)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_calculate_range( + buffer_mapping->byte_offset, buffer_mapping->contents.data_length, + byte_offset, byte_length, &byte_offset, &byte_length)); + return _VTABLE_DISPATCH(buffer, invalidate_range)(buffer, byte_offset, + byte_length); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_flush_range( + iree_hal_buffer_mapping_t* base_buffer_mapping, + iree_device_size_t byte_offset, iree_device_size_t byte_length) { + IREE_ASSERT_ARGUMENT(base_buffer_mapping); + iree_hal_buffer_mapping_impl_t* buffer_mapping = + (iree_hal_buffer_mapping_impl_t*)base_buffer_mapping; + iree_hal_buffer_t* buffer = buffer_mapping->backing_buffer; + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + buffer_mapping->allowed_access, IREE_HAL_MEMORY_ACCESS_WRITE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_calculate_range( + buffer_mapping->byte_offset, buffer_mapping->contents.data_length, + byte_offset, byte_length, &byte_offset, &byte_length)); + return _VTABLE_DISPATCH(buffer, flush_range)(buffer, byte_offset, + byte_length); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_mapping_subspan( + iree_hal_buffer_mapping_t* base_buffer_mapping, + iree_hal_memory_access_t memory_access, iree_device_size_t byte_offset, + iree_device_size_t byte_length, iree_byte_span_t* out_span) { + IREE_ASSERT_ARGUMENT(base_buffer_mapping); + iree_hal_buffer_mapping_impl_t* buffer_mapping = + (iree_hal_buffer_mapping_impl_t*)base_buffer_mapping; + IREE_ASSERT_ARGUMENT(out_span); + memset(out_span, 0, sizeof(*out_span)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + buffer_mapping->allowed_access, memory_access)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_calculate_range( + 0, buffer_mapping->contents.data_length, byte_offset, byte_length, + &byte_offset, &out_span->data_length)); + out_span->data = buffer_mapping->contents.data + byte_offset; + return iree_ok_status(); +} diff --git a/iree/hal/buffer.cc b/iree/hal/buffer.cc deleted file mode 100644 index f7989524effad..0000000000000 --- a/iree/hal/buffer.cc +++ /dev/null @@ -1,551 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/buffer.h" - -#include -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -#if HAS_IREE_BUFFER_DEBUG_NAME -namespace { -// Used for diagnostic purposes only as a default buffer name. -std::atomic next_buffer_id_{0}; -} // namespace -#endif // HAS_IREE_BUFFER_DEBUG_NAME - -std::string MemoryTypeString(MemoryTypeBitfield memory_type) { - return FormatBitfieldValue(memory_type, - { - // Combined: - {MemoryType::kHostLocal, "kHostLocal"}, - {MemoryType::kDeviceLocal, "kDeviceLocal"}, - // Separate: - {MemoryType::kTransient, "kTransient"}, - {MemoryType::kHostVisible, "kHostVisible"}, - {MemoryType::kHostCoherent, "kHostCoherent"}, - {MemoryType::kHostCached, "kHostCached"}, - {MemoryType::kDeviceVisible, "kDeviceVisible"}, - }); -} - -std::string MemoryAccessString(MemoryAccessBitfield memory_access) { - return FormatBitfieldValue(memory_access, - { - // Combined: - {MemoryAccess::kAll, "kAll"}, - {MemoryAccess::kDiscardWrite, "kDiscardWrite"}, - // Separate: - {MemoryAccess::kRead, "kRead"}, - {MemoryAccess::kWrite, "kWrite"}, - {MemoryAccess::kDiscard, "kDiscard"}, - {MemoryAccess::kMayAlias, "kMayAlias"}, - }); -} - -std::string BufferUsageString(BufferUsageBitfield buffer_usage) { - return FormatBitfieldValue(buffer_usage, - { - // Combined: - {BufferUsage::kAll, "kAll"}, - // Separate: - {BufferUsage::kConstant, "kConstant"}, - {BufferUsage::kTransfer, "kTransfer"}, - {BufferUsage::kMapping, "kMapping"}, - {BufferUsage::kDispatch, "kDispatch"}, - }); -} - -// Special router for buffers that just reference other buffers. -// We keep this out of the base Buffer so that it's a bit easier to track -// delegation. -class SubspanBuffer : public Buffer { - public: - SubspanBuffer(ref_ptr parent_buffer, device_size_t byte_offset, - device_size_t byte_length) - : Buffer(parent_buffer->allocator(), parent_buffer->memory_type(), - parent_buffer->allowed_access(), parent_buffer->usage(), - parent_buffer->allocation_size(), byte_offset, byte_length) { - allocated_buffer_ = parent_buffer.get(); - parent_buffer_ = std::move(parent_buffer); - } - - protected: - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override { - return parent_buffer_->FillImpl(byte_offset, byte_length, pattern, - pattern_length); - } - - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override { - return parent_buffer_->ReadDataImpl(source_offset, data, data_length); - } - - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override { - return parent_buffer_->WriteDataImpl(target_offset, data, data_length); - } - - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override { - return parent_buffer_->CopyDataImpl(target_offset, source_buffer, - source_offset, data_length); - } - - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override { - return parent_buffer_->MapMemoryImpl(mapping_mode, memory_access, - local_byte_offset, local_byte_length, - out_data); - } - - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override { - return parent_buffer_->UnmapMemoryImpl(local_byte_offset, local_byte_length, - data); - } - - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override { - return parent_buffer_->InvalidateMappedMemoryImpl(local_byte_offset, - local_byte_length); - } - - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override { - return parent_buffer_->FlushMappedMemoryImpl(local_byte_offset, - local_byte_length); - } -}; - -// static -StatusOr> Buffer::Subspan(const ref_ptr& buffer, - device_size_t byte_offset, - device_size_t byte_length) { - IREE_RETURN_IF_ERROR(buffer->CalculateRange(byte_offset, byte_length, - &byte_offset, &byte_length)); - if (byte_offset == 0 && byte_length == buffer->byte_length()) { - // Asking for the same buffer. - return add_ref(buffer); - } - - // To avoid heavy nesting of subspans that just add indirection we go to the - // parent buffer directly. If we wanted better accounting (to track where - // buffers came from) we'd want to avoid this but I'm not sure that's worth - // the super deep indirection that could arise. - if (buffer->allocated_buffer() != buffer.get()) { - IREE_CHECK(buffer->parent_buffer_); - return Buffer::Subspan(buffer->parent_buffer_, byte_offset, byte_length); - } else { - return {make_ref(add_ref(buffer), byte_offset, byte_length)}; - } -} - -// static -Buffer::Overlap Buffer::TestOverlap( - Buffer* lhs_buffer, device_size_t lhs_offset, device_size_t lhs_length, - Buffer* rhs_buffer, device_size_t rhs_offset, device_size_t rhs_length) { - if (lhs_buffer->allocated_buffer() != rhs_buffer->allocated_buffer()) { - // Not even the same buffers. - return Overlap::kDisjoint; - } - // Resolve offsets into the underlying allocation. - device_size_t lhs_alloc_offset = lhs_buffer->byte_offset() + lhs_offset; - device_size_t rhs_alloc_offset = rhs_buffer->byte_offset() + rhs_offset; - device_size_t lhs_alloc_length = lhs_length == kWholeBuffer - ? lhs_buffer->byte_length() - lhs_offset - : lhs_length; - device_size_t rhs_alloc_length = rhs_length == kWholeBuffer - ? rhs_buffer->byte_length() - rhs_offset - : rhs_length; - if (!lhs_alloc_length || !rhs_alloc_length) { - return Overlap::kDisjoint; - } - if (lhs_alloc_offset == rhs_alloc_offset && - lhs_alloc_length == rhs_alloc_length) { - return Overlap::kComplete; - } - return lhs_alloc_offset + lhs_alloc_length > rhs_alloc_offset && - rhs_alloc_offset + rhs_alloc_length > lhs_alloc_offset - ? Overlap::kPartial - : Overlap::kDisjoint; -} - -// static -bool Buffer::DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, - device_size_t lhs_length, Buffer* rhs_buffer, - device_size_t rhs_offset, device_size_t rhs_length) { - return TestOverlap(lhs_buffer, lhs_offset, lhs_length, rhs_buffer, rhs_offset, - rhs_length) != Overlap::kDisjoint; -} - -Buffer::Buffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length) - : allocated_buffer_(const_cast(this)), - allocator_(allocator), - memory_type_(memory_type), - allowed_access_(allowed_access), - usage_(usage), - allocation_size_(allocation_size), - byte_offset_(byte_offset), - byte_length_(byte_length) { -#if HAS_IREE_BUFFER_DEBUG_NAME - // Default name for logging. - // It'd be nice to defer this until it's required but that would require - // synchronization or something. - const char* debug_name_prefix = ""; - if ((memory_type_ & MemoryType::kHostLocal) == MemoryType::kHostLocal) { - debug_name_prefix = "host_buffer_"; - } else if ((memory_type_ & MemoryType::kDeviceLocal) == - MemoryType::kDeviceLocal) { - // TODO(benvanik): include allocator ID to differentiate devices. - debug_name_prefix = "device_buffer_"; - } - debug_name_ = absl::StrCat(debug_name_prefix, next_buffer_id_++); -#endif // HAS_IREE_BUFFER_DEBUG_NAME -} - -Buffer* Buffer::allocated_buffer() const noexcept { - Buffer* allocated_buffer = allocated_buffer_; - while (allocated_buffer != this && - allocated_buffer != allocated_buffer->allocated_buffer()) { - allocated_buffer = allocated_buffer->allocated_buffer(); - } - return allocated_buffer; -} - -std::string Buffer::DebugString() const { - std::ostringstream stream; - stream << allocated_buffer()->debug_name() << "[" - << (allocation_size() == kWholeBuffer - ? "?" - : std::to_string(allocation_size())) - << "]."; - if (AnyBitSet(memory_type() & MemoryType::kTransient)) stream << "Z"; - if ((memory_type() & MemoryType::kHostLocal) == MemoryType::kHostLocal) { - stream << "h"; - } else { - if (AnyBitSet(memory_type() & MemoryType::kHostVisible)) stream << "v"; - if (AnyBitSet(memory_type() & MemoryType::kHostCoherent)) stream << "x"; - if (AnyBitSet(memory_type() & MemoryType::kHostCached)) stream << "c"; - } - if ((memory_type() & MemoryType::kDeviceLocal) == MemoryType::kDeviceLocal) { - stream << "D"; - } else { - if (AnyBitSet(memory_type() & MemoryType::kDeviceVisible)) stream << "V"; - } - stream << "."; - if (AnyBitSet(usage() & BufferUsage::kConstant)) stream << "c"; - if (AnyBitSet(usage() & BufferUsage::kTransfer)) stream << "t"; - if (AnyBitSet(usage() & BufferUsage::kMapping)) stream << "m"; - if (AnyBitSet(usage() & BufferUsage::kDispatch)) stream << "d"; - if (byte_offset_ || byte_length_ != allocation_size_) { - stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) - << ")"; - } - return stream.str(); -} - -std::string Buffer::DebugStringShort() const { - // TODO(benvanik): figure out what's most useful here. Maybe a long variant? - std::ostringstream stream; - stream << allocated_buffer()->debug_name() << "[" - << (allocation_size() == kWholeBuffer - ? "?" - : std::to_string(allocation_size())) - << "]"; - if (byte_offset_ || byte_length_ != allocation_size_) { - stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) - << ")"; - } - return stream.str(); -} - -Status Buffer::ValidateCompatibleMemoryType( - MemoryTypeBitfield memory_type) const { - if ((memory_type_ & memory_type) != memory_type) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is not compatible with the requested " - "operation; buffer has " - << MemoryTypeString(memory_type_) << ", operation requires " - << MemoryTypeString(memory_type); - } - return OkStatus(); -} - -Status Buffer::ValidateAccess(MemoryAccessBitfield memory_access) const { - if (!AnyBitSet(memory_access & - (MemoryAccess::kRead | MemoryAccess::kWrite))) { - // No actual access bits defined. - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Memory access must specify one or more of kRead or kWrite"; - } else if ((allowed_access_ & memory_access) != memory_access) { - // Bits must match exactly. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "The buffer does not support the requested access type; buffer " - "allows " - << MemoryAccessString(allowed_access_) << ", operation requires " - << MemoryAccessString(memory_access); - } - return OkStatus(); -} - -Status Buffer::ValidateUsage(BufferUsageBitfield usage) const { - if ((usage_ & usage) != usage) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Requested usage was not specified when the buffer was " - "allocated; buffer allows " - << BufferUsageString(usage_) << ", operation requires " - << BufferUsageString(usage); - } - return OkStatus(); -} - -Status Buffer::CalculateRange(device_size_t base_offset, - device_size_t max_length, device_size_t offset, - device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length) { - // Check if the start of the range runs off the end of the buffer. - if (offset > max_length) { - *out_adjusted_offset = 0; - if (out_adjusted_length) *out_adjusted_length = 0; - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address off the end of the valid buffer " - "range (offset=" - << offset << ", length=" << length - << ", buffer byte_length=" << max_length << ")"; - } - - // Handle length as kWholeBuffer by adjusting it (if allowed). - if (length == kWholeBuffer && !out_adjusted_length) { - *out_adjusted_offset = 0; - return InvalidArgumentErrorBuilder(IREE_LOC) - << "kWholeBuffer may only be used with buffer ranges, not external " - "pointer ranges"; - } - - // Calculate the real ranges adjusted for our region within the allocation. - device_size_t adjusted_offset = base_offset + offset; - device_size_t adjusted_length = - length == kWholeBuffer ? max_length - offset : length; - if (adjusted_length == 0) { - // Fine to have a zero length. - *out_adjusted_offset = adjusted_offset; - if (out_adjusted_length) *out_adjusted_length = adjusted_length; - return OkStatus(); - } - - // Check if the end runs over the allocation. - device_size_t end = offset + adjusted_length - 1; - if (end >= max_length) { - *out_adjusted_offset = 0; - if (out_adjusted_length) *out_adjusted_length = 0; - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address outside of the valid buffer " - "range (offset=" - << offset << ", adjusted_length=" << adjusted_length - << ", end=" << end << ", buffer byte_length=" << max_length << ")"; - } - - *out_adjusted_offset = adjusted_offset; - if (out_adjusted_length) *out_adjusted_length = adjusted_length; - return OkStatus(); -} - -Status Buffer::CalculateRange(device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length) const { - return CalculateRange(byte_offset_, byte_length_, offset, length, - out_adjusted_offset, out_adjusted_length); -} - -Status Buffer::CalculateLocalRange(device_size_t max_length, - device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length) { - return CalculateRange(0, max_length, offset, length, out_adjusted_offset, - out_adjusted_length); -} - -Status Buffer::Fill(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) { - // If not host visible we'll need to issue command buffers. - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - IREE_RETURN_IF_ERROR( - CalculateRange(byte_offset, byte_length, &byte_offset, &byte_length)); - if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fill patterns must be 1, 2, or 4 bytes"; - } - if ((byte_offset % pattern_length) != 0 || - (byte_length % pattern_length) != 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Attempting to fill a range with " << pattern_length - << " byte values that is not " - "aligned (offset=" - << byte_offset << ", length=" << byte_length << ")"; - } - if (byte_length == 0) { - return OkStatus(); // No-op. - } - const uint32_t kZero = 0; - if (std::memcmp(pattern, &kZero, pattern_length) == 0) { - // We can turn all-zero values into single-byte fills as that can be much - // faster on devices (doing a fill8 vs fill32). - pattern_length = 1; - } - return FillImpl(byte_offset, byte_length, pattern, pattern_length); -} - -Status Buffer::ReadData(device_size_t source_offset, void* data, - device_size_t data_length) { - // If not host visible we'll need to issue command buffers. - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - IREE_RETURN_IF_ERROR( - CalculateRange(source_offset, data_length, &source_offset)); - if (data_length == 0) { - return OkStatus(); // No-op. - } - return ReadDataImpl(source_offset, data, data_length); -} - -Status Buffer::WriteData(device_size_t target_offset, const void* data, - device_size_t data_length) { - // If not host visible we'll need to issue command buffers. - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - IREE_RETURN_IF_ERROR( - CalculateRange(target_offset, data_length, &target_offset)); - if (data_length == 0) { - return OkStatus(); // No-op. - } - return WriteDataImpl(target_offset, data, data_length); -} - -Status Buffer::CopyData(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - IREE_RETURN_IF_ERROR( - source_buffer->ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - IREE_RETURN_IF_ERROR(source_buffer->ValidateAccess(MemoryAccess::kRead)); - IREE_RETURN_IF_ERROR(source_buffer->ValidateUsage(BufferUsage::kMapping)); - - // We need to validate both buffers. - device_size_t source_data_length = data_length; - device_size_t target_data_length = data_length; - device_size_t adjusted_source_offset; - IREE_RETURN_IF_ERROR(source_buffer->CalculateRange( - source_offset, source_data_length, &adjusted_source_offset, - &source_data_length)); - IREE_RETURN_IF_ERROR(CalculateRange(target_offset, target_data_length, - &target_offset, &target_data_length)); - device_size_t adjusted_data_length; - if (data_length == kWholeBuffer) { - // Whole buffer copy requested - that could mean either, so take the min. - adjusted_data_length = std::min(source_data_length, target_data_length); - } else { - // Specific length requested - validate that we have matching lengths. - IREE_CHECK_EQ(source_data_length, target_data_length); - adjusted_data_length = source_data_length; - } - - // Elide zero length copies. - if (adjusted_data_length == 0) { - return OkStatus(); - } - - // Check for overlap. - if (this == source_buffer && - adjusted_source_offset <= target_offset + adjusted_data_length && - target_offset <= adjusted_source_offset + adjusted_data_length) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Source and target ranges overlap within the same buffer"; - } - - return CopyDataImpl(target_offset, source_buffer, source_offset, - adjusted_data_length); -} - -Status Buffer::MapMemory(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t* byte_offset, device_size_t* byte_length, - void** out_data) { - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - IREE_RETURN_IF_ERROR(ValidateAccess(memory_access)); - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - IREE_RETURN_IF_ERROR( - CalculateRange(*byte_offset, *byte_length, byte_offset, byte_length)); - *out_data = nullptr; - return MapMemoryImpl(mapping_mode, memory_access, *byte_offset, *byte_length, - out_data); -} - -Status Buffer::UnmapMemory(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) { - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return UnmapMemoryImpl(local_byte_offset, local_byte_length, data); -} - -Status Buffer::InvalidateMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length) { - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - if (AnyBitSet(memory_type_ & MemoryType::kHostCoherent)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is coherent and invalidation is not required"; - } - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return InvalidateMappedMemoryImpl(local_byte_offset, local_byte_length); -} - -Status Buffer::FlushMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length) { - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible | - MemoryType::kHostCached)); - IREE_RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return FlushMappedMemoryImpl(local_byte_offset, local_byte_length); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/buffer.h b/iree/hal/buffer.h index a7e8c01c5d711..a49de94c6183e 100644 --- a/iree/hal/buffer.h +++ b/iree/hal/buffer.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,107 +12,50 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Allocated memory buffer wrapper type and utilities. -// -// Buffers are the basic unit of memory used by the inference system. They may -// be allocated such that they are accessible from the host (normal C++ code -// running on the main CPU), a particular device (such as an accelerator) or -// family of devices, or from some mix of all of those. -// -// The type of memory a buffer is allocated within has implications on it's -// performance and lifetime. For example if an application attempts to use a -// host-allocated buffer (MemoryType::kHostLocal) on an accelerator with -// discrete memory the accelerator may either be unable to access the memory or -// take a non-trivial performance hit when attempting to do so (involving -// setting up kernel mappings, doing DMA transfers, etc). Likewise, trying to -// access a device-allocated buffer (MemoryType::kDeviceLocal) may incur similar -// overhead or not be possible at all. This may be due to restrictions in the -// memory visibility, address spaces, mixed endianness or pointer widths, -// and other weirdness. -// -// The memory types (defined by a bitfield of MemoryType values) that a -// particular context (host or device) may use vary from device to device and -// must be queried by the application when allocating buffers. It's strongly -// recommended that the most specific memory type be set as possible. For -// example allocating a buffer with MemoryType::kHostCoherent even when it will -// never be used in a way that requires coherency may occupy address space -// reservations or memory mapping that would otherwise not be needed. -// -// As buffers may sometimes not be accessible from the host the base Buffer type -// does not allow for direct void* access and instead buffers must be either -// manipulated using utility functions (such as ReadData or WriteData) or by -// mapping them into a host-accessible address space via MapMemory. Buffer must -// be unmapped before any command may use it. -// -// Buffers may map (roughly) 1:1 with an allocation either from the host heap or -// a device. Buffer::Subspan can be used to reference subspans of buffers like -// absl::Span - though unlike absl::Span the returned Buffer holds a reference -// to the parent buffer. - #ifndef IREE_HAL_BUFFER_H_ #define IREE_HAL_BUFFER_H_ -#include -#include -#include -#include -#include +#include +#include -#include "absl/types/span.h" -#include "iree/base/bitfield.h" -#include "iree/base/logging.h" -#include "iree/base/status.h" +#include "iree/base/api.h" #include "iree/hal/resource.h" -// Only enable debug names in non-opt modes (unless the user forces it on). -#if !defined(NDEBUG) && !defined(HAS_IREE_BUFFER_DEBUG_NAME) -#define HAS_IREE_BUFFER_DEBUG_NAME 1 -#endif // !NDEBUG - -namespace iree { - -// std::size_t equivalent that is the size as used on device. -// As the device may have a larger memory address space than the host we treat -// all byte offsets as this type instead of the host-specified size_t. -using device_size_t = uint64_t; - -// When used as a length value in functions causes the length to be the entire -// remaining buffer from the specified offset. -constexpr device_size_t kWholeBuffer = ~0ull; +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -} // namespace iree +typedef struct iree_hal_allocator_s iree_hal_allocator_t; -namespace iree { -namespace hal { - -class Allocator; -template -class MappedMemory; +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// // A bitfield specifying properties for a memory type. -enum class MemoryType : uint32_t { - kNone = 0, +enum iree_hal_memory_type_e { + IREE_HAL_MEMORY_TYPE_NONE = 0u, // Memory is lazily allocated by the device and only exists transiently. // This is the optimal mode for memory used only within a single command - // buffer. Transient buffers, even if they have kHostVisible set, should be - // treated as device-local and opaque as they may have no memory attached to - // them outside of the time they are being evaluated on devices. + // buffer. Transient buffers, even if they have + // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE set, should be treated as device-local + // and opaque as they may have no memory attached to them outside of the time + // they are being evaluated on devices. // // This flag can be treated as a hint in most cases; allocating a buffer with // it set _may_ return the same as if it had not be set. Certain allocation // routines may use the hint to more tightly control reuse or defer wiring the // memory. - kTransient = 1 << 0, + IREE_HAL_MEMORY_TYPE_TRANSIENT = 1u << 0, // Memory allocated with this type can be mapped for host access using - // Buffer::MapMemory. - kHostVisible = 1 << 1, + // iree_hal_buffer_map_range. + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE = 1u << 1, // The host cache management commands MappedMemory::Flush and // MappedMemory::Invalidate are not needed to flush host writes // to the device or make device writes visible to the host, respectively. - kHostCoherent = 1 << 2, + IREE_HAL_MEMORY_TYPE_HOST_COHERENT = 1u << 2, // Memory allocated with this type is cached on the host. Host memory // accesses to uncached memory are slower than to cached memory, however @@ -120,85 +63,81 @@ enum class MemoryType : uint32_t { // to ensure the device has visibility into any changes made on the host and // Invalidate must be used to ensure the host has visibility into any changes // made on the device. - kHostCached = 1 << 3, + IREE_HAL_MEMORY_TYPE_HOST_CACHED = 1u << 3, // Memory is accessible as normal host allocated memory. - kHostLocal = kHostVisible | kHostCoherent, + IREE_HAL_MEMORY_TYPE_HOST_LOCAL = + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_COHERENT, // Memory allocated with this type is visible to the device for execution. - // Being device visible does not mean the same thing as kDeviceLocal. Though - // an allocation may be visible to the device and therefore useable for - // execution it may require expensive mapping or implicit transfers. - kDeviceVisible = 1 << 4, + // Being device visible does not mean the same thing as + // IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL. Though an allocation may be visible to + // the device and therefore useable for execution it may require expensive + // mapping or implicit transfers. + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE = 1u << 4, // Memory allocated with this type is the most efficient for device access. // Devices may support using memory that is not device local via - // kDeviceVisible but doing so can incur non-trivial performance penalties. - // Device local memory, on the other hand, is guaranteed to be fast for all - // operations. - kDeviceLocal = kDeviceVisible | (1 << 5), + // IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE but doing so can incur non-trivial + // performance penalties. Device local memory, on the other hand, is + // guaranteed to be fast for all operations. + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL = + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | (1u << 5), }; -IREE_BITFIELD(MemoryType); -using MemoryTypeBitfield = MemoryType; -std::string MemoryTypeString(MemoryTypeBitfield memory_type); +typedef uint32_t iree_hal_memory_type_t; // A bitfield specifying how memory will be accessed in a mapped memory region. -enum class MemoryAccess : uint32_t { +enum iree_hal_memory_access_e { // Memory is not mapped. - kNone = 0, - + IREE_HAL_MEMORY_ACCESS_NONE = 0u, // Memory will be read. // If a buffer is only mapped for reading it may still be possible to write to // it but the results will be undefined (as it may present coherency issues). - kRead = 1 << 0, - + IREE_HAL_MEMORY_ACCESS_READ = 1u << 0, // Memory will be written. // If a buffer is only mapped for writing it may still be possible to read // from it but the results will be undefined or incredibly slow (as it may // be mapped by the driver as uncached). - kWrite = 1 << 1, - + IREE_HAL_MEMORY_ACCESS_WRITE = 1u << 1, // Memory will be discarded prior to mapping. // The existing contents will be undefined after mapping and must be written // to ensure validity. - kDiscard = 1 << 2, - + IREE_HAL_MEMORY_ACCESS_DISCARD = 1u << 2, // Memory will be discarded and completely overwritten in a single operation. - kDiscardWrite = kWrite | kDiscard, - + IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE = + IREE_HAL_MEMORY_ACCESS_WRITE | IREE_HAL_MEMORY_ACCESS_DISCARD, // A flag that can be applied to any access type to indicate that the buffer // storage being accessed may alias with other accesses occurring concurrently // within or across operations. The lack of the flag indicates that the access // is guaranteed not to alias (ala C's `restrict` keyword). - kMayAlias = 1 << 3, - + IREE_HAL_MEMORY_ACCESS_MAY_ALIAS = 1u << 3, // Memory may have any operation performed on it. - kAll = kRead | kWrite | kDiscard, + IREE_HAL_MEMORY_ACCESS_ALL = IREE_HAL_MEMORY_ACCESS_READ | + IREE_HAL_MEMORY_ACCESS_WRITE | + IREE_HAL_MEMORY_ACCESS_DISCARD, }; -IREE_BITFIELD(MemoryAccess); -using MemoryAccessBitfield = MemoryAccess; -std::string MemoryAccessString(MemoryAccessBitfield memory_access); +typedef uint32_t iree_hal_memory_access_t; // Bitfield that defines how a buffer is intended to be used. // Usage allows the driver to appropriately place the buffer for more // efficient operations of the specified types. -enum class BufferUsage { - kNone = 0, +enum iree_hal_buffer_usage_e { + IREE_HAL_BUFFER_USAGE_NONE = 0u, // The buffer, once defined, will not be mapped or updated again. // This should be used for uniform parameter values such as runtime // constants for executables. Doing so may allow drivers to inline values or // represent them in command buffers more efficiently (avoiding memory reads // or swapping, etc). - kConstant = 1 << 0, + IREE_HAL_BUFFER_USAGE_CONSTANT = 1u << 0, // The buffer can be used as the source or target of a transfer command // (CopyBuffer, UpdateBuffer, etc). // - // If |kMapping| is not specified drivers may safely assume that the host - // may never need visibility of this buffer as all accesses will happen via - // command buffers. - kTransfer = 1 << 1, + // If |IREE_HAL_BUFFER_USAGE_MAPPING| is not specified drivers may safely + // assume that the host may never need visibility of this buffer as all + // accesses will happen via command buffers. + IREE_HAL_BUFFER_USAGE_TRANSFER = 1u << 1, // The buffer can be mapped by the host application for reading and writing. // @@ -206,706 +145,405 @@ enum class BufferUsage { // calls to enable visibility the driver can use the presence (or lack of) // this flag to perform allocation-type setup and avoid initial mapping // overhead. - kMapping = 1 << 2, + IREE_HAL_BUFFER_USAGE_MAPPING = 1u << 2, // The buffer can be provided as an input or output to an executable. // Buffers of this type may be directly used by drivers during dispatch. - kDispatch = 1 << 3, + IREE_HAL_BUFFER_USAGE_DISPATCH = 1u << 3, // Buffer may be used for any operation. - kAll = kTransfer | kMapping | kDispatch, + IREE_HAL_BUFFER_USAGE_ALL = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_MAPPING | + IREE_HAL_BUFFER_USAGE_DISPATCH, }; -IREE_BITFIELD(BufferUsage); -using BufferUsageBitfield = BufferUsage; -std::string BufferUsageString(BufferUsageBitfield buffer_usage); - -// A memory buffer. -// Buffers have a specific memory_type that is used to describe the capabilities -// and behavior of the backing memory of the buffer. Buffers may be any mix of -// host-accessible, host-coherent, or device-accessible for various usages. -// Depending on these memory types the buffers may be mapped for access on the -// host as memory though certain restrictions may be imposed. -// -// See MemoryType for more information about the types and what operations they -// support. -class Buffer : public Resource { - public: - // Returns a reference to a subspan of the buffer. - // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after - // |byte_offset| (possibly 0) will be selected. - // - // The parent buffer will remain alive for the lifetime of the subspan - // returned. If the subspan is a small portion this may cause additional - // memory to remain allocated longer than required. - // - // Returns the given |buffer| if the requested span covers the entire range. - static StatusOr> Subspan(const ref_ptr& buffer, - device_size_t byte_offset, - device_size_t byte_length); - - // Overlap test results. - enum class Overlap { - // No overlap between the two buffers. - kDisjoint, - // Partial overlap between the two buffers. - kPartial, - // Complete overlap between the two buffers (they are the same). - kComplete, - }; - - // Tests whether the given buffers overlap, including support for subspans. - // kWholeBuffer may be used for |lhs_length| and/or |rhs_length| to use the - // lengths of those buffers, respectively. - static Overlap TestOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, - device_size_t lhs_length, Buffer* rhs_buffer, - device_size_t rhs_offset, - device_size_t rhs_length); - - // Returns true if the two buffer ranges overlap at all. - static bool DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, - device_size_t lhs_length, Buffer* rhs_buffer, - device_size_t rhs_offset, device_size_t rhs_length); - - // Disallow copies (as copying requires real work). - Buffer(const Buffer&) = delete; - Buffer& operator=(const Buffer&) = delete; - - ~Buffer() override = default; - -#if HAS_IREE_BUFFER_DEBUG_NAME - // Optionally populated name useful for logging a persistent name for the - // buffer. - absl::string_view debug_name() const { return debug_name_; } - void set_debug_name(std::string debug_name) { - debug_name_ = std::move(debug_name); - } -#else - absl::string_view debug_name() const { return ""; } - void set_debug_name(std::string debug_name) {} -#endif // HAS_IREE_BUFFER_DEBUG_NAME - - // Memory allocator this buffer was allocated from. - // May be nullptr if the buffer has no particular allocator and should be - // assumed to be allocated from the host heap. - constexpr Allocator* allocator() const { - return allocated_buffer_ == this ? allocator_ - : allocated_buffer_->allocator(); - } - - // Memory type this buffer is allocated from. - MemoryTypeBitfield memory_type() const { return memory_type_; } - - // Memory access operations allowed on the buffer. - MemoryAccessBitfield allowed_access() const { return allowed_access_; } - - // Bitfield describing how the buffer is to be used. - BufferUsageBitfield usage() const { return usage_; } - - // Returns the underlying buffer that represents the allocated memory for the - // Buffer. In most cases this is the buffer itself but for buffer subspan - // references it will point to the parent buffer. - Buffer* allocated_buffer() const noexcept; - - // Size of the resource memory allocation in bytes. - // This may be rounded up from the originally requested size or the ideal - // size for the resource based on device restrictions. - constexpr device_size_t allocation_size() const { - return allocated_buffer_ == this ? allocation_size_ - : allocated_buffer_->allocation_size(); - } - - // Range within the underlying allocation this buffer occupies. - // For buffers that map 1:1 with an allocation this should be - // [0, allocation_size()), however may still differ if the allocation needed - // to be aligned. - // - // The offset is most often manipulated by Subspan, however it's important to - // note that the offset may not be what was passed to Subspan as it refers to - // the offset in the original ancestor buffer, not the buffer from which the - // subspan was taken. - constexpr device_size_t byte_offset() const noexcept { return byte_offset_; } - constexpr device_size_t byte_length() const noexcept { return byte_length_; } - - // TODO(benvanik): add debug_name. - - // Returns a longer debug string describing the buffer and its attributes. - std::string DebugString() const override; - // Returns a short debug string describing the buffer. - std::string DebugStringShort() const override; - - // Sets a range of the buffer to the given value. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after - // |byte_offset| (possibly 0) will be filled. - // - // The |byte_offset| and |byte_length| must be aligned to the size of the fill - // value. Multi-byte values will be written in host order for host buffers and - // device order for device buffers. - // - // Only |pattern_length| values with 1, 2, or 4 bytes are supported. - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status Fill(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length); - template - Status Fill8(device_size_t byte_offset, device_size_t byte_length, T value); - template - Status Fill16(device_size_t byte_offset, device_size_t byte_length, T value); - template - Status Fill32(device_size_t byte_offset, device_size_t byte_length, T value); - template - Status Fill8(T value); - template - Status Fill16(T value); - template - Status Fill32(T value); - - // Reads a block of byte data from the resource at the given offset. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // - // Fails if the read could not be performed; either the bounds are out of - // range or the memory type does not support reading in this way. - Status ReadData(device_size_t source_offset, void* data, - device_size_t data_length); - - // Writes a block of byte data into the resource at the given offset. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status WriteData(device_size_t target_offset, const void* data, - device_size_t data_length); - - // Copies data from the provided source_buffer into the buffer. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // The source and destination may be the same buffer but the ranges must not - // overlap (a la memcpy). - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status CopyData(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, device_size_t data_length); - Status CopyData(device_size_t target_offset, Buffer* source_buffer) { - return CopyData(target_offset, source_buffer, 0, kWholeBuffer); - } - - // Maps the resource memory for direct access from the host. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // - // If MemoryType::kHostCoherent was not specified then explicit - // Invalidate and Flush calls must be used to control visibility of the data - // on the device. If MemoryType::kHostCached is not set callers must not - // attempt to read from the mapped memory as doing so may produce undefined - // results and/or ultra slow reads. - // - // If the MemoryAccess::kDiscard bit is set when mapping for writes the caller - // guarantees that they will be overwriting all data in the mapped range. This - // is used as a hint to the device that the prior contents are no longer - // required and can enable optimizations that save on synchronization and - // readback. Note however that it is strictly a hint and the contents are not - // guaranteed to be zeroed during mapping. - // - // This allows mapping the memory as a C++ type. Care must be taken to ensure - // the data layout in C++ matches the expected data layout in the executables - // that consume this data. For simple primitives like uint8_t or float this is - // usually not a problem however struct packing may have many restrictions. - // - // The returned mapping should be unmapped when it is no longer required. - // Unmapping does not implicitly flush. - // - // Fails if the memory could not be mapped due to mapping exhaustion, invalid - // arguments, or unsupported memory types. - // - // Example: - // IREE_ASSIGN_OR_RETURN(auto mapping, buffer->MapForRead()); - // mapping[5].foo = 3; - // std::memcpy(mapping.data(), source_data, mapping.size()); - // mapping.reset(); - template - StatusOr> MapMemory( - MemoryAccessBitfield memory_access, device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer); - - protected: - template - friend class MappedMemory; - - // Defines the mode of a MapMemory operation. - enum class MappingMode { - // The call to MapMemory will always be matched with UnmapMemory. - kScoped, - }; - - Buffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length); - - // Allows subclasses to override the allowed access bits. - // This should only be done when known safe by the allocation scheme. - void set_allowed_access(MemoryAccessBitfield allowed_access) { - allowed_access_ = allowed_access; - } - - // Sets a range of the buffer to the given value. - // State and parameters have already been validated. For the >8bit variants - // the offset and length have already been validated to be aligned to the - // natural alignment of the type. - virtual Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, - device_size_t pattern_length) = 0; - - // Reads a block of byte data from the resource at the given offset. - // State and parameters have already been validated. - virtual Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) = 0; - - // Writes a block of byte data into the resource at the given offset. - // State and parameters have already been validated. - virtual Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) = 0; - - // Copies a block of byte data into the resource at the given offset. - // State and parameters have already been validated. - virtual Status CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) = 0; - - // Maps memory directly. - // The output data pointer will be properly aligned to the start of the data. - // |local_byte_offset| and |local_byte_length| are the adjusted values that - // should map into the local space of the buffer. - // - // Fails if the memory could not be mapped (invalid access type, invalid - // range, or unsupported memory type). - // State and parameters have already been validated. - virtual Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) = 0; - - // Unmaps previously mapped memory. - // No-op if the memory is not mapped. As this is often used in destructors - // we can't rely on failures here propagating with anything but - // IREE_CHECK/IREE_DCHECK. State and parameters have already been validated. - virtual Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, - void* data) = 0; - - // Invalidates ranges of non-coherent memory from the host caches. - // Use this before reading from non-coherent memory. - // This guarantees that device writes to the memory ranges provided are - // visible on the host. - // This is only required for memory types without kHostCoherent set. - // State and parameters have already been validated. - virtual Status InvalidateMappedMemoryImpl( - device_size_t local_byte_offset, device_size_t local_byte_length) = 0; - - // Flushes ranges of non-coherent memory from the host caches. - // Use this after writing to non-coherent memory. - // This guarantees that host writes to the memory ranges provided are made - // available for device access. - // This is only required for memory types without kHostCoherent set. - // State and parameters have already been validated. - virtual Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) = 0; - - // Validates the given buffer range and adjusts the offset and length if the - // provided length is kWholeBuffer or the buffer is offset within its - // allocation. This calculates the range in the given domain without adjusting - // to any particular buffer base offsets. - static Status CalculateLocalRange(device_size_t max_length, - device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length); - - private: - friend class Allocator; - - // This is not great and deserves cleanup. - friend class DeferredBuffer; - friend class SubspanBuffer; - friend class HeapBuffer; - - // Maps memory directly. - // The byte offset and byte length may be adjusted for device alignment. - // The output data pointer will be properly aligned to the start of the data. - // Fails if the memory could not be mapped (invalid access type, invalid - // range, or unsupported memory type). - Status MapMemory(MappingMode mapping_mode, MemoryAccessBitfield memory_access, - device_size_t* byte_offset, device_size_t* byte_length, - void** out_data); - - // Unmaps previously mapped memory. - // No-op if the memory is not mapped. As this is often used in destructors - // we can't rely on failures here propagating with anything but - // IREE_CHECK/IREE_DCHECK. - Status UnmapMemory(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data); - - // Invalidates ranges of non-coherent memory from the host caches. - // Use this before reading from non-coherent memory. - // This guarantees that device writes to the memory ranges provided are - // visible on the host. - // This is only required for memory types without kHostCoherent set. - Status InvalidateMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length); - - // Flushes ranges of non-coherent memory from the host caches. - // Use this after writing to non-coherent memory. - // This guarantees that host writes to the memory ranges provided are made - // available for device access. - // This is only required for memory types without kHostCoherent set. - Status FlushMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length); - - // Returns a failure if the memory type the buffer was allocated from is not - // compatible with the given type. - Status ValidateCompatibleMemoryType(MemoryTypeBitfield memory_type) const; - // Returns a failure if the buffer memory type or usage disallows the given - // access type. - Status ValidateAccess(MemoryAccessBitfield memory_access) const; - // Returns a failure if the buffer was not allocated for the given usage. - Status ValidateUsage(BufferUsageBitfield usage) const; - // Validates the given buffer range and optionally adjusts the offset and - // length if the provided length is kWholeBuffer or the buffer is offset - // within its allocation. - static Status CalculateRange(device_size_t base_offset, - device_size_t max_length, device_size_t offset, - device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length = nullptr); - Status CalculateRange(device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length = nullptr) const; - - // Points to either this or parent_buffer_.get(). - Buffer* allocated_buffer_ = nullptr; - - Allocator* allocator_ = nullptr; - MemoryTypeBitfield memory_type_ = MemoryType::kNone; - MemoryAccessBitfield allowed_access_ = MemoryAccess::kNone; - BufferUsageBitfield usage_ = BufferUsage::kNone; - - device_size_t allocation_size_ = 0; - device_size_t byte_offset_ = 0; - device_size_t byte_length_ = 0; - -#if HAS_IREE_BUFFER_DEBUG_NAME - // Friendly name for the buffer used in DebugString. May be set by the app or - // auto generated. - std::string debug_name_; -#endif // HAS_IREE_BUFFER_DEBUG_NAME - - // Defined when this buffer is a subspan of another buffer. - ref_ptr parent_buffer_; +typedef uint32_t iree_hal_buffer_usage_t; + +// Buffer overlap testing results. +enum iree_hal_buffer_overlap_e { + // No overlap between the two buffers. + IREE_HAL_BUFFER_OVERLAP_DISJOINT = 0, + // Partial overlap between the two buffers. + IREE_HAL_BUFFER_OVERLAP_PARTIAL, + // Complete overlap between the two buffers (they are the same). + IREE_HAL_BUFFER_OVERLAP_COMPLETE, +}; +typedef uint8_t iree_hal_buffer_overlap_t; + +enum iree_hal_mapping_mode_e { + IREE_HAL_MAPPING_MODE_SCOPED = 0, + IREE_HAL_MAPPING_MODE_PERSISTENT = 0, }; +typedef uint32_t iree_hal_mapping_mode_t; + +// Reference to a buffer's mapped memory. +typedef struct { + // Contents of the buffer. Behavior is undefined if an access is performed + // whose type was not specified during mapping. + // + // The bytes available may be greater than what was requested if platform + // alignment rules require it. Only memory defined by the given span may be + // accessed. + iree_byte_span_t contents; + + // Used internally - do not modify. + uint64_t reserved[4]; +} iree_hal_buffer_mapping_t; + +// TODO(benvanik): replace with tables for iree_string_builder_*. +#define iree_hal_memory_type_string(...) "TODO" +// // Combined: +// {IREE_HAL_MEMORY_TYPE_HOST_LOCAL, "HOST_LOCAL"}, +// {IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, "DEVICE_LOCAL"}, +// // Separate: +// {IREE_HAL_MEMORY_TYPE_TRANSIENT, "TRANSIENT"}, +// {IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, "HOST_VISIBLE"}, +// {IREE_HAL_MEMORY_TYPE_HOST_COHERENT, "HOST_COHERENT"}, +// {IREE_HAL_MEMORY_TYPE_HOST_CACHED, "HOST_CACHED"}, +// {IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, "DEVICE_VISIBLE"}, +#define iree_hal_memory_access_string(...) "TODO" +// // Combined: +// {IREE_HAL_MEMORY_ACCESS_ALL, "ALL"}, +// {IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, "DISCARD_WRITE"}, +// // Separate: +// {IREE_HAL_MEMORY_ACCESS_READ, "READ"}, +// {IREE_HAL_MEMORY_ACCESS_WRITE, "WRITE"}, +// {IREE_HAL_MEMORY_ACCESS_DISCARD, "DISCARD"}, +// {IREE_HAL_MEMORY_ACCESS_MAY_ALIAS, "MAY_ALIAS"}, +#define iree_hal_buffer_usage_string(...) "TODO" +// // Combined: +// {IREE_HAL_BUFFER_USAGE_ALL, "ALL"}, +// // Separate: +// {IREE_HAL_BUFFER_USAGE_CONSTANT, "CONSTANT"}, +// {IREE_HAL_BUFFER_USAGE_TRANSFER, "TRANSFER"}, +// {IREE_HAL_BUFFER_USAGE_MAPPING, "MAPPING"}, +// {IREE_HAL_BUFFER_USAGE_DISPATCH, "DISPATCH"}, + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_t +//===----------------------------------------------------------------------===// -// A memory mapping RAII object. -// The mapping will stay active until it is reset and will retain the buffer. -template -class MappedMemory { - public: - using unspecified_bool_type = const T* MappedMemory::*; - - MappedMemory() = default; - MappedMemory(MemoryAccessBitfield access, ref_ptr buffer, - device_size_t byte_offset, device_size_t byte_length, - device_size_t element_size, T* data); - - // Allow moving but disallow copying as the mapping is stateful. - MappedMemory(MappedMemory&& rhs) noexcept; - MappedMemory& operator=(MappedMemory&& rhs) noexcept; - MappedMemory(const MappedMemory&) = delete; - MappedMemory& operator=(const MappedMemory&) = delete; - - ~MappedMemory(); - - // The buffer resource that this mapping references. - const ref_ptr& buffer() const noexcept { return buffer_; } - // Offset, in bytes, into the resource allocation. - // This value is *informative only*, as it may vary from device to device. - device_size_t byte_offset() const noexcept { return byte_offset_; } - // Length, in bytes, of the resource mapping. - // This may be larger than the originally requested length due to alignment. - // This value is *informative only*, as it may vary from device to device. - device_size_t byte_length() const noexcept { return byte_length_; } - - // True if the mapping is empty. - bool empty() const noexcept { return element_size_ == 0; } - // The size of the mapping as requested in elements. - size_t size() const noexcept { return static_cast(element_size_); } - - // Returns a read-only pointer to the mapped memory. - // This will be nullptr if the mapping failed or the mapping is not readable. - const T* data() const noexcept; - absl::Span contents() const noexcept { return {data(), size()}; } - - // Returns a mutable pointer to the mapped memory. - // This will be nullptr if the mapping failed or the mapping is not writable. - // If the mapping was not made with read access it may still be possible to - // read from this memory but behavior is undefined. - T* mutable_data() noexcept; - absl::Span mutable_contents() noexcept { return {mutable_data(), size()}; } - - // Returns a raw pointer to the mapped data without any access checks. - T* unsafe_data() const noexcept { return data_; } - - // Equivalent to absl::Span::subspan(). - // May return a 0-length span. - // Fails if the buffer is not mapped or not mapped for the requested access. - StatusOr> Subspan( - device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer) const noexcept; - StatusOr> MutableSubspan( - device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer) noexcept; - - // Accesses an element in the mapped memory. - // Must be called with a valid index in [0, size()). - const T& operator[](device_size_t i) const noexcept { return data_[i]; } - - // Invalidates a range of non-coherent elements from the host caches. - Status Invalidate(device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer) const; - - // Flushes a range of non-coherent elements from the host caches. - Status Flush(device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer); - - // Unmaps the mapped memory. - // The memory will not be implicitly flushed when unmapping. - void reset(); - - private: - Status ValidateAccess(MemoryAccessBitfield memory_access) const; - Status CalculateDataRange(device_size_t element_offset, - device_size_t element_length, - device_size_t* out_adjusted_element_offset, - device_size_t* out_adjusted_element_length) const; - - MemoryAccessBitfield access_ = MemoryAccess::kNone; - ref_ptr buffer_; - device_size_t byte_offset_ = 0; - device_size_t byte_length_ = 0; - device_size_t element_size_ = 0; - T* data_ = nullptr; +// Allocated memory buffer wrapper type and utilities. +// +// Buffers are the basic unit of memory used by the inference system. They may +// be allocated such that they are accessible from the host (normal C++ code +// running on the main CPU), a particular device (such as an accelerator) or +// family of devices, or from some mix of all of those. +// +// The type of memory a buffer is allocated within has implications on it's +// performance and lifetime. For example if an application attempts to use a +// host-allocated buffer (IREE_HAL_MEMORY_TYPE_HOST_LOCAL) on an accelerator +// with discrete memory the accelerator may either be unable to access the +// memory or take a non-trivial performance hit when attempting to do so +// (involving setting up kernel mappings, doing DMA transfers, etc). Likewise, +// trying to access a device-allocated buffer +// (IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL) may incur similar overhead or not be +// possible at all. This may be due to restrictions in the memory visibility, +// address spaces, mixed endianness or pointer widths, and other weirdness. +// +// The memory types (defined by a bitfield of iree_hal_memory_type_t values) +// that a particular context (host or device) may use vary from device to device +// and must be queried by the application when allocating buffers. It's strongly +// recommended that the most specific memory type be set as possible. For +// example allocating a buffer with IREE_HAL_MEMORY_TYPE_HOST_COHERENT even when +// it will never be used in a way that requires coherency may occupy address +// space reservations or memory mapping that would otherwise not be needed. +// +// As buffers may sometimes not be accessible from the host the base Buffer type +// does not allow for direct void* access and instead buffers must be either +// manipulated using utility functions (such as ReadData or WriteData) or by +// mapping them into a host-accessible address space via MapMemory. Buffer must +// be unmapped before any command may use it. +// +// Buffers may map (roughly) 1:1 with an allocation either from the host heap or +// a device. iree_hal_buffer_Subspan can be used to reference subspans of +// buffers like absl::Span - though unlike absl::Span the returned Buffer holds +// a reference to the parent buffer. +typedef struct iree_hal_buffer_s iree_hal_buffer_t; + +// Returns success iff the buffer was allocated with the given memory type. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_validate_memory_type( + iree_hal_memory_type_t actual_memory_type, + iree_hal_memory_type_t expected_memory_type); + +// Returns success iff the buffer allows the requested access. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_validate_access( + iree_hal_memory_access_t allowed_memory_access, + iree_hal_memory_access_t required_memory_access); + +// Returns success iff the buffer usage allows the given usage type. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_validate_usage(iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t required_usage); + +// Returns success iff the given byte range falls within the valid buffer. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_validate_range( + iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length); + +// Tests whether the given buffers overlap, including support for subspans. +// IREE_WHOLE_BUFFER may be used for |lhs_length| and/or |rhs_length| to use the +// lengths of those buffers, respectively. +IREE_API_EXPORT iree_hal_buffer_overlap_t IREE_API_CALL +iree_hal_buffer_test_overlap(iree_hal_buffer_t* lhs_buffer, + iree_device_size_t lhs_offset, + iree_device_size_t lhs_length, + iree_hal_buffer_t* rhs_buffer, + iree_device_size_t rhs_offset, + iree_device_size_t rhs_length); + +// Returns a reference to a subspan of the |buffer|. +// If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer after +// |byte_offset| (possibly 0) will be selected. +// +// The parent buffer will remain alive for the lifetime of the subspan +// returned. If the subspan is a small portion this may cause additional +// memory to remain allocated longer than required. +// +// Returns the given |buffer| if the requested span covers the entire range. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_subspan( + iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, iree_hal_buffer_t** out_buffer); + +// Retains the given |buffer| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_retain(iree_hal_buffer_t* buffer); + +// Releases the given |buffer| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_release(iree_hal_buffer_t* buffer); + +// Returns the allocator this buffer was allocated from. +IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL +iree_hal_buffer_allocator(const iree_hal_buffer_t* buffer); + +// Returns a pointer to the buffer containing the actual allocation. +// The buffer represents a span of the allocated bytes defined by byte_offset +// and byte_length. If the provided buffer *is* the allocated buffer then the +// returned value will be the provided buffer pointer. +IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL +iree_hal_buffer_allocated_buffer(const iree_hal_buffer_t* buffer); + +// Returns the size of the resource memory allocation in bytes. +// This may be rounded up from the originally requested size or the ideal +// size for the resource based on device restrictions. +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_allocation_size(const iree_hal_buffer_t* buffer); + +// Returns the offset in bytes of the buffer within its allocated_buffer. +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_byte_offset(const iree_hal_buffer_t* buffer); + +// Returns the size in bytes of the buffer. +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer); + +// Returns the memory type the buffer was allocated with. +IREE_API_EXPORT +iree_hal_memory_type_t IREE_API_CALL +iree_hal_buffer_memory_type(const iree_hal_buffer_t* buffer); + +// Returns the allowed memory access modes. +// These may be more strict than the underlying allocation, for example when the +// buffer is exposing read-only memory that may be in mutable pages. +IREE_API_EXPORT +iree_hal_memory_access_t IREE_API_CALL +iree_hal_buffer_allowed_access(const iree_hal_buffer_t* buffer); + +// Returns the allowed buffer usage modes. +IREE_API_EXPORT +iree_hal_buffer_usage_t IREE_API_CALL +iree_hal_buffer_allowed_usage(const iree_hal_buffer_t* buffer); + +// Sets a range of the buffer to binary zero. +// +// Requires that the buffer has the IREE_HAL_BUFFER_USAGE_MAPPING bit set. +// The byte range in |buffer| will be flushed if needed. +// +// It is strongly recommended that buffer operations are performed on transfer +// queues; using this synchronous function may incur additional cache flushes +// and synchronous blocking behavior and is not supported on all buffer types. +// See iree_hal_command_buffer_fill_buffer. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length); + +// Sets a range of the buffer to the given value. +// Only |pattern_length| values with 1, 2, or 4 bytes are supported. +// +// Requires that the buffer has the IREE_HAL_BUFFER_USAGE_MAPPING bit set. +// The byte range in |buffer| will be flushed if needed. +// +// It is strongly recommended that buffer operations are performed on transfer +// queues; using this synchronous function may incur additional cache flushes +// and synchronous blocking behavior and is not supported on all buffer types. +// See iree_hal_command_buffer_fill_buffer. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, const void* pattern, + iree_host_size_t pattern_length); + +// Reads a block of data from the buffer at the given offset. +// +// Requires that the buffer has the IREE_HAL_BUFFER_USAGE_MAPPING bit set. +// +// It is strongly recommended that buffer operations are performed on transfer +// queues; using this synchronous function may incur additional cache flushes +// and synchronous blocking behavior and is not supported on all buffer types. +// See iree_hal_command_buffer_copy_buffer. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_read_data( + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + void* target_buffer, iree_device_size_t data_length); + +// Writes a block of byte data into the buffer at the given offset. +// +// Requires that the buffer has the IREE_HAL_BUFFER_USAGE_MAPPING bit set. +// The byte range in |target_buffer| will be flushed if needed. +// +// It is strongly recommended that buffer operations are performed on transfer +// queues; using this synchronous function may incur additional cache flushes +// and synchronous blocking behavior and is not supported on all buffer types. +// See iree_hal_command_buffer_update_buffer and +// iree_hal_command_buffer_copy_buffer. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_write_data( + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + const void* source_buffer, iree_device_size_t data_length); + +// Copies data from the provided |source_buffer| into the |target_buffer|. +// +// Requires that both buffers have the IREE_HAL_BUFFER_USAGE_MAPPING bit set. +// The byte range in |target_buffer| will be flushed if needed. +// +// It is strongly recommended that buffer operations are performed on transfer +// queues; using this synchronous function may incur additional cache flushes +// and synchronous blocking behavior and is not supported on all buffer types. +// See iree_hal_command_buffer_copy_buffer. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_copy_data( + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t data_length); + +// Maps the buffer to be accessed as a host pointer into |out_buffer_mapping|. +// The byte offset and byte length may be adjusted for device alignment. +// The output data pointer will be properly aligned to the start of the data. +// Fails if the memory could not be mapped (invalid access type, invalid +// range, or unsupported memory type). +// +// Requires that the buffer has the IREE_HAL_BUFFER_USAGE_MAPPING bit set. +// If the buffer is not IREE_HAL_MEMORY_TYPE_HOST_COHERENT then the caller must +// invalidate the byte range they want to access to update the visibility of the +// mapped memory. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map_range( + iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + iree_hal_buffer_mapping_t* out_buffer_mapping); + +// Unmaps the buffer as was previously mapped to |buffer_mapping|. +// +// If the buffer is not IREE_HAL_MEMORY_TYPE_HOST_COHERENT then the caller must +// flush the byte range they want to make available to other threads/devices. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_unmap_range(iree_hal_buffer_mapping_t* buffer_mapping); + +// Invalidates ranges of non-coherent memory from the host caches. +// This guarantees that device writes to the memory ranges provided are +// visible on the host. Use before reading from non-coherent memory. +// +// Only required for memory types without IREE_HAL_MEMORY_TYPE_HOST_COHERENT. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_invalidate_range( + iree_hal_buffer_mapping_t* buffer_mapping, iree_device_size_t byte_offset, + iree_device_size_t byte_length); + +// Flushes ranges of non-coherent memory from the host caches. +// This guarantees that host writes to the memory ranges provided are available +// for device access. Use after writing to non-coherent memory. +// +// Only required for memory types without IREE_HAL_MEMORY_TYPE_HOST_COHERENT. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_flush_range( + iree_hal_buffer_mapping_t* buffer_mapping, iree_device_size_t byte_offset, + iree_device_size_t byte_length); + +// Calculates and returns a byte subspan range within a buffer mapping. +// The byte range provided is local to the mapping. May return a 0-length span. +// IREE_WHOLE_BUFFER can be used for |byte_length|. +// +// Note that the access requirements of the mapping still hold: if the memory is +// not host coherent and writeable then the caller must use the +// iree_hal_buffer_invalidate_range and iree_hal_buffer_flush_range methods to +// ensure memory is in the expected state. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_mapping_subspan( + iree_hal_buffer_mapping_t* buffer_mapping, + iree_hal_memory_access_t memory_access, iree_device_size_t byte_offset, + iree_device_size_t byte_length, iree_byte_span_t* out_span); + +//===----------------------------------------------------------------------===// +// iree_hal_heap_buffer_t +//===----------------------------------------------------------------------===// + +// Wraps an existing host allocation in a buffer. +// When the buffer is destroyed the provided |data_allocator| will be used to +// free |data|. Pass iree_allocator_null() to wrap without ownership semantics. +// +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_byte_span_t data, iree_allocator_t data_allocator, + iree_hal_buffer_t** out_buffer); + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_buffer_t* buffer); + + iree_status_t(IREE_API_PTR* map_range)(iree_hal_buffer_t* buffer, + iree_hal_mapping_mode_t mapping_mode, + iree_hal_memory_access_t memory_access, + iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, + void** out_data_ptr); + + void(IREE_API_PTR* unmap_range)(iree_hal_buffer_t* buffer, + iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, + void* data_ptr); + + iree_status_t(IREE_API_PTR* invalidate_range)( + iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length); + + iree_status_t(IREE_API_PTR* flush_range)( + iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length); +} iree_hal_buffer_vtable_t; + +struct iree_hal_buffer_s { + iree_hal_resource_t resource; + + iree_hal_allocator_t* allocator; + + iree_hal_buffer_t* allocated_buffer; + iree_device_size_t allocation_size; + iree_device_size_t byte_offset; + iree_device_size_t byte_length; + + iree_hal_memory_type_t memory_type; + iree_hal_memory_access_t allowed_access; + iree_hal_buffer_usage_t allowed_usage; }; -// Inline functions and template definitions follow: - -template -Status Buffer::Fill8(device_size_t byte_offset, device_size_t byte_length, - T value) { - auto sized_value = reinterpret_cast(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template -Status Buffer::Fill16(device_size_t byte_offset, device_size_t byte_length, - T value) { - auto sized_value = reinterpret_cast(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template -Status Buffer::Fill32(device_size_t byte_offset, device_size_t byte_length, - T value) { - auto sized_value = reinterpret_cast(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template -Status Buffer::Fill8(T value) { - return Fill8(0, kWholeBuffer, value); -} - -template -Status Buffer::Fill16(T value) { - return Fill16(0, kWholeBuffer, value); -} - -template -Status Buffer::Fill32(T value) { - return Fill32(0, kWholeBuffer, value); -} - -template -StatusOr> Buffer::MapMemory(MemoryAccessBitfield memory_access, - device_size_t element_offset, - device_size_t element_length) { - device_size_t byte_offset = element_offset * sizeof(T); - device_size_t byte_length = element_length == kWholeBuffer - ? kWholeBuffer - : element_length * sizeof(T); - void* data = nullptr; - IREE_RETURN_IF_ERROR(MapMemory(MappingMode::kScoped, memory_access, - &byte_offset, &byte_length, &data)); - return MappedMemory{ - memory_access, add_ref(this), byte_offset, - byte_length, byte_length / sizeof(T), static_cast(data)}; -} - -template -MappedMemory::MappedMemory(MemoryAccessBitfield access, - ref_ptr buffer, device_size_t byte_offset, - device_size_t byte_length, - device_size_t element_size, T* data) - : access_(access), - buffer_(std::move(buffer)), - byte_offset_(byte_offset), - byte_length_(byte_length), - element_size_(element_size), - data_(data) {} - -template -MappedMemory::MappedMemory(MappedMemory&& rhs) noexcept - : access_(rhs.access_), - buffer_(std::move(rhs.buffer_)), - byte_offset_(rhs.byte_offset_), - byte_length_(rhs.byte_length_), - element_size_(rhs.element_size_), - data_(rhs.data_) { - rhs.access_ = MemoryAccess::kNone; - rhs.buffer_.reset(); - rhs.byte_offset_ = 0; - rhs.byte_length_ = 0; - rhs.element_size_ = 0; - rhs.data_ = nullptr; -} - -template -MappedMemory& MappedMemory::operator=(MappedMemory&& rhs) noexcept { - if (this != &rhs) { - reset(); - access_ = rhs.access_; - buffer_ = std::move(rhs.buffer_); - byte_offset_ = rhs.byte_offset_; - byte_length_ = rhs.byte_length_; - element_size_ = rhs.element_size_; - data_ = rhs.data_; - - rhs.access_ = MemoryAccess::kNone; - rhs.buffer_.reset(); - rhs.byte_offset_ = 0; - rhs.byte_length_ = 0; - rhs.element_size_ = 0; - rhs.data_ = nullptr; - } - return *this; -} - -template -MappedMemory::~MappedMemory() { - // Unmap (if needed) - note that we can't fail gracefully here :( - reset(); -} - -template -const T* MappedMemory::data() const noexcept { - if (!data_ || !AnyBitSet(access_ & MemoryAccess::kRead)) { - return nullptr; - } - return data_; -} - -template -T* MappedMemory::mutable_data() noexcept { - if (!data_ || !AnyBitSet(access_ & MemoryAccess::kWrite)) { - return nullptr; - } - return data_; -} - -template -Status MappedMemory::ValidateAccess( - MemoryAccessBitfield memory_access) const { - if (!data_) { - return FailedPreconditionErrorBuilder(IREE_LOC) << "Buffer is not mapped"; - } else if (!AnyBitSet(access_ & memory_access)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer is not mapped for the desired access"; - } - return OkStatus(); -} - -template -Status MappedMemory::CalculateDataRange( - device_size_t element_offset, device_size_t element_length, - device_size_t* out_adjusted_element_offset, - device_size_t* out_adjusted_element_length) const { - IREE_RETURN_IF_ERROR(Buffer::CalculateLocalRange( - element_size_ * sizeof(T), element_offset * sizeof(T), - element_length == kWholeBuffer ? kWholeBuffer - : element_length * sizeof(T), - out_adjusted_element_offset, out_adjusted_element_length)); - *out_adjusted_element_offset /= sizeof(T); - *out_adjusted_element_length /= sizeof(T); - return OkStatus(); -} - -template -inline StatusOr> MappedMemory::Subspan( - device_size_t element_offset, device_size_t element_length) const noexcept { - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - return absl::Span(data_ + element_offset, element_length); -} - -template -inline StatusOr> MappedMemory::MutableSubspan( - device_size_t element_offset, device_size_t element_length) noexcept { - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - return absl::Span(data_ + element_offset, element_length); -} - -template -Status MappedMemory::Invalidate(device_size_t element_offset, - device_size_t element_length) const { - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - if (!element_length) return OkStatus(); - return buffer_->InvalidateMappedMemory( - byte_offset_ + element_offset * sizeof(T), element_length * sizeof(T)); -} - -template -Status MappedMemory::Flush(device_size_t element_offset, - device_size_t element_length) { - IREE_RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - if (!element_length) return OkStatus(); - return buffer_->FlushMappedMemory(byte_offset_ + element_offset * sizeof(T), - element_length * sizeof(T)); -} - -template -void MappedMemory::reset() { - if (!buffer_) return; - // TODO(benvanik): better handling of errors? may be fine to always warn. - buffer_->UnmapMemory(byte_offset_, byte_length_, data_).IgnoreError(); - buffer_.reset(); - access_ = MemoryAccess::kNone; - byte_offset_ = 0; - byte_length_ = 0; - element_size_ = 0; - data_ = nullptr; -} - -} // namespace hal -} // namespace iree +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_destroy(iree_hal_buffer_t* buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_BUFFER_H_ diff --git a/iree/hal/buffer_heap.c b/iree/hal/buffer_heap.c new file mode 100644 index 0000000000000..ca2140f63ef70 --- /dev/null +++ b/iree/hal/buffer_heap.c @@ -0,0 +1,122 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/base/tracing.h" +#include "iree/hal/allocator.h" +#include "iree/hal/buffer.h" +#include "iree/hal/detail.h" + +typedef struct iree_hal_heap_buffer_s { + iree_hal_buffer_t base; + + iree_byte_span_t data; + iree_allocator_t data_allocator; +} iree_hal_heap_buffer_t; + +static const iree_hal_buffer_vtable_t iree_hal_heap_buffer_vtable; + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_byte_span_t data, iree_allocator_t data_allocator, + iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(allocator); + IREE_ASSERT_ARGUMENT(out_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_heap_buffer_t* buffer = NULL; + iree_status_t status = + iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator), + sizeof(*buffer), (void**)&buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_heap_buffer_vtable, + &buffer->base.resource); + buffer->base.allocator = allocator; + buffer->base.allocated_buffer = &buffer->base; + buffer->base.allocation_size = allocation_size; + buffer->base.byte_offset = 0; + buffer->base.byte_length = data.data_length; + buffer->base.memory_type = memory_type; + buffer->base.allowed_access = allowed_access; + buffer->base.allowed_usage = allowed_usage; + buffer->data = data; + buffer->data_allocator = data_allocator; + *out_buffer = &buffer->base; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_heap_buffer_destroy(iree_hal_buffer_t* base_buffer) { + iree_hal_heap_buffer_t* buffer = (iree_hal_heap_buffer_t*)base_buffer; + iree_allocator_t host_allocator = + iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer)); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(buffer->data_allocator, buffer->data.data); + iree_allocator_free(host_allocator, buffer); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_heap_buffer_map_range( + iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode, + iree_hal_memory_access_t memory_access, + iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, + void** out_data_ptr) { + iree_hal_heap_buffer_t* buffer = (iree_hal_heap_buffer_t*)base_buffer; + *out_data_ptr = buffer->data.data + local_byte_offset; + + // If we mapped for discard scribble over the bytes. This is not a mandated + // behavior but it will make debugging issues easier. Alternatively for + // heap buffers we could reallocate them such that ASAN yells, but that + // would only work if the entire buffer was discarded. +#ifndef NDEBUG + if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { + memset(*out_data_ptr, 0xCD, local_byte_length); + } +#endif // !NDEBUG + + return iree_ok_status(); +} + +static void iree_hal_heap_buffer_unmap_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, void* data_ptr) { + // No-op here as we always have the pointer. +} + +static iree_status_t iree_hal_heap_buffer_invalidate_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + iree_atomic_thread_fence(iree_memory_order_acquire); + return iree_ok_status(); +} + +static iree_status_t iree_hal_heap_buffer_flush_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + iree_atomic_thread_fence(iree_memory_order_release); + return iree_ok_status(); +} + +static const iree_hal_buffer_vtable_t iree_hal_heap_buffer_vtable = { + .destroy = iree_hal_heap_buffer_destroy, + .map_range = iree_hal_heap_buffer_map_range, + .unmap_range = iree_hal_heap_buffer_unmap_range, + .invalidate_range = iree_hal_heap_buffer_invalidate_range, + .flush_range = iree_hal_heap_buffer_flush_range, +}; diff --git a/iree/hal/buffer_mapping_test.cc b/iree/hal/buffer_mapping_test.cc deleted file mode 100644 index a3d7edf6881ae..0000000000000 --- a/iree/hal/buffer_mapping_test.cc +++ /dev/null @@ -1,547 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests for the MemoryMapping RAII wrapper. -// This uses a mock buffer implementation such that it is only testing -// MemoryMapping and not any real underlying memory mapping behavior. - -#include -#include -#include - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/buffer.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -class Allocator; - -namespace { - -using ::testing::_; -using ::testing::DoAll; -using ::testing::Return; -using ::testing::SetArgPointee; - -static void* const kValidPtr = reinterpret_cast(0xBEEFCAFEF00D1234ull); - -class MockBuffer : public Buffer { - public: - using MappingMode = Buffer::MappingMode; - - MockBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, - 0, allocation_size) {} - - MOCK_METHOD(Status, FillImpl, - (device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length), - (override)); - - MOCK_METHOD(Status, ReadDataImpl, - (device_size_t source_offset, void* data, - device_size_t data_length), - (override)); - MOCK_METHOD(Status, WriteDataImpl, - (device_size_t target_offset, const void* data, - device_size_t data_length), - (override)); - MOCK_METHOD(Status, CopyDataImpl, - (device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, device_size_t data_length), - (override)); - - MOCK_METHOD(Status, MapMemoryImpl, - (MappingMode mapping_mode, MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, device_size_t local_byte_length, - void** out_data), - (override)); - MOCK_METHOD(Status, UnmapMemoryImpl, - (device_size_t local_byte_offset, device_size_t local_byte_length, - void* data), - (override)); - MOCK_METHOD(Status, InvalidateMappedMemoryImpl, - (device_size_t local_byte_offset, - device_size_t local_byte_length), - (override)); - MOCK_METHOD(Status, FlushMappedMemoryImpl, - (device_size_t local_byte_offset, - device_size_t local_byte_length), - (override)); -}; - -TEST(MemoryMappingTest, MapWholeBuffer) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mapping, - buffer->MapMemory(MemoryAccess::kRead)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mapping.reset(); -} - -TEST(MemoryMappingTest, MapPartialBuffer) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 4, 12, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, buffer->MapMemory(MemoryAccess::kRead, 4, 12)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(4, 12, kValidPtr)) - .WillOnce(Return(OkStatus())); - mapping.reset(); -} - -TEST(MemoryMappingTest, EmptyHandle) { - MappedMemory mm_a; - MappedMemory mm_b; - mm_a = std::move(mm_b); - EXPECT_EQ(nullptr, mm_a.buffer()); - EXPECT_EQ(0, mm_a.byte_offset()); - EXPECT_EQ(0, mm_a.byte_length()); - EXPECT_TRUE(mm_a.empty()); - EXPECT_EQ(0, mm_a.size()); - EXPECT_EQ(nullptr, mm_a.data()); - EXPECT_EQ(nullptr, mm_a.mutable_data()); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Subspan().status())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.MutableSubspan().status())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Invalidate())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Flush())); - mm_a.reset(); -} - -TEST(MemoryMappingTest, MoveHandle) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_a, - buffer->MapMemory(MemoryAccess::kRead)); - - // Should be able to move the handle around without having any calls. - auto mm_b = std::move(mm_a); - mm_a = std::move(mm_b); - mm_b = std::move(mm_a); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_b.reset(); -} - -TEST(MemoryMappingTest, ReadOnlyAccess) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kRead, - BufferUsage::kAll, 128); - - // Should succeed to map for reading. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory(MemoryAccess::kRead)); - - // Non-mutable access is fine. - EXPECT_EQ(kValidPtr, mm_r.data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mm_r.Subspan()); - (void)span; - - // Read-only mappings should not be able to get mutable access. - EXPECT_EQ(nullptr, mm_r.mutable_data()); - EXPECT_TRUE(IsPermissionDenied(mm_r.MutableSubspan().status())); - - // Read-only mappings should not be able to call Flush. - EXPECT_TRUE(IsPermissionDenied(mm_r.Flush())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); - - // Should fail to map for writing. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory(MemoryAccess::kWrite).status())); -} - -TEST(MemoryMappingTest, ReadWriteAccess) { - auto buffer = make_ref(nullptr, MemoryType::kHostLocal, - MemoryAccess::kRead | MemoryAccess::kWrite, - BufferUsage::kAll, 128); - - // Should succeed to map for reading and/or writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead | MemoryAccess::kWrite, - 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_rw, - buffer->MapMemory(MemoryAccess::kRead | MemoryAccess::kWrite)); - - // Everything valid. - EXPECT_EQ(kValidPtr, mm_rw.data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mm_rw.Subspan()); - EXPECT_EQ(kValidPtr, mm_rw.mutable_data()); - IREE_ASSERT_OK_AND_ASSIGN(span, mm_rw.MutableSubspan()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_rw.reset(); - - // Should fail to map for discard. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory(MemoryAccess::kDiscardWrite).status())); -} - -TEST(MemoryMappingTest, WriteOnlyAccess) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, - MemoryAccess::kWrite, BufferUsage::kAll, 128); - - // Should succeed to map for writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory(MemoryAccess::kWrite)); - - // Mutable access is valid. - EXPECT_EQ(kValidPtr, mm_w.mutable_data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mm_w.MutableSubspan()); - (void)span; - - // Write-only mappings should not be able to get non-mutable access. - EXPECT_EQ(nullptr, mm_w.data()); - EXPECT_TRUE(IsPermissionDenied(mm_w.Subspan().status())); - - // Write-only mappings should not be able to call Invalidate. - EXPECT_TRUE(IsPermissionDenied(mm_w.Invalidate())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); - - // Should fail to map for reading. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory(MemoryAccess::kRead).status())); - - // Should fail to map for discard. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory(MemoryAccess::kDiscardWrite).status())); -} - -TEST(MemoryMappingTest, WriteDiscardAccess) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, - MemoryAccess::kDiscardWrite, BufferUsage::kAll, 128); - - // Should succeed to map for writing with discard. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kDiscardWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_dw, buffer->MapMemory(MemoryAccess::kDiscardWrite)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_dw.reset(); - - // Should also be ok to map for just writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory(MemoryAccess::kWrite)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); - - // Should fail to map for reading. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory(MemoryAccess::kRead).status())); -} - -TEST(MemoryMappingTest, Subspan) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory(MemoryAccess::kRead)); - - // Request some valid ranges and ensure the byte offsets are correct. - IREE_ASSERT_OK_AND_ASSIGN(auto ss, mm_r.Subspan()); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_EQ(128, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, 2)); - EXPECT_EQ(static_cast(kValidPtr) + 100, ss.data()); - EXPECT_EQ(2, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, kWholeBuffer)); - EXPECT_EQ(static_cast(kValidPtr) + 100, ss.data()); - EXPECT_EQ(28, ss.size()); - - // Zero length ranges are fine. - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(0, 0)); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, 0)); - EXPECT_EQ(static_cast(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, kWholeBuffer)); - EXPECT_EQ(static_cast(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, SubspanOutOfRange) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory(MemoryAccess::kRead)); - - // Try some invalid ranges that would overrun the span. - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 0).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 2).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, kWholeBuffer).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(100, 1234).status())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, MutableSubspan) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory(MemoryAccess::kWrite)); - - // Request some valid ranges and ensure the byte offsets are correct. - IREE_ASSERT_OK_AND_ASSIGN(auto ss, mm_w.MutableSubspan()); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_EQ(128, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, 2)); - EXPECT_EQ(static_cast(kValidPtr) + 100, ss.data()); - EXPECT_EQ(2, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, kWholeBuffer)); - EXPECT_EQ(static_cast(kValidPtr) + 100, ss.data()); - EXPECT_EQ(28, ss.size()); - - // Zero length ranges are fine. - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(0, 0)); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, 0)); - EXPECT_EQ(static_cast(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, kWholeBuffer)); - EXPECT_EQ(static_cast(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, MutableSubspanOutOfRange) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory(MemoryAccess::kWrite)); - - // Try some invalid ranges that would overrun the span. - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 0).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 2).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, kWholeBuffer).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(100, 1234).status())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, ElementOperator) { - auto buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory(MemoryAccess::kRead)); - - // Just verify we are getting the expected pointer back. - EXPECT_EQ(kValidPtr, &mm_r[0]); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, Invalidate) { - auto buffer = - make_ref(nullptr, MemoryType::kHostVisible, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory(MemoryAccess::kRead)); - - // Invalidate a few ways. - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(0, 128)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_r.Invalidate()); - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 2)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_r.Invalidate(100, 2)); - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 28)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_r.Invalidate(100, kWholeBuffer)); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, InvalidateOutOfRange) { - auto buffer = - make_ref(nullptr, MemoryType::kHostVisible, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory(MemoryAccess::kRead)); - - // Try to invalidate invalid ranges. - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 0))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 12345))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, kWholeBuffer))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1, 1234))); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, InvalidateBadMode) { - // Invalidate is not required on coherent memory. - auto coherent_buffer = - make_ref(nullptr, MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kAll, 128); - EXPECT_CALL(*coherent_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_r, coherent_buffer->MapMemory(MemoryAccess::kRead)); - EXPECT_TRUE(IsPermissionDenied(mm_r.Invalidate())); - EXPECT_CALL(*coherent_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, Flush) { - auto buffer = make_ref( - nullptr, MemoryType::kHostVisible | MemoryType::kHostCached, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory(MemoryAccess::kWrite)); - - // Flush a few ways. - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(0, 128)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_w.Flush()); - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 2)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_w.Flush(100, 2)); - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 28)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_w.Flush(100, kWholeBuffer)); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, FlushOutOfRange) { - auto buffer = make_ref( - nullptr, MemoryType::kHostVisible | MemoryType::kHostCached, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory(MemoryAccess::kWrite)); - - // Try to flush invalid ranges. - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 0))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 12345))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, kWholeBuffer))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1, 1234))); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, FlushBadMode) { - // Flush is not required on uncached memory. - auto uncached_buffer = - make_ref(nullptr, MemoryType::kHostVisible, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*uncached_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_w, uncached_buffer->MapMemory(MemoryAccess::kWrite)); - EXPECT_TRUE(IsPermissionDenied(mm_w.Flush())); - EXPECT_CALL(*uncached_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -} // namespace -} // namespace hal -} // namespace iree diff --git a/iree/hal/buffer_test.cc b/iree/hal/buffer_test.cc deleted file mode 100644 index bf5853927831d..0000000000000 --- a/iree/hal/buffer_test.cc +++ /dev/null @@ -1,1013 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests for the shared buffer functionality and host heap buffers. -// This does not test device-specific buffer implementations; see the device -// code for associated tests. - -#include "iree/hal/buffer.h" - -#include - -#include "iree/hal/heap_buffer.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -namespace { - -using ::testing::_; -using ::testing::ElementsAre; -using ::testing::Eq; -using ::testing::Not; - -TEST(BufferTest, Allocate) { - auto buffer = - HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 14); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(MemoryAccess::kAll, buffer->allowed_access()); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - - // We don't currently do any padding on the host. - // Other implementations may differ. - EXPECT_LE(14, buffer->allocation_size()); - EXPECT_EQ(0, buffer->byte_offset()); - EXPECT_EQ(14, buffer->byte_length()); - - // Data should be zeroed by default. - std::vector zero_data(buffer->allocation_size()); - std::vector actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(zero_data)); -} - -TEST(BufferTest, AllocateZeroLength) { - auto buffer = - HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 0); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_EQ(0, buffer->allocation_size()); -} - -TEST(BufferTest, AllocateCopy) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_LE(src_data.size(), buffer->allocation_size()); - - // Data should have been copied. - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data and ensure it is not reflected in the buffer. - src_data[0] = 0x88; - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Not(Eq(src_data))); -} - -TEST(BufferTest, AllocateCopyZeroLength) { - std::vector src_data; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(0, buffer->allocation_size()); -} - -TEST(BufferTest, AllocateCopyTyped) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - absl::MakeConstSpan(src_data)); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_LE(src_data.size() * sizeof(int32_t), buffer->allocation_size()); - - // Data should have been copied. - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), - actual_data.size() * sizeof(int32_t))); - EXPECT_THAT(actual_data, Eq(src_data)); -} - -TEST(BufferTest, WrapConstant) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::Wrap(MemoryType::kHostLocal, - BufferUsage::kTransfer | BufferUsage::kMapping, - absl::MakeConstSpan(src_data)); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_EQ(src_data.size(), buffer->allocation_size()); - - // src_data and buffer should match after the wrapping. - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data directly. - src_data[0] = 123; - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Attempts to modify the buffer should fail. - std::vector new_data = {3, 2, 1, 0}; - EXPECT_TRUE(IsPermissionDenied( - buffer->WriteData(0, new_data.data(), new_data.size()))); -} - -TEST(BufferTest, WrapMutable) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::WrapMutable( - MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data)); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_EQ(src_data.size(), buffer->allocation_size()); - - // src_data and buffer should match after the wrapping. - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data directly. - src_data[0] = 123; - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data via the Buffer and ensure reflected in src_data. - std::vector new_data = {3, 2, 1, 0}; - IREE_EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); - EXPECT_THAT(src_data, Eq(new_data)); -} - -TEST(BufferTest, WrapExternal) { - // This is not fully supported yet, but does let us verify that the validation - // of memory types is working. - std::vector src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::Wrap(MemoryType::kDeviceLocal, BufferUsage::kAll, - absl::MakeConstSpan(src_data)); - EXPECT_EQ(MemoryType::kDeviceLocal, buffer->memory_type()); - - // Should fail (for now) as the buffer is not host visible. - EXPECT_TRUE(IsPermissionDenied(buffer->Fill8(0, kWholeBuffer, 0x99u))); -} - -TEST(BufferTest, DoesOverlap) { - std::vector src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - - // A buffer should overlap with itself. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - parent_buffer.get(), 1, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - parent_buffer.get(), 0, 1)); - - // Zero length buffers never overlap. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 1, 1, - parent_buffer.get(), 1, 0)); - - // Subspans should offset within their allocation. - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer_0, - Buffer::Subspan(parent_buffer, 1, 2)); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer_1, - Buffer::Subspan(parent_buffer, 2, 2)); - EXPECT_FALSE(Buffer::DoesOverlap(subspan_buffer_0.get(), 0, 1, - subspan_buffer_1.get(), 0, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(subspan_buffer_0.get(), 1, 1, - subspan_buffer_1.get(), 0, 1)); - - // Mixing subspans and normal buffers. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - subspan_buffer_0.get(), 0, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 1, 2, - subspan_buffer_0.get(), 1, 1)); - - // Independent buffers should not be able to overlap. - auto other_buffer = HeapBuffer::Allocate(BufferUsage::kAll, 128); - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, kWholeBuffer, - other_buffer.get(), 0, kWholeBuffer)); -} - -TEST(BufferTest, Subspan) { - std::vector src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - - // Create a subspan of the buffer. - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 2)); - ASSERT_TRUE(subspan_buffer); - EXPECT_EQ(1, subspan_buffer->byte_offset()); - EXPECT_EQ(2, subspan_buffer->byte_length()); - - // Modifications to either buffer should appear in the other. - IREE_EXPECT_OK(subspan_buffer->Fill8(1, kWholeBuffer, 0xFFu)); - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK( - parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xFF, 3)); - - // Subspans should be able to create subspans. - // NOTE: offset is from the original buffer. - IREE_ASSERT_OK_AND_ASSIGN(auto subsubspan_buffer, - Buffer::Subspan(subspan_buffer, 1, 1)); - ASSERT_TRUE(subsubspan_buffer); - EXPECT_EQ(2, subsubspan_buffer->byte_offset()); - EXPECT_EQ(1, subsubspan_buffer->byte_length()); - - // Zero length subspans are fine. - IREE_ASSERT_OK_AND_ASSIGN(auto zero_subspan_buffer, - Buffer::Subspan(parent_buffer, 0, 0)); - ASSERT_TRUE(zero_subspan_buffer); - EXPECT_EQ(0, zero_subspan_buffer->byte_offset()); - EXPECT_EQ(0, zero_subspan_buffer->byte_length()); - - // Subspan with kWholeBuffer should get the remaining size (or zero). - IREE_ASSERT_OK_AND_ASSIGN(auto whole_subspan_buffer, - Buffer::Subspan(parent_buffer, 1, kWholeBuffer)); - ASSERT_TRUE(whole_subspan_buffer); - EXPECT_EQ(1, whole_subspan_buffer->byte_offset()); - EXPECT_EQ(3, whole_subspan_buffer->byte_length()); - - // Zero length subspans are fine. - IREE_ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, 0)); - IREE_ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, kWholeBuffer)); -} - -TEST(BufferTest, SubspanIdentity) { - std::vector src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - - // Asking for a subspan of the entire buffer should return the same buffer. - // Mostly an optimization. - EXPECT_EQ(parent_buffer.get(), - Buffer::Subspan(parent_buffer, 0, kWholeBuffer).value().get()); - EXPECT_EQ(parent_buffer.get(), - Buffer::Subspan(parent_buffer, 0, 4).value().get()); -} - -TEST(BufferTest, SubspanOutOfRange) { - std::vector src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - - // Create a subspan of the buffer. - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 2)); - ASSERT_TRUE(subspan_buffer); - EXPECT_EQ(1, subspan_buffer->byte_offset()); - EXPECT_EQ(2, subspan_buffer->byte_length()); - - // Try to make subspans from invalid ranges. - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 5, 0).status())); - EXPECT_TRUE( - IsOutOfRange(Buffer::Subspan(parent_buffer, 5, kWholeBuffer).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 4, 1).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 0, 123).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 1, 2).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 0, 44).status())); -} - -TEST(BufferTest, Fill8) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0)); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); - - // Verify data. - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill8(0, 0, 0x44u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - IREE_EXPECT_OK(buffer->Fill8(2, kWholeBuffer, 0x55u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x55, 0x55, 0x55)); - - // Fill a small region of the buffer. - IREE_EXPECT_OK(buffer->Fill8(1, 1, 0x66u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x66, 0x55, 0x55, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(buffer->Fill8(0x99u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x99, 0x99, 0x99, 0x99, 0x99)); -} - -TEST(BufferTest, Fill8OutOfRange) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 444, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 1, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); - - // Ensure nothing happened with the bad ranges. - std::vector actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); -} - -TEST(BufferTest, Fill8BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE( - IsPermissionDenied(nonmapping_buffer->Fill8(0, kWholeBuffer, 0x99u))); - - // Fail to fill constant buffers. - std::vector const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->Fill8(0, kWholeBuffer, 0x99u))); -} - -TEST(BufferTest, Fill8Subspan) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector actual_data(buffer->allocation_size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 3)); - IREE_EXPECT_OK(subspan_buffer->Fill8(2, kWholeBuffer, 0xDDu)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0xDD, 0)); -} - -TEST(BufferTest, Fill16) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); - - // Verify data. - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill16(0, 0, 0x5566u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8); - IREE_EXPECT_OK(aligned_buffer->Fill16(4, kWholeBuffer, 0x5566u)); - std::vector aligned_actual_data(aligned_buffer->allocation_size()); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x66, 0x55, 0x66, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(aligned_buffer->Fill16(0x5566u)); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x66, 0x55, 0x66, 0x55, 0x66, 0x55, 0x66, 0x55)); -} - -TEST(BufferTest, Fill16OutOfRange) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 444, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 4, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); -} - -TEST(BufferTest, Fill16Unaligned) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with unaligned ranges. - EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(1, 4, 0x5566u))); - EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(0, 5, 0x5566u))); -} - -TEST(BufferTest, Fill16BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE( - IsPermissionDenied(nonmapping_buffer->Fill16(0, kWholeBuffer, 0x99AAu))); - - // Fail to fill constant buffers. - std::vector const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->Fill16(0, kWholeBuffer, 0x99AAu))); -} - -TEST(BufferTest, Fill16Subspan) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); - - // Test on subspan. - std::vector actual_data(buffer->allocation_size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 2, 4)); - IREE_EXPECT_OK(subspan_buffer->Fill16(2, kWholeBuffer, 0xAABBu)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x22, 0x11, 0x22, 0x11, 0xBB, 0xAA, 0, 0, 0)); -} - -TEST(BufferTest, Fill32) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); - - // Verify data. - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill32(0, 0, 0x55667788u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8); - IREE_EXPECT_OK(aligned_buffer->Fill32(4, kWholeBuffer, 0x55667788u)); - std::vector aligned_actual_data(aligned_buffer->allocation_size()); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x88, 0x77, 0x66, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(aligned_buffer->Fill32(0x55667788u)); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x88, 0x77, 0x66, 0x55, 0x88, 0x77, 0x66, 0x55)); -} - -TEST(BufferTest, Fill32OutOfRange) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 444, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 4, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); -} - -TEST(BufferTest, Fill32Unaligned) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with unaligned ranges. - EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(1, 4, 0x55667788u))); - EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(0, 5, 0x55667788u))); -} - -TEST(BufferTest, Fill32BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu))); - - // Fail to fill constant buffers. - std::vector const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE(IsPermissionDenied( - constant_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu))); -} - -TEST(BufferTest, Fill32Subspan) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); - - // Test on subspan. - std::vector actual_data(buffer->allocation_size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 4, 4)); - IREE_EXPECT_OK(subspan_buffer->Fill32(0, kWholeBuffer, 0xAABBCCDDu)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0xDD, 0xCC, 0xBB, 0xAA, 0)); -} - -TEST(BufferTest, ReadData) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Read the data back. - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Reading zero bytes is valid. - std::vector zero_data(0); - IREE_EXPECT_OK(buffer->ReadData(1, zero_data.data(), 0)); - - // Read a portion of the data. - std::vector partial_data(2); - IREE_EXPECT_OK(buffer->ReadData(1, partial_data.data(), 2)); - EXPECT_THAT(partial_data, ElementsAre(1, 2)); -} - -TEST(BufferTest, ReadDataOutOfRange) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Try to read out of range. - std::vector partial_data(2); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(0, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 1))); - EXPECT_TRUE(IsInvalidArgument( - buffer->ReadData(0, partial_data.data(), kWholeBuffer))); -} - -TEST(BufferTest, ReadDataBadMode) { - // Fail to read buffers not supporting mapping. - std::vector actual_data(1); - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->ReadData(0, actual_data.data(), 1))); -} - -TEST(BufferTest, ReadDataSubspan) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector subspan_data(1); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); - IREE_EXPECT_OK(subspan_buffer->ReadData(1, subspan_data.data(), 1)); - EXPECT_THAT(subspan_data, ElementsAre(2)); -} - -TEST(BufferTest, WriteData) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Read the data back - should still match. - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Write over the entire buffer. - std::vector new_data = {10, 20, 30, 40}; - IREE_EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Writing zero bytes is valid. - std::vector zero_data; - IREE_EXPECT_OK(buffer->WriteData(0, zero_data.data(), 0)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Write over a portion of the buffer. - std::vector partial_data = {99}; - IREE_EXPECT_OK( - buffer->WriteData(1, partial_data.data(), partial_data.size())); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(10, 99, 30, 40)); -} - -TEST(BufferTest, WriteDataOutOfRange) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Try to write out of range. - std::vector partial_data = {99}; - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(0, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 1))); - EXPECT_TRUE(IsInvalidArgument( - buffer->WriteData(0, partial_data.data(), kWholeBuffer))); -} - -TEST(BufferTest, WriteDataBadMode) { - std::vector actual_data(4); - - // Fail to write buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->WriteData(0, actual_data.data(), 1))); - - // Fail to write to constant buffers. - std::vector const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->WriteData(0, actual_data.data(), 2))); -} - -TEST(BufferTest, WriteDataSubspan) { - std::vector src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector subspan_data = {0xAA}; - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); - IREE_EXPECT_OK(subspan_buffer->WriteData(1, subspan_data.data(), 1)); - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xAA, 3)); -} - -TEST(BufferTest, CopyData) { - std::vector src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Copy of length 0 should not change the dest buffer. - IREE_EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 0, 0)); - std::vector actual_data(dst_data.size()); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(dst_data)); - - // Copy a subrange of the buffer. - IREE_EXPECT_OK(dst_buffer->CopyData(1, src_buffer.get(), 2, 2)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 3, 4)); - - // Copy the entire buffer using kWholeBuffer. This will adjust sizes - // to ensure that the min buffer is taken. We test both src and dst buffer - // offset/length calculations (note that some may end up as 0 copies). - IREE_EXPECT_OK(dst_buffer->CopyData(3, src_buffer.get(), 0, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 0, 1)); - IREE_EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 2, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(2, 3, 3, 0, 1)); - IREE_EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 3, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 1)); - IREE_EXPECT_OK(dst_buffer->CopyData(4, src_buffer.get(), 0, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 0)); -} - -TEST(BufferTest, CopyDataOutOfRange) { - std::vector src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Try to copy out of range of source and dest. - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 0, 1))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(4, src_buffer.get(), 0, 4))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 1))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 0, 123))); - EXPECT_TRUE( - IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 123, 123))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 0))); -} - -TEST(BufferTest, CopyDataOverlapping) { - std::vector src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Test overlap. Non-overlapping regions should be fine, otherwise fail. - std::vector actual_data(dst_data.size()); - IREE_EXPECT_OK(dst_buffer->CopyData(0, dst_buffer.get(), 4, 1)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(4, 1, 2, 3, 4)); - EXPECT_TRUE( - IsInvalidArgument(dst_buffer->CopyData(2, dst_buffer.get(), 0, 3))); - EXPECT_TRUE( - IsInvalidArgument(dst_buffer->CopyData(0, dst_buffer.get(), 0, 3))); -} - -TEST(BufferTest, CopyDataBadMode) { - // Both source and target buffers must support mapping. - auto nonmapping_src_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - auto nonmapping_dst_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied(nonmapping_dst_buffer->CopyData( - 0, nonmapping_src_buffer.get(), 0, kWholeBuffer))); - EXPECT_TRUE(IsPermissionDenied(nonmapping_src_buffer->CopyData( - 0, nonmapping_dst_buffer.get(), 0, kWholeBuffer))); - - // Fail to copy into to constant buffers. - std::vector const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer, - absl::MakeConstSpan(const_data)); - std::vector src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - EXPECT_TRUE(IsPermissionDenied( - constant_buffer->CopyData(0, src_buffer.get(), 0, kWholeBuffer))); -} - -TEST(BufferTest, CopyDataSubspan) { - std::vector src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Test on subspan. - std::vector actual_data(dst_data.size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_src_buffer, - Buffer::Subspan(src_buffer, 1, 3)); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_dst_buffer, - Buffer::Subspan(dst_buffer, 2, 3)); - IREE_EXPECT_OK( - subspan_dst_buffer->CopyData(1, subspan_src_buffer.get(), 1, 2)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 2, 2, 3)); -} - -// NOTE: more tests related specifically to MappedMemory are in -// buffer_mapping_test.cc. This tests the MapMemory operation and enough to -// ensure the memory was mapped to the correct range and the HostBuffer and -// SubspanBuffer work as intended for basic usage. -TEST(BufferTest, MapMemory) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // 0-length mappings are valid. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, buffer->MapMemory(MemoryAccess::kRead, 0, 0)); - EXPECT_TRUE(mapping.empty()); - EXPECT_EQ(0, mapping.size()); - EXPECT_EQ(0, mapping.byte_length()); - EXPECT_NE(nullptr, mapping.data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mapping.Subspan()); - EXPECT_TRUE(span.empty()); - mapping.reset(); - - // Map the whole buffer for reading. - IREE_ASSERT_OK_AND_ASSIGN(mapping, buffer->MapMemory( - MemoryAccess::kRead, 0, kWholeBuffer)); - EXPECT_EQ(src_data.size(), mapping.size()); - IREE_ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(0, 1, 2, 3, 4, 5, 6)); - mapping.reset(); - - // Map a portion of the buffer for reading. - IREE_ASSERT_OK_AND_ASSIGN( - mapping, buffer->MapMemory(MemoryAccess::kRead, 1, 2)); - EXPECT_EQ(2, mapping.size()); - IREE_ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(1, 2)); - mapping.reset(); -} - -TEST(BufferTest, MapMemoryNonByte) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map the buffer as non-byte values. - // Note that we'll round down to the number of valid elements at the - // alignment. - IREE_ASSERT_OK_AND_ASSIGN(auto mapping16, - buffer->MapMemory(MemoryAccess::kRead)); - EXPECT_EQ(3, mapping16.size()); - EXPECT_LE(6, mapping16.byte_length()); - IREE_ASSERT_OK_AND_ASSIGN(auto span16, mapping16.Subspan()); - EXPECT_THAT(span16, ElementsAre(0x0100, 0x0302, 0x0504)); - mapping16.reset(); -} - -TEST(BufferTest, MapMemoryOutOfRange) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test invalid mapping ranges. - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory(MemoryAccess::kRead, 0, 123).status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory(MemoryAccess::kRead, 5, 1231).status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory(MemoryAccess::kRead, 6, kWholeBuffer) - .status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory(MemoryAccess::kRead, 1236, 1).status())); -} - -TEST(BufferTest, MapMemoryBadMode) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - auto read_buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(read_buffer); - - // Test mapping the read-only buffer for writing. - EXPECT_TRUE(IsPermissionDenied( - read_buffer->MapMemory(MemoryAccess::kWrite).status())); - EXPECT_TRUE(IsPermissionDenied( - read_buffer->MapMemory(MemoryAccess::kDiscardWrite).status())); - EXPECT_TRUE(IsPermissionDenied( - read_buffer - ->MapMemory(MemoryAccess::kRead | MemoryAccess::kDiscard) - .status())); - EXPECT_TRUE(IsInvalidArgument( - read_buffer->MapMemory(MemoryAccess::kNone).status())); -} - -TEST(BufferTest, MapMemoryWrite) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map and modify the data. We should see it when we read back. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, buffer->MapMemory(MemoryAccess::kWrite, 1, 2)); - auto mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xAA; - mutable_data[1] = 0xBB; - mapping.reset(); - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0xAA, 0xBB, 3, 4, 5, 6)); -} - -TEST(BufferTest, MapMemoryDiscard) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map for discard. Note that we can't really rely on the value of the data - // so we just trust that it's been discarded. It's a hint, anyway. We can be - // sure that the data we didn't want to discard is the same though. - std::vector actual_data(src_data.size()); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - buffer->MapMemory(MemoryAccess::kDiscardWrite, 1, 2)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, _, _, 3, 4, 5, 6)); - mapping.reset(); -} - -TEST(BufferTest, MapMemorySubspan) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - auto parent_buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 3)); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - subspan_buffer->MapMemory(MemoryAccess::kDiscardWrite, 1, 2)); - auto* mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xCC; - mutable_data[1] = 0xDD; - mapping.reset(); - - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK( - parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6)); - - // Just here to make coverage happy; they are currently no-ops on the host. - // buffer_mapping_test.cc contains tests that ensure they are called - // correctly. - std::vector external_data = {0, 1, 2, 3, 4}; - auto external_buffer = HeapBuffer::WrapMutable( - MemoryType::kHostVisible | MemoryType::kHostCached, MemoryAccess::kAll, - BufferUsage::kAll, absl::MakeSpan(external_data)); - IREE_ASSERT_OK_AND_ASSIGN(auto external_subspan_buffer, - Buffer::Subspan(external_buffer, 0, 1)); - IREE_ASSERT_OK_AND_ASSIGN( - mapping, external_subspan_buffer->MapMemory(MemoryAccess::kAll)); - IREE_EXPECT_OK(mapping.Invalidate()); - IREE_EXPECT_OK(mapping.Flush()); -} - -} // namespace -} // namespace hal -} // namespace iree diff --git a/iree/hal/buffer_view.c b/iree/hal/buffer_view.c new file mode 100644 index 0000000000000..00875c5265463 --- /dev/null +++ b/iree/hal/buffer_view.c @@ -0,0 +1,325 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/buffer_view.h" + +#include "iree/base/tracing.h" +#include "iree/hal/allocator.h" +#include "iree/hal/detail.h" + +struct iree_hal_buffer_view_s { + iree_atomic_ref_count_t ref_count; + iree_hal_buffer_t* buffer; + iree_hal_element_type_t element_type; + iree_device_size_t byte_length; + iree_host_size_t shape_rank; + iree_hal_dim_t shape[]; +}; + +IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create( + iree_hal_buffer_t* buffer, const iree_hal_dim_t* shape, + iree_host_size_t shape_rank, iree_hal_element_type_t element_type, + iree_hal_buffer_view_t** out_buffer_view) { + IREE_ASSERT_ARGUMENT(buffer); + IREE_ASSERT_ARGUMENT(out_buffer_view); + + *out_buffer_view = NULL; + if (IREE_UNLIKELY(shape_rank > 0 && !shape)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no shape dimensions specified"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_t host_allocator = + iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(buffer)); + + // Allocate and initialize the iree_hal_buffer_view_t struct. + // Note that we have the dynamically-sized shape dimensions on the end. + iree_hal_buffer_view_t* buffer_view = NULL; + iree_status_t status = iree_allocator_malloc( + host_allocator, + sizeof(*buffer_view) + sizeof(iree_hal_dim_t) * shape_rank, + (void**)&buffer_view); + if (iree_status_is_ok(status)) { + iree_atomic_ref_count_init(&buffer_view->ref_count); + buffer_view->buffer = buffer; + iree_hal_buffer_retain(buffer_view->buffer); + buffer_view->element_type = element_type; + buffer_view->byte_length = + iree_hal_element_byte_count(buffer_view->element_type); + buffer_view->shape_rank = shape_rank; + for (iree_host_size_t i = 0; i < shape_rank; ++i) { + buffer_view->shape[i] = shape[i]; + buffer_view->byte_length *= shape[i]; + } + *out_buffer_view = buffer_view; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view) { + if (IREE_LIKELY(buffer_view)) { + iree_atomic_ref_count_inc(&buffer_view->ref_count); + } +} + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view) { + if (IREE_LIKELY(buffer_view) && + iree_atomic_ref_count_dec(&buffer_view->ref_count) == 1) { + iree_hal_buffer_view_destroy(buffer_view); + } +} + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_view_destroy(iree_hal_buffer_view_t* buffer_view) { + iree_allocator_t host_allocator = iree_hal_allocator_host_allocator( + iree_hal_buffer_allocator(buffer_view->buffer)); + iree_hal_buffer_release(buffer_view->buffer); + iree_allocator_free(host_allocator, buffer_view); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_subview( + const iree_hal_buffer_view_t* buffer_view, + const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, + const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, + iree_hal_buffer_view_t** out_buffer_view) { + IREE_ASSERT_ARGUMENT(out_buffer_view); + + // NOTE: we rely on the compute range call to do parameter validation. + iree_device_size_t start_offset = 0; + iree_device_size_t subview_length = 0; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_compute_range( + buffer_view, start_indices, indices_count, lengths, lengths_count, + &start_offset, &subview_length)); + + iree_hal_buffer_t* subview_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan( + buffer_view->buffer, start_offset, subview_length, &subview_buffer)); + + iree_status_t status = + iree_hal_buffer_view_create(subview_buffer, lengths, lengths_count, + buffer_view->element_type, out_buffer_view); + iree_hal_buffer_release(subview_buffer); + return status; +} + +IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer( + const iree_hal_buffer_view_t* buffer_view) { + IREE_ASSERT_ARGUMENT(buffer_view); + return buffer_view->buffer; +} + +IREE_API_EXPORT iree_host_size_t IREE_API_CALL +iree_hal_buffer_view_shape_rank(const iree_hal_buffer_view_t* buffer_view) { + IREE_ASSERT_ARGUMENT(buffer_view); + return buffer_view->shape_rank; +} + +IREE_API_EXPORT const iree_hal_dim_t* IREE_API_CALL +iree_hal_buffer_view_shape_dims(const iree_hal_buffer_view_t* buffer_view) { + IREE_ASSERT_ARGUMENT(buffer_view); + return buffer_view->shape; +} + +IREE_API_EXPORT iree_hal_dim_t IREE_API_CALL iree_hal_buffer_view_shape_dim( + const iree_hal_buffer_view_t* buffer_view, iree_host_size_t index) { + IREE_ASSERT_ARGUMENT(buffer_view); + if (IREE_UNLIKELY(index > buffer_view->shape_rank)) { + return 0; + } + return buffer_view->shape[index]; +} + +IREE_API_EXPORT iree_host_size_t +iree_hal_buffer_view_element_count(const iree_hal_buffer_view_t* buffer_view) { + IREE_ASSERT_ARGUMENT(buffer_view); + iree_host_size_t element_count = 1; + for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) { + element_count *= buffer_view->shape[i]; + } + return element_count; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape( + const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity, + iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) { + IREE_ASSERT_ARGUMENT(buffer_view); + IREE_ASSERT_ARGUMENT(out_shape); + if (out_shape_rank) { + *out_shape_rank = 0; + } + + if (out_shape_rank) { + *out_shape_rank = buffer_view->shape_rank; + } + if (rank_capacity < buffer_view->shape_rank) { + // Not an error; just a size query. + return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); + } + + for (iree_host_size_t i = 0; i < buffer_view->shape_rank; ++i) { + out_shape[i] = buffer_view->shape[i]; + } + + return iree_ok_status(); +} + +IREE_API_EXPORT iree_hal_element_type_t IREE_API_CALL +iree_hal_buffer_view_element_type(const iree_hal_buffer_view_t* buffer_view) { + IREE_ASSERT_ARGUMENT(buffer_view); + return buffer_view->element_type; +} + +IREE_API_EXPORT iree_host_size_t +iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) { + IREE_ASSERT_ARGUMENT(buffer_view); + return iree_hal_element_byte_count(buffer_view->element_type); +} + +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_view_byte_length(const iree_hal_buffer_view_t* buffer_view) { + IREE_ASSERT_ARGUMENT(buffer_view); + return buffer_view->byte_length; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_offset( + const iree_hal_buffer_view_t* buffer_view, const iree_hal_dim_t* indices, + iree_host_size_t indices_count, iree_device_size_t* out_offset) { + IREE_ASSERT_ARGUMENT(buffer_view); + return iree_hal_buffer_compute_view_offset( + buffer_view->shape, buffer_view->shape_rank, buffer_view->element_type, + indices, indices_count, out_offset); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_range( + const iree_hal_buffer_view_t* buffer_view, + const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, + const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, + iree_device_size_t* out_start_offset, iree_device_size_t* out_length) { + IREE_ASSERT_ARGUMENT(buffer_view); + return iree_hal_buffer_compute_view_range( + buffer_view->shape, buffer_view->shape_rank, buffer_view->element_type, + start_indices, indices_count, lengths, lengths_count, out_start_offset, + out_length); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_size( + const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_hal_element_type_t element_type, + iree_device_size_t* out_allocation_size) { + IREE_ASSERT_ARGUMENT(shape); + IREE_ASSERT_ARGUMENT(out_allocation_size); + *out_allocation_size = 0; + iree_device_size_t byte_length = iree_hal_element_byte_count(element_type); + for (iree_host_size_t i = 0; i < shape_rank; ++i) { + byte_length *= shape[i]; + } + *out_allocation_size = byte_length; + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_offset( + const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_hal_element_type_t element_type, const iree_hal_dim_t* indices, + iree_host_size_t indices_count, iree_device_size_t* out_offset) { + IREE_ASSERT_ARGUMENT(shape); + IREE_ASSERT_ARGUMENT(indices); + IREE_ASSERT_ARGUMENT(out_offset); + *out_offset = 0; + if (IREE_UNLIKELY(shape_rank != indices_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "shape rank/indices mismatch: %zu != %zu", + shape_rank, indices_count); + } + + iree_device_size_t offset = 0; + for (iree_host_size_t i = 0; i < indices_count; ++i) { + if (IREE_UNLIKELY(indices[i] >= shape[i])) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "index[%zu] out of bounds: %d >= %d", i, + indices[i], shape[i]); + } + iree_device_size_t axis_offset = indices[i]; + for (iree_host_size_t j = i + 1; j < shape_rank; ++j) { + axis_offset *= shape[j]; + } + offset += axis_offset; + } + offset *= iree_hal_element_byte_count(element_type); + + *out_offset = offset; + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_range( + const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_hal_element_type_t element_type, const iree_hal_dim_t* start_indices, + iree_host_size_t indices_count, const iree_hal_dim_t* lengths, + iree_host_size_t lengths_count, iree_device_size_t* out_start_offset, + iree_device_size_t* out_length) { + IREE_ASSERT_ARGUMENT(shape); + IREE_ASSERT_ARGUMENT(start_indices); + IREE_ASSERT_ARGUMENT(lengths); + IREE_ASSERT_ARGUMENT(out_start_offset); + IREE_ASSERT_ARGUMENT(out_length); + *out_start_offset = 0; + *out_length = 0; + if (IREE_UNLIKELY(indices_count != lengths_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "indices/lengths mismatch: %zu != %zu", + indices_count, lengths_count); + } + if (IREE_UNLIKELY(shape_rank != indices_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "shape rank/indices mismatch: %zu != %zu", + shape_rank, indices_count); + } + + iree_hal_dim_t* end_indices = + iree_alloca(shape_rank * sizeof(iree_hal_dim_t)); + iree_device_size_t element_size = iree_hal_element_byte_count(element_type); + iree_device_size_t subspan_length = element_size; + for (iree_host_size_t i = 0; i < lengths_count; ++i) { + subspan_length *= lengths[i]; + end_indices[i] = start_indices[i] + lengths[i] - 1; + } + + iree_device_size_t start_byte_offset = 0; + IREE_RETURN_IF_ERROR(iree_hal_buffer_compute_view_offset( + shape, shape_rank, element_type, start_indices, indices_count, + &start_byte_offset)); + iree_device_size_t end_byte_offset = 0; + IREE_RETURN_IF_ERROR(iree_hal_buffer_compute_view_offset( + shape, shape_rank, element_type, end_indices, shape_rank, + &end_byte_offset)); + + // Non-contiguous regions not yet implemented. Will be easier to detect when + // we have strides. + iree_device_size_t offset_length = + end_byte_offset - start_byte_offset + element_size; + if (subspan_length != offset_length) { + return iree_make_status( + IREE_STATUS_UNIMPLEMENTED, + "non-contiguous range region computation not implemented"); + } + + *out_start_offset = start_byte_offset; + *out_length = subspan_length; + return iree_ok_status(); +} diff --git a/iree/hal/buffer_view.cc b/iree/hal/buffer_view.cc new file mode 100644 index 0000000000000..850b140ff5f3e --- /dev/null +++ b/iree/hal/buffer_view.cc @@ -0,0 +1,230 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/buffer_view.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "iree/base/api.h" +#include "iree/base/memory.h" +#include "iree/base/tracing.h" +#include "iree/hal/allocator.h" +#include "iree/hal/string_util.h" + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_parse( + iree_string_view_t value, iree_hal_allocator_t* buffer_allocator, + iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_parse"); + IREE_ASSERT_ARGUMENT(buffer_allocator); + + // Strip whitespace that may come along (linefeeds/etc). + auto string_view = + absl::StripAsciiWhitespace(absl::string_view(value.data, value.size)); + string_view = absl::StripPrefix(string_view, "\""); + string_view = absl::StripSuffix(string_view, "\""); + if (string_view.empty()) { + // Empty lines are invalid; need at least the shape/type information. + *out_buffer_view = nullptr; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "empty string input"); + } + + // The part of the string corresponding to the shape, e.g. 1x2x3. + absl::string_view shape_str; + // The part of the string corresponding to the type, e.g. f32 + absl::string_view type_str; + // The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6 + absl::string_view data_str; + + absl::string_view shape_and_type_str; + auto equal_index = string_view.find('='); + if (equal_index == std::string::npos) { + // Treat a lack of = as defaulting the data to zeros. + shape_and_type_str = string_view; + } else { + shape_and_type_str = string_view.substr(0, equal_index); + data_str = string_view.substr(equal_index + 1); + } + auto last_x_index = shape_and_type_str.rfind('x'); + if (last_x_index == std::string::npos) { + // Scalar. + type_str = shape_and_type_str; + } else { + // Has a shape. + shape_str = shape_and_type_str.substr(0, last_x_index); + type_str = shape_and_type_str.substr(last_x_index + 1); + } + + // AxBxC... + absl::InlinedVector shape(6); + iree_host_size_t shape_rank = 0; + iree_status_t shape_result = + iree_hal_parse_shape({shape_str.data(), shape_str.length()}, shape.size(), + shape.data(), &shape_rank); + if (iree_status_is_ok(shape_result)) { + shape.resize(shape_rank); + } else if (iree_status_is_out_of_range(shape_result)) { + shape.resize(shape_rank); + IREE_RETURN_IF_ERROR( + iree_hal_parse_shape({shape_str.data(), shape_str.length()}, + shape.size(), shape.data(), &shape_rank)); + } else { + return shape_result; + } + + // f32, i32, etc + iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE; + IREE_RETURN_IF_ERROR(iree_hal_parse_element_type( + {type_str.data(), type_str.length()}, &element_type)); + + // Allocate the buffer we will parse into from the provided allocator. + iree_device_size_t buffer_length = 0; + IREE_RETURN_IF_ERROR(iree_hal_buffer_compute_view_size( + shape.data(), shape.size(), element_type, &buffer_length)); + iree_hal_buffer_t* buffer = nullptr; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + buffer_allocator, + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + buffer_length, &buffer)); + + iree_status_t status; + + // Parse the elements directly into the buffer. + iree_hal_buffer_mapping_t buffer_mapping; + status = + iree_hal_buffer_map_range(buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0, + buffer_length, &buffer_mapping); + if (!iree_status_is_ok(status)) { + iree_hal_buffer_release(buffer); + return status; + } + status = + iree_hal_parse_buffer_elements({data_str.data(), data_str.length()}, + element_type, buffer_mapping.contents); + iree_hal_buffer_unmap_range(&buffer_mapping); + if (!iree_status_is_ok(status)) { + iree_hal_buffer_release(buffer); + return status; + } + + // Wrap and pass ownership of the buffer to the buffer view. + status = iree_hal_buffer_view_create(buffer, shape.data(), shape.size(), + element_type, out_buffer_view); + iree_hal_buffer_release(buffer); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_format( + const iree_hal_buffer_view_t* buffer_view, + iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_format"); + IREE_ASSERT_ARGUMENT(buffer_view); + if (out_buffer_length) { + *out_buffer_length = 0; + } + if (buffer && buffer_capacity) { + buffer[0] = 0; + } + + iree_status_t status; + iree_host_size_t buffer_length = 0; + auto append_char = [&](char c) { + if (buffer) { + if (buffer_length < buffer_capacity - 1) { + buffer[buffer_length] = c; + buffer[buffer_length + 1] = '\0'; + } else { + buffer = nullptr; + } + } + ++buffer_length; + }; + + if (iree_hal_buffer_view_shape_rank(buffer_view) > 0) { + // Shape: 1x2x3 + iree_host_size_t shape_length = 0; + status = iree_hal_format_shape(iree_hal_buffer_view_shape_dims(buffer_view), + iree_hal_buffer_view_shape_rank(buffer_view), + buffer ? buffer_capacity - buffer_length : 0, + buffer ? buffer + buffer_length : nullptr, + &shape_length); + buffer_length += shape_length; + if (iree_status_is_out_of_range(status)) { + status = iree_status_ignore(status); + buffer = nullptr; + } else if (!iree_status_is_ok(status)) { + return status; + } + + // Separator: x + append_char('x'); + } + + // Element type: f32 + iree_host_size_t element_type_length = 0; + status = iree_hal_format_element_type( + iree_hal_buffer_view_element_type(buffer_view), + buffer ? buffer_capacity - buffer_length : 0, + buffer ? buffer + buffer_length : nullptr, &element_type_length); + buffer_length += element_type_length; + if (iree_status_is_out_of_range(status)) { + status = iree_status_ignore(status); + buffer = nullptr; + } else if (!iree_status_is_ok(status)) { + return status; + } + + // Separator: = + append_char('='); + + // Buffer contents: 0 1 2 3 ... + iree_hal_buffer_mapping_t buffer_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + iree_hal_buffer_view_buffer(buffer_view), IREE_HAL_MEMORY_ACCESS_READ, 0, + IREE_WHOLE_BUFFER, &buffer_mapping)); + iree_host_size_t elements_length = 0; + status = iree_hal_format_buffer_elements( + iree_const_byte_span_t{buffer_mapping.contents.data, + buffer_mapping.contents.data_length}, + iree_hal_buffer_view_shape_dims(buffer_view), + iree_hal_buffer_view_shape_rank(buffer_view), + iree_hal_buffer_view_element_type(buffer_view), max_element_count, + buffer ? buffer_capacity - buffer_length : 0, + buffer ? buffer + buffer_length : nullptr, &elements_length); + buffer_length += elements_length; + iree_hal_buffer_unmap_range(&buffer_mapping); + if (iree_status_is_out_of_range(status)) { + status = iree_status_ignore(status); + buffer = nullptr; + } else if (!iree_status_is_ok(status)) { + return status; + } + + if (out_buffer_length) { + *out_buffer_length = buffer_length; + } + return buffer ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); +} diff --git a/iree/hal/buffer_view.h b/iree/hal/buffer_view.h new file mode 100644 index 0000000000000..8e76cc64449b8 --- /dev/null +++ b/iree/hal/buffer_view.h @@ -0,0 +1,241 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_BUFFER_VIEW_H_ +#define IREE_HAL_BUFFER_VIEW_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/buffer.h" +#include "iree/hal/resource.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +// NOTE: these values must be in sync with +// iree/compiler/Dialect/HAL/IR/HALTypes.cpp + +enum iree_hal_numerical_type_e { + IREE_HAL_NUMERICAL_TYPE_UNKNOWN = 0x00u, + IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED = 0x01u, + IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED = 0x02u, + // TODO(benvanik): specialize with semantics from APFloat. + IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE = 0x03u, +}; +typedef uint8_t iree_hal_numerical_type_t; + +#define IREE_HAL_ELEMENT_TYPE_VALUE(numerical_type, bit_count) \ + (((uint32_t)(numerical_type) << 24) | (uint32_t)(bit_count)) + +#define iree_hal_make_element_type(numerical_type, bit_count) \ + (iree_hal_element_type_t)( \ + IREE_HAL_ELEMENT_TYPE_VALUE(numerical_type, bit_count)) +#define iree_hal_element_numerical_type(element_type) \ + (iree_hal_numerical_type_t)((uint32_t)(element_type) >> 24) +#define iree_hal_element_bit_count(element_type) (size_t)((element_type)&0xFF) +#define iree_hal_element_byte_count(element_type) \ + ((iree_hal_element_bit_count(element_type) + 8 - 1) / 8) + +// Defines the element type of a buffer in a standard format. +// +// Composed as a 32-bit bitfield to allow for opaque data types. Use +// iree_hal_make_element_type to make a bitfield with the appropriate ordering. +// +// MSB ----------------------------------------------- LSB +// [numerical type] [reserved] [reserved] [number of bits] +// +// clang-format off +enum iree_hal_element_type_e { + IREE_HAL_ELEMENT_TYPE_NONE = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 0), // NOLINT + IREE_HAL_ELEMENT_TYPE_OPAQUE_8 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_OPAQUE_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 16), // NOLINT + IREE_HAL_ELEMENT_TYPE_OPAQUE_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 32), // NOLINT + IREE_HAL_ELEMENT_TYPE_OPAQUE_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, 64), // NOLINT + IREE_HAL_ELEMENT_TYPE_SINT_8 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_UINT_8 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 8), // NOLINT + IREE_HAL_ELEMENT_TYPE_SINT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 16), // NOLINT + IREE_HAL_ELEMENT_TYPE_UINT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 16), // NOLINT + IREE_HAL_ELEMENT_TYPE_SINT_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 32), // NOLINT + IREE_HAL_ELEMENT_TYPE_UINT_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 32), // NOLINT + IREE_HAL_ELEMENT_TYPE_SINT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, 64), // NOLINT + IREE_HAL_ELEMENT_TYPE_UINT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, 64), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_16 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, 16), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_32 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, 32), // NOLINT + IREE_HAL_ELEMENT_TYPE_FLOAT_64 = IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, 64), // NOLINT +}; +typedef uint32_t iree_hal_element_type_t; +// clang-format on + +// A dimension within a shape. +typedef int32_t iree_hal_dim_t; + +//===----------------------------------------------------------------------===// +// Buffer view math +//===----------------------------------------------------------------------===// + +// Calculates the allocation size of a buffer view. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_size( + const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_hal_element_type_t element_type, + iree_device_size_t* out_allocation_size); + +// Calculates a byte offset into a buffer at the given indices. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_offset( + const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_hal_element_type_t element_type, const iree_hal_dim_t* indices, + size_t indices_count, iree_device_size_t* out_offset); + +// Calculates a byte range into a buffer of the given contiguous range. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_compute_view_range( + const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_hal_element_type_t element_type, const iree_hal_dim_t* start_indices, + iree_host_size_t indices_count, const iree_hal_dim_t* lengths, + iree_host_size_t lengths_count, iree_device_size_t* out_start_offset, + iree_device_size_t* out_length); + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_t +//===----------------------------------------------------------------------===// + +// A shaped and typed view into a storage buffer. +// This is the closest thing to a "tensor" we have, and it's purely used to ease +// application code and not treated special internally by IREE. They are +// effectively just `tuple(shape, type, buffer)`, and if the application is +// already tracking this information in its own structures this entire type can +// be ignored. +typedef struct iree_hal_buffer_view_s iree_hal_buffer_view_t; + +// Creates a buffer view with the given |buffer|. +// |out_buffer_view| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_create( + iree_hal_buffer_t* buffer, const iree_hal_dim_t* shape, + iree_host_size_t shape_rank, iree_hal_element_type_t element_type, + iree_hal_buffer_view_t** out_buffer_view); + +// Creates a buffer view referencing a subview of the given |buffer_view|. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_subview( + const iree_hal_buffer_view_t* buffer_view, + const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, + const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, + iree_hal_buffer_view_t** out_buffer_view); + +// Retains the given |buffer_view| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view); + +// Releases the given |buffer_view| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view); + +// Returns the buffer underlying the buffer view. +// The caller must retain the returned buffer if they want to continue using it. +IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL +iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view); + +// Returns the rank of the shape associated with the buffer view. +IREE_API_EXPORT iree_host_size_t IREE_API_CALL +iree_hal_buffer_view_shape_rank(const iree_hal_buffer_view_t* buffer_view); + +// Returns a pointer to the shape dimensions; the array limit is defined by +// iree_hal_buffer_view_shape_rank. +IREE_API_EXPORT const iree_hal_dim_t* IREE_API_CALL +iree_hal_buffer_view_shape_dims(const iree_hal_buffer_view_t* buffer_view); + +// Returns the value of the given dimension. +IREE_API_EXPORT iree_hal_dim_t IREE_API_CALL iree_hal_buffer_view_shape_dim( + const iree_hal_buffer_view_t* buffer_view, iree_host_size_t index); + +// Returns the dimensions of the shape in |out_shape| and its rank in +// |out_shape_rank|. |rank_capacity| indicates the number of dimensions +// available in the |out_shape| buffer. If there is not enough capacity to store +// all of the dimensions IREE_STATUS_OUT_OF_RANGE is returned. +// |out_shape_rank| can be omitted if the rank is already known. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape( + const iree_hal_buffer_view_t* buffer_view, iree_host_size_t rank_capacity, + iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank); + +// Returns the total number of elements stored in the view. +IREE_API_EXPORT iree_host_size_t +iree_hal_buffer_view_element_count(const iree_hal_buffer_view_t* buffer_view); + +// Returns the element type of the buffer. +IREE_API_EXPORT iree_hal_element_type_t IREE_API_CALL +iree_hal_buffer_view_element_type(const iree_hal_buffer_view_t* buffer_view); + +// Returns the size of each element in the buffer view in bytes. +// Note that not all buffers are contiguous or densely packed. +IREE_API_EXPORT iree_host_size_t IREE_API_CALL +iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view); + +// Returns the total size of the specified view in bytes. +// Note that not all buffers are contiguous or densely packed. +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_view_byte_length(const iree_hal_buffer_view_t* buffer_view); + +// Calculates a byte offset into the |buffer_view| at the given indices. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_offset( + const iree_hal_buffer_view_t* buffer_view, const iree_hal_dim_t* indices, + iree_host_size_t indices_count, iree_device_size_t* out_offset); + +// Calculates a byte range into the |buffer_view| of the given contiguous range. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_compute_range( + const iree_hal_buffer_view_t* buffer_view, + const iree_hal_dim_t* start_indices, iree_host_size_t indices_count, + const iree_hal_dim_t* lengths, iree_host_size_t lengths_count, + iree_device_size_t* out_start_offset, iree_device_size_t* out_length); + +// Parses a serialized set of buffer elements in the canonical tensor format +// (the same as produced by iree_hal_buffer_view_format). The underlying buffer +// will be allocated with |buffer_allocator| as a host-local/device-visible +// buffer. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_parse( + iree_string_view_t value, iree_hal_allocator_t* buffer_allocator, + iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view); + +// Converts buffer view elements into a fully-specified string-form format like +// `2x4xi16=[[1 2][3 4]]`. +// +// |max_element_count| can be used to limit the total number of elements printed +// when the count may be large. Elided elements will be replaced with `...`. +// +// |buffer_capacity| defines the size of |buffer| in bytes and +// |out_buffer_length| will return the string length in characters. Returns +// IREE_STATUS_OUT_OF_RANGE if the buffer capacity is insufficient to hold the +// formatted elements and |out_buffer_length| will contain the required size. +// +// Follows the standard API string formatting rules. See iree/base/api.h. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_format( + const iree_hal_buffer_view_t* buffer_view, + iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length); + +//===----------------------------------------------------------------------===// +// iree_hal_buffer_view_t implementation details +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_buffer_view_destroy(iree_hal_buffer_view_t* buffer_view); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_BUFFER_VIEW_H_ diff --git a/iree/hal/command_buffer.c b/iree/hal/command_buffer.c new file mode 100644 index 0000000000000..c8a605817413b --- /dev/null +++ b/iree/hal/command_buffer.c @@ -0,0 +1,272 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/command_buffer.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(command_buffer, method_name) \ + IREE_HAL_VTABLE_DISPATCH(command_buffer, iree_hal_command_buffer, method_name) + +IREE_HAL_API_RETAIN_RELEASE(command_buffer); + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_create( + iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, create_command_buffer)( + device, mode, command_categories, out_command_buffer); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_hal_command_category_t IREE_API_CALL +iree_hal_command_buffer_allowed_categories( + const iree_hal_command_buffer_t* command_buffer) { + IREE_ASSERT_ARGUMENT(command_buffer); + return _VTABLE_DISPATCH(command_buffer, allowed_categories)(command_buffer); +} + +IREE_API_EXPORT iree_status_t +iree_hal_command_buffer_begin(iree_hal_command_buffer_t* command_buffer) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + _VTABLE_DISPATCH(command_buffer, begin)(command_buffer); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t +iree_hal_command_buffer_end(iree_hal_command_buffer_t* command_buffer) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, end)(command_buffer); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_execution_barrier( + iree_hal_command_buffer_t* command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, execution_barrier)( + command_buffer, source_stage_mask, target_stage_mask, + memory_barrier_count, memory_barriers, buffer_barrier_count, + buffer_barriers); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_signal_event( + iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(event); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, signal_event)( + command_buffer, event, source_stage_mask); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_reset_event( + iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(event); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, reset_event)( + command_buffer, event, source_stage_mask); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_wait_events( + iree_hal_command_buffer_t* command_buffer, iree_host_size_t event_count, + const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(!event_count || events); + IREE_ASSERT_ARGUMENT(!memory_barrier_count || memory_barriers); + IREE_ASSERT_ARGUMENT(!buffer_barrier_count || buffer_barriers); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, wait_events)( + command_buffer, event_count, events, source_stage_mask, target_stage_mask, + memory_barrier_count, memory_barriers, buffer_barrier_count, + buffer_barriers); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_discard_buffer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* buffer) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + _VTABLE_DISPATCH(command_buffer, discard_buffer)(command_buffer, buffer); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_fill_buffer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length, + const void* pattern, iree_host_size_t pattern_length) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(target_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, fill_buffer)( + command_buffer, target_buffer, target_offset, length, pattern, + pattern_length); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_update_buffer(iree_hal_command_buffer_t* command_buffer, + const void* source_buffer, + iree_host_size_t source_offset, + iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, + iree_device_size_t length) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, update_buffer)( + command_buffer, source_buffer, source_offset, target_buffer, + target_offset, length); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_copy_buffer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, + iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, copy_buffer)( + command_buffer, source_buffer, source_offset, target_buffer, + target_offset, length); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_push_constants( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(executable_layout); + IREE_ASSERT_ARGUMENT(values); + if (IREE_UNLIKELY(values_length == 0)) { + return iree_ok_status(); + } + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, push_constants)( + command_buffer, executable_layout, offset, values, values_length); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(executable_layout); + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, push_descriptor_set)( + command_buffer, executable_layout, set, binding_count, bindings); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_bind_descriptor_set( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(executable_layout); + IREE_ASSERT_ARGUMENT(descriptor_set); + IREE_ASSERT_ARGUMENT(!dynamic_offset_count || dynamic_offsets); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, bind_descriptor_set)( + command_buffer, executable_layout, set, descriptor_set, + dynamic_offset_count, dynamic_offsets); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_dispatch( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(executable); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch)( + command_buffer, executable, entry_point, workgroup_x, workgroup_y, + workgroup_z); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(executable); + IREE_ASSERT_ARGUMENT(workgroups_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch_indirect)( + command_buffer, executable, entry_point, workgroups_buffer, + workgroups_offset); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/command_buffer.cc b/iree/hal/command_buffer.cc deleted file mode 100644 index d83810a57c9e3..0000000000000 --- a/iree/hal/command_buffer.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { - -std::string CommandCategoryString(CommandCategoryBitfield categories) { - return FormatBitfieldValue(categories, - { - {CommandCategory::kTransfer, "kTransfer"}, - {CommandCategory::kDispatch, "kDispatch"}, - }); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/command_buffer.h b/iree/hal/command_buffer.h index a94b53278eee3..7f767e26195dd 100644 --- a/iree/hal/command_buffer.h +++ b/iree/hal/command_buffer.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,10 +15,10 @@ #ifndef IREE_HAL_COMMAND_BUFFER_H_ #define IREE_HAL_COMMAND_BUFFER_H_ -#include +#include +#include -#include "iree/base/bitfield.h" -#include "iree/base/status.h" +#include "iree/base/api.h" #include "iree/hal/buffer.h" #include "iree/hal/descriptor_set.h" #include "iree/hal/event.h" @@ -26,78 +26,89 @@ #include "iree/hal/executable_layout.h" #include "iree/hal/resource.h" -namespace iree { -namespace hal { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_s iree_hal_device_t; + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// // A bitfield specifying the mode of operation for a command buffer. -enum class CommandBufferMode : uint32_t { +enum iree_hal_command_buffer_mode_e { // Command buffer will be submitted once and never used again. - // This may enable in-place patching of command buffers that reduces overhead + // This may enable in-place patching of command buffers that reduce overhead // when it's known that command buffers will not be reused. - kOneShot = 1 << 0, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT = 1u << 0, + + // TODO(benvanik): IREE_HAL_COMMAND_BUFFER_MODE_REUSABLE = 1u << 1, + // TODO(benvanik): IREE_HAL_COMMAND_BUFFER_MODE_PRIMARY = 1u << 2, + // TODO(benvanik): IREE_HAL_COMMAND_BUFFER_MODE_SECONDARY = 1u << 3, }; -IREE_BITFIELD(CommandBufferMode); -using CommandBufferModeBitfield = CommandBufferMode; -std::string CommandBufferModeString(CommandBufferModeBitfield mode); +typedef uint32_t iree_hal_command_buffer_mode_t; // A bitfield specifying the category of commands in a command queue. -enum class CommandCategory : uint32_t { +enum iree_hal_command_category_e { // Command is considered a transfer operation (memcpy, etc). - kTransfer = 1 << 0, + IREE_HAL_COMMAND_CATEGORY_TRANSFER = 1u << 0, // Command is considered a dispatch operation (dispatch/execute). - kDispatch = 1 << 1, + IREE_HAL_COMMAND_CATEGORY_DISPATCH = 1u << 1, + // Commands may be of any type. + // Using this value may prevent optimizations and if possible callers should + // always specify the strictest set possible (for example, only transfer + // commands to ensure they get placed on a DMA queue). + IREE_HAL_COMMAND_CATEGORY_ANY = + IREE_HAL_COMMAND_CATEGORY_TRANSFER | IREE_HAL_COMMAND_CATEGORY_DISPATCH, }; -IREE_BITFIELD(CommandCategory); -using CommandCategoryBitfield = CommandCategory; -std::string CommandCategoryString(CommandCategoryBitfield categories); +typedef uint32_t iree_hal_command_category_t; -// Bitfield specifying which execution stage a brarrier should start/end at. +// Bitfield specifying which execution stage a barrier should start/end at. // // Maps to VkPipelineStageFlagBits. -enum class ExecutionStage : uint32_t { +enum iree_hal_execution_stage_e { // Top of the pipeline when commands are initially issued by the device. - kCommandIssue = 1 << 0, + IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE = 1u << 0, // Stage of the pipeline when dispatch parameter data is consumed. - kCommandProcess = 1 << 1, + IREE_HAL_EXECUTION_STAGE_COMMAND_PROCESS = 1u << 1, // Stage where dispatch commands execute. - kDispatch = 1 << 2, + IREE_HAL_EXECUTION_STAGE_DISPATCH = 1u << 2, // Stage where transfer (copy/clear/fill/etc) commands execute. - kTransfer = 1 << 3, + IREE_HAL_EXECUTION_STAGE_TRANSFER = 1u << 3, // Final stage in the pipeline when commands are retired on the device. - kCommandRetire = 1 << 4, + IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE = 1u << 4, // Pseudo-stage for read/writes by the host. Not executed on device. - kHost = 1 << 5, + IREE_HAL_EXECUTION_STAGE_HOST = 1u << 5, }; -IREE_BITFIELD(ExecutionStage); -using ExecutionStageBitfield = ExecutionStage; +typedef uint32_t iree_hal_execution_stage_t; // Bitfield specifying which scopes will access memory and how. // // Maps to VkAccessFlagBits. -enum class AccessScope : uint32_t { +enum iree_hal_access_scope_e { // Read access to indirect command data as part of an indirect dispatch. - kIndirectCommandRead = 1 << 0, + IREE_HAL_ACCESS_SCOPE_INDIRECT_COMMAND_READ = 1u << 0, // Constant uniform buffer reads by the device. - kConstantRead = 1 << 1, + IREE_HAL_ACCESS_SCOPE_CONSTANT_READ = 1u << 1, // Storage buffer reads by dispatch commands. - kDispatchRead = 1 << 2, + IREE_HAL_ACCESS_SCOPE_DISPATCH_READ = 1u << 2, // Storage buffer writes by dispatch commands. - kDispatchWrite = 1 << 3, + IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE = 1u << 3, // Source of a transfer operation. - kTransferRead = 1 << 4, + IREE_HAL_ACCESS_SCOPE_TRANSFER_READ = 1u << 4, // Target of a transfer operation. - kTransferWrite = 1 << 5, + IREE_HAL_ACCESS_SCOPE_TRANSFER_WRITE = 1u << 5, // Read operation by the host through mapped memory. - kHostRead = 1 << 6, + IREE_HAL_ACCESS_SCOPE_HOST_READ = 1u << 6, // Write operation by the host through mapped memory. - kHostWrite = 1 << 7, + IREE_HAL_ACCESS_SCOPE_HOST_WRITE = 1u << 7, // External/non-specific read. - kMemoryRead = 1 << 8, + IREE_HAL_ACCESS_SCOPE_MEMORY_READ = 1u << 8, // External/non-specific write. - kMemoryWrite = 1 << 9, + IREE_HAL_ACCESS_SCOPE_MEMORY_WRITE = 1u << 9, }; -IREE_BITFIELD(AccessScope); -using AccessScopeBitfield = AccessScope; +typedef uint32_t iree_hal_access_scope_t; // Defines a global memory barrier. // These are cheaper to encode than buffer-specific barriers but may cause @@ -106,12 +117,12 @@ using AccessScopeBitfield = AccessScope; // completely changing execution contexts). // // Maps to VkMemoryBarrier. -struct MemoryBarrier { +typedef struct { // All access scopes prior-to the barrier (inclusive). - AccessScopeBitfield source_scope; + iree_hal_access_scope_t source_scope; // All access scopes following the barrier (inclusive). - AccessScopeBitfield target_scope; -}; + iree_hal_access_scope_t target_scope; +} iree_hal_memory_barrier_t; // Defines a memory barrier that applies to a range of a specific buffer. // Use of these (vs. global memory barriers) provides fine-grained execution @@ -119,19 +130,30 @@ struct MemoryBarrier { // reordering. // // Maps to VkBufferMemoryBarrier. -struct BufferBarrier { +typedef struct { // All access scopes prior-to the barrier (inclusive). - AccessScopeBitfield source_scope; + iree_hal_access_scope_t source_scope; // All access scopes following the barrier (inclusive). - AccessScopeBitfield target_scope; + iree_hal_access_scope_t target_scope; // Buffer the barrier is restricted to. // The barrier will apply to the entire physical device allocation. - Buffer* buffer = nullptr; + iree_hal_buffer_t* buffer; // Relative offset/length within |buffer| (which may itself be mapped into the // device allocation at an offset). - device_size_t offset = 0; - device_size_t length = kWholeBuffer; -}; + iree_device_size_t offset; + iree_device_size_t length; +} iree_hal_buffer_barrier_t; + +// TODO(benvanik): replace with tables for iree_string_builder_*. +#define iree_hal_command_buffer_mode_string(...) "TODO" +// {IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, "ONE_SHOT"}, +#define iree_hal_command_category_string(...) "TODO" +// {IREE_HAL_COMMAND_CATEGORY_TRANSFER, "TRANSFER"}, +// {IREE_HAL_COMMAND_CATEGORY_DISPATCH, "DISPATCH"}, + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_t +//===----------------------------------------------------------------------===// // Asynchronous command buffer recording interface. // Commands are recorded by the implementation for later submission to command @@ -159,196 +181,337 @@ struct BufferBarrier { // to record commands from multiple threads. Command buffers must not be mutated // between when they have are submitted for execution on a queue and when the // semaphore fires indicating the completion of their execution. -class CommandBuffer : public Resource { - public: - virtual CommandBuffer* impl() { return this; } - - // Command buffer operation mode. - CommandBufferModeBitfield mode() const { return mode_; } - - // Command categories that may be recorded into the buffer. - CommandCategoryBitfield command_categories() const { - return command_categories_; - } - - // True if the command buffer is between a Begin/End recording block. - virtual bool is_recording() const = 0; - - // Resets and begins recording into the command buffer, clearing all - // previously recorded contents. - // The command buffer must not be in-flight. - virtual Status Begin() = 0; - - // Ends recording into the command buffer. - // This must be called prior to submitting the command buffer for execution. - virtual Status End() = 0; - - // TODO(benvanik): annotations for debugging and tracing: - // enter/exit - // stack frame manipulation - // explicit timers? or profiling buffer? - - // TODO(b/138719910): cross-queue and external acquire/release. - // virtual Status AcquireBuffer() = 0; - // virtual Status ReleaseBuffer() = 0; - - // Defines a memory dependency between commands recorded before and after the - // barrier. One or more memory or buffer barriers can be specified to indicate - // between which stages or buffers the dependencies exist. - virtual Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) = 0; - - // Sets an event to the signaled state. - // |source_stage_mask| specifies when the event is signaled. - // - // Events are only valid within a single command buffer. Events can only be - // used on non-transfer queues. - virtual Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) = 0; - - // Resets an event to the non-signaled state. - // |source_stage_mask| specifies when the event is unsignaled. - // - // Events are only valid within a single command buffer. Events can only be - // used on non-transfer queues. - virtual Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) = 0; - - // Waits for one or more events to be signaled and defines a memory dependency - // between the synchronization scope of the signal operations and the commands - // following the wait. - // - // |source_stage_mask| must include ExecutionStage::kHost for Event::Signal to - // be visibile. - // - // Events are only valid within a single command buffer. Events remain - // signaled even after waiting and must be reset to be reused. Events can only - // be used on non-transfer queues. - virtual Status WaitEvents( - absl::Span events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) = 0; - - // Fills the target buffer with the given repeating value. - // Expects that value_length is one of 1, 2, or 4 and that the offset and - // length are aligned to the natural alignment of the value. - // The target buffer must be compatible with the devices owned by this - // device queue and be allocated with BufferUsage::kTransfer. - virtual Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) = 0; - - // Hints to the device queue that the given buffer will not be used again. - // After encoding a discard the buffer contents will be considered undefined. - // This is because the discard may be used to elide write backs to host memory - // or aggressively reuse the allocation for other purposes. - // - // For buffers allocated with MemoryType::kTransient this may allow - // the device queue to reclaim the memory used by the buffer earlier than - // otherwise possible. - virtual Status DiscardBuffer(Buffer* buffer) = 0; - - // Updates a range of the given target buffer from the source host memory. - // The source host memory is copied immediately into the command buffer and - // occupies command buffer space. It is strongly recommended that large buffer - // updates are performed via CopyBuffer where there is the possibility of a - // zero-copy path. - // The |source_buffer| may be releaed by the caller immediately after this - // call returns. - // The |target_buffer| must be compatible with the devices owned by this - // device queue and be allocated with BufferUsage::kTransfer. - virtual Status UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) = 0; - - // Copies a range of one buffer to another. - // Both buffers must be compatible with the devices owned by this device - // queue and be allocated with BufferUsage::kTransfer. Though the source and - // target buffer may be the same the ranges must not overlap (as with memcpy). - // - // This can be used to perform device->host, host->device, and device->device - // copies. - virtual Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) = 0; - - // Pushes an inline set of constants that can be accessed by subsequent - // dispatches using a compatible executable layout. - // - // Push constants are always 4-byte values and treated as opaque, meaning that - // they may be bit-casted floats, bit-packed booleans, etc. - virtual Status PushConstants(ExecutableLayout* executable_layout, - size_t offset, - absl::Span values) = 0; - - // Pushes a descriptor set and associates it with |set|. - // This uses an internal ringbuffer inside of the command buffer to avoid the - // need for creating and binding descriptor sets and managing their lifetime. - // - // The descriptor set will remain bound and valid so long as the executable - // layouts used by dispatches are compatible (same descriptor layouts and push - // constant sizes). - virtual Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) = 0; - - // Binds a descriptor set to the given |set| matching that used in the - // executable layout interface. - // - // The descriptor set will remain bound and valid so long as the executable - // layouts used by dispatches are compatible (same descriptor layouts and push - // constant sizes). - // - // If any dynamic descriptor types are defined in the descriptor set layout - // then the dynamic offsets must be provided. These offsets will be added to - // the base offset of the descriptor layout binding. - virtual Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) = 0; - - // Dispatches an execution request. - // The request may execute overlapped with any other transfer operation or - // dispatch made within the same barrier-defined sequence. - // - // The executable specified must be registered for use with the device driver - // owning this queue. It must not be unregistered until all requests that use - // it have completed. - // - // Fails if the queue does not support dispatch operations (as indicated by - // can_dispatch). - virtual Status Dispatch(Executable* executable, int32_t entry_point, - std::array workgroups) = 0; - - // Dispatches an execution request with deferred workgroup counts. - // This is the same as Dispatch but the workgroup counts are read from the - // given |workgroups_buffer| at offset |workgroups_offset| as 3 uint32_t XYZ - // values before performing the dispatch. This allows prior dispatches within - // the command sequence to populate the workgroup counts. - // - // The buffer must have been allocated with BufferUsage::kDispatch and be - // of MemoryType::kDeviceVisible. - virtual Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) = 0; - - protected: - CommandBuffer(CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) - : mode_(mode), command_categories_(command_categories) {} - - private: - const CommandBufferModeBitfield mode_; - const CommandCategoryBitfield command_categories_; -}; - -} // namespace hal -} // namespace iree +typedef struct iree_hal_command_buffer_s iree_hal_command_buffer_t; + +// Creates a command buffer ready to begin recording, possibly reusing an +// existing one from the |device| pool. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_create( + iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer); + +// Retains the given |command_buffer| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_command_buffer_retain(iree_hal_command_buffer_t* command_buffer); + +// Releases the given |command_buffer| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_command_buffer_release(iree_hal_command_buffer_t* command_buffer); + +// Returns a bitmask indicating which command categories this command buffer +// can record. +IREE_API_EXPORT iree_hal_command_category_t IREE_API_CALL +iree_hal_command_buffer_allowed_categories( + const iree_hal_command_buffer_t* command_buffer); + +// Resets and begins recording into the command buffer, clearing all +// previously recorded contents. +// The command buffer must not be in-flight. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_begin(iree_hal_command_buffer_t* command_buffer); + +// Ends recording into the command buffer. +// This must be called prior to submitting the command buffer for execution. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_end(iree_hal_command_buffer_t* command_buffer); + +// Defines a memory dependency between commands recorded before and after the +// barrier. One or more memory or buffer barriers can be specified to indicate +// between which stages or buffers the dependencies exist. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_execution_barrier( + iree_hal_command_buffer_t* command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers); + +// Sets an event to the signaled state. +// |source_stage_mask| specifies when the event is signaled. +// +// Events are only valid within a single command buffer. Events can only be +// used on non-transfer queues. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_signal_event( + iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask); + +// Resets an event to the non-signaled state. +// |source_stage_mask| specifies when the event is unsignaled. +// +// Events are only valid within a single command buffer. Events can only be +// used on non-transfer queues. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_reset_event( + iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask); + +// Waits for one or more events to be signaled and defines a memory dependency +// between the synchronization scope of the signal operations and the commands +// following the wait. +// +// |source_stage_mask| must include ExecutionStage::kHost for Event::Signal to +// be visibile. +// +// Events are only valid within a single command buffer. Events remain +// signaled even after waiting and must be reset to be reused. Events can only +// be used on non-transfer queues. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_wait_events( + iree_hal_command_buffer_t* command_buffer, iree_host_size_t event_count, + const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers); + +// Hints to the device queue that the given buffer will not be used again. +// After encoding a discard the buffer contents will be considered undefined. +// This is because the discard may be used to elide write backs to host memory +// or aggressively reuse the allocation for other purposes. +// +// For buffers allocated with IREE_HAL_MEMORY_TYPE_TRANSIENT this may allow +// the device queue to reclaim the memory used by the buffer earlier than +// otherwise possible. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_discard_buffer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* buffer); + +// Fills the target buffer with the given repeating value. +// Expects that |pattern_length| is one of 1, 2, or 4 and that the offset and +// length are aligned to the natural alignment of the value. +// The target buffer must be compatible with the devices owned by this +// device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_fill_buffer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length, + const void* pattern, iree_host_size_t pattern_length); + +// Updates a range of the given target buffer from the source host memory. +// The source host memory is copied immediately into the command buffer and +// occupies command buffer space. It is strongly recommended that large buffer +// updates are performed via iree_hal_command_buffer_copy_buffer where there is +// the possibility of a zero-copy path. +// The |source_buffer| may be releaed by the caller immediately after this +// call returns. +// The |target_buffer| must be compatible with the devices owned by this +// device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_update_buffer(iree_hal_command_buffer_t* command_buffer, + const void* source_buffer, + iree_host_size_t source_offset, + iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, + iree_device_size_t length); + +// Copies a range of one buffer to another. +// Both buffers must be compatible with the devices owned by this device +// queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. Though the source +// and target buffer may be the same the ranges must not overlap (as with +// memcpy). +// +// This can be used to perform device->host, host->device, and device->device +// copies. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_copy_buffer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, + iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length); + +// Pushes an inline set of constants that can be accessed by subsequent +// dispatches using a compatible executable layout. +// +// Push constants are treated as opaque bytes, meaning that they may be +// bit-casted floats, bit-packed booleans, etc. |offset| and |values_length| are +// in bytes. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_push_constants( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length); + +// Pushes a descriptor set and associates it with |set|. +// This uses an internal ringbuffer inside of the command buffer to avoid the +// need for creating and binding descriptor sets and managing their lifetime. +// +// The descriptor set will remain bound and valid so long as the executable +// layouts used by dispatches are compatible (same descriptor layouts and push +// constant sizes). +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings); + +// Binds a descriptor set to the given |set| matching that used in the +// executable layout interface. +// +// The descriptor set will remain bound and valid so long as the executable +// layouts used by dispatches are compatible (same descriptor layouts and push +// constant sizes). +// +// If any dynamic descriptor types are defined in the descriptor set layout then +// the dynamic offsets must be provided. These offsets will be added to the base +// offset of the descriptor layout binding. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_bind_descriptor_set( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets); + +// Dispatches an execution request. +// The request may execute overlapped with any other transfer operation or +// dispatch made within the same barrier-defined sequence. +// +// The executable specified must be registered for use with the device driver +// owning this queue. It must not be unregistered until all requests that use +// it have completed. +// +// Fails if the queue does not support dispatch operations (as indicated by +// can_dispatch). +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_command_buffer_dispatch( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + +// Dispatches an execution request with deferred workgroup counts. +// This is the same as iree_hal_command_buffer_dispatch but the workgroup counts +// are read from the given |workgroups_buffer| at offset |workgroups_offset| as +// 3 uint32_t XYZ values before performing the dispatch. This allows prior +// dispatches within the command sequence to populate the workgroup counts. +// +// The buffer must have been allocated with IREE_HAL_BUFFER_USAGE_DISPATCH and +// be of IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset); + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_t validation wrapper +//===----------------------------------------------------------------------===// + +// Wraps |target_command_buffer| with a validation layer that checks the +// parameters to each call in an attempt to return errors where usage may result +// in failed or incorrect execution. This layer adds many additional checks to +// each call but must be used when dealing with untrusted incoming commands. +// +// The validation is strictly input argument and permission-based and not a full +// verification of the correctness of any barriers or memory dependencies. A +// command buffer recording that has passed validation does not indicate that it +// is guaranteed to make forward progress or properly observe memory visibility +// or availability rules. Instead, validation ensures that no command references +// memory outside of the allowed ranges or accesses memory in violation of the +// allowed usage or access rights. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_wrap_validation( + iree_hal_device_t* device, iree_hal_command_buffer_t* target_command_buffer, + iree_hal_command_buffer_t** out_command_buffer); + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_command_buffer_t* command_buffer); + + iree_hal_command_category_t(IREE_API_PTR* allowed_categories)( + const iree_hal_command_buffer_t* command_buffer); + + iree_status_t(IREE_API_PTR* begin)(iree_hal_command_buffer_t* command_buffer); + + iree_status_t(IREE_API_PTR* end)(iree_hal_command_buffer_t* command_buffer); + + iree_status_t(IREE_API_PTR* execution_barrier)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers); + + iree_status_t(IREE_API_PTR* signal_event)( + iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask); + + iree_status_t(IREE_API_PTR* reset_event)( + iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask); + + iree_status_t(IREE_API_PTR* wait_events)( + iree_hal_command_buffer_t* command_buffer, iree_host_size_t event_count, + const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers); + + iree_status_t(IREE_API_PTR* discard_buffer)( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* buffer); + + iree_status_t(IREE_API_PTR* fill_buffer)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length); + + iree_status_t(IREE_API_PTR* update_buffer)( + iree_hal_command_buffer_t* command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length); + + iree_status_t(IREE_API_PTR* copy_buffer)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length); + + iree_status_t(IREE_API_PTR* push_constants)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length); + + iree_status_t(IREE_API_PTR* push_descriptor_set)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings); + + iree_status_t(IREE_API_PTR* bind_descriptor_set)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets); + + iree_status_t(IREE_API_PTR* dispatch)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + + iree_status_t(IREE_API_PTR* dispatch_indirect)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset); +} iree_hal_command_buffer_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_command_buffer_destroy(iree_hal_command_buffer_t* command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_COMMAND_BUFFER_H_ diff --git a/iree/hal/command_buffer_validation.c b/iree/hal/command_buffer_validation.c new file mode 100644 index 0000000000000..036d4490b3ea4 --- /dev/null +++ b/iree/hal/command_buffer_validation.c @@ -0,0 +1,542 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/base/tracing.h" +#include "iree/hal/allocator.h" +#include "iree/hal/command_buffer.h" +#include "iree/hal/device.h" + +typedef struct { + iree_hal_resource_t resource; + iree_hal_device_t* device; + iree_hal_command_buffer_t* target_command_buffer; + iree_hal_command_category_t allowed_categories; + + bool is_recording; + // TODO(benvanik): current executable layout/descriptor set layout info. + // TODO(benvanik): valid push constant bit ranges. +} iree_hal_validating_command_buffer_t; + +static const iree_hal_command_buffer_vtable_t + iree_hal_validating_command_buffer_vtable; + +// Returns success iff the queue supports the given command categories. +static iree_status_t iree_hal_command_buffer_validate_categories( + const iree_hal_validating_command_buffer_t* command_buffer, + iree_hal_command_category_t required_categories) { + if (!iree_all_bits_set(command_buffer->allowed_categories, + required_categories)) { + return iree_make_status( + IREE_STATUS_FAILED_PRECONDITION, + "operation requires categories %s but command buffer only supports %s", + iree_hal_command_category_string(required_categories), + iree_hal_command_category_string(command_buffer->allowed_categories)); + } + return iree_ok_status(); +} + +// Returns success iff the buffer is compatible with the device. +static iree_status_t iree_hal_command_buffer_validate_buffer_compatibility( + const iree_hal_validating_command_buffer_t* command_buffer, + iree_hal_buffer_t* buffer, + iree_hal_buffer_compatibility_t required_compatibility, + iree_hal_buffer_usage_t intended_usage) { + iree_hal_buffer_compatibility_t allowed_compatibility = + iree_hal_allocator_query_buffer_compatibility( + iree_hal_device_allocator(command_buffer->device), + iree_hal_buffer_memory_type(buffer), + iree_hal_buffer_allowed_usage(buffer), intended_usage, + iree_hal_buffer_allocation_size(buffer)); + if (!iree_all_bits_set(allowed_compatibility, required_compatibility)) { + // Buffer cannot be used on the queue for the given usage. + return iree_make_status( + IREE_STATUS_PERMISSION_DENIED, + "requested buffer usage is not supported for the buffer on this queue; " + "buffer allows %s, operation requires %s", + iree_hal_buffer_usage_string(iree_hal_buffer_allowed_usage(buffer)), + iree_hal_buffer_usage_string(intended_usage)); + } + return iree_ok_status(); +} + +// Returns success iff the currently bound descriptor sets are valid for the +// given executable entry point. +static iree_status_t iree_hal_command_buffer_validate_dispatch_bindings( + iree_hal_validating_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point) { + // TODO(benvanik): validate buffers referenced have compatible memory types, + // access rights, and usage. + // TODO(benvanik): validate no aliasing between inputs/outputs. + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_command_buffer_wrap_validation( + iree_hal_device_t* device, iree_hal_command_buffer_t* target_command_buffer, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(target_command_buffer); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_validating_command_buffer_t* command_buffer = NULL; + iree_status_t status = + iree_allocator_malloc(iree_hal_device_host_allocator(device), + sizeof(*command_buffer), (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_validating_command_buffer_vtable, + &command_buffer->resource); + command_buffer->device = device; + iree_hal_device_retain(command_buffer->device); + command_buffer->target_command_buffer = target_command_buffer; + iree_hal_command_buffer_retain(command_buffer->target_command_buffer); + command_buffer->allowed_categories = + iree_hal_command_buffer_allowed_categories( + command_buffer->target_command_buffer); + + command_buffer->is_recording = false; + } + + *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer; + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_validating_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + iree_allocator_t host_allocator = + iree_hal_device_host_allocator(command_buffer->device); + iree_hal_command_buffer_release(command_buffer->target_command_buffer); + iree_hal_device_release(command_buffer->device); + iree_allocator_free(host_allocator, command_buffer); + IREE_TRACE_ZONE_END(z0); +} + +static iree_hal_command_category_t +iree_hal_validating_command_buffer_allowed_categories( + const iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + return command_buffer->allowed_categories; +} + +static iree_status_t iree_hal_validating_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + if (command_buffer->is_recording) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "command buffer is already in a recording state"); + } + command_buffer->is_recording = true; + + return iree_hal_command_buffer_begin(command_buffer->target_command_buffer); +} + +static iree_status_t iree_hal_validating_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + if (!command_buffer->is_recording) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "command buffer is not in a recording state"); + } + command_buffer->is_recording = false; + + return iree_hal_command_buffer_end(command_buffer->target_command_buffer); +} + +static iree_status_t iree_hal_validating_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_ANY)); + + // TODO(benvanik): additional synchronization validation. + + return iree_hal_command_buffer_execution_barrier( + command_buffer->target_command_buffer, source_stage_mask, + target_stage_mask, memory_barrier_count, memory_barriers, + buffer_barrier_count, buffer_barriers); +} + +static iree_status_t iree_hal_validating_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + + // TODO(benvanik): additional synchronization validation. + + return iree_hal_command_buffer_signal_event( + command_buffer->target_command_buffer, event, source_stage_mask); +} + +static iree_status_t iree_hal_validating_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + + // TODO(benvanik): additional synchronization validation. + + return iree_hal_command_buffer_reset_event( + command_buffer->target_command_buffer, event, source_stage_mask); +} + +static iree_status_t iree_hal_validating_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + + // TODO(benvanik): additional synchronization validation. + + return iree_hal_command_buffer_wait_events( + command_buffer->target_command_buffer, event_count, events, + source_stage_mask, target_stage_mask, memory_barrier_count, + memory_barriers, buffer_barrier_count, buffer_barriers); +} + +static iree_status_t iree_hal_validating_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( + iree_hal_buffer_memory_type(buffer), + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)); + + return iree_hal_command_buffer_discard_buffer( + command_buffer->target_command_buffer, buffer); +} + +static iree_status_t iree_hal_validating_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_compatibility( + command_buffer, target_buffer, + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER, + IREE_HAL_BUFFER_USAGE_TRANSFER)); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( + iree_hal_buffer_memory_type(target_buffer), + IREE_HAL_MEMORY_ACCESS_WRITE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + iree_hal_buffer_allowed_access(target_buffer), + IREE_HAL_MEMORY_ACCESS_WRITE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(target_buffer), + IREE_HAL_BUFFER_USAGE_TRANSFER)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_validate_range(target_buffer, target_offset, length)); + + // Ensure the value length is supported. + if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "fill value length is not one of the supported " + "values (pattern_length=%zu)", + pattern_length); + } + + // Ensure the offset and length have an alignment matching the value length. + if ((target_offset % pattern_length) != 0 || (length % pattern_length) != 0) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "fill offset and/or length do not match the natural alignment of the " + "fill value (target_offset=%zu, length=%zu, pattern_length=%zu)", + target_offset, length, pattern_length); + } + + return iree_hal_command_buffer_fill_buffer( + command_buffer->target_command_buffer, target_buffer, target_offset, + length, pattern, pattern_length); +} + +static iree_status_t iree_hal_validating_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_compatibility( + command_buffer, target_buffer, + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER, + IREE_HAL_BUFFER_USAGE_TRANSFER)); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( + iree_hal_buffer_memory_type(target_buffer), + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + iree_hal_buffer_allowed_access(target_buffer), + IREE_HAL_MEMORY_ACCESS_WRITE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(target_buffer), + IREE_HAL_BUFFER_USAGE_TRANSFER)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_validate_range(target_buffer, target_offset, length)); + + return iree_hal_command_buffer_update_buffer( + command_buffer->target_command_buffer, source_buffer, source_offset, + target_buffer, target_offset, length); +} + +static iree_status_t iree_hal_validating_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_TRANSFER)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_compatibility( + command_buffer, source_buffer, + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER, + IREE_HAL_BUFFER_USAGE_TRANSFER)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_compatibility( + command_buffer, target_buffer, + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER, + IREE_HAL_BUFFER_USAGE_TRANSFER)); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + iree_hal_buffer_allowed_access(source_buffer), + IREE_HAL_MEMORY_ACCESS_READ)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(source_buffer), + IREE_HAL_BUFFER_USAGE_TRANSFER)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_validate_range(source_buffer, source_offset, length)); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(target_buffer), + IREE_HAL_BUFFER_USAGE_TRANSFER)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + iree_hal_buffer_allowed_access(target_buffer), + IREE_HAL_MEMORY_ACCESS_WRITE)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_validate_range(target_buffer, target_offset, length)); + + // At least source or destination must be device-visible to enable + // host->device, device->host, and device->device. + // TODO(b/117338171): host->host copies. + if (!iree_any_bit_set(iree_hal_buffer_memory_type(source_buffer), + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE) && + !iree_any_bit_set(iree_hal_buffer_memory_type(target_buffer), + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + return iree_make_status( + IREE_STATUS_PERMISSION_DENIED, + "at least one buffer must be device-visible for a copy; " + "source_buffer=%s, target_buffer=%s", + iree_hal_memory_type_string(iree_hal_buffer_memory_type(source_buffer)), + iree_hal_memory_type_string( + iree_hal_buffer_memory_type(target_buffer))); + } + + // Check for overlap - just like memcpy we don't handle that. + if (iree_hal_buffer_test_overlap(source_buffer, source_offset, length, + target_buffer, target_offset, length) != + IREE_HAL_BUFFER_OVERLAP_DISJOINT) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "source and target ranges overlap within the same buffer"); + } + + return iree_hal_command_buffer_copy_buffer( + command_buffer->target_command_buffer, source_buffer, source_offset, + target_buffer, target_offset, length); +} + +static iree_status_t iree_hal_validating_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + + if (IREE_UNLIKELY((values_length % 4) != 0)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid alignment %zu, must be 4-byte aligned", + values_length); + } + + // TODO(benvanik): validate offset and value count with layout. + + return iree_hal_command_buffer_push_constants( + command_buffer->target_command_buffer, executable_layout, offset, values, + values_length); +} + +static iree_status_t iree_hal_validating_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + + // TODO(benvanik): validate set index. + // TODO(benvanik): validate binding_offset. + // TODO(benvanik): validate bindings. + + return iree_hal_command_buffer_push_descriptor_set( + command_buffer->target_command_buffer, executable_layout, set, + binding_count, bindings); +} + +static iree_status_t iree_hal_validating_command_buffer_bind_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + + // TODO(benvanik): validate set index. + // TODO(benvanik): validate dynamic offsets (both count and offsets). + + return iree_hal_command_buffer_bind_descriptor_set( + command_buffer->target_command_buffer, executable_layout, set, + descriptor_set, dynamic_offset_count, dynamic_offsets); +} + +static iree_status_t iree_hal_validating_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings( + command_buffer, executable, entry_point)); + + return iree_hal_command_buffer_dispatch(command_buffer->target_command_buffer, + executable, entry_point, workgroup_x, + workgroup_y, workgroup_z); +} + +static iree_status_t iree_hal_validating_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + iree_hal_validating_command_buffer_t* command_buffer = + (iree_hal_validating_command_buffer_t*)base_command_buffer; + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_compatibility( + command_buffer, workgroups_buffer, + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH, + IREE_HAL_BUFFER_USAGE_DISPATCH)); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type( + iree_hal_buffer_memory_type(workgroups_buffer), + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access( + iree_hal_buffer_allowed_access(workgroups_buffer), + IREE_HAL_MEMORY_ACCESS_READ)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage( + iree_hal_buffer_allowed_usage(workgroups_buffer), + IREE_HAL_BUFFER_USAGE_DISPATCH)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_range( + workgroups_buffer, workgroups_offset, sizeof(uint32_t) * 3)); + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings( + command_buffer, executable, entry_point)); + + return iree_hal_command_buffer_dispatch_indirect( + command_buffer->target_command_buffer, executable, entry_point, + workgroups_buffer, workgroups_offset); +} + +static const iree_hal_command_buffer_vtable_t + iree_hal_validating_command_buffer_vtable = { + .destroy = iree_hal_validating_command_buffer_destroy, + .allowed_categories = + iree_hal_validating_command_buffer_allowed_categories, + .begin = iree_hal_validating_command_buffer_begin, + .end = iree_hal_validating_command_buffer_end, + .execution_barrier = + iree_hal_validating_command_buffer_execution_barrier, + .signal_event = iree_hal_validating_command_buffer_signal_event, + .reset_event = iree_hal_validating_command_buffer_reset_event, + .wait_events = iree_hal_validating_command_buffer_wait_events, + .discard_buffer = iree_hal_validating_command_buffer_discard_buffer, + .fill_buffer = iree_hal_validating_command_buffer_fill_buffer, + .update_buffer = iree_hal_validating_command_buffer_update_buffer, + .copy_buffer = iree_hal_validating_command_buffer_copy_buffer, + .push_constants = iree_hal_validating_command_buffer_push_constants, + .push_descriptor_set = + iree_hal_validating_command_buffer_push_descriptor_set, + .bind_descriptor_set = + iree_hal_validating_command_buffer_bind_descriptor_set, + .dispatch = iree_hal_validating_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_validating_command_buffer_dispatch_indirect, +}; diff --git a/iree/hal/command_buffer_validation.cc b/iree/hal/command_buffer_validation.cc deleted file mode 100644 index 58df97fae1f18..0000000000000 --- a/iree/hal/command_buffer_validation.cc +++ /dev/null @@ -1,514 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/command_buffer_validation.h" - -#include "absl/strings/str_join.h" -#include "iree/base/logging.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -namespace { - -// Command buffer validation shim. -// Wraps an existing command buffer to provide in-depth validation during -// recording. This should be enabled whenever the command buffer is being driven -// by unsafe code or when early and readable diagnostics are needed. -class ValidatingCommandBuffer : public CommandBuffer { - public: - explicit ValidatingCommandBuffer(Allocator* allocator, - ref_ptr impl); - ~ValidatingCommandBuffer() override; - - // Device allocator that commands encoded into the buffer share compatibility - // with. - Allocator* allocator() const { return allocator_; } - - CommandBuffer* impl() override { return impl_.get(); } - - bool is_recording() const override; - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status WaitEvents(absl::Span events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - Status DiscardBuffer(Buffer* buffer) override; - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - Status PushConstants(ExecutableLayout* executable_layout, size_t offset, - absl::Span values) override; - Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) override; - Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) override; - Status Dispatch(Executable* executable, int32_t entry_point, - std::array workgroups) override; - Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) override; - - private: - // Returns a failure if the queue does not support the given caps. - Status ValidateCategories(CommandCategoryBitfield required_categories) const; - // Returns a failure if the memory type the buffer was allocated from is not - // compatible with the given type. - Status ValidateCompatibleMemoryType(Buffer* buffer, - MemoryTypeBitfield memory_type) const; - // Returns a failure if the buffer memory type or usage disallows the given - // access type. - Status ValidateAccess(Buffer* buffer, - MemoryAccessBitfield memory_access) const; - // Returns a failure if the buffer was not allocated for the given usage. - Status ValidateUsage(Buffer* buffer, BufferUsageBitfield usage) const; - // Validates that the range provided is within the given buffer. - Status ValidateRange(Buffer* buffer, device_size_t byte_offset, - device_size_t byte_length) const; - - // Validates that the currently bound descriptor sets are valid for the given - // executable entry point. - Status ValidateDispatchBindings(Executable* executable, int32_t entry_point); - - Allocator* const allocator_; - ref_ptr impl_; -}; - -ValidatingCommandBuffer::ValidatingCommandBuffer(Allocator* allocator, - ref_ptr impl) - : CommandBuffer(impl->mode(), impl->command_categories()), - allocator_(allocator), - impl_(std::move(impl)) {} - -ValidatingCommandBuffer::~ValidatingCommandBuffer() = default; - -bool ValidatingCommandBuffer::is_recording() const { - return impl_->is_recording(); -} - -Status ValidatingCommandBuffer::Begin() { - IREE_DVLOG(3) << "CommandBuffer::Begin()"; - if (impl_->is_recording()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Command buffer is already recording"; - } - return impl_->Begin(); -} - -Status ValidatingCommandBuffer::End() { - IREE_DVLOG(3) << "CommandBuffer::End()"; - if (!impl_->is_recording()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Command buffer is not recording"; - } - return impl_->End(); -} - -Status ValidatingCommandBuffer::ValidateCategories( - CommandCategoryBitfield required_categories) const { - if (!AllBitsSet(command_categories(), required_categories)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Operation requires categories " - << CommandCategoryString(required_categories) - << " but buffer only supports " - << CommandCategoryString(command_categories()); - } - return OkStatus(); -} - -Status ValidatingCommandBuffer::ValidateCompatibleMemoryType( - Buffer* buffer, MemoryTypeBitfield memory_type) const { - if ((buffer->memory_type() & memory_type) != memory_type) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is not compatible with the requested " - "operation; buffer has " - << MemoryTypeString(buffer->memory_type()) << ", operation requires " - << MemoryTypeString(memory_type); - } - return OkStatus(); -} - -Status ValidatingCommandBuffer::ValidateAccess( - Buffer* buffer, MemoryAccessBitfield memory_access) const { - if ((buffer->allowed_access() & memory_access) != memory_access) { - // Bits must match exactly. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "The buffer does not support the requested access type; buffer " - "allows " - << MemoryAccessString(buffer->allowed_access()) - << ", operation requires " << MemoryAccessString(memory_access); - } - return OkStatus(); -} - -// Returns a failure if the buffer was not allocated for the given usage. -Status ValidatingCommandBuffer::ValidateUsage(Buffer* buffer, - BufferUsageBitfield usage) const { - if (!allocator()->CanUseBuffer(buffer, usage)) { - // Buffer cannot be used on the queue for the given usage. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Requested usage of " << buffer->DebugString() - << " is not supported for the buffer on this queue; " - "buffer allows " - << BufferUsageString(buffer->usage()) << ", queue requires " - << BufferUsageString(usage); - } - - if ((buffer->usage() & usage) != usage) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Requested usage was not specified when the buffer was " - "allocated; buffer allows " - << BufferUsageString(buffer->usage()) << ", operation requires " - << BufferUsageString(usage); - } - - return OkStatus(); -} - -// Validates that the range provided is within the given buffer. -Status ValidatingCommandBuffer::ValidateRange(Buffer* buffer, - device_size_t byte_offset, - device_size_t byte_length) const { - // Check if the start of the range runs off the end of the buffer. - if (byte_offset > buffer->byte_length()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address off the end of the valid buffer " - "range (offset=" - << byte_offset << ", length=" << byte_length - << ", buffer byte_length=" << buffer->byte_length() << ")"; - } - - if (byte_length == 0) { - // Fine to have a zero length. - return OkStatus(); - } - - // Check if the end runs over the allocation. - device_size_t end = byte_offset + byte_length; - if (end > buffer->byte_length()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address outside of the valid buffer " - "range (offset=" - << byte_offset << ", length=" << byte_length - << ", end(inc)=" << (end - 1) - << ", buffer byte_length=" << buffer->byte_length() << ")"; - } - - return OkStatus(); -} - -Status ValidatingCommandBuffer::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_DVLOG(3) << "CommandBuffer::ExecutionBarrier(...)"; - - // TODO(benvanik): additional synchronization validation. - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer | - CommandCategory::kDispatch)); - - return impl_->ExecutionBarrier(source_stage_mask, target_stage_mask, - memory_barriers, buffer_barriers); -} - -Status ValidatingCommandBuffer::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_DVLOG(3) << "CommandBuffer::SignalEvent(...)"; - - // TODO(benvanik): additional synchronization validation. - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - return impl_->SignalEvent(event, source_stage_mask); -} - -Status ValidatingCommandBuffer::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_DVLOG(3) << "CommandBuffer::ResetEvent(...)"; - - // TODO(benvanik): additional synchronization validation. - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - return impl_->ResetEvent(event, source_stage_mask); -} - -Status ValidatingCommandBuffer::WaitEvents( - absl::Span events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_DVLOG(3) << "CommandBuffer::WaitEvents(...)"; - - // TODO(benvanik): additional synchronization validation. - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - return impl_->WaitEvents(events, source_stage_mask, target_stage_mask, - memory_barriers, buffer_barriers); -} - -Status ValidatingCommandBuffer::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_DVLOG(3) << "CommandBuffer::FillBuffer(" << target_buffer->DebugString() - << ", " << target_offset << ", " << length << ", ??, " - << pattern_length << ")"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible)); - IREE_RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); - IREE_RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); - - // Ensure the value length is supported. - if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fill value length is not one of the supported values " - "(pattern_length=" - << pattern_length << ")"; - } - - // Ensure the offset and length have an alignment matching the value length. - if ((target_offset % pattern_length) != 0 || (length % pattern_length) != 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fill offset and/or length do not match the natural alignment of " - "the fill value (target_offset=" - << target_offset << ", length=" << length - << ", pattern_length=" << pattern_length << ")"; - } - - return impl_->FillBuffer(target_buffer, target_offset, length, pattern, - pattern_length); -} - -Status ValidatingCommandBuffer::DiscardBuffer(Buffer* buffer) { - IREE_DVLOG(3) << "CommandBuffer::DiscardBuffer(" << buffer->DebugString() - << ")"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(buffer, MemoryType::kDeviceVisible)); - IREE_RETURN_IF_ERROR(ValidateUsage(buffer, BufferUsage::kNone)); - - return impl_->DiscardBuffer(buffer); -} - -Status ValidatingCommandBuffer::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_DVLOG(3) << "CommandBuffer::UpdateBuffer(" << source_buffer << ", " - << source_offset << ", " << target_buffer->DebugString() << ", " - << target_offset << ", " << length << ")"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible)); - IREE_RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); - IREE_RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); - - return impl_->UpdateBuffer(source_buffer, source_offset, target_buffer, - target_offset, length); -} - -Status ValidatingCommandBuffer::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_DVLOG(3) << "CommandBuffer::CopyBuffer(" << source_buffer->DebugString() - << ", " << source_offset << ", " << target_buffer->DebugString() - << ", " << target_offset << ", " << length << ")"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - - // At least source or destination must be device-visible to enable - // host->device, device->host, and device->device. - // TODO(b/117338171): host->host copies. - if (!AnyBitSet(source_buffer->memory_type() & MemoryType::kDeviceVisible) && - !AnyBitSet(target_buffer->memory_type() & MemoryType::kDeviceVisible)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "At least one buffer must be device-visible for a copy; " - "source_buffer=" - << MemoryTypeString(source_buffer->memory_type()) - << ", target_buffer=" - << MemoryTypeString(target_buffer->memory_type()); - } - - IREE_RETURN_IF_ERROR(ValidateAccess(source_buffer, MemoryAccess::kRead)); - IREE_RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); - IREE_RETURN_IF_ERROR(ValidateUsage(source_buffer, BufferUsage::kTransfer)); - IREE_RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); - IREE_RETURN_IF_ERROR(ValidateRange(source_buffer, source_offset, length)); - IREE_RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); - - // Check for overlap - just like memcpy we don't handle that. - if (Buffer::TestOverlap(source_buffer, source_offset, length, target_buffer, - target_offset, - length) != Buffer::Overlap::kDisjoint) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Source and target ranges overlap within the same buffer"; - } - - return impl_->CopyBuffer(source_buffer, source_offset, target_buffer, - target_offset, length); -} - -Status ValidatingCommandBuffer::PushConstants( - ExecutableLayout* executable_layout, size_t offset, - absl::Span values) { - IREE_DVLOG(3) << "CommandBuffer::PushConstants(" - << executable_layout->DebugString() << ", " << offset << ", " - << absl::StrJoin(values, ", ") << ")"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - // TODO(benvanik): validate offset and value count with layout. - - return impl_->PushConstants(executable_layout, offset, values); -} - -Status ValidatingCommandBuffer::PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) { - IREE_DVLOG(3) << "CommandBuffer::PushDescriptorSet(" - << executable_layout->DebugString() << ", " << set << ", [" - << absl::StrJoin(bindings, ", ", - DescriptorSetBindingFormatter()) - << "])"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - // TODO(benvanik): validate set index. - // TODO(benvanik): validate bindings. - - return impl_->PushDescriptorSet(executable_layout, set, bindings); -} - -Status ValidatingCommandBuffer::BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) { - IREE_DVLOG(3) << "CommandBuffer::BindDescriptorSet(" - << executable_layout->DebugString() << ", " << set << ", " - << descriptor_set->DebugString() << ", [" - << absl::StrJoin(dynamic_offsets, ", ") << "])"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - // TODO(benvanik): validate set index. - // TODO(benvanik): validate dynamic offsets (both count and offsets). - - return impl_->BindDescriptorSet(executable_layout, set, descriptor_set, - dynamic_offsets); -} - -Status ValidatingCommandBuffer::ValidateDispatchBindings(Executable* executable, - int32_t entry_point) { - // Validate all buffers referenced have compatible memory types, access - // rights, and usage. - // TODO(benvanik): add validation by walking executable layout. - // for (const auto& binding : bindings) { - // IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(binding.buffer, - // MemoryType::kDeviceVisible)) - // << "input buffer: " << MemoryAccessString(binding.access) << " " - // << binding.buffer->DebugStringShort(); - // IREE_RETURN_IF_ERROR(ValidateAccess(binding.buffer, binding.access)); - // IREE_RETURN_IF_ERROR(ValidateUsage(binding.buffer, - // BufferUsage::kDispatch)); - // TODO(benvanik): validate it matches the executable expectations. - // TODO(benvanik): validate buffer contains enough data for shape+size. - // } - - // TODO(benvanik): validate no aliasing between inputs/outputs. - - return OkStatus(); -} - -Status ValidatingCommandBuffer::Dispatch(Executable* executable, - int32_t entry_point, - std::array workgroups) { - IREE_DVLOG(3) << "CommandBuffer::Dispatch(" << executable->DebugString() - << ", " << entry_point << ", [" - << absl::StrJoin(workgroups, ", ") << "])"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - IREE_RETURN_IF_ERROR(ValidateDispatchBindings(executable, entry_point)); - - return impl_->Dispatch(executable, entry_point, workgroups); -} - -Status ValidatingCommandBuffer::DispatchIndirect( - Executable* executable, int32_t entry_point, Buffer* workgroups_buffer, - device_size_t workgroups_offset) { - IREE_DVLOG(3) << "CommandBuffer::DispatchIndirect(" - << executable->DebugString() << ", " << entry_point << ", " - << workgroups_buffer->DebugString() << ", " << workgroups_offset - << ")"; - - IREE_RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType(workgroups_buffer, - MemoryType::kDeviceVisible)) - << "input buffer: " << workgroups_buffer->DebugStringShort(); - IREE_RETURN_IF_ERROR(ValidateAccess(workgroups_buffer, MemoryAccess::kRead)); - IREE_RETURN_IF_ERROR( - ValidateUsage(workgroups_buffer, BufferUsage::kDispatch)); - IREE_RETURN_IF_ERROR(ValidateRange(workgroups_buffer, workgroups_offset, - sizeof(uint32_t) * 3)); - - IREE_RETURN_IF_ERROR(ValidateDispatchBindings(executable, entry_point)); - - return impl_->DispatchIndirect(executable, entry_point, workgroups_buffer, - workgroups_offset); -} - -} // namespace - -ref_ptr WrapCommandBufferWithValidation( - Allocator* allocator, ref_ptr impl) { - return make_ref(allocator, std::move(impl)); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/command_buffer_validation.h b/iree/hal/command_buffer_validation.h deleted file mode 100644 index 026b9b68c0aae..0000000000000 --- a/iree/hal/command_buffer_validation.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_COMMAND_BUFFER_VALIDATION_H_ -#define IREE_HAL_COMMAND_BUFFER_VALIDATION_H_ - -#include "iree/hal/allocator.h" -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { - -// Wraps an existing command buffer to provide in-depth validation during -// recording. This should be enabled whenever the command buffer is being driven -// by unsafe code or when early and readable diagnostics are needed. -ref_ptr WrapCommandBufferWithValidation( - Allocator* allocator, ref_ptr impl); - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_H_ diff --git a/iree/hal/command_queue.h b/iree/hal/command_queue.h deleted file mode 100644 index 7d068d731477c..0000000000000 --- a/iree/hal/command_queue.h +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_COMMAND_QUEUE_H_ -#define IREE_HAL_COMMAND_QUEUE_H_ - -#include -#include - -#include "absl/types/span.h" -#include "iree/base/bitfield.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { - -// A batch of command buffers with synchronization information for submission. -struct SubmissionBatch { - // A set of semaphores that must have their payload values meet or exceed the - // specified values prior to any command buffer within this batch executing. - absl::Span wait_semaphores; - - // Command buffers that will execute in this batch. - // The command buffers will begin execution in order but may complete out of - // order. - absl::Span command_buffers; - - // Semaphores to signal after execution of all command buffers complete. - // Semaphore playloads will be set to the maximum of the specified payload or - // their current payload. - absl::Span signal_semaphores; -}; - -// Asynchronous command execution queue. -// -// CommandQueues may capture device status at Semaphore barriers, including -// information about device state such as thermal throttling. This information -// is a snapshot of the state at the time the semaphore was signaled and not -// necessarily live at the time of the application query. -// -// Command queues are thread-safe and submissions may occur from multiple -// threads. -class CommandQueue { - public: - virtual ~CommandQueue() = default; - - // Name of the queue used for logging purposes. - // Try to keep at 4 characters total for prettier logging. - const std::string& name() const { return name_; } - - // Capabilities of the command queue. - CommandCategoryBitfield supported_categories() const { - return supported_categories_; - } - - // Whether this queue may be used for transfer commands. - bool can_transfer() const { - return AllBitsSet(supported_categories_, CommandCategory::kTransfer); - } - - // Whether this queue may be used for dispatch commands. - bool can_dispatch() const { - return AllBitsSet(supported_categories_, CommandCategory::kDispatch); - } - - // Submits one or more command batches for execution on the queue. - virtual Status Submit(absl::Span batches) = 0; - inline Status Submit(const SubmissionBatch& batch) { - return Submit(absl::MakeConstSpan(&batch, 1)); - } - - // Blocks until all outstanding requests have been completed. - // This is equivalent to having waited on all outstanding semaphores. - // Implicitly calls Flush to ensure delayed requests are scheduled. - // - // If the command queue has encountered an error during submission at any - // point it will be returned here (repeatedly). - virtual Status WaitIdle(Time deadline_ns) = 0; - inline Status WaitIdle(Duration timeout_ns) { - return WaitIdle(RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - inline Status WaitIdle() { return WaitIdle(InfiniteFuture()); } - - protected: - CommandQueue(std::string name, CommandCategoryBitfield supported_categories) - : name_(std::move(name)), supported_categories_(supported_categories) {} - - const std::string name_; - const CommandCategoryBitfield supported_categories_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_COMMAND_QUEUE_H_ diff --git a/iree/hal/cts/BUILD b/iree/hal/cts/BUILD index babb000a448f3..627f3bf812e12 100644 --- a/iree/hal/cts/BUILD +++ b/iree/hal/cts/BUILD @@ -24,15 +24,9 @@ cc_library( name = "cts_test_base", testonly = True, hdrs = ["cts_test_base.h"], - data = [ - # For AddressSanitizer when using Vulkan + a local Nvidia GPU - "//iree/tools:sanitizer_suppressions.txt", - ], deps = [ - "//iree/base:status", - "//iree/hal", + "//iree/base:api", "//iree/hal:api", - "//iree/hal/testing:driver_registry", "//iree/testing:gtest", ], ) @@ -42,7 +36,6 @@ cc_test( srcs = ["allocator_test.cc"], deps = [ ":cts_test_base", - "//iree/base:status", "//iree/hal/testing:driver_registry", "//iree/testing:gtest", "//iree/testing:gtest_main", @@ -50,11 +43,10 @@ cc_test( ) cc_test( - name = "buffer_test", - srcs = ["buffer_test.cc"], + name = "command_buffer_test", + srcs = ["command_buffer_test.cc"], deps = [ ":cts_test_base", - "//iree/base:status", "//iree/hal/testing:driver_registry", "//iree/testing:gtest", "//iree/testing:gtest_main", @@ -62,11 +54,10 @@ cc_test( ) cc_test( - name = "command_buffer_test", - srcs = ["command_buffer_test.cc"], + name = "descriptor_set_test", + srcs = ["descriptor_set_test.cc"], deps = [ ":cts_test_base", - "//iree/base:status", "//iree/hal/testing:driver_registry", "//iree/testing:gtest", "//iree/testing:gtest_main", @@ -74,11 +65,10 @@ cc_test( ) cc_test( - name = "command_queue_test", - srcs = ["command_queue_test.cc"], + name = "descriptor_set_layout_test", + srcs = ["descriptor_set_layout_test.cc"], deps = [ ":cts_test_base", - "//iree/base:status", "//iree/hal/testing:driver_registry", "//iree/testing:gtest", "//iree/testing:gtest_main", @@ -96,6 +86,28 @@ cc_test( ], ) +cc_test( + name = "event_test", + srcs = ["event_test.cc"], + deps = [ + ":cts_test_base", + "//iree/hal/testing:driver_registry", + "//iree/testing:gtest", + "//iree/testing:gtest_main", + ], +) + +cc_test( + name = "executable_layout_test", + srcs = ["executable_layout_test.cc"], + deps = [ + ":cts_test_base", + "//iree/hal/testing:driver_registry", + "//iree/testing:gtest", + "//iree/testing:gtest_main", + ], +) + cc_test( name = "semaphore_test", srcs = ["semaphore_test.cc"], diff --git a/iree/hal/cts/CMakeLists.txt b/iree/hal/cts/CMakeLists.txt index 3589b1b9d39a5..f5c3d9445ff8a 100644 --- a/iree/hal/cts/CMakeLists.txt +++ b/iree/hal/cts/CMakeLists.txt @@ -19,13 +19,9 @@ iree_cc_library( cts_test_base HDRS "cts_test_base.h" - DATA - iree::tools::sanitizer_suppressions.txt DEPS - iree::base::status - iree::hal + iree::base::api iree::hal::api - iree::hal::testing::driver_registry iree::testing::gtest TESTONLY PUBLIC @@ -38,7 +34,6 @@ iree_cc_test( "allocator_test.cc" DEPS ::cts_test_base - iree::base::status iree::hal::testing::driver_registry iree::testing::gtest iree::testing::gtest_main @@ -46,12 +41,11 @@ iree_cc_test( iree_cc_test( NAME - buffer_test + command_buffer_test SRCS - "buffer_test.cc" + "command_buffer_test.cc" DEPS ::cts_test_base - iree::base::status iree::hal::testing::driver_registry iree::testing::gtest iree::testing::gtest_main @@ -59,12 +53,11 @@ iree_cc_test( iree_cc_test( NAME - command_buffer_test + descriptor_set_test SRCS - "command_buffer_test.cc" + "descriptor_set_test.cc" DEPS ::cts_test_base - iree::base::status iree::hal::testing::driver_registry iree::testing::gtest iree::testing::gtest_main @@ -72,12 +65,11 @@ iree_cc_test( iree_cc_test( NAME - command_queue_test + descriptor_set_layout_test SRCS - "command_queue_test.cc" + "descriptor_set_layout_test.cc" DEPS ::cts_test_base - iree::base::status iree::hal::testing::driver_registry iree::testing::gtest iree::testing::gtest_main @@ -95,6 +87,30 @@ iree_cc_test( iree::testing::gtest_main ) +iree_cc_test( + NAME + event_test + SRCS + "event_test.cc" + DEPS + ::cts_test_base + iree::hal::testing::driver_registry + iree::testing::gtest + iree::testing::gtest_main +) + +iree_cc_test( + NAME + executable_layout_test + SRCS + "executable_layout_test.cc" + DEPS + ::cts_test_base + iree::hal::testing::driver_registry + iree::testing::gtest + iree::testing::gtest_main +) + iree_cc_test( NAME semaphore_test diff --git a/iree/hal/cts/allocator_test.cc b/iree/hal/cts/allocator_test.cc index 8677133c96751..d1f76a7f5c0ce 100644 --- a/iree/hal/cts/allocator_test.cc +++ b/iree/hal/cts/allocator_test.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/base/status.h" #include "iree/hal/cts/cts_test_base.h" #include "iree/hal/testing/driver_registry.h" #include "iree/testing/gtest.h" @@ -22,61 +21,71 @@ namespace iree { namespace hal { namespace cts { -class AllocatorTest : public CtsTestBase { - protected: - virtual void SetUp() { - CtsTestBase::SetUp(); +class AllocatorTest : public CtsTestBase {}; - if (!device_) { - return; - } +// Tests for baseline buffer compatibility that all HAL drivers must support. +TEST_P(AllocatorTest, QueryBufferCompatibility) { + iree_host_size_t allocation_size = 1024; - allocator_ = device_->allocator(); - } + // Need at least one way to get data between the host and device. + iree_hal_buffer_compatibility_t transfer_compatibility_host = + iree_hal_allocator_query_buffer_compatibility( + device_allocator_, + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + /*allowed_usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER, + /*intended_usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER, allocation_size); + iree_hal_buffer_compatibility_t transfer_compatibility_device = + iree_hal_allocator_query_buffer_compatibility( + device_allocator_, + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + /*allowed_usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER, + /*intended_usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER, allocation_size); + iree_hal_buffer_compatibility_t required_transfer_compatibility = + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE | + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + EXPECT_TRUE(iree_all_bits_set(transfer_compatibility_host, + required_transfer_compatibility) || + iree_all_bits_set(transfer_compatibility_device, + required_transfer_compatibility)); - Allocator* allocator_ = nullptr; -}; - -TEST_P(AllocatorTest, CanAllocate) { - EXPECT_TRUE(allocator_->CanAllocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kMapping, 1024)); - EXPECT_TRUE(allocator_->CanAllocate( - MemoryType::kHostVisible | MemoryType::kDeviceLocal, - BufferUsage::kMapping, 1024)); - - // TODO(scotttodd): Minimum memory types and buffer usages necessary for use - // TODO(scotttodd): Test upper limits of memory size for allocations (1GB+)? + // Need to be able to use some type of buffer as dispatch inputs or outputs. + iree_hal_buffer_compatibility_t dispatch_compatibility = + iree_hal_allocator_query_buffer_compatibility( + device_allocator_, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + /*allowed_usage=*/IREE_HAL_BUFFER_USAGE_DISPATCH, + /*intended_usage=*/IREE_HAL_BUFFER_USAGE_DISPATCH, allocation_size); + EXPECT_TRUE( + iree_all_bits_set(dispatch_compatibility, + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE | + IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH)); } -TEST_P(AllocatorTest, Allocate) { - MemoryType memory_type = MemoryType::kHostLocal | MemoryType::kDeviceVisible; - BufferUsage usage = BufferUsage::kMapping; - size_t allocation_size = 1024; +TEST_P(AllocatorTest, AllocateBuffer) { + iree_hal_memory_type_t memory_type = + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE; + iree_hal_buffer_usage_t buffer_usage = IREE_HAL_BUFFER_USAGE_ALL; + iree_host_size_t allocation_size = 1024; - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, allocator_->Allocate(memory_type, usage, allocation_size)); + iree_hal_buffer_t* buffer; + IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( + device_allocator_, memory_type, buffer_usage, allocation_size, &buffer)); - EXPECT_EQ(allocator_, buffer->allocator()); + EXPECT_EQ(device_allocator_, iree_hal_buffer_allocator(buffer)); // At a mimimum, the requested memory type should be respected. // Additional bits may be optionally set depending on the allocator. - EXPECT_TRUE((buffer->memory_type() & memory_type) == memory_type); - EXPECT_TRUE((buffer->usage() & usage) == usage); - EXPECT_GE(buffer->allocation_size(), allocation_size); // Larger is okay. -} - -TEST_P(AllocatorTest, CanUseBufferLike) { - MemoryType memory_type = MemoryType::kHostLocal | MemoryType::kDeviceVisible; - BufferUsage usage = BufferUsage::kMapping; - size_t allocation_size = 1024; - - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, allocator_->Allocate(memory_type, usage, allocation_size)); - // Using the buffer for its original requested purpose should be fine. EXPECT_TRUE( - allocator_->CanUseBufferLike(allocator_, memory_type, usage, usage)); + iree_all_bits_set(iree_hal_buffer_memory_type(buffer), memory_type)); + EXPECT_TRUE( + iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer), buffer_usage)); + EXPECT_GE(iree_hal_buffer_allocation_size(buffer), + allocation_size); // Larger is okay. + + iree_hal_buffer_release(buffer); } +// TODO(scotttodd): iree_hal_allocator_wrap_buffer +// * if implemented (skip test if status is "IREE_STATUS_UNAVAILABLE") + INSTANTIATE_TEST_SUITE_P( AllDrivers, AllocatorTest, ::testing::ValuesIn(testing::EnumerateAvailableDrivers()), diff --git a/iree/hal/cts/buffer_test.cc b/iree/hal/cts/buffer_test.cc deleted file mode 100644 index 0d754e1459345..0000000000000 --- a/iree/hal/cts/buffer_test.cc +++ /dev/null @@ -1,385 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/status.h" -#include "iree/hal/cts/cts_test_base.h" -#include "iree/hal/testing/driver_registry.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -namespace cts { - -using ::testing::_; -using ::testing::ElementsAre; -using ::testing::Eq; - -// Note: this file only covers hal::Buffer APIs that can be overridden by -// subclasses. Errors caught by hal::Buffer's common validations are not -// covered as they are already tested in iree/hal/buffer_test.cc. - -class BufferTest : public CtsTestBase {}; - -TEST_P(BufferTest, Allocate) { - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 14)); - - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(MemoryAccess::kAll, buffer->allowed_access()); - EXPECT_EQ(MemoryType::kHostLocal | MemoryType::kDeviceVisible, - buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - - EXPECT_LE(14, buffer->allocation_size()); - EXPECT_EQ(0, buffer->byte_offset()); - EXPECT_EQ(14, buffer->byte_length()); -} - -TEST_P(BufferTest, AllocateZeroLength) { - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 0)); - EXPECT_LE(0, buffer->allocation_size()); -} - -TEST_P(BufferTest, Fill8) { - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 5)); - - std::vector actual_data(buffer->allocation_size()); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill8(0, 0, 0x44u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - IREE_EXPECT_OK(buffer->Fill8(2, kWholeBuffer, 0x55u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x55, 0x55, 0x55)); - - // Fill a small region of the buffer. - IREE_EXPECT_OK(buffer->Fill8(1, 1, 0x66u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x66, 0x55, 0x55, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(buffer->Fill8(0x99u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x99, 0x99, 0x99, 0x99, 0x99)); -} - -TEST_P(BufferTest, Fill16) { - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 9)); - - std::vector actual_data(buffer->allocation_size()); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill16(0, 0, 0x5566u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto aligned_buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 8)); - IREE_EXPECT_OK(aligned_buffer->Fill16(4, kWholeBuffer, 0x5566u)); - std::vector aligned_actual_data(aligned_buffer->allocation_size()); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x66, 0x55, 0x66, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(aligned_buffer->Fill16(0x5566u)); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x66, 0x55, 0x66, 0x55, 0x66, 0x55, 0x66, 0x55)); -} - -TEST_P(BufferTest, Fill32) { - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 9)); - - std::vector actual_data(buffer->allocation_size()); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill32(0, 0, 0x55667788u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto aligned_buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 8)); - IREE_EXPECT_OK(aligned_buffer->Fill32(4, kWholeBuffer, 0x55667788u)); - std::vector aligned_actual_data(aligned_buffer->allocation_size()); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x88, 0x77, 0x66, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(aligned_buffer->Fill32(0x55667788u)); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x88, 0x77, 0x66, 0x55, 0x88, 0x77, 0x66, 0x55)); -} - -TEST_P(BufferTest, ReadWriteData) { - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, 4)); - - std::vector actual_data(4); - - // Write over the entire buffer. - std::vector new_data = {10, 20, 30, 40}; - IREE_EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Writing zero bytes is valid. - std::vector zero_data; - IREE_EXPECT_OK(buffer->WriteData(0, zero_data.data(), 0)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Write over a portion of the buffer. - std::vector partial_data = {99}; - IREE_EXPECT_OK( - buffer->WriteData(1, partial_data.data(), partial_data.size())); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(10, 99, 30, 40)); -} - -TEST_P(BufferTest, CopyData) { - std::vector src_data = {0, 1, 2, 3}; - IREE_ASSERT_OK_AND_ASSIGN( - auto src_buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, src_data.size())); - IREE_EXPECT_OK(src_buffer->WriteData(0, src_data.data(), src_data.size())); - - std::vector dst_data = {0, 1, 2, 3, 4}; - IREE_ASSERT_OK_AND_ASSIGN( - auto dst_buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, dst_data.size())); - IREE_EXPECT_OK(dst_buffer->WriteData(0, dst_data.data(), dst_data.size())); - - // Copy of length 0 should not change the dest buffer. - IREE_EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 0, 0)); - std::vector actual_data(dst_data.size()); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(dst_data)); - - // Copy a subrange of the buffer. - IREE_EXPECT_OK(dst_buffer->CopyData(1, src_buffer.get(), 2, 2)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 3, 4)); - - // Copy the entire buffer using kWholeBuffer. This will adjust sizes - // to ensure that the min buffer is taken. We test both src and dst buffer - // offset/length calculations (note that some may end up as 0 copies). - IREE_EXPECT_OK(dst_buffer->CopyData(3, src_buffer.get(), 0, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 0, 1)); - IREE_EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 2, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(2, 3, 3, 0, 1)); - IREE_EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 3, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 1)); - IREE_EXPECT_OK(dst_buffer->CopyData(4, src_buffer.get(), 0, kWholeBuffer)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 0)); -} - -TEST_P(BufferTest, MapMemory) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, src_data.size())); - IREE_EXPECT_OK(buffer->WriteData(0, src_data.data(), src_data.size())); - - // 0-length mappings are valid. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, buffer->MapMemory(MemoryAccess::kRead, 0, 0)); - EXPECT_TRUE(mapping.empty()); - EXPECT_EQ(0, mapping.size()); - EXPECT_EQ(0, mapping.byte_length()); - EXPECT_NE(nullptr, mapping.data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mapping.Subspan()); - EXPECT_TRUE(span.empty()); - mapping.reset(); - - // Map the whole buffer for reading. - IREE_ASSERT_OK_AND_ASSIGN(mapping, buffer->MapMemory( - MemoryAccess::kRead, 0, kWholeBuffer)); - EXPECT_EQ(src_data.size(), mapping.size()); - IREE_ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(0, 1, 2, 3, 4, 5, 6)); - mapping.reset(); - - // Map a portion of the buffer for reading. - IREE_ASSERT_OK_AND_ASSIGN( - mapping, buffer->MapMemory(MemoryAccess::kRead, 1, 2)); - EXPECT_EQ(2, mapping.size()); - IREE_ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(1, 2)); - mapping.reset(); -} - -TEST_P(BufferTest, MapMemoryNonByte) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, src_data.size())); - IREE_EXPECT_OK(buffer->WriteData(0, src_data.data(), src_data.size())); - - // Map the buffer as non-byte values. - // Note that we'll round down to the number of valid elements at the - // alignment. - IREE_ASSERT_OK_AND_ASSIGN(auto mapping16, - buffer->MapMemory(MemoryAccess::kRead)); - EXPECT_EQ(3, mapping16.size()); - EXPECT_LE(6, mapping16.byte_length()); - IREE_ASSERT_OK_AND_ASSIGN(auto span16, mapping16.Subspan()); - EXPECT_THAT(span16, ElementsAre(0x0100, 0x0302, 0x0504)); - mapping16.reset(); -} - -TEST_P(BufferTest, MapMemoryWrite) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, src_data.size())); - IREE_EXPECT_OK(buffer->WriteData(0, src_data.data(), src_data.size())); - - // Map and modify the data. We should see it when we read back. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, buffer->MapMemory(MemoryAccess::kWrite, 1, 2)); - auto mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xAA; - mutable_data[1] = 0xBB; - mapping.reset(); - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0xAA, 0xBB, 3, 4, 5, 6)); -} - -TEST_P(BufferTest, MapMemoryDiscard) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - IREE_ASSERT_OK_AND_ASSIGN( - auto buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, src_data.size())); - IREE_EXPECT_OK(buffer->WriteData(0, src_data.data(), src_data.size())); - - // Map for discard. Note that we can't really rely on the value of the data - // so we just trust that it's been discarded. It's a hint, anyway. We can be - // sure that the data we didn't want to discard is the same though. - std::vector actual_data(src_data.size()); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - buffer->MapMemory(MemoryAccess::kDiscardWrite, 1, 2)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, _, _, 3, 4, 5, 6)); - mapping.reset(); -} - -TEST_P(BufferTest, MapMemorySubspan) { - std::vector src_data = {0, 1, 2, 3, 4, 5, 6}; - IREE_ASSERT_OK_AND_ASSIGN( - auto parent_buffer, - device_->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kTransfer | BufferUsage::kMapping, src_data.size())); - IREE_EXPECT_OK(parent_buffer->WriteData(0, src_data.data(), src_data.size())); - - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 3)); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - subspan_buffer->MapMemory(MemoryAccess::kDiscardWrite, 1, 2)); - auto* mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xCC; - mutable_data[1] = 0xDD; - mapping.reset(); - - std::vector actual_data(src_data.size()); - IREE_EXPECT_OK( - parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6)); -} - -INSTANTIATE_TEST_SUITE_P( - AllDrivers, BufferTest, - ::testing::ValuesIn(testing::EnumerateAvailableDrivers()), - GenerateTestName()); - -} // namespace cts -} // namespace hal -} // namespace iree diff --git a/iree/hal/cts/command_buffer_test.cc b/iree/hal/cts/command_buffer_test.cc index 0dac6ddae4999..f82c9293027f1 100644 --- a/iree/hal/cts/command_buffer_test.cc +++ b/iree/hal/cts/command_buffer_test.cc @@ -15,7 +15,6 @@ #include #include -#include "iree/base/status.h" #include "iree/hal/cts/cts_test_base.h" #include "iree/hal/testing/driver_registry.h" #include "iree/testing/gtest.h" @@ -29,213 +28,209 @@ using ::testing::ContainerEq; class CommandBufferTest : public CtsTestBase { protected: - static constexpr device_size_t kBufferNumBytes = 16; - - void SubmitAndWait(CommandQueue* command_queue, - CommandBuffer* command_buffer) { - IREE_ASSERT_OK_AND_ASSIGN(auto signal_semaphore, - device_->CreateSemaphore(0ull)); - - IREE_ASSERT_OK(command_queue->Submit( - {{}, {command_buffer}, {{signal_semaphore.get(), 1ull}}})); - IREE_ASSERT_OK(signal_semaphore->Wait(1ull, InfiniteFuture())); - } + static constexpr iree_device_size_t kBufferSize = 4096; }; TEST_P(CommandBufferTest, Create) { - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kDispatch)); - - EXPECT_TRUE((command_buffer->mode() & CommandBufferMode::kOneShot) == - CommandBufferMode::kOneShot); - EXPECT_TRUE((command_buffer->command_categories() & - CommandCategory::kDispatch) == CommandCategory::kDispatch); - EXPECT_FALSE(command_buffer->is_recording()); + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer)); + + EXPECT_TRUE((iree_hal_command_buffer_allowed_categories(command_buffer) & + IREE_HAL_COMMAND_CATEGORY_DISPATCH) == + IREE_HAL_COMMAND_CATEGORY_DISPATCH); + + iree_hal_command_buffer_release(command_buffer); } TEST_P(CommandBufferTest, BeginEnd) { - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kDispatch)); - - EXPECT_FALSE(command_buffer->is_recording()); - IREE_EXPECT_OK(command_buffer->Begin()); - EXPECT_TRUE(command_buffer->is_recording()); - IREE_EXPECT_OK(command_buffer->End()); - EXPECT_FALSE(command_buffer->is_recording()); + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer)); + + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + iree_hal_command_buffer_release(command_buffer); +} + +TEST_P(CommandBufferTest, SubmitEmpty) { + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer)); + + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_DISPATCH, + command_buffer)); + + iree_hal_command_buffer_release(command_buffer); } TEST_P(CommandBufferTest, FillBufferWithRepeatedBytes) { - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kTransfer)); + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, &command_buffer)); - IREE_ASSERT_OK_AND_ASSIGN( - auto device_buffer, - device_->allocator()->Allocate( - MemoryType::kDeviceLocal | MemoryType::kHostVisible, - BufferUsage::kAll, kBufferNumBytes)); + iree_hal_buffer_t* device_buffer; + IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( + device_allocator_, + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, + IREE_HAL_BUFFER_USAGE_ALL, kBufferSize, &device_buffer)); - std::vector reference_buffer(kBufferNumBytes); + std::vector reference_buffer(kBufferSize); - IREE_EXPECT_OK(command_buffer->Begin()); + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); // Fill the device buffer with segments of different values so that we can // test both fill and offset/size. - uint8_t val1 = 0x07; - IREE_EXPECT_OK(command_buffer->FillBuffer(device_buffer.get(), - /*target_offset=*/0, - /*length=*/kBufferNumBytes / 4, - &val1, - /*pattern_length=*/1)); - std::memset(reference_buffer.data(), val1, kBufferNumBytes / 4); + iree_hal_command_buffer_fill_buffer( + command_buffer, device_buffer, + /*target_offset=*/0, /*length=*/kBufferSize / 4, /*pattern=*/&val1, + /*pattern_length=*/sizeof(val1)); + std::memset(reference_buffer.data(), val1, kBufferSize / 4); uint8_t val2 = 0xbe; - IREE_EXPECT_OK( - command_buffer->FillBuffer(device_buffer.get(), - /*target_offset=*/kBufferNumBytes / 4, - /*length=*/kBufferNumBytes / 4, &val2, - /*pattern_length=*/1)); - std::memset(reference_buffer.data() + kBufferNumBytes / 4, val2, - kBufferNumBytes / 4); + iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer, + /*target_offset=*/kBufferSize / 4, + /*length=*/kBufferSize / 4, + /*pattern=*/&val2, + /*pattern_length=*/sizeof(val2)); + std::memset(reference_buffer.data() + kBufferSize / 4, val2, kBufferSize / 4); uint8_t val3 = 0x54; - IREE_EXPECT_OK( - command_buffer->FillBuffer(device_buffer.get(), - /*target_offset=*/kBufferNumBytes / 2, - /*length=*/kBufferNumBytes / 2, &val3, - /*pattern_length=*/1)); - std::memset(reference_buffer.data() + kBufferNumBytes / 2, val3, - kBufferNumBytes / 2); - - IREE_EXPECT_OK(command_buffer->End()); - - SubmitAndWait(device_->transfer_queues()[0], command_buffer.get()); - - // Read back the device buffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapped_memory, - device_buffer->MapMemory(MemoryAccess::kRead)); - IREE_EXPECT_OK(mapped_memory.Invalidate()); - - std::vector actual_data(mapped_memory.data(), - mapped_memory.data() + kBufferNumBytes); + iree_hal_command_buffer_fill_buffer(command_buffer, device_buffer, + /*target_offset=*/kBufferSize / 2, + /*length=*/kBufferSize / 2, + /*pattern=*/&val3, + /*pattern_length=*/sizeof(val3)); + std::memset(reference_buffer.data() + kBufferSize / 2, val3, kBufferSize / 2); + + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_TRANSFER, + command_buffer)); + + // Read the device buffer and compare. + std::vector actual_data(kBufferSize); + IREE_ASSERT_OK(iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0, + /*target_buffer=*/actual_data.data(), + /*data_length=*/kBufferSize)); EXPECT_THAT(actual_data, ContainerEq(reference_buffer)); + + // Must release the command buffer before resources used by it. + iree_hal_command_buffer_release(command_buffer); + iree_hal_buffer_release(device_buffer); } TEST_P(CommandBufferTest, CopyWholeBuffer) { - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kTransfer)); - - // Create a host buffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto host_buffer, device_->allocator()->Allocate( - MemoryType::kHostVisible | MemoryType::kHostCached | - MemoryType::kDeviceVisible, - BufferUsage::kAll, kBufferNumBytes)); - - // Fill the host buffer. - uint8_t i8_val = 0x55; - IREE_EXPECT_OK(host_buffer->Fill8(0, kWholeBuffer, i8_val)); - IREE_ASSERT_OK_AND_ASSIGN( - auto host_mapped_memory, - // Cannot use kDiscard here given we filled in the above. - host_buffer->MapMemory(MemoryAccess::kWrite)); - IREE_EXPECT_OK(host_mapped_memory.Flush()); - - std::vector reference_buffer(kBufferNumBytes); - std::memset(reference_buffer.data(), i8_val, kBufferNumBytes); + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, &command_buffer)); + + // Create and fill a host buffer. + iree_hal_buffer_t* host_buffer; + IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( + device_allocator_, + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_CACHED | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + IREE_HAL_BUFFER_USAGE_ALL, kBufferSize, &host_buffer)); + uint8_t i8_val = 0x54; + IREE_ASSERT_OK(iree_hal_buffer_fill(host_buffer, /*byte_offset=*/0, + /*byte_length=*/kBufferSize, &i8_val, + /*pattern_length=*/sizeof(i8_val))); + std::vector reference_buffer(kBufferSize); + std::memset(reference_buffer.data(), i8_val, kBufferSize); // Create a device buffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto device_buffer, - device_->allocator()->Allocate( - MemoryType::kDeviceLocal | MemoryType::kHostVisible, - BufferUsage::kAll, kBufferNumBytes)); + iree_hal_buffer_t* device_buffer; + IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( + device_allocator_, + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, + IREE_HAL_BUFFER_USAGE_ALL, kBufferSize, &device_buffer)); // Copy the host buffer to the device buffer. - IREE_EXPECT_OK(command_buffer->Begin()); - IREE_EXPECT_OK( - command_buffer->CopyBuffer(host_buffer.get(), /*source_offset=*/0, - device_buffer.get(), /*target_offset=*/0, - /*length=*/kBufferNumBytes)); - IREE_EXPECT_OK(command_buffer->End()); - - SubmitAndWait(device_->transfer_queues()[0], command_buffer.get()); - - // Read back the device buffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto device_mapped_memory, - device_buffer->MapMemory(MemoryAccess::kRead)); - IREE_EXPECT_OK(device_mapped_memory.Invalidate()); - - std::vector actual_data( - device_mapped_memory.data(), - device_mapped_memory.data() + kBufferNumBytes); + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_copy_buffer( + command_buffer, /*source_buffer=*/host_buffer, /*source_offset=*/0, + /*target_buffer=*/device_buffer, /*target_offset=*/0, + /*length=*/kBufferSize)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_TRANSFER, + command_buffer)); + + // Read the device buffer and compare. + std::vector actual_data(kBufferSize); + IREE_ASSERT_OK(iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0, + /*target_buffer=*/actual_data.data(), + /*data_length=*/kBufferSize)); EXPECT_THAT(actual_data, ContainerEq(reference_buffer)); + + // Must release the command buffer before resources used by it. + iree_hal_command_buffer_release(command_buffer); + iree_hal_buffer_release(device_buffer); + iree_hal_buffer_release(host_buffer); } TEST_P(CommandBufferTest, CopySubBuffer) { - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kTransfer)); - // Create a device buffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto device_buffer, - device_->allocator()->Allocate( - MemoryType::kDeviceLocal | MemoryType::kHostVisible, - BufferUsage::kAll, kBufferNumBytes)); + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, &command_buffer)); + + iree_hal_buffer_t* device_buffer; + IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( + device_allocator_, + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, + IREE_HAL_BUFFER_USAGE_ALL, kBufferSize, &device_buffer)); // Create another host buffer with a smaller size. - IREE_ASSERT_OK_AND_ASSIGN( - auto host_buffer, device_->allocator()->Allocate( - MemoryType::kHostVisible | MemoryType::kHostCached | - MemoryType::kDeviceVisible, - BufferUsage::kAll, kBufferNumBytes / 2)); + iree_hal_buffer_t* host_buffer; + IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( + device_allocator_, + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_CACHED | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + IREE_HAL_BUFFER_USAGE_ALL, kBufferSize / 2, &host_buffer)); // Fill the host buffer. uint8_t i8_val = 0x88; - IREE_EXPECT_OK(host_buffer->Fill8(0, kWholeBuffer, i8_val)); - IREE_ASSERT_OK_AND_ASSIGN( - auto host_mapped_memory, - // Cannot use kDiscard here given we filled in the above. - host_buffer->MapMemory(MemoryAccess::kWrite)); - IREE_EXPECT_OK(host_mapped_memory.Flush()); - - std::vector reference_buffer(kBufferNumBytes); - std::memset(reference_buffer.data() + 8, i8_val, kBufferNumBytes / 2 - 4); + IREE_ASSERT_OK(iree_hal_buffer_fill(host_buffer, /*byte_offset=*/0, + /*byte_length=*/kBufferSize / 2, &i8_val, + /*pattern_length=*/sizeof(i8_val))); + std::vector reference_buffer(kBufferSize); + std::memset(reference_buffer.data() + 8, i8_val, kBufferSize / 2 - 4); // Copy the host buffer to the device buffer. - IREE_EXPECT_OK(command_buffer->Begin()); - IREE_EXPECT_OK( - command_buffer->CopyBuffer(host_buffer.get(), /*source_offset=*/4, - device_buffer.get(), /*target_offset=*/8, - /*length=*/kBufferNumBytes / 2 - 4)); - IREE_EXPECT_OK(command_buffer->End()); - - SubmitAndWait(device_->transfer_queues()[0], command_buffer.get()); - - // Read back the device buffer. - IREE_ASSERT_OK_AND_ASSIGN( - auto device_mapped_memory, - device_buffer->MapMemory(MemoryAccess::kRead)); - IREE_EXPECT_OK(device_mapped_memory.Invalidate()); - - std::vector actual_data( - device_mapped_memory.data(), - device_mapped_memory.data() + kBufferNumBytes); + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_copy_buffer( + command_buffer, /*source_buffer=*/host_buffer, /*source_offset=*/4, + /*target_buffer=*/device_buffer, /*target_offset=*/8, + /*length=*/kBufferSize / 2 - 4)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_TRANSFER, + command_buffer)); + + // Read the device buffer and compare. + std::vector actual_data(kBufferSize); + IREE_ASSERT_OK(iree_hal_buffer_read_data(device_buffer, /*source_offset=*/0, + /*target_buffer=*/actual_data.data(), + /*data_length=*/kBufferSize)); EXPECT_THAT(actual_data, ContainerEq(reference_buffer)); -} -// TODO(scotttodd): UpdateBuffer, Dispatch, Sync, etc. + // Must release the command buffer before resources used by it. + iree_hal_command_buffer_release(command_buffer); + iree_hal_buffer_release(device_buffer); + iree_hal_buffer_release(host_buffer); +} INSTANTIATE_TEST_SUITE_P( AllDrivers, CommandBufferTest, diff --git a/iree/hal/cts/command_queue_test.cc b/iree/hal/cts/command_queue_test.cc deleted file mode 100644 index 230e3d128d9fa..0000000000000 --- a/iree/hal/cts/command_queue_test.cc +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "iree/base/status.h" -#include "iree/hal/cts/cts_test_base.h" -#include "iree/hal/testing/driver_registry.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -namespace cts { -namespace { - -using ::iree::testing::status::IsOkAndHolds; -using ::testing::Eq; - -class CommandQueueTest : public CtsTestBase {}; - -TEST_P(CommandQueueTest, EnumerateDeviceQueues) { - // Log how many queues we have so future test cases have more context. - // Most tests just use the first queue, but supporting multiple queues may be - // relevant on some implementations. - - absl::Span dispatch_queues = device_->dispatch_queues(); - IREE_LOG(INFO) << "Device has " << dispatch_queues.size() - << " dispatch queue(s)"; - EXPECT_GE(dispatch_queues.size(), 1); - for (auto* dispatch_queue : dispatch_queues) { - EXPECT_TRUE(dispatch_queue->can_dispatch()); - } - - absl::Span transfer_queues = device_->transfer_queues(); - IREE_LOG(INFO) << "Device has " << transfer_queues.size() - << " transfer queue(s)"; - EXPECT_GE(transfer_queues.size(), 1); - for (auto* transfer_queue : transfer_queues) { - EXPECT_TRUE(transfer_queue->can_transfer()); - } -} - -// Tests that waiting for idle is a no-op when nothing is queued. -TEST_P(CommandQueueTest, WaitIdleWhileIdle) { - for (auto* dispatch_queue : device_->dispatch_queues()) { - IREE_EXPECT_OK(dispatch_queue->WaitIdle()); - } - for (auto* transfer_queue : device_->transfer_queues()) { - IREE_EXPECT_OK(transfer_queue->WaitIdle()); - } -} - -// Tests that submitting a command buffer and immediately waiting will not -// deadlock. -// Note: this test never completes with Vulkan timeline semaphore emulation. -TEST_P(CommandQueueTest, BlockingSubmit) { - auto command_queue = device_->dispatch_queues()[0]; - - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kDispatch)); - IREE_ASSERT_OK_AND_ASSIGN(auto semaphore, device_->CreateSemaphore(0ull)); - - IREE_ASSERT_OK(command_queue->Submit( - {{}, {command_buffer.get()}, {{semaphore.get(), 1ull}}})); - IREE_ASSERT_OK(semaphore->Wait(1ull, InfiniteFuture())); -} - -// Tests waiting while work is pending/in-flight. -// Note: this test never completes with Vulkan timeline semaphore emulation. -TEST_P(CommandQueueTest, WaitTimeout) { - auto command_queue = device_->dispatch_queues()[0]; - - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kDispatch)); - IREE_ASSERT_OK_AND_ASSIGN(auto wait_semaphore, - device_->CreateSemaphore(0ull)); - IREE_ASSERT_OK_AND_ASSIGN(auto signal_semaphore, - device_->CreateSemaphore(0ull)); - - IREE_ASSERT_OK(command_queue->Submit({{{wait_semaphore.get(), 1ull}}, - {command_buffer.get()}, - {{signal_semaphore.get(), 1ull}}})); - - // Work shouldn't start until the wait semaphore reaches its payload value. - EXPECT_THAT(signal_semaphore->Query(), IsOkAndHolds(Eq(0ull))); - EXPECT_TRUE(IsDeadlineExceeded(command_queue->WaitIdle(Milliseconds(100)))); - - // Signal the wait semaphore, work should begin and complete. - IREE_ASSERT_OK(wait_semaphore->Signal(1ull)); - IREE_ASSERT_OK(signal_semaphore->Wait(1ull, InfiniteFuture())); -} - -// Tests using multiple wait and signal semaphores. -TEST_P(CommandQueueTest, WaitMultiple) { - auto command_queue = device_->dispatch_queues()[0]; - - IREE_ASSERT_OK_AND_ASSIGN( - auto command_buffer, - device_->CreateCommandBuffer(CommandBufferMode::kOneShot, - CommandCategory::kDispatch)); - IREE_ASSERT_OK_AND_ASSIGN(auto wait_semaphore_1, - device_->CreateSemaphore(0ull)); - IREE_ASSERT_OK_AND_ASSIGN(auto wait_semaphore_2, - device_->CreateSemaphore(0ull)); - IREE_ASSERT_OK_AND_ASSIGN(auto signal_semaphore_1, - device_->CreateSemaphore(0ull)); - IREE_ASSERT_OK_AND_ASSIGN(auto signal_semaphore_2, - device_->CreateSemaphore(0ull)); - - IREE_ASSERT_OK(command_queue->Submit( - {{{wait_semaphore_1.get(), 1ull}, {wait_semaphore_2.get(), 1ull}}, - {command_buffer.get()}, - {{signal_semaphore_1.get(), 1ull}, {signal_semaphore_2.get(), 1ull}}})); - - // Work shouldn't start until the wait semaphore reaches its payload value. - EXPECT_THAT(signal_semaphore_1->Query(), IsOkAndHolds(Eq(0ull))); - EXPECT_THAT(signal_semaphore_2->Query(), IsOkAndHolds(Eq(0ull))); - // Note: This fails with Vulkan timeline semaphore emulation (returns OK) - EXPECT_TRUE(IsDeadlineExceeded(command_queue->WaitIdle(Milliseconds(100)))); - - // Signal the wait semaphores, work should only begin after each is set. - IREE_ASSERT_OK(wait_semaphore_1->Signal(1ull)); - EXPECT_THAT(signal_semaphore_1->Query(), IsOkAndHolds(Eq(0ull))); - EXPECT_THAT(signal_semaphore_2->Query(), IsOkAndHolds(Eq(0ull))); - IREE_ASSERT_OK(wait_semaphore_2->Signal(1ull)); - - IREE_ASSERT_OK(command_queue->WaitIdle()); -} - -INSTANTIATE_TEST_SUITE_P( - AllDrivers, CommandQueueTest, - ::testing::ValuesIn( - // Disabled on Vulkan until tests pass with - // timeline semaphore emulation. - testing::RemoveDriverByName(testing::EnumerateAvailableDrivers(), - "vulkan")), - GenerateTestName()); - -} // namespace -} // namespace cts -} // namespace hal -} // namespace iree diff --git a/iree/hal/cts/cts_test_base.h b/iree/hal/cts/cts_test_base.h index 1aa33c3149cf6..881442157fa2a 100644 --- a/iree/hal/cts/cts_test_base.h +++ b/iree/hal/cts/cts_test_base.h @@ -19,14 +19,11 @@ #include #include -#include "iree/base/status.h" +#include "iree/base/api.h" #include "iree/hal/api.h" #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" -// TODO(3934): rebase this all on the C API. -#include "iree/hal/driver.h" - namespace iree { namespace hal { namespace cts { @@ -34,60 +31,102 @@ namespace cts { // Common setup for tests parameterized across all registered drivers. class CtsTestBase : public ::testing::TestWithParam { protected: - // Per-test-suite set-up. This is called before the first test in this test - // suite. We use it to set up drivers that must be reused between test cases - // to work around issues with driver lifetimes (specifically SwiftShader for - // Vulkan). - // - // TODO(#3933): this is a very nasty hack that indicates a serious issue. If - // we have to do it here in our test suite it means that every user of IREE - // will also have to do something like it. We should be reusing all drivers - // across tests in a suite (removing the vulkan-specific behavior here) but - // ALSO need a test that tries to create a driver twice. - static void SetUpTestSuite() { - iree_hal_driver_t* driver = NULL; - iree_status_t status = iree_hal_driver_registry_try_create_by_name( - iree_hal_driver_registry_default(), iree_make_cstring_view("vulkan"), - iree_allocator_system(), &driver); - if (iree_status_consume_code(status) == IREE_STATUS_OK) { - shared_drivers_["vulkan"] = - assign_ref(reinterpret_cast(driver)); - } - } - - // Per-test-suite tear-down. This is called after the last test in this test - // suite. We use it to destruct driver handles before program exit. This - // avoids us to rely on static object destruction happening after main(). It - // can cause unexpected problems when the driver also want to perform clean up - // at that time. - static void TearDownTestSuite() { shared_drivers_.clear(); } - - static std::map> shared_drivers_; - virtual void SetUp() { const std::string& driver_name = GetParam(); // Get driver with the given name and create its default device. // Skip drivers that are (gracefully) unavailable, fail if creation fails. - auto driver_or = GetDriver(driver_name); - if (IsUnavailable(driver_or.status())) { - IREE_LOG(WARNING) << "Skipping test as driver is unavailable: " - << driver_or.status(); + iree_hal_driver_t* driver; + iree_status_t status = TryGetDriver(driver_name, &driver); + if (iree_status_is_unavailable(status)) { + iree_status_free(status); + IREE_LOG(WARNING) << "Skipping test as driver is unavailable"; + GTEST_SKIP(); + return; + } + driver_ = driver; + + iree_hal_device_t* device; + status = iree_hal_driver_create_default_device( + driver_, iree_allocator_system(), &device); + if (iree_status_is_unavailable(status)) { + iree_status_free(status); + IREE_LOG(WARNING) << "Skipping test as driver is unavailable"; GTEST_SKIP(); return; } - IREE_ASSERT_OK_AND_ASSIGN(driver_, std::move(driver_or)); - IREE_LOG(INFO) << "Creating default device..."; - IREE_ASSERT_OK_AND_ASSIGN(device_, driver_->CreateDefaultDevice()); - IREE_LOG(INFO) << "Created device '" << device_->info().name() << "'"; + IREE_ASSERT_OK(status); + iree_status_free(status); + device_ = device; + + device_allocator_ = iree_hal_device_allocator(device_); + iree_hal_allocator_retain(device_allocator_); + } + + virtual void TearDown() { + if (device_allocator_) { + iree_hal_allocator_release(device_allocator_); + device_allocator_ = nullptr; + } + if (device_) { + iree_hal_device_release(device_); + device_ = nullptr; + } + if (driver_) { + iree_hal_driver_release(driver_); + driver_ = nullptr; + } + } + + // Submits |command_buffer| to the device and waits for it to complete before + // returning. + iree_status_t SubmitCommandBufferAndWait( + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t* command_buffer) { + iree_hal_semaphore_t* signal_semaphore = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_create(device_, 0ull, &signal_semaphore)); + + iree_hal_submission_batch_t submission_batch; + + // No wait semaphores. + submission_batch.wait_semaphores.count = 0; + submission_batch.wait_semaphores.semaphores = NULL; + submission_batch.wait_semaphores.payload_values = NULL; + + iree_hal_command_buffer_t* command_buffer_ptrs[] = {command_buffer}; + submission_batch.command_buffer_count = IREE_ARRAYSIZE(command_buffer_ptrs); + submission_batch.command_buffers = command_buffer_ptrs; + + // One signal semaphore from 0 -> 1. + iree_hal_semaphore_t* signal_semaphore_ptrs[] = {signal_semaphore}; + uint64_t payload_values[] = {1ull}; + submission_batch.signal_semaphores.count = + IREE_ARRAYSIZE(signal_semaphore_ptrs); + submission_batch.signal_semaphores.semaphores = signal_semaphore_ptrs; + submission_batch.signal_semaphores.payload_values = payload_values; + + iree_status_t status = + iree_hal_device_queue_submit(device_, command_categories, + /*queue_affinity=*/0, + /*batch_count=*/1, &submission_batch); + if (iree_status_is_ok(status)) { + status = iree_hal_semaphore_wait_with_deadline(signal_semaphore, 1ull, + IREE_TIME_INFINITE_FUTURE); + } + + iree_hal_semaphore_release(signal_semaphore); + return status; } - ref_ptr driver_; - ref_ptr device_; + iree_hal_driver_t* driver_ = nullptr; + iree_hal_device_t* device_ = nullptr; + iree_hal_allocator_t* device_allocator_ = nullptr; private: // Gets a HAL driver with the provided name, if available. - static StatusOr> GetDriver(const std::string& driver_name) { + static iree_status_t TryGetDriver(const std::string& driver_name, + iree_hal_driver_t** out_driver) { static std::set unavailable_driver_names; // If creation failed before, don't try again. @@ -96,15 +135,7 @@ class CtsTestBase : public ::testing::TestWithParam { return UnavailableErrorBuilder(IREE_LOC) << "Driver unavailable"; } - // Reuse an existing driver if possible. - auto found_it = shared_drivers_.find(driver_name); - if (found_it != shared_drivers_.end()) { - IREE_LOG(INFO) << "Reusing existing driver '" << driver_name << "'..."; - return add_ref(found_it->second); - } - // No existing driver, attempt to create. - IREE_LOG(INFO) << "Creating driver '" << driver_name << "'..."; iree_hal_driver_t* driver = NULL; iree_status_t status = iree_hal_driver_registry_try_create_by_name( iree_hal_driver_registry_default(), @@ -113,13 +144,13 @@ class CtsTestBase : public ::testing::TestWithParam { if (iree_status_is_unavailable(status)) { unavailable_driver_names.insert(driver_name); } - IREE_RETURN_IF_ERROR(status); - return assign_ref(reinterpret_cast(driver)); + if (iree_status_is_ok(status)) { + *out_driver = driver; + } + return status; } }; -std::map> CtsTestBase::shared_drivers_; - struct GenerateTestName { template std::string operator()( diff --git a/iree/hal/cts/descriptor_set_layout_test.cc b/iree/hal/cts/descriptor_set_layout_test.cc new file mode 100644 index 0000000000000..14dd2b5520eae --- /dev/null +++ b/iree/hal/cts/descriptor_set_layout_test.cc @@ -0,0 +1,87 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/cts/cts_test_base.h" +#include "iree/hal/testing/driver_registry.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace iree { +namespace hal { +namespace cts { + +class DescriptorSetLayoutTest : public CtsTestBase {}; + +// Note: bindingCount == 0 is valid in VkDescriptorSetLayoutCreateInfo: +// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkDescriptorSetLayoutCreateInfo.html +TEST_P(DescriptorSetLayoutTest, CreateWithNoBindings) { + iree_hal_descriptor_set_layout_t* descriptor_set_layout; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE, + /*binding_count=*/0, + /*bindings=*/NULL, &descriptor_set_layout)); + iree_hal_descriptor_set_layout_release(descriptor_set_layout); +} + +TEST_P(DescriptorSetLayoutTest, CreateWithOneBinding) { + iree_hal_descriptor_set_layout_t* descriptor_set_layout; + iree_hal_descriptor_set_layout_binding_t descriptor_set_layout_bindings[] = { + {/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + }; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE, + IREE_ARRAYSIZE(descriptor_set_layout_bindings), + descriptor_set_layout_bindings, &descriptor_set_layout)); + iree_hal_descriptor_set_layout_release(descriptor_set_layout); +} + +TEST_P(DescriptorSetLayoutTest, CreateWithTwoBindings) { + iree_hal_descriptor_set_layout_t* descriptor_set_layout; + iree_hal_descriptor_set_layout_binding_t descriptor_set_layout_bindings[] = { + {/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + {/*binding=*/1, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE}, + }; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE, + IREE_ARRAYSIZE(descriptor_set_layout_bindings), + descriptor_set_layout_bindings, &descriptor_set_layout)); + iree_hal_descriptor_set_layout_release(descriptor_set_layout); +} + +TEST_P(DescriptorSetLayoutTest, CreateWithPushDescriptorType) { + iree_hal_descriptor_set_layout_t* descriptor_set_layout; + iree_hal_descriptor_set_layout_binding_t descriptor_set_layout_bindings[] = { + {/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + {/*binding=*/1, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE}, + }; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY, + IREE_ARRAYSIZE(descriptor_set_layout_bindings), + descriptor_set_layout_bindings, &descriptor_set_layout)); + iree_hal_descriptor_set_layout_release(descriptor_set_layout); +} + +INSTANTIATE_TEST_SUITE_P( + AllDrivers, DescriptorSetLayoutTest, + ::testing::ValuesIn(testing::EnumerateAvailableDrivers()), + GenerateTestName()); + +} // namespace cts +} // namespace hal +} // namespace iree diff --git a/iree/hal/cts/descriptor_set_test.cc b/iree/hal/cts/descriptor_set_test.cc new file mode 100644 index 0000000000000..abf5f0912dfa4 --- /dev/null +++ b/iree/hal/cts/descriptor_set_test.cc @@ -0,0 +1,63 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/cts/cts_test_base.h" +#include "iree/hal/testing/driver_registry.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace iree { +namespace hal { +namespace cts { + +class DescriptorSetTest : public CtsTestBase {}; + +// TODO(scotttodd): enable once any driver implements non-push descriptor sets +// * also test with buffers in the bindings +// * also test usage in iree_hal_command_buffer_bind_descriptor_set +TEST_P(DescriptorSetTest, DISABLED_CreateWithTwoBindings) { + iree_hal_descriptor_set_layout_t* descriptor_set_layout; + iree_hal_descriptor_set_layout_binding_t descriptor_set_layout_bindings[] = { + {/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + {/*binding=*/1, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE}, + }; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE, + IREE_ARRAYSIZE(descriptor_set_layout_bindings), + descriptor_set_layout_bindings, &descriptor_set_layout)); + + iree_hal_descriptor_set_binding_t descriptor_set_bindings[] = { + {/*binding=*/0, /*buffer=*/NULL, /*offset=*/0, /*length=*/0}, + {/*binding=*/1, /*buffer=*/NULL, /*offset=*/0, /*length=*/0}, + }; + + iree_hal_descriptor_set_t* descriptor_set; + IREE_ASSERT_OK(iree_hal_descriptor_set_create( + device_, descriptor_set_layout, IREE_ARRAYSIZE(descriptor_set_bindings), + descriptor_set_bindings, &descriptor_set)); + + iree_hal_descriptor_set_release(descriptor_set); + iree_hal_descriptor_set_layout_release(descriptor_set_layout); +} + +INSTANTIATE_TEST_SUITE_P( + AllDrivers, DescriptorSetTest, + ::testing::ValuesIn(testing::EnumerateAvailableDrivers()), + GenerateTestName()); + +} // namespace cts +} // namespace hal +} // namespace iree diff --git a/iree/hal/cts/driver_test.cc b/iree/hal/cts/driver_test.cc index 3a7b757eaa9bd..922785221ed8d 100644 --- a/iree/hal/cts/driver_test.cc +++ b/iree/hal/cts/driver_test.cc @@ -23,17 +23,27 @@ namespace cts { class DriverTest : public CtsTestBase {}; -TEST_P(DriverTest, CreateDefaultDevice) { - IREE_LOG(INFO) << "Device details:\n" << device_->DebugString(); -} - -TEST_P(DriverTest, EnumerateAndCreateAvailableDevices) { - IREE_ASSERT_OK_AND_ASSIGN(auto devices, driver_->EnumerateAvailableDevices()); - - for (iree_host_size_t i = 0; i < devices.size(); ++i) { - IREE_ASSERT_OK_AND_ASSIGN(auto device, driver_->CreateDevice(devices[i])); - IREE_LOG(INFO) << "Device #" << i << " details:\n" << device->DebugString(); +TEST_P(DriverTest, QueryAndCreateAvailableDevices) { + iree_hal_device_info_t* device_infos; + iree_host_size_t device_info_count; + IREE_ASSERT_OK(iree_hal_driver_query_available_devices( + driver_, iree_allocator_system(), &device_infos, &device_info_count)); + + IREE_LOG(INFO) << "Driver has " << device_info_count << " device(s)"; + for (iree_host_size_t i = 0; i < device_info_count; ++i) { + IREE_LOG(INFO) << " Creating device '" + << std::string(device_infos[i].name.data, + device_infos[i].name.size) + << "'"; + iree_hal_device_t* device; + IREE_ASSERT_OK(iree_hal_driver_create_device( + driver_, device_infos[i].device_id, iree_allocator_system(), &device)); + iree_string_view_t device_id = iree_hal_device_id(device); + IREE_LOG(INFO) << " Created device with id: '" + << std::string(device_id.data, device_id.size) << "'"; } + + iree_allocator_free(iree_allocator_system(), device_infos); } INSTANTIATE_TEST_SUITE_P( diff --git a/iree/hal/cts/event_test.cc b/iree/hal/cts/event_test.cc new file mode 100644 index 0000000000000..b7282773c55f1 --- /dev/null +++ b/iree/hal/cts/event_test.cc @@ -0,0 +1,125 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/cts/cts_test_base.h" +#include "iree/hal/testing/driver_registry.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace iree { +namespace hal { +namespace cts { + +class EventTest : public CtsTestBase {}; + +TEST_P(EventTest, Create) { + iree_hal_event_t* event; + IREE_ASSERT_OK(iree_hal_event_create(device_, &event)); + iree_hal_event_release(event); +} + +TEST_P(EventTest, SignalAndReset) { + iree_hal_event_t* event; + IREE_ASSERT_OK(iree_hal_event_create(device_, &event)); + + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer)); + + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_signal_event( + command_buffer, event, IREE_HAL_EXECUTION_STAGE_COMMAND_PROCESS)); + IREE_ASSERT_OK(iree_hal_command_buffer_reset_event( + command_buffer, event, IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + IREE_ASSERT_OK(SubmitCommandBufferAndWait(IREE_HAL_COMMAND_CATEGORY_DISPATCH, + command_buffer)); + + iree_hal_event_release(event); + iree_hal_command_buffer_release(command_buffer); +} + +TEST_P(EventTest, SubmitWithChainedCommandBuffers) { + iree_hal_event_t* event; + IREE_ASSERT_OK(iree_hal_event_create(device_, &event)); + + iree_hal_command_buffer_t* command_buffer_1; + iree_hal_command_buffer_t* command_buffer_2; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer_1)); + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer_2)); + + // First command buffer signals the event when it completes. + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer_1)); + IREE_ASSERT_OK(iree_hal_command_buffer_signal_event( + command_buffer_1, event, IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer_1)); + + // Second command buffer waits on the event before starting. + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer_2)); + const iree_hal_event_t* event_pts[] = {event}; + // TODO(scotttodd): verify execution stage usage (check Vulkan spec) + IREE_ASSERT_OK(iree_hal_command_buffer_wait_events( + command_buffer_2, IREE_ARRAYSIZE(event_pts), event_pts, + /*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE, + /*target_stage_mask=*/IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE, + /*memory_barrier_count=*/0, + /*memory_barriers=*/NULL, /*buffer_barrier_count=*/0, + /*buffer_barriers=*/NULL)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer_2)); + + // No wait semaphores, one signal which we immediately wait on after submit. + iree_hal_submission_batch_t submission_batch; + submission_batch.wait_semaphores.count = 0; + submission_batch.wait_semaphores.semaphores = NULL; + submission_batch.wait_semaphores.payload_values = NULL; + iree_hal_command_buffer_t* command_buffer_ptrs[] = {command_buffer_1, + command_buffer_2}; + submission_batch.command_buffer_count = IREE_ARRAYSIZE(command_buffer_ptrs); + submission_batch.command_buffers = command_buffer_ptrs; + iree_hal_semaphore_t* signal_semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &signal_semaphore)); + iree_hal_semaphore_t* signal_semaphore_ptrs[] = {signal_semaphore}; + submission_batch.signal_semaphores.count = + IREE_ARRAYSIZE(signal_semaphore_ptrs); + submission_batch.signal_semaphores.semaphores = signal_semaphore_ptrs; + uint64_t payload_values[] = {1ull}; + submission_batch.signal_semaphores.payload_values = payload_values; + + IREE_ASSERT_OK( + iree_hal_device_queue_submit(device_, IREE_HAL_COMMAND_CATEGORY_DISPATCH, + /*queue_affinity=*/0, + /*batch_count=*/1, &submission_batch)); + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + signal_semaphore, 1ull, IREE_TIME_INFINITE_FUTURE)); + + iree_hal_command_buffer_release(command_buffer_1); + iree_hal_command_buffer_release(command_buffer_2); + iree_hal_semaphore_release(signal_semaphore); + iree_hal_event_release(event); +} + +INSTANTIATE_TEST_SUITE_P( + AllDrivers, EventTest, + ::testing::ValuesIn(testing::EnumerateAvailableDrivers()), + GenerateTestName()); + +} // namespace cts +} // namespace hal +} // namespace iree diff --git a/iree/hal/cts/executable_layout_test.cc b/iree/hal/cts/executable_layout_test.cc new file mode 100644 index 0000000000000..c43da43b1f480 --- /dev/null +++ b/iree/hal/cts/executable_layout_test.cc @@ -0,0 +1,111 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/cts/cts_test_base.h" +#include "iree/hal/testing/driver_registry.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace iree { +namespace hal { +namespace cts { + +class ExecutableLayoutTest : public CtsTestBase {}; + +TEST_P(ExecutableLayoutTest, CreateWithNoLayouts) { + iree_hal_executable_layout_t* executable_layout; + IREE_ASSERT_OK(iree_hal_executable_layout_create( + device_, /*set_layout_count=*/0, NULL, + /*push_constants=*/0, &executable_layout)); + + iree_hal_executable_layout_release(executable_layout); +} + +TEST_P(ExecutableLayoutTest, CreateWithPushConstants) { + iree_hal_executable_layout_t* executable_layout; + // Note: The Vulkan maxPushConstantsSize limit must be at least 128 bytes: + // https://www.khronos.org/registry/vulkan/specs/1.2/html/vkspec.html#limits-minmax + IREE_ASSERT_OK(iree_hal_executable_layout_create( + device_, /*set_layout_count=*/0, NULL, + /*push_constants=*/5, &executable_layout)); + + iree_hal_executable_layout_release(executable_layout); +} + +TEST_P(ExecutableLayoutTest, CreateWithOneLayout) { + iree_hal_descriptor_set_layout_t* descriptor_set_layout; + iree_hal_descriptor_set_layout_binding_t descriptor_set_layout_bindings[] = { + {/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + {/*binding=*/1, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE}, + }; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE, + IREE_ARRAYSIZE(descriptor_set_layout_bindings), + descriptor_set_layout_bindings, &descriptor_set_layout)); + + iree_hal_executable_layout_t* executable_layout; + IREE_ASSERT_OK(iree_hal_executable_layout_create( + device_, /*set_layout_count=*/1, &descriptor_set_layout, + /*push_constants=*/0, &executable_layout)); + + iree_hal_executable_layout_release(executable_layout); + iree_hal_descriptor_set_layout_release(descriptor_set_layout); +} + +TEST_P(ExecutableLayoutTest, CreateWithTwoLayouts) { + iree_hal_descriptor_set_layout_t* descriptor_set_layouts[2]; + iree_hal_descriptor_set_layout_binding_t layout_bindings_0[] = { + {/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + {/*binding=*/1, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE}, + }; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE, + IREE_ARRAYSIZE(layout_bindings_0), layout_bindings_0, + &descriptor_set_layouts[0])); + + iree_hal_descriptor_set_layout_binding_t layout_bindings_1[] = { + {/*binding=*/0, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + {/*binding=*/1, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_READ}, + {/*binding=*/2, /*type=*/IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER, + /*access=*/IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE}, + }; + IREE_ASSERT_OK(iree_hal_descriptor_set_layout_create( + device_, IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE, + IREE_ARRAYSIZE(layout_bindings_1), layout_bindings_1, + &descriptor_set_layouts[1])); + + iree_hal_executable_layout_t* executable_layout; + IREE_ASSERT_OK(iree_hal_executable_layout_create( + device_, IREE_ARRAYSIZE(descriptor_set_layouts), descriptor_set_layouts, + /*push_constants=*/0, &executable_layout)); + + iree_hal_executable_layout_release(executable_layout); + iree_hal_descriptor_set_layout_release(descriptor_set_layouts[0]); + iree_hal_descriptor_set_layout_release(descriptor_set_layouts[1]); +} + +INSTANTIATE_TEST_SUITE_P( + AllDrivers, ExecutableLayoutTest, + ::testing::ValuesIn(testing::EnumerateAvailableDrivers()), + GenerateTestName()); + +} // namespace cts +} // namespace hal +} // namespace iree diff --git a/iree/hal/cts/semaphore_test.cc b/iree/hal/cts/semaphore_test.cc index 385590c56a383..d2a31fa4a8851 100644 --- a/iree/hal/cts/semaphore_test.cc +++ b/iree/hal/cts/semaphore_test.cc @@ -26,119 +26,373 @@ class SemaphoreTest : public CtsTestBase {}; // Tests that a semaphore that is unused properly cleans itself up. TEST_P(SemaphoreTest, NoOp) { - IREE_ASSERT_OK_AND_ASSIGN(auto semaphore, device_->CreateSemaphore(123u)); - IREE_ASSERT_OK_AND_ASSIGN(uint64_t value, semaphore->Query()); - EXPECT_EQ(123u, value); + iree_hal_semaphore_t* semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 123ull, &semaphore)); + + uint64_t value; + IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore, &value)); + EXPECT_EQ(123ull, value); + + iree_hal_semaphore_release(semaphore); } // Tests that a semaphore will accept new values as it is signaled. TEST_P(SemaphoreTest, NormalSignaling) { - IREE_ASSERT_OK_AND_ASSIGN(auto semaphore, device_->CreateSemaphore(2u)); - EXPECT_EQ(2u, semaphore->Query().value()); - IREE_EXPECT_OK(semaphore->Signal(3u)); - EXPECT_EQ(3u, semaphore->Query().value()); - IREE_EXPECT_OK(semaphore->Signal(40u)); - EXPECT_EQ(40u, semaphore->Query().value()); + iree_hal_semaphore_t* semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 2ull, &semaphore)); + + uint64_t value; + IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore, &value)); + EXPECT_EQ(2ull, value); + IREE_ASSERT_OK(iree_hal_semaphore_signal(semaphore, 3ull)); + IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore, &value)); + EXPECT_EQ(3ull, value); + IREE_ASSERT_OK(iree_hal_semaphore_signal(semaphore, 40ull)); + IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore, &value)); + EXPECT_EQ(40ull, value); + + iree_hal_semaphore_release(semaphore); } // Note: Behavior is undefined when signaling with decreasing values, so we // can't reliably test it across backends. Some backends may return errors, // while others may accept the new, decreasing, values. -// Tests that a semaphore that has failed will remain in a failed state. +// Tests semaphore failure handling. TEST_P(SemaphoreTest, Failure) { - IREE_ASSERT_OK_AND_ASSIGN(auto semaphore, device_->CreateSemaphore(2u)); - // Signal to 3. - IREE_EXPECT_OK(semaphore->Signal(3u)); - EXPECT_EQ(3u, semaphore->Query().value()); + iree_hal_semaphore_t* semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 2ull, &semaphore)); - // Fail now. - semaphore->Fail(UnknownErrorBuilder(IREE_LOC)); - EXPECT_TRUE(IsUnknown(semaphore->Query().status())); + IREE_ASSERT_OK(iree_hal_semaphore_signal(semaphore, 3ull)); + uint64_t value; + IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore, &value)); + EXPECT_EQ(3ull, value); - // Signaling again is undefined behavior. Some backends may return a - // sticky failure status while others may silently process new signal values. + iree_hal_semaphore_fail(semaphore, + iree_status_from_code(IREE_STATUS_UNKNOWN)); + EXPECT_TRUE( + iree_status_is_unknown(iree_hal_semaphore_query(semaphore, &value))); + + // Signaling again is undefined behavior. Some backends may return a sticky + // failure status while others may silently process new signal values. + + iree_hal_semaphore_release(semaphore); } // Tests waiting on no semaphores. TEST_P(SemaphoreTest, EmptyWait) { - IREE_EXPECT_OK(device_->WaitAllSemaphores({}, InfiniteFuture())); + IREE_ASSERT_OK(iree_hal_device_wait_semaphores_with_deadline( + device_, IREE_HAL_WAIT_MODE_ANY, NULL, IREE_TIME_INFINITE_FUTURE)); + IREE_ASSERT_OK(iree_hal_device_wait_semaphores_with_deadline( + device_, IREE_HAL_WAIT_MODE_ALL, NULL, IREE_TIME_INFINITE_FUTURE)); } // Tests waiting on a semaphore that has already been signaled. TEST_P(SemaphoreTest, WaitAlreadySignaled) { - IREE_ASSERT_OK_AND_ASSIGN(auto semaphore, device_->CreateSemaphore(2u)); + iree_hal_semaphore_t* semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 2ull, &semaphore)); + // Test both previous and current values. - IREE_EXPECT_OK( - device_->WaitAllSemaphores({{semaphore.get(), 1u}}, InfiniteFuture())); - IREE_EXPECT_OK( - device_->WaitAllSemaphores({{semaphore.get(), 2u}}, InfiniteFuture())); + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + semaphore, 1ull, IREE_TIME_INFINITE_FUTURE)); + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + semaphore, 2ull, IREE_TIME_INFINITE_FUTURE)); + + iree_hal_semaphore_release(semaphore); } // Tests waiting on a semaphore that has not been signaled. TEST_P(SemaphoreTest, WaitUnsignaled) { - IREE_ASSERT_OK_AND_ASSIGN(auto semaphore, device_->CreateSemaphore(2u)); + iree_hal_semaphore_t* semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 2ull, &semaphore)); + // NOTE: we don't actually block here because otherwise we'd lock up. // Result status is undefined - some backends may return DeadlineExceededError // while others may return success. - device_->WaitAllSemaphores({{semaphore.get(), 3u}}, InfinitePast()) - .IgnoreError(); + IREE_IGNORE_ERROR(iree_hal_semaphore_wait_with_deadline( + semaphore, 3ull, IREE_TIME_INFINITE_PAST)); + + iree_hal_semaphore_release(semaphore); } // Waiting on a failed semaphore is undefined behavior. Some backends may // return UnknownError while others may succeed. -// Waiting all semaphores but not all are signaled. +// Tests IREE_HAL_WAIT_MODE_ALL when not all are signaled. TEST_P(SemaphoreTest, WaitAllButNotAllSignaled) { - IREE_ASSERT_OK_AND_ASSIGN(auto a, device_->CreateSemaphore(0u)); - IREE_ASSERT_OK_AND_ASSIGN(auto b, device_->CreateSemaphore(1u)); + iree_hal_semaphore_t* semaphore_a; + iree_hal_semaphore_t* semaphore_b; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &semaphore_a)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 1ull, &semaphore_b)); + + iree_hal_semaphore_list_t semaphore_list; + iree_hal_semaphore_t* semaphore_ptrs[] = {semaphore_a, semaphore_b}; + semaphore_list.count = IREE_ARRAYSIZE(semaphore_ptrs); + semaphore_list.semaphores = semaphore_ptrs; + uint64_t payload_values[] = {1ull, 1ull}; + semaphore_list.payload_values = payload_values; + // NOTE: we don't actually block here because otherwise we'd lock up. // Result status is undefined - some backends may return DeadlineExceededError // while others may return success. - device_->WaitAllSemaphores({{a.get(), 1u}, {b.get(), 1u}}, InfinitePast()) - .IgnoreError(); + IREE_IGNORE_ERROR(iree_hal_device_wait_semaphores_with_deadline( + device_, IREE_HAL_WAIT_MODE_ALL, &semaphore_list, + IREE_TIME_INFINITE_PAST)); + + iree_hal_semaphore_release(semaphore_a); + iree_hal_semaphore_release(semaphore_b); } -// Waiting all semaphores and all are signaled. +// Tests IREE_HAL_WAIT_MODE_ALL when all are signaled. TEST_P(SemaphoreTest, WaitAllAndAllSignaled) { - IREE_ASSERT_OK_AND_ASSIGN(auto a, device_->CreateSemaphore(1u)); - IREE_ASSERT_OK_AND_ASSIGN(auto b, device_->CreateSemaphore(1u)); - IREE_ASSERT_OK(device_->WaitAllSemaphores({{a.get(), 1u}, {b.get(), 1u}}, - InfiniteFuture())); + iree_hal_semaphore_t* semaphore_a; + iree_hal_semaphore_t* semaphore_b; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 1ull, &semaphore_a)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 1ull, &semaphore_b)); + + iree_hal_semaphore_list_t semaphore_list; + iree_hal_semaphore_t* semaphore_ptrs[] = {semaphore_a, semaphore_b}; + semaphore_list.count = IREE_ARRAYSIZE(semaphore_ptrs); + semaphore_list.semaphores = semaphore_ptrs; + uint64_t payload_values[] = {1ull, 1ull}; + semaphore_list.payload_values = payload_values; + + // NOTE: we don't actually block here because otherwise we'd lock up. + // Result status is undefined - some backends may return DeadlineExceededError + // while others may return success. + IREE_IGNORE_ERROR(iree_hal_device_wait_semaphores_with_deadline( + device_, IREE_HAL_WAIT_MODE_ALL, &semaphore_list, + IREE_TIME_INFINITE_FUTURE)); + + iree_hal_semaphore_release(semaphore_a); + iree_hal_semaphore_release(semaphore_b); } -// Waiting any semaphore to signal. -TEST_P(SemaphoreTest, WaitAny) { - // TODO: fix this. - if (driver_->name() == "dylib" || driver_->name() == "vmla" || - driver_->name() == "vulkan") { - GTEST_SKIP(); - } - - IREE_ASSERT_OK_AND_ASSIGN(auto a, device_->CreateSemaphore(0u)); - IREE_ASSERT_OK_AND_ASSIGN(auto b, device_->CreateSemaphore(1u)); - IREE_ASSERT_OK(device_->WaitAnySemaphore({{a.get(), 1u}, {b.get(), 1u}}, - InfiniteFuture())); +// Tests IREE_HAL_WAIT_MODE_ANY. +// **Fails using timeline semaphore emulation** +TEST_P(SemaphoreTest, DISABLED_WaitAny) { + iree_hal_semaphore_t* semaphore_a; + iree_hal_semaphore_t* semaphore_b; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &semaphore_a)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 1ull, &semaphore_b)); + + iree_hal_semaphore_list_t semaphore_list; + iree_hal_semaphore_t* semaphore_ptrs[] = {semaphore_a, semaphore_b}; + semaphore_list.count = IREE_ARRAYSIZE(semaphore_ptrs); + semaphore_list.semaphores = semaphore_ptrs; + uint64_t payload_values[] = {1ull, 1ull}; + semaphore_list.payload_values = payload_values; + + IREE_ASSERT_OK(iree_hal_device_wait_semaphores_with_deadline( + device_, IREE_HAL_WAIT_MODE_ANY, &semaphore_list, + IREE_TIME_INFINITE_FUTURE)); + + iree_hal_semaphore_release(semaphore_a); + iree_hal_semaphore_release(semaphore_b); } // Tests threading behavior by ping-ponging between the test main thread and // a little thread. TEST_P(SemaphoreTest, PingPong) { - IREE_ASSERT_OK_AND_ASSIGN(auto a2b, device_->CreateSemaphore(0u)); - IREE_ASSERT_OK_AND_ASSIGN(auto b2a, device_->CreateSemaphore(0u)); + iree_hal_semaphore_t* a2b; + iree_hal_semaphore_t* b2a; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &a2b)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &b2a)); std::thread thread([&]() { // Should advance right past this because the value is already set. - IREE_ASSERT_OK( - device_->WaitAllSemaphores({{a2b.get(), 0u}}, InfiniteFuture())); - IREE_ASSERT_OK(b2a->Signal(1u)); - // Jump ahead. - IREE_ASSERT_OK( - device_->WaitAllSemaphores({{a2b.get(), 4u}}, InfiniteFuture())); + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + a2b, 0ull, IREE_TIME_INFINITE_FUTURE)); + IREE_ASSERT_OK(iree_hal_semaphore_signal(b2a, 1ull)); + // Jump ahead (blocking at first). + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + a2b, 4ull, IREE_TIME_INFINITE_FUTURE)); }); - IREE_ASSERT_OK( - device_->WaitAllSemaphores({{b2a.get(), 1u}}, InfiniteFuture())); - IREE_ASSERT_OK(a2b->Signal(4u)); + // Block until thread signals. + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + b2a, 1ull, IREE_TIME_INFINITE_FUTURE)); + IREE_ASSERT_OK(iree_hal_semaphore_signal(a2b, 4ull)); thread.join(); + + iree_hal_semaphore_release(a2b); + iree_hal_semaphore_release(b2a); +} + +TEST_P(SemaphoreTest, SubmitWithNoCommandBuffers) { + // No waits, one signal which we immediately wait on after submit. + iree_hal_submission_batch_t submission_batch; + submission_batch.wait_semaphores.count = 0; + submission_batch.wait_semaphores.semaphores = NULL; + submission_batch.wait_semaphores.payload_values = NULL; + submission_batch.command_buffer_count = 0; + submission_batch.command_buffers = NULL; + iree_hal_semaphore_t* signal_semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &signal_semaphore)); + iree_hal_semaphore_t* signal_semaphore_ptrs[] = {signal_semaphore}; + submission_batch.signal_semaphores.count = + IREE_ARRAYSIZE(signal_semaphore_ptrs); + submission_batch.signal_semaphores.semaphores = signal_semaphore_ptrs; + uint64_t payload_values[] = {1ull}; + submission_batch.signal_semaphores.payload_values = payload_values; + + IREE_ASSERT_OK( + iree_hal_device_queue_submit(device_, IREE_HAL_COMMAND_CATEGORY_DISPATCH, + /*queue_affinity=*/0, + /*batch_count=*/1, &submission_batch)); + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + signal_semaphore, 1ull, IREE_TIME_INFINITE_FUTURE)); + + iree_hal_semaphore_release(signal_semaphore); +} + +TEST_P(SemaphoreTest, SubmitAndSignal) { + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer)); + + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + // No waits, one signal which we immediately wait on after submit. + iree_hal_submission_batch_t submission_batch; + submission_batch.wait_semaphores.count = 0; + submission_batch.wait_semaphores.semaphores = NULL; + submission_batch.wait_semaphores.payload_values = NULL; + submission_batch.command_buffer_count = 1; + submission_batch.command_buffers = &command_buffer; + iree_hal_semaphore_t* signal_semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &signal_semaphore)); + iree_hal_semaphore_t* signal_semaphore_ptrs[] = {signal_semaphore}; + submission_batch.signal_semaphores.count = + IREE_ARRAYSIZE(signal_semaphore_ptrs); + submission_batch.signal_semaphores.semaphores = signal_semaphore_ptrs; + uint64_t payload_values[] = {1ull}; + submission_batch.signal_semaphores.payload_values = payload_values; + + IREE_ASSERT_OK( + iree_hal_device_queue_submit(device_, IREE_HAL_COMMAND_CATEGORY_DISPATCH, + /*queue_affinity=*/0, + /*batch_count=*/1, &submission_batch)); + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + signal_semaphore, 1ull, IREE_TIME_INFINITE_FUTURE)); + + iree_hal_command_buffer_release(command_buffer); + iree_hal_semaphore_release(signal_semaphore); +} + +TEST_P(SemaphoreTest, SubmitWithWait) { + // Empty command buffer. + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + // One wait and one signal semaphore. + iree_hal_submission_batch_t submission_batch; + iree_hal_semaphore_t* wait_semaphore; + iree_hal_semaphore_t* signal_semaphore; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &wait_semaphore)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 100ull, &signal_semaphore)); + iree_hal_semaphore_t* wait_semaphore_ptrs[] = {wait_semaphore}; + iree_hal_semaphore_t* signal_semaphore_ptrs[] = {signal_semaphore}; + uint64_t wait_payload_values[] = {1ull}; + uint64_t signal_payload_values[] = {101ull}; + submission_batch.wait_semaphores.count = IREE_ARRAYSIZE(wait_semaphore_ptrs); + submission_batch.wait_semaphores.semaphores = wait_semaphore_ptrs; + submission_batch.wait_semaphores.payload_values = wait_payload_values; + submission_batch.command_buffer_count = 1; + submission_batch.command_buffers = &command_buffer; + submission_batch.signal_semaphores.count = + IREE_ARRAYSIZE(signal_semaphore_ptrs); + submission_batch.signal_semaphores.semaphores = signal_semaphore_ptrs; + submission_batch.signal_semaphores.payload_values = signal_payload_values; + + IREE_ASSERT_OK( + iree_hal_device_queue_submit(device_, IREE_HAL_COMMAND_CATEGORY_DISPATCH, + /*queue_affinity=*/0, + /*batch_count=*/1, &submission_batch)); + + // Work shouldn't start until the wait semaphore reaches its payload value. + uint64_t value; + IREE_ASSERT_OK(iree_hal_semaphore_query(signal_semaphore, &value)); + EXPECT_EQ(100ull, value); + + // Signal the wait semaphore, work should begin and complete. + IREE_ASSERT_OK(iree_hal_semaphore_signal(wait_semaphore, 1ull)); + IREE_ASSERT_OK(iree_hal_semaphore_wait_with_deadline( + signal_semaphore, 101ull, IREE_TIME_INFINITE_FUTURE)); + + iree_hal_command_buffer_release(command_buffer); + iree_hal_semaphore_release(wait_semaphore); + iree_hal_semaphore_release(signal_semaphore); +} + +TEST_P(SemaphoreTest, SubmitWithMultipleSemaphores) { + iree_hal_command_buffer_t* command_buffer; + IREE_ASSERT_OK(iree_hal_command_buffer_create( + device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, + IREE_HAL_COMMAND_CATEGORY_DISPATCH, &command_buffer)); + + IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer)); + IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer)); + + iree_hal_submission_batch_t submission_batch; + iree_hal_semaphore_t* wait_semaphore_1; + iree_hal_semaphore_t* wait_semaphore_2; + iree_hal_semaphore_t* signal_semaphore_1; + iree_hal_semaphore_t* signal_semaphore_2; + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &wait_semaphore_1)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &wait_semaphore_2)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &signal_semaphore_1)); + IREE_ASSERT_OK(iree_hal_semaphore_create(device_, 0ull, &signal_semaphore_2)); + iree_hal_semaphore_t* wait_semaphore_ptrs[] = {wait_semaphore_1, + wait_semaphore_2}; + iree_hal_semaphore_t* signal_semaphore_ptrs[] = {signal_semaphore_1, + signal_semaphore_2}; + uint64_t wait_payload_values[] = {1ull, 1ull}; + uint64_t signal_payload_values[] = {1ull, 1ull}; + submission_batch.wait_semaphores.count = IREE_ARRAYSIZE(wait_semaphore_ptrs); + submission_batch.wait_semaphores.semaphores = wait_semaphore_ptrs; + submission_batch.wait_semaphores.payload_values = wait_payload_values; + submission_batch.command_buffer_count = 1; + submission_batch.command_buffers = &command_buffer; + submission_batch.signal_semaphores.count = + IREE_ARRAYSIZE(signal_semaphore_ptrs); + submission_batch.signal_semaphores.semaphores = signal_semaphore_ptrs; + submission_batch.signal_semaphores.payload_values = signal_payload_values; + + IREE_ASSERT_OK( + iree_hal_device_queue_submit(device_, IREE_HAL_COMMAND_CATEGORY_DISPATCH, + /*queue_affinity=*/0, + /*batch_count=*/1, &submission_batch)); + + // Work shouldn't start until all wait semaphores reach their payload values. + uint64_t value; + IREE_ASSERT_OK(iree_hal_semaphore_query(signal_semaphore_1, &value)); + EXPECT_EQ(0ull, value); + IREE_ASSERT_OK(iree_hal_semaphore_query(signal_semaphore_2, &value)); + EXPECT_EQ(0ull, value); + + // Signal the wait semaphores, work should begin and complete. + IREE_ASSERT_OK(iree_hal_semaphore_signal(wait_semaphore_1, 1ull)); + IREE_ASSERT_OK(iree_hal_semaphore_signal(wait_semaphore_2, 1ull)); + + iree_hal_semaphore_list_t signal_semaphore_list; + signal_semaphore_list.count = IREE_ARRAYSIZE(signal_semaphore_ptrs); + signal_semaphore_list.semaphores = signal_semaphore_ptrs; + uint64_t payload_values[] = {1ull, 1ull}; + signal_semaphore_list.payload_values = payload_values; + IREE_ASSERT_OK(iree_hal_device_wait_semaphores_with_deadline( + device_, IREE_HAL_WAIT_MODE_ALL, &signal_semaphore_list, + IREE_TIME_INFINITE_FUTURE)); + + iree_hal_command_buffer_release(command_buffer); + iree_hal_semaphore_release(wait_semaphore_1); + iree_hal_semaphore_release(wait_semaphore_2); + iree_hal_semaphore_release(signal_semaphore_1); + iree_hal_semaphore_release(signal_semaphore_2); } INSTANTIATE_TEST_SUITE_P( diff --git a/iree/hal/debug_capture_manager.h b/iree/hal/debug_capture_manager.h deleted file mode 100644 index 6105d91e56359..0000000000000 --- a/iree/hal/debug_capture_manager.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEBUG_CAPTURE_MANAGER_H_ -#define IREE_HAL_DEBUG_CAPTURE_MANAGER_H_ - -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -// Interface for interacting with command recorders / debuggers. -// -// Subclasses connect to tools like RenderDoc or MTLCaptureManager and use them -// to record commands sent to underlying APIs like Vulkan or Metal, for future -// debugging and analysis. -class DebugCaptureManager { - public: - DebugCaptureManager() {} - virtual ~DebugCaptureManager() = default; - - // Attempts to connect to a command recorder, if not already connected. - // - // This should be called *before* the underlying system and its devices (such - // as a VkInstance and its VkDevices) are initialized, so the command recorder - // can inject any necessary hooks. - virtual Status Connect() = 0; - - // Disconnects from a connected command recorder, if connected. - // This implicitly stops capture if currently capturing. - virtual void Disconnect() = 0; - - // Returns true if connected to a command recorder. - virtual bool is_connected() const = 0; - - // Starts capturing commands. - // Must already be connected and must not already be capturing. - virtual void StartCapture() = 0; - - // Stops capturing commands and saves the capture. - // Must already be connected and capturing. - virtual void StopCapture() = 0; - - // Returns true if currently capturing commands. - virtual bool is_capturing() const = 0; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEBUG_CAPTURE_MANAGER_H_ diff --git a/iree/hal/deferred_buffer.cc b/iree/hal/deferred_buffer.cc deleted file mode 100644 index 07b6894b74906..0000000000000 --- a/iree/hal/deferred_buffer.cc +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/deferred_buffer.h" - -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -DeferredBuffer::DeferredBuffer(Allocator* allocator, - MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, - device_size_t byte_length) - : Buffer(allocator, memory_type, allowed_access, usage, 0, 0, byte_length) { -} - -DeferredBuffer::~DeferredBuffer() = default; - -Status DeferredBuffer::GrowByteLength(device_size_t new_byte_length) { - if (parent_buffer_) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Attempting to set min allocation size while bound to an " - "allocation"; - } - if (byte_length_ != kWholeBuffer && new_byte_length < byte_length_) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Attempting to shrink a buffer to " << new_byte_length - << " when it has a minimum size of " << byte_length_; - } - byte_length_ = new_byte_length; - return OkStatus(); -} - -Status DeferredBuffer::BindAllocation(ref_ptr allocated_buffer, - device_size_t byte_offset, - device_size_t byte_length) { - // We can only be bound to allocations that are compatible with our specified - // allocator and usage. - if (!allocator_->CanUseBuffer(allocated_buffer.get(), usage())) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Allocation is not compatible with the allocator specified for " - "the deferred buffer"; - } - - // Calculate the range in the allocated_buffer that we are interested in. - IREE_RETURN_IF_ERROR( - Buffer::CalculateRange(0, allocated_buffer->byte_length(), byte_offset, - byte_length, &byte_offset, &byte_length)); - - // Verify that we have enough bytes for what we've promised. - if (byte_length < byte_length_) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Allocation range is too small; min_allocation_size=" - << byte_length_ << " but the range of " << byte_offset << "-" - << (byte_offset + byte_length - 1) << " (" << byte_length - << "b) is too small"; - } - - allocated_buffer_ = allocated_buffer.get(); - parent_buffer_ = std::move(allocated_buffer); - byte_offset_ = byte_offset; - return OkStatus(); -} - -void DeferredBuffer::ResetAllocation() { - allocated_buffer_ = this; - parent_buffer_.reset(); - byte_offset_ = 0; -} - -StatusOr DeferredBuffer::ResolveAllocation() const { - // If you get errors here then someone allocated the buffer with - // MemoryType::kTransient and you are trying to use it outside of the time - // it is actually allocated (such as during CommandBuffer evaluation). If - // you need to use the buffer in non-transient ways then allocate the buffer - // without the MemoryType::kTransient flag. - if (!parent_buffer_) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Attempting to use a transient buffer prior to allocation: " - << DebugString(); - } - return parent_buffer_.get(); -} - -Status DeferredBuffer::FillImpl(device_size_t byte_offset, - device_size_t byte_length, const void* pattern, - device_size_t pattern_length) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->FillImpl(byte_offset, byte_length, pattern, - pattern_length); -} - -Status DeferredBuffer::ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->ReadDataImpl(source_offset, data, data_length); -} - -Status DeferredBuffer::WriteDataImpl(device_size_t target_offset, - const void* data, - device_size_t data_length) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->WriteDataImpl(target_offset, data, data_length); -} - -Status DeferredBuffer::CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->CopyDataImpl(target_offset, source_buffer, - source_offset, data_length); -} - -Status DeferredBuffer::MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->MapMemoryImpl(mapping_mode, memory_access, - local_byte_offset, local_byte_length, - out_data); -} - -Status DeferredBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, - void* data) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->UnmapMemoryImpl(local_byte_offset, local_byte_length, - data); -} - -Status DeferredBuffer::InvalidateMappedMemoryImpl( - device_size_t local_byte_offset, device_size_t local_byte_length) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->InvalidateMappedMemoryImpl(local_byte_offset, - local_byte_length); -} - -Status DeferredBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - IREE_ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->FlushMappedMemoryImpl(local_byte_offset, - local_byte_length); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/deferred_buffer.h b/iree/hal/deferred_buffer.h deleted file mode 100644 index e3785417d5a6e..0000000000000 --- a/iree/hal/deferred_buffer.h +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEFERRED_BUFFER_H_ -#define IREE_HAL_DEFERRED_BUFFER_H_ - -#include -#include -#include - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// A Buffer that can have its underlying allocation changed at runtime. -// Unbound buffers act as a way to logically group dependent ranges of memory -// without needing to have allocated that memory yet. -// -// Usage: -// // Setup two spans referencing ranges of a deferred buffer. -// auto deferred_buffer = make_ref(..., 200); -// IREE_ASSIGN_OR_RETURN(auto span0, Buffer::Subspan(deferred_buffer, 0, 100)); -// IREE_ASSIGN_OR_RETURN(auto span1, Buffer::Subspan(deferred_buffer, 100, -// 100)); -// -// // Attempting to access |deferred_buffer| or |span0| or |span1| will fail. -// // ERROR: span0->Fill(false); -// -// // Now allocate a real buffer to serve as storage for the data. -// IREE_ASSIGN_OR_RETURN(auto allocated_buffer, Buffer::Allocate(..., 200)); -// IREE_RETURN_IF_ERROR(deferred_buffer->BindAllocation( -// allocated_buffer, 0, kWholeBuffer)); -// -// // And now we can use the spans. -// IREE_RETURN_IF_ERROR(span0->Fill(false)); -// -// // If at some point we want to detach the buffer from the allocation (so we -// // can use a different allocation, reuse the memory, etc). -// deferred_buffer->ResetAllocation(); -// -// Thread-compatible. Attempting to rebind the allocation while other threads -// are using the buffer will lead to undefined behavior. -class DeferredBuffer : public Buffer { - public: - DeferredBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t byte_length); - ~DeferredBuffer() override; - - // Grows the minimum allocation size of the buffer to |new_byte_length|. - // Attempting to bind an allocation less than this size will fail. This must - // only be called when the buffer is not bound to an allocation. - Status GrowByteLength(device_size_t new_byte_length); - - // Binds or rebinds the deferred buffer to an allocated buffer. - Status BindAllocation(ref_ptr allocated_buffer, - device_size_t byte_offset, device_size_t byte_length); - - // Resets the deferred buffer to have no binding. - void ResetAllocation(); - - private: - // Resolves the allocated buffer that this subspan references into. - // This will fail if the buffer has not yet been bound to an allocation or - // the allocated buffer has not been committed. - StatusOr ResolveAllocation() const; - - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override; - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override; - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override; - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override; - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override; - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEFERRED_BUFFER_H_ diff --git a/iree/hal/deferred_buffer_test.cc b/iree/hal/deferred_buffer_test.cc deleted file mode 100644 index 454fe2d1426a4..0000000000000 --- a/iree/hal/deferred_buffer_test.cc +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/deferred_buffer.h" - -#include "absl/memory/memory.h" -#include "iree/hal/heap_buffer.h" -#include "iree/hal/testing/mock_allocator.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -namespace { - -using ::iree::hal::testing::MockAllocator; -using ::testing::_; -using ::testing::Return; - -// Tests properties of unbound buffers. -TEST(DeferredBufferTest, Unbound) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - EXPECT_EQ(&allocator, deferred_buffer->allocator()); - EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(0, deferred_buffer->allocation_size()); - EXPECT_EQ(0, deferred_buffer->byte_offset()); - EXPECT_EQ(100, deferred_buffer->byte_length()); -} - -// Tests that binding verifies allocators are compatible. -TEST(DeferredBufferTest, AllocatorCheck) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL( - allocator, - CanUseBufferLike(real_buffer->allocator(), real_buffer->memory_type(), - real_buffer->usage(), BufferUsage::kAll)) - .WillOnce(Return(false)); - EXPECT_TRUE(IsInvalidArgument( - deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100))); -} - -// Tests that binding verifies allocation sizes. -TEST(DeferredBufferTest, SizeCheck) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - - IREE_EXPECT_OK( - deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 100)); - EXPECT_EQ(256, deferred_buffer->allocation_size()); - EXPECT_EQ(10, deferred_buffer->byte_offset()); - EXPECT_EQ(100, deferred_buffer->byte_length()); - IREE_EXPECT_OK( - deferred_buffer->BindAllocation(add_ref(real_buffer), 10, kWholeBuffer)); - EXPECT_EQ(256, deferred_buffer->allocation_size()); - EXPECT_EQ(10, deferred_buffer->byte_offset()); - EXPECT_EQ(100, deferred_buffer->byte_length()); - - EXPECT_TRUE(IsOutOfRange( - deferred_buffer->BindAllocation(add_ref(real_buffer), 200, 100))); - EXPECT_TRUE(IsOutOfRange(deferred_buffer->BindAllocation(add_ref(real_buffer), - 200, kWholeBuffer))); - EXPECT_TRUE(IsOutOfRange( - deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 10))); -} - -// Tests resizing buffers after they have been allocated. -TEST(DeferredBufferTest, Resizing) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - - // Grow. - EXPECT_EQ(100, deferred_buffer->byte_length()); - IREE_EXPECT_OK(deferred_buffer->GrowByteLength(150)); - EXPECT_EQ(150, deferred_buffer->byte_length()); - - // Shrinking should fail. - EXPECT_TRUE(IsInvalidArgument(deferred_buffer->GrowByteLength(5))); - - // Growing should fail if bound. - IREE_EXPECT_OK( - deferred_buffer->BindAllocation(std::move(real_buffer), 0, 150)); - EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->GrowByteLength(100))); -} - -// Tests binding and rebinding behavior. -TEST(DeferredBufferTest, Rebinding) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - - // Safe to reset when not bound. - deferred_buffer->ResetAllocation(); - EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(0, deferred_buffer->allocation_size()); - - IREE_EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100)); - EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(256, deferred_buffer->allocation_size()); - deferred_buffer->ResetAllocation(); - EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(0, deferred_buffer->allocation_size()); - IREE_EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100)); - EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(256, deferred_buffer->allocation_size()); -} - -// Tests normal usage of bound buffers. -TEST(DeferredBufferTest, BoundUsage) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - IREE_EXPECT_OK( - deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100)); - - EXPECT_FALSE(deferred_buffer->DebugString().empty()); - EXPECT_FALSE(deferred_buffer->DebugStringShort().empty()); - - IREE_EXPECT_OK(deferred_buffer->Fill8(0, 10, 0xFF)); -} - -// Tests that unbound buffers fail to perform any buffer actions. -TEST(DeferredBufferTest, UnboundUsage) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - EXPECT_FALSE(deferred_buffer->DebugString().empty()); - EXPECT_FALSE(deferred_buffer->DebugStringShort().empty()); - - EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->Fill8(0, 10, 0xFF))); -} - -} // namespace -} // namespace hal -} // namespace iree diff --git a/iree/hal/descriptor_set.c b/iree/hal/descriptor_set.c new file mode 100644 index 0000000000000..a3b468db61fc2 --- /dev/null +++ b/iree/hal/descriptor_set.c @@ -0,0 +1,42 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/descriptor_set.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(descriptor_set, method_name) \ + IREE_HAL_VTABLE_DISPATCH(descriptor_set, iree_hal_descriptor_set, method_name) + +IREE_HAL_API_RETAIN_RELEASE(descriptor_set); + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_descriptor_set_create( + iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(set_layout); + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set); + *out_descriptor_set = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, create_descriptor_set)( + device, set_layout, binding_count, bindings, out_descriptor_set); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/descriptor_set.h b/iree/hal/descriptor_set.h index 9726795735a1d..1e48ec34bd688 100644 --- a/iree/hal/descriptor_set.h +++ b/iree/hal/descriptor_set.h @@ -12,58 +12,101 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/strings/str_cat.h" +#ifndef IREE_HAL_DESCRIPTOR_SET_H_ +#define IREE_HAL_DESCRIPTOR_SET_H_ + +#include +#include + +#include "iree/base/api.h" #include "iree/hal/buffer.h" +#include "iree/hal/descriptor_set_layout.h" #include "iree/hal/resource.h" -#ifndef IREE_HAL_DESCRIPTOR_SET_H_ -#define IREE_HAL_DESCRIPTOR_SET_H_ +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_s iree_hal_device_t; + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// -namespace iree { -namespace hal { +// Specifies a descriptor set binding. +// The range specified by [offset, length) will be made available to executables +// on the given binding. If the descriptor type is dynamic then the range will +// be [offset + dynamic_offset, length). +// +// The IREE HAL buffer type may internally be offset; such offset is applied +// here as if it were the base address of the buffer. Note that the offset will +// be applied at the time the binding is recording into the command buffer. +// +// Maps to VkDescriptorSetBinding. +typedef struct { + // The binding number of this entry and corresponds to a resource of the + // same binding number in the executable interface. + uint32_t binding; + // Buffer bound to the binding number. + // May be NULL if the binding is not used by the executable. + iree_hal_buffer_t* buffer; + // Offset, in bytes, into the buffer that the binding starts at. + // If the descriptor type is dynamic this will be added to the dynamic + // offset provided during binding. + iree_device_size_t offset; + // Length, in bytes, of the buffer that is available to the executable. + // This can be IREE_WHOLE_BUFFER, however note that if the entire buffer + // contents are larger than supported by the device (~128MiB, usually) this + // will fail. If the descriptor type is dynamic this will be used for all + // ranges regardless of offset. + iree_device_size_t length; +} iree_hal_descriptor_set_binding_t; + +//===----------------------------------------------------------------------===// +// iree_hal_descriptor_set_t +//===----------------------------------------------------------------------===// // Opaque handle to a descriptor set object. +// A "descriptor" is effectively a bound memory range and each dispatch can use +// one or more "descriptor sets" to access their I/O memory. Each descriptor set +// conforms to a template "descriptor set layout". // // Maps to VkDescriptorSet: // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkDescriptorSet.html -class DescriptorSet : public Resource { - public: - // Specifies a descriptor set binding. - struct Binding { - // The binding number of this entry and corresponds to a resource of the - // same binding number in the executable interface. - int32_t binding = 0; - // Buffer bound to the binding number. - // May be nullptr if the binding is not used by the executable. - Buffer* buffer; - // Offset, in bytes, into the buffer that the binding starts at. - // If the descriptor type is dynamic this will be added to the dynamic - // offset provided during binding. - device_size_t offset = 0; - // Length, in bytes, of the buffer that is available to the executable. - // This can be kWholeBuffer, however note that if the entire buffer - // contents are larger than supported by the device (~128MiB, usually) this - // will fail. If the descriptor type is dynamic this will be used for all - // ranges regardless of offset. - device_size_t length = kWholeBuffer; - - std::string DebugStringShort() const { - return absl::StrCat("binding=", binding, ", ", buffer->DebugStringShort(), - ", offset=", offset, ", length=", length); - } - }; -}; - -struct DescriptorSetBindingFormatter { - void operator()(std::string* out, - const DescriptorSet::Binding& binding) const { - out->append("<"); - out->append(binding.DebugStringShort()); - out->append(">"); - } -}; - -} // namespace hal -} // namespace iree +typedef struct iree_hal_descriptor_set_s iree_hal_descriptor_set_t; + +// Creates a descriptor set of the given layout and bindings. +// Descriptor sets are immutable and retain their bindings. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_descriptor_set_create( + iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set); + +// Retains the given |set| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_descriptor_set_retain(iree_hal_descriptor_set_t* descriptor_set); + +// Releases the given |set| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_descriptor_set_release(iree_hal_descriptor_set_t* descriptor_set); + +//===----------------------------------------------------------------------===// +// iree_hal_descriptor_set_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_descriptor_set_t* descriptor_set); +} iree_hal_descriptor_set_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_descriptor_set_destroy(iree_hal_descriptor_set_t* descriptor_set); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_DESCRIPTOR_SET_H_ diff --git a/iree/hal/descriptor_set_layout.c b/iree/hal/descriptor_set_layout.c new file mode 100644 index 0000000000000..3f84c1e60c5c5 --- /dev/null +++ b/iree/hal/descriptor_set_layout.c @@ -0,0 +1,44 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/descriptor_set_layout.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(descriptor_set_layout, method_name) \ + IREE_HAL_VTABLE_DISPATCH(descriptor_set_layout, \ + iree_hal_descriptor_set_layout, method_name) + +IREE_HAL_API_RETAIN_RELEASE(descriptor_set_layout); + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_descriptor_set_layout_create( + iree_hal_device_t* device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); + *out_descriptor_set_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, + create_descriptor_set_layout)( + device, usage_type, binding_count, bindings, out_descriptor_set_layout); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/descriptor_set_layout.h b/iree/hal/descriptor_set_layout.h index 2e808889f829d..d26502d2426aa 100644 --- a/iree/hal/descriptor_set_layout.h +++ b/iree/hal/descriptor_set_layout.h @@ -12,56 +12,108 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/strings/str_cat.h" +#ifndef IREE_HAL_DESCRIPTOR_SET_LAYOUT_H_ +#define IREE_HAL_DESCRIPTOR_SET_LAYOUT_H_ + +#include +#include + +#include "iree/base/api.h" #include "iree/hal/buffer.h" #include "iree/hal/resource.h" -#ifndef IREE_HAL_DESCRIPTOR_SET_LAYOUT_H_ -#define IREE_HAL_DESCRIPTOR_SET_LAYOUT_H_ +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_s iree_hal_device_t; -namespace iree { -namespace hal { +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// // Specifies the type of a descriptor in a descriptor set. -enum class DescriptorType : uint32_t { - kUniformBuffer = 6, - kStorageBuffer = 7, - kUniformBufferDynamic = 8, - kStorageBufferDynamic = 9, +enum iree_hal_descriptor_type_e { + IREE_HAL_DESCRIPTOR_TYPE_UNIFORM_BUFFER = 6u, + IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER = 7u, + IREE_HAL_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC = 8u, + IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC = 9u, }; +typedef uint32_t iree_hal_descriptor_type_t; + +// Specifies the usage type of the descriptor set. +enum iree_hal_descriptor_set_layout_usage_type_e { + // Descriptor set will be initialized once and never changed. + IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_IMMUTABLE = 0u, + // Descriptor set is never created and instead used with push descriptors. + IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY = 1u, +}; +typedef uint32_t iree_hal_descriptor_set_layout_usage_type_t; + +// Specifies a descriptor set layout binding. +// +// Maps to VkDescriptorSetLayoutBinding. +typedef struct { + // The binding number of this entry and corresponds to a resource of the + // same binding number in the executable interface. + uint32_t binding; + // Specifies which type of resource descriptors are used for this binding. + iree_hal_descriptor_type_t type; + // Specifies the memory access performed by the executables. + iree_hal_memory_access_t access; +} iree_hal_descriptor_set_layout_binding_t; + +//===----------------------------------------------------------------------===// +// iree_hal_descriptor_set_layout_t +//===----------------------------------------------------------------------===// // Opaque handle to a descriptor set layout object. +// A "descriptor" is effectively a bound memory range and each dispatch can use +// one or more "descriptor sets" to access their I/O memory. A "descriptor set +// layout" defines the types and usage semantics of the descriptors that make up +// one set. Implementations can use this to verify program correctness and +// accelerate reservation/allocatation/computation of descriptor-related +// operations. // // Maps to VkDescriptorSetLayout: // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkDescriptorSetLayout.html -class DescriptorSetLayout : public Resource { - public: - // Specifies the usage type of the descriptor set. - enum class UsageType : uint32_t { - // Descriptor set will be initialized once and never changed. - kImmutable = 0, - // Descriptor set is never created and instead used with push descriptors. - kPushOnly = 1, - }; - - // Specifies a descriptor set layout binding. - struct Binding { - // The binding number of this entry and corresponds to a resource of the - // same binding number in the executable interface. - int32_t binding = 0; - // Specifies which type of resource descriptors are used for this binding. - DescriptorType type = DescriptorType::kStorageBuffer; - // Specifies the memory access performed by the executables. - MemoryAccessBitfield access = MemoryAccess::kRead | MemoryAccess::kWrite; - - std::string DebugStringShort() const { - return absl::StrCat("binding=", binding, ", type=", type, - ", access=", MemoryAccessString(access)); - } - }; -}; +typedef struct iree_hal_descriptor_set_layout_s + iree_hal_descriptor_set_layout_t; + +// Creates a descriptor set layout with the given bindings. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_descriptor_set_layout_create( + iree_hal_device_t* device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + +// Retains the given |descriptor_set_layout| for the caller. +IREE_API_EXPORT void IREE_API_CALL iree_hal_descriptor_set_layout_retain( + iree_hal_descriptor_set_layout_t* descriptor_set_layout); + +// Releases the given |descriptor_set_layout| from the caller. +IREE_API_EXPORT void IREE_API_CALL iree_hal_descriptor_set_layout_release( + iree_hal_descriptor_set_layout_t* descriptor_set_layout); + +//===----------------------------------------------------------------------===// +// iree_hal_descriptor_set_layout_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)( + iree_hal_descriptor_set_layout_t* descriptor_set_layout); +} iree_hal_descriptor_set_layout_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL iree_hal_descriptor_set_layout_destroy( + iree_hal_descriptor_set_layout_t* descriptor_set_layout); -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_DESCRIPTOR_SET_LAYOUT_H_ diff --git a/iree/hal/detail.h b/iree/hal/detail.h new file mode 100644 index 0000000000000..827498b57e94b --- /dev/null +++ b/iree/hal/detail.h @@ -0,0 +1,75 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DETAIL_H_ +#define IREE_HAL_DETAIL_H_ + +#include +#include + +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Dispatches a method on a HAL object vtable. +// +// In the future we can use this to compile in a mode where all indirect +// dispatches are replaced by direct calls to static methods. For example, +// by changing the macro to resolve to `iree_hal_[resource]_[method_name]` we +// can rely on LTO to perform cross-compilation unit inlining/strip unused HAL +// calls/etc. This will be particularly useful for super tiny builds +// (web/embedded) where there's only ever one usable backend and debugging +// features like command buffer validation aren't required. +// +// Some changes (mostly whackamole) are still required to fully support this and +// it's critical there's a CI building with the setting as it's not hard to keep +// working but very easy to accidentally break (by not routing through this +// interface, using the vtable for object instance comparison, etc). +#define IREE_HAL_VTABLE_DISPATCH(resource, type_prefix, method_name) \ + ((const type_prefix##_vtable_t*)((const iree_hal_resource_t*)(resource)) \ + ->vtable) \ + ->method_name + +// Defines the iree_hal__retain/_release methods. +#define IREE_HAL_API_RETAIN_RELEASE(type_name) \ + IREE_API_EXPORT void IREE_API_CALL iree_hal_##type_name##_destroy( \ + iree_hal_##type_name##_t* type_name) { \ + if (IREE_LIKELY(type_name)) { \ + IREE_HAL_VTABLE_DISPATCH(type_name, iree_hal_##type_name, destroy) \ + (type_name); \ + } \ + } \ + IREE_API_EXPORT void IREE_API_CALL iree_hal_##type_name##_retain( \ + iree_hal_##type_name##_t* type_name) { \ + if (IREE_LIKELY(type_name)) { \ + iree_atomic_ref_count_inc( \ + &((iree_hal_resource_t*)(type_name))->ref_count); \ + } \ + } \ + IREE_API_EXPORT void IREE_API_CALL iree_hal_##type_name##_release( \ + iree_hal_##type_name##_t* type_name) { \ + if (IREE_LIKELY(type_name) && \ + iree_atomic_ref_count_dec( \ + &((iree_hal_resource_t*)(type_name))->ref_count) == 1) { \ + iree_hal_##type_name##_destroy(type_name); \ + } \ + } + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_DETAIL_H_ diff --git a/iree/hal/device.c b/iree/hal/device.c new file mode 100644 index 0000000000000..8ad7615a6b3ec --- /dev/null +++ b/iree/hal/device.c @@ -0,0 +1,103 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/device.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" + +#define _VTABLE_DISPATCH(device, method_name) \ + IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, method_name) + +IREE_HAL_API_RETAIN_RELEASE(device); + +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_hal_device_id(iree_hal_device_t* device) { + IREE_ASSERT_ARGUMENT(device); + return _VTABLE_DISPATCH(device, id)(device); +} + +IREE_API_EXPORT iree_allocator_t IREE_API_CALL +iree_hal_device_host_allocator(iree_hal_device_t* device) { + IREE_ASSERT_ARGUMENT(device); + return _VTABLE_DISPATCH(device, host_allocator)(device); +} + +IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL +iree_hal_device_allocator(iree_hal_device_t* device) { + IREE_ASSERT_ARGUMENT(device); + return _VTABLE_DISPATCH(device, device_allocator)(device); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_device_queue_submit( + iree_hal_device_t* device, iree_hal_command_category_t command_categories, + uint64_t queue_affinity, iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(!batch_count || batches); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(device, queue_submit)( + device, command_categories, queue_affinity, batch_count, batches); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_device_wait_semaphores_with_deadline( + iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) { + IREE_ASSERT_ARGUMENT(device); + if (!semaphore_list || semaphore_list->count == 0) return iree_ok_status(); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + _VTABLE_DISPATCH(device, wait_semaphores_with_deadline)( + device, wait_mode, semaphore_list, deadline_ns); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_device_wait_semaphores_with_timeout( + iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, + iree_duration_t timeout_ns) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(semaphore_list); + if (!semaphore_list || semaphore_list->count == 0) return iree_ok_status(); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(device, wait_semaphores_with_timeout)( + device, wait_mode, semaphore_list, timeout_ns); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_device_wait_idle_with_deadline( + iree_hal_device_t* device, iree_time_t deadline_ns) { + IREE_ASSERT_ARGUMENT(device); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + _VTABLE_DISPATCH(device, wait_idle_with_deadline)(device, deadline_ns); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_device_wait_idle_with_timeout( + iree_hal_device_t* device, iree_duration_t timeout_ns) { + IREE_ASSERT_ARGUMENT(device); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + _VTABLE_DISPATCH(device, wait_idle_with_timeout)(device, timeout_ns); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/device.h b/iree/hal/device.h index 4cd36960bb009..b683b48727be2 100644 --- a/iree/hal/device.h +++ b/iree/hal/device.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,171 +15,289 @@ #ifndef IREE_HAL_DEVICE_H_ #define IREE_HAL_DEVICE_H_ -#include +#include +#include -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/base/target_platform.h" -#include "iree/base/time.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" -#include "iree/hal/command_queue.h" +#include "iree/base/api.h" +#include "iree/hal/command_buffer.h" #include "iree/hal/descriptor_set.h" #include "iree/hal/descriptor_set_layout.h" -#include "iree/hal/device_info.h" #include "iree/hal/event.h" #include "iree/hal/executable_cache.h" #include "iree/hal/executable_layout.h" +#include "iree/hal/resource.h" #include "iree/hal/semaphore.h" -#if defined(IREE_PLATFORM_WINDOWS) -// Win32 macro name conflicts: -#undef CreateEvent -#undef CreateSemaphore -#endif // IREE_PLATFORM_WINDOWS - -namespace iree { -namespace hal { - -class Device : public RefObject { - public: - virtual ~Device() = default; - - // Information about device capabilities. - const DeviceInfo& info() const { return device_info_; } - - // Returns a debug string describing the device. - virtual std::string DebugString() const { return device_info_.DebugString(); } - - // TODO(benvanik): status (thermal, power mode, etc). - - // TODO(benvanik): throttling adjustment/power profile. - - // TODO(benvanik): control (suspend/resume, delay, etc). - - // An allocator providing buffers usable by the device. - // This allocator may be shared with other devices in the same family. - virtual Allocator* allocator() const = 0; - - // Returns a list of all general-purpose dispatch queues provided by the - // device. In general these map 1:1 with independent execution contexts, - // though some devices may hide that and expose only a single queue that is - // scheduled internally. - virtual absl::Span dispatch_queues() const = 0; - - // Returns a list of transfer queues provided by the device. These queues may - // perform transfer operations asynchronously with respect to execution on the - // dispatch queues. For large sequences of transfer operations always prefer - // using one of these queues. - // Note that if the device does not support a dedicated transfer queue this - // list may be the same as (or a subset of) dispatch_queues. - virtual absl::Span transfer_queues() const = 0; - - // TODO(b/137153339): accept initial cache data. - // Creates a device-specific cache for executables prepared for dispatch. - // The cache manages executable compilation, caching (on disk or in memory), - // and lifetime. Users can decide to use one or more caches to allow differing - // lifetimes (such as unloading modules), persistent on disk caching of only - // specific hot executables, etc. - // - // Returns a thread-safe cache that must remain alive until all executables - // using the cache are no longer in-flight. - virtual ref_ptr CreateExecutableCache() = 0; - - // Creates a descriptor set layout with the given bindings. - virtual StatusOr> CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) = 0; - - // Creates an executable layout composed of the given descriptor set layouts. - // The returned executable layout can be used by multiple executables with the - // same compatible resource binding layouts. - virtual StatusOr> CreateExecutableLayout( - absl::Span set_layouts, - size_t push_constants) = 0; - - // Creates a descriptor set of the given layout and bindings. - // Descriptor sets are immutable and retain their bindings. - virtual StatusOr> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span bindings) = 0; - - // Creates a command buffer for recording commands to submit to queues owned - // by this device. The command buffer may come from a pool but will be reset - // prior to being returned to the caller. - virtual StatusOr> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) = 0; - - // Creates an event for recording into command buffers. - // The returned event object is only usable with this device and events must - // only be used to synchronize within the same queue. - virtual StatusOr> CreateEvent() = 0; - - // Creates a semaphore that can be used with command queues owned by this - // device. To use the semaphores with other devices or instances they must - // first be exported. - virtual StatusOr> CreateSemaphore( - uint64_t initial_value) = 0; - - // TODO(benvanik): import/export semaphore utilities. - // TODO(benvanik): semaphores to wait handles. - - // Blocks the caller until all passed |semaphores| reach or exceed the - // specified payload values or the |deadline| elapses. All |semaphores| must - // be created from this device (or be imported into it). - // - // Returns success if the wait is successful and all semaphores have been - // signaled. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all semaphores - // having been signaled. Note that a subset of the |semaphores| may have been - // signaled and each can be queried to see which ones. - virtual Status WaitAllSemaphores(absl::Span semaphores, - Time deadline_ns) = 0; - inline Status WaitAllSemaphores(absl::Span semaphores, - Duration timeout_ns) { - return WaitAllSemaphores(semaphores, - RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - - // Blocks the caller until at least one of the |semaphores| reaches or exceeds - // the specified payload value or the |deadline| elapses. All |semaphores| - // must be created from this device (or be imported into it). - // - // Returns an arbitrary index into |semaphores| of a semaphore that was - // signaled. Note that more than one semaphore may have been signaled and all - // of the other |semaphores| should be queried or waited on again until waits - // for them succeed. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any semaphores - // having been signaled. - virtual StatusOr WaitAnySemaphore( - absl::Span semaphores, Time deadline_ns) = 0; - inline StatusOr WaitAnySemaphore( - absl::Span semaphores, Duration timeout_ns) { - return WaitAnySemaphore(semaphores, - RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - - // Blocks until all outstanding requests on all queues have been - // completed. This is equivalent to having waited on all outstanding - // semaphores. - virtual Status WaitIdle(Time deadline_ns) = 0; - inline Status WaitIdle(Duration timeout_ns) { - return WaitIdle(RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - inline Status WaitIdle() { return WaitIdle(InfiniteFuture()); } - - protected: - explicit Device(DeviceInfo device_info) - : device_info_(std::move(device_info)) {} - - private: - const DeviceInfo device_info_; +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +// An opaque driver-specific handle to identify different devices. +typedef uintptr_t iree_hal_device_id_t; + +#define IREE_HAL_DEVICE_ID_INVALID 0ull + +// Describes features supported by a device. +// These flags indicate the availability of features that may be enabled at the +// request of the calling application. Note that certain features may disable +// runtime optimizations or require compilation flags to ensure the required +// metadata is present in executables. +enum iree_hal_device_feature_e { + IREE_HAL_DEVICE_FEATURE_NONE = 0, + + // Device supports executable debugging. + // When present executables *may* be compiled with + // IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_DEBUGGING and will have usable + // debugging related methods. Note that if the input executables do not have + // embedded debugging information they still may not be able to perform + // disassembly or fine-grained breakpoint insertion. + IREE_HAL_DEVICE_FEATURE_SUPPORTS_DEBUGGING = 1 << 0, + + // Device supports executable coverage information. + // When present executables *may* be compiled with + // IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_COVERAGE and will produce + // coverage buffers during dispatch. Note that input executables must have + // partial embedded debug information to allow mapping back to source offsets. + IREE_HAL_DEVICE_FEATURE_SUPPORTS_COVERAGE = 1 << 1, + + // Device supports executable and command queue profiling. + // When present executables *may* be compiled with + // IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_PROFILING and will produce + // profiling buffers during dispatch. Note that input executables must have + // partial embedded debug information to allow mapping back to source offsets. + IREE_HAL_DEVICE_FEATURE_SUPPORTS_PROFILING = 1 << 2, }; +typedef uint32_t iree_hal_device_feature_t; + +// Describes an enumerated HAL device. +typedef struct { + // Opaque handle used by drivers. Not valid across driver instances. + iree_hal_device_id_t device_id; + // Name of the device as returned by the API. + iree_string_view_t name; +} iree_hal_device_info_t; + +// A list of semaphores and their corresponding payloads. +// When signaling each semaphore will be set to the new payload value provided. +// When waiting each semaphore must reach or exceed the payload value. +typedef struct { + iree_host_size_t count; + iree_hal_semaphore_t** semaphores; + uint64_t* payload_values; +} iree_hal_semaphore_list_t; + +// A single batch of command buffers submitted to a device queue. +// All of the wait semaphores must reach or exceed the given payload value prior +// to the batch beginning execution. Each command buffer begins execution in the +// order it is present in the list, though note that the command buffers +// execute concurrently and require internal synchronization via events if there +// are any dependencies between them. Only after all command buffers have +// completed will the signal semaphores be updated to the provided payload +// values. +// +// Matches Vulkan's VkSubmitInfo: +// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkSubmitInfo.html +// Note that as the HAL only models timeline semaphores we take the payload +// values directly in this struct; see: +// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkTimelineSemaphoreSubmitInfo.html +typedef struct { + // Semaphores to wait on prior to executing any command buffer. + iree_hal_semaphore_list_t wait_semaphores; + + // Command buffers to execute, in order. + iree_host_size_t command_buffer_count; + iree_hal_command_buffer_t** command_buffers; + + // Semaphores to signal once all command buffers have completed execution. + iree_hal_semaphore_list_t signal_semaphores; +} iree_hal_submission_batch_t; + +// Defines how a multi-wait operation treats the results of multiple semaphores. +enum iree_hal_wait_mode_e { + // Waits for all semaphores to reach or exceed their specified values. + IREE_HAL_WAIT_MODE_ALL = 0, + // Waits for one or more semaphores to reach or exceed their specified values. + IREE_HAL_WAIT_MODE_ANY = 1, +}; +typedef uint8_t iree_hal_wait_mode_t; + +//===----------------------------------------------------------------------===// +// iree_hal_device_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_device_s iree_hal_device_t; + +// Retains the given |device| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_device_retain(iree_hal_device_t* device); + +// Releases the given |device| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_device_release(iree_hal_device_t* device); + +// Returns the device identifier. +// This identifier may vary based on the runtime device type; for example, a +// Vulkan device may return `vulkan-v1.1` or `vulkan-v1.2-spec1`. +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_hal_device_id(iree_hal_device_t* device); + +// Returns the host allocator used for objects. +IREE_API_EXPORT iree_allocator_t IREE_API_CALL +iree_hal_device_host_allocator(iree_hal_device_t* device); + +// Returns a reference to the allocator of the device that can be used for +// allocating buffers. +IREE_API_EXPORT iree_hal_allocator_t* IREE_API_CALL +iree_hal_device_allocator(iree_hal_device_t* device); + +// Submits one or more batches of work to a device queue. +// +// The queue is selected based on the flags set in |command_categories| and the +// |queue_affinity|. As the number of available queues can vary the +// |queue_affinity| is used to hash into the available queues for the required +// categories. For example if 2 queues support transfer commands and the +// affinity is 5 the resulting queue could be index hash(5)=1. The affinity can +// thus be treated as just a way to indicate whether two submissions must be +// placed on to the same queue. Note that the exact hashing function is +// implementation dependent. +// +// The submission behavior matches Vulkan's vkQueueSubmit, with each batch +// executing its command buffers in the order they are defined but allowing the +// command buffers to complete out-of-order. See: +// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/vkQueueSubmit.html +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_device_queue_submit( + iree_hal_device_t* device, iree_hal_command_category_t command_categories, + uint64_t queue_affinity, iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches); + +// Blocks the caller until the semaphores reach or exceed the specified payload +// values or the |deadline_ns| elapses. All semaphores in |semaphore_list| must +// be created from this device (or be imported into it). +// +// |wait_mode| can be used to decide when the wait will proceed; whether *all* +// semaphores in |semaphore_list| must be signaled or whether *any* (one or +// more) can be signaled before an early return. +// +// Returns success if the wait is successful and semaphores have been signaled +// satisfying the |wait_mode|. +// +// Returns DEADLINE_EXCEEDED if the |deadline_ns| elapses without the +// |wait_mode| being satisfied. Note that even on success only a subset of the +// semaphores may have been signaled and each can be queried to see which ones. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_device_wait_semaphores_with_deadline( + iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns); + +// Blocks the caller until the semaphores reach or exceed the specified payload +// values or the |timeout_ns| elapses. +// A relative-time version of iree_hal_device_wait_semaphores_with_deadline +// using the relative nanoseconds from the time the call is made. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_device_wait_semaphores_with_timeout( + iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, + iree_duration_t timeout_ns); + +// Blocks the caller until all outstanding requests on all queues have been +// completed or the |deadline_ns| elapses. This is equivalent to having waited +// on all semaphores outstanding at the time of the call, meaning that if new +// work is submitted by another thread it may not be waited on prior to this +// call returning. +// +// Returns success if the device reaches an idle point during the call. +// +// Returns DEADLINE_EXCEEDED if the |deadline_ns| elapses without the device +// having become idle. +IREE_API_EXPORT iree_status_t iree_hal_device_wait_idle_with_deadline( + iree_hal_device_t* device, iree_time_t deadline_ns); + +// Blocks the caller until all outstanding requests on all quests have been +// completed or the |timeout_ns| elapses. +// A relative-time version of iree_hal_device_wait_idle_with_deadline +// using the relative nanoseconds from the time the call is made. +IREE_API_EXPORT iree_status_t iree_hal_device_wait_idle_with_timeout( + iree_hal_device_t* device, iree_duration_t timeout_ns); + +//===----------------------------------------------------------------------===// +// iree_hal_device_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_device_t* device); + + iree_string_view_t(IREE_API_PTR* id)(iree_hal_device_t* device); + + iree_allocator_t(IREE_API_PTR* host_allocator)(iree_hal_device_t* device); + iree_hal_allocator_t*(IREE_API_PTR* device_allocator)( + iree_hal_device_t* device); + + iree_status_t(IREE_API_PTR* create_command_buffer)( + iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer); + + iree_status_t(IREE_API_PTR* create_descriptor_set)( + iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set); + + iree_status_t(IREE_API_PTR* create_descriptor_set_layout)( + iree_hal_device_t* device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + + iree_status_t(IREE_API_PTR* create_event)(iree_hal_device_t* device, + iree_hal_event_t** out_event); + + iree_status_t(IREE_API_PTR* create_executable_cache)( + iree_hal_device_t* device, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache); + + iree_status_t(IREE_API_PTR* create_executable_layout)( + iree_hal_device_t* device, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, + iree_hal_executable_layout_t** out_executable_layout); + + iree_status_t(IREE_API_PTR* create_semaphore)( + iree_hal_device_t* device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + + iree_status_t(IREE_API_PTR* queue_submit)( + iree_hal_device_t* device, iree_hal_command_category_t command_categories, + uint64_t queue_affinity, iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches); + + iree_status_t(IREE_API_PTR* wait_semaphores_with_deadline)( + iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns); + iree_status_t(IREE_API_PTR* wait_semaphores_with_timeout)( + iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, + iree_duration_t timeout_ns); + + iree_status_t(IREE_API_PTR* wait_idle_with_deadline)( + iree_hal_device_t* device, iree_time_t deadline_ns); + iree_status_t(IREE_API_PTR* wait_idle_with_timeout)( + iree_hal_device_t* device, iree_duration_t timeout_ns); +} iree_hal_device_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_device_destroy(iree_hal_device_t* device); -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_DEVICE_H_ diff --git a/iree/hal/device_info.h b/iree/hal/device_info.h deleted file mode 100644 index 0455219f56ed8..0000000000000 --- a/iree/hal/device_info.h +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEVICE_INFO_H_ -#define IREE_HAL_DEVICE_INFO_H_ - -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "iree/base/bitfield.h" - -namespace iree { -namespace hal { - -// An opaque driver-specific handle to identify different devices. -using DriverDeviceID = uintptr_t; - -// Describes features supported by the device. -// These flags indicate the availability of features that may be enabled at the -// request of the calling application. Note that certain features may disable -// runtime optimizations or require compilation flags to ensure the required -// metadata is present in executables. -enum class DeviceFeature : uint32_t { - kNone = 0, - - // Device supports executable debugging. - // When present executables *may* be compiled with - // ExecutableCachingMode::kEnableDebugging and will have usable debugging - // related methods. Note that if the input executables do not have embedded - // debugging information they still may not be able to perform disassembly or - // fine-grained breakpoint insertion. - kDebugging = 1 << 0, - - // Device supports executable coverage information. - // When present executables *may* be compiled with - // ExecutableCachingMode::kEnableCoverage and will produce coverage buffers - // during dispatch. Note that input executables must have partial embedded - // debug information to allow mapping back to source offsets. - kCoverage = 1 << 1, - - // Device supports executable and command queue profiling. - // When present executables *may* be compiled with - // ExecutableCachingMode::kEnableProfiling and will produce profiling buffers - // during dispatch. Note that input executables must have partial embedded - // debug information to allow mapping back to source offsets. - kProfiling = 1 << 2, -}; -IREE_BITFIELD(DeviceFeature); -using DeviceFeatureBitfield = DeviceFeature; - -// TODO(benvanik): device info (caps, physical mappings, etc). -class DeviceInfo { - public: - DeviceInfo(std::string id, std::string name, - DeviceFeatureBitfield supported_features, - DriverDeviceID device_id = 0) - : id_(std::move(id)), - name_(std::move(name)), - supported_features_(supported_features), - device_id_(device_id) {} - - // Machine-friendly device identifier used to match the device against - // compiler-generated patterns. This should be consistent with the device IDs - // emitted by the compiler. For example: `vulkan-v1.1-spec`. - const std::string& id() const { return id_; } - - // Human-friendly device name. - const std::string& name() const { return name_; } - - // Features supported by the device. - DeviceFeatureBitfield supported_features() const { - return supported_features_; - } - - // Opaque handle used by drivers to correlate this device with their internal - // listing. This handle will not be valid across driver instances or outside - // of the current process. - DriverDeviceID device_id() const { return device_id_; } - - // Returns a debug string describing the device information. - std::string DebugString() const { - std::string features = FormatBitfieldValue( - supported_features_, { - {DeviceFeature::kDebugging, "kDebugging"}, - {DeviceFeature::kCoverage, "kCoverage"}, - {DeviceFeature::kProfiling, "kProfiling"}, - }); - - return absl::StrCat("[DeviceInfo]", // - "\n Name: ", name_, // - "\n Supported features: [", features, "]", // - "\n Device ID: ", device_id_); - } - - private: - const std::string id_; - const std::string name_; - const DeviceFeatureBitfield supported_features_; - DriverDeviceID device_id_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEVICE_INFO_H_ diff --git a/iree/hal/device_manager.cc b/iree/hal/device_manager.cc deleted file mode 100644 index 5a527f3c52cbf..0000000000000 --- a/iree/hal/device_manager.cc +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/device_manager.h" - -#include - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/heap_buffer.h" - -namespace iree { -namespace hal { - -DeviceManager::DeviceManager() = default; - -DeviceManager::~DeviceManager() { - IREE_TRACE_SCOPE0("DeviceManager::dtor"); - WaitIdle().IgnoreError(); -} - -Status DeviceManager::RegisterDevice(ref_ptr device) { - IREE_TRACE_SCOPE0("DeviceManager::RegisterDevice"); - absl::MutexLock lock(&device_mutex_); - if (std::find(devices_.begin(), devices_.end(), device) != devices_.end()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Device already registered"; - } - devices_.push_back(std::move(device)); - return OkStatus(); -} - -Status DeviceManager::UnregisterDevice(Device* device) { - IREE_TRACE_SCOPE0("DeviceManager::UnregisterDevice"); - absl::MutexLock lock(&device_mutex_); - auto it = std::find_if(devices_.begin(), devices_.end(), - [device](const ref_ptr& other_device) { - return device == other_device.get(); - }); - if (it == devices_.end()) { - return NotFoundErrorBuilder(IREE_LOC) << "Device not registered"; - } - devices_.erase(it); - return OkStatus(); -} - -StatusOr DeviceManager::ResolvePlacement( - const PlacementSpec& placement_spec) const { - IREE_TRACE_SCOPE0("DeviceManager::ResolvePlacement"); - absl::MutexLock lock(&device_mutex_); - if (devices_.empty()) { - return NotFoundErrorBuilder(IREE_LOC) << "No devices registered"; - } - - // TODO(benvanik): multiple devices and placement. - IREE_QCHECK_EQ(devices_.size(), 1) - << "Multiple devices not yet supported (need placement)"; - DevicePlacement device_placement; - device_placement.device = devices_.front().get(); - - return device_placement; -} - -StatusOr DeviceManager::FindCompatibleAllocator( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - absl::Span device_placements) const { - IREE_TRACE_SCOPE0("DeviceManager::FindCompatibleAllocator"); - if (device_placements.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No placements provided"; - } - - // Find the first allocator. As we only return an allocator if all placements - // are compatible we'll compare allocator[0] against allocator[1,N]. - Allocator* some_allocator = nullptr; - for (const auto& device_placement : device_placements) { - auto* allocator = device_placement.device->allocator(); - if (!some_allocator) { - some_allocator = allocator; - continue; - } - // NOTE: as there can be asymmetry between usage restrictions (A can use B - // but B cannot use A) we have to compare both directions. - if (!some_allocator->CanUseBufferLike(allocator, memory_type, buffer_usage, - buffer_usage) || - !allocator->CanUseBufferLike(some_allocator, memory_type, buffer_usage, - buffer_usage)) { - // Allocators are not compatible. - return NotFoundErrorBuilder(IREE_LOC) - << "No single allocator found that is compatible with all " - "placements"; - } - } - return some_allocator; -} - -StatusOr> DeviceManager::TryAllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span device_placements) { - IREE_TRACE_SCOPE0("DeviceManager::TryAllocateDeviceVisibleBuffer:size"); - if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Host-local buffers require the kHostLocal bit: " - << MemoryTypeString(memory_type); - } - - // Strip kDeviceVisible as we conditionally add it based on support. - memory_type &= ~MemoryType::kDeviceVisible; - - // Find an allocator that works for device-visible buffers. - // If this fails we'll fall back to allocation a non-device-visible buffer. - auto allocator_or = - FindCompatibleAllocator(memory_type | MemoryType::kDeviceVisible, - buffer_usage, device_placements); - if (allocator_or.ok()) { - return allocator_or.value()->Allocate( - memory_type | MemoryType::kDeviceVisible, buffer_usage, - allocation_size); - } - - // Fallback to allocating a host-local buffer. - return HeapBuffer::Allocate(memory_type, buffer_usage, allocation_size); -} - -StatusOr> DeviceManager::AllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span device_placements) { - IREE_TRACE_SCOPE0("DeviceManager::AllocateDeviceVisibleBuffer:size"); - if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Host-local buffers require the kHostLocal bit: " - << MemoryTypeString(memory_type); - } - - // Always use device-visible. - memory_type |= MemoryType::kDeviceVisible; - - // Find an allocator that works for device-visible buffers. - IREE_ASSIGN_OR_RETURN( - auto* allocator, - FindCompatibleAllocator(memory_type, buffer_usage, device_placements)); - return allocator->Allocate(memory_type, buffer_usage, allocation_size); -} - -StatusOr> DeviceManager::AllocateDeviceLocalBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span device_placements) { - IREE_TRACE_SCOPE0("DeviceManager::AllocateDeviceLocalBuffer:size"); - if (!AnyBitSet(memory_type & MemoryType::kDeviceLocal)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Device-local buffers require the kDeviceLocal bit: " - << MemoryTypeString(memory_type); - } - - // Find an allocator that works for device-local buffers. - IREE_ASSIGN_OR_RETURN( - auto* allocator, - FindCompatibleAllocator(memory_type, buffer_usage, device_placements)); - return allocator->Allocate(memory_type, buffer_usage, allocation_size); -} - -Status DeviceManager::Submit(Device* device, CommandQueue* command_queue, - absl::Span batches, - Time deadline_ns) { - IREE_TRACE_SCOPE0("DeviceManager::Submit"); - return command_queue->Submit(batches); -} - -Status DeviceManager::Flush() { - IREE_TRACE_SCOPE0("DeviceManager::Flush"); - return OkStatus(); -} - -Status DeviceManager::WaitIdle(Time deadline_ns) { - IREE_TRACE_SCOPE0("DeviceManager::WaitIdle"); - absl::MutexLock lock(&device_mutex_); - for (const auto& device : devices_) { - IREE_RETURN_IF_ERROR(device->WaitIdle(deadline_ns)); - } - return OkStatus(); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/device_manager.h b/iree/hal/device_manager.h deleted file mode 100644 index 492b824cde861..0000000000000 --- a/iree/hal/device_manager.h +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEVICE_MANAGER_H_ -#define IREE_HAL_DEVICE_MANAGER_H_ - -#include - -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/device.h" -#include "iree/hal/device_placement.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { - -// Specifies how devices should be resolved to DevicePlacements. -// Most fields are optional and when not included will be ignored. -struct PlacementSpec { - // TODO(benvanik): other requirements (features/caps, power, etc). - - // A list of executable formats that the placement should support. - // If more than one format is provided any device satisfying at least one - // will be considered for placement. The formats can be sorted in descending - // priority order to prefer the first available format in the case of ties. - absl::Span available_formats; -}; - -// Manages device lifetime and placement resolution. -// Optionally the DeviceManager may be used for automatic device selection for -// allocations or batched submissions, however this is not required if specific -// devices and scheduling behavior are known to the caller. -// -// Thread-safe. Note that callers must ensure that unregistered devices are kept -// alive for as long as any commands are in-flight that may be using them. -class DeviceManager final { - public: - DeviceManager(); - ~DeviceManager(); - - // Registers a device with the manager. - // The device will be used to resolve placements. Any placements resolved - // prior to the addition of the device will need to be refreshed by the caller - // if they want to make use of the new device. - Status RegisterDevice(ref_ptr device); - - // Unregisters a device with the manager. - // Placements that resolved to the device prior to unregistering will remain - // valid for that device. Callers will need to refresh the placements to - // ensure the device stops being used. - Status UnregisterDevice(Device* device); - - // TODO(benvanik): dispatch info + requirements + etc -> DevicePlacement. - - // Resolves a placement spec to a device placement based on the registered - // devices. - // If the placement is not fully specified the device and queue may be chosen - // at random. See PlacementSpec for more information about resolution and - // ranking. - StatusOr ResolvePlacement( - const PlacementSpec& placement_spec) const; - - // Finds an allocator that can allocate buffers of the given |memory_type| and - // |buffer_usage| such that the buffers can be used interchangebly. - // Fails if there is no Allocator that can satisfy that requirement. - StatusOr FindCompatibleAllocator( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - absl::Span device_placements) const; - - // Tries to allocate a host-local buffer that _may_ be optimal for use with - // the given |device_placements| and _may_ be device-visible. The buffer can - // be used for staging uploads to device-local buffers and is useful for times - // when the buffer will be used more on the host than the device. If a buffer - // never needs to be used with a device prefer instead - // Allocator::host_local()::Allocate. - // - // Returns a buffer even if it's not possible to satisfy the requested - // |buffer_usage| for the |device_placements| at the cost of a run-time - // performance hit. - StatusOr> TryAllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span device_placements); - StatusOr> TryAllocateDeviceVisibleBuffer( - BufferUsageBitfield buffer_usage, device_size_t allocation_size, - absl::Span device_placements) { - return TryAllocateDeviceVisibleBuffer( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage, - allocation_size, device_placements); - } - - // Allocates a host-local buffer that is optimal for use on the host but is - // usable by the given |device_placements| (at a possible performance - // penalty). The buffer can be used for staging uploads to device-local - // buffers and is useful for times when the buffer will be used more on the - // host than the device. If a buffer never needs to be used with a device - // prefer instead HeapBuffer::Allocate. - // - // Fails if it is not possible to allocate and satisfy all |device_placements| - // for the requested |buffer_usage|. - StatusOr> AllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span device_placements); - StatusOr> AllocateDeviceVisibleBuffer( - BufferUsageBitfield buffer_usage, device_size_t allocation_size, - absl::Span device_placements) { - return AllocateDeviceVisibleBuffer( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage, - allocation_size, device_placements); - } - - // Allocates a device-local buffer that is optimal for use with the given - // |device_placements|. The buffer will not be host-visible and can only be - // used from compatible device queues. - // - // Fails if it is not possible to allocate and satisfy all |device_placements| - // for the requested |buffer_usage|. - StatusOr> AllocateDeviceLocalBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span device_placements); - StatusOr> AllocateDeviceLocalBuffer( - BufferUsageBitfield buffer_usage, device_size_t allocation_size, - absl::Span device_placements) { - return AllocateDeviceLocalBuffer(MemoryType::kDeviceLocal, buffer_usage, - allocation_size, device_placements); - } - - // Enqueues a submission against the given target |device| |command_queue|. - // The provided |deadline| is used to determine how long the submission can - // stay waiting in the queue prior to flushing, with absl::InfinitePast - // indicating immediate submission and absl::InfiniteFuture indicating that - // Flush must be called. - // - // If |batches| signal_semaphores are provided they will be signaled when - // their corresponding submission has completed. If a sequence of submissions - // are performed then the semaphore value relationships can be used to elide - // waits. - // - // All provided resources must remain alive until the provided semaphores are - // signaled indicating that the resources used are no longer required. - // - // Submissions may be made from any thread. Behavior is undefined - // if a thread is performing a WaitIdle while another thread submits work. - Status Submit(Device* device, CommandQueue* command_queue, - absl::Span batches, Time deadline_ns); - Status Submit(Device* device, CommandQueue* command_queue, - absl::Span batches, - Duration timeout_ns) { - return Submit(device, command_queue, batches, - RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - Status Submit(Device* device, CommandQueue* command_queue, - absl::Span batches) { - return Submit(device, command_queue, batches, InfinitePast()); - } - - // Flushes any requests that are pending in the scheduler and ensures they - // begin executing ASAP regardless of policy. - // - // If any used device has encountered an error during submission at any - // point it will be returned here (repeatedly). - Status Flush(); - - // Blocks until all outstanding requests have been completed. - // This is equivalent to having waited on all outstanding semaphore signal - // operations in all previously submitted batches. - // Implicitly calls Flush to ensure delayed requests are scheduled. - // Work submitted from other threads during a wait may not be included in the - // wait set. - // - // If any used device has encountered an error during submission at any - // point it will be returned here (repeatedly). - Status WaitIdle(Time deadline_ns); - inline Status WaitIdle(Duration timeout_ns) { - return WaitIdle(RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - inline Status WaitIdle() { return WaitIdle(InfiniteFuture()); } - - private: - mutable absl::Mutex device_mutex_; - std::vector> devices_ ABSL_GUARDED_BY(device_mutex_); -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEVICE_MANAGER_H_ diff --git a/iree/hal/device_placement.h b/iree/hal/device_placement.h deleted file mode 100644 index bd337ff6e4c98..0000000000000 --- a/iree/hal/device_placement.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEVICE_PLACEMENT_H_ -#define IREE_HAL_DEVICE_PLACEMENT_H_ - -namespace iree { -namespace hal { - -class Device; - -// TODO(benvanik): define device-specific placement info - possibly opaque. -struct DevicePlacement { - Device* device = nullptr; - int queue_id = 0; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEVICE_PLACEMENT_H_ diff --git a/iree/hal/driver.c b/iree/hal/driver.c new file mode 100644 index 0000000000000..8ae07319976a8 --- /dev/null +++ b/iree/hal/driver.c @@ -0,0 +1,66 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/driver.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" + +#define _VTABLE_DISPATCH(driver, method_name) \ + IREE_HAL_VTABLE_DISPATCH(driver, iree_hal_driver, method_name) + +IREE_HAL_API_RETAIN_RELEASE(driver); + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_driver_query_available_devices( + iree_hal_driver_t* driver, iree_allocator_t allocator, + iree_hal_device_info_t** out_device_infos, + iree_host_size_t* out_device_info_count) { + IREE_ASSERT_ARGUMENT(driver); + IREE_ASSERT_ARGUMENT(out_device_infos); + IREE_ASSERT_ARGUMENT(out_device_info_count); + *out_device_info_count = 0; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(driver, query_available_devices)( + driver, allocator, out_device_infos, out_device_info_count); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_create_device( + iree_hal_driver_t* driver, iree_hal_device_id_t device_id, + iree_allocator_t allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(driver); + IREE_ASSERT_ARGUMENT(out_device); + *out_device = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(driver, create_device)( + driver, device_id, allocator, out_device); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_driver_create_default_device(iree_hal_driver_t* driver, + iree_allocator_t allocator, + iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(driver); + IREE_ASSERT_ARGUMENT(out_device); + *out_device = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(driver, create_device)( + driver, IREE_HAL_DRIVER_ID_INVALID, allocator, out_device); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/driver.h b/iree/hal/driver.h index e1e25adb69df6..d68b7cd641d0a 100644 --- a/iree/hal/driver.h +++ b/iree/hal/driver.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,54 +15,118 @@ #ifndef IREE_HAL_DRIVER_H_ #define IREE_HAL_DRIVER_H_ -#include -#include -#include +#include +#include -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/debug_capture_manager.h" +#include "iree/base/api.h" #include "iree/hal/device.h" -#include "iree/hal/device_info.h" +#include "iree/hal/resource.h" -namespace iree { -namespace hal { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -class Driver : public RefObject { - public: - virtual ~Driver() = default; +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// - // Driver name used during registration. - const std::string& name() const { return name_; } +// An opaque factory-specific handle to identify different drivers. +typedef uint64_t iree_hal_driver_id_t; - // TODO(benvanik): info/query (version number, etc). +#define IREE_HAL_DRIVER_ID_INVALID 0ull - // Enumerates devices available for creation from the driver. - // This may fail if the driver is in an invalid state but otherwise will - // return an empty list if no devices are available. - virtual StatusOr> EnumerateAvailableDevices() = 0; +// Describes a driver providing device enumeration and creation. +// The lifetime of memory referenced by this structure (such as strings) is +// dependent on where it originated. +// +// * When using iree_hal_driver_registry_enumerate the driver info is copied +// into memory owned by the caller. +// * When queried from a live driver with iree_hal_driver_info the memory is +// only guaranteed to live for as long as the driver is. +// * When enumerating via factories the information may be valid only while the +// driver registry lock is held. +typedef struct { + IREE_API_UNSTABLE + + // Opaque handle used by factories. Unique across all factories. + iree_hal_driver_id_t driver_id; + + // Canonical name of the driver as used in command lines, documentation, etc. + // Examples: 'metal', 'vulkan' + iree_string_view_t driver_name; + + // Full human-readable name of the driver for display. + // Examples: 'Vulkan 1.2 (NVIDIA)'. + iree_string_view_t full_name; + + // TODO(benvanik): version information; useful if wanting to expose multiple + // versions that may have completely different implementations (like vulkan + // 1.0, 1.1, and 1.2) but allow a nice sort/selection process. + // TODO(benvanik): triple, feature flags, etc. +} iree_hal_driver_info_t; + +//===----------------------------------------------------------------------===// +// iree_hal_driver_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_driver_s iree_hal_driver_t; + +// Retains the given |driver| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_driver_retain(iree_hal_driver_t* driver); + +// Releases the given |driver| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_driver_release(iree_hal_driver_t* driver); + +// Queries available devices and returns them as a list. +// The provided |allocator| will be used to allocate the returned list and after +// the caller is done with it |out_device_infos| must be freed with that same +// allocator by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_driver_query_available_devices( + iree_hal_driver_t* driver, iree_allocator_t allocator, + iree_hal_device_info_t** out_device_infos, + iree_host_size_t* out_device_info_count); + +// Creates a device as queried with iree_hal_driver_query_available_devices. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_create_device( + iree_hal_driver_t* driver, iree_hal_device_id_t device_id, + iree_allocator_t allocator, iree_hal_device_t** out_device); + +// Creates the driver-defined "default" device. This may simply be the first +// device enumerated. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_driver_create_default_device(iree_hal_driver_t* driver, + iree_allocator_t allocator, + iree_hal_device_t** out_device); + +//===----------------------------------------------------------------------===// +// iree_hal_driver_t implementation details +//===----------------------------------------------------------------------===// - // Creates the driver-defined 'default' device. - // This may simply be the first device enumerated. - virtual StatusOr> CreateDefaultDevice() = 0; +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE - // Creates a device as queried with the given |driver_handle|. - virtual StatusOr> CreateDevice(DriverDeviceID device_id) = 0; - StatusOr> CreateDevice(const DeviceInfo& device_info) { - return CreateDevice(device_info.device_id()); - } + void(IREE_API_PTR* destroy)(iree_hal_driver_t* driver); - // Gets the capture manager for this driver, if one exists. - virtual DebugCaptureManager* debug_capture_manager() { return nullptr; } + iree_status_t(IREE_API_PTR* query_available_devices)( + iree_hal_driver_t* driver, iree_allocator_t allocator, + iree_hal_device_info_t** out_device_infos, + iree_host_size_t* out_device_info_count); - protected: - explicit Driver(std::string name) : name_(std::move(name)) {} + iree_status_t(IREE_API_PTR* create_device)(iree_hal_driver_t* driver, + iree_hal_device_id_t device_id, + iree_allocator_t allocator, + iree_hal_device_t** out_device); +} iree_hal_driver_vtable_t; - private: - const std::string name_; -}; +IREE_API_EXPORT void IREE_API_CALL +iree_hal_driver_destroy(iree_hal_driver_t* driver); -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_DRIVER_H_ diff --git a/iree/hal/api.c b/iree/hal/driver_registry.c similarity index 98% rename from iree/hal/api.c rename to iree/hal/driver_registry.c index 499156c03f826..2efe308a2e63b 100644 --- a/iree/hal/api.c +++ b/iree/hal/driver_registry.c @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/hal/api.h" +#include "iree/hal/driver_registry.h" #include "iree/base/synchronization.h" #include "iree/base/threading.h" #include "iree/base/tracing.h" +#include "iree/hal/detail.h" //===----------------------------------------------------------------------===// // iree_hal_driver_registry_t @@ -324,6 +325,10 @@ iree_hal_driver_registry_try_create_by_name( if (hit_driver_id != IREE_HAL_DRIVER_ID_INVALID) { status = hit_factory->try_create(hit_factory->self, hit_driver_id, allocator, out_driver); + } else { + status = + iree_make_status(IREE_STATUS_NOT_FOUND, "no driver '%.*s' registered", + (int)driver_name.size, driver_name.data); } iree_slim_mutex_unlock(®istry->mutex); diff --git a/iree/hal/driver_registry.h b/iree/hal/driver_registry.h new file mode 100644 index 0000000000000..84f54574cdc63 --- /dev/null +++ b/iree/hal/driver_registry.h @@ -0,0 +1,172 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DRIVER_REGISTRY_H_ +#define IREE_HAL_DRIVER_REGISTRY_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/driver.h" +#include "iree/hal/resource.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +// Factory interface used for driver enumeration and creation. +// The factory is designed to in many cases live in rodata by not requiring any +// real code or processing when the driver is statically known to be available. +// When drivers may be dynamically available based on system configuration a +// factory can discover them and provide them during enumeration. +// +// Delay-loaded drivers that may require non-trivial setup time (such as those +// implemented in dynamic libraries or over RPC) can be speculatively enumerated +// by a factory and then rely on the try_create to actually perform the slow +// work once the user has explicitly signaled that they are willing to pay the +// cost (and deal with the consequences). +// +// WARNING: this API is unstable until the HAL is fully ported. Do not use. +typedef struct { + // TODO(benvanik): version field. + IREE_API_UNSTABLE + + // User-defined pointer passed to all functions. + void* self; + + // Queries the list of available drivers provided by the factory, if any. + // |out_driver_infos| will be populated with a *reference* to factory data + // structures (such as the driver name) that callers may choose to clone if + // needed. + // + // Implementers must make their factory enumeration results immutable for the + // duration they are registered, though the behavior of try_create is allowed + // to change call-to-call. If a factory needs to mutate its set of enumerated + // devices then it must do so by first unregistering itself and re-registering + // only after the changes have been made. + // + // Called with the driver registry lock held; may be called from any thread. + iree_status_t(IREE_API_PTR* enumerate)( + void* self, const iree_hal_driver_info_t** out_driver_infos, + iree_host_size_t* out_driver_info_count); + + // Tries to create a driver as previously queried with enumerate. + // |driver_id| is the opaque ID returned from enumeration; note that there may + // be a significant amount of time between enumeration and creation and the + // driver registry lock may have been release between then. + // + // Delay-loaded drivers may still fail here if - for example - required system + // resources are unavailable or permission is denied. + // + // Called with the driver registry lock held; may be called from any thread. + iree_status_t(IREE_API_PTR* try_create)(void* self, + iree_hal_driver_id_t driver_id, + iree_allocator_t allocator, + iree_hal_driver_t** out_driver); +} iree_hal_driver_factory_t; + +//===----------------------------------------------------------------------===// +// iree_hal_driver_registry_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_driver_registry_s iree_hal_driver_registry_t; + +// Returns the default per-process driver registry. +// In simple applications this is usually where you want to go to register and +// create drivers. More sophisticated applications that want tighter control +// over the visibility of drivers to certain callers such as when dealing with +// requests from multiple users may choose to allocate their own registries and +// manage their lifetime as desired. +// +// TODO(benvanik): remove global registry and make callers manage always. We can +// provide helpers to make that easier to do, but there's really no benefit to +// having this be global like it is. Alternatively, this can be opt-in thanks to +// LTO: if a user doesn't call this then the default registry is never +// allocated. +IREE_API_EXPORT iree_hal_driver_registry_t* IREE_API_CALL +iree_hal_driver_registry_default(); + +// Registers a driver factory to serve future queries/requests for drivers. +// See iree_hal_driver_registry_t for more information. +// +// Thread-safe. The factory is not retained and must be kept alive by the caller +// until it is unregistered (or the application terminates). +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_driver_registry_register_factory( + iree_hal_driver_registry_t* registry, + const iree_hal_driver_factory_t* factory); + +// Unregisters a driver factory. +// Unregistering a factory only prevents new drivers from being created; +// existing drivers may remain live even after unregistering. Factories can +// expect that no new drivers will be created via the factory after the call +// returns. +// +// Thread-safe. As the factory is not retained by the registry the caller must +// release its memory (if needed) after this call returns. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_driver_registry_unregister_factory( + iree_hal_driver_registry_t* registry, + const iree_hal_driver_factory_t* factory); + +// Enumerates all drivers from registered factories and returns them as a list. +// The provided |allocator| will be used to allocate the returned list and after +// the caller is done with it |out_driver_infos| must be freed with that same +// allocator by the caller. +// +// The set of drivers returned should be considered the superset of those that +// may be available for successful creation as it's possible that delay-loaded +// drivers may fail even if they appear in this list. +// +// Thread-safe. Note that the factory may be unregistered between the query +// completing and any attempt to instantiate the driver. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_registry_enumerate( + iree_hal_driver_registry_t* registry, iree_allocator_t allocator, + iree_hal_driver_info_t** out_driver_infos, + iree_host_size_t* out_driver_info_count); + +// Attempts to create a driver registered with the driver registry by a specific +// ID as returned during enumeration in iree_hal_driver_info_t::driver_id. +// This can be used to specify the exact driver to create in cases where there +// may be multiple factories providing drivers with the same name. +// +// Thread-safe. May block the caller if the driver is delay-loaded and needs to +// perform additional loading/verification/etc before returning. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_driver_registry_try_create( + iree_hal_driver_registry_t* registry, iree_hal_driver_id_t driver_id, + iree_allocator_t allocator, iree_hal_driver_t** out_driver); + +// Attempts to create a driver registered with the given canonical driver name. +// Effectively enumerate + find by name + try_create if found. Factories are +// searched in most-recently-added order such that it's possible to override +// drivers with newer registrations when multiple factories provide the same +// driver name. +// +// Thread-safe. May block the caller if the driver is delay-loaded and needs to +// perform additional loading/verification/etc before returning. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_driver_registry_try_create_by_name( + iree_hal_driver_registry_t* registry, iree_string_view_t driver_name, + iree_allocator_t allocator, iree_hal_driver_t** out_driver); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_DRIVER_REGISTRY_H_ diff --git a/iree/hal/drivers/BUILD b/iree/hal/drivers/BUILD index 4bf772b9033db..8c421d7dff6c7 100644 --- a/iree/hal/drivers/BUILD +++ b/iree/hal/drivers/BUILD @@ -30,8 +30,5 @@ cc_library( "//iree/hal/dylib/registration", "//iree/hal/vmla/registration", "//iree/hal/vulkan/registration", - ] + select({ - "@bazel_tools//src/conditions:darwin": ["//iree/hal/metal/registration"], - "//conditions:default": [], - }), + ], ) diff --git a/iree/hal/drivers/CMakeLists.txt b/iree/hal/drivers/CMakeLists.txt index 3671c847028b0..305bcb5c436d3 100644 --- a/iree/hal/drivers/CMakeLists.txt +++ b/iree/hal/drivers/CMakeLists.txt @@ -18,9 +18,6 @@ set(IREE_HAL_DRIVER_MODULES) if(${IREE_HAL_DRIVER_DYLIB}) list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration) endif() -if(${IREE_HAL_DRIVER_METAL}) - list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::metal::registration) -endif() if(${IREE_HAL_DRIVER_VMLA}) list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::registration) endif() diff --git a/iree/hal/drivers/init.c b/iree/hal/drivers/init.c index 2b9a2b935b006..bb36616849d1a 100644 --- a/iree/hal/drivers/init.c +++ b/iree/hal/drivers/init.c @@ -20,10 +20,6 @@ #include "iree/hal/dylib/registration/driver_module.h" #endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE -#if defined(IREE_HAL_HAVE_METAL_DRIVER_MODULE) -#include "iree/hal/metal/registration/driver_module.h" -#endif // IREE_HAL_HAVE_METAL_DRIVER_MODULE - #if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE) #include "iree/hal/vmla/registration/driver_module.h" #endif // IREE_HAL_HAVE_VMLA_DRIVER_MODULE @@ -41,11 +37,6 @@ iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) { z0, iree_hal_dylib_driver_module_register(registry)); #endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE -#if defined(IREE_HAL_HAVE_METAL_DRIVER_MODULE) - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_metal_driver_module_register(registry)); -#endif // IREE_HAL_HAVE_METAL_DRIVER_MODULE - #if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE) IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_vmla_driver_module_register(registry)); diff --git a/iree/hal/dylib/BUILD b/iree/hal/dylib/BUILD index 49a5711beb42d..0e6ae3d977cd2 100644 --- a/iree/hal/dylib/BUILD +++ b/iree/hal/dylib/BUILD @@ -12,49 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") - package( default_visibility = ["//visibility:public"], features = ["layering_check"], licenses = ["notice"], # Apache 2.0 ) - -iree_cmake_extra_content( - content = """ -if(NOT ${IREE_HAL_DRIVER_DYLIB}) - return() -endif() -""", -) - -cc_library( - name = "dylib", - srcs = [ - "dylib_device.cc", - "dylib_driver.cc", - "dylib_executable.cc", - "dylib_executable_cache.cc", - ], - hdrs = [ - "dylib_device.h", - "dylib_driver.h", - "dylib_executable.h", - "dylib_executable_cache.h", - ], - deps = [ - "//iree/base:dynamic_library", - "//iree/base:file_io", - "//iree/base:file_path", - "//iree/base:flatcc", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "//iree/hal/host:host_executable", - "//iree/hal/host:host_local_device", - "//iree/hal/host/serial:serial_scheduling_model", - "//iree/schemas:dylib_executable_def_c_fbs", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) diff --git a/iree/hal/dylib/CMakeLists.txt b/iree/hal/dylib/CMakeLists.txt index 03d278debe295..15e92636ef743 100644 --- a/iree/hal/dylib/CMakeLists.txt +++ b/iree/hal/dylib/CMakeLists.txt @@ -12,38 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_HAL_DRIVER_DYLIB}) - return() -endif() - iree_add_all_subdirs() - -iree_cc_library( - NAME - dylib - HDRS - "dylib_device.h" - "dylib_driver.h" - "dylib_executable.h" - "dylib_executable_cache.h" - SRCS - "dylib_device.cc" - "dylib_driver.cc" - "dylib_executable.cc" - "dylib_executable_cache.cc" - DEPS - absl::inlined_vector - absl::span - iree::base::dynamic_library - iree::base::file_io - iree::base::file_path - iree::base::flatcc - iree::base::status - iree::base::tracing - iree::hal - iree::hal::host::host_executable - iree::hal::host::host_local_device - iree::hal::host::serial::serial_scheduling_model - iree::schemas::dylib_executable_def_c_fbs - PUBLIC -) diff --git a/iree/hal/dylib/dylib_device.cc b/iree/hal/dylib/dylib_device.cc deleted file mode 100644 index dc74a98ced932..0000000000000 --- a/iree/hal/dylib/dylib_device.cc +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dylib/dylib_device.h" - -#include - -#include "iree/base/tracing.h" -#include "iree/hal/dylib/dylib_executable_cache.h" - -namespace iree { -namespace hal { -namespace dylib { - -DyLibDevice::DyLibDevice( - DeviceInfo device_info, - std::unique_ptr scheduling_model) - : HostLocalDevice(std::move(device_info), std::move(scheduling_model)) {} - -DyLibDevice::~DyLibDevice() = default; - -ref_ptr DyLibDevice::CreateExecutableCache() { - IREE_TRACE_SCOPE0("DyLibDevice::CreateExecutableCache"); - return make_ref(); -} - -} // namespace dylib -} // namespace hal -} // namespace iree diff --git a/iree/hal/dylib/dylib_device.h b/iree/hal/dylib/dylib_device.h deleted file mode 100644 index a99604ee3152c..0000000000000 --- a/iree/hal/dylib/dylib_device.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DYLIB_DYLIB_DEVICE_H_ -#define IREE_HAL_DYLIB_DYLIB_DEVICE_H_ - -#include "iree/hal/host/host_local_device.h" - -namespace iree { -namespace hal { -namespace dylib { - -class DyLibDevice final : public host::HostLocalDevice { - public: - DyLibDevice(DeviceInfo device_info, - std::unique_ptr scheduling_model); - ~DyLibDevice() override; - - ref_ptr CreateExecutableCache() override; -}; - -} // namespace dylib -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DYLIB_DYLIB_DEVICE_H_ diff --git a/iree/hal/dylib/dylib_driver.cc b/iree/hal/dylib/dylib_driver.cc deleted file mode 100644 index 4aaa6683023af..0000000000000 --- a/iree/hal/dylib/dylib_driver.cc +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dylib/dylib_driver.h" - -#include - -#include "iree/hal/device_info.h" -#include "iree/hal/dylib/dylib_device.h" -#include "iree/hal/host/serial/serial_scheduling_model.h" - -namespace iree { -namespace hal { -namespace dylib { -namespace { - -DeviceInfo GetDefaultDeviceInfo() { - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - // TODO(benvanik): implement debugging/profiling features. - // supported_features |= DeviceFeature::kDebugging; - // supported_features |= DeviceFeature::kCoverage; - // supported_features |= DeviceFeature::kProfiling; - DeviceInfo device_info("dylib", "Dynamic Library (dylib)", - supported_features); - // TODO(benvanik): device info. - return device_info; -} - -} // namespace - -DyLibDriver::DyLibDriver() : Driver("dylib") {} - -DyLibDriver::~DyLibDriver() = default; - -StatusOr> DyLibDriver::EnumerateAvailableDevices() { - std::vector device_infos; - device_infos.push_back(GetDefaultDeviceInfo()); - return device_infos; -} - -StatusOr> DyLibDriver::CreateDefaultDevice() { - // Only one device, pass a dummy device_id. - return CreateDevice(0); -} - -StatusOr> DyLibDriver::CreateDevice(DriverDeviceID device_id) { - // Only one device, ignore device_id. - auto scheduling_model = std::make_unique(); - return make_ref(GetDefaultDeviceInfo(), - std::move(scheduling_model)); -} - -} // namespace dylib -} // namespace hal -} // namespace iree diff --git a/iree/hal/dylib/dylib_driver.h b/iree/hal/dylib/dylib_driver.h deleted file mode 100644 index a8402710eafbf..0000000000000 --- a/iree/hal/dylib/dylib_driver.h +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DYLIB_DYLIB_DRIVER_H_ -#define IREE_HAL_DYLIB_DYLIB_DRIVER_H_ - -#include "iree/hal/driver.h" - -namespace iree { -namespace hal { -namespace dylib { - -class DyLibDriver final : public Driver { - public: - DyLibDriver(); - ~DyLibDriver() override; - - StatusOr> EnumerateAvailableDevices() override; - - StatusOr> CreateDefaultDevice() override; - - StatusOr> CreateDevice(DriverDeviceID device_id) override; -}; - -} // namespace dylib -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DYLIB_DYLIB_DRIVER_H_ diff --git a/iree/hal/dylib/dylib_executable.cc b/iree/hal/dylib/dylib_executable.cc deleted file mode 100644 index 2aa7292cce0a8..0000000000000 --- a/iree/hal/dylib/dylib_executable.cc +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dylib/dylib_executable.h" - -#include "iree/base/file_io.h" -#include "iree/base/file_path.h" - -// flatcc schemas: -#include "iree/base/flatcc.h" -#include "iree/schemas/dylib_executable_def_reader.h" -#include "iree/schemas/dylib_executable_def_verifier.h" - -// NOTE: starting to port this to C. - -// Verifies the structure of the flatbuffer so that we can avoid doing so during -// runtime. There are still some conditions we must be aware of (such as omitted -// names on functions with internal linkage), however we shouldn't need to -// bounds check anything within the flatbuffer after this succeeds. -static iree_status_t iree_hal_dylib_executable_flatbuffer_verify( - iree_const_byte_span_t flatbuffer_data) { - if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer data is not present or less than 16 bytes (%zu total)", - flatbuffer_data.data_length); - } - - // Run flatcc generated verification. This ensures all pointers are in-bounds - // and that we can safely walk the file, but not that the actual contents of - // the flatbuffer meet our expectations. - int verify_ret = iree_DyLibExecutableDef_verify_as_root( - flatbuffer_data.data, flatbuffer_data.data_length); - if (verify_ret != flatcc_verify_ok) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer verification failed: %s", - flatcc_verify_error_string(verify_ret)); - } - - iree_DyLibExecutableDef_table_t executable_def = - iree_DyLibExecutableDef_as_root(flatbuffer_data.data); - - flatbuffers_string_vec_t entry_points_vec = - iree_DyLibExecutableDef_entry_points_get(executable_def); - size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); - for (size_t i = 0; i < entry_point_count; ++i) { - if (!flatbuffers_string_len( - flatbuffers_string_vec_at(entry_points_vec, i))) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable entry point %zu has no name", i); - } - } - - if (!flatbuffers_uint8_vec_len( - iree_DyLibExecutableDef_library_embedded_get(executable_def))) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable library_embedded is missing/empty"); - } - - return iree_ok_status(); -} - -namespace iree { -namespace hal { -namespace dylib { - -// static -StatusOr> DyLibExecutable::Load(ExecutableSpec spec) { - auto executable = make_ref(); - IREE_RETURN_IF_ERROR(executable->Initialize(spec)); - return executable; -} - -DyLibExecutable::DyLibExecutable() = default; - -DyLibExecutable::~DyLibExecutable() { - IREE_TRACE_SCOPE0("DyLibExecutable::dtor"); -#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION - // Leak the library when tracing, since the profiler may still be reading it. - // TODO(benvanik): move to an atexit handler instead, verify with ASAN/MSAN - // TODO(scotttodd): Make this compatible with testing: - // two test cases, one for each function in the same executable - // first test case passes, second fails to open the file (already open) - executable_library_.release(); -#else - executable_library_.reset(); -#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION - for (const auto& file_path : temp_file_paths_) { - file_io::DeleteFile(file_path).IgnoreError(); - } -} - -Status DyLibExecutable::Initialize(ExecutableSpec spec) { - IREE_TRACE_SCOPE0("DyLibExecutable::Initialize"); - - // Verify and fetch the executable flatbuffer wrapper. - iree_const_byte_span_t executable_data = iree_make_const_byte_span( - spec.executable_data.data(), spec.executable_data.size()); - IREE_RETURN_IF_ERROR( - iree_hal_dylib_executable_flatbuffer_verify(executable_data)); - iree_DyLibExecutableDef_table_t executable_def = - iree_DyLibExecutableDef_as_root(executable_data.data); - - // Write the embedded library out to a temp file, since all of the dynamic - // library APIs work with files. We could instead use in-memory files on - // platforms where that is convenient. - // - // TODO(#3845): use dlopen on an fd with either dlopen(/proc/self/fd/NN), - // fdlopen, or android_dlopen_ext to avoid needing to write the file to disk. - // Can fallback to memfd_create + dlopen where available, and fallback from - // that to disk (maybe just windows/mac). - std::string base_name = "dylib_executable"; - IREE_ASSIGN_OR_RETURN(auto library_temp_path, - file_io::GetTempFile(base_name)); - temp_file_paths_.push_back(library_temp_path); - -// Add platform-specific file extensions so opinionated dynamic library -// loaders are more likely to find the file: -#if defined(IREE_PLATFORM_WINDOWS) - library_temp_path += ".dll"; -#else - library_temp_path += ".so"; -#endif - - flatbuffers_uint8_vec_t embedded_library_vec = - iree_DyLibExecutableDef_library_embedded_get(executable_def); - IREE_RETURN_IF_ERROR(file_io::SetFileContents( - library_temp_path, - absl::string_view(reinterpret_cast(embedded_library_vec), - flatbuffers_uint8_vec_len(embedded_library_vec)))); - - IREE_ASSIGN_OR_RETURN(executable_library_, - DynamicLibrary::Load(library_temp_path.c_str())); - - flatbuffers_string_t debug_database_filename = - iree_DyLibExecutableDef_debug_database_filename_get(executable_def); - flatbuffers_uint8_vec_t debug_database_embedded_vec = - iree_DyLibExecutableDef_debug_database_embedded_get(executable_def); - if (flatbuffers_string_len(debug_database_filename) && - flatbuffers_uint8_vec_len(debug_database_embedded_vec)) { - IREE_TRACE_SCOPE0("DyLibExecutable::AttachDebugDatabase"); - auto debug_database_path = file_path::JoinPaths( - file_path::DirectoryName(library_temp_path), - absl::string_view(debug_database_filename, - flatbuffers_string_len(debug_database_filename))); - temp_file_paths_.push_back(debug_database_path); - IREE_IGNORE_ERROR(file_io::SetFileContents( - debug_database_path, - absl::string_view( - reinterpret_cast(debug_database_embedded_vec), - flatbuffers_uint8_vec_len(debug_database_embedded_vec)))); - executable_library_->AttachDebugDatabase(debug_database_path.c_str()); - } - - flatbuffers_string_vec_t entry_points = - iree_DyLibExecutableDef_entry_points_get(executable_def); - entry_functions_.resize(flatbuffers_string_vec_len(entry_points)); - IREE_TRACE(entry_names_.resize(flatbuffers_string_vec_len(entry_points))); - for (size_t i = 0; i < entry_functions_.size(); ++i) { - flatbuffers_string_t entry_point = - flatbuffers_string_vec_at(entry_points, i); - void* symbol = executable_library_->GetSymbol(entry_point); - if (!symbol) { - return NotFoundErrorBuilder(IREE_LOC) - << "Could not find symbol: " << entry_point; - } - entry_functions_[i] = symbol; - - IREE_TRACE(entry_names_[i] = entry_point); - } - - return OkStatus(); -} - -struct DyLibDispatchState : public HostExecutable::DispatchState { - DyLibDispatchState() = default; - - IREE_TRACE(const char* entry_name = nullptr); - - std::array workgroup_count; - std::array workgroup_size; - void* entry_function = nullptr; - std::array args; - std::array push_constants; -}; - -StatusOr> -DyLibExecutable::PrepareDispatch(const DispatchParams& params) { - IREE_TRACE_SCOPE0("DyLibExecutable::PrepareDispatch"); - - if (params.entry_point >= entry_functions_.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid entry point ordinal " << params.entry_point; - } - - auto dispatch_state = make_ref(); - dispatch_state->workgroup_count = params.workgroup_count; - dispatch_state->workgroup_size = params.workgroup_size; - IREE_TRACE(dispatch_state->entry_name = entry_names_[params.entry_point]); - dispatch_state->entry_function = entry_functions_[params.entry_point]; - - int binding_count = 0; - for (size_t set = 0; set < params.set_bindings.size(); ++set) { - for (size_t binding = 0; binding < params.set_bindings[set].size(); - ++binding) { - const auto& io_binding = params.set_bindings[set][binding]; - IREE_ASSIGN_OR_RETURN(auto memory, - io_binding.buffer->MapMemory( - MemoryAccessBitfield::kWrite, io_binding.offset, - io_binding.length)); - auto data = memory.mutable_data(); - dispatch_state->args[binding_count++] = data; - } - } - dispatch_state->push_constants = params.push_constants->values; - - return std::move(dispatch_state); -} - -Status DyLibExecutable::DispatchTile(DispatchState* state, - std::array workgroup_xyz) { - auto* dispatch_state = static_cast(state); - IREE_TRACE_SCOPE_DYNAMIC(dispatch_state->entry_name); - - auto entry_function = (void (*)(void**, uint32_t*, uint32_t*, uint32_t*, - uint32_t*))dispatch_state->entry_function; - entry_function(dispatch_state->args.data(), - dispatch_state->push_constants.data(), workgroup_xyz.data(), - dispatch_state->workgroup_count.data(), - dispatch_state->workgroup_size.data()); - - return OkStatus(); -} - -} // namespace dylib -} // namespace hal -} // namespace iree diff --git a/iree/hal/dylib/dylib_executable.h b/iree/hal/dylib/dylib_executable.h deleted file mode 100644 index 424c096c989f3..0000000000000 --- a/iree/hal/dylib/dylib_executable.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DYLIB_DYLIB_EXECUTABLE_H_ -#define IREE_HAL_DYLIB_DYLIB_EXECUTABLE_H_ - -#include -#include - -#include "absl/container/inlined_vector.h" -#include "iree/base/dynamic_library.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/executable_spec.h" -#include "iree/hal/host/host_executable.h" - -namespace iree { -namespace hal { -namespace dylib { - -struct MemrefType; - -class DyLibExecutable final : public HostExecutable { - public: - static StatusOr> Load(ExecutableSpec spec); - - DyLibExecutable(); - ~DyLibExecutable() override; - - bool supports_debugging() const override { return false; } - - StatusOr> PrepareDispatch( - const DispatchParams& params) override; - Status DispatchTile(DispatchState* state, - std::array workgroup_xyz) override; - - private: - Status Initialize(ExecutableSpec spec); - - absl::InlinedVector temp_file_paths_; - std::unique_ptr executable_library_; - std::vector entry_functions_; - - IREE_TRACE(std::vector entry_names_); -}; - -} // namespace dylib -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DYLIB_DYLIB_EXECUTABLE_H_ diff --git a/iree/hal/dylib/dylib_executable_cache.cc b/iree/hal/dylib/dylib_executable_cache.cc deleted file mode 100644 index 010e5d997125c..0000000000000 --- a/iree/hal/dylib/dylib_executable_cache.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dylib/dylib_executable_cache.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/dylib/dylib_executable.h" -#include "iree/hal/executable_format.h" - -namespace iree { -namespace hal { -namespace dylib { - -DyLibExecutableCache::DyLibExecutableCache() = default; - -DyLibExecutableCache::~DyLibExecutableCache() = default; - -bool DyLibExecutableCache::CanPrepareFormat(ExecutableFormat format) const { - return format == kExecutableFormatDyLib; -} - -StatusOr> DyLibExecutableCache::PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("DyLibExecutableCache::PrepareExecutable"); - - // TODO(scotttodd): Options for using in-memory files where supported, or not - // writing to temp files on disk (and failing if necessary) if not allowed. - // TODO(scotttodd): Use stable (possibly temp, but reusable) files when - // ExecutableCachingMode::AllowPersistentCaching is set. For example, - // hash data into a filename and read from / write to GetTempPath() or - // GetCachePath() rather than use GetTempFile(). - - return DyLibExecutable::Load(spec); -} - -} // namespace dylib -} // namespace hal -} // namespace iree diff --git a/iree/hal/dylib/dylib_executable_cache.h b/iree/hal/dylib/dylib_executable_cache.h deleted file mode 100644 index 68c98722a621b..0000000000000 --- a/iree/hal/dylib/dylib_executable_cache.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DYLIB_EXECUTABLE_CACHE_H_ -#define IREE_HAL_DYLIB_EXECUTABLE_CACHE_H_ - -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" - -namespace iree { -namespace hal { -namespace dylib { - -class DyLibExecutableCache final : public ExecutableCache { - public: - DyLibExecutableCache(); - ~DyLibExecutableCache() override; - - bool CanPrepareFormat(ExecutableFormat format) const override; - - StatusOr> PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) override; -}; - -} // namespace dylib -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DYLIB_EXECUTABLE_CACHE_H_ diff --git a/iree/hal/dylib/registration/BUILD b/iree/hal/dylib/registration/BUILD index d1927647200e0..ba66711a686b9 100644 --- a/iree/hal/dylib/registration/BUILD +++ b/iree/hal/dylib/registration/BUILD @@ -37,9 +37,9 @@ cc_library( "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1", ], deps = [ - "//iree/base:flags", - "//iree/base:status", "//iree/hal:api", - "//iree/hal/dylib", + "//iree/hal/local:task_driver", + "//iree/hal/local/loaders:legacy_library_loader", + "@com_google_absl//absl/flags:flag", ], ) diff --git a/iree/hal/dylib/registration/CMakeLists.txt b/iree/hal/dylib/registration/CMakeLists.txt index 3658f8bb54f79..6dcab19b52fd1 100644 --- a/iree/hal/dylib/registration/CMakeLists.txt +++ b/iree/hal/dylib/registration/CMakeLists.txt @@ -26,10 +26,10 @@ iree_cc_library( SRCS "driver_module.cc" DEPS - iree::base::flags - iree::base::status + absl::flags iree::hal::api - iree::hal::dylib + iree::hal::local::loaders::legacy_library_loader + iree::hal::local::task_driver DEFINES "IREE_HAL_HAVE_DYLIB_DRIVER_MODULE=1" PUBLIC diff --git a/iree/hal/dylib/registration/driver_module.cc b/iree/hal/dylib/registration/driver_module.cc index d9a0245d72883..1f0a302a6f99b 100644 --- a/iree/hal/dylib/registration/driver_module.cc +++ b/iree/hal/dylib/registration/driver_module.cc @@ -16,18 +16,35 @@ #include -#include "iree/hal/dylib/dylib_driver.h" +#include "absl/flags/flag.h" +#include "iree/hal/local/loaders/legacy_library_loader.h" +#include "iree/hal/local/task_driver.h" + +// TODO(#4298): remove this driver registration and wrapper. +// By having a single iree/hal/local/registration that then has the loaders +// added to it based on compilation settings we can have a single set of flags +// for everything. We can also have API helper methods that register the driver +// using an existing executor so that we can entirely externalize the task +// system configuration from the HAL. + +ABSL_FLAG(int, dylib_worker_count, 0, + "Specified number of workers to use or 0 for automatic."); +ABSL_FLAG(int, dylib_max_worker_count, 16, + "Maximum number of task system workers to use."); #define IREE_HAL_DYLIB_DRIVER_ID 0x58444C4Cu // XDLL static iree_status_t iree_hal_dylib_driver_factory_enumerate( void* self, const iree_hal_driver_info_t** out_driver_infos, iree_host_size_t* out_driver_info_count) { - static const iree_hal_driver_info_t driver_infos[1] = {{ - /*driver_id=*/IREE_HAL_DYLIB_DRIVER_ID, - /*driver_name=*/iree_make_cstring_view("dylib"), - /*full_name=*/iree_make_cstring_view("Dynamic library loader"), - }}; + static const iree_hal_driver_info_t driver_infos[1] = { + { + /*.driver_id=*/IREE_HAL_DYLIB_DRIVER_ID, + /*.driver_name=*/iree_make_cstring_view("dylib"), + /*.full_name=*/ + iree_make_cstring_view("AOT compiled dynamic libraries"), + }, + }; *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); *out_driver_infos = driver_infos; return iree_ok_status(); @@ -42,9 +59,42 @@ static iree_status_t iree_hal_dylib_driver_factory_try_create( " is provided by this factory", driver_id); } - auto* driver = new iree::hal::dylib::DyLibDriver(); - *out_driver = reinterpret_cast(driver); - return iree_ok_status(); + + iree_hal_task_device_params_t default_params; + iree_hal_task_device_params_initialize(&default_params); + + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + if (absl::GetFlag(FLAGS_dylib_worker_count) > 0) { + iree_task_topology_initialize_from_group_count( + absl::GetFlag(FLAGS_dylib_worker_count), &topology); + } else { + iree_task_topology_initialize_from_unique_l2_cache_groups( + /*max_group_count=*/absl::GetFlag(FLAGS_dylib_max_worker_count), + &topology); + } + + iree_hal_executable_loader_t* dylib_loader = NULL; + iree_status_t status = + iree_hal_legacy_library_loader_create(allocator, &dylib_loader); + iree_hal_executable_loader_t* loaders[1] = {dylib_loader}; + + iree_task_executor_t* executor = NULL; + if (iree_status_is_ok(status)) { + status = iree_task_executor_create(IREE_TASK_SCHEDULING_MODE_RESERVED, + &topology, allocator, &executor); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_task_driver_create( + iree_make_cstring_view("dylib"), &default_params, executor, + IREE_ARRAYSIZE(loaders), loaders, allocator, out_driver); + } + + iree_task_executor_release(executor); + iree_task_topology_deinitialize(&topology); + iree_hal_executable_loader_release(dylib_loader); + return status; } IREE_API_EXPORT iree_status_t IREE_API_CALL diff --git a/iree/hal/event.c b/iree/hal/event.c new file mode 100644 index 0000000000000..c012439502361 --- /dev/null +++ b/iree/hal/event.c @@ -0,0 +1,36 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/event.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(event, method_name) \ + IREE_HAL_VTABLE_DISPATCH(event, iree_hal_event, method_name) + +IREE_HAL_API_RETAIN_RELEASE(event); + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_event_create(iree_hal_device_t* device, iree_hal_event_t** out_event) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(out_event); + *out_event = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = IREE_HAL_VTABLE_DISPATCH( + device, iree_hal_device, create_event)(device, out_event); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/event.h b/iree/hal/event.h index c7786f435d7e2..ca2e2adf7b1f5 100644 --- a/iree/hal/event.h +++ b/iree/hal/event.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,21 +15,63 @@ #ifndef IREE_HAL_EVENT_H_ #define IREE_HAL_EVENT_H_ +#include +#include + +#include "iree/base/api.h" #include "iree/hal/resource.h" -namespace iree { -namespace hal { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_s iree_hal_device_t; + +//===----------------------------------------------------------------------===// +// iree_hal_event_t +//===----------------------------------------------------------------------===// -// Events are used for defining synchronization scopes within CommandBuffers. +// Events are used for defining synchronization scopes within command buffers. // An event only exists within a single CommandBuffer and must not be used -// across CommandBuffers from the same device or others. +// across command buffers from the same device or others. // -// See CommandBuffer::SignalEvent and CommandBuffer::WaitEvents for more info. -class Event : public Resource { - public: -}; +// See iree_hal_command_buffer_signal_event and +// iree_hal_command_buffer_wait_events for more info. +// +// Maps to VkEvent: +// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkEvent.html +typedef struct iree_hal_event_s iree_hal_event_t; + +// Creates an event for recording into command buffers. +// The returned event object is only usable with this device and events must +// only be used to synchronize within the same queue. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_event_create(iree_hal_device_t* device, iree_hal_event_t** out_event); + +// Retains the given |event| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_event_retain(iree_hal_event_t* event); + +// Releases the given |event| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_event_release(iree_hal_event_t* event); + +//===----------------------------------------------------------------------===// +// iree_hal_event_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_event_t* event); +} iree_hal_event_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_event_destroy(iree_hal_event_t* event); -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_EVENT_H_ diff --git a/iree/hal/executable_cache.cc b/iree/hal/executable.c similarity index 66% rename from iree/hal/executable_cache.cc rename to iree/hal/executable.c index 26ce40c851606..a27c675520073 100644 --- a/iree/hal/executable_cache.cc +++ b/iree/hal/executable.c @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/hal/executable_cache.h" +#include "iree/hal/executable.h" -namespace iree { -namespace hal { +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" -ExecutableCache::ExecutableCache() = default; +#define _VTABLE_DISPATCH(executable, method_name) \ + IREE_HAL_VTABLE_DISPATCH(executable, iree_hal_executable, method_name) -ExecutableCache::~ExecutableCache() = default; - -} // namespace hal -} // namespace iree +IREE_HAL_API_RETAIN_RELEASE(executable); diff --git a/iree/hal/executable.h b/iree/hal/executable.h index d724d01657110..dc41d75d541de 100644 --- a/iree/hal/executable.h +++ b/iree/hal/executable.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,43 +15,61 @@ #ifndef IREE_HAL_EXECUTABLE_H_ #define IREE_HAL_EXECUTABLE_H_ +#include +#include + +#include "iree/base/api.h" #include "iree/hal/resource.h" -namespace iree { -namespace hal { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_s iree_hal_device_t; -class Executable : public Resource { - public: - ~Executable() override = default; +//===----------------------------------------------------------------------===// +// iree_hal_executable_t +//===----------------------------------------------------------------------===// - // True if the executable was prepared with debugging enabled and the device - // and input data support debugging (symbols present, etc). - virtual bool supports_debugging() const = 0; +// Handle to a loaded executable. +// Loading of executables routes through an executable cache, allowing for +// context-aware scoped caches. HAL implementations can use this to preserve +// JIT'ed executables across processes or reuse executables across device +// instances. +// +// Executables provide one or more entry points that can be dispatched via +// iree_hal_command_buffer_dispatch. Some entry points may represent the same +// computation but specialized in different ways such that the runtime can +// switch strategies and choose between them per-dispatch. +// +// +// Maps (roughly) to vkShaderModule + VkPipeline[]. +typedef struct iree_hal_executable_s iree_hal_executable_t; - // TODO(benvanik): disassembly methods. +// Retains the given |executable| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_executable_retain(iree_hal_executable_t* executable); - // TODO(benvanik): relative offset calculation: - // - step once - // - step over - // - step out +// Releases the given |executable| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_executable_release(iree_hal_executable_t* executable); - // TODO(benvanik): create executable split on breakpoint. - // Executable should return when the breakpoint is hit without any future - // modifications to output buffers. If the breakpoint is not hit the - // executable should run to completion as normal. +//===----------------------------------------------------------------------===// +// iree_hal_executable_t implementation details +//===----------------------------------------------------------------------===// - // TODO(benvanik): retrieve coverage info. - // Returns a buffer containing offset -> coverage metrics. Note that depending - // on the device this may only contain a single coverage metric for the entire - // executable or some subset of the available offsets. +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE - // TODO(benvanik): retrieve profiling info. + void(IREE_API_PTR* destroy)(iree_hal_executable_t* executable); +} iree_hal_executable_vtable_t; - protected: - Executable() = default; -}; +IREE_API_EXPORT void IREE_API_CALL +iree_hal_executable_destroy(iree_hal_executable_t* executable); -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_EXECUTABLE_H_ diff --git a/iree/hal/executable_cache.c b/iree/hal/executable_cache.c new file mode 100644 index 0000000000000..dc17375c9532b --- /dev/null +++ b/iree/hal/executable_cache.c @@ -0,0 +1,66 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/executable_cache.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(executable_cache, method_name) \ + IREE_HAL_VTABLE_DISPATCH(executable_cache, iree_hal_executable_cache, \ + method_name) + +IREE_HAL_API_RETAIN_RELEASE(executable_cache); + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_cache_create( + iree_hal_device_t* device, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(out_executable_cache); + *out_executable_cache = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = IREE_HAL_VTABLE_DISPATCH( + device, iree_hal_device, create_executable_cache)(device, identifier, + out_executable_cache); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT bool IREE_API_CALL iree_hal_executable_cache_can_prepare_format( + iree_hal_executable_cache_t* executable_cache, + iree_hal_executable_format_t format) { + IREE_ASSERT_ARGUMENT(executable_cache); + return _VTABLE_DISPATCH(executable_cache, can_prepare_format)( + executable_cache, format); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_executable_cache_prepare_executable( + iree_hal_executable_cache_t* executable_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(executable_cache); + IREE_ASSERT_ARGUMENT(executable_layout); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(executable_cache, prepare_executable)( + executable_cache, executable_layout, caching_mode, executable_data, + out_executable); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/executable_cache.h b/iree/hal/executable_cache.h index a08385509aef2..3bee8ce920ec7 100644 --- a/iree/hal/executable_cache.h +++ b/iree/hal/executable_cache.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,66 +15,83 @@ #ifndef IREE_HAL_EXECUTABLE_CACHE_H_ #define IREE_HAL_EXECUTABLE_CACHE_H_ -#include "iree/base/bitfield.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" +#include +#include + +#include "iree/base/api.h" #include "iree/hal/executable.h" -#include "iree/hal/executable_format.h" #include "iree/hal/executable_layout.h" -#include "iree/hal/executable_spec.h" +#include "iree/hal/resource.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_s iree_hal_device_t; -namespace iree { -namespace hal { +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +// TODO(benvanik): eliminate fourccs and just use strings. +// An identifier for executable formats used to query support. +typedef uint32_t iree_hal_executable_format_t; + +// Constructs an iree_hal_executable_format_t 4cc at compile-time. +static inline iree_hal_executable_format_t iree_hal_make_executable_format( + char const four_cc[5]) { + return (four_cc[0] << 24) | (four_cc[1] << 16) | (four_cc[2] << 8) | + four_cc[3]; +} // Defines how the executable cache performs preparation. -enum class ExecutableCachingMode : uint32_t { +enum iree_hal_executable_caching_mode_e { // Allows the cache to reference the provided executable_data after it has // prepared the executable. Callers must ensure the data remains valid for the // lifetime of the cache. If memory mapping constant executable data from // disk this can be used to avoid copies. - kAliasProvidedData = 1 << 0, - + IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA = 1u << 0, // Allows the prepared executable to be cached persistently (on disk/etc). // Enable for any executable that is likely to be used in future runs. // Note that not all caches support persistent serialization and this is just // a hint. - kAllowPersistentCaching = 1 << 1, - + IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_PERSISTENT_CACHING = 1u << 1, // Allows the cache to optimize the executable as much as it can. // This may cause preparation to take significantly longer while (hopefully) // improving runtime performance. Avoid for one-shot executables. - kAllowOptimization = 1 << 2, - + IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION = 1u << 2, // Enables Executable debugging methods if supported by the device and // executable. This may disable certain optimizations or retain additional // data to allow disassembly, stepping, etc. // - // Device must support the DeviceFeature::kDebugging feature and executables - // must support the ExecutableFeature::kDebugging feature. - kEnableDebugging = 1 << 3, - + // Device must support the IREE_HAL_DEVICE_FEATURE_SUPPORTS_DEBUGGING feature + // and executables must support the ExecutableFeature::kDebugging feature. + IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_DEBUGGING = 1u << 3, // Enables Executable coverage if supported by the device and executable. // Depending on the optimization mode this may produce partial coverage // results (for example, when certain source operations were optimized away). // - // Device must support the DeviceFeature::kCoverage feature and executables - // must support the ExecutableFeature::kCoverage feature. - kEnableCoverage = 1 << 4, - + // Device must support the IREE_HAL_DEVICE_FEATURE_SUPPORTS_COVERAGE feature + // and executables must support the ExecutableFeature::kCoverage feature. + IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_COVERAGE = 1u << 4, // Enables Executable profiling if supported by the device and executable. // Depending on the optimization mode this may produce partial profiling // results. Profiling attribution (whether to the entire executable or // specific operations) depends on the implementation. // - // Device must support the DeviceFeature::kProfiling feature and executables - // must support the ExecutableFeature::kProfiling feature. - kEnableProfiling = 1 << 5, - + // Device must support the IREE_HAL_DEVICE_FEATURE_SUPPORTS_PROFILING feature + // and executables must support the ExecutableFeature::kProfiling feature. + IREE_HAL_EXECUTABLE_CACHING_MODE_ENABLE_PROFILING = 1u << 5, // Default caching mode. - kDefault = kAllowPersistentCaching | kAllowOptimization, + IREE_HAL_EXECUTABLE_CACHING_MODE_DEFAULT = + IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_PERSISTENT_CACHING | + IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION, }; -IREE_BITFIELD(ExecutableCachingMode); -using ExecutableCachingModeBitfield = ExecutableCachingMode; +typedef uint32_t iree_hal_executable_caching_mode_t; + +//===----------------------------------------------------------------------===// +// iree_hal_executable_cache_t +//===----------------------------------------------------------------------===// // A cache of prepared executables for a particular device. // Caches may be shared across multiple devices from the same driver or specific @@ -84,45 +101,80 @@ using ExecutableCachingModeBitfield = ExecutableCachingMode; // // The term 'cache' here is rather optimistic - it's perfectly acceptable for // implementations to not cache at all and return new Executables for each -// PrepareExecutable called (even for the same executable). Callers should -// expect such behavior and try to retain the results of the PrepareExecutable -// calls to reduce overhead in re-preparing executables. +// iree_hal_executable_cache_prepare_executable called (even for the same +// executable). Callers should expect such behavior and try to retain the +// results of the iree_hal_executable_cache_prepare_executable calls to reduce +// overhead in re-preparing executables. // // Thread-safe - multiple threads may prepare executables (including the *same* // executable) simultaneously. -class ExecutableCache : public RefObject { - public: - virtual ~ExecutableCache(); - - // TODO(benvanik): status/queries (size, etc). - - // TODO(b/137153339): serialization/deserialization. - - // Returns true if the executable cache can prepare the given executable input - // format. Preparation may still fail if the particular version or features - // required by the executable are not supported. - virtual bool CanPrepareFormat(ExecutableFormat format) const = 0; - - // Prepares an executable for use. - // The provided |spec| and |executable_data| will be used to either lookup a - // previously prepared executable in the cache or prepare a new one. - // - // Depending on the driver preparation may take a non-trivial amount of time - // (such as when JITing/etc). As the cache is internally synchronized callers - // can issue preparation requests from multiple threads - even for the same - // executables - and calls will block until preparation completes. - // - // When preparing a large number of executables it's recommended to use the - // PrepareExecutables method to batch and wait on the results. - virtual StatusOr> PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) = 0; - - protected: - ExecutableCache(); -}; - -} // namespace hal -} // namespace iree +typedef struct iree_hal_executable_cache_s iree_hal_executable_cache_t; + +// Creates an executable cache using the given identifier. +// The identifier is provided to the backing cache API as way to partition +// caches between different groups of executables (from different modules, etc). +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_cache_create( + iree_hal_device_t* device, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache); + +// Retains the given |executable_cache| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_executable_cache_retain(iree_hal_executable_cache_t* executable_cache); + +// Releases the given |executable_cache| from the caller. +IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_cache_release( + iree_hal_executable_cache_t* executable_cache); + +// Returns true if the executable cache can prepare the given executable input +// format. Preparation may still fail if the particular version or features +// required by the executable are not supported. +IREE_API_EXPORT bool IREE_API_CALL iree_hal_executable_cache_can_prepare_format( + iree_hal_executable_cache_t* executable_cache, + iree_hal_executable_format_t format); + +// Prepares an executable for use. +// The provided |executable_data| will be used to either lookup a previously +// prepared executable in the cache or prepare a new one. +// +// Depending on the driver preparation may take a non-trivial amount of time +// (such as when JITing/etc). As the cache is internally synchronized callers +// can issue preparation requests from multiple threads - even for the same +// executables - and calls will block until preparation completes. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_executable_cache_prepare_executable( + iree_hal_executable_cache_t* executable_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable); + +//===----------------------------------------------------------------------===// +// iree_hal_executable_cache_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_executable_cache_t* executable_cache); + + bool(IREE_API_PTR* can_prepare_format)( + iree_hal_executable_cache_t* executable_cache, + iree_hal_executable_format_t format); + + iree_status_t(IREE_API_PTR* prepare_executable)( + iree_hal_executable_cache_t* executable_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable); +} iree_hal_executable_cache_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_cache_destroy( + iree_hal_executable_cache_t* executable_cache); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_EXECUTABLE_CACHE_H_ diff --git a/iree/hal/executable_format.h b/iree/hal/executable_format.h deleted file mode 100644 index cb536587446d5..0000000000000 --- a/iree/hal/executable_format.h +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Defines the ExecutableFormat 4cc type and a few well-known formats. -// Not all formats need to be defined here, however any format expected to be -// supported by debuggers/tooling will likely want to be here to ensure easier -// referencing. - -#ifndef IREE_HAL_EXECUTABLE_FORMAT_H_ -#define IREE_HAL_EXECUTABLE_FORMAT_H_ - -#include - -namespace iree { -namespace hal { - -// Executable format 4cc identifier. -using ExecutableFormat = uint32_t; - -// Constructs an ExecutableFormat 4cc at compile-time. -constexpr ExecutableFormat MakeExecutableFormatID(char const four_cc[5]) { - return (four_cc[0] << 24) | (four_cc[1] << 16) | (four_cc[2] << 8) | - four_cc[3]; -} - -// Keep these in sync with iree/compiler/Dialect/HAL/IR/HALBase.td - -// Undefined (or unknown). The format may be derived from the executable -// contents (such as file magic bytes). -constexpr ExecutableFormat kExecutableFormatUnspecified = - MakeExecutableFormatID(" "); - -// MLIR text form. -constexpr ExecutableFormat kExecutableFormatMlir = - MakeExecutableFormatID("MLIR"); - -// IREE v0 bytecode. -constexpr ExecutableFormat kExecutableFormatIreeBytecode = - MakeExecutableFormatID("IREE"); - -// IREE VMLA executable in FlatBuffer format using the -// iree/schemas/vmla_executable_def.fbs schema. -constexpr ExecutableFormat kExecutableFormatVMLA = - MakeExecutableFormatID("VMLA"); - -// SPIR-V executable in FlatBuffer format using the -// iree/schemas/spirv_executable_def.fbs schema. -constexpr ExecutableFormat kExecutableFormatSpirV = - MakeExecutableFormatID("SPVE"); - -// Metal executable in FlatBuffer format using the -// iree/schemas/metal_executable_def.fbs schema. -constexpr ExecutableFormat kExecutableFormatMetal = - MakeExecutableFormatID("MTLE"); - -// Dynamic Library (dylib) executable in FlatBuffer format using the -// iree/schemas/dylib_executable_def.fbs schema -constexpr ExecutableFormat kExecutableFormatDyLib = - MakeExecutableFormatID("DLIB"); - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_EXECUTABLE_FORMAT_H_ diff --git a/iree/hal/executable_layout.c b/iree/hal/executable_layout.c new file mode 100644 index 0000000000000..844fa7e5507c3 --- /dev/null +++ b/iree/hal/executable_layout.c @@ -0,0 +1,43 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/executable_layout.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(executable_layout, method_name) \ + IREE_HAL_VTABLE_DISPATCH(executable_layout, iree_hal_executable_layout, \ + method_name) + +IREE_HAL_API_RETAIN_RELEASE(executable_layout); + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_layout_create( + iree_hal_device_t* device, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, + iree_hal_executable_layout_t** out_executable_layout) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); + IREE_ASSERT_ARGUMENT(out_executable_layout); + *out_executable_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, + create_executable_layout)( + device, set_layout_count, set_layouts, push_constants, + out_executable_layout); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/executable_layout.h b/iree/hal/executable_layout.h index 2ce959d52f08d..8ba573bed0cee 100644 --- a/iree/hal/executable_layout.h +++ b/iree/hal/executable_layout.h @@ -12,15 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/hal/resource.h" - #ifndef IREE_HAL_EXECUTABLE_LAYOUT_H_ #define IREE_HAL_EXECUTABLE_LAYOUT_H_ -namespace iree { -namespace hal { +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/descriptor_set_layout.h" +#include "iree/hal/resource.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_device_s iree_hal_device_t; + +//===----------------------------------------------------------------------===// +// iree_hal_executable_layout_t +//===----------------------------------------------------------------------===// // Defines the resource binding layout used by an executable. +// A "descriptor" is effectively a bound memory range and each dispatch can use +// one or more "descriptor sets" to access their I/O memory. A "descriptor set +// layout" defines the types and usage semantics of the descriptors that make up +// one set. An "executable layout" defines all of the set layouts that will be +// used when dispatching. Implementations can use this to verify program +// correctness and accelerate reservation/allocatation/computation of +// descriptor-related operations. // // Executables can share the same layout even if they do not use all of the // resources referenced by descriptor sets referenced by the layout. Doing so @@ -29,11 +48,41 @@ namespace hal { // // Maps to VkPipelineLayout: // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPipelineLayout.html -class ExecutableLayout : public Resource { - public: -}; +typedef struct iree_hal_executable_layout_s iree_hal_executable_layout_t; + +// Creates an executable layout composed of the given descriptor set layouts. +// The returned executable layout can be used by multiple executables with the +// same compatible resource binding layouts. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_executable_layout_create( + iree_hal_device_t* device, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, + iree_hal_executable_layout_t** out_executable_layout); + +// Retains the given |executable_layout| for the caller. +IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_layout_retain( + iree_hal_executable_layout_t* executable_layout); + +// Releases the given |executable_layout| from the caller. +IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_layout_release( + iree_hal_executable_layout_t* executable_layout); + +//===----------------------------------------------------------------------===// +// iree_hal_executable_layout_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_executable_layout_t* executable_layout); +} iree_hal_executable_layout_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL iree_hal_executable_layout_destroy( + iree_hal_executable_layout_t* executable_layout); -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_EXECUTABLE_LAYOUT_H_ diff --git a/iree/hal/executable_spec.h b/iree/hal/executable_spec.h deleted file mode 100644 index c9486d33fe3ea..0000000000000 --- a/iree/hal/executable_spec.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_EXECUTABLE_SPEC_H_ -#define IREE_HAL_EXECUTABLE_SPEC_H_ - -#include "absl/types/span.h" -#include "iree/hal/executable_format.h" - -namespace iree { -namespace hal { - -// Defines an executable specification used by a cache to prepare an executable. -struct ExecutableSpec { - // TODO(benvanik): pre-populated hash_code/key to avoid calculation. - - // A reference to the executable data as input to the cache. - // If ExecutableCachingMode::kAliasProvidedData is set then this reference - // may be retained by the cache and the backing buffer must be kept valid for - // the lifetime of the cache. - absl::Span executable_data; - - // TODO(benvanik): add specialization info (constants/defines). - // TODO(benvanik): add compiler flags? could treat as opaque. -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_EXECUTABLE_SPEC_H_ diff --git a/iree/hal/heap_buffer.cc b/iree/hal/heap_buffer.cc deleted file mode 100644 index fae11346a9d2b..0000000000000 --- a/iree/hal/heap_buffer.cc +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/heap_buffer.h" - -#include -#include -#include -#include - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/allocator.h" -#include "iree/hal/host/host_buffer.h" - -namespace iree { -namespace hal { - -namespace { - -// An allocator that allocates or wraps host-only buffers. -// The resulting buffers are not usable by most devices without a copy and -// using a device allocator is strongly preferred. -class HeapAllocator : public Allocator { - public: - // Returns a singleton heap allocator that can provide buffers that have - // MemoryType::kHostLocal and are allocated with malloc/free. - // These buffers will not be usable by devices directly and may incur - // additional copies. - static Allocator* std_heap(); - - // TODO(benvanik): specify custom allocator (not malloc/free). - HeapAllocator(); - ~HeapAllocator() override; - - bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const override; - - bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const override; - - StatusOr> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) override; - - StatusOr> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - void* data, - size_t data_length) override; -}; - -// static -Allocator* HeapAllocator::std_heap() { - static Allocator* std_heap_allocator = new HeapAllocator(); - return std_heap_allocator; -} - -HeapAllocator::HeapAllocator() = default; - -HeapAllocator::~HeapAllocator() = default; - -bool HeapAllocator::CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const { - // The host can use anything with kHostVisible. - if (!AnyBitSet(memory_type & MemoryType::kHostVisible)) { - return false; - } - - // Host currently uses mapping to copy buffers, which is done a lot. - if (!AnyBitSet(buffer_usage & BufferUsage::kMapping)) { - return false; - } - - return true; -} - -bool HeapAllocator::CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const { - // This host only allocator cannot serve device visible allocation as we - // can't know which devices these buffers will be used with. - return (memory_type & MemoryType::kHostLocal) == MemoryType::kHostLocal && - !AnyBitSet(memory_type & MemoryType::kDeviceLocal) && - !AnyBitSet(memory_type & MemoryType::kDeviceVisible); -} - -StatusOr> HeapAllocator::Allocate( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("HeapAllocator::Allocate"); - - if (!CanAllocate(memory_type, buffer_usage, allocation_size)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Allocation not supported; memory_type=" - << MemoryTypeString(memory_type) - << ", buffer_usage=" << BufferUsageString(buffer_usage) - << ", allocation_size=" << allocation_size; - } - - void* malloced_data = std::calloc(1, allocation_size); - if (!malloced_data) { - return ResourceExhaustedErrorBuilder(IREE_LOC) - << "Failed to malloc " << allocation_size << " bytes"; - } - - auto buffer = - make_ref(this, memory_type, MemoryAccess::kAll, buffer_usage, - allocation_size, malloced_data, true); - return buffer; -} - -StatusOr> HeapAllocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length) { - auto buffer = make_ref(this, memory_type, allowed_access, - buffer_usage, data_length, data, false); - return buffer; -} - -} // namespace - -// static -ref_ptr HeapBuffer::Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - size_t allocation_size) { - auto buffer_or = - HeapAllocator::std_heap()->Allocate(memory_type, usage, allocation_size); - return std::move(buffer_or.value()); -} - -// static -ref_ptr HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - const void* data, size_t data_length) { - return AllocateCopy(usage, MemoryAccess::kAll, data, data_length); -} - -// static -ref_ptr HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - const void* data, size_t data_length) { - IREE_TRACE_SCOPE0("HeapBuffer::AllocateCopy"); - // Ensure we can map so that we can copy into it. - usage |= BufferUsage::kMapping; - auto buffer_or = HeapAllocator::std_heap()->Allocate(MemoryType::kHostLocal, - usage, data_length); - auto buffer = std::move(buffer_or.value()); - buffer->WriteData(0, data, data_length).IgnoreError(); - buffer->set_allowed_access(allowed_access); - return buffer; -} - -// static -ref_ptr HeapBuffer::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, const void* data, - size_t data_length) { - auto buffer_or = - HeapAllocator::std_heap()->Wrap(memory_type, usage, data, data_length); - return std::move(buffer_or.value()); -} - -// static -ref_ptr HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, void* data, - size_t data_length) { - auto buffer_or = HeapAllocator::std_heap()->WrapMutable( - memory_type, allowed_access, usage, data, data_length); - return std::move(buffer_or.value()); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/heap_buffer.h b/iree/hal/heap_buffer.h deleted file mode 100644 index eba9b72eb59fc..0000000000000 --- a/iree/hal/heap_buffer.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HEAP_BUFFER_H_ -#define IREE_HAL_HEAP_BUFFER_H_ - -#include - -#include "iree/base/status.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// Factory for buffers that are allocated from the host heap (malloc/free). -// These buffers cannot be used by devices and will incur copies/transfers when -// used. Prefer device-specific allocators instead. -class HeapBuffer { - public: - // Allocates a zeroed host heap buffer of the given size. - // Returns a buffer allocated with malloc and have MemoryType::kHostLocal - // and will not be usable by devices without copies. - static ref_ptr Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - size_t allocation_size); - static ref_ptr Allocate(BufferUsageBitfield usage, - size_t allocation_size) { - return Allocate(MemoryType::kHostLocal, usage, allocation_size); - } - - // Allocates a host heap buffer with a copy of the given data. - // Returns a buffer allocated with malloc and have MemoryType::kHostLocal - // and will not be usable by devices without copies. - static ref_ptr AllocateCopy(BufferUsageBitfield usage, - const void* data, size_t data_length); - static ref_ptr AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - const void* data, size_t data_length); - template - static ref_ptr AllocateCopy(BufferUsageBitfield usage, - absl::Span data); - template - static ref_ptr AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - absl::Span data); - - // Wraps an existing host heap allocation in a buffer. - // Ownership of the host allocation remains with the caller and the memory - // must remain valid for so long as the Buffer may be in use. - // Will have MemoryType::kHostLocal in most cases and may not be usable - // by the device. - static ref_ptr Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, const void* data, - size_t data_length); - static ref_ptr WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, void* data, - size_t data_length); - template - static ref_ptr Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - absl::Span data); - template - static ref_ptr WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, - absl::Span data); -}; - -// Inline functions and template definitions follow: - -template -ref_ptr HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - absl::Span data) { - return HeapBuffer::AllocateCopy(usage, MemoryAccess::kAll, data); -} - -template -ref_ptr HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - absl::Span data) { - return HeapBuffer::AllocateCopy(usage, allowed_access, data.data(), - data.size() * sizeof(T)); -} - -template -ref_ptr HeapBuffer::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - absl::Span data) { - return HeapBuffer::Wrap(memory_type, usage, data.data(), - data.size() * sizeof(T)); -} - -template -ref_ptr HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, - absl::Span data) { - return HeapBuffer::WrapMutable(memory_type, allowed_access, usage, - data.data(), data.size() * sizeof(T)); -} - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HEAP_BUFFER_H_ diff --git a/iree/hal/host/BUILD b/iree/hal/host/BUILD deleted file mode 100644 index a968eed67676f..0000000000000 --- a/iree/hal/host/BUILD +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Default implementations for HAL types that use the host resources. -# These are generally just wrappers around host heap memory and host threads. - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "condvar_semaphore", - srcs = ["condvar_semaphore.cc"], - hdrs = ["condvar_semaphore.h"], - deps = [ - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "condvar_semaphore_test", - srcs = ["condvar_semaphore_test.cc"], - deps = [ - ":condvar_semaphore", - "//iree/base:api", - "//iree/base:status", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - ], -) - -cc_library( - name = "host_buffer", - srcs = ["host_buffer.cc"], - hdrs = ["host_buffer.h"], - deps = [ - "//iree/base:logging", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - ], -) - -cc_library( - name = "host_descriptor_set", - srcs = ["host_descriptor_set.cc"], - hdrs = ["host_descriptor_set.h"], - deps = [ - "//iree/hal", - "@com_google_absl//absl/container:inlined_vector", - ], -) - -cc_library( - name = "host_executable", - hdrs = ["host_executable.h"], - deps = [ - "//iree/base:status", - "//iree/hal", - ], -) - -cc_library( - name = "host_executable_layout", - srcs = ["host_executable_layout.cc"], - hdrs = ["host_executable_layout.h"], - deps = [ - "//iree/base:core_headers", - "//iree/hal", - "@com_google_absl//absl/container:inlined_vector", - ], -) - -cc_library( - name = "host_local_allocator", - srcs = ["host_local_allocator.cc"], - hdrs = ["host_local_allocator.h"], - deps = [ - ":host_buffer", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - ], -) - -cc_library( - name = "host_local_device", - srcs = ["host_local_device.cc"], - hdrs = ["host_local_device.h"], - deps = [ - ":host_descriptor_set", - ":host_executable_layout", - ":host_local_allocator", - ":scheduling_model", - "//iree/base:core_headers", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "//iree/hal:command_buffer_validation", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "inproc_command_buffer", - srcs = ["inproc_command_buffer.cc"], - hdrs = ["inproc_command_buffer.h"], - deps = [ - "//iree/base:arena", - "//iree/base:intrusive_list", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - ], -) - -cc_library( - name = "nop_event", - srcs = ["nop_event.cc"], - hdrs = ["nop_event.h"], - deps = [ - "//iree/hal", - ], -) - -cc_library( - name = "scheduling_model", - hdrs = ["scheduling_model.h"], - deps = [ - "//iree/hal", - ], -) diff --git a/iree/hal/host/CMakeLists.txt b/iree/hal/host/CMakeLists.txt deleted file mode 100644 index 5c926c6007882..0000000000000 --- a/iree/hal/host/CMakeLists.txt +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -iree_add_all_subdirs() - -iree_cc_library( - NAME - condvar_semaphore - HDRS - "condvar_semaphore.h" - SRCS - "condvar_semaphore.cc" - DEPS - absl::core_headers - absl::inlined_vector - absl::span - absl::synchronization - iree::base::status - iree::base::tracing - iree::hal - PUBLIC -) - -iree_cc_test( - NAME - condvar_semaphore_test - SRCS - "condvar_semaphore_test.cc" - DEPS - ::condvar_semaphore - iree::base::api - iree::base::status - iree::testing::gtest - iree::testing::gtest_main -) - -iree_cc_library( - NAME - host_buffer - HDRS - "host_buffer.h" - SRCS - "host_buffer.cc" - DEPS - iree::base::logging - iree::base::status - iree::base::tracing - iree::hal - PUBLIC -) - -iree_cc_library( - NAME - host_descriptor_set - HDRS - "host_descriptor_set.h" - SRCS - "host_descriptor_set.cc" - DEPS - absl::inlined_vector - iree::hal - PUBLIC -) - -iree_cc_library( - NAME - host_executable - HDRS - "host_executable.h" - DEPS - iree::base::status - iree::hal - PUBLIC -) - -iree_cc_library( - NAME - host_executable_layout - HDRS - "host_executable_layout.h" - SRCS - "host_executable_layout.cc" - DEPS - absl::inlined_vector - iree::base::core_headers - iree::hal - PUBLIC -) - -iree_cc_library( - NAME - host_local_allocator - HDRS - "host_local_allocator.h" - SRCS - "host_local_allocator.cc" - DEPS - ::host_buffer - iree::base::status - iree::base::tracing - iree::hal - PUBLIC -) - -iree_cc_library( - NAME - host_local_device - HDRS - "host_local_device.h" - SRCS - "host_local_device.cc" - DEPS - ::host_descriptor_set - ::host_executable_layout - ::host_local_allocator - ::scheduling_model - absl::core_headers - absl::memory - absl::span - iree::base::core_headers - iree::base::status - iree::base::tracing - iree::hal - iree::hal::command_buffer_validation - PUBLIC -) - -iree_cc_library( - NAME - inproc_command_buffer - HDRS - "inproc_command_buffer.h" - SRCS - "inproc_command_buffer.cc" - DEPS - iree::base::arena - iree::base::intrusive_list - iree::base::status - iree::base::tracing - iree::hal - PUBLIC -) - -iree_cc_library( - NAME - nop_event - HDRS - "nop_event.h" - SRCS - "nop_event.cc" - DEPS - iree::hal - PUBLIC -) - -iree_cc_library( - NAME - scheduling_model - HDRS - "scheduling_model.h" - DEPS - iree::hal - PUBLIC -) diff --git a/iree/hal/host/condvar_semaphore.cc b/iree/hal/host/condvar_semaphore.cc deleted file mode 100644 index dafe86c0ef923..0000000000000 --- a/iree/hal/host/condvar_semaphore.cc +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/condvar_semaphore.h" - -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { -namespace host { - -CondVarSemaphore::CondVarSemaphore(uint64_t initial_value) - : value_(initial_value) {} - -CondVarSemaphore::~CondVarSemaphore() = default; - -StatusOr CondVarSemaphore::Query() { - absl::MutexLock lock(&mutex_); - if (!status_.ok()) { - return status_; - } - return value_.load(std::memory_order_acquire); -} - -Status CondVarSemaphore::Signal(uint64_t value) { - absl::MutexLock lock(&mutex_); - if (!status_.ok()) { - return status_; - } - if (value_.exchange(value) >= value) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Semaphore values must be monotonically increasing"; - } - return OkStatus(); -} - -void CondVarSemaphore::Fail(Status status) { - absl::MutexLock lock(&mutex_); - status_ = std::move(status); - value_.store(UINT64_MAX, std::memory_order_release); -} - -// static -Status CondVarSemaphore::WaitForSemaphores( - absl::Span semaphores, bool wait_all, - Time deadline_ns) { - IREE_TRACE_SCOPE0("CondVarSemaphore::WaitForSemaphores"); - - // Some of the semaphores may already be signaled; we only need to wait for - // those that are not yet at the expected value. - using CondVarSemaphoreValue = std::pair; - absl::InlinedVector waitable_semaphores; - waitable_semaphores.reserve(semaphores.size()); - for (auto& semaphore_value : semaphores) { - auto* semaphore = - reinterpret_cast(semaphore_value.semaphore); - IREE_ASSIGN_OR_RETURN(uint64_t current_value, semaphore->Query()); - if (current_value < semaphore_value.value) { - // Semaphore has not yet hit the required value; wait for it. - waitable_semaphores.push_back({semaphore, semaphore_value.value}); - } - } - - // TODO(benvanik): maybe sort semaphores by value in case we are waiting on - // multiple values from the same semaphore. - - // Loop over the semaphores and wait for them to complete. - // TODO(b/140026716): add WaitHandle support for !wait_all (wait any). - for (auto& semaphore_value : waitable_semaphores) { - auto* semaphore = semaphore_value.first; - absl::MutexLock lock(&semaphore->mutex_); - if (!semaphore->mutex_.AwaitWithDeadline( - absl::Condition( - +[](CondVarSemaphoreValue* semaphore_value) { - return semaphore_value->first->value_.load( - std::memory_order_acquire) >= - semaphore_value->second; - }, - &semaphore_value), - absl::FromUnixNanos(static_cast(deadline_ns)))) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for semaphores"; - } - if (!semaphore->status_.ok()) { - return semaphore->status_; - } - } - - return OkStatus(); -} - -Status CondVarSemaphore::Wait(uint64_t value, Time deadline_ns) { - return WaitForSemaphores({{this, value}}, /*wait_all=*/true, deadline_ns); -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/condvar_semaphore.h b/iree/hal/host/condvar_semaphore.h deleted file mode 100644 index 44dd541f8d570..0000000000000 --- a/iree/hal/host/condvar_semaphore.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_CONDVAR_SEMAPHORE_H_ -#define IREE_HAL_HOST_CONDVAR_SEMAPHORE_H_ - -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { -namespace host { - -// Simple host-only semaphore semaphore implemented with a mutex. -// Uses a condition variable to track the current value. -// -// Thread-safe (as instances may be imported and used by others). -class CondVarSemaphore final : public Semaphore { - public: - // Waits for one or more (or all) semaphores to reach or exceed the given - // values. - static Status WaitForSemaphores(absl::Span semaphores, - bool wait_all, Time deadline_ns); - - explicit CondVarSemaphore(uint64_t initial_value); - ~CondVarSemaphore() override; - - StatusOr Query() override; - - Status Signal(uint64_t value) override; - void Fail(Status status) override; - Status Wait(uint64_t value, Time deadline_ns) override; - - private: - // The mutex is not required to query the value; this lets us quickly check if - // a required value has been exceeded. The mutex is only used to update and - // notify waiters. - std::atomic value_{0}; - - // We have a full mutex here so that we can perform condvar waits on value - // changes. - mutable absl::Mutex mutex_; - Status status_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_CONDVAR_SEMAPHORE_H_ diff --git a/iree/hal/host/condvar_semaphore_test.cc b/iree/hal/host/condvar_semaphore_test.cc deleted file mode 100644 index 00cbbd42704dc..0000000000000 --- a/iree/hal/host/condvar_semaphore_test.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/condvar_semaphore.h" - -#include -#include // NOLINT - -#include "iree/base/status.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -namespace host { -namespace { - -// Tests that a semaphore that is unused properly cleans itself up. -TEST(CondVarSemaphoreTest, NoOp) { - CondVarSemaphore semaphore(123u); - IREE_ASSERT_OK_AND_ASSIGN(uint64_t value, semaphore.Query()); - EXPECT_EQ(123u, value); -} - -// Tests that a semaphore will accept new values as it is signaled. -TEST(CondVarSemaphoreTest, NormalSignaling) { - CondVarSemaphore semaphore(2u); - EXPECT_EQ(2u, semaphore.Query().value()); - IREE_EXPECT_OK(semaphore.Signal(3u)); - EXPECT_EQ(3u, semaphore.Query().value()); - IREE_EXPECT_OK(semaphore.Signal(40u)); - EXPECT_EQ(40u, semaphore.Query().value()); -} - -// Tests that a semaphore will fail to set non-increasing values. -TEST(CondVarSemaphoreTest, RequireIncreasingValues) { - CondVarSemaphore semaphore(2u); - EXPECT_EQ(2u, semaphore.Query().value()); - // Same value. - EXPECT_TRUE(IsInvalidArgument(semaphore.Signal(2u))); - // Decreasing. - EXPECT_TRUE(IsInvalidArgument(semaphore.Signal(1u))); -} - -// Tests that a semaphore that has failed will remain in a failed state. -TEST(CondVarSemaphoreTest, StickyFailure) { - CondVarSemaphore semaphore(2u); - // Signal to 3. - IREE_EXPECT_OK(semaphore.Signal(3u)); - EXPECT_EQ(3u, semaphore.Query().value()); - - // Fail now. - semaphore.Fail(UnknownErrorBuilder(IREE_LOC)); - EXPECT_TRUE(IsUnknown(semaphore.Query().status())); - - // Unable to signal again (it'll return the sticky failure). - EXPECT_TRUE(IsUnknown(semaphore.Signal(4u))); - EXPECT_TRUE(IsUnknown(semaphore.Query().status())); -} - -// Tests waiting on no semaphores. -TEST(CondVarSemaphoreTest, EmptyWait) { - IREE_EXPECT_OK(CondVarSemaphore::WaitForSemaphores({}, /*wait_all=*/true, - InfiniteFuture())); -} - -// Tests waiting on a semaphore that has already been signaled. -TEST(CondVarSemaphoreTest, WaitAlreadySignaled) { - CondVarSemaphore semaphore(2u); - // Test both previous and current values. - IREE_EXPECT_OK(CondVarSemaphore::WaitForSemaphores( - {{&semaphore, 1u}}, /*wait_all=*/true, InfiniteFuture())); - IREE_EXPECT_OK(CondVarSemaphore::WaitForSemaphores( - {{&semaphore, 2u}}, /*wait_all=*/true, InfiniteFuture())); -} - -// Tests waiting on a semaphore that has not been signaled. -TEST(CondVarSemaphoreTest, WaitUnsignaled) { - CondVarSemaphore semaphore(2u); - // NOTE: we don't actually block here because otherwise we'd lock up. - EXPECT_TRUE(IsDeadlineExceeded(CondVarSemaphore::WaitForSemaphores( - {{&semaphore, 3u}}, /*wait_all=*/true, InfinitePast()))); -} - -// Tests waiting on a failed semaphore (it should return the error on the -// semaphore). -TEST(CondVarSemaphoreTest, WaitAlreadyFailed) { - CondVarSemaphore semaphore(2u); - semaphore.Fail(UnknownErrorBuilder(IREE_LOC)); - EXPECT_TRUE(IsUnknown(CondVarSemaphore::WaitForSemaphores( - {{&semaphore, 2u}}, /*wait_all=*/true, InfinitePast()))); -} - -// Tests threading behavior by ping-ponging between the test main thread and -// a little thread. -TEST(CondVarSemaphoreTest, PingPong) { - CondVarSemaphore a2b(0u); - CondVarSemaphore b2a(0u); - std::thread thread([&]() { - // Should advance right past this because the value is already set. - IREE_ASSERT_OK(CondVarSemaphore::WaitForSemaphores( - {{&a2b, 0u}}, /*wait_all=*/true, InfiniteFuture())); - IREE_ASSERT_OK(b2a.Signal(1u)); - // Jump ahead. - IREE_ASSERT_OK(CondVarSemaphore::WaitForSemaphores( - {{&a2b, 4u}}, /*wait_all=*/true, InfiniteFuture())); - }); - IREE_ASSERT_OK(CondVarSemaphore::WaitForSemaphores( - {{&b2a, 1u}}, /*wait_all=*/true, InfiniteFuture())); - IREE_ASSERT_OK(a2b.Signal(4u)); - thread.join(); -} - -// Tests that failure still wakes waiters and propagates the error. -TEST(CondVarSemaphoreTest, FailNotifies) { - CondVarSemaphore a2b(0u); - CondVarSemaphore b2a(0u); - bool got_failure = false; - std::thread thread([&]() { - IREE_ASSERT_OK(b2a.Signal(1u)); - got_failure = IsUnknown(CondVarSemaphore::WaitForSemaphores( - {{&a2b, 1u}}, /*wait_all=*/true, InfiniteFuture())); - }); - IREE_ASSERT_OK(CondVarSemaphore::WaitForSemaphores( - {{&b2a, 1u}}, /*wait_all=*/true, InfiniteFuture())); - a2b.Fail(UnknownErrorBuilder(IREE_LOC)); - thread.join(); - ASSERT_TRUE(got_failure); -} - -} // namespace -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/host_buffer.cc b/iree/hal/host/host_buffer.cc deleted file mode 100644 index 265e048f41fa4..0000000000000 --- a/iree/hal/host/host_buffer.cc +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_buffer.h" - -#include -#include -#include - -#include "iree/base/logging.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -class Allocator; - -HostBuffer::HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, device_size_t allocation_size, - void* data, bool owns_data) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, 0, - allocation_size), - data_(data), - owns_data_(owns_data) {} - -HostBuffer::~HostBuffer() { - IREE_TRACE_SCOPE(); - if (owns_data_ && data_) { - std::free(data_); - data_ = nullptr; - } -} - -Status HostBuffer::FillImpl(device_size_t byte_offset, - device_size_t byte_length, const void* pattern, - device_size_t pattern_length) { - auto data_ptr = data_; - switch (pattern_length) { - case 1: { - uint8_t* data = static_cast(data_ptr); - uint8_t value_bits = *static_cast(pattern); - std::fill_n(data + byte_offset, byte_length, value_bits); - break; - } - case 2: { - uint16_t* data = static_cast(data_ptr); - uint16_t value_bits = *static_cast(pattern); - std::fill_n(data + byte_offset / sizeof(uint16_t), - byte_length / sizeof(uint16_t), value_bits); - break; - } - case 4: { - uint32_t* data = static_cast(data_ptr); - uint32_t value_bits = *static_cast(pattern); - std::fill_n(data + byte_offset / sizeof(uint32_t), - byte_length / sizeof(uint32_t), value_bits); - break; - } - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported scalar data size: " << pattern_length; - } - return OkStatus(); -} - -Status HostBuffer::ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) { - auto data_ptr = static_cast(data_); - std::memcpy(data, data_ptr + source_offset, data_length); - return OkStatus(); -} - -Status HostBuffer::WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) { - auto data_ptr = static_cast(data_); - std::memcpy(data_ptr + target_offset, data, data_length); - return OkStatus(); -} - -Status HostBuffer::CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - // This is pretty terrible. Let's not do this. - // TODO(benvanik): a way for allocators to indicate transfer compat. - IREE_ASSIGN_OR_RETURN(auto source_data, - source_buffer->MapMemory( - MemoryAccess::kRead, source_offset, data_length)); - IREE_CHECK_EQ(data_length, source_data.size()); - auto data_ptr = static_cast(data_); - std::memcpy(data_ptr + target_offset, source_data.data(), data_length); - return OkStatus(); -} - -Status HostBuffer::MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) { - auto data_ptr = static_cast(data_); - *out_data = data_ptr + local_byte_offset; - - // If we mapped for discard scribble over the bytes. This is not a mandated - // behavior but it will make debugging issues easier. Alternatively for - // heap buffers we could reallocate them such that ASAN yells, but that - // would only work if the entire buffer was discarded. -#ifndef NDEBUG - if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) { - std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); - } -#endif // !NDEBUG - - return OkStatus(); -} - -Status HostBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, - void* data) { - // No-op? We still want error checking to make finding misuse easier. - return OkStatus(); -} - -Status HostBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - // No-op? We still want error checking to make finding misuse easier. - return OkStatus(); -} - -Status HostBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - // No-op? We still want error checking to make finding misuse easier. - return OkStatus(); -} - -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/host_buffer.h b/iree/hal/host/host_buffer.h deleted file mode 100644 index edc46359dcf63..0000000000000 --- a/iree/hal/host/host_buffer.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_BUFFER_H_ -#define IREE_HAL_HOST_BUFFER_H_ - -#include - -#include "iree/base/status.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// A buffer type that operates on host pointers. -// This can be used by Allocator implementations when they support operating -// on host memory (or mapping their memory to host memory). -class HostBuffer : public Buffer { - public: - HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, void* data, bool owns_data); - - ~HostBuffer() override; - - const void* data() const { return data_; } - void* mutable_data() { return data_; } - - protected: - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override; - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override; - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override; - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override; - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override; - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - - private: - void* data_ = nullptr; - bool owns_data_ = false; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_BUFFER_H_ diff --git a/iree/hal/host/host_descriptor_set.h b/iree/hal/host/host_descriptor_set.h deleted file mode 100644 index c3c24a242a723..0000000000000 --- a/iree/hal/host/host_descriptor_set.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_DESCRIPTOR_SET_H_ -#define IREE_HAL_HOST_HOST_DESCRIPTOR_SET_H_ - -#include "absl/container/inlined_vector.h" -#include "iree/hal/descriptor_set.h" -#include "iree/hal/descriptor_set_layout.h" - -namespace iree { -namespace hal { - -class HostDescriptorSet final : public DescriptorSet { - public: - HostDescriptorSet(DescriptorSetLayout* set_layout, - absl::Span bindings); - ~HostDescriptorSet() override; - - absl::Span bindings() const { - return absl::MakeConstSpan(bindings_); - } - - private: - absl::InlinedVector bindings_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_DESCRIPTOR_SET_H_ diff --git a/iree/hal/host/host_executable.h b/iree/hal/host/host_executable.h deleted file mode 100644 index 9b7aa788decbb..0000000000000 --- a/iree/hal/host/host_executable.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_EXECUTABLE_H_ -#define IREE_HAL_HOST_HOST_EXECUTABLE_H_ - -#include "iree/base/status.h" -#include "iree/hal/descriptor_set.h" -#include "iree/hal/executable.h" - -namespace iree { -namespace hal { - -// Computed push constant values available to all tiles in the grid. -struct PushConstantBlock { - // We limit ourselves to 32 constants (32*sizeof(uint32) = 128b). - // This is the lower bound for Vulkan implementations and ensures that we - // have consistent support everywhere. - std::array values; -}; - -// Abstract host-local executable that can dispatch grid-based tiles. -// Implementations provide the logic to process individual tiles within the -// workgroup-defined XYZ grid. -// -// Thread-safe; the processor may be called to process the grid by any thread in -// any order. -class HostExecutable : public Executable { - public: - // Grid parameters shared for all tiles within a dispatch. - struct DispatchParams { - // Entry point within the executable. - size_t entry_point = 0; - - // Total workgroup XYZ count for the grid. - std::array workgroup_count; - - // Size of each tile in the grid in local space. - std::array workgroup_size; - - // Push constants populated by the command buffer. - const PushConstantBlock* push_constants = nullptr; - - // Descriptor set bindings organized by set and binding ordinal. - absl::Span> set_bindings; - }; - - struct DispatchState : public RefObject { - virtual ~DispatchState() = default; - }; - - // Begins processing a grid dispatch with the given parameters. - // May be called from any thread. Returns dispatch state that will be passed - // to all DispatchTile calls from the same dispatch operation. - virtual StatusOr> PrepareDispatch( - const DispatchParams& params) = 0; - - // Processes a single tile within the grid. - // |workgroup_xyz| is the tile coordinates in the grid as defined during - // preparation. May be called from any thread. - virtual Status DispatchTile(DispatchState* state, - std::array workgroup_xyz) = 0; - - protected: - HostExecutable() = default; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_EXECUTABLE_H_ diff --git a/iree/hal/host/host_executable_layout.cc b/iree/hal/host/host_executable_layout.cc deleted file mode 100644 index 24474ee136a77..0000000000000 --- a/iree/hal/host/host_executable_layout.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_executable_layout.h" - -#include "iree/base/memory.h" - -namespace iree { -namespace hal { - -HostDescriptorSetLayout::HostDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) - : bindings_(bindings.begin(), bindings.end()) {} - -HostDescriptorSetLayout::~HostDescriptorSetLayout() = default; - -HostExecutableLayout::HostExecutableLayout( - absl::Span set_layouts, size_t push_constants) - : push_constants_(push_constants) { - dynamic_binding_map_.resize(set_layouts.size()); - for (int i = 0; i < set_layouts.size(); ++i) { - auto* set_layout = static_cast(set_layouts[i]); - auto& set_binding_map = dynamic_binding_map_[i]; - for (auto& binding : set_layout->bindings()) { - if (binding.type == DescriptorType::kStorageBufferDynamic || - binding.type == DescriptorType::kUniformBufferDynamic) { - set_binding_map.push_back(binding.binding); - } - } - } -} - -HostExecutableLayout::~HostExecutableLayout() = default; - -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/host_executable_layout.h b/iree/hal/host/host_executable_layout.h deleted file mode 100644 index 9e17c40317203..0000000000000 --- a/iree/hal/host/host_executable_layout.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_EXECUTABLE_LAYOUT_H_ -#define IREE_HAL_HOST_HOST_EXECUTABLE_LAYOUT_H_ - -#include "absl/container/inlined_vector.h" -#include "iree/hal/descriptor_set_layout.h" -#include "iree/hal/executable_layout.h" - -namespace iree { -namespace hal { - -class HostDescriptorSetLayout final : public DescriptorSetLayout { - public: - HostDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings); - ~HostDescriptorSetLayout() override; - - absl::Span bindings() const { - return absl::MakeConstSpan(bindings_); - } - - private: - absl::InlinedVector bindings_; -}; - -class HostExecutableLayout final : public ExecutableLayout { - public: - HostExecutableLayout(absl::Span set_layouts, - size_t push_constants); - ~HostExecutableLayout() override; - - // Returns the total number of descriptor sets in the layout. - size_t set_count() const { return dynamic_binding_map_.size(); } - - // Returns a map from dynamic offset index to the binding index in |set|. - absl::Span GetDynamicBindingMap(int32_t set) const { - return dynamic_binding_map_[set]; - } - - size_t push_constants() const { return push_constants_; } - - private: - size_t push_constants_; - absl::InlinedVector, 2> dynamic_binding_map_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_EXECUTABLE_LAYOUT_H_ diff --git a/iree/hal/host/host_local_allocator.cc b/iree/hal/host/host_local_allocator.cc deleted file mode 100644 index faf7fdc9c3d4b..0000000000000 --- a/iree/hal/host/host_local_allocator.cc +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_local_allocator.h" - -#include -#include -#include - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/host/host_buffer.h" - -namespace iree { -namespace hal { -namespace host { - -HostLocalAllocator::HostLocalAllocator() = default; - -HostLocalAllocator::~HostLocalAllocator() = default; - -bool HostLocalAllocator::CanUseBufferLike( - Allocator* source_allocator, MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const { - // Must always have visibility to the device, which ensures we can test - // against the host but have things work on devices with separate address - // spaces. - if (!AnyBitSet(memory_type & MemoryType::kDeviceVisible)) { - return false; - } - - // kHostVisible is required for mapping. - if (AnyBitSet(intended_usage & BufferUsage::kMapping) && - !AnyBitSet(memory_type & MemoryType::kHostVisible)) { - return false; - } - - // Dispatch needs to be specified if we intend to dispatch. - if (AnyBitSet(intended_usage & BufferUsage::kDispatch) && - !AnyBitSet(buffer_usage & BufferUsage::kDispatch)) { - return false; - } - - return true; -} - -bool HostLocalAllocator::CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const { - // Host allows everything, pretty much, so long as it is device-visible (as - // the host is the device here). - return AnyBitSet(memory_type & MemoryType::kDeviceVisible); -} - -Status HostLocalAllocator::MakeCompatible( - MemoryTypeBitfield* memory_type, BufferUsageBitfield* buffer_usage) const { - // Always ensure we are host-visible. - *memory_type |= MemoryType::kHostVisible; - - // Host currently uses mapping to copy buffers, which is done a lot. - // We could probably remove this restriction somehow. - *buffer_usage |= BufferUsage::kMapping; - - // TODO(b/111372612): tensorflow needs transfer too, but shouldn't. - *buffer_usage |= BufferUsage::kTransfer; - - return OkStatus(); -} - -StatusOr> HostLocalAllocator::Allocate( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("HostLocalAllocator::Allocate"); - - if (!CanAllocate(memory_type, buffer_usage, allocation_size)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Allocation not supported; memory_type=" - << MemoryTypeString(memory_type) - << ", buffer_usage=" << BufferUsageString(buffer_usage) - << ", allocation_size=" << allocation_size; - } - - // Make compatible with our requirements. - IREE_RETURN_IF_ERROR(MakeCompatible(&memory_type, &buffer_usage)); - - void* malloced_data = std::calloc(1, allocation_size); - if (!malloced_data) { - return ResourceExhaustedErrorBuilder(IREE_LOC) - << "Failed to malloc " << allocation_size << " bytes"; - } - - auto buffer = - make_ref(this, memory_type, MemoryAccess::kAll, buffer_usage, - allocation_size, malloced_data, true); - return buffer; -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/host_local_allocator.h b/iree/hal/host/host_local_allocator.h deleted file mode 100644 index 375cfb8e3a812..0000000000000 --- a/iree/hal/host/host_local_allocator.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_LOCAL_ALLOCATOR_H_ -#define IREE_HAL_HOST_LOCAL_ALLOCATOR_H_ - -#include -#include - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { -namespace host { - -// An allocator implementation that allocates buffers from host memory. -// This can be used for drivers that do not have a memory space of their own. -// -// Buffers allocated will have be MemoryType::kHostLocal | kDeviceVisible as -// the 'device' in the case of a host-local queue *is* the host. To keep code -// written initially for a host-local queue working when other queues are used -// the allocator only works with buffers that are kDeviceVisible. -class HostLocalAllocator : public Allocator { - public: - HostLocalAllocator(); - ~HostLocalAllocator() override; - - bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const override; - - bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const override; - - Status MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const override; - - StatusOr> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) override; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_LOCAL_ALLOCATOR_H_ diff --git a/iree/hal/host/host_local_device.cc b/iree/hal/host/host_local_device.cc deleted file mode 100644 index 1c5fc4ee97212..0000000000000 --- a/iree/hal/host/host_local_device.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_local_device.h" - -#include - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/command_buffer_validation.h" -#include "iree/hal/host/host_descriptor_set.h" -#include "iree/hal/host/host_executable_layout.h" - -namespace iree { -namespace hal { -namespace host { - -HostLocalDevice::HostLocalDevice( - DeviceInfo device_info, std::unique_ptr scheduling_model) - : Device(std::move(device_info)), - scheduling_model_(std::move(scheduling_model)) {} - -HostLocalDevice::~HostLocalDevice() = default; - -StatusOr> -HostLocalDevice::CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) { - IREE_TRACE_SCOPE0("HostLocalDevice::CreateDescriptorSetLayout"); - return make_ref(usage_type, bindings); -} - -StatusOr> HostLocalDevice::CreateExecutableLayout( - absl::Span set_layouts, size_t push_constants) { - IREE_TRACE_SCOPE0("HostLocalDevice::CreateExecutableLayout"); - return make_ref(set_layouts, push_constants); -} - -StatusOr> HostLocalDevice::CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span bindings) { - IREE_TRACE_SCOPE0("HostLocalDevice::CreateDescriptorSet"); - return make_ref(set_layout, bindings); -} - -StatusOr> HostLocalDevice::CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) { - IREE_TRACE_SCOPE0("HostLocalDevice::CreateCommandBuffer"); - // TODO(b/140026716): conditionally enable validation. - IREE_ASSIGN_OR_RETURN(auto impl, scheduling_model_->CreateCommandBuffer( - mode, command_categories)); - return WrapCommandBufferWithValidation(allocator(), std::move(impl)); -} - -StatusOr> HostLocalDevice::CreateEvent() { - IREE_TRACE_SCOPE0("HostLocalDevice::CreateEvent"); - return scheduling_model_->CreateEvent(); -} - -StatusOr> HostLocalDevice::CreateSemaphore( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("HostLocalDevice::CreateSemaphore"); - return scheduling_model_->CreateSemaphore(initial_value); -} - -Status HostLocalDevice::WaitAllSemaphores( - absl::Span semaphores, Time deadline_ns) { - IREE_TRACE_SCOPE0("HostLocalDevice::WaitAllSemaphores"); - return scheduling_model_->WaitAllSemaphores(semaphores, deadline_ns); -} - -StatusOr HostLocalDevice::WaitAnySemaphore( - absl::Span semaphores, Time deadline_ns) { - IREE_TRACE_SCOPE0("HostLocalDevice::WaitAnySemaphore"); - return scheduling_model_->WaitAnySemaphore(semaphores, deadline_ns); -} - -Status HostLocalDevice::WaitIdle(Time deadline_ns) { - IREE_TRACE_SCOPE0("HostLocalDevice::WaitIdle"); - return scheduling_model_->WaitIdle(deadline_ns); -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/host_local_device.h b/iree/hal/host/host_local_device.h deleted file mode 100644 index 8fc3b3182592f..0000000000000 --- a/iree/hal/host/host_local_device.h +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_LOCAL_DEVICE_H_ -#define IREE_HAL_HOST_HOST_LOCAL_DEVICE_H_ - -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/device.h" -#include "iree/hal/host/host_local_allocator.h" -#include "iree/hal/host/scheduling_model.h" - -namespace iree { -namespace hal { -namespace host { - -// A host-local device that uses host-local memory and in-process execution. -// This implements the boilerplate needed for any device that runs on the CPU -// using the other Host* types. The scheduling model used to distribute work -// across local CPU resources is provided by the SchedulingModel interface. -class HostLocalDevice : public Device { - public: - ~HostLocalDevice() override; - - Allocator* allocator() const override { return &allocator_; } - - absl::Span dispatch_queues() const override { - return scheduling_model_->dispatch_queues(); - } - - absl::Span transfer_queues() const override { - return scheduling_model_->transfer_queues(); - } - - StatusOr> CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) override; - - StatusOr> CreateExecutableLayout( - absl::Span set_layouts, - size_t push_constants) override; - - StatusOr> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span bindings) override; - - StatusOr> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr> CreateEvent() override; - - StatusOr> CreateSemaphore(uint64_t initial_value) override; - Status WaitAllSemaphores(absl::Span semaphores, - Time deadline_ns) override; - StatusOr WaitAnySemaphore(absl::Span semaphores, - Time deadline_ns) override; - - Status WaitIdle(Time deadline_ns) override; - - protected: - explicit HostLocalDevice(DeviceInfo device_info, - std::unique_ptr scheduling_model); - - private: - std::unique_ptr scheduling_model_; - mutable HostLocalAllocator allocator_; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_LOCAL_DEVICE_H_ diff --git a/iree/hal/host/inproc_command_buffer.cc b/iree/hal/host/inproc_command_buffer.cc deleted file mode 100644 index 80ca7a042bb9a..0000000000000 --- a/iree/hal/host/inproc_command_buffer.cc +++ /dev/null @@ -1,338 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/inproc_command_buffer.h" - -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { -namespace host { - -InProcCommandBuffer::InProcCommandBuffer( - CommandBufferModeBitfield mode, CommandCategoryBitfield command_categories) - : CommandBuffer(mode, command_categories) {} - -InProcCommandBuffer::~InProcCommandBuffer() { Reset(); } - -Status InProcCommandBuffer::Begin() { - IREE_TRACE_SCOPE0("InProcCommandBuffer::Begin"); - is_recording_ = true; - Reset(); - return OkStatus(); -} - -Status InProcCommandBuffer::End() { - IREE_TRACE_SCOPE0("InProcCommandBuffer::End"); - is_recording_ = false; - return OkStatus(); -} - -Status InProcCommandBuffer::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::ExecutionBarrier"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->source_stage_mask = source_stage_mask; - cmd->target_stage_mask = target_stage_mask; - cmd->memory_barriers = AppendStructSpan(memory_barriers); - cmd->buffer_barriers = AppendStructSpan(buffer_barriers); - return OkStatus(); -} - -Status InProcCommandBuffer::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::SignalEvent"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->event = event; - cmd->source_stage_mask = source_stage_mask; - return OkStatus(); -} - -Status InProcCommandBuffer::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::ResetEvent"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->event = event; - cmd->source_stage_mask = source_stage_mask; - return OkStatus(); -} - -Status InProcCommandBuffer::WaitEvents( - absl::Span events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::WaitEvents"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->events = AppendStructSpan(events); - cmd->source_stage_mask = source_stage_mask; - cmd->target_stage_mask = target_stage_mask; - cmd->memory_barriers = AppendStructSpan(memory_barriers); - cmd->buffer_barriers = AppendStructSpan(buffer_barriers); - return OkStatus(); -} - -Status InProcCommandBuffer::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::FillBuffer"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->target_buffer = target_buffer; - cmd->target_offset = target_offset; - cmd->length = length; - std::memcpy(cmd->pattern, pattern, pattern_length); - cmd->pattern_length = pattern_length; - return OkStatus(); -} - -Status InProcCommandBuffer::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::DiscardBuffer"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->buffer = buffer; - return OkStatus(); -} - -Status InProcCommandBuffer::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::UpdateBuffer"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->source_buffer = AppendCmdData(source_buffer, source_offset, length); - cmd->target_buffer = target_buffer; - cmd->target_offset = target_offset; - cmd->length = length; - return OkStatus(); -} - -Status InProcCommandBuffer::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::CopyBuffer"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->source_buffer = source_buffer; - cmd->source_offset = source_offset; - cmd->target_buffer = target_buffer; - cmd->target_offset = target_offset; - cmd->length = length; - return OkStatus(); -} - -Status InProcCommandBuffer::PushConstants(ExecutableLayout* executable_layout, - size_t offset, - absl::Span values) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::PushConstants"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->executable_layout = executable_layout; - cmd->offset = offset; - cmd->values = AppendStructSpan(values); - return OkStatus(); -} - -Status InProcCommandBuffer::PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::PushDescriptorSet"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->executable_layout = executable_layout; - cmd->set = set; - cmd->bindings = AppendStructSpan(bindings); - return OkStatus(); -} - -Status InProcCommandBuffer::BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::BindDescriptorSet"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->executable_layout = executable_layout; - cmd->set = set; - cmd->descriptor_set = descriptor_set; - cmd->dynamic_offsets = AppendStructSpan(dynamic_offsets); - return OkStatus(); -} - -Status InProcCommandBuffer::Dispatch(Executable* executable, - int32_t entry_point, - std::array workgroups) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::Dispatch"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->executable = executable; - cmd->entry_point = entry_point; - cmd->workgroups = workgroups; - return OkStatus(); -} - -Status InProcCommandBuffer::DispatchIndirect(Executable* executable, - int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::DispatchIndirect"); - IREE_ASSIGN_OR_RETURN(auto* cmd, AppendCmd()); - cmd->executable = executable; - cmd->entry_point = entry_point; - cmd->workgroups_buffer = workgroups_buffer; - cmd->workgroups_offset = workgroups_offset; - return OkStatus(); -} - -void InProcCommandBuffer::Reset() { - auto* cmd_list = ¤t_cmd_list_; - cmd_list->head = cmd_list->tail = nullptr; - cmd_list->arena.Reset(); -} - -InProcCommandBuffer::CmdHeader* InProcCommandBuffer::AppendCmdHeader( - CmdType type, size_t cmd_size) { - auto* cmd_list = ¤t_cmd_list_; - auto* cmd_header = reinterpret_cast( - cmd_list->arena.AllocateBytes(sizeof(CmdHeader) + cmd_size)); - cmd_header->next = nullptr; - cmd_header->type = type; - if (!cmd_list->head) { - cmd_list->head = cmd_header; - } else if (cmd_list->tail) { - cmd_list->tail->next = cmd_header; - } - cmd_list->tail = cmd_header; - return cmd_header; -} - -void* InProcCommandBuffer::AppendCmdData(const void* source_buffer, - device_size_t source_offset, - device_size_t source_length) { - auto* cmd_list = ¤t_cmd_list_; - - uint8_t* allocated_bytes = cmd_list->arena.AllocateBytes(source_length); - std::memcpy(allocated_bytes, - static_cast(source_buffer) + source_offset, - source_length); - return allocated_bytes; -} - -Status InProcCommandBuffer::Process(CommandBuffer* command_processor) const { - IREE_TRACE_SCOPE0("InProcCommandBuffer::Process"); - - IREE_RETURN_IF_ERROR(command_processor->Begin()); - - // Process each command in the order they were recorded. - auto* cmd_list = ¤t_cmd_list_; - for (CmdHeader* cmd_header = cmd_list->head; cmd_header != nullptr; - cmd_header = cmd_header->next) { - auto command_status = ProcessCmd(cmd_header, command_processor); - if (!command_status.ok()) { - IREE_LOG(ERROR) - << "DeviceQueue failure while executing command; permanently " - "failing all future commands: " - << command_status; - return command_status; - } - } - - IREE_RETURN_IF_ERROR(command_processor->End()); - - return OkStatus(); -} - -Status InProcCommandBuffer::ProcessCmd(CmdHeader* cmd_header, - CommandBuffer* command_processor) const { - switch (cmd_header->type) { - case CmdType::kExecutionBarrier: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->ExecutionBarrier( - cmd->source_stage_mask, cmd->target_stage_mask, cmd->memory_barriers, - cmd->buffer_barriers); - } - case CmdType::kSignalEvent: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->SignalEvent(cmd->event, cmd->source_stage_mask); - } - case CmdType::kResetEvent: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->ResetEvent(cmd->event, cmd->source_stage_mask); - } - case CmdType::kWaitEvents: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->WaitEvents( - cmd->events, cmd->source_stage_mask, cmd->target_stage_mask, - cmd->memory_barriers, cmd->buffer_barriers); - } - case CmdType::kFillBuffer: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->FillBuffer(cmd->target_buffer, - cmd->target_offset, cmd->length, - cmd->pattern, cmd->pattern_length); - } - case CmdType::kDiscardBuffer: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->DiscardBuffer(cmd->buffer); - } - case CmdType::kUpdateBuffer: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->UpdateBuffer(cmd->source_buffer, 0, - cmd->target_buffer, - cmd->target_offset, cmd->length); - } - case CmdType::kCopyBuffer: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->CopyBuffer( - cmd->source_buffer, cmd->source_offset, cmd->target_buffer, - cmd->target_offset, cmd->length); - } - case CmdType::kPushConstants: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->PushConstants(cmd->executable_layout, - cmd->offset, cmd->values); - } - case CmdType::kPushDescriptorSet: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->PushDescriptorSet(cmd->executable_layout, - cmd->set, cmd->bindings); - } - case CmdType::kBindDescriptorSet: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->BindDescriptorSet(cmd->executable_layout, - cmd->set, cmd->descriptor_set, - cmd->dynamic_offsets); - } - case CmdType::kDispatch: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->Dispatch(cmd->executable, cmd->entry_point, - cmd->workgroups); - } - case CmdType::kDispatchIndirect: { - auto* cmd = reinterpret_cast(cmd_header + 1); - return command_processor->DispatchIndirect( - cmd->executable, cmd->entry_point, cmd->workgroups_buffer, - cmd->workgroups_offset); - } - default: - return DataLossErrorBuilder(IREE_LOC) - << "Unrecognized command type " - << static_cast(cmd_header->type) << "; corrupt buffer?"; - } -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/inproc_command_buffer.h b/iree/hal/host/inproc_command_buffer.h deleted file mode 100644 index dc59288d6b6c8..0000000000000 --- a/iree/hal/host/inproc_command_buffer.h +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_ -#define IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_ - -#include "iree/base/arena.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/status.h" -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { -namespace host { - -// In-process command buffer with support for recording and playback. -// Commands are recorded into heap-allocated arenas with pointers to used -// resources (Buffer*, etc). To replay a command buffer against a real -// implementation use Process to call each command method as it was originally -// recorded. -// -// Thread-compatible (as with CommandBuffer itself). -class InProcCommandBuffer final : public CommandBuffer { - public: - InProcCommandBuffer(CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories); - ~InProcCommandBuffer() override; - - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status WaitEvents(absl::Span events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - - Status DiscardBuffer(Buffer* buffer) override; - - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status PushConstants(ExecutableLayout* executable_layout, size_t offset, - absl::Span values) override; - - Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) override; - - Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) override; - - Status Dispatch(Executable* executable, int32_t entry_point, - std::array workgroups) override; - - Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) override; - - // Processes all commands in the buffer using the given |command_processor|. - // The commands are issued in the order they were recorded. - Status Process(CommandBuffer* command_processor) const; - - private: - // Type of Cmd, used by CmdHeader to identify the command payload. - enum class CmdType { - kExecutionBarrier, - kSignalEvent, - kResetEvent, - kWaitEvents, - kFillBuffer, - kDiscardBuffer, - kUpdateBuffer, - kCopyBuffer, - kPushConstants, - kPushDescriptorSet, - kBindDescriptorSet, - kDispatch, - kDispatchIndirect, - }; - - // Prefix for commands encoded into the CmdList. - // This is used to identify the type of a command as well as connect commands - // in the list sequence. Command data immediately follows the header in - // memory. - struct CmdHeader { - // Optional next command in the list. - CmdHeader* next; - // Type of the command. - CmdType type; - }; - - // A lightweight linked list of commands and an arena that stores them. - // CmdLists are designed to be reused so that the arena allocations are - // amortized across multiple uses. - // - // Note that this and the CmdHeader/Cmd types include raw pointers and as - // such are *not* portable across processes. It'd be possible, though, to - // extend this for cross-process use if a shared-memory Buffer was also - // implemented. For YAGNI we avoid that here. - struct CmdList : public IntrusiveLinkBase { - static constexpr size_t kArenaBlockSize = 64 * 1024; - - Arena arena{kArenaBlockSize}; - CmdHeader* head = nullptr; - CmdHeader* tail = nullptr; - }; - - // Defines an execution barrier. - struct ExecutionBarrierCmd { - static constexpr CmdType kType = CmdType::kExecutionBarrier; - ExecutionStageBitfield source_stage_mask; - ExecutionStageBitfield target_stage_mask; - absl::Span memory_barriers; - absl::Span buffer_barriers; - }; - - // Signals an event. - struct SignalEventCmd { - static constexpr CmdType kType = CmdType::kSignalEvent; - Event* event; - ExecutionStageBitfield source_stage_mask; - }; - - // Resets an event. - struct ResetEventCmd { - static constexpr CmdType kType = CmdType::kResetEvent; - Event* event; - ExecutionStageBitfield source_stage_mask; - }; - - // Waits for one or more events. - struct WaitEventsCmd { - static constexpr CmdType kType = CmdType::kWaitEvents; - absl::Span events; - ExecutionStageBitfield source_stage_mask; - ExecutionStageBitfield target_stage_mask; - absl::Span memory_barriers; - absl::Span buffer_barriers; - }; - - // Fills the target buffer with the given repeating value. - struct FillBufferCmd { - static constexpr CmdType kType = CmdType::kFillBuffer; - Buffer* target_buffer; - device_size_t target_offset; - device_size_t length; - uint8_t pattern[4]; - size_t pattern_length; - }; - - // Hints to the device queue that the given buffer will not be used again. - struct DiscardBufferCmd { - static constexpr CmdType kType = CmdType::kDiscardBuffer; - Buffer* buffer; - }; - - // Writes a range of the given target buffer from the embedded memory. - // The source buffer contents immediately follow the command in the arena. - struct UpdateBufferCmd { - static constexpr CmdType kType = CmdType::kUpdateBuffer; - const void* source_buffer; - Buffer* target_buffer; - device_size_t target_offset; - device_size_t length; - }; - - // Copies a range of one buffer to another. - struct CopyBufferCmd { - static constexpr CmdType kType = CmdType::kCopyBuffer; - Buffer* source_buffer; - device_size_t source_offset; - Buffer* target_buffer; - device_size_t target_offset; - device_size_t length; - }; - - // Pushes inline constant values. - struct PushConstantsCmd { - static constexpr CmdType kType = CmdType::kPushConstants; - ExecutableLayout* executable_layout; - size_t offset; - absl::Span values; - }; - - // Pushes an inline descriptor set update. - struct PushDescriptorSetCmd { - static constexpr CmdType kType = CmdType::kPushDescriptorSet; - ExecutableLayout* executable_layout; - int32_t set; - absl::Span bindings; - }; - - // Binds a descriptor set. - struct BindDescriptorSetCmd { - static constexpr CmdType kType = CmdType::kBindDescriptorSet; - ExecutableLayout* executable_layout; - int32_t set; - DescriptorSet* descriptor_set; - absl::Span dynamic_offsets; - }; - - // Dispatches an execution request. - struct DispatchCmd { - static constexpr CmdType kType = CmdType::kDispatch; - Executable* executable; - int32_t entry_point; - std::array workgroups; - }; - - // Dispatches an execution request with indirect workgroup counts. - struct DispatchIndirectCmd { - static constexpr CmdType kType = CmdType::kDispatchIndirect; - Executable* executable; - int32_t entry_point; - Buffer* workgroups_buffer; - device_size_t workgroups_offset; - }; - - // Resets the command list. - void Reset(); - - // Allocates a command and appends it to the current command list. - // The caller must populate the fields in the returned pointer. - template - StatusOr AppendCmd() { - return reinterpret_cast(AppendCmdHeader(T::kType, sizeof(T)) + 1); - } - - // Appends a command with the given |type| and payload |cmd_size| prefixed - // with a CmdHeader. Returns a pointer to the CmdHeader that is followed - // immediately by |cmd_size| zero bytes. - CmdHeader* AppendCmdHeader(CmdType type, size_t cmd_size); - - // Appends a byte buffer to the command buffer and returns a pointer to the - // copied data within the command buffer arena. - void* AppendCmdData(const void* source_buffer, device_size_t source_offset, - device_size_t source_length); - - // Appends a span of POD structs to the current CmdList and returns a span - // pointing into the CmdList arena. - template - absl::Span AppendStructSpan(absl::Span value) { - static_assert(std::is_standard_layout::value, - "Struct must be a POD type"); - void* data_ptr = AppendCmdData(value.data(), 0, value.size() * sizeof(T)); - return absl::MakeSpan(static_cast(data_ptr), value.size()); - } - - // Processes a single command. - Status ProcessCmd(CmdHeader* cmd_header, - CommandBuffer* command_processor) const; - - bool is_recording_ = false; - - // NOTE: not synchronized. Expected to be used from a single thread. - CmdList current_cmd_list_; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_ diff --git a/iree/hal/host/nop_event.h b/iree/hal/host/nop_event.h deleted file mode 100644 index 59efe714e429e..0000000000000 --- a/iree/hal/host/nop_event.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_NOP_EVENT_H_ -#define IREE_HAL_HOST_NOP_EVENT_H_ - -#include "iree/hal/event.h" - -namespace iree { -namespace hal { -namespace host { - -// A no-op event that can be used when a scheduling model does not perform -// intra-command buffer out-of-order execution. Since events must always have -// a signal recorded prior to recording a wait they are fine to ignore in -// in-order command processors. -class NopEvent final : public Event { - public: - NopEvent(); - ~NopEvent() override; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_NOP_EVENT_H_ diff --git a/iree/hal/host/scheduling_model.h b/iree/hal/host/scheduling_model.h deleted file mode 100644 index 38771ecc6a901..0000000000000 --- a/iree/hal/host/scheduling_model.h +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_SCHEDULING_MODEL_H_ -#define IREE_HAL_HOST_SCHEDULING_MODEL_H_ - -#include "iree/hal/command_queue.h" - -namespace iree { -namespace hal { -namespace host { - -// Host-local scheduling interface that device implementations can use to choose -// between various scheduling strategies (such as serial/in-order, -// fiber/out-of-order, etc). The interface models a subset of the Device -// interface relating to the scheduling primitives (such as semaphores) and the -// device-level operations that can be performed on them (such as wait-all). -class SchedulingModel { - public: - virtual ~SchedulingModel() = default; - - // Returns a list of all general-purpose dispatch queues provided by the - // device. In general these map 1:1 with independent execution contexts, - // though some devices may hide that and expose only a single queue that is - // scheduled internally. - virtual absl::Span dispatch_queues() const = 0; - - // Returns a list of transfer queues provided by the device. These queues may - // perform transfer operations asynchronously with respect to execution on the - // dispatch queues. For large sequences of transfer operations always prefer - // using one of these queues. - // Note that if the device does not support a dedicated transfer queue this - // list may be the same as (or a subset of) dispatch_queues. - virtual absl::Span transfer_queues() const = 0; - - // Creates a command buffer for recording commands to submit to queues owned - // by this device. The command buffer may come from a pool but will be reset - // prior to being returned to the caller. - virtual StatusOr> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) = 0; - - // Creates an event for recording into command buffers. - virtual StatusOr> CreateEvent() = 0; - - // Creates a semaphore that can be used with command queues owned by this - // device. To use the semaphores with other devices or instances they must - // first be exported. - virtual StatusOr> CreateSemaphore( - uint64_t initial_value) = 0; - - // Blocks the caller until all passed |semaphores| reach or exceed the - // specified payload values or the |deadline| elapses. All |semaphores| must - // be created from this device (or be imported into it). - // - // Returns success if the wait is successful and all semaphores have been - // signaled. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all semaphores - // having been signaled. Note that a subset of the |semaphores| may have been - // signaled and each can be queried to see which ones. - virtual Status WaitAllSemaphores(absl::Span semaphores, - Time deadline_ns) = 0; - - // Blocks the caller until at least one of the |semaphores| reaches or exceeds - // the specified payload value or the |deadline| elapses. All |semaphores| - // must be created from this device (or be imported into it). - // - // Returns an arbitrary index into |semaphores| of a semaphore that was - // signaled. Note that more than one semaphore may have been signaled and all - // of the other |semaphores| should be queried or waited on again until waits - // for them succeed. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any semaphores - // having been signaled. - virtual StatusOr WaitAnySemaphore( - absl::Span semaphores, Time deadline_ns) = 0; - - // Blocks until all outstanding requests on all queues have been - // completed. This is equivalent to having waited on all outstanding - // semaphores. - virtual Status WaitIdle(Time deadline_ns) = 0; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_SCHEDULING_MODEL_H_ diff --git a/iree/hal/host/serial/BUILD b/iree/hal/host/serial/BUILD deleted file mode 100644 index 9228d98d32529..0000000000000 --- a/iree/hal/host/serial/BUILD +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Default implementations for HAL types that use the host resources. -# These are generally just wrappers around host heap memory and host threads. - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "async_command_queue", - srcs = ["async_command_queue.cc"], - hdrs = ["async_command_queue.h"], - deps = [ - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "//iree/hal/host/serial:serial_submission_queue", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - ], -) - -cc_test( - name = "async_command_queue_test", - srcs = ["async_command_queue_test.cc"], - deps = [ - ":async_command_queue", - "//iree/base:status", - "//iree/base:time", - "//iree/hal", - "//iree/hal/host/serial:serial_submission_queue", - "//iree/hal/testing:mock_command_buffer", - "//iree/hal/testing:mock_command_queue", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - "@com_google_absl//absl/memory", - ], -) - -cc_library( - name = "serial_command_processor", - srcs = ["serial_command_processor.cc"], - hdrs = ["serial_command_processor.h"], - deps = [ - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "//iree/hal/host:host_descriptor_set", - "//iree/hal/host:host_executable", - "//iree/hal/host:host_executable_layout", - "@com_google_absl//absl/container:inlined_vector", - ], -) - -cc_library( - name = "serial_scheduling_model", - srcs = ["serial_scheduling_model.cc"], - hdrs = ["serial_scheduling_model.h"], - deps = [ - ":async_command_queue", - ":serial_command_processor", - ":serial_submission_queue", - "//iree/base:core_headers", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal/host:condvar_semaphore", - "//iree/hal/host:inproc_command_buffer", - "//iree/hal/host:nop_event", - "//iree/hal/host:scheduling_model", - "@com_google_absl//absl/container:inlined_vector", - ], -) - -cc_library( - name = "serial_submission_queue", - srcs = ["serial_submission_queue.cc"], - hdrs = ["serial_submission_queue.h"], - deps = [ - "//iree/base:intrusive_list", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "//iree/hal/host:condvar_semaphore", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - ], -) diff --git a/iree/hal/host/serial/CMakeLists.txt b/iree/hal/host/serial/CMakeLists.txt deleted file mode 100644 index 86b7da5409f59..0000000000000 --- a/iree/hal/host/serial/CMakeLists.txt +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -iree_add_all_subdirs() - -iree_cc_library( - NAME - async_command_queue - HDRS - "async_command_queue.h" - SRCS - "async_command_queue.cc" - DEPS - absl::core_headers - absl::synchronization - iree::base::status - iree::base::tracing - iree::hal - iree::hal::host::serial::serial_submission_queue - PUBLIC -) - -iree_cc_test( - NAME - async_command_queue_test - SRCS - "async_command_queue_test.cc" - DEPS - ::async_command_queue - absl::memory - iree::base::status - iree::base::time - iree::hal - iree::hal::host::serial::serial_submission_queue - iree::hal::testing::mock_command_buffer - iree::hal::testing::mock_command_queue - iree::testing::gtest - iree::testing::gtest_main -) - -iree_cc_library( - NAME - serial_command_processor - HDRS - "serial_command_processor.h" - SRCS - "serial_command_processor.cc" - DEPS - absl::inlined_vector - iree::base::status - iree::base::tracing - iree::hal - iree::hal::host::host_descriptor_set - iree::hal::host::host_executable - iree::hal::host::host_executable_layout - PUBLIC -) - -iree_cc_library( - NAME - serial_scheduling_model - HDRS - "serial_scheduling_model.h" - SRCS - "serial_scheduling_model.cc" - DEPS - ::async_command_queue - ::serial_command_processor - ::serial_submission_queue - absl::inlined_vector - iree::base::core_headers - iree::base::status - iree::base::tracing - iree::hal::host::condvar_semaphore - iree::hal::host::inproc_command_buffer - iree::hal::host::nop_event - iree::hal::host::scheduling_model - PUBLIC -) - -iree_cc_library( - NAME - serial_submission_queue - HDRS - "serial_submission_queue.h" - SRCS - "serial_submission_queue.cc" - DEPS - absl::core_headers - absl::inlined_vector - absl::synchronization - iree::base::intrusive_list - iree::base::status - iree::base::tracing - iree::hal - iree::hal::host::condvar_semaphore - PUBLIC -) diff --git a/iree/hal/host/serial/async_command_queue.cc b/iree/hal/host/serial/async_command_queue.cc deleted file mode 100644 index 5813ba055f46c..0000000000000 --- a/iree/hal/host/serial/async_command_queue.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/serial/async_command_queue.h" - -#include "absl/base/thread_annotations.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { -namespace host { - -AsyncCommandQueue::AsyncCommandQueue(std::unique_ptr target_queue) - : CommandQueue(target_queue->name(), target_queue->supported_categories()), - target_queue_(std::move(target_queue)) { - IREE_TRACE_SCOPE0("AsyncCommandQueue::ctor"); - thread_ = std::thread([this]() { ThreadMain(); }); -} - -AsyncCommandQueue::~AsyncCommandQueue() { - IREE_TRACE_SCOPE0("AsyncCommandQueue::dtor"); - { - // Signal to thread that we want to stop. Note that the thread may have - // already been stopped and that's ok (as we'll Join right away). - // The thread will finish processing any queued submissions. - absl::MutexLock lock(&submission_mutex_); - submission_queue_.SignalShutdown(); - } - thread_.join(); - - // Ensure we shut down OK. - { - absl::MutexLock lock(&submission_mutex_); - IREE_CHECK(submission_queue_.empty()) - << "Dirty shutdown of async queue (unexpected thread exit?)"; - } -} - -void AsyncCommandQueue::ThreadMain() { - IREE_TRACE_SET_THREAD_NAME(target_queue_->name().c_str()); - - bool is_exiting = false; - while (!is_exiting) { - // Block until we are either requested to exit or there are pending - // submissions. - submission_mutex_.Lock(); - submission_mutex_.Await(absl::Condition( - +[](SerialSubmissionQueue* queue) { - return queue->has_shutdown() || !queue->empty(); - }, - &submission_queue_)); - if (!submission_queue_.empty()) { - // Run all ready submissions (this may be called many times). - submission_mutex_.AssertHeld(); - submission_queue_ - .ProcessBatches( - [this](absl::Span command_buffers) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(submission_mutex_) { - // Release the lock while we perform the processing so that - // other threads can submit more work. - submission_mutex_.AssertHeld(); - submission_mutex_.Unlock(); - - // Relay the command buffers to the target queue. - // Since we are taking care of all synchronization they - // don't need any waiters or semaphores. - auto status = - target_queue_->Submit({{}, command_buffers, {}}); - - // Take back the lock so we can manipulate the queue safely. - submission_mutex_.Lock(); - submission_mutex_.AssertHeld(); - - return status; - }) - .IgnoreError(); - submission_mutex_.AssertHeld(); - } - if (submission_queue_.has_shutdown()) { - // Exit when there are no more submissions to process and an exit was - // requested (or we errored out). - is_exiting = true; - } - submission_mutex_.Unlock(); - } -} - -Status AsyncCommandQueue::Submit(absl::Span batches) { - IREE_TRACE_SCOPE0("AsyncCommandQueue::Submit"); - absl::MutexLock lock(&submission_mutex_); - return submission_queue_.Enqueue(batches); -} - -Status AsyncCommandQueue::WaitIdle(Time deadline_ns) { - IREE_TRACE_SCOPE0("AsyncCommandQueue::WaitIdle"); - - // Wait until the deadline, the thread exits, or there are no more pending - // submissions. - absl::MutexLock lock(&submission_mutex_); - if (!submission_mutex_.AwaitWithDeadline( - absl::Condition( - +[](SerialSubmissionQueue* queue) { - return queue->empty() || !queue->permanent_error().ok(); - }, - &submission_queue_), - absl::FromUnixNanos(static_cast(deadline_ns)))) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for submission thread to go idle"; - } - return submission_queue_.permanent_error(); -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/serial/async_command_queue.h b/iree/hal/host/serial/async_command_queue.h deleted file mode 100644 index e474c23091e0a..0000000000000 --- a/iree/hal/host/serial/async_command_queue.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_SERIAL_ASYNC_COMMAND_QUEUE_H_ -#define IREE_HAL_HOST_SERIAL_ASYNC_COMMAND_QUEUE_H_ - -#include -#include // NOLINT - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/host/serial/serial_submission_queue.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { -namespace host { - -// Asynchronous command queue wrapper. -// This creates a single thread to perform all CommandQueue operations. Any -// submitted CommandBuffer is dispatched in FIFO order on the queue thread -// against the provided |target_queue|. -// -// Target queues will receive submissions containing only command buffers as -// all semaphore synchronization is handled by the wrapper. Semaphores will also -// be omitted and code should safely handle nullptr. -// -// AsyncCommandQueue (as with CommandQueue) is thread-safe. Multiple threads -// may submit command buffers concurrently, though the order of execution in -// such a case depends entirely on the synchronization primitives provided. -class AsyncCommandQueue final : public CommandQueue { - public: - explicit AsyncCommandQueue(std::unique_ptr target_queue); - ~AsyncCommandQueue() override; - - Status Submit(absl::Span batches) override; - - Status WaitIdle(Time deadline_ns) override; - - private: - // Thread entry point for the async worker thread. - // Waits for submissions to be queued up and processes them eagerly. - void ThreadMain(); - - // CommandQueue that the async queue relays submissions into. - std::unique_ptr target_queue_; - - // Thread that runs the ThreadMain() function and processes submissions. - std::thread thread_; - - // Queue that manages submission ordering. - mutable absl::Mutex submission_mutex_; - SerialSubmissionQueue submission_queue_ ABSL_GUARDED_BY(submission_mutex_); -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_SERIAL_ASYNC_COMMAND_QUEUE_H_ diff --git a/iree/hal/host/serial/async_command_queue_test.cc b/iree/hal/host/serial/async_command_queue_test.cc deleted file mode 100644 index 8583ba2b8138e..0000000000000 --- a/iree/hal/host/serial/async_command_queue_test.cc +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/serial/async_command_queue.h" - -#include -#include -#include -#include -#include - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/host/serial/serial_submission_queue.h" -#include "iree/hal/testing/mock_command_buffer.h" -#include "iree/hal/testing/mock_command_queue.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -namespace host { -namespace { - -using ::testing::_; - -using testing::MockCommandBuffer; -using testing::MockCommandQueue; - -// Suspends execution of the calling thread for the given |duration_ms|. -// Depending on platform this may have an extremely coarse resolution (upwards -// of several to dozens of milliseconds). -inline void Sleep(std::chrono::milliseconds duration_ms) { - std::this_thread::sleep_for(duration_ms); -} - -struct AsyncCommandQueueTest : public ::testing::Test { - MockCommandQueue* mock_target_queue; - std::unique_ptr command_queue; - - void SetUp() override { - auto mock_queue = absl::make_unique( - "mock", CommandCategory::kTransfer | CommandCategory::kDispatch); - mock_target_queue = mock_queue.get(); - command_queue = absl::make_unique(std::move(mock_queue)); - } - - void TearDown() override { - command_queue.reset(); - mock_target_queue = nullptr; - } -}; - -// Tests that submitting a command buffer and immediately waiting will not -// deadlock. -TEST_F(AsyncCommandQueueTest, BlockingSubmit) { - ::testing::InSequence sequence; - - auto cmd_buffer = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - - EXPECT_CALL(*mock_target_queue, Submit(_)) - .WillOnce([&](absl::Span batches) { - IREE_CHECK_EQ(1, batches.size()); - IREE_CHECK_EQ(1, batches[0].command_buffers.size()); - IREE_CHECK_EQ(cmd_buffer.get(), batches[0].command_buffers[0]); - return OkStatus(); - }); - CondVarSemaphore semaphore(0ull); - IREE_ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer.get()}, {{&semaphore, 1ull}}})); - IREE_ASSERT_OK(semaphore.Wait(1ull, InfiniteFuture())); -} - -// Tests that failure is propagated along the fence from the target queue. -TEST_F(AsyncCommandQueueTest, PropagateSubmitFailure) { - ::testing::InSequence sequence; - - auto cmd_buffer = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - - EXPECT_CALL(*mock_target_queue, Submit(_)) - .WillOnce([](absl::Span batches) { - return DataLossErrorBuilder(IREE_LOC); - }); - CondVarSemaphore semaphore(0ull); - IREE_ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer.get()}, {{&semaphore, 1ull}}})); - EXPECT_TRUE(IsDataLoss(semaphore.Wait(1ull, InfiniteFuture()))); -} - -// Tests that waiting for idle is a no-op when nothing is queued. -TEST_F(AsyncCommandQueueTest, WaitIdleWhileIdle) { - IREE_ASSERT_OK(command_queue->WaitIdle()); -} - -// Tests that waiting for idle will block when work is pending/in-flight. -TEST_F(AsyncCommandQueueTest, WaitIdleWithPending) { - ::testing::InSequence sequence; - - auto cmd_buffer = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - - EXPECT_CALL(*mock_target_queue, Submit(_)) - .WillOnce([](absl::Span batches) { - Sleep(std::chrono::milliseconds(100)); - return OkStatus(); - }); - CondVarSemaphore semaphore(0ull); - IREE_ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer.get()}, {{&semaphore, 1ull}}})); - - // This should block for a sec or two. - IREE_ASSERT_OK(command_queue->WaitIdle()); - - // Should have already expired. - IREE_ASSERT_OK_AND_ASSIGN(uint64_t value, semaphore.Query()); - ASSERT_EQ(1ull, value); -} - -// Tests that waiting for idle with multiple pending submissions will wait until -// all of them complete while still allowing incremental progress. -TEST_F(AsyncCommandQueueTest, WaitIdleAndProgress) { - ::testing::InSequence sequence; - - EXPECT_CALL(*mock_target_queue, Submit(_)) - .WillRepeatedly([](absl::Span batches) { - Sleep(std::chrono::milliseconds(100)); - return OkStatus(); - }); - - auto cmd_buffer_0 = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - auto cmd_buffer_1 = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - - CondVarSemaphore semaphore_0(0u); - IREE_ASSERT_OK(command_queue->Submit( - {{}, {cmd_buffer_0.get()}, {{&semaphore_0, 1ull}}})); - CondVarSemaphore semaphore_1(0u); - IREE_ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer_1.get()}, {{&semaphore_1, 1u}}})); - - // This should block for a sec or two. - IREE_ASSERT_OK(command_queue->WaitIdle()); - - // Both should have already expired. - IREE_ASSERT_OK_AND_ASSIGN(uint64_t value_0, semaphore_0.Query()); - ASSERT_EQ(1ull, value_0); - IREE_ASSERT_OK_AND_ASSIGN(uint64_t value_1, semaphore_1.Query()); - ASSERT_EQ(1ull, value_1); -} - -// Tests that failures are sticky. -TEST_F(AsyncCommandQueueTest, StickyFailures) { - ::testing::InSequence sequence; - - // Fail. - EXPECT_CALL(*mock_target_queue, Submit(_)) - .WillOnce([](absl::Span batches) { - Sleep(std::chrono::milliseconds(100)); - return DataLossErrorBuilder(IREE_LOC); - }); - auto cmd_buffer_0 = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - CondVarSemaphore semaphore_0(0ull); - IREE_ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer_0.get()}, {{&semaphore_0, 1u}}})); - EXPECT_TRUE(IsDataLoss(semaphore_0.Wait(1ull, InfiniteFuture()))); - - // Future flushes/waits/etc should also fail. - EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); - - // Future submits should fail asynchronously. - auto cmd_buffer_1 = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - CondVarSemaphore semaphore_1(0ull); - EXPECT_TRUE(IsDataLoss(command_queue->Submit( - {{}, {cmd_buffer_1.get()}, {{&semaphore_1, 1ull}}}))); -} - -// Tests that a failure with two submissions pending causes the second to -// bail as well. -TEST_F(AsyncCommandQueueTest, FailuresCascadeAcrossSubmits) { - ::testing::InSequence sequence; - - // Fail. - EXPECT_CALL(*mock_target_queue, Submit(_)) - .WillOnce([](absl::Span batches) { - Sleep(std::chrono::milliseconds(100)); - return DataLossErrorBuilder(IREE_LOC); - }); - - auto cmd_buffer_0 = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - auto cmd_buffer_1 = make_ref(CommandBufferMode::kOneShot, - CommandCategory::kTransfer); - - CondVarSemaphore semaphore_0(0ull); - IREE_ASSERT_OK(command_queue->Submit( - {{}, {cmd_buffer_0.get()}, {{&semaphore_0, 1ull}}})); - CondVarSemaphore semaphore_1(0ull); - IREE_ASSERT_OK(command_queue->Submit( - {{{&semaphore_0, 1ull}}, {cmd_buffer_1.get()}, {{&semaphore_1, 1ull}}})); - - EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); - - EXPECT_TRUE(IsDataLoss(semaphore_0.Wait(1ull, InfiniteFuture()))); - EXPECT_TRUE(IsDataLoss(semaphore_1.Wait(1ull, InfiniteFuture()))); - - // Future flushes/waits/etc should also fail. - EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); -} - -} // namespace -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/serial/serial_command_processor.cc b/iree/hal/host/serial/serial_command_processor.cc deleted file mode 100644 index c5d1cc8cf21f4..0000000000000 --- a/iree/hal/host/serial/serial_command_processor.cc +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/serial/serial_command_processor.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/host/host_descriptor_set.h" -#include "iree/hal/host/host_executable_layout.h" - -namespace iree { -namespace hal { -namespace host { - -SerialCommandProcessor::SerialCommandProcessor( - CommandCategoryBitfield command_categories) - : CommandBuffer(CommandBufferMode::kOneShot, command_categories) {} - -SerialCommandProcessor::~SerialCommandProcessor() = default; - -Status SerialCommandProcessor::Begin() { - IREE_TRACE_SCOPE0("SerialCommandProcessor::Begin"); - is_recording_ = true; - return OkStatus(); -} - -Status SerialCommandProcessor::End() { - IREE_TRACE_SCOPE0("SerialCommandProcessor::End"); - is_recording_ = false; - return OkStatus(); -} - -Status SerialCommandProcessor::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::ExecutionBarrier"); - // No-op. - return OkStatus(); -} - -Status SerialCommandProcessor::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::SignalEvent"); - // No-op. - return OkStatus(); -} - -Status SerialCommandProcessor::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::ResetEvent"); - // No-op. - return OkStatus(); -} - -Status SerialCommandProcessor::WaitEvents( - absl::Span events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::WaitEvents"); - // No-op. - return OkStatus(); -} - -Status SerialCommandProcessor::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::FillBuffer"); - return target_buffer->Fill(target_offset, length, pattern, pattern_length); -} - -Status SerialCommandProcessor::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::DiscardBuffer"); - // No-op as we don't support lazily allocated buffers. - return OkStatus(); -} - -Status SerialCommandProcessor::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::UpdateBuffer"); - return target_buffer->WriteData( - target_offset, static_cast(source_buffer) + source_offset, - length); -} - -Status SerialCommandProcessor::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::CopyBuffer"); - return target_buffer->CopyData(target_offset, source_buffer, source_offset, - length); -} - -Status SerialCommandProcessor::PushConstants( - ExecutableLayout* executable_layout, size_t offset, - absl::Span values) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::PushConstants"); - if (offset + values.size() > push_constants_.values.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Push constants out of range"; - } - for (int i = 0; i < values.size(); ++i) { - push_constants_.values[offset + i] = values[i]; - } - return OkStatus(); -} - -Status SerialCommandProcessor::PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::PushDescriptorSet"); - if (!AnyBitSet(command_categories() & CommandCategory::kDispatch)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Command processor does not support dispatch operations"; - } - - auto* host_executable_layout = - static_cast(executable_layout); - descriptor_sets_.resize(host_executable_layout->set_count()); - if (set < 0 || set >= descriptor_sets_.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Set " << set << " out of range (" << descriptor_sets_.size() - << ")"; - } - - auto& set_bindings = descriptor_sets_[set]; - set_bindings.resize(bindings.size()); - for (size_t i = 0; i < bindings.size(); ++i) { - set_bindings[i] = bindings[i]; - } - - return OkStatus(); -} - -Status SerialCommandProcessor::BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::BindDescriptorSet"); - if (!AnyBitSet(command_categories() & CommandCategory::kDispatch)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Command processor does not support dispatch operations"; - } - - auto* host_executable_layout = - static_cast(executable_layout); - descriptor_sets_.resize(host_executable_layout->set_count()); - if (set < 0 || descriptor_sets_.size() >= set) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Set " << set << " out of range (" << descriptor_sets_.size() - << ")"; - } - - auto* host_descriptor_set = static_cast(descriptor_set); - auto* set_bindings = &descriptor_sets_[set]; - *set_bindings = {host_descriptor_set->bindings().begin(), - host_descriptor_set->bindings().end()}; - if (!dynamic_offsets.empty()) { - auto dynamic_binding_map = - host_executable_layout->GetDynamicBindingMap(set); - if (dynamic_offsets.size() != dynamic_binding_map.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Dynamic offset count mismatch (provided " - << dynamic_offsets.size() << " but expected " - << dynamic_binding_map.size() << ")"; - } - for (int i = 0; i < dynamic_binding_map.size(); ++i) { - (*set_bindings)[dynamic_binding_map[i]].offset += dynamic_offsets[i]; - } - } - - return OkStatus(); -} - -Status SerialCommandProcessor::Dispatch(Executable* executable, - int32_t entry_point, - std::array workgroups) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::Dispatch"); - return DispatchGrid(executable, entry_point, workgroups); -} - -Status SerialCommandProcessor::DispatchIndirect( - Executable* executable, int32_t entry_point, Buffer* workgroups_buffer, - device_size_t workgroups_offset) { - IREE_TRACE_SCOPE0("SerialCommandProcessor::DispatchIndirect"); - - std::array workgroup_count; - IREE_RETURN_IF_ERROR(workgroups_buffer->ReadData( - workgroups_offset, workgroup_count.data(), sizeof(uint32_t) * 3)); - - return DispatchGrid(executable, entry_point, workgroup_count); -} - -Status SerialCommandProcessor::DispatchGrid( - Executable* executable, int32_t entry_point, - std::array workgroup_count) { - HostExecutable::DispatchParams params; - params.entry_point = entry_point; - params.workgroup_count = workgroup_count; - params.push_constants = &push_constants_; - - absl::InlinedVector, 2> - descriptor_sets(descriptor_sets_.size()); - for (int i = 0; i < descriptor_sets_.size(); ++i) { - descriptor_sets[i] = absl::MakeConstSpan(descriptor_sets_[i]); - } - params.set_bindings = descriptor_sets; - - auto* host_executable = reinterpret_cast(executable); - IREE_ASSIGN_OR_RETURN(auto dispatch_state, - host_executable->PrepareDispatch(params)); - for (uint32_t z = 0; z < params.workgroup_count[2]; ++z) { - for (uint32_t y = 0; y < params.workgroup_count[1]; ++y) { - for (uint32_t x = 0; x < params.workgroup_count[0]; ++x) { - IREE_RETURN_IF_ERROR( - host_executable->DispatchTile(dispatch_state.get(), {x, y, z})); - } - } - } - return OkStatus(); -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/serial/serial_command_processor.h b/iree/hal/host/serial/serial_command_processor.h deleted file mode 100644 index 1577fd5d89a75..0000000000000 --- a/iree/hal/host/serial/serial_command_processor.h +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_SERIAL_SERIAL_COMMAND_PROCESSOR_H_ -#define IREE_HAL_HOST_SERIAL_SERIAL_COMMAND_PROCESSOR_H_ - -#include "absl/container/inlined_vector.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/host/host_executable.h" - -namespace iree { -namespace hal { -namespace host { - -// Host-local command processor for dispatching transfer operations against -// buffers allocated from the HostLocalAllocator. -// This assumes that all buffers are host-visible (if not local) and that all -// buffers can be mapped for access. -// -// Uses HostExecutable to perform tiled dispatch processing. -// -// Thread-compatible (as with CommandBuffer itself). -class SerialCommandProcessor final : public CommandBuffer { - public: - explicit SerialCommandProcessor(CommandCategoryBitfield command_categories); - ~SerialCommandProcessor() override; - - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status WaitEvents(absl::Span events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - - Status DiscardBuffer(Buffer* buffer) override; - - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status PushConstants(ExecutableLayout* executable_layout, size_t offset, - absl::Span values) override; - - Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) override; - - Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) override; - - Status Dispatch(Executable* executable, int32_t entry_point, - std::array workgroups) override; - - Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) override; - - private: - Status DispatchGrid(Executable* executable, int32_t entry_point, - std::array workgroup_count); - - bool is_recording_ = false; - - PushConstantBlock push_constants_; - absl::InlinedVector, 2> - descriptor_sets_; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_SERIAL_SERIAL_COMMAND_PROCESSOR_H_ diff --git a/iree/hal/host/serial/serial_scheduling_model.cc b/iree/hal/host/serial/serial_scheduling_model.cc deleted file mode 100644 index d9fdedfda3b14..0000000000000 --- a/iree/hal/host/serial/serial_scheduling_model.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/serial/serial_scheduling_model.h" - -#include "iree/base/tracing.h" -#include "iree/hal/host/condvar_semaphore.h" -#include "iree/hal/host/inproc_command_buffer.h" -#include "iree/hal/host/nop_event.h" -#include "iree/hal/host/serial/async_command_queue.h" -#include "iree/hal/host/serial/serial_command_processor.h" -#include "iree/hal/host/serial/serial_submission_queue.h" - -namespace iree { -namespace hal { -namespace host { -namespace { - -// A CommandQueue that performs no synchronization (semaphores/fences) and just -// directly executes command buffers inline. -// -// This is meant to be wrapped by SyncCommandQueue or AsyncCommandQueue that -// themselves perform the synchronization/threading/etc. As such we ignore -// all semaphores in the provided batches under the assumption that if Submit is -// being called then all dependencies are valid. The wrapping queue is also -// responsible for signaling the fence as well as propagating errors in a way -// that is dependent on how it is performing its synchronization. -class UnsynchronizedCommandQueue final : public CommandQueue { - public: - UnsynchronizedCommandQueue(std::string name, - CommandCategoryBitfield supported_categories) - : CommandQueue(std::move(name), supported_categories) {} - ~UnsynchronizedCommandQueue() override = default; - - Status Submit(absl::Span batches) override { - IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::Submit"); - - // Process command buffers and propagate errors asynchronously through the - // fence. This ensures that even if we are running synchronously we still - // get consistent failure behavior with drivers that are purely async. - for (auto& batch : batches) { - IREE_DCHECK(batch.wait_semaphores.empty() && - batch.signal_semaphores.empty()) - << "Semaphores must be handled by the wrapping queue"; - IREE_RETURN_IF_ERROR(ProcessCommandBuffers(batch.command_buffers)); - } - - return OkStatus(); - } - - Status WaitIdle(Time deadline_ns) override { - // No-op. - return OkStatus(); - } - - private: - // Processes each command buffer in-turn with a fresh processor. - // This ensures we don't have any state that can carry across buffers. - Status ProcessCommandBuffers( - absl::Span command_buffers) { - IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::ProcessCommandBuffers"); - for (auto* command_buffer : command_buffers) { - auto* inproc_command_buffer = - static_cast(command_buffer->impl()); - SerialCommandProcessor command_processor(supported_categories()); - IREE_RETURN_IF_ERROR(inproc_command_buffer->Process(&command_processor)); - } - return OkStatus(); - } -}; - -} // namespace - -SerialSchedulingModel::SerialSchedulingModel() { - // We currently only expose a single command queue. - auto command_queue = absl::make_unique( - "cpu0", CommandCategory::kTransfer | CommandCategory::kDispatch); - - // Wrap in the simple async command queue. - auto async_command_queue = - absl::make_unique(std::move(command_queue)); - command_queues_.push_back(std::move(async_command_queue)); -} - -SerialSchedulingModel::~SerialSchedulingModel() = default; - -StatusOr> SerialSchedulingModel::CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) { - return make_ref(mode, command_categories); -} - -StatusOr> SerialSchedulingModel::CreateEvent() { - return make_ref(); -} - -StatusOr> SerialSchedulingModel::CreateSemaphore( - uint64_t initial_value) { - return make_ref(initial_value); -} - -Status SerialSchedulingModel::WaitAllSemaphores( - absl::Span semaphores, Time deadline_ns) { - return CondVarSemaphore::WaitForSemaphores(semaphores, /*wait_all=*/true, - deadline_ns); -} - -StatusOr SerialSchedulingModel::WaitAnySemaphore( - absl::Span semaphores, Time deadline_ns) { - return CondVarSemaphore::WaitForSemaphores(semaphores, /*wait_all=*/false, - deadline_ns); -} - -Status SerialSchedulingModel::WaitIdle(Time deadline_ns) { - for (auto& command_queue : command_queues_) { - IREE_RETURN_IF_ERROR(command_queue->WaitIdle(deadline_ns)); - } - return OkStatus(); -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/serial/serial_scheduling_model.h b/iree/hal/host/serial/serial_scheduling_model.h deleted file mode 100644 index b065025f50ed2..0000000000000 --- a/iree/hal/host/serial/serial_scheduling_model.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_SERIAL_SERIAL_SCHEDULING_MODEL_H_ -#define IREE_HAL_HOST_SERIAL_SERIAL_SCHEDULING_MODEL_H_ - -#include "absl/container/inlined_vector.h" -#include "iree/base/memory.h" -#include "iree/hal/host/scheduling_model.h" - -namespace iree { -namespace hal { -namespace host { - -// Performs host-local scheduling by way of a simple serial queue. -// Submissions and commands are processed in-order one at a time on a single -// core. This is a reference implementation that has no dependencies beyond -// std::thread and allows us to quickly bring up new platforms and more easily -// debug/profile as we won't have OS fibers/other weird constructs involved. -class SerialSchedulingModel final : public SchedulingModel { - public: - SerialSchedulingModel(); - ~SerialSchedulingModel() override; - - absl::Span dispatch_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - absl::Span transfer_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - StatusOr> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr> CreateEvent() override; - - StatusOr> CreateSemaphore(uint64_t initial_value) override; - - Status WaitAllSemaphores(absl::Span semaphores, - Time deadline_ns) override; - StatusOr WaitAnySemaphore(absl::Span semaphores, - Time deadline_ns) override; - Status WaitIdle(Time deadline_ns) override; - - private: - mutable absl::InlinedVector, 4> command_queues_; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_SERIAL_SERIAL_SCHEDULING_MODEL_H_ diff --git a/iree/hal/host/serial/serial_submission_queue.cc b/iree/hal/host/serial/serial_submission_queue.cc deleted file mode 100644 index 5780e78ceffa1..0000000000000 --- a/iree/hal/host/serial/serial_submission_queue.cc +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/serial/serial_submission_queue.h" - -#include -#include - -#include "absl/synchronization/mutex.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { -namespace host { - -SerialSubmissionQueue::SerialSubmissionQueue() = default; - -SerialSubmissionQueue::~SerialSubmissionQueue() = default; - -StatusOr SerialSubmissionQueue::CheckBatchReady( - const PendingBatch& batch) const { - for (auto& wait_point : batch.wait_semaphores) { - auto* semaphore = reinterpret_cast(wait_point.semaphore); - IREE_ASSIGN_OR_RETURN(uint64_t value, semaphore->Query()); - if (value < wait_point.value) { - return false; - } - } - return true; -} - -Status SerialSubmissionQueue::Enqueue( - absl::Span batches) { - IREE_TRACE_SCOPE0("SerialSubmissionQueue::Enqueue"); - - if (has_shutdown_) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Cannot enqueue new submissions; queue is exiting"; - } else if (!permanent_error_.ok()) { - return permanent_error_; - } - - // Add to list in submission order. - auto submission = absl::make_unique(); - submission->pending_batches.resize(batches.size()); - for (int i = 0; i < batches.size(); ++i) { - submission->pending_batches[i] = PendingBatch{ - {batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()}, - {batches[i].command_buffers.begin(), batches[i].command_buffers.end()}, - {batches[i].signal_semaphores.begin(), - batches[i].signal_semaphores.end()}, - }; - } - list_.push_back(std::move(submission)); - - return OkStatus(); -} - -Status SerialSubmissionQueue::ProcessBatches(ExecuteFn execute_fn) { - IREE_TRACE_SCOPE0("SerialSubmissionQueue::ProcessBatches"); - - if (!permanent_error_.ok()) { - // Sticky failure state. - return permanent_error_; - } - - // Repeated try to run things until we quiesce or are blocked. - while (permanent_error_.ok() && !list_.empty()) { - // NOTE: to support re-entrancy where |execute_fn| may modify the submission - // list we need to always start from the beginning. If we wanted we could - // track a list of ready submissions however that's a lot of bookkeeping and - // the list is usually short. - auto* submission = list_.front(); - for (int i = 0; i < submission->pending_batches.size(); ++i) { - auto& batch = submission->pending_batches[i]; - auto wait_status_or = CheckBatchReady(batch); - if (!wait_status_or.ok()) { - // Batch dependencies failed; set the permanent error flag and abort - // so we don't try to process anything else. - permanent_error_ = std::move(wait_status_or).status(); - CompleteSubmission(submission, permanent_error_); - FailAllPending(permanent_error_); - return permanent_error_; - } else if (wait_status_or.ok() && !wait_status_or.value()) { - // To preserve submission order we bail if we encounter a batch that - // is not ready and wait for something to become ready before pumping - // again. - // Note that if we were properly threading here we would potentially - // be evaluating this while previous batches were still processing - // but for now we do everything serially. - return OkStatus(); - } - - // Batch can run! Process now and remove it from the list so we don't - // try to run it again. - auto batch_status = ProcessBatch(batch, execute_fn); - if (!batch_status.ok()) { - // Batch failed; set the permanent error flag and abort so we don't - // try to process anything else. - permanent_error_ = Status(batch_status); - CompleteSubmission(submission, batch_status); - FailAllPending(permanent_error_); - return permanent_error_; - } - submission->pending_batches.erase(submission->pending_batches.begin() + - i); - - // Batch succeeded. Since we want to preserve submission order we'll - // break out of the loop and try from the first submission again. - if (submission->pending_batches.empty()) { - // All work for this submission completed successfully. Signal the - // semaphore and remove the submission from the list. - CompleteSubmission(submission, OkStatus()); - break; - } - } - } - - return OkStatus(); -} - -Status SerialSubmissionQueue::ProcessBatch(const PendingBatch& batch, - const ExecuteFn& execute_fn) { - IREE_TRACE_SCOPE0("SerialSubmissionQueue::ProcessBatch"); - - // NOTE: the precondition is that the batch is ready to execute so we don't - // need to check the wait semaphores here. - - // Let the caller handle execution of the command buffers. - IREE_RETURN_IF_ERROR(execute_fn(batch.command_buffers)); - - // Signal all semaphores to allow them to unblock waiters. - for (auto& signal_point : batch.signal_semaphores) { - auto* semaphore = - reinterpret_cast(signal_point.semaphore); - IREE_RETURN_IF_ERROR(semaphore->Signal(signal_point.value)); - } - - return OkStatus(); -} - -void SerialSubmissionQueue::CompleteSubmission(Submission* submission, - Status status) { - IREE_TRACE_SCOPE0("SerialSubmissionQueue::CompleteSubmission"); - - if (status.ok() && !submission->pending_batches.empty()) { - // Completed with work remaining? Cause a failure. - status = FailedPreconditionErrorBuilder(IREE_LOC) - << "Submission ended prior to completion of all batches"; - } - if (!status.ok()) { - // Fail all pending batch semaphores that we would have signaled. - for (auto& batch : submission->pending_batches) { - for (auto& signal_point : batch.signal_semaphores) { - auto* semaphore = - reinterpret_cast(signal_point.semaphore); - semaphore->Fail(status); - } - } - submission->pending_batches.clear(); - } - - list_.take(submission).reset(); -} - -void SerialSubmissionQueue::FailAllPending(Status status) { - IREE_TRACE_SCOPE0("SerialSubmissionQueue::FailAllPending"); - while (!list_.empty()) { - CompleteSubmission(list_.front(), status); - } -} - -void SerialSubmissionQueue::SignalShutdown() { - IREE_TRACE_SCOPE0("SerialSubmissionQueue::SignalShutdown"); - has_shutdown_ = true; -} - -} // namespace host -} // namespace hal -} // namespace iree diff --git a/iree/hal/host/serial/serial_submission_queue.h b/iree/hal/host/serial/serial_submission_queue.h deleted file mode 100644 index f7d262f4c7d98..0000000000000 --- a/iree/hal/host/serial/serial_submission_queue.h +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_SERIAL_SERIAL_SUBMISSION_QUEUE_H_ -#define IREE_HAL_HOST_SERIAL_SERIAL_SUBMISSION_QUEUE_H_ - -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/status.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/host/condvar_semaphore.h" - -namespace iree { -namespace hal { -namespace host { - -// A queue managing CommandQueue submissions that uses host-local -// synchronization primitives. Evaluates submission order by respecting the -// wait and signal semaphores defined per batch and notifies semaphores upon -// submission completion. -// -// Note that it's possible for HAL users to deadlock themselves; we don't try to -// avoid that as in device backends it may not be possible and we want to have -// some kind of warning in the host implementation that TSAN can catch. -// -// Thread-compatible. Const methods may be called from any thread. -class SerialSubmissionQueue final { - public: - using ExecuteFn = - std::function command_buffers)>; - - SerialSubmissionQueue(); - ~SerialSubmissionQueue(); - - // Returns true if the queue is currently empty. - bool empty() const { return list_.empty(); } - // Returns true if SignalShutdown has been called. - bool has_shutdown() const { return has_shutdown_; } - // The sticky error status, if an error has occurred. - Status permanent_error() const { return permanent_error_; } - - // Enqueues a new submission. - // No work will be performed until Process is called. - Status Enqueue(absl::Span batches); - - // Processes all ready batches using the provided |execute_fn|. - // The function may be called several times if new batches become ready due to - // prior batches in the sequence completing during processing. - // - // Returns any errors returned by |execute_fn| (which will be the same as - // permanent_error()). When an error occurs all in-flight submissions are - // aborted, the permanent_error() is set, and the queue is shutdown. - Status ProcessBatches(ExecuteFn execute_fn); - - // Marks the queue as having shutdown. All pending submissions will be allowed - // to complete but future enqueues will fail. - void SignalShutdown(); - - private: - // A submitted command buffer batch and its synchronization information. - struct PendingBatch { - absl::InlinedVector wait_semaphores; - absl::InlinedVector command_buffers; - absl::InlinedVector signal_semaphores; - }; - struct Submission : public IntrusiveLinkBase { - absl::InlinedVector pending_batches; - }; - - // Returns true if all wait semaphores in the |batch| are signaled. - // If one or more of the wait semaphores have failed then returns a status - // from one of them arbitrarily. - StatusOr CheckBatchReady(const PendingBatch& batch) const; - - // Processes a batch by resetting semaphores, dispatching the command buffers - // to the specified |execute_fn|, and signaling semaphores. - // - // Preconditions: CheckBatchReady(batch) == true - Status ProcessBatch(const PendingBatch& batch, const ExecuteFn& execute_fn); - - // Completes a submission. Assumes that all batches have had their semaphores - // signaled and that any remaining here will need to be signaled for failure. - void CompleteSubmission(Submission* submission, Status status); - - // Fails all pending submissions with the given status. - // Errors that occur during this process are silently ignored. - void FailAllPending(Status status); - - // True to exit the thread after all submissions complete. - bool has_shutdown_ = false; - - // A sticky error that is set on the first failed submit. All future - // submissions will be skipped except for semaphores, which will receive this - // error. - Status permanent_error_; - - // Pending submissions in submission order. - // Note that we may evaluate batches within the list out of order. - IntrusiveList> list_; -}; - -} // namespace host -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_SERIAL_SERIAL_SUBMISSION_QUEUE_H_ diff --git a/iree/hal/local/BUILD b/iree/hal/local/BUILD new file mode 100644 index 0000000000000..82c641671bdba --- /dev/null +++ b/iree/hal/local/BUILD @@ -0,0 +1,121 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Default implementations for HAL types that use the host resources. +# These are generally just wrappers around host heap memory and host threads. + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +# TODO(benvanik): replace iree/base/arena.h with this one. We still want the +# old-style arena for pure stack use; we may be able to do that with a change +# to block pool that allows for on-stack initialization (iree_stack_arena_t +# that has storage for one block inside itself and then dynamically allocates +# new ones if needed). That way we have only one arena implementation and can +# easily use the iree_allocator_t interface without worry. +cc_library( + name = "arena", + srcs = ["arena.c"], + hdrs = ["arena.h"], + deps = [ + "//iree/base:api", + "//iree/base:atomic_slist", + "//iree/base:core_headers", + "//iree/base:synchronization", + ], +) + +# TODO(benvanik): move into base/? may be useful for other backends or for other +# parts of the system (like modules handling IO/RPC). +cc_library( + name = "event_pool", + srcs = ["event_pool.c"], + hdrs = ["event_pool.h"], + deps = [ + "//iree/base:api", + "//iree/base:core_headers", + "//iree/base:synchronization", + "//iree/base:tracing", + "//iree/base:wait_handle", + ], +) + +cc_library( + name = "executable_library", + hdrs = ["executable_library.h"], +) + +cc_library( + name = "local", + srcs = [ + "executable_loader.c", + "local_descriptor_set.c", + "local_descriptor_set_layout.c", + "local_executable.c", + "local_executable_cache.c", + "local_executable_layout.c", + ], + hdrs = [ + "executable_loader.h", + "local_descriptor_set.h", + "local_descriptor_set_layout.h", + "local_executable.h", + "local_executable_cache.h", + "local_executable_layout.h", + ], + deps = [ + ":executable_library", + "//iree/base:api", + "//iree/base:core_headers", + "//iree/base:tracing", + "//iree/hal:api", + ], +) + +cc_library( + name = "task_driver", + srcs = [ + "task_command_buffer.c", + "task_device.c", + "task_driver.c", + "task_event.c", + "task_queue.c", + "task_queue_state.c", + "task_semaphore.c", + ], + hdrs = [ + "task_command_buffer.h", + "task_device.h", + "task_driver.h", + "task_event.h", + "task_queue.h", + "task_queue_state.h", + "task_semaphore.h", + ], + deps = [ + ":arena", + ":event_pool", + ":local", + "//iree/base:api", + "//iree/base:core_headers", + "//iree/base:synchronization", + "//iree/base:tracing", + "//iree/base:wait_handle", + "//iree/hal:api", + "//iree/task", + ], +) diff --git a/iree/hal/local/CMakeLists.txt b/iree/hal/local/CMakeLists.txt new file mode 100644 index 0000000000000..3eb8267912a52 --- /dev/null +++ b/iree/hal/local/CMakeLists.txt @@ -0,0 +1,113 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +iree_add_all_subdirs() + +iree_cc_library( + NAME + arena + HDRS + "arena.h" + SRCS + "arena.c" + DEPS + iree::base::api + iree::base::atomic_slist + iree::base::core_headers + iree::base::synchronization + PUBLIC +) + +iree_cc_library( + NAME + event_pool + HDRS + "event_pool.h" + SRCS + "event_pool.c" + DEPS + iree::base::api + iree::base::core_headers + iree::base::synchronization + iree::base::tracing + iree::base::wait_handle + PUBLIC +) + +iree_cc_library( + NAME + executable_library + HDRS + "executable_library.h" + PUBLIC +) + +iree_cc_library( + NAME + local + HDRS + "executable_loader.h" + "local_descriptor_set.h" + "local_descriptor_set_layout.h" + "local_executable.h" + "local_executable_cache.h" + "local_executable_layout.h" + SRCS + "executable_loader.c" + "local_descriptor_set.c" + "local_descriptor_set_layout.c" + "local_executable.c" + "local_executable_cache.c" + "local_executable_layout.c" + DEPS + ::executable_library + iree::base::api + iree::base::core_headers + iree::base::tracing + iree::hal::api + PUBLIC +) + +iree_cc_library( + NAME + task_driver + HDRS + "task_command_buffer.h" + "task_device.h" + "task_driver.h" + "task_event.h" + "task_queue.h" + "task_queue_state.h" + "task_semaphore.h" + SRCS + "task_command_buffer.c" + "task_device.c" + "task_driver.c" + "task_event.c" + "task_queue.c" + "task_queue_state.c" + "task_semaphore.c" + DEPS + ::arena + ::event_pool + ::local + iree::base::api + iree::base::core_headers + iree::base::synchronization + iree::base::tracing + iree::base::wait_handle + iree::hal::api + iree::task + PUBLIC +) diff --git a/iree/hal/local/arena.c b/iree/hal/local/arena.c new file mode 100644 index 0000000000000..7b9529a1bb8de --- /dev/null +++ b/iree/hal/local/arena.c @@ -0,0 +1,185 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/arena.h" + +#include "iree/base/alignment.h" + +//===----------------------------------------------------------------------===// +// iree_arena_block_pool_t +//===----------------------------------------------------------------------===// + +void iree_arena_block_pool_initialize(iree_host_size_t total_block_size, + iree_allocator_t block_allocator, + iree_arena_block_pool_t* out_block_pool) { + memset(out_block_pool, 0, sizeof(*out_block_pool)); + out_block_pool->total_block_size = total_block_size; + out_block_pool->usable_block_size = + total_block_size - sizeof(iree_arena_block_t); + out_block_pool->block_allocator = block_allocator; + iree_atomic_arena_block_slist_initialize(&out_block_pool->available_slist); +} + +void iree_arena_block_pool_deinitialize(iree_arena_block_pool_t* block_pool) { + // Since all blocks must have been released we can just reuse trim (today) as + // it doesn't retain any blocks. + iree_arena_block_pool_trim(block_pool); + iree_atomic_arena_block_slist_deinitialize(&block_pool->available_slist); +} + +void iree_arena_block_pool_trim(iree_arena_block_pool_t* block_pool) { + iree_arena_block_t* head = NULL; + iree_atomic_arena_block_slist_flush( + &block_pool->available_slist, + IREE_ATOMIC_SLIST_FLUSH_ORDER_APPROXIMATE_LIFO, &head, NULL); + while (head) { + void* ptr = (uint8_t*)head - block_pool->usable_block_size; + head = head->next; + iree_allocator_free(block_pool->block_allocator, ptr); + } +} + +iree_status_t iree_arena_block_pool_acquire(iree_arena_block_pool_t* block_pool, + iree_arena_block_t** out_block) { + iree_arena_block_t* block = + iree_atomic_arena_block_slist_pop(&block_pool->available_slist); + + if (!block) { + // No blocks available; allocate one now. + // Note that it's possible for there to be a race here where one thread + // releases a block to the pool while we are trying to acquire one - in that + // case we may end up allocating a block when perhaps we didn't need to but + // that's fine - it's just one block and the contention means there's likely + // to be a need for more anyway. + uint8_t* block_base = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc(block_pool->block_allocator, + block_pool->total_block_size, + (void**)&block_base)); + block = (iree_arena_block_t*)(block_base + (block_pool->total_block_size - + sizeof(iree_arena_block_t))); + } + + block->next = NULL; + *out_block = block; + return iree_ok_status(); +} + +void iree_arena_block_pool_release(iree_arena_block_pool_t* block_pool, + iree_arena_block_t* block_head, + iree_arena_block_t* block_tail) { + iree_atomic_arena_block_slist_concat(&block_pool->available_slist, block_head, + block_tail); +} + +//===----------------------------------------------------------------------===// +// iree_arena_allocator_t +//===----------------------------------------------------------------------===// + +void iree_arena_initialize(iree_arena_block_pool_t* block_pool, + iree_arena_allocator_t* out_arena) { + memset(out_arena, 0, sizeof(*out_arena)); + out_arena->block_pool = block_pool; +} + +void iree_arena_deinitialize(iree_arena_allocator_t* arena) { + iree_arena_reset(arena); +} + +void iree_arena_reset(iree_arena_allocator_t* arena) { + if (arena->allocation_head != NULL) { + iree_arena_oversized_allocation_t* head = arena->allocation_head; + do { + void* ptr = (void*)head; + head = head->next; + iree_allocator_free(arena->block_pool->block_allocator, ptr); + } while (head); + arena->allocation_head = NULL; + } + if (arena->block_head != NULL) { + iree_arena_block_pool_release(arena->block_pool, arena->block_head, + arena->block_tail); + arena->block_head = NULL; + arena->block_tail = NULL; + } +} + +iree_status_t iree_arena_allocate(iree_arena_allocator_t* arena, + iree_host_size_t byte_length, + void** out_ptr) { + *out_ptr = NULL; + + iree_arena_block_pool_t* block_pool = arena->block_pool; + + if (byte_length > block_pool->usable_block_size) { + // Oversized allocation that can't be handled by the block pool. We'll + // allocate directly from the system allocator and track it ourselves for + // freeing during reset. + iree_host_size_t allocation_size = + sizeof(iree_arena_oversized_allocation_t) + byte_length; + iree_arena_oversized_allocation_t* allocation = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + block_pool->block_allocator, allocation_size, (void**)&allocation)); + allocation->next = arena->allocation_head; + arena->allocation_head = allocation; + arena->total_allocation_size += allocation_size; + arena->used_allocation_size += byte_length; + *out_ptr = (uint8_t*)allocation + sizeof(iree_arena_oversized_allocation_t); + return iree_ok_status(); + } + + // Pad length allocated so that each pointer bump is always ending at an + // aligned address and the next allocation will start aligned. + iree_host_size_t aligned_length = iree_align(byte_length, iree_max_align_t); + + // Check to see if the current block (if any) has space - if not, get another. + if (arena->block_head == NULL || + arena->block_bytes_remaining < aligned_length) { + iree_arena_block_t* block = NULL; + IREE_RETURN_IF_ERROR( + iree_arena_block_pool_acquire(arena->block_pool, &block)); + block->next = arena->block_head; + arena->block_head = block; + if (!arena->block_tail) arena->block_tail = block; + arena->total_allocation_size += block_pool->total_block_size; + arena->block_bytes_remaining = block_pool->usable_block_size; + } + + // Slice out the allocation from the current block. + void* ptr = (uint8_t*)arena->block_head - arena->block_bytes_remaining; + arena->block_bytes_remaining -= aligned_length; + arena->used_allocation_size += aligned_length; + *out_ptr = ptr; + return iree_ok_status(); +} + +static iree_status_t iree_arena_allocate_thunk(void* self, + iree_allocation_mode_t mode, + iree_host_size_t byte_length, + void** out_ptr) { + iree_arena_allocator_t* arena = (iree_arena_allocator_t*)self; + IREE_RETURN_IF_ERROR(iree_arena_allocate(arena, byte_length, out_ptr)); + if (mode & IREE_ALLOCATION_MODE_ZERO_CONTENTS) { + memset(*out_ptr, 0, byte_length); + } + return iree_ok_status(); +} + +iree_allocator_t iree_arena_allocator(iree_arena_allocator_t* arena) { + iree_allocator_t v = { + .self = arena, + .alloc = (iree_allocator_alloc_fn_t)iree_arena_allocate_thunk, + .free = NULL, + }; + return v; +} diff --git a/iree/hal/local/arena.h b/iree/hal/local/arena.h new file mode 100644 index 0000000000000..3498a699d99fd --- /dev/null +++ b/iree/hal/local/arena.h @@ -0,0 +1,155 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_ARENA_H_ +#define IREE_HAL_LOCAL_ARENA_H_ + +#include "iree/base/api.h" +#include "iree/base/atomic_slist.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_arena_block_pool_t +//===----------------------------------------------------------------------===// + +// NOTE: this struct is at the *end* of allocated blocks such that we don't mess +// with alignment - byte 0 of a block is always byte 0 of the allocation from +// the system. We can do this as all blocks have the same size so computing the +// footer offset from a pointer is easy. +typedef struct iree_arena_block_s { + struct iree_arena_block_s* next; +} iree_arena_block_t; + +// An atomic approximately LIFO singly-linked list. +IREE_TYPED_ATOMIC_SLIST_WRAPPER(iree_atomic_arena_block, iree_arena_block_t, + offsetof(iree_arena_block_t, next)); + +// A simple atomic fixed-size block pool. +// Blocks are allocated from the system as required and kept in the pool to +// satisfy future requests. Blocks are all of a uniform size specified when the +// pool is created. It's recommended that power-of-two sizes are used for the +// blocks so that the underlying allocator is more likely to bucket them +// appropriately. +// +// Thread-safe; multiple threads may acquire and release blocks from the pool. +// The underlying allocator must also be thread-safe. +typedef struct { + // Block size, in bytes. All blocks in the available_slist will have this + // byte size which includes the iree_arena_block_t footer. + iree_host_size_t total_block_size; + // Block size, in bytes, of the usable bytes within a block. + iree_host_size_t usable_block_size; + // Allocator used for allocating/freeing each allocation block. + iree_allocator_t block_allocator; + // Linked list of free blocks (LIFO). + iree_atomic_arena_block_slist_t available_slist; +} iree_arena_block_pool_t; + +// Initializes a new block pool in |out_block_pool|. +// |block_allocator| will be used to allocate and free blocks for the pool. +// Each block allocated will be |total_block_size| but have a slightly smaller +// usable size due to the tracking overhead. Prefer powers of two. +void iree_arena_block_pool_initialize(iree_host_size_t total_block_size, + iree_allocator_t block_allocator, + iree_arena_block_pool_t* out_block_pool); + +// Deinitializes a block pool and frees all allocations. +// All blocks that were acquired from the pool must have already been released +// back to it. +void iree_arena_block_pool_deinitialize(iree_arena_block_pool_t* block_pool); + +// Trims the pool by freeing unused blocks back to the allocator. +// Acquired blocks are not freed and remain valid. +void iree_arena_block_pool_trim(iree_arena_block_pool_t* block_pool); + +// Acquires a single block from the pool and returns it in |out_block|. +// The block may be either a new allocation with undefined contents or a reused +// prior allocation with undefined contents. +iree_status_t iree_arena_block_pool_acquire(iree_arena_block_pool_t* block_pool, + iree_arena_block_t** out_block); + +// Releases one or more blocks back to the block pool. +// Any blocks chained in |block_head| will also be released allowing for +// low-overhead resets when the blocks are already tracked in linked lists. +void iree_arena_block_pool_release(iree_arena_block_pool_t* block_pool, + iree_arena_block_t* block_head, + iree_arena_block_t* block_tail); + +//===----------------------------------------------------------------------===// +// iree_arena_allocator_t +//===----------------------------------------------------------------------===// + +typedef struct iree_arena_oversized_allocation_s { + struct iree_arena_oversized_allocation_s* next; +} iree_arena_oversized_allocation_t; + +// A lightweight bump-pointer arena allocator using a shared block pool. +// As allocations are made from the arena and block capacity is exhausted new +// blocks will be acquired from the pool. Upon being reset all blocks will be +// released back to the pool for reuse by either the same arena in the future or +// other arenas sharing the same pool. +// +// The size of each allocated block used by the arena is inherited from the +// block pool. Allocations from the arena may exceed the block size but will +// incur additional allocation overhead as the block pool is bypassed and the +// system allocator is directly used to service the request. +// +// Thread-compatible; the shared block pool is thread-safe and may be used by +// arenas on multiple threads but each arena must only be used by a single +// thread. +typedef struct { + // Fixed-size block pool used to acquire new blocks for the arena. + iree_arena_block_pool_t* block_pool; + // Total bytes allocated to the arena from the block pool or system allocator. + iree_host_size_t total_allocation_size; + // Total bytes allocated from the arena; the utilization of the arena can be + // checked with `used_allocation_size / total_allocation_size`. + iree_host_size_t used_allocation_size; + // Linked list of oversized allocations made directly from the system + // allocator used by the block pool. + iree_arena_oversized_allocation_t* allocation_head; + // Linked list of allocated blocks maintained so that reset can release them. + iree_arena_block_t* block_head; + iree_arena_block_t* block_tail; + // The number of bytes remaining in the block pointed to by block_head. + iree_host_size_t block_bytes_remaining; +} iree_arena_allocator_t; + +// Initializes an arena that will use |block_pool| for allocating blocks as +// needed. +void iree_arena_initialize(iree_arena_block_pool_t* block_pool, + iree_arena_allocator_t* out_arena); + +// Deinitializes the arena and returns allocated blocks to the parent pool. +void iree_arena_deinitialize(iree_arena_allocator_t* arena); + +// Resets the entire arena and returns allocated blocks to the parent pool. +void iree_arena_reset(iree_arena_allocator_t* arena); + +// Allocates |byte_length| contiguous bytes from the arena. +iree_status_t iree_arena_allocate(iree_arena_allocator_t* arena, + iree_host_size_t byte_length, void** out_ptr); + +// Returns an iree_allocator_t that allocates from the given |arena|. +// Frees are ignored as arenas can only be reset as a whole. +iree_allocator_t iree_arena_allocator(iree_arena_allocator_t* arena); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_ARENA_H_ diff --git a/iree/hal/local/event_pool.c b/iree/hal/local/event_pool.c new file mode 100644 index 0000000000000..0a6c1dda9d81e --- /dev/null +++ b/iree/hal/local/event_pool.c @@ -0,0 +1,172 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/event_pool.h" + +#include "iree/base/debugging.h" +#include "iree/base/synchronization.h" +#include "iree/base/tracing.h" + +struct iree_hal_local_event_pool_s { + // Allocator used to create the event pool. + iree_allocator_t host_allocator; + // Guards the pool. Since this pool is used to get operating system-level + // event objects that will be signaled and waited on using syscalls it's got + // relatively low contention: callers are rate limited by how fast they can + // signal and wait on the events they get. + iree_slim_mutex_t mutex; + // Maximum number of events that will be maintained in the pool. More events + // may be allocated at any time but when they are no longer needed they will + // be disposed directly. + iree_host_size_t available_capacity; + // Total number of available + iree_host_size_t available_count; + // Dense left-aligned list of available_count events. + iree_event_t available_list[]; +}; + +iree_status_t iree_hal_local_event_pool_allocate( + iree_host_size_t available_capacity, iree_allocator_t host_allocator, + iree_hal_local_event_pool_t** out_event_pool) { + IREE_ASSERT_ARGUMENT(out_event_pool); + *out_event_pool = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_event_pool_t* event_pool = NULL; + iree_host_size_t total_size = + sizeof(*event_pool) + + available_capacity * sizeof(event_pool->available_list[0]); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, total_size, (void**)&event_pool)); + event_pool->host_allocator = host_allocator; + event_pool->available_capacity = available_capacity; + event_pool->available_count = 0; + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < available_capacity; ++i) { + status = iree_event_initialize( + /*initial_state=*/false, + &event_pool->available_list[event_pool->available_count++]); + if (!iree_status_is_ok(status)) break; + } + + if (iree_status_is_ok(status)) { + *out_event_pool = event_pool; + } else { + iree_hal_local_event_pool_free(event_pool); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_local_event_pool_free(iree_hal_local_event_pool_t* event_pool) { + iree_allocator_t host_allocator = event_pool->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < event_pool->available_count; ++i) { + iree_event_deinitialize(&event_pool->available_list[i]); + } + iree_slim_mutex_deinitialize(&event_pool->mutex); + iree_allocator_free(host_allocator, event_pool); + + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_local_event_pool_acquire( + iree_hal_local_event_pool_t* event_pool, iree_host_size_t event_count, + iree_event_t* out_events) { + IREE_ASSERT_ARGUMENT(event_pool); + if (!event_count) return iree_ok_status(); + IREE_ASSERT_ARGUMENT(out_events); + + // We'll try to get what we can from the pool and fall back to initializing + // new events. + iree_host_size_t remaining_count = event_count; + + // Try first to grab from the pool. + iree_slim_mutex_lock(&event_pool->mutex); + iree_host_size_t from_pool_count = + iree_min(event_pool->available_count, event_count); + if (from_pool_count > 0) { + iree_host_size_t pool_base_index = + event_pool->available_count - from_pool_count; + memcpy(out_events, &event_pool->available_list[pool_base_index], + from_pool_count * sizeof(iree_event_t)); + event_pool->available_count -= from_pool_count; + remaining_count -= from_pool_count; + } + iree_slim_mutex_unlock(&event_pool->mutex); + + // Allocate the rest of the events. + if (remaining_count > 0) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < remaining_count; ++i) { + status = iree_event_initialize(/*initial_state=*/false, + &out_events[from_pool_count + i]); + if (!iree_status_is_ok(status)) { + // Must release all events we've acquired so far. + iree_hal_local_event_pool_release(event_pool, from_pool_count + i, + out_events); + IREE_TRACE_ZONE_END(z0); + return status; + } + } + IREE_TRACE_ZONE_END(z0); + } + + return iree_ok_status(); +} + +void iree_hal_local_event_pool_release(iree_hal_local_event_pool_t* event_pool, + iree_host_size_t event_count, + iree_event_t* events) { + IREE_ASSERT_ARGUMENT(event_pool); + if (!event_count) return; + IREE_ASSERT_ARGUMENT(events); + + // We'll try to release all we can back to the pool and then deinitialize + // the ones that won't fit. + iree_host_size_t remaining_count = event_count; + + // Try first to release to the pool. + // Note that we reset the events we add back to the pool so that they are + // ready to be acquired again. + iree_slim_mutex_lock(&event_pool->mutex); + iree_host_size_t to_pool_count = + iree_min(event_pool->available_capacity - event_pool->available_count, + event_count); + if (to_pool_count > 0) { + iree_host_size_t pool_base_index = event_pool->available_count; + for (iree_host_size_t i = 0; i < to_pool_count; ++i) { + iree_event_reset(&events[i]); + } + memcpy(&event_pool->available_list[pool_base_index], events, + to_pool_count * sizeof(iree_event_t)); + event_pool->available_count += to_pool_count; + remaining_count -= to_pool_count; + } + iree_slim_mutex_unlock(&event_pool->mutex); + + // Deallocate the rest of the events. We don't bother resetting them as we are + // getting rid of them. + if (remaining_count > 0) { + IREE_TRACE_ZONE_BEGIN(z0); + for (iree_host_size_t i = 0; i < remaining_count; ++i) { + iree_event_deinitialize(&events[to_pool_count + i]); + } + IREE_TRACE_ZONE_END(z0); + } +} diff --git a/iree/hal/local/event_pool.h b/iree/hal/local/event_pool.h new file mode 100644 index 0000000000000..6a4dd9c26071e --- /dev/null +++ b/iree/hal/local/event_pool.h @@ -0,0 +1,57 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_EVENT_POOL_H_ +#define IREE_HAL_LOCAL_EVENT_POOL_H_ + +#include "iree/base/api.h" +#include "iree/base/wait_handle.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// A simple pool of iree_event_ts to recycle. +// +// Thread-safe; multiple threads may acquire and release events from the pool. +typedef struct iree_hal_local_event_pool_s iree_hal_local_event_pool_t; + +// Allocates a new event pool with up to |available_capacity| events. +iree_status_t iree_hal_local_event_pool_allocate( + iree_host_size_t available_capacity, iree_allocator_t host_allocator, + iree_hal_local_event_pool_t** out_event_pool); + +// Deallocates an event pool and destroys all events. +// All events that were acquired from the pool must have already been released +// back to it prior to deallocation. +void iree_hal_local_event_pool_free(iree_hal_local_event_pool_t* event_pool); + +// Acquires one or more events from the event pool. +// The returned events will be unsignaled and ready for use. Callers may set and +// reset the events as much as they want prior to releasing them back to the +// pool with iree_hal_local_event_pool_release. +iree_status_t iree_hal_local_event_pool_acquire( + iree_hal_local_event_pool_t* event_pool, iree_host_size_t event_count, + iree_event_t* out_events); + +// Releases one or more events back to the block pool. +void iree_hal_local_event_pool_release(iree_hal_local_event_pool_t* event_pool, + iree_host_size_t event_count, + iree_event_t* events); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_EVENT_POOL_H_ diff --git a/iree/hal/local/executable_library.h b/iree/hal/local/executable_library.h new file mode 100644 index 0000000000000..a8ec8b2e66b3d --- /dev/null +++ b/iree/hal/local/executable_library.h @@ -0,0 +1,147 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_EXECUTABLE_LIBRARY_H_ +#define IREE_HAL_LOCAL_EXECUTABLE_LIBRARY_H_ + +// NOTE: this file is designed to be a standalone header: it is embedded in the +// compiler and must not take any dependences on the runtime HAL code. +// Changes here will require changes to the compiler and must be versioned as if +// this was a schema: backwards-incompatible changes require version bumps or +// the ability to feature-detect at runtime. + +#include +#include + +//===----------------------------------------------------------------------===// +// Versioning and interface querying +//===----------------------------------------------------------------------===// + +// Known valid version values. +enum iree_hal_executable_library_version_e { + // iree_hal_executable_library_v0_t is used as the API communication + // structure. + IREE_HAL_EXECUTABLE_LIBRARY_VERSION_0 = 0u, +}; +typedef uint32_t iree_hal_executable_library_version_t; + +// The latest version of the library API; can be used to populate the +// iree_hal_executable_library_header_t::version when building libraries. +#define IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION \ + IREE_HAL_EXECUTABLE_LIBRARY_VERSION_0 + +// A header present at the top of all versions of the library API used by the +// runtime to ensure version compatibility. +typedef struct { + // Version of the API this library was built with, which was likely the value + // of IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION. + iree_hal_executable_library_version_t version; + + // Name used for logging/diagnostics. + const char* name; +} iree_hal_executable_library_header_t; + +// Exported function from dynamic libraries for querying library information. +// The provided |max_version| is the maximum version the caller supports; +// callees must return NULL if their lowest available version is greater +// than the max version supported by the caller. +typedef const iree_hal_executable_library_header_t* ( + *iree_hal_executable_library_query_fn_t)( + iree_hal_executable_library_version_t max_version); + +// Function name exported from dynamic libraries (pass to dlsym). +#define IREE_HAL_EXECUTABLE_LIBRARY_EXPORT_NAME \ + "iree_hal_executable_library_query" + +//===----------------------------------------------------------------------===// +// IREE_HAL_EXECUTABLE_LIBRARY_VERSION_0 +//===----------------------------------------------------------------------===// + +// Read-only per-dispatch state passed to each tile in a dispatch. +typedef struct { + uint32_t reserved; +} iree_hal_executable_dispatch_state_v0_t; + +typedef union { + struct { + uint32_t x; + uint32_t y; + uint32_t z; + }; + uint32_t value[3]; +} iree_hal_vec3_t; + +#if defined(_MSC_VER) +typedef __declspec( + align(16)) const uint32_t* iree_hal_executable_push_constants_ptr_t; +#else +typedef const uint32_t* iree_hal_executable_push_constants_ptr_t + __attribute__((align_value(16))); +#endif // MSVC + +typedef void* iree_hal_executable_binding_ptr_t; + +// Function signature of exported executable entry points. +// The same |state| is passed to all tiles in a dispatch, with other arguments +// such as |workgroup_id| varying per-tile (counting to the |workgroup_count|). +// Each tile represents |workgroup_size| local invocations in the global +// |workgroup_count| grid. +// +// 0 or more push constants are available at |push_constants| with the count +// being determined by the sidechannel information provided by the compiler. +// +// The |bindings| list is a dense set of pointers to I/O data with the count and +// ordering determined by the compiler. +typedef void (*iree_hal_executable_dispatch_v0_t)( + const iree_hal_executable_dispatch_state_v0_t* state, + const iree_hal_vec3_t* workgroup_id, const iree_hal_vec3_t* workgroup_size, + const iree_hal_vec3_t* workgroup_count, + const iree_hal_executable_push_constants_ptr_t push_constants, + const iree_hal_executable_binding_ptr_t* bindings); + +// Structure used for v0 library interfaces. +// The entire structure is designed to be read-only and able to live embedded in +// the binary .rdata section. +// +// Implementations may still choose to heap allocate this structure and modify +// at runtime so long as they observe the thread-safety guarantees. For example, +// a JIT may default all entry_points to JIT thunk functions and then swap them +// out for the translated function pointers. +typedef struct { + // Version/metadata header. Will have a version of + // IREE_HAL_EXECUTABLE_LIBRARY_VERSION_0. + const iree_hal_executable_library_header_t* header; + + // The total number of entry points available in the library. Bounds all of + // the tables below. + uint32_t entry_point_count; + + // Table of export function entry points matching the ordinals defined during + // library generation. The runtime will use this table to map the ordinals to + // function pointers for execution. + const iree_hal_executable_dispatch_v0_t* entry_points; + + // Optional table of export function entry point names 1:1 with entry_points. + // These names are only used for tracing/debugging and can be omitted to save + // binary size. + const char** entry_point_names; + + // Optional table of entry point tags that describe the entry point in a + // human-readable format useful for verbose logging. The string values, when + // present, may be attached to tracing/debugging events related to the entry + // point. + const char** entry_point_tags; +} iree_hal_executable_library_v0_t; + +#endif // IREE_HAL_LOCAL_EXECUTABLE_LIBRARY_H_ diff --git a/iree/hal/local/executable_loader.c b/iree/hal/local/executable_loader.c new file mode 100644 index 0000000000000..b3a400562b1eb --- /dev/null +++ b/iree/hal/local/executable_loader.c @@ -0,0 +1,60 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/executable_loader.h" + +void iree_hal_executable_loader_initialize( + const void* vtable, iree_hal_executable_loader_t* out_base_loader) { + iree_atomic_ref_count_init(&out_base_loader->ref_count); + out_base_loader->vtable = vtable; +} + +void iree_hal_executable_loader_retain( + iree_hal_executable_loader_t* executable_loader) { + if (IREE_LIKELY(executable_loader)) { + iree_atomic_ref_count_inc(&executable_loader->ref_count); + } +} + +void iree_hal_executable_loader_release( + iree_hal_executable_loader_t* executable_loader) { + if (IREE_LIKELY(executable_loader) && + iree_atomic_ref_count_dec(&executable_loader->ref_count) == 1) { + executable_loader->vtable->destroy(executable_loader); + } +} + +bool iree_hal_executable_loader_query_support( + iree_hal_executable_loader_t* executable_loader, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode) { + IREE_ASSERT_ARGUMENT(executable_loader); + return executable_loader->vtable->query_support( + executable_loader, executable_format, caching_mode); +} + +iree_status_t iree_hal_executable_loader_try_load( + iree_hal_executable_loader_t* executable_loader, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(executable_loader); + IREE_ASSERT_ARGUMENT(executable_data.data); + IREE_ASSERT_ARGUMENT(out_executable); + return executable_loader->vtable->try_load( + executable_loader, executable_layout, executable_format, caching_mode, + executable_data, out_executable); +} diff --git a/iree/hal/local/executable_loader.h b/iree/hal/local/executable_loader.h new file mode 100644 index 0000000000000..7c18ff5c557e2 --- /dev/null +++ b/iree/hal/local/executable_loader.h @@ -0,0 +1,116 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_EXECUTABLE_LOADER_H_ +#define IREE_HAL_LOCAL_EXECUTABLE_LOADER_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/base/atomics.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree_hal_executable_loader_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_executable_loader_vtable_s + iree_hal_executable_loader_vtable_t; + +// Interface for compiled executable loader implementations. +// A loader may be as simple as something that resolves function pointers in the +// local executable for statically linked executables or as complex as a custom +// relocatable ELF loader. Loaders are registered and persist for each device +// they are attached to and may keep internal caches or memoize resources shared +// by multiple loaded executables. +// +// Thread-safe - multiple threads may load executables (including the *same* +// executable) simultaneously. +typedef struct { + iree_atomic_ref_count_t ref_count; + const iree_hal_executable_loader_vtable_t* vtable; +} iree_hal_executable_loader_t; + +// Initializes the base iree_hal_executable_loader_t type. +// Called by subclasses upon allocating their loader. +void iree_hal_executable_loader_initialize( + const void* vtable, iree_hal_executable_loader_t* out_base_loader); + +// Retains the given |executable_loader| for the caller. +void iree_hal_executable_loader_retain( + iree_hal_executable_loader_t* executable_loader); + +// Releases the given |executable_loader| from the caller. +void iree_hal_executable_loader_release( + iree_hal_executable_loader_t* executable_loader); + +// Returns true if the loader can load executables of the given +// |executable_format|. Note that loading may still fail if the executable uses +// features not available on the current host or runtime. +bool iree_hal_executable_loader_query_support( + iree_hal_executable_loader_t* executable_loader, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode); + +// Tries loading the |executable_data| provided in the given +// |executable_format|. May fail even if the executable is valid if it requires +// features not supported by the current host or runtime (such as available +// architectures, imports, etc). +// +// Depending on loader ability the |caching_mode| is used to enable certain +// features such as instrumented profiling. Not all formats support these +// features and cooperation of both the compiler producing the executables and +// the runtime loader and system are required. +// +// Returns IREE_STATUS_CANCELLED when the loader cannot load the file in the +// given format. +iree_status_t iree_hal_executable_loader_try_load( + iree_hal_executable_loader_t* executable_loader, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable); + +//===----------------------------------------------------------------------===// +// iree_hal_executable_loader_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_executable_loader_vtable_s { + void(IREE_API_PTR* destroy)(iree_hal_executable_loader_t* executable_loader); + + bool(IREE_API_PTR* query_support)( + iree_hal_executable_loader_t* executable_loader, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode); + + iree_status_t(IREE_API_PTR* try_load)( + iree_hal_executable_loader_t* executable_loader, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable); +} iree_hal_executable_loader_vtable_t; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_EXECUTABLE_LOADER_H_ diff --git a/iree/hal/local/loaders/BUILD b/iree/hal/local/loaders/BUILD new file mode 100644 index 0000000000000..efbd44ddab41b --- /dev/null +++ b/iree/hal/local/loaders/BUILD @@ -0,0 +1,95 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Default implementations for HAL types that use the host resources. +# These are generally just wrappers around host heap memory and host threads. + +load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "legacy_library_loader", + srcs = ["legacy_library_loader.cc"], + hdrs = ["legacy_library_loader.h"], + defines = [ + "IREE_HAL_HAVE_LEGACY_LIBRARY_LOADER=1", + ], + deps = [ + "//iree/base:api", + "//iree/base:dynamic_library", + "//iree/base:file_io", + "//iree/base:file_path", + "//iree/base:flatcc", + "//iree/base:tracing", + "//iree/hal:api", + "//iree/hal/local", + "//iree/schemas:dylib_executable_def_c_fbs", + ], +) + +cc_library( + name = "system_library_loader", + srcs = ["system_library_loader.c"], + hdrs = ["system_library_loader.h"], + defines = [ + "IREE_HAL_HAVE_SYSTEM_LIBRARY_LOADER=1", + ], + deps = [ + "//iree/base:api", + "//iree/base:file_io", + "//iree/base:flatcc", + "//iree/base:tracing", + "//iree/hal:api", + "//iree/hal/local", + ], +) + +iree_cmake_extra_content( + content = """ +if(${IREE_HAL_DRIVER_VMLA}) +""", + inline = True, +) + +cc_library( + name = "vmla_module_loader", + srcs = ["vmla_module_loader.cc"], + hdrs = ["vmla_module_loader.h"], + defines = [ + "IREE_HAL_HAVE_VMLA_MODULE_LOADER=1", + ], + deps = [ + "//iree/base:api", + "//iree/base:flatcc", + "//iree/base:tracing", + "//iree/hal:api", + "//iree/hal/local", + "//iree/modules/vmla:op_module", + "//iree/schemas:vmla_executable_def_c_fbs", + "//iree/vm", + "//iree/vm:bytecode_module", + ], +) + +iree_cmake_extra_content( + content = """ +endif() +""", + inline = True, +) diff --git a/iree/hal/local/loaders/CMakeLists.txt b/iree/hal/local/loaders/CMakeLists.txt new file mode 100644 index 0000000000000..65bf612568cf0 --- /dev/null +++ b/iree/hal/local/loaders/CMakeLists.txt @@ -0,0 +1,82 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +iree_add_all_subdirs() + +iree_cc_library( + NAME + legacy_library_loader + HDRS + "legacy_library_loader.h" + SRCS + "legacy_library_loader.cc" + DEPS + iree::base::api + iree::base::dynamic_library + iree::base::file_io + iree::base::file_path + iree::base::flatcc + iree::base::tracing + iree::hal::api + iree::hal::local + iree::schemas::dylib_executable_def_c_fbs + DEFINES + "IREE_HAL_HAVE_LEGACY_LIBRARY_LOADER=1" + PUBLIC +) + +iree_cc_library( + NAME + system_library_loader + HDRS + "system_library_loader.h" + SRCS + "system_library_loader.c" + DEPS + iree::base::api + iree::base::file_io + iree::base::flatcc + iree::base::tracing + iree::hal::api + iree::hal::local + DEFINES + "IREE_HAL_HAVE_SYSTEM_LIBRARY_LOADER=1" + PUBLIC +) + +if(${IREE_HAL_DRIVER_VMLA}) + +iree_cc_library( + NAME + vmla_module_loader + HDRS + "vmla_module_loader.h" + SRCS + "vmla_module_loader.cc" + DEPS + iree::base::api + iree::base::flatcc + iree::base::tracing + iree::hal::api + iree::hal::local + iree::modules::vmla::op_module + iree::schemas::vmla_executable_def_c_fbs + iree::vm + iree::vm::bytecode_module + DEFINES + "IREE_HAL_HAVE_VMLA_MODULE_LOADER=1" + PUBLIC +) + +endif() diff --git a/iree/hal/local/loaders/legacy_library_loader.cc b/iree/hal/local/loaders/legacy_library_loader.cc new file mode 100644 index 0000000000000..30c78447c6d11 --- /dev/null +++ b/iree/hal/local/loaders/legacy_library_loader.cc @@ -0,0 +1,407 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/loaders/legacy_library_loader.h" + +#include "iree/base/dynamic_library.h" +#include "iree/base/file_io.h" +#include "iree/base/file_path.h" +#include "iree/base/tracing.h" +#include "iree/hal/local/local_executable.h" + +// flatcc schemas: +#include "iree/base/flatcc.h" +#include "iree/schemas/dylib_executable_def_reader.h" +#include "iree/schemas/dylib_executable_def_verifier.h" + +//===----------------------------------------------------------------------===// +// Verification and file utilities +//===----------------------------------------------------------------------===// + +// Verifies the structure of the flatbuffer so that we can avoid doing so during +// runtime. There are still some conditions we must be aware of (such as omitted +// names on functions with internal linkage), however we shouldn't need to +// bounds check anything within the flatbuffer after this succeeds. +static iree_status_t iree_hal_dylib_executable_flatbuffer_verify( + iree_const_byte_span_t flatbuffer_data) { + // Special handling for valid but mismatching flatbuffers. + if (!flatbuffer_data.data || flatbuffer_data.data_length < 16 || + !flatbuffers_has_identifier(flatbuffer_data.data, + iree_DyLibExecutableDef_file_identifier)) { + return iree_status_from_code(IREE_STATUS_CANCELLED); + } + + // Run flatcc generated verification. This ensures all pointers are in-bounds + // and that we can safely walk the file, but not that the actual contents of + // the flatbuffer meet our expectations. + int verify_ret = iree_DyLibExecutableDef_verify_as_root( + flatbuffer_data.data, flatbuffer_data.data_length); + if (verify_ret != flatcc_verify_ok) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer verification failed: %s", + flatcc_verify_error_string(verify_ret)); + } + + iree_DyLibExecutableDef_table_t executable_def = + iree_DyLibExecutableDef_as_root(flatbuffer_data.data); + + flatbuffers_string_vec_t entry_points_vec = + iree_DyLibExecutableDef_entry_points_get(executable_def); + size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); + for (size_t i = 0; i < entry_point_count; ++i) { + if (!flatbuffers_string_len( + flatbuffers_string_vec_at(entry_points_vec, i))) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable entry point %zu has no name", i); + } + } + + if (!flatbuffers_uint8_vec_len( + iree_DyLibExecutableDef_library_embedded_get(executable_def))) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable library_embedded is missing/empty"); + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_legacy_executable_t +//===----------------------------------------------------------------------===// + +typedef void (*iree_hal_legacy_executable_fn_ptr_t)(void* const*, + const uint32_t*, + const uint32_t*, + const uint32_t*, + const uint32_t*); + +typedef struct { + iree_hal_local_executable_t base; + + // Flatbuffer definition referencing the executable memory. + iree_DyLibExecutableDef_table_t def; + + // Temporary files created as part of extraction. + // Strings are allocated from the host allocator. + iree_host_size_t temp_file_count; + iree_string_view_t temp_files[8]; + + // Loaded platform dynamic library. + iree::DynamicLibrary* library; + + // Resolved entry points from the dynamic library. + iree_host_size_t entry_fn_count; + iree_hal_legacy_executable_fn_ptr_t entry_fns[]; +} iree_hal_legacy_executable_t; + +extern const iree_hal_local_executable_vtable_t + iree_hal_legacy_executable_vtable; + +static iree_status_t iree_hal_legacy_executable_extract_and_load( + iree_hal_legacy_executable_t* executable, iree_allocator_t host_allocator) { + // Write the embedded library out to a temp file, since all of the dynamic + // library APIs work with files. We could instead use in-memory files on + // platforms where that is convenient. + // + // TODO(#3845): use dlopen on an fd with either dlopen(/proc/self/fd/NN), + // fdlopen, or android_dlopen_ext to avoid needing to write the file to disk. + // Can fallback to memfd_create + dlopen where available, and fallback from + // that to disk (maybe just windows/mac). + IREE_ASSIGN_OR_RETURN(auto library_temp_path, + iree::file_io::GetTempFile("dylib_executable")); + +// Add platform-specific file extensions so opinionated dynamic library +// loaders are more likely to find the file: +#if defined(IREE_PLATFORM_WINDOWS) + library_temp_path += ".dll"; +#else + library_temp_path += ".so"; +#endif // IREE_PLATFORM_WINDOWS + + iree_string_view_t library_temp_file = iree_string_view_empty(); + IREE_RETURN_IF_ERROR( + iree_allocator_clone(host_allocator, + iree_make_const_byte_span(library_temp_path.data(), + library_temp_path.size()), + (void**)&library_temp_file.data)); + library_temp_file.size = library_temp_path.size(); + executable->temp_files[executable->temp_file_count++] = library_temp_file; + + flatbuffers_uint8_vec_t embedded_library_vec = + iree_DyLibExecutableDef_library_embedded_get(executable->def); + IREE_RETURN_IF_ERROR(iree::file_io::SetFileContents( + library_temp_path, + absl::string_view(reinterpret_cast(embedded_library_vec), + flatbuffers_uint8_vec_len(embedded_library_vec)))); + + IREE_ASSIGN_OR_RETURN(auto library, + iree::DynamicLibrary::Load(library_temp_path.c_str())); + + flatbuffers_string_t debug_database_filename = + iree_DyLibExecutableDef_debug_database_filename_get(executable->def); + flatbuffers_uint8_vec_t debug_database_embedded_vec = + iree_DyLibExecutableDef_debug_database_embedded_get(executable->def); + if (flatbuffers_string_len(debug_database_filename) && + flatbuffers_uint8_vec_len(debug_database_embedded_vec)) { + IREE_TRACE_SCOPE0("DyLibExecutable::AttachDebugDatabase"); + auto debug_database_path = iree::file_path::JoinPaths( + iree::file_path::DirectoryName(library_temp_path), + absl::string_view(debug_database_filename, + flatbuffers_string_len(debug_database_filename))); + iree_string_view_t debug_database_file = iree_string_view_empty(); + IREE_RETURN_IF_ERROR(iree_allocator_clone( + host_allocator, + iree_make_const_byte_span(debug_database_path.data(), + debug_database_path.size()), + (void**)&debug_database_file.data)); + debug_database_file.size = debug_database_path.size(); + executable->temp_files[executable->temp_file_count++] = debug_database_file; + IREE_IGNORE_ERROR(iree::file_io::SetFileContents( + debug_database_path, + absl::string_view( + reinterpret_cast(debug_database_embedded_vec), + flatbuffers_uint8_vec_len(debug_database_embedded_vec)))); + library->AttachDebugDatabase(debug_database_path.c_str()); + } + + executable->library = library.release(); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_legacy_executable_resolve_symbols( + iree_hal_legacy_executable_t* executable) { + flatbuffers_string_vec_t entry_points_vec = + iree_DyLibExecutableDef_entry_points_get(executable->def); + for (iree_host_size_t i = 0; i < executable->entry_fn_count; ++i) { + flatbuffers_string_t entry_point_str = + flatbuffers_string_vec_at(entry_points_vec, i); + void* symbol = executable->library->GetSymbol(entry_point_str); + if (!symbol) { + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "symbol %s not exported by the dynamic library, check visibility", + entry_point_str); + } + executable->entry_fns[i] = (iree_hal_legacy_executable_fn_ptr_t)symbol; + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_legacy_executable_create( + iree_hal_executable_layout_t* base_layout, + iree_DyLibExecutableDef_table_t executable_def, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(base_layout); + IREE_ASSERT_ARGUMENT(executable_def); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_executable_layout_t* local_layout = + iree_hal_local_executable_layout_cast(base_layout); + IREE_ASSERT_ARGUMENT(local_layout); + + flatbuffers_string_vec_t entry_points_vec = + iree_DyLibExecutableDef_entry_points_get(executable_def); + iree_host_size_t entry_point_count = + flatbuffers_string_vec_len(entry_points_vec); + + iree_hal_legacy_executable_t* executable = NULL; + iree_host_size_t total_size = + sizeof(*executable) + entry_point_count * sizeof(*executable->entry_fns); + iree_status_t status = iree_allocator_malloc(local_layout->host_allocator, + total_size, (void**)&executable); + if (iree_status_is_ok(status)) { + iree_hal_local_executable_initialize(&iree_hal_legacy_executable_vtable, + local_layout, &executable->base); + executable->def = executable_def; + executable->entry_fn_count = entry_point_count; + } + if (iree_status_is_ok(status)) { + // Attempt to extract the embedded flatbuffer library and load it. + // Will scribble information into executable. + // This is bad, but ehh all this is getting deleted soon and hopefully we + // can avoid ever touching the disk at all. + status = iree_hal_legacy_executable_extract_and_load( + executable, local_layout->host_allocator); + } + if (iree_status_is_ok(status)) { + // Attempt to resolve symbols for all entry points. + status = iree_hal_legacy_executable_resolve_symbols(executable); + } + + if (iree_status_is_ok(status)) { + *out_executable = (iree_hal_executable_t*)executable; + } else { + iree_hal_executable_release((iree_hal_executable_t*)executable); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_legacy_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_legacy_executable_t* executable = + (iree_hal_legacy_executable_t*)base_executable; + iree_allocator_t host_allocator = executable->base.layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + +#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + // Leak the library when tracing, since the profiler may still be reading it. + // TODO(benvanik): move to an atexit handler instead, verify with ASAN/MSAN + // TODO(scotttodd): Make this compatible with testing: + // two test cases, one for each function in the same executable + // first test case passes, second fails to open the file (already open) +#else + delete executable->library; +#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + + for (iree_host_size_t i = 0; i < executable->temp_file_count; ++i) { + iree_string_view_t file_path = executable->temp_files[i]; + iree::file_io::DeleteFile(std::string(file_path.data, file_path.size)) + .IgnoreError(); + iree_allocator_free(host_allocator, (void*)file_path.data); + } + + iree_hal_local_executable_deinitialize( + (iree_hal_local_executable_t*)base_executable); + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_legacy_executable_issue_call( + iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal, + const iree_hal_local_executable_call_t* call) { + iree_hal_legacy_executable_t* executable = + (iree_hal_legacy_executable_t*)base_executable; + + if (IREE_UNLIKELY(ordinal >= executable->entry_fn_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "entry point ordinal out of bounds"); + } + +#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + flatbuffers_string_t entry_point_str = flatbuffers_string_vec_at( + iree_DyLibExecutableDef_entry_points_get(executable->def), ordinal); + iree_string_view_t entry_point_name = iree_make_string_view( + entry_point_str, flatbuffers_string_len(entry_point_str)); + if (iree_string_view_is_empty(entry_point_name)) { + entry_point_name = iree_make_cstring_view("unknown_dylib_call"); + } + IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(z0, entry_point_name.data, + entry_point_name.size); +#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + + executable->entry_fns[ordinal](call->bindings, call->push_constants, + (const uint32_t*)&call->workgroup_id, + (const uint32_t*)&call->workgroup_count, + (const uint32_t*)&call->workgroup_size); + + IREE_TRACE_ZONE_END(z0); + + return iree_ok_status(); +} + +const iree_hal_local_executable_vtable_t iree_hal_legacy_executable_vtable = { + /*.base=*/ + { + /*.destroy=*/iree_hal_legacy_executable_destroy, + }, + /*.issue_call=*/iree_hal_legacy_executable_issue_call, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_legacy_library_loader_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_executable_loader_t base; + iree_allocator_t host_allocator; +} iree_hal_legacy_library_loader_t; + +extern const iree_hal_executable_loader_vtable_t + iree_hal_legacy_library_loader_vtable; + +iree_status_t iree_hal_legacy_library_loader_create( + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader) { + IREE_ASSERT_ARGUMENT(out_executable_loader); + *out_executable_loader = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_legacy_library_loader_t* executable_loader = NULL; + iree_status_t status = iree_allocator_malloc( + host_allocator, sizeof(*executable_loader), (void**)&executable_loader); + if (iree_status_is_ok(status)) { + iree_hal_executable_loader_initialize( + &iree_hal_legacy_library_loader_vtable, &executable_loader->base); + executable_loader->host_allocator = host_allocator; + *out_executable_loader = (iree_hal_executable_loader_t*)executable_loader; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_legacy_library_loader_destroy( + iree_hal_executable_loader_t* base_executable_loader) { + iree_hal_legacy_library_loader_t* executable_loader = + (iree_hal_legacy_library_loader_t*)base_executable_loader; + iree_allocator_t host_allocator = executable_loader->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_loader); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_legacy_library_loader_query_support( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode) { + return executable_format == iree_hal_make_executable_format("DLIB"); +} + +static iree_status_t iree_hal_legacy_library_loader_try_load( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Verify and fetch the executable flatbuffer wrapper. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_dylib_executable_flatbuffer_verify(executable_data)); + iree_DyLibExecutableDef_table_t executable_def = + iree_DyLibExecutableDef_as_root(executable_data.data); + + // Perform the load (and requisite disgusting hackery). + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_legacy_executable_create(executable_layout, executable_def, + out_executable)); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +const iree_hal_executable_loader_vtable_t + iree_hal_legacy_library_loader_vtable = { + /*.destroy=*/iree_hal_legacy_library_loader_destroy, + /*.query_support=*/iree_hal_legacy_library_loader_query_support, + /*.try_load=*/iree_hal_legacy_library_loader_try_load, +}; diff --git a/iree/hal/local/loaders/legacy_library_loader.h b/iree/hal/local/loaders/legacy_library_loader.h new file mode 100644 index 0000000000000..98cc5effe877b --- /dev/null +++ b/iree/hal/local/loaders/legacy_library_loader.h @@ -0,0 +1,42 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOADERS_LEGACY_LIBRARY_LOADER_H_ +#define IREE_HAL_LOCAL_LOADERS_LEGACY_LIBRARY_LOADER_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/local/executable_loader.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates an executable loader that can load files from platform-supported +// dynamic libraries (such as .dylib on darwin, .so on linux, .dll on windows). +// +// This uses the legacy "dylib"-style format that will be deleted soon and is +// only a placeholder until the compiler can be switched to output +// iree_hal_executable_library_t-compatible files. +iree_status_t iree_hal_legacy_library_loader_create( + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOADERS_LEGACY_LIBRARY_LOADER_H_ diff --git a/iree/hal/local/loaders/system_library_loader.c b/iree/hal/local/loaders/system_library_loader.c new file mode 100644 index 0000000000000..503f41f8bad39 --- /dev/null +++ b/iree/hal/local/loaders/system_library_loader.c @@ -0,0 +1,205 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/loaders/system_library_loader.h" + +#include "iree/base/tracing.h" +#include "iree/hal/local/local_executable.h" + +// flatcc schemas: +#include "iree/base/flatcc.h" + +//===----------------------------------------------------------------------===// +// iree_hal_system_executable_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_local_executable_t base; + + // TODO(benvanik): library handle for ownership. + + union { + const iree_hal_executable_library_header_t* header; + const iree_hal_executable_library_v0_t* v0; + } library; +} iree_hal_system_executable_t; + +static const iree_hal_local_executable_vtable_t + iree_hal_system_executable_vtable; + +static iree_status_t iree_hal_system_executable_create( + iree_hal_executable_layout_t* base_layout, + const iree_hal_executable_library_header_t* library_header, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(base_layout); + IREE_ASSERT_ARGUMENT(library_header); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_executable_layout_t* local_layout = + iree_hal_local_executable_layout_cast(base_layout); + IREE_ASSERT_ARGUMENT(local_layout); + + iree_hal_system_executable_t* executable = NULL; + iree_status_t status = iree_allocator_malloc( + local_layout->host_allocator, sizeof(*executable), (void**)&executable); + if (iree_status_is_ok(status)) { + iree_hal_local_executable_initialize(&iree_hal_system_executable_vtable, + local_layout, &executable->base); + executable->library.header = library_header; + *out_executable = (iree_hal_executable_t*)executable; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_system_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_system_executable_t* executable = + (iree_hal_system_executable_t*)base_executable; + iree_allocator_t host_allocator = executable->base.layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_executable_deinitialize( + (iree_hal_local_executable_t*)base_executable); + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_system_executable_issue_call( + iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal, + const iree_hal_local_executable_call_t* call) { + iree_hal_system_executable_t* executable = + (iree_hal_system_executable_t*)base_executable; + + iree_host_size_t ordinal_count = executable->library.v0->entry_point_count; + if (IREE_UNLIKELY(ordinal >= ordinal_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "entry point ordinal out of bounds"); + } + +#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + iree_string_view_t entry_point_name = iree_make_cstring_view( + executable->library.v0->entry_point_names[ordinal]); + if (iree_string_view_is_empty(entry_point_name)) { + entry_point_name = iree_make_cstring_view("unknown_dylib_call"); + } + IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(z0, entry_point_name.data, + entry_point_name.size); +#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + + executable->library.v0->entry_points[ordinal]( + call->state, &call->workgroup_id, &call->workgroup_size, + &call->workgroup_count, call->push_constants, call->bindings); + + IREE_TRACE_ZONE_END(z0); + + return iree_ok_status(); +} + +static const iree_hal_local_executable_vtable_t + iree_hal_system_executable_vtable = { + .base = + { + .destroy = iree_hal_system_executable_destroy, + }, + .issue_call = iree_hal_system_executable_issue_call, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_system_library_loader_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_executable_loader_t base; + iree_allocator_t host_allocator; +} iree_hal_system_library_loader_t; + +static const iree_hal_executable_loader_vtable_t + iree_hal_system_library_loader_vtable; + +iree_status_t iree_hal_system_library_loader_create( + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader) { + IREE_ASSERT_ARGUMENT(out_executable_loader); + *out_executable_loader = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_system_library_loader_t* executable_loader = NULL; + iree_status_t status = iree_allocator_malloc( + host_allocator, sizeof(*executable_loader), (void**)&executable_loader); + if (iree_status_is_ok(status)) { + iree_hal_executable_loader_initialize( + &iree_hal_system_library_loader_vtable, &executable_loader->base); + executable_loader->host_allocator = host_allocator; + *out_executable_loader = (iree_hal_executable_loader_t*)executable_loader; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_system_library_loader_destroy( + iree_hal_executable_loader_t* base_executable_loader) { + iree_hal_system_library_loader_t* executable_loader = + (iree_hal_system_library_loader_t*)base_executable_loader; + iree_allocator_t host_allocator = executable_loader->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_loader); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_system_library_loader_query_support( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode) { + return executable_format == iree_hal_make_executable_format("DYEX"); +} + +static iree_status_t iree_hal_system_library_loader_try_load( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = + iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "new executable library format not yet implemented"); + + // Query the executable library to get the latest interface. + // Will fail if the executable is using a newer interface than we support. + // iree_hal_executable_library_header_t* header = NULL; + // IREE_RETURN_AND_END_ZONE_IF_ERROR( + // z0, iree_hal_executable_library_handle_query( + // executable_handle, IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION, + // &header)); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static const iree_hal_executable_loader_vtable_t + iree_hal_system_library_loader_vtable = { + .destroy = iree_hal_system_library_loader_destroy, + .query_support = iree_hal_system_library_loader_query_support, + .try_load = iree_hal_system_library_loader_try_load, +}; diff --git a/iree/hal/local/loaders/system_library_loader.h b/iree/hal/local/loaders/system_library_loader.h new file mode 100644 index 0000000000000..666fe7fee4d29 --- /dev/null +++ b/iree/hal/local/loaders/system_library_loader.h @@ -0,0 +1,38 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOADERS_SYSTEM_LIBRARY_LOADER_H_ +#define IREE_HAL_LOCAL_LOADERS_SYSTEM_LIBRARY_LOADER_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/local/executable_loader.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates an executable loader that can load files from platform-supported +// dynamic libraries (such as .dylib on darwin, .so on linux, .dll on windows). +iree_status_t iree_hal_system_library_loader_create( + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOADERS_SYSTEM_LIBRARY_LOADER_H_ diff --git a/iree/hal/local/loaders/vmla_module_loader.cc b/iree/hal/local/loaders/vmla_module_loader.cc new file mode 100644 index 0000000000000..3175038b74c97 --- /dev/null +++ b/iree/hal/local/loaders/vmla_module_loader.cc @@ -0,0 +1,375 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/loaders/vmla_module_loader.h" + +#include "iree/base/tracing.h" +#include "iree/hal/local/local_descriptor_set_layout.h" +#include "iree/hal/local/local_executable.h" +#include "iree/modules/vmla/op_module.h" +#include "iree/vm/bytecode_module.h" + +// flatcc schemas: +#include "iree/base/flatcc.h" +#include "iree/schemas/vmla_executable_def_reader.h" +#include "iree/schemas/vmla_executable_def_verifier.h" + +//===----------------------------------------------------------------------===// +// Verification and file utilities +//===----------------------------------------------------------------------===// + +// Verifies the structure of the flatbuffer so that we can avoid doing so during +// runtime. There are still some conditions we must be aware of (such as omitted +// names on functions with internal linkage), however we shouldn't need to +// bounds check anything within the flatbuffer after this succeeds. +static iree_status_t iree_hal_vmla_executable_flatbuffer_verify( + iree_const_byte_span_t flatbuffer_data) { + // Special handling for valid but mismatching flatbuffers. + if (!flatbuffer_data.data || flatbuffer_data.data_length < 16 || + !flatbuffers_has_identifier(flatbuffer_data.data, + iree_VMLAExecutableDef_file_identifier)) { + return iree_status_from_code(IREE_STATUS_CANCELLED); + } + + // Run flatcc generated verification. This ensures all pointers are in-bounds + // and that we can safely walk the file, but not that the actual contents of + // the flatbuffer meet our expectations. + int verify_ret = iree_VMLAExecutableDef_verify_as_root( + flatbuffer_data.data, flatbuffer_data.data_length); + if (verify_ret != flatcc_verify_ok) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer verification failed: %s", + flatcc_verify_error_string(verify_ret)); + } + + iree_VMLAExecutableDef_table_t executable_def = + iree_VMLAExecutableDef_as_root(flatbuffer_data.data); + + if (flatbuffers_uint8_vec_len( + iree_VMLAExecutableDef_bytecode_module_get(executable_def)) < 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable bytecode_module is missing/empty"); + } + + // NOTE: we don't check the actual bytecode module contents here; it's opaque + // to us and passed on to the VM. + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_vmla_executable_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_local_executable_t base; + + // Context containing both the VMLA module and the loaded executable. + iree_vm_context_t* context; + + // Resolved entry functions from the module. + iree_host_size_t entry_fn_count; + iree_vm_function_t entry_fns[]; +} iree_hal_vmla_executable_t; + +extern const iree_hal_local_executable_vtable_t iree_hal_vmla_executable_vtable; + +static iree_status_t iree_hal_vmla_executable_create( + iree_hal_executable_layout_t* base_layout, iree_vm_context_t* context, + iree_vm_module_t* bytecode_module, iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(base_layout); + IREE_ASSERT_ARGUMENT(context); + IREE_ASSERT_ARGUMENT(bytecode_module); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_executable_layout_t* local_layout = + iree_hal_local_executable_layout_cast(base_layout); + IREE_ASSERT_ARGUMENT(local_layout); + + iree_allocator_t host_allocator = local_layout->host_allocator; + iree_hal_vmla_executable_t* executable = NULL; + iree_host_size_t entry_count = + iree_vm_module_signature(bytecode_module).export_function_count; + iree_host_size_t total_size = + sizeof(*executable) + entry_count * sizeof(*executable->entry_fns); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&executable); + if (iree_status_is_ok(status)) { + iree_hal_local_executable_initialize(&iree_hal_vmla_executable_vtable, + local_layout, &executable->base); + executable->context = context; + iree_vm_context_retain(executable->context); + + executable->entry_fn_count = entry_count; + for (iree_host_size_t i = 0; i < executable->entry_fn_count; ++i) { + status = iree_vm_module_lookup_function_by_ordinal( + bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, + &executable->entry_fns[i], NULL); + if (!iree_status_is_ok(status)) break; + } + } + + if (iree_status_is_ok(status)) { + *out_executable = (iree_hal_executable_t*)executable; + } else { + iree_hal_executable_release((iree_hal_executable_t*)executable); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vmla_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_vmla_executable_t* executable = + (iree_hal_vmla_executable_t*)base_executable; + iree_allocator_t host_allocator = executable->base.layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_vm_context_release(executable->context); + iree_hal_local_executable_deinitialize( + (iree_hal_local_executable_t*)base_executable); + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_vmla_executable_issue_call( + iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal, + const iree_hal_local_executable_call_t* call) { + iree_hal_vmla_executable_t* executable = + (iree_hal_vmla_executable_t*)base_executable; + + if (IREE_UNLIKELY(ordinal >= executable->entry_fn_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "entry point ordinal out of bounds"); + } + +#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + iree_string_view_t entry_point_name = + iree_vm_function_name(&executable->entry_fns[ordinal]); + if (iree_string_view_is_empty(entry_point_name)) { + entry_point_name = iree_make_cstring_view("unknown_vmla_call"); + } + IREE_TRACE_ZONE_BEGIN_NAMED_DYNAMIC(z0, entry_point_name.data, + entry_point_name.size); +#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + + // We churn memory here but I don't rightly care: this entire VMLA approach is + // deprecated and will be going away at some point. There's about 100 + // low-hanging branches we can hack at in the compiler before this extra + // allocation matters :) + iree_allocator_t host_allocator = executable->base.layout->host_allocator; + iree::hal::vmla::Interface interface; + iree_vm_ref_t interface_ref = Interface_retain_ref(&interface); + iree_host_size_t input_list_size = iree_vm_list_storage_size( + /*element_type=*/NULL, /*interface*/ 1 + /*workgroup_xyz[3]*/ 3); + void* input_list_storage = iree_alloca(input_list_size); + iree_vm_list_t* input_list = NULL; + IREE_CHECK_OK(iree_vm_list_initialize( + iree_make_byte_span(input_list_storage, input_list_size), + /*element_type=*/NULL, + /*interface*/ 1 + /*workgroup_xyz[3]*/ 3, &input_list)); + iree_vm_list_push_ref_retain(input_list, &interface_ref); + iree_vm_value_t workgroup_id_x = iree_vm_value_make_i32(call->workgroup_id.x); + iree_vm_value_t workgroup_id_y = iree_vm_value_make_i32(call->workgroup_id.y); + iree_vm_value_t workgroup_id_z = iree_vm_value_make_i32(call->workgroup_id.z); + iree_vm_list_push_value(input_list, &workgroup_id_x); + iree_vm_list_push_value(input_list, &workgroup_id_y); + iree_vm_list_push_value(input_list, &workgroup_id_z); + + iree_hal_local_executable_layout_t* local_layout = executable->base.layout; + + IREE_CHECK_OK(interface.SetConstants( + absl::MakeConstSpan(call->push_constants, local_layout->push_constants))); + + for (iree_host_size_t set_ordinal = 0; + set_ordinal < local_layout->set_layout_count; ++set_ordinal) { + iree_hal_local_descriptor_set_layout_t* local_set_layout = + iree_hal_local_descriptor_set_layout_cast( + local_layout->set_layouts[set_ordinal]); + for (iree_host_size_t i = 0; i < local_set_layout->binding_count; ++i) { + auto buffer_or = iree::hal::vmla::Buffer::WrapMutable( + call->bindings[i], call->binding_lengths[i], iree_allocator_null()); + IREE_CHECK_OK(buffer_or.status()); + IREE_CHECK_OK(interface.SetBinding(set_ordinal, + local_set_layout->bindings[i].binding, + {std::move(buffer_or.value())})); + } + } + + iree_status_t status = + iree_vm_invoke(executable->context, executable->entry_fns[ordinal], + /*policy=*/NULL, input_list, + /*outputs=*/NULL, host_allocator); + + iree_vm_list_deinitialize(input_list); + iree_vm_ref_release(&interface_ref); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +const iree_hal_local_executable_vtable_t iree_hal_vmla_executable_vtable = { + /*.base=*/ + { + /*.destroy=*/iree_hal_vmla_executable_destroy, + }, + /*.issue_call=*/iree_hal_vmla_executable_issue_call, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_vmla_module_loader_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_executable_loader_t base; + iree_allocator_t host_allocator; + iree_vm_instance_t* instance; + iree_vm_module_t* vmla_module; +} iree_hal_vmla_module_loader_t; + +extern const iree_hal_executable_loader_vtable_t + iree_hal_vmla_module_loader_vtable; + +iree_status_t iree_hal_vmla_module_loader_create( + iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader) { + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(out_executable_loader); + *out_executable_loader = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // A single VMLA module is shared across all loaded executables. + IREE_RETURN_IF_ERROR(iree::hal::vmla::ModuleRegisterTypes()); + iree_vm_module_t* vmla_module = NULL; + IREE_RETURN_IF_ERROR( + iree::hal::vmla::ModuleCreate(host_allocator, &vmla_module)); + + iree_hal_vmla_module_loader_t* executable_loader = NULL; + iree_status_t status = iree_allocator_malloc( + host_allocator, sizeof(*executable_loader), (void**)&executable_loader); + if (iree_status_is_ok(status)) { + iree_hal_executable_loader_initialize(&iree_hal_vmla_module_loader_vtable, + &executable_loader->base); + executable_loader->host_allocator = host_allocator; + executable_loader->instance = instance; + iree_vm_instance_retain(executable_loader->instance); + executable_loader->vmla_module = vmla_module; + iree_vm_module_retain(executable_loader->vmla_module); + *out_executable_loader = (iree_hal_executable_loader_t*)executable_loader; + } + + iree_vm_module_release(vmla_module); + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vmla_module_loader_destroy( + iree_hal_executable_loader_t* base_executable_loader) { + iree_hal_vmla_module_loader_t* executable_loader = + (iree_hal_vmla_module_loader_t*)base_executable_loader; + iree_allocator_t host_allocator = executable_loader->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_vm_module_release(executable_loader->vmla_module); + iree_vm_instance_release(executable_loader->instance); + iree_allocator_free(host_allocator, executable_loader); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_vmla_module_loader_query_support( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode) { + return executable_format == iree_hal_make_executable_format("VMLA"); +} + +static iree_status_t iree_hal_vmla_module_loader_try_load( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_format_t executable_format, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + iree_hal_vmla_module_loader_t* executable_loader = + (iree_hal_vmla_module_loader_t*)base_executable_loader; + IREE_TRACE_ZONE_BEGIN(z0); + + // Verify that we have a valid flatbuffer that contains a VMLA executable. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vmla_executable_flatbuffer_verify(executable_data)); + iree_VMLAExecutableDef_table_t executable_def = + iree_VMLAExecutableDef_as_root(executable_data.data); + flatbuffers_uint8_vec_t bytecode_module_vec = + iree_VMLAExecutableDef_bytecode_module_get(executable_def); + iree_const_byte_span_t bytecode_module_data = iree_make_const_byte_span( + bytecode_module_vec, flatbuffers_uint8_vec_len(bytecode_module_vec)); + + // If the caching mode allows for aliasing the existing flatbuffer data then + // we avoid allocations and just pass the pointer on through. The caller + // ensures that the data remains valid for the duration the executable is + // loaded. Otherwise, we clone it and let the bytecode module take ownership. + iree_allocator_t bytecode_module_allocator; + if (iree_all_bits_set(caching_mode, + IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA)) { + // Zero-copy route. + bytecode_module_allocator = iree_allocator_null(); + } else { + bytecode_module_allocator = executable_loader->host_allocator; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_clone(executable_loader->host_allocator, + bytecode_module_data, + (void**)&bytecode_module_data.data)); + } + + // Load the user-provided bytecode module. We pass ownership of the data (if + // we have it) to the module to manage. + iree_vm_module_t* bytecode_module = NULL; + iree_status_t status = iree_vm_bytecode_module_create( + bytecode_module_data, bytecode_module_allocator, + executable_loader->host_allocator, &bytecode_module); + + // Create the context tying together the shared VMLA module and the + // user-provided module that references it. If we wanted to allow custom + // modules here for user-provided functions we'd mix them in here. + iree_vm_context_t* context = NULL; + if (iree_status_is_ok(status)) { + iree_vm_module_t* modules[2] = {executable_loader->vmla_module, + bytecode_module}; + status = iree_vm_context_create_with_modules( + executable_loader->instance, modules, IREE_ARRAYSIZE(modules), + executable_loader->host_allocator, &context); + } + + // Executable takes ownership of the entire context (including the bytecode + // module, which itself may own the underlying allocation). + if (iree_status_is_ok(status)) { + status = iree_hal_vmla_executable_create(executable_layout, context, + bytecode_module, out_executable); + } + + iree_vm_context_release(context); + iree_vm_module_release(bytecode_module); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +const iree_hal_executable_loader_vtable_t iree_hal_vmla_module_loader_vtable = { + /*.destroy=*/iree_hal_vmla_module_loader_destroy, + /*.query_support=*/iree_hal_vmla_module_loader_query_support, + /*.try_load=*/iree_hal_vmla_module_loader_try_load, +}; diff --git a/iree/hal/local/loaders/vmla_module_loader.h b/iree/hal/local/loaders/vmla_module_loader.h new file mode 100644 index 0000000000000..041b7d43f4762 --- /dev/null +++ b/iree/hal/local/loaders/vmla_module_loader.h @@ -0,0 +1,39 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOADERS_VMLA_MODULE_LOADER_H_ +#define IREE_HAL_LOCAL_LOADERS_VMLA_MODULE_LOADER_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/local/executable_loader.h" +#include "iree/vm/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates an executable loader that can load compiled IREE VM bytecode modules +// using the VMLA module. |instance| will be used for all loaded contexts. +iree_status_t iree_hal_vmla_module_loader_create( + iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOADERS_VMLA_MODULE_LOADER_H_ diff --git a/iree/hal/local/local_descriptor_set.c b/iree/hal/local/local_descriptor_set.c new file mode 100644 index 0000000000000..b37d9617da34d --- /dev/null +++ b/iree/hal/local/local_descriptor_set.c @@ -0,0 +1,87 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/local_descriptor_set.h" + +#include "iree/base/tracing.h" + +static const iree_hal_descriptor_set_vtable_t + iree_hal_local_descriptor_set_vtable; + +iree_hal_local_descriptor_set_t* iree_hal_local_descriptor_set_cast( + iree_hal_descriptor_set_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_local_descriptor_set_vtable); + return (iree_hal_local_descriptor_set_t*)base_value; +} + +iree_status_t iree_hal_local_descriptor_set_create( + iree_hal_descriptor_set_layout_t* base_layout, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set) { + IREE_ASSERT_ARGUMENT(base_layout); + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set); + *out_descriptor_set = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_descriptor_set_layout_t* local_layout = + iree_hal_local_descriptor_set_layout_cast(base_layout); + IREE_ASSERT_ARGUMENT(local_layout); + + iree_hal_local_descriptor_set_t* descriptor_set = NULL; + iree_host_size_t total_size = + sizeof(*descriptor_set) + + binding_count * sizeof(*descriptor_set->bindings); + iree_status_t status = iree_allocator_malloc( + local_layout->host_allocator, total_size, (void**)&descriptor_set); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_local_descriptor_set_vtable, + &descriptor_set->resource); + descriptor_set->layout = local_layout; + iree_hal_descriptor_set_layout_retain(base_layout); + descriptor_set->binding_count = binding_count; + memcpy(descriptor_set->bindings, bindings, + binding_count * sizeof(iree_hal_descriptor_set_binding_t)); + for (iree_host_size_t i = 0; i < descriptor_set->binding_count; ++i) { + iree_hal_buffer_retain(descriptor_set->bindings[i].buffer); + } + *out_descriptor_set = (iree_hal_descriptor_set_t*)descriptor_set; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_local_descriptor_set_destroy( + iree_hal_descriptor_set_t* base_descriptor_set) { + iree_hal_local_descriptor_set_t* descriptor_set = + iree_hal_local_descriptor_set_cast(base_descriptor_set); + iree_allocator_t host_allocator = descriptor_set->layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < descriptor_set->binding_count; ++i) { + iree_hal_buffer_release(descriptor_set->bindings[i].buffer); + } + iree_hal_descriptor_set_layout_release( + (iree_hal_descriptor_set_layout_t*)descriptor_set->layout); + iree_allocator_free(host_allocator, descriptor_set); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_descriptor_set_vtable_t + iree_hal_local_descriptor_set_vtable = { + .destroy = iree_hal_local_descriptor_set_destroy, +}; diff --git a/iree/hal/local/local_descriptor_set.h b/iree/hal/local/local_descriptor_set.h new file mode 100644 index 0000000000000..1032008aa90c8 --- /dev/null +++ b/iree/hal/local/local_descriptor_set.h @@ -0,0 +1,45 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOCAL_DESCRIPTOR_SET_H_ +#define IREE_HAL_LOCAL_LOCAL_DESCRIPTOR_SET_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/local/local_descriptor_set_layout.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct { + iree_hal_resource_t resource; + iree_hal_local_descriptor_set_layout_t* layout; + iree_host_size_t binding_count; + iree_hal_descriptor_set_binding_t bindings[]; +} iree_hal_local_descriptor_set_t; + +iree_status_t iree_hal_local_descriptor_set_create( + iree_hal_descriptor_set_layout_t* layout, iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set); + +iree_hal_local_descriptor_set_t* iree_hal_local_descriptor_set_cast( + iree_hal_descriptor_set_t* base_value); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOCAL_DESCRIPTOR_SET_H_ diff --git a/iree/hal/local/local_descriptor_set_layout.c b/iree/hal/local/local_descriptor_set_layout.c new file mode 100644 index 0000000000000..f1c91fe020211 --- /dev/null +++ b/iree/hal/local/local_descriptor_set_layout.c @@ -0,0 +1,82 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/local_descriptor_set_layout.h" + +#include "iree/base/tracing.h" + +static const iree_hal_descriptor_set_layout_vtable_t + iree_hal_local_descriptor_set_layout_vtable; + +iree_hal_local_descriptor_set_layout_t* +iree_hal_local_descriptor_set_layout_cast( + iree_hal_descriptor_set_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_local_descriptor_set_layout_vtable); + return (iree_hal_local_descriptor_set_layout_t*)base_value; +} + +iree_status_t iree_hal_local_descriptor_set_layout_create( + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_allocator_t host_allocator, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); + *out_descriptor_set_layout = NULL; + if (binding_count > IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, "binding count %zu over the limit of %d", + binding_count, IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_descriptor_set_layout_t* layout = NULL; + iree_host_size_t total_size = + sizeof(*layout) + binding_count * sizeof(*layout->bindings); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_local_descriptor_set_layout_vtable, + &layout->resource); + layout->host_allocator = host_allocator; + layout->usage_type = usage_type; + layout->binding_count = binding_count; + memcpy(layout->bindings, bindings, + binding_count * sizeof(iree_hal_descriptor_set_layout_binding_t)); + *out_descriptor_set_layout = (iree_hal_descriptor_set_layout_t*)layout; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_local_descriptor_set_layout_destroy( + iree_hal_descriptor_set_layout_t* base_layout) { + iree_hal_local_descriptor_set_layout_t* layout = + iree_hal_local_descriptor_set_layout_cast(base_layout); + iree_allocator_t host_allocator = layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, layout); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_descriptor_set_layout_vtable_t + iree_hal_local_descriptor_set_layout_vtable = { + .destroy = iree_hal_local_descriptor_set_layout_destroy, +}; diff --git a/iree/hal/local/local_descriptor_set_layout.h b/iree/hal/local/local_descriptor_set_layout.h new file mode 100644 index 0000000000000..3ee7dc7874712 --- /dev/null +++ b/iree/hal/local/local_descriptor_set_layout.h @@ -0,0 +1,50 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOCAL_DESCRIPTOR_SET_LAYOUT_H_ +#define IREE_HAL_LOCAL_LOCAL_DESCRIPTOR_SET_LAYOUT_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT 32 + +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + iree_hal_descriptor_set_layout_usage_type_t usage_type; + iree_host_size_t binding_count; + iree_hal_descriptor_set_layout_binding_t bindings[]; +} iree_hal_local_descriptor_set_layout_t; + +iree_status_t iree_hal_local_descriptor_set_layout_create( + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_allocator_t host_allocator, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + +iree_hal_local_descriptor_set_layout_t* +iree_hal_local_descriptor_set_layout_cast( + iree_hal_descriptor_set_layout_t* base_value); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOCAL_DESCRIPTOR_SET_LAYOUT_H_ diff --git a/iree/hal/local/local_executable.c b/iree/hal/local/local_executable.c new file mode 100644 index 0000000000000..616c3d785c92e --- /dev/null +++ b/iree/hal/local/local_executable.c @@ -0,0 +1,45 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/local_executable.h" + +void iree_hal_local_executable_initialize( + const iree_hal_local_executable_vtable_t* vtable, + iree_hal_local_executable_layout_t* layout, + iree_hal_local_executable_t* out_base_executable) { + iree_hal_resource_initialize(vtable, &out_base_executable->resource); + out_base_executable->layout = layout; + iree_hal_executable_layout_retain((iree_hal_executable_layout_t*)layout); +} + +void iree_hal_local_executable_deinitialize( + iree_hal_local_executable_t* base_executable) { + iree_hal_executable_layout_release( + (iree_hal_executable_layout_t*)base_executable->layout); +} + +iree_hal_local_executable_t* iree_hal_local_executable_cast( + iree_hal_executable_t* base_value) { + return (iree_hal_local_executable_t*)base_value; +} + +iree_status_t iree_hal_local_executable_issue_call( + iree_hal_local_executable_t* executable, iree_host_size_t ordinal, + const iree_hal_local_executable_call_t* call) { + IREE_ASSERT_ARGUMENT(executable); + IREE_ASSERT_ARGUMENT(call); + return ((const iree_hal_local_executable_vtable_t*) + executable->resource.vtable) + ->issue_call(executable, ordinal, call); +} diff --git a/iree/hal/local/local_executable.h b/iree/hal/local/local_executable.h new file mode 100644 index 0000000000000..a3e2acfa512b4 --- /dev/null +++ b/iree/hal/local/local_executable.h @@ -0,0 +1,69 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOCAL_EXECUTABLE_H_ +#define IREE_HAL_LOCAL_LOCAL_EXECUTABLE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/local/executable_library.h" +#include "iree/hal/local/local_executable_layout.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct { + const iree_hal_executable_dispatch_state_v0_t* state; + iree_hal_vec3_t workgroup_id; + iree_hal_vec3_t workgroup_size; + iree_hal_vec3_t workgroup_count; + iree_hal_executable_push_constants_ptr_t push_constants; + const iree_hal_executable_binding_ptr_t* bindings; + const iree_host_size_t* binding_lengths; +} iree_hal_local_executable_call_t; + +typedef struct { + iree_hal_resource_t resource; + iree_hal_local_executable_layout_t* layout; +} iree_hal_local_executable_t; + +typedef struct { + iree_hal_executable_vtable_t base; + + iree_status_t(IREE_API_PTR* issue_call)( + iree_hal_local_executable_t* executable, iree_host_size_t ordinal, + const iree_hal_local_executable_call_t* call); +} iree_hal_local_executable_vtable_t; + +void iree_hal_local_executable_initialize( + const iree_hal_local_executable_vtable_t* vtable, + iree_hal_local_executable_layout_t* layout, + iree_hal_local_executable_t* out_base_executable); + +void iree_hal_local_executable_deinitialize( + iree_hal_local_executable_t* base_executable); + +iree_hal_local_executable_t* iree_hal_local_executable_cast( + iree_hal_executable_t* base_value); + +iree_status_t iree_hal_local_executable_issue_call( + iree_hal_local_executable_t* executable, iree_host_size_t ordinal, + const iree_hal_local_executable_call_t* call); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOCAL_EXECUTABLE_H_ diff --git a/iree/hal/local/local_executable_cache.c b/iree/hal/local/local_executable_cache.c new file mode 100644 index 0000000000000..e5d5748f46a08 --- /dev/null +++ b/iree/hal/local/local_executable_cache.c @@ -0,0 +1,144 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/local_executable_cache.h" + +#include "iree/base/tracing.h" + +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + iree_string_view_t identifier; + iree_host_size_t loader_count; + iree_hal_executable_loader_t* loaders[]; +} iree_hal_local_executable_cache_t; + +static const iree_hal_executable_cache_vtable_t + iree_hal_local_executable_cache_vtable; + +static iree_hal_local_executable_cache_t* iree_hal_local_executable_cache_cast( + iree_hal_executable_cache_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_local_executable_cache_vtable); + return (iree_hal_local_executable_cache_t*)base_value; +} + +iree_status_t iree_hal_local_executable_cache_create( + iree_string_view_t identifier, iree_host_size_t loader_count, + iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator, + iree_hal_executable_cache_t** out_executable_cache) { + IREE_ASSERT_ARGUMENT(!loader_count || loaders); + IREE_ASSERT_ARGUMENT(out_executable_cache); + *out_executable_cache = NULL; + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_executable_cache_t* executable_cache = NULL; + iree_host_size_t total_size = + sizeof(*executable_cache) + + loader_count * sizeof(*executable_cache->loaders) + identifier.size; + iree_status_t status = iree_allocator_malloc(host_allocator, total_size, + (void**)&executable_cache); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_local_executable_cache_vtable, + &executable_cache->resource); + executable_cache->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &executable_cache->identifier, + (char*)executable_cache + total_size - identifier.size); + + executable_cache->loader_count = loader_count; + for (iree_host_size_t i = 0; i < executable_cache->loader_count; ++i) { + executable_cache->loaders[i] = loaders[i]; + iree_hal_executable_loader_retain(executable_cache->loaders[i]); + } + + *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_local_executable_cache_destroy( + iree_hal_executable_cache_t* base_executable_cache) { + iree_hal_local_executable_cache_t* executable_cache = + iree_hal_local_executable_cache_cast(base_executable_cache); + iree_allocator_t host_allocator = executable_cache->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < executable_cache->loader_count; ++i) { + iree_hal_executable_loader_release(executable_cache->loaders[i]); + } + iree_allocator_free(host_allocator, executable_cache); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_local_executable_cache_can_prepare_format( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_format_t format) { + iree_hal_local_executable_cache_t* executable_cache = + iree_hal_local_executable_cache_cast(base_executable_cache); + for (iree_host_size_t i = 0; i < executable_cache->loader_count; ++i) { + if (iree_hal_executable_loader_query_support( + executable_cache->loaders[i], format, + IREE_HAL_EXECUTABLE_CACHING_MODE_DEFAULT)) { + return true; + } + } + return false; +} + +static iree_status_t iree_hal_local_executable_cache_prepare_executable( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + iree_hal_local_executable_cache_t* executable_cache = + iree_hal_local_executable_cache_cast(base_executable_cache); + for (iree_host_size_t i = 0; i < executable_cache->loader_count; ++i) { + // TODO(benvanik): pass executable format through from the HAL. + // if (iree_hal_executable_loader_query_support( + // executable_cache->loaders[i], executable_format, + // IREE_HAL_EXECUTABLE_CACHING_MODE_DEFAULT)) { + // return iree_hal_executable_loader_try_load( + // executable_cache->loaders[i], executable_layout, + // executable_format, caching_mode, executable_data, + // out_executable); + // } + iree_status_t status = iree_hal_executable_loader_try_load( + executable_cache->loaders[i], executable_layout, + /*executable_format=*/0, caching_mode, executable_data, out_executable); + if (iree_status_is_ok(status)) { + // Executable was successfully loaded. + return status; + } else if (!iree_status_is_cancelled(status)) { + // Error beyond just the try failing due to unsupported formats. + return status; + } + } + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "no executable loader registered for the given file format"); +} + +static const iree_hal_executable_cache_vtable_t + iree_hal_local_executable_cache_vtable = { + .destroy = iree_hal_local_executable_cache_destroy, + .can_prepare_format = + iree_hal_local_executable_cache_can_prepare_format, + .prepare_executable = + iree_hal_local_executable_cache_prepare_executable, +}; diff --git a/iree/hal/local/local_executable_cache.h b/iree/hal/local/local_executable_cache.h new file mode 100644 index 0000000000000..09f502489fb05 --- /dev/null +++ b/iree/hal/local/local_executable_cache.h @@ -0,0 +1,43 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOCAL_EXECUTABLE_CACHE_H_ +#define IREE_HAL_LOCAL_LOCAL_EXECUTABLE_CACHE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/local/executable_loader.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(benvanik): when we refactor executable caches this can become something +// more specialized; like nop_executable_cache (does nothing but pass through) +// or inproc_lru_executable_cache (simple in-memory LRU of recent executables). +// +// We can also set this up so they share storage. Ideally a JIT'ed executable in +// one device is the same JIT'ed executable in another, and in multi-tenant +// situations we're likely to want that isolation _and_ sharing. + +iree_status_t iree_hal_local_executable_cache_create( + iree_string_view_t identifier, iree_host_size_t loader_count, + iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator, + iree_hal_executable_cache_t** out_executable_cache); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOCAL_EXECUTABLE_CACHE_H_ diff --git a/iree/hal/local/local_executable_layout.c b/iree/hal/local/local_executable_layout.c new file mode 100644 index 0000000000000..b505aa243ae74 --- /dev/null +++ b/iree/hal/local/local_executable_layout.c @@ -0,0 +1,111 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/local_executable_layout.h" + +#include "iree/base/tracing.h" +#include "iree/hal/local/local_descriptor_set_layout.h" + +static const iree_hal_executable_layout_vtable_t + iree_hal_local_executable_layout_vtable; + +iree_hal_local_executable_layout_t* iree_hal_local_executable_layout_cast( + iree_hal_executable_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_local_executable_layout_vtable); + return (iree_hal_local_executable_layout_t*)base_value; +} + +iree_status_t iree_hal_local_executable_layout_create( + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, iree_allocator_t host_allocator, + iree_hal_executable_layout_t** out_executable_layout) { + IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); + IREE_ASSERT_ARGUMENT(out_executable_layout); + *out_executable_layout = NULL; + if (set_layout_count > IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "set layout count %zu over the limit of %d", + set_layout_count, + IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT); + } + if (push_constants > IREE_HAL_LOCAL_MAX_PUSH_CONSTANT_COUNT) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "push constant count %zu over the limit of %d", + push_constants, + IREE_HAL_LOCAL_MAX_PUSH_CONSTANT_COUNT); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_host_size_t total_size = + sizeof(iree_hal_local_executable_layout_t) + + set_layout_count * sizeof(iree_hal_descriptor_set_layout_t*); + + iree_hal_local_executable_layout_t* layout = NULL; + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_local_executable_layout_vtable, + &layout->resource); + layout->host_allocator = host_allocator; + layout->push_constants = push_constants; + layout->dynamic_binding_count = 0; + layout->used_bindings = 0; + layout->set_layout_count = set_layout_count; + for (iree_host_size_t i = 0; i < set_layout_count; ++i) { + layout->set_layouts[i] = set_layouts[i]; + iree_hal_descriptor_set_layout_retain(layout->set_layouts[i]); + + iree_hal_local_descriptor_set_layout_t* local_set_layout = + iree_hal_local_descriptor_set_layout_cast(set_layouts[i]); + for (iree_host_size_t j = 0; j < local_set_layout->binding_count; ++j) { + const iree_hal_descriptor_set_layout_binding_t* binding = + &local_set_layout->bindings[j]; + layout->used_bindings |= + 1ull << (i * IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT + j); + switch (binding->type) { + case IREE_HAL_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC: + case IREE_HAL_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC: + ++layout->dynamic_binding_count; + break; + } + } + } + *out_executable_layout = (iree_hal_executable_layout_t*)layout; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_local_executable_layout_destroy( + iree_hal_executable_layout_t* base_layout) { + iree_hal_local_executable_layout_t* layout = + iree_hal_local_executable_layout_cast(base_layout); + iree_allocator_t host_allocator = layout->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < layout->set_layout_count; ++i) { + iree_hal_descriptor_set_layout_release(layout->set_layouts[i]); + } + iree_allocator_free(host_allocator, layout); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_executable_layout_vtable_t + iree_hal_local_executable_layout_vtable = { + .destroy = iree_hal_local_executable_layout_destroy, +}; diff --git a/iree/hal/local/local_executable_layout.h b/iree/hal/local/local_executable_layout.h new file mode 100644 index 0000000000000..801236aed6a7e --- /dev/null +++ b/iree/hal/local/local_executable_layout.h @@ -0,0 +1,53 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOCAL_EXECUTABLE_LAYOUT_H_ +#define IREE_HAL_LOCAL_LOCAL_EXECUTABLE_LAYOUT_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT 2 +#define IREE_HAL_LOCAL_MAX_PUSH_CONSTANT_COUNT 64 + +typedef uint64_t iree_hal_local_binding_mask_t; + +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + iree_host_size_t push_constants; + iree_host_size_t dynamic_binding_count; + iree_hal_local_binding_mask_t used_bindings; + iree_host_size_t set_layout_count; + iree_hal_descriptor_set_layout_t* set_layouts[]; +} iree_hal_local_executable_layout_t; + +iree_status_t iree_hal_local_executable_layout_create( + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, iree_allocator_t host_allocator, + iree_hal_executable_layout_t** out_executable_layout); + +iree_hal_local_executable_layout_t* iree_hal_local_executable_layout_cast( + iree_hal_executable_layout_t* base_value); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOCAL_EXECUTABLE_LAYOUT_H_ diff --git a/iree/hal/local/task_command_buffer.c b/iree/hal/local/task_command_buffer.c new file mode 100644 index 0000000000000..3bb46bb72a2f7 --- /dev/null +++ b/iree/hal/local/task_command_buffer.c @@ -0,0 +1,870 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/task_command_buffer.h" + +#include "iree/base/debugging.h" +#include "iree/base/tracing.h" +#include "iree/hal/local/local_descriptor_set_layout.h" +#include "iree/hal/local/local_executable.h" +#include "iree/hal/local/local_executable_layout.h" +#include "iree/task/list.h" +#include "iree/task/submission.h" +#include "iree/task/task.h" + +//===----------------------------------------------------------------------===// +// iree_hal_task_command_buffer_t +//===----------------------------------------------------------------------===// + +// iree/task/-based command buffer. +// We track a minimal amount of state here and incrementally build out the task +// DAG that we can submit to the task system directly. There's no intermediate +// data structures and we produce the iree_task_ts directly. In the steady state +// all allocations are served from a shared per-device block pool with no +// additional allocations required during recording or execution. That means our +// command buffer here is essentially just a builder for the task system types +// and manager of the lifetime of the tasks. +typedef struct { + iree_hal_resource_t resource; + + iree_hal_device_t* device; + iree_task_scope_t* scope; + iree_hal_command_buffer_mode_t mode; + iree_hal_command_category_t allowed_categories; + + // Arena used for all allocations; references the shared device block pool. + iree_arena_allocator_t arena; + + // One or more tasks at the root of the command buffer task DAG. + // These tasks are all able to execute concurrently and will be the initial + // ready task set in the submission. + iree_task_list_t root_tasks; + + // One or more tasks at the leaves of the DAG. + // Only once all these tasks have completed execution will the command buffer + // be considered completed as a whole. + // + // An empty list indicates that root_tasks are also the leaves. + iree_task_list_t leaf_tasks; + + // TODO(benvanik): move this out of the struct and allocate from the arena - + // we only need this during recording and it's ~4KB of waste otherwise. + // State tracked within the command buffer during recording only. + struct { + // The last global barrier that was inserted, if any. + // The barrier is allocated and inserted into the DAG when requested but the + // actual barrier dependency list is only allocated and set on flushes. + // This lets us allocate the appropriately sized barrier task list from the + // arena even though when the barrier is recorded we don't yet know what + // other tasks we'll be emitting as we walk the command stream. + iree_task_barrier_t* open_barrier; + + // The number of tasks in the open barrier (|open_tasks|), used to quickly + // allocate storage for the task list without needing to walk the list. + iree_host_size_t open_task_count; + + // All execution tasks emitted that must execute after |open_barrier|. + iree_task_list_t open_tasks; + + // A flattened list of all available descriptor set bindings. + // As descriptor sets are pushed/bound the bindings will be updated to + // represent the fully-translated binding data pointer. + // TODO(benvanik): support proper mapping semantics and track the + // iree_hal_buffer_mapping_t and map/unmap where appropriate. + iree_hal_executable_binding_ptr_t + bindings[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * + IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT]; + iree_device_size_t + binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * + IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT]; + + // All available push constants updated each time push_constants is called. + // Reset only with the command buffer and otherwise will maintain its values + // during recording to allow for partial push_constants updates. + uint32_t push_constants[IREE_HAL_LOCAL_MAX_PUSH_CONSTANT_COUNT]; + } state; +} iree_hal_task_command_buffer_t; + +static const iree_hal_command_buffer_vtable_t + iree_hal_task_command_buffer_vtable; + +static iree_hal_task_command_buffer_t* iree_hal_task_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_task_command_buffer_vtable); + return (iree_hal_task_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_task_command_buffer_create( + iree_hal_device_t* device, iree_task_scope_t* scope, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + if (mode != IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) { + // If we want reuse we'd need to support duplicating the task DAG after + // recording or have some kind of copy-on-submit behavior that does so if + // a command buffer is submitted for execution twice. Allowing for the same + // command buffer to be enqueued multiple times would be fine so long as + // execution doesn't overlap (`cmdbuf|cmdbuf` vs + // `cmdbuf -> semaphore -> cmdbuf`) though we'd still need to be careful + // that we did the enqueuing and reset of the task structures at the right + // times. Definitely something that'll be useful in the future... but not + // today :) + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "only one-shot command buffer usage is supported"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_task_command_buffer_t* command_buffer = NULL; + iree_status_t status = + iree_allocator_malloc(iree_hal_device_host_allocator(device), + sizeof(*command_buffer), (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_task_command_buffer_vtable, + &command_buffer->resource); + command_buffer->device = device; + command_buffer->scope = scope; + command_buffer->mode = mode; + command_buffer->allowed_categories = command_categories; + iree_arena_initialize(block_pool, &command_buffer->arena); + iree_task_list_initialize(&command_buffer->root_tasks); + iree_task_list_initialize(&command_buffer->leaf_tasks); + memset(&command_buffer->state, 0, sizeof(command_buffer->state)); + *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_task_command_buffer_reset( + iree_hal_task_command_buffer_t* command_buffer) { + memset(&command_buffer->state, 0, sizeof(command_buffer->state)); + iree_task_list_discard(&command_buffer->leaf_tasks); + iree_task_list_discard(&command_buffer->root_tasks); + iree_arena_reset(&command_buffer->arena); +} + +static void iree_hal_task_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + iree_allocator_t host_allocator = + iree_hal_device_host_allocator(command_buffer->device); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_task_command_buffer_reset(command_buffer); + iree_arena_deinitialize(&command_buffer->arena); + iree_allocator_free(host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_hal_command_category_t +iree_hal_task_command_buffer_allowed_categories( + const iree_hal_command_buffer_t* base_command_buffer) { + return ((const iree_hal_task_command_buffer_t*)base_command_buffer) + ->allowed_categories; +} + +//===----------------------------------------------------------------------===// +// iree_hal_task_command_buffer_t recording +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_task_command_buffer_flush_tasks( + iree_hal_task_command_buffer_t* command_buffer); + +static iree_status_t iree_hal_task_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + iree_hal_task_command_buffer_reset(command_buffer); + return iree_ok_status(); +} + +static iree_status_t iree_hal_task_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + // Flush any open barriers. + IREE_RETURN_IF_ERROR( + iree_hal_task_command_buffer_flush_tasks(command_buffer)); + + // Move the tasks from the leaf list (tail) to the root list (head) if this + // was the first set of tasks recorded. + if (iree_task_list_is_empty(&command_buffer->root_tasks) && + !iree_task_list_is_empty(&command_buffer->leaf_tasks)) { + iree_task_list_move(&command_buffer->leaf_tasks, + &command_buffer->root_tasks); + } + + return iree_ok_status(); +} + +// Flushes all open tasks to the previous barrier and prepares for more +// recording. The root tasks are also populated here when required as this is +// the one place where we can see both halves of the most recent synchronization +// event: those tasks recorded prior (if any) and the task that marks the set of +// tasks that will be recorded after (if any). +static iree_status_t iree_hal_task_command_buffer_flush_tasks( + iree_hal_task_command_buffer_t* command_buffer) { + iree_task_barrier_t* open_barrier = command_buffer->state.open_barrier; + if (open_barrier != NULL) { + // There is an open barrier we need to fixup the fork out to all of the open + // tasks that were recorded after it. + iree_task_t* task_head = + iree_task_list_front(&command_buffer->state.open_tasks); + iree_host_size_t dependent_task_count = + command_buffer->state.open_task_count; + if (dependent_task_count == 1) { + // Special-case: only one open task so we can avoid the additional barrier + // overhead by reusing the completion task. + iree_task_set_completion_task(&open_barrier->header, task_head); + } else if (dependent_task_count > 1) { + // Allocate the list of tasks we'll stash back on the previous barrier. + // Since we couldn't know at the time how many tasks would end up in the + // barrier we had to defer it until now. + iree_task_t** dependent_tasks = NULL; + IREE_RETURN_IF_ERROR(iree_arena_allocate( + &command_buffer->arena, dependent_task_count * sizeof(iree_task_t*), + (void**)&dependent_tasks)); + iree_task_t* task = task_head; + for (iree_host_size_t i = 0; i < dependent_task_count; ++i) { + dependent_tasks[i] = task; + task = task->next_task; + } + iree_task_barrier_set_dependent_tasks(open_barrier, dependent_task_count, + dependent_tasks); + } + } + command_buffer->state.open_barrier = NULL; + + // Move the open tasks to the tail as they represent the first half of the + // *next* barrier that will be inserted. + if (command_buffer->state.open_task_count > 0) { + iree_task_list_move(&command_buffer->state.open_tasks, + &command_buffer->leaf_tasks); + command_buffer->state.open_task_count = 0; + } + + return iree_ok_status(); +} + +// Emits a global barrier, splitting execution into all prior recorded tasks +// and all subsequent recorded tasks. This is currently the critical piece that +// limits our concurrency: changing to fine-grained barriers (via barrier +// buffers or events) will allow more work to overlap at the cost of more brain +// to build out the proper task graph. +static iree_status_t iree_hal_task_command_buffer_emit_global_barrier( + iree_hal_task_command_buffer_t* command_buffer) { + // Flush open tasks to the previous barrier. This resets our state such that + // we can assign the new open barrier and start recording tasks for it. + // Previous tasks will be moved into the leaf_tasks list. + IREE_RETURN_IF_ERROR( + iree_hal_task_command_buffer_flush_tasks(command_buffer)); + + // Allocate the new open barrier. + // As we are recording forward we can't yet assign the dependent tasks (the + // second half of the synchronization domain) and instead are just inserting + // it so we can setup the join from previous tasks (the first half of the + // synchronization domain). + iree_task_barrier_t* barrier = NULL; + IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->arena, + sizeof(*barrier), (void**)&barrier)); + iree_task_barrier_initialize_empty(command_buffer->scope, barrier); + + // If there were previous tasks then join them to the barrier. + for (iree_task_t* task = iree_task_list_front(&command_buffer->leaf_tasks); + task != NULL; task = task->next_task) { + iree_task_set_completion_task(task, &barrier->header); + } + + // Move the tasks from the leaf list (tail) to the root list (head) if this + // was the first set of tasks recorded. + if (iree_task_list_is_empty(&command_buffer->root_tasks) && + !iree_task_list_is_empty(&command_buffer->leaf_tasks)) { + iree_task_list_move(&command_buffer->leaf_tasks, + &command_buffer->root_tasks); + } + + // Reset the tail of the command buffer to the barrier. This leaves us in a + // consistent state if the recording ends immediate after this (the barrier + // will be the last task). + iree_task_list_initialize(&command_buffer->leaf_tasks); + iree_task_list_push_back(&command_buffer->leaf_tasks, &barrier->header); + + // NOTE: all new tasks emitted will be executed after this barrier. + command_buffer->state.open_barrier = barrier; + command_buffer->state.open_task_count = 0; + + return iree_ok_status(); +} + +// Emits a the given execution |task| into the current open synchronization +// scope (after state.open_barrier and before the next barrier). +static iree_status_t iree_hal_task_command_buffer_emit_execution_task( + iree_hal_task_command_buffer_t* command_buffer, iree_task_t* task) { + if (command_buffer->state.open_barrier == NULL) { + // If there is no open barrier then we are at the head and going right into + // the task DAG. + iree_task_list_push_back(&command_buffer->leaf_tasks, task); + } else { + // Append to the open task list that will be flushed to the open barrier. + iree_task_list_push_back(&command_buffer->state.open_tasks, task); + ++command_buffer->state.open_task_count; + } + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_task_command_buffer_t execution +//===----------------------------------------------------------------------===// + +iree_status_t iree_hal_task_command_buffer_issue( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_task_queue_state_t* queue_state, iree_task_t* retire_task, + iree_arena_allocator_t* arena, iree_task_submission_t* pending_submission) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + // If the command buffer is empty (valid!) then we are a no-op. + bool has_root_tasks = !iree_task_list_is_empty(&command_buffer->root_tasks); + if (!has_root_tasks) { + return iree_ok_status(); + } + + bool has_leaf_tasks = !iree_task_list_is_empty(&command_buffer->leaf_tasks); + if (has_leaf_tasks) { + // Chain the retire task onto the leaf tasks as their completion indicates + // that all commands have completed. + for (iree_task_t* task = command_buffer->leaf_tasks.head; task != NULL; + task = task->next_task) { + iree_task_set_completion_task(task, retire_task); + } + } else { + // If we have no leaf tasks it means that this is a single layer DAG and + // after the root tasks complete the entire command buffer has completed. + for (iree_task_t* task = command_buffer->root_tasks.head; task != NULL; + task = task->next_task) { + iree_task_set_completion_task(task, retire_task); + } + } + + // Enqueue all root tasks that are ready to run immediately. + // After this all of the command buffer tasks are owned by the submission and + // we need to ensure the command buffer doesn't try to discard them. + iree_task_submission_enqueue_list(pending_submission, + &command_buffer->root_tasks); + iree_task_list_initialize(&command_buffer->leaf_tasks); + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_execution_barrier +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_task_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + // TODO(benvanik): actual DAG construction. Right now we are just doing simple + // global barriers each time and forcing a join-fork point. + return iree_hal_task_command_buffer_emit_global_barrier(command_buffer); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_signal_event +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_task_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + // TODO(#4518): implement events. For now we just insert global barriers. + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_reset_event +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_task_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + // TODO(#4518): implement events. For now we just insert global barriers. + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_wait_events +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_task_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + // TODO(#4518): implement events. For now we just insert global barriers. + return iree_hal_task_command_buffer_emit_global_barrier(command_buffer); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_discard_buffer +//===----------------------------------------------------------------------===// + +static iree_status_t iree_hal_task_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_fill_buffer +//===----------------------------------------------------------------------===// +// NOTE: for large fills we could dispatch this as tiles for parallelism. +// We'd want to do some measurement for when it's worth it; filling a 200KB +// buffer: maybe not, filling a 200MB buffer: yeah. + +typedef struct { + iree_task_call_t task; + iree_hal_buffer_t* target_buffer; + iree_device_size_t target_offset; + iree_device_size_t length; + uint32_t pattern_length; + uint8_t pattern[8]; +} iree_hal_cmd_fill_buffer_t; + +static iree_status_t iree_hal_cmd_fill_buffer( + uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + const iree_hal_cmd_fill_buffer_t* cmd = + (const iree_hal_cmd_fill_buffer_t*)user_context; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + iree_hal_buffer_fill(cmd->target_buffer, cmd->target_offset, cmd->length, + cmd->pattern, cmd->pattern_length); + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_task_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + iree_hal_cmd_fill_buffer_t* cmd = NULL; + IREE_RETURN_IF_ERROR( + iree_arena_allocate(&command_buffer->arena, sizeof(*cmd), (void**)&cmd)); + + iree_task_call_initialize( + command_buffer->scope, + iree_task_make_call_closure(iree_hal_cmd_fill_buffer, (uintptr_t)cmd), + &cmd->task); + cmd->target_buffer = target_buffer; + cmd->target_offset = target_offset; + cmd->length = length; + memcpy(cmd->pattern, pattern, pattern_length); + cmd->pattern_length = pattern_length; + + return iree_hal_task_command_buffer_emit_execution_task(command_buffer, + &cmd->task.header); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_update_buffer +//===----------------------------------------------------------------------===// + +typedef struct { + iree_task_call_t task; + iree_hal_buffer_t* target_buffer; + iree_device_size_t target_offset; + iree_device_size_t length; + uint8_t source_buffer[]; +} iree_hal_cmd_update_buffer_t; + +static iree_status_t iree_hal_cmd_update_buffer( + uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + const iree_hal_cmd_update_buffer_t* cmd = + (const iree_hal_cmd_update_buffer_t*)user_context; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_hal_buffer_write_data( + cmd->target_buffer, cmd->target_offset, cmd->source_buffer, cmd->length); + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_task_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + iree_host_size_t total_cmd_size = + sizeof(iree_hal_cmd_update_buffer_t) + length; + + iree_hal_cmd_update_buffer_t* cmd = NULL; + IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->arena, + total_cmd_size, (void**)&cmd)); + + iree_task_call_initialize( + command_buffer->scope, + iree_task_make_call_closure(iree_hal_cmd_update_buffer, (uintptr_t)cmd), + &cmd->task); + cmd->target_buffer = (iree_hal_buffer_t*)target_buffer; + cmd->target_offset = target_offset; + cmd->length = length; + + memcpy(cmd->source_buffer, (const uint8_t*)source_buffer + source_offset, + cmd->length); + + return iree_hal_task_command_buffer_emit_execution_task(command_buffer, + &cmd->task.header); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_copy_buffer +//===----------------------------------------------------------------------===// +// NOTE: for large copies we could dispatch this as tiles for parallelism. +// We'd want to do some measurement for when it's worth it; copying a 200KB +// buffer: maybe not, copying a 200MB buffer: yeah. + +typedef struct { + iree_task_call_t task; + iree_hal_buffer_t* source_buffer; + iree_device_size_t source_offset; + iree_hal_buffer_t* target_buffer; + iree_device_size_t target_offset; + iree_device_size_t length; +} iree_hal_cmd_copy_buffer_t; + +static iree_status_t iree_hal_cmd_copy_buffer( + uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + const iree_hal_cmd_copy_buffer_t* cmd = + (const iree_hal_cmd_copy_buffer_t*)user_context; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_hal_buffer_copy_data( + cmd->source_buffer, cmd->source_offset, cmd->target_buffer, + cmd->target_offset, cmd->length); + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_task_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + iree_hal_cmd_copy_buffer_t* cmd = NULL; + IREE_RETURN_IF_ERROR( + iree_arena_allocate(&command_buffer->arena, sizeof(*cmd), (void**)&cmd)); + + iree_task_call_initialize( + command_buffer->scope, + iree_task_make_call_closure(iree_hal_cmd_copy_buffer, (uintptr_t)cmd), + &cmd->task); + cmd->source_buffer = (iree_hal_buffer_t*)source_buffer; + cmd->source_offset = source_offset; + cmd->target_buffer = (iree_hal_buffer_t*)target_buffer; + cmd->target_offset = target_offset; + cmd->length = length; + + return iree_hal_task_command_buffer_emit_execution_task(command_buffer, + &cmd->task.header); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_push_constants +//===----------------------------------------------------------------------===// +// NOTE: command buffer state change only; enqueues no tasks. + +static iree_status_t iree_hal_task_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + if (IREE_UNLIKELY(offset + values_length >= + sizeof(command_buffer->state.push_constants))) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "push constant range %zu (length=%zu) out of range", + offset, values_length); + } + + memcpy((uint8_t*)&command_buffer->state.push_constants + offset, values, + values_length); + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_push_descriptor_set +//===----------------------------------------------------------------------===// +// NOTE: command buffer state change only; enqueues no tasks. + +static iree_status_t iree_hal_task_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + if (IREE_UNLIKELY(set >= IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "set %u out of bounds", set); + } + + iree_hal_local_executable_layout_t* local_executable_layout = + iree_hal_local_executable_layout_cast(executable_layout); + iree_hal_local_descriptor_set_layout_t* local_set_layout = + iree_hal_local_descriptor_set_layout_cast( + local_executable_layout->set_layouts[set]); + + iree_host_size_t binding_base = + set * IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT; + for (iree_host_size_t i = 0; i < binding_count; ++i) { + if (IREE_UNLIKELY(bindings[i].binding >= + IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "buffer binding index out of bounds"); + } + iree_host_size_t binding_ordinal = binding_base + bindings[i].binding; + + // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. + iree_hal_buffer_mapping_t buffer_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + bindings[i].buffer, local_set_layout->bindings[binding_ordinal].access, + bindings[i].offset, bindings[i].length, &buffer_mapping)); + command_buffer->state.bindings[binding_ordinal] = + buffer_mapping.contents.data; + command_buffer->state.binding_lengths[binding_ordinal] = + buffer_mapping.contents.data_length; + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_bind_descriptor_set +//===----------------------------------------------------------------------===// +// NOTE: command buffer state change only; enqueues no tasks. + +static iree_status_t iree_hal_task_command_buffer_bind_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "descriptor set binding not yet implemented"); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_dispatch +//===----------------------------------------------------------------------===// + +typedef struct { + iree_task_dispatch_t task; + iree_hal_local_executable_t* executable; + iree_host_size_t ordinal; + iree_hal_executable_binding_ptr_t* IREE_RESTRICT bindings; + iree_device_size_t* IREE_RESTRICT binding_lengths; + uint32_t* IREE_RESTRICT push_constants; +} iree_hal_cmd_dispatch_t; + +static iree_status_t iree_hal_cmd_dispatch_tile( + uintptr_t user_context, const iree_task_tile_context_t* tile_context, + iree_task_submission_t* pending_submission) { + const iree_hal_cmd_dispatch_t* cmd = + (const iree_hal_cmd_dispatch_t*)user_context; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_executable_dispatch_state_v0_t state; + // TODO(benvanik): wire up device state (imports, etc) and cache on the + // command buffer for reuse across all tiles. + + iree_hal_local_executable_call_t call = { + .state = &state, + .push_constants = cmd->push_constants, + .bindings = cmd->bindings, + .binding_lengths = cmd->binding_lengths, + }; + memcpy(call.workgroup_id.value, tile_context->workgroup_xyz, + sizeof(iree_hal_vec3_t)); + memcpy(call.workgroup_size.value, tile_context->workgroup_size, + sizeof(iree_hal_vec3_t)); + memcpy(call.workgroup_count.value, tile_context->workgroup_count, + sizeof(iree_hal_vec3_t)); + iree_status_t status = iree_hal_local_executable_issue_call( + cmd->executable, cmd->ordinal, &call); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_task_command_buffer_build_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_cmd_dispatch_t** out_cmd) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + iree_hal_local_executable_t* local_executable = + iree_hal_local_executable_cast(executable); + iree_host_size_t push_constant_count = + local_executable->layout->push_constants; + iree_hal_local_binding_mask_t used_binding_mask = + local_executable->layout->used_bindings; + iree_host_size_t used_binding_count = + iree_math_count_ones_u64(used_binding_mask); + + iree_hal_cmd_dispatch_t* cmd = NULL; + iree_host_size_t total_cmd_size = + sizeof(*cmd) + push_constant_count * sizeof(uint32_t) + + used_binding_count * sizeof(iree_hal_executable_binding_ptr_t) + + used_binding_count * sizeof(iree_device_size_t); + IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->arena, + total_cmd_size, (void**)&cmd)); + + cmd->executable = local_executable; + cmd->ordinal = entry_point; + + uint32_t workgroup_count[3] = {workgroup_x, workgroup_y, workgroup_z}; + // TODO(benvanik): expose on API or keep fixed on executable. + uint32_t workgroup_size[3] = {1, 1, 1}; + iree_task_dispatch_initialize(command_buffer->scope, + iree_task_make_dispatch_closure( + iree_hal_cmd_dispatch_tile, (uintptr_t)cmd), + workgroup_size, workgroup_count, &cmd->task); + + // Copy only the push constant range used by the executable. + uint8_t* cmd_ptr = (uint8_t*)cmd + sizeof(*cmd); + cmd->push_constants = (uint32_t*)cmd_ptr; + memcpy(cmd->push_constants, command_buffer->state.push_constants, + push_constant_count * sizeof(*cmd->push_constants)); + cmd_ptr += push_constant_count * sizeof(*cmd->push_constants); + + // Produce the dense binding list based on the declared bindings used. + // This allows us to change the descriptor sets and bindings counts supported + // in the HAL independent of any executable as each executable just gets the + // flat dense list and doesn't care about our descriptor set stuff. + // + // Note that we are just directly setting the binding data pointers here with + // no ownership/retaining/etc - it's part of the HAL contract that buffers are + // kept valid for the duration they may be in use. + cmd->bindings = (iree_hal_executable_binding_ptr_t*)cmd_ptr; + cmd_ptr += used_binding_count * sizeof(*cmd->bindings); + cmd->binding_lengths = (iree_device_size_t*)cmd_ptr; + cmd_ptr += used_binding_count * sizeof(*cmd->binding_lengths); + iree_host_size_t binding_base = 0; + for (iree_host_size_t i = 0; i < used_binding_count; ++i) { + int mask_offset = iree_math_count_trailing_zeros_u64(used_binding_mask); + int binding_ordinal = binding_base + mask_offset; + binding_base += mask_offset + 1; + used_binding_mask = used_binding_mask >> (mask_offset + 1); + cmd->bindings[i] = command_buffer->state.bindings[binding_ordinal]; + cmd->binding_lengths[i] = + command_buffer->state.binding_lengths[binding_ordinal]; + if (!cmd->bindings[i]) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "(flat) binding %d is NULL", binding_ordinal); + } + } + + *out_cmd = cmd; + return iree_hal_task_command_buffer_emit_execution_task(command_buffer, + &cmd->task.header); +} + +static iree_status_t iree_hal_task_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_cmd_dispatch_t* cmd = NULL; + return iree_hal_task_command_buffer_build_dispatch( + base_command_buffer, executable, entry_point, workgroup_x, workgroup_y, + workgroup_z, &cmd); +} + +static iree_status_t iree_hal_task_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. + iree_hal_buffer_mapping_t buffer_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + workgroups_buffer, IREE_HAL_MEMORY_ACCESS_READ, workgroups_offset, + 3 * sizeof(uint32_t), &buffer_mapping)); + + iree_hal_cmd_dispatch_t* cmd = NULL; + IREE_RETURN_IF_ERROR(iree_hal_task_command_buffer_build_dispatch( + base_command_buffer, executable, entry_point, 0, 0, 0, &cmd)); + cmd->task.workgroup_count.ptr = (const uint32_t*)buffer_mapping.contents.data; + cmd->task.header.flags |= IREE_TASK_FLAG_DISPATCH_INDIRECT; + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_vtable_t +//===----------------------------------------------------------------------===// + +static const iree_hal_command_buffer_vtable_t + iree_hal_task_command_buffer_vtable = { + .destroy = iree_hal_task_command_buffer_destroy, + .allowed_categories = iree_hal_task_command_buffer_allowed_categories, + .begin = iree_hal_task_command_buffer_begin, + .end = iree_hal_task_command_buffer_end, + .execution_barrier = iree_hal_task_command_buffer_execution_barrier, + .signal_event = iree_hal_task_command_buffer_signal_event, + .reset_event = iree_hal_task_command_buffer_reset_event, + .wait_events = iree_hal_task_command_buffer_wait_events, + .discard_buffer = iree_hal_task_command_buffer_discard_buffer, + .fill_buffer = iree_hal_task_command_buffer_fill_buffer, + .update_buffer = iree_hal_task_command_buffer_update_buffer, + .copy_buffer = iree_hal_task_command_buffer_copy_buffer, + .push_constants = iree_hal_task_command_buffer_push_constants, + .push_descriptor_set = iree_hal_task_command_buffer_push_descriptor_set, + .bind_descriptor_set = iree_hal_task_command_buffer_bind_descriptor_set, + .dispatch = iree_hal_task_command_buffer_dispatch, + .dispatch_indirect = iree_hal_task_command_buffer_dispatch_indirect, +}; diff --git a/iree/hal/local/task_command_buffer.h b/iree/hal/local/task_command_buffer.h new file mode 100644 index 0000000000000..ece73886094cc --- /dev/null +++ b/iree/hal/local/task_command_buffer.h @@ -0,0 +1,60 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_TASK_COMMAND_BUFFER_H_ +#define IREE_HAL_LOCAL_TASK_COMMAND_BUFFER_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/local/arena.h" +#include "iree/hal/local/task_queue_state.h" +#include "iree/task/scope.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +iree_status_t iree_hal_task_command_buffer_create( + iree_hal_device_t* device, iree_task_scope_t* scope, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_arena_block_pool_t* block_pool, + iree_hal_command_buffer_t** out_command_buffer); + +// Issues a recorded command buffer using the serial |queue_state|. +// |queue_state| is used to track the synchronization scope of the queue from +// prior commands such as signaled events and will be mutated as events are +// reset or new events are signaled. +// +// |retire_task| will be scheduled once all commands issued from the command +// buffer retire and can be used as a fence point. +// +// Any new tasks that are allocated as part of the issue operation (such as +// barrier tasks to handle event synchronization) will be acquired from |arena|. +// The lifetime of |arena| must be at least that of |retire_task| ensuring that +// all of the allocated commands issued have completed and their memory in the +// arena can be recycled. +// +// |pending_submission| will receive the ready list of commands and must be +// submitted to the executor (or discarded on failure) by the caller. +iree_status_t iree_hal_task_command_buffer_issue( + iree_hal_command_buffer_t* command_buffer, + iree_hal_task_queue_state_t* queue_state, iree_task_t* retire_task, + iree_arena_allocator_t* arena, iree_task_submission_t* pending_submission); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_TASK_COMMAND_BUFFER_H_ diff --git a/iree/hal/local/task_device.c b/iree/hal/local/task_device.c new file mode 100644 index 0000000000000..63740116da946 --- /dev/null +++ b/iree/hal/local/task_device.c @@ -0,0 +1,349 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/task_device.h" + +#include "iree/base/tracing.h" +#include "iree/hal/local/arena.h" +#include "iree/hal/local/event_pool.h" +#include "iree/hal/local/local_descriptor_set.h" +#include "iree/hal/local/local_descriptor_set_layout.h" +#include "iree/hal/local/local_executable.h" +#include "iree/hal/local/local_executable_cache.h" +#include "iree/hal/local/local_executable_layout.h" +#include "iree/hal/local/task_command_buffer.h" +#include "iree/hal/local/task_event.h" +#include "iree/hal/local/task_queue.h" +#include "iree/hal/local/task_semaphore.h" + +#define IREE_HAL_LOCAL_TASK_EVENT_POOL_CAPACITY 32 + +typedef struct { + iree_hal_resource_t resource; + iree_string_view_t identifier; + + // Block pool used for small allocations like tasks and submissions. + iree_arena_block_pool_t small_block_pool; + + // Block pool used for command buffers with a larger block size (as command + // buffers can contain inlined data uploads). + iree_arena_block_pool_t large_block_pool; + + // iree_event_t pool for semaphore wait operations. + iree_hal_local_event_pool_t* event_pool; + + iree_task_executor_t* executor; + + iree_host_size_t loader_count; + iree_hal_executable_loader_t** loaders; + + iree_allocator_t host_allocator; + iree_hal_allocator_t* device_allocator; + + iree_host_size_t queue_count; + iree_hal_task_queue_t queues[]; +} iree_hal_task_device_t; + +static const iree_hal_device_vtable_t iree_hal_task_device_vtable; + +static iree_hal_task_device_t* iree_hal_task_device_cast( + iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_task_device_vtable); + return (iree_hal_task_device_t*)base_value; +} + +void iree_hal_task_device_params_initialize( + iree_hal_task_device_params_t* out_params) { + out_params->arena_block_size = 32 * 1024; + out_params->queue_count = 8; +} + +static iree_status_t iree_hal_task_device_check_params( + const iree_hal_task_device_params_t* params) { + if (params->arena_block_size < 4096) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "arena block size too small (< 4096 bytes)"); + } + if (params->queue_count == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "at least one queue is required"); + } + return iree_ok_status(); +} + +iree_status_t iree_hal_task_device_create( + iree_string_view_t identifier, const iree_hal_task_device_params_t* params, + iree_task_executor_t* executor, iree_host_size_t loader_count, + iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(params); + IREE_ASSERT_ARGUMENT(!loader_count || loaders); + IREE_ASSERT_ARGUMENT(out_device); + *out_device = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, + iree_hal_task_device_check_params(params)); + + iree_hal_task_device_t* device = NULL; + iree_host_size_t total_size = + sizeof(*device) + params->queue_count * sizeof(*device->queues) + + identifier.size + loader_count * sizeof(*device->loaders); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&device); + if (iree_status_is_ok(status)) { + memset(device, 0, total_size); + iree_hal_resource_initialize(&iree_hal_task_device_vtable, + &device->resource); + iree_string_view_append_to_buffer( + identifier, &device->identifier, + (char*)device + sizeof(*device) + + params->queue_count * sizeof(*device->queues)); + device->host_allocator = host_allocator; + iree_arena_block_pool_initialize(4096, host_allocator, + &device->small_block_pool); + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->large_block_pool); + device->event_pool = NULL; + + device->executor = executor; + iree_task_executor_retain(device->executor); + + device->loader_count = loader_count; + device->loaders = + (iree_hal_executable_loader_t**)((uint8_t*)device->identifier.data + + identifier.size); + for (iree_host_size_t i = 0; i < device->loader_count; ++i) { + device->loaders[i] = loaders[i]; + iree_hal_executable_loader_retain(device->loaders[i]); + } + + device->queue_count = params->queue_count; + for (iree_host_size_t i = 0; i < device->queue_count; ++i) { + // TODO(benvanik): add a number to each queue ID. + iree_hal_task_queue_initialize(device->identifier, device->executor, + &device->small_block_pool, + &device->queues[i]); + } + } + + if (iree_status_is_ok(status)) { + status = iree_hal_allocator_create_heap(identifier, host_allocator, + &device->device_allocator); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_local_event_pool_allocate( + IREE_HAL_LOCAL_TASK_EVENT_POOL_CAPACITY, host_allocator, + &device->event_pool); + } + + if (iree_status_is_ok(status)) { + *out_device = (iree_hal_device_t*)device; + } else { + iree_hal_device_release((iree_hal_device_t*)device); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_task_device_destroy(iree_hal_device_t* base_device) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < device->queue_count; ++i) { + iree_hal_task_queue_deinitialize(&device->queues[i]); + } + for (iree_host_size_t i = 0; i < device->loader_count; ++i) { + iree_hal_executable_loader_release(device->loaders[i]); + } + iree_task_executor_release(device->executor); + iree_hal_local_event_pool_free(device->event_pool); + iree_arena_block_pool_deinitialize(&device->large_block_pool); + iree_arena_block_pool_deinitialize(&device->small_block_pool); + iree_hal_allocator_release(device->device_allocator); + iree_allocator_free(host_allocator, device); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_string_view_t iree_hal_task_device_id( + iree_hal_device_t* base_device) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + return device->identifier; +} + +static iree_allocator_t iree_hal_task_device_host_allocator( + iree_hal_device_t* base_device) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + return device->host_allocator; +} + +static iree_hal_allocator_t* iree_hal_task_device_allocator( + iree_hal_device_t* base_device) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + return device->device_allocator; +} + +static iree_status_t iree_hal_task_device_create_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + // TODO(benvanik): prevent the need for taking a scope here. We need it to + // construct the tasks as we record but unfortunately then that means we would + // need to know which queue we'd be submitting against ahead of time. + return iree_hal_task_command_buffer_create( + base_device, &device->queues[0].scope, mode, command_categories, + &device->large_block_pool, out_command_buffer); +} + +static iree_status_t iree_hal_task_device_create_descriptor_set( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_t* set_layout, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set) { + return iree_hal_local_descriptor_set_create(set_layout, binding_count, + bindings, out_descriptor_set); +} + +static iree_status_t iree_hal_task_device_create_descriptor_set_layout( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + return iree_hal_local_descriptor_set_layout_create( + usage_type, binding_count, bindings, + iree_hal_device_host_allocator(base_device), out_descriptor_set_layout); +} + +static iree_status_t iree_hal_task_device_create_event( + iree_hal_device_t* base_device, iree_hal_event_t** out_event) { + return iree_hal_task_event_create(iree_hal_device_host_allocator(base_device), + out_event); +} + +static iree_status_t iree_hal_task_device_create_executable_cache( + iree_hal_device_t* base_device, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + return iree_hal_local_executable_cache_create( + identifier, device->loader_count, device->loaders, + iree_hal_device_host_allocator(base_device), out_executable_cache); +} + +static iree_status_t iree_hal_task_device_create_executable_layout( + iree_hal_device_t* base_device, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, + iree_hal_executable_layout_t** out_executable_layout) { + return iree_hal_local_executable_layout_create( + set_layout_count, set_layouts, push_constants, + iree_hal_device_host_allocator(base_device), out_executable_layout); +} + +static iree_status_t iree_hal_task_device_create_semaphore( + iree_hal_device_t* base_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + return iree_hal_task_semaphore_create(device->event_pool, initial_value, + device->host_allocator, out_semaphore); +} + +// Returns the queue index to submit work to based on the |queue_affinity|. +// +// If we wanted to have dedicated transfer queues we'd fork off based on +// command_categories. For now all queues are general purpose. +static iree_host_size_t iree_hal_device_select_queue( + iree_hal_task_device_t* device, + iree_hal_command_category_t command_categories, uint64_t queue_affinity) { + // TODO(benvanik): evaluate if we want to obscure this mapping a bit so that + // affinity really means "equivalent affinities map to equivalent queues" and + // not a specific queue index. + return queue_affinity % device->queue_count; +} + +static iree_status_t iree_hal_task_device_queue_submit( + iree_hal_device_t* base_device, + iree_hal_command_category_t command_categories, uint64_t queue_affinity, + iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + iree_host_size_t queue_index = + iree_hal_device_select_queue(device, command_categories, queue_affinity); + return iree_hal_task_queue_submit(&device->queues[queue_index], batch_count, + batches); +} + +static iree_status_t iree_hal_task_device_wait_semaphores_with_deadline( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + return iree_hal_task_semaphore_multi_wait(wait_mode, semaphore_list, + deadline_ns, device->event_pool, + &device->large_block_pool); +} + +static iree_status_t iree_hal_task_device_wait_semaphores_with_timeout( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, + iree_duration_t timeout_ns) { + return iree_hal_task_device_wait_semaphores_with_deadline( + base_device, wait_mode, semaphore_list, + iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +static iree_status_t iree_hal_task_device_wait_idle_with_deadline( + iree_hal_device_t* base_device, iree_time_t deadline_ns) { + iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < device->queue_count; ++i) { + status = iree_hal_task_queue_wait_idle_with_deadline(&device->queues[i], + deadline_ns); + if (!iree_status_is_ok(status)) break; + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_task_device_wait_idle_with_timeout( + iree_hal_device_t* base_device, iree_duration_t timeout_ns) { + return iree_hal_task_device_wait_idle_with_deadline( + base_device, iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +static const iree_hal_device_vtable_t iree_hal_task_device_vtable = { + .destroy = iree_hal_task_device_destroy, + .id = iree_hal_task_device_id, + .host_allocator = iree_hal_task_device_host_allocator, + .device_allocator = iree_hal_task_device_allocator, + .create_command_buffer = iree_hal_task_device_create_command_buffer, + .create_descriptor_set = iree_hal_task_device_create_descriptor_set, + .create_descriptor_set_layout = + iree_hal_task_device_create_descriptor_set_layout, + .create_event = iree_hal_task_device_create_event, + .create_executable_cache = iree_hal_task_device_create_executable_cache, + .create_executable_layout = iree_hal_task_device_create_executable_layout, + .create_semaphore = iree_hal_task_device_create_semaphore, + .queue_submit = iree_hal_task_device_queue_submit, + .wait_semaphores_with_deadline = + iree_hal_task_device_wait_semaphores_with_deadline, + .wait_semaphores_with_timeout = + iree_hal_task_device_wait_semaphores_with_timeout, + .wait_idle_with_deadline = iree_hal_task_device_wait_idle_with_deadline, + .wait_idle_with_timeout = iree_hal_task_device_wait_idle_with_timeout, +}; diff --git a/iree/hal/local/task_device.h b/iree/hal/local/task_device.h new file mode 100644 index 0000000000000..c1d1edc6a34e0 --- /dev/null +++ b/iree/hal/local/task_device.h @@ -0,0 +1,58 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_TASK_DEVICE_H_ +#define IREE_HAL_LOCAL_TASK_DEVICE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/local/executable_loader.h" +#include "iree/task/executor.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Parameters configuring an iree_hal_task_device_t. +// Must be initialized with iree_hal_task_device_params_initialize prior to use. +typedef struct { + // Number of queues exposed on the device. + // Each queue acts as a separate synchronization scope where all work executes + // concurrently unless prohibited by semaphores. + iree_host_size_t queue_count; + + // Total size of each block in the device shared block pool. + // Larger sizes will lower overhead and ensure the heap isn't hit for + // transient allocations while also increasing memory consumption. + iree_host_size_t arena_block_size; +} iree_hal_task_device_params_t; + +// Initializes |out_params| to default values. +void iree_hal_task_device_params_initialize( + iree_hal_task_device_params_t* out_params); + +// Creates a new iree/task/-based local CPU device that uses |executor| for +// scheduling tasks. |loaders| is the set of executable loaders that are +// available for loading in the device context. +iree_status_t iree_hal_task_device_create( + iree_string_view_t identifier, const iree_hal_task_device_params_t* params, + iree_task_executor_t* executor, iree_host_size_t loader_count, + iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator, + iree_hal_device_t** out_device); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_TASK_DEVICE_H_ diff --git a/iree/hal/local/task_driver.c b/iree/hal/local/task_driver.c new file mode 100644 index 0000000000000..b017e8f986feb --- /dev/null +++ b/iree/hal/local/task_driver.c @@ -0,0 +1,133 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/task_driver.h" + +#include "iree/base/tracing.h" + +#define IREE_HAL_TASK_DEVICE_ID_DEFAULT 0 + +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + + iree_string_view_t identifier; + iree_hal_task_device_params_t default_params; + + iree_task_executor_t* executor; + + iree_host_size_t loader_count; + iree_hal_executable_loader_t* loaders[]; +} iree_hal_task_driver_t; + +static const iree_hal_driver_vtable_t iree_hal_task_driver_vtable; + +static iree_hal_task_driver_t* iree_hal_task_driver_cast( + iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_task_driver_vtable); + return (iree_hal_task_driver_t*)base_value; +} + +iree_status_t iree_hal_task_driver_create( + iree_string_view_t identifier, + const iree_hal_task_device_params_t* default_params, + iree_task_executor_t* executor, iree_host_size_t loader_count, + iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(default_params); + IREE_ASSERT_ARGUMENT(!loader_count || loaders); + IREE_ASSERT_ARGUMENT(out_driver); + *out_driver = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_task_driver_t* driver = NULL; + iree_host_size_t total_size = sizeof(*driver) + + loader_count * sizeof(*driver->loaders) + + identifier.size; + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&driver); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_task_driver_vtable, + &driver->resource); + driver->host_allocator = host_allocator; + + iree_string_view_append_to_buffer( + identifier, &driver->identifier, + (char*)driver + total_size - identifier.size); + memcpy(&driver->default_params, default_params, + sizeof(driver->default_params)); + + driver->executor = executor; + iree_task_executor_retain(driver->executor); + + driver->loader_count = loader_count; + for (iree_host_size_t i = 0; i < driver->loader_count; ++i) { + driver->loaders[i] = loaders[i]; + iree_hal_executable_loader_retain(driver->loaders[i]); + } + } + + if (iree_status_is_ok(status)) { + *out_driver = (iree_hal_driver_t*)driver; + } else { + iree_hal_driver_release((iree_hal_driver_t*)driver); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_task_driver_destroy(iree_hal_driver_t* base_driver) { + iree_hal_task_driver_t* driver = iree_hal_task_driver_cast(base_driver); + iree_allocator_t host_allocator = driver->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < driver->loader_count; ++i) { + iree_hal_executable_loader_release(driver->loaders[i]); + } + iree_task_executor_release(driver->executor); + iree_allocator_free(host_allocator, driver); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_task_driver_query_available_devices( + iree_hal_driver_t* base_driver, iree_allocator_t allocator, + iree_hal_device_info_t** out_device_infos, + iree_host_size_t* out_device_info_count) { + static const iree_hal_device_info_t device_infos[1] = { + { + .device_id = IREE_HAL_TASK_DEVICE_ID_DEFAULT, + .name = iree_string_view_literal("default"), + }, + }; + *out_device_info_count = IREE_ARRAYSIZE(device_infos); + return iree_allocator_clone( + allocator, iree_make_const_byte_span(device_infos, sizeof(device_infos)), + (void**)out_device_infos); +} + +static iree_status_t iree_hal_task_driver_create_device( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_allocator_t allocator, iree_hal_device_t** out_device) { + iree_hal_task_driver_t* driver = iree_hal_task_driver_cast(base_driver); + return iree_hal_task_device_create( + driver->identifier, &driver->default_params, driver->executor, + driver->loader_count, driver->loaders, allocator, out_device); +} + +static const iree_hal_driver_vtable_t iree_hal_task_driver_vtable = { + .destroy = iree_hal_task_driver_destroy, + .query_available_devices = iree_hal_task_driver_query_available_devices, + .create_device = iree_hal_task_driver_create_device, +}; diff --git a/iree/hal/local/task_driver.h b/iree/hal/local/task_driver.h new file mode 100644 index 0000000000000..92117dd50344a --- /dev/null +++ b/iree/hal/local/task_driver.h @@ -0,0 +1,42 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_TASK_DRIVER_H_ +#define IREE_HAL_LOCAL_TASK_DRIVER_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/local/executable_loader.h" +#include "iree/hal/local/task_device.h" +#include "iree/task/executor.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a new iree/task/-based local CPU driver that creates devices sharing +// the same |executor| for scheduling tasks. |loaders| is the set of executable +// loaders that are available for loading in each device context. +iree_status_t iree_hal_task_driver_create( + iree_string_view_t identifier, + const iree_hal_task_device_params_t* default_params, + iree_task_executor_t* executor, iree_host_size_t loader_count, + iree_hal_executable_loader_t** loaders, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_TASK_DRIVER_H_ diff --git a/iree/hal/local/task_event.c b/iree/hal/local/task_event.c new file mode 100644 index 0000000000000..6152417052c28 --- /dev/null +++ b/iree/hal/local/task_event.c @@ -0,0 +1,63 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/task_event.h" + +#include "iree/base/tracing.h" + +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; +} iree_hal_task_event_t; + +static const iree_hal_event_vtable_t iree_hal_task_event_vtable; + +static iree_hal_task_event_t* iree_hal_task_event_cast( + iree_hal_event_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_task_event_vtable); + return (iree_hal_task_event_t*)base_value; +} + +iree_status_t iree_hal_task_event_create(iree_allocator_t host_allocator, + iree_hal_event_t** out_event) { + IREE_ASSERT_ARGUMENT(out_event); + *out_event = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_task_event_t* event = NULL; + iree_status_t status = + iree_allocator_malloc(host_allocator, sizeof(*event), (void**)&event); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_task_event_vtable, &event->resource); + event->host_allocator = host_allocator; + *out_event = (iree_hal_event_t*)event; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_task_event_destroy(iree_hal_event_t* base_event) { + iree_hal_task_event_t* event = iree_hal_task_event_cast(base_event); + iree_allocator_t host_allocator = event->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, event); + + IREE_TRACE_ZONE_END(z0); +} + +static const iree_hal_event_vtable_t iree_hal_task_event_vtable = { + .destroy = iree_hal_task_event_destroy, +}; diff --git a/iree/hal/metal/registration/driver_module.h b/iree/hal/local/task_event.h similarity index 71% rename from iree/hal/metal/registration/driver_module.h rename to iree/hal/local/task_event.h index edb6c05c4d4b6..e4009be26e7a0 100644 --- a/iree/hal/metal/registration/driver_module.h +++ b/iree/hal/local/task_event.h @@ -12,20 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_ -#define IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_ +#ifndef IREE_HAL_LOCAL_TASK_EVENT_H_ +#define IREE_HAL_LOCAL_TASK_EVENT_H_ +#include "iree/base/api.h" #include "iree/hal/api.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_metal_driver_module_register(iree_hal_driver_registry_t* registry); +iree_status_t iree_hal_task_event_create(iree_allocator_t host_allocator, + iree_hal_event_t** out_event); #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // IREE_HAL_METAL_REGISTRATION_DRIVER_MODULE_H_ +#endif // IREE_HAL_LOCAL_TASK_EVENT_H_ diff --git a/iree/hal/local/task_queue.c b/iree/hal/local/task_queue.c new file mode 100644 index 0000000000000..6a45f3565dd16 --- /dev/null +++ b/iree/hal/local/task_queue.c @@ -0,0 +1,523 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/task_queue.h" + +#include "iree/base/tracing.h" +#include "iree/hal/local/task_command_buffer.h" +#include "iree/hal/local/task_semaphore.h" +#include "iree/task/submission.h" + +// Each submission is turned into a DAG for execution: +// +// +--------------------+ To preserve the sequential issue order an edge is +// | (previous issue) | added between the previous outstanding issue (if +// +--------------------+ it exists) such that all issues run in the order +// | they were submitted to the queue. Note that this +// v is *only* the issue; the commands issued by two +// +--------------------+ submissions may still overlap and are only +// | sequence barrier | guaranteed to begin execution in order. +// +--------------------+ +// | +// | +--------------+ +// +-> | +--------------+ Unsatisfied waits are scheduled as wait tasks and +// . +-| sema waits | block the issuing of commands until all have +// . +--------------+ been satisfied. If the wait is immediately +// . | | | | | following a signal from the same queue then it +// +--------+-+-+-+-+ elided - only cross-queue or external waits +// | actually go down to system wait handles. +// v +// +--------------------+ Command buffers in the batch are issued in-order +// | command issue | as if all commands had been recorded into the same +// +--------------------+ command buffer (excluding recording state like +// | push constants). The dependencies between commands +// | +--------------+ are determined by the events and barriers recorded +// +-> | +--------------+ in each command buffer. +// . +-| commands | +// . +--------------+ +// . | | | | | +// +--------+-+-+-+-+ +// | +// v +// +--------------------+ After all commands within the batch complete the +// | semaphore signals | submission is retired and all semaphores are +// +--------------------+ signaled. Note that this may happen *before* other +// | earlier submissions complete if there were no +// ... dependencies between the commands in each batch. +// +// Could this be simplified? Probably. Improvements to the task system to allow +// for efficient multiwaits and better stitching of independent DAGs would help. + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Clones a list of semaphores into an |arena| and initializes |out_target_list| +// to reference the newly-cloned data. +static iree_status_t iree_hal_semaphore_list_clone( + const iree_hal_semaphore_list_t* source_list, iree_arena_allocator_t* arena, + iree_hal_semaphore_list_t* out_target_list) { + iree_host_size_t semaphores_size = + source_list->count * sizeof(out_target_list->semaphores[0]); + iree_host_size_t payload_values_size = + source_list->count * sizeof(out_target_list->payload_values[0]); + iree_host_size_t total_size = semaphores_size + payload_values_size; + uint8_t* buffer = NULL; + IREE_RETURN_IF_ERROR(iree_arena_allocate(arena, total_size, (void**)&buffer)); + + out_target_list->count = source_list->count; + out_target_list->semaphores = (iree_hal_semaphore_t**)buffer; + out_target_list->payload_values = (uint64_t*)(buffer + semaphores_size); + + for (iree_host_size_t i = 0; i < source_list->count; ++i) { + out_target_list->semaphores[i] = source_list->semaphores[i]; + iree_hal_semaphore_retain(out_target_list->semaphores[i]); + out_target_list->payload_values[i] = source_list->payload_values[i]; + } + + return iree_ok_status(); +} + +static void iree_hal_semaphore_list_release(iree_hal_semaphore_list_t* list) { + for (iree_host_size_t i = 0; i < list->count; ++i) { + iree_hal_semaphore_release(list->semaphores[i]); + } +} + +//===----------------------------------------------------------------------===// +// iree_hal_task_queue_wait_cmd_t +//===----------------------------------------------------------------------===// + +// Task to fork out and wait on one or more semaphores. +// This optimizes for same-queue semaphore chaining by ensuring that semaphores +// used to stitch together subsequent submissions never have to go to the system +// to wait as the implicit queue ordering ensures that the signals would have +// happened prior to the sequence command being executed. Cross-queue semaphores +// will still cause waits if they have not yet been signaled. +typedef struct { + // Call to iree_hal_task_queue_wait_cmd. + iree_task_call_t task; + + // Arena used for the submission - additional tasks can be allocated from + // this. + iree_arena_allocator_t* arena; + + // A list of semaphores to wait on prior to issuing the rest of the + // submission. + iree_hal_semaphore_list_t wait_semaphores; +} iree_hal_task_queue_wait_cmd_t; + +// Forks out multiple wait tasks prior to issuing the commands. +static iree_status_t iree_hal_task_queue_wait_cmd( + uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + iree_hal_task_queue_wait_cmd_t* cmd = (iree_hal_task_queue_wait_cmd_t*)task; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < cmd->wait_semaphores.count; ++i) { + status = iree_hal_task_semaphore_enqueue_timepoint( + cmd->wait_semaphores.semaphores[i], + cmd->wait_semaphores.payload_values[i], + cmd->task.header.completion_task, cmd->arena, pending_submission); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) break; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Cleanup for iree_hal_task_queue_wait_cmd_t that releases the retained +// semaphores. +static void iree_hal_task_queue_wait_cmd_cleanup(iree_task_t* task, + iree_status_t status) { + iree_hal_task_queue_wait_cmd_t* cmd = (iree_hal_task_queue_wait_cmd_t*)task; + iree_hal_semaphore_list_release(&cmd->wait_semaphores); +} + +// Allocates and initializes a iree_hal_task_queue_wait_cmd_t task. +static iree_status_t iree_hal_task_queue_wait_cmd_allocate( + iree_task_scope_t* scope, const iree_hal_semaphore_list_t* wait_semaphores, + iree_arena_allocator_t* arena, iree_hal_task_queue_wait_cmd_t** out_cmd) { + iree_hal_task_queue_wait_cmd_t* cmd = NULL; + IREE_RETURN_IF_ERROR(iree_arena_allocate(arena, sizeof(*cmd), (void**)&cmd)); + iree_task_call_initialize( + scope, iree_task_make_call_closure(iree_hal_task_queue_wait_cmd, 0), + &cmd->task); + iree_task_set_cleanup_fn(&cmd->task.header, + iree_hal_task_queue_wait_cmd_cleanup); + cmd->arena = arena; + + // Clone the wait semaphores from the batch - we retain them and their + // payloads. + IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_clone(wait_semaphores, arena, + &cmd->wait_semaphores)); + + *out_cmd = cmd; + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_task_queue_issue_cmd_t +//===----------------------------------------------------------------------===// + +// Task to issue all the command buffers in the batch. +// After this task completes the commands have been issued but have not yet +// completed and the issued commands may complete in any order. +typedef struct { + // Call to iree_hal_task_queue_issue_cmd. + iree_task_call_t task; + + // Arena used for the submission - additional tasks can be allocated from + // this. + iree_arena_allocator_t* arena; + + // Nasty back reference to the queue so that we can clear the tail_issue_task + // if we are the last issue pending. + iree_hal_task_queue_t* queue; + + // Command buffers to be issued in the order the appeared in the submission. + iree_host_size_t command_buffer_count; + iree_hal_command_buffer_t* command_buffers[]; +} iree_hal_task_queue_issue_cmd_t; + +// Issues a set of command buffers without waiting for them to complete. +static iree_status_t iree_hal_task_queue_issue_cmd( + uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + iree_hal_task_queue_issue_cmd_t* cmd = (iree_hal_task_queue_issue_cmd_t*)task; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + // NOTE: it's ok for there to be no command buffers - in that case the + // submission was purely for synchronization. + if (cmd->command_buffer_count > 0) { + for (iree_host_size_t i = 0; i < cmd->command_buffer_count; ++i) { + status = iree_hal_task_command_buffer_issue( + cmd->command_buffers[i], &cmd->queue->state, + cmd->task.header.completion_task, cmd->arena, pending_submission); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) break; + } + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Cleanup for iree_hal_task_queue_issue_cmd_t that resets the queue state +// tracking the last in-flight issue. +static void iree_hal_task_queue_issue_cmd_cleanup(iree_task_t* task, + iree_status_t status) { + iree_hal_task_queue_issue_cmd_t* cmd = (iree_hal_task_queue_issue_cmd_t*)task; + + // Reset queue tail issue task if it was us. + iree_slim_mutex_lock(&cmd->queue->mutex); + if (cmd->queue->tail_issue_task == task) { + cmd->queue->tail_issue_task = NULL; + } + iree_slim_mutex_unlock(&cmd->queue->mutex); +} + +// Allocates and initializes a iree_hal_task_queue_issue_cmd_t task. +static iree_status_t iree_hal_task_queue_issue_cmd_allocate( + iree_task_scope_t* scope, iree_hal_task_queue_t* queue, + iree_task_t* retire_task, iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t** const command_buffers, + iree_arena_allocator_t* arena, iree_hal_task_queue_issue_cmd_t** out_cmd) { + iree_hal_task_queue_issue_cmd_t* cmd = NULL; + iree_host_size_t total_cmd_size = + sizeof(*cmd) + command_buffer_count * sizeof(*cmd->command_buffers); + IREE_RETURN_IF_ERROR( + iree_arena_allocate(arena, total_cmd_size, (void**)&cmd)); + iree_task_call_initialize( + scope, iree_task_make_call_closure(iree_hal_task_queue_issue_cmd, 0), + &cmd->task); + iree_task_set_completion_task(&cmd->task.header, retire_task); + iree_task_set_cleanup_fn(&cmd->task.header, + iree_hal_task_queue_issue_cmd_cleanup); + cmd->arena = arena; + cmd->queue = queue; + + cmd->command_buffer_count = command_buffer_count; + memcpy(cmd->command_buffers, command_buffers, + cmd->command_buffer_count * sizeof(*cmd->command_buffers)); + + *out_cmd = cmd; + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_task_queue_retire_cmd_t +//===----------------------------------------------------------------------===// + +// Task to retire the submission and free the transient memory allocated for +// it. The task is issued only once all commands from all command buffers in +// the submission complete. Semaphores will be signaled and dependent +// submissions may be issued. +typedef struct { + // Call to iree_hal_task_queue_retire_cmd. + iree_task_call_t task; + + // Original arena used for all transient allocations required for the + // submission. All queue-related commands are allocated from this, **including + // this retire command**. + iree_arena_allocator_t arena; + + // A list of semaphores to signal upon retiring. + iree_hal_semaphore_list_t signal_semaphores; +} iree_hal_task_queue_retire_cmd_t; + +// Retires a submission by signaling semaphores to their desired value and +// disposing of the temporary arena memory used for the submission. +static iree_status_t iree_hal_task_queue_retire_cmd( + uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + iree_hal_task_queue_retire_cmd_t* cmd = + (iree_hal_task_queue_retire_cmd_t*)task; + IREE_TRACE_ZONE_BEGIN(z0); + + // Signal all semaphores to their new values. + // Note that if any signal fails then the whole command will fail and all + // semaphores will be signaled to the failure state. + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < cmd->signal_semaphores.count; ++i) { + status = + iree_hal_semaphore_signal(cmd->signal_semaphores.semaphores[i], + cmd->signal_semaphores.payload_values[i]); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) break; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Cleanup for iree_hal_task_queue_retire_cmd_t that ensures that the arena +// holding the submission is properly disposed and that semaphores are signaled +// (or signaled to failure if the command failed). +static void iree_hal_task_queue_retire_cmd_cleanup(iree_task_t* task, + iree_status_t status) { + iree_hal_task_queue_retire_cmd_t* cmd = + (iree_hal_task_queue_retire_cmd_t*)task; + + // If the command failed then fail all semaphores to ensure future + // submissions fail as well (including those on other queues). + if (!iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < cmd->signal_semaphores.count; ++i) { + iree_hal_semaphore_fail(cmd->signal_semaphores.semaphores[i], + iree_status_clone(status)); + } + } + + // Release all semaphores. + iree_hal_semaphore_list_release(&cmd->signal_semaphores); + + // Drop all memory used by the submission (**including cmd**). + iree_arena_allocator_t arena = cmd->arena; + cmd = NULL; + iree_arena_deinitialize(&arena); +} + +// Allocates and initializes a iree_hal_task_queue_retire_cmd_t task. +// The command will own an arena that can be used for other submission-related +// allocations. +static iree_status_t iree_hal_task_queue_retire_cmd_allocate( + iree_task_scope_t* scope, + const iree_hal_semaphore_list_t* signal_semaphores, + iree_arena_block_pool_t* block_pool, + iree_hal_task_queue_retire_cmd_t** out_cmd) { + // Make an arena we'll use for allocating the command itself. + iree_arena_allocator_t arena; + iree_arena_initialize(block_pool, &arena); + + // Allocate the command from the arena. + iree_hal_task_queue_retire_cmd_t* cmd = NULL; + iree_status_t status = + iree_arena_allocate(&arena, sizeof(*cmd), (void**)&cmd); + if (iree_status_is_ok(status)) { + iree_task_call_initialize( + scope, iree_task_make_call_closure(iree_hal_task_queue_retire_cmd, 0), + &cmd->task); + iree_task_set_cleanup_fn(&cmd->task.header, + iree_hal_task_queue_retire_cmd_cleanup); + } + + // Clone the signal semaphores from the batch - we retain them and their + // payloads. + if (iree_status_is_ok(status)) { + status = iree_hal_semaphore_list_clone(signal_semaphores, &arena, + &cmd->signal_semaphores); + } + + if (iree_status_is_ok(status)) { + // Transfer ownership of the arena to command. + memcpy(&cmd->arena, &arena, sizeof(cmd->arena)); + *out_cmd = cmd; + } else { + iree_arena_deinitialize(&arena); + } + return status; +} + +//===----------------------------------------------------------------------===// +// iree_hal_task_queue_t +//===----------------------------------------------------------------------===// + +void iree_hal_task_queue_initialize(iree_string_view_t identifier, + iree_task_executor_t* executor, + iree_arena_block_pool_t* block_pool, + iree_hal_task_queue_t* out_queue) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_TEXT(z0, identifier.data, identifier.size); + + memset(out_queue, 0, sizeof(*out_queue)); + + out_queue->executor = executor; + iree_task_executor_retain(out_queue->executor); + out_queue->block_pool = block_pool; + + iree_task_scope_initialize(identifier, &out_queue->scope); + + iree_slim_mutex_initialize(&out_queue->mutex); + iree_hal_task_queue_state_initialize(&out_queue->state); + out_queue->tail_issue_task = NULL; + + IREE_TRACE_ZONE_END(z0); +} + +void iree_hal_task_queue_deinitialize(iree_hal_task_queue_t* queue) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_ignore( + iree_task_scope_wait_idle(&queue->scope, IREE_TIME_INFINITE_FUTURE)); + + iree_slim_mutex_lock(&queue->mutex); + IREE_ASSERT(!queue->tail_issue_task); + iree_slim_mutex_unlock(&queue->mutex); + + iree_hal_task_queue_state_deinitialize(&queue->state); + iree_slim_mutex_deinitialize(&queue->mutex); + iree_task_scope_deinitialize(&queue->scope); + iree_task_executor_release(queue->executor); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_task_queue_submit_batch( + iree_hal_task_queue_t* queue, const iree_hal_submission_batch_t* batch) { + // Task to retire the submission and free the transient memory allocated for + // it (including the command itself). We allocate this first so it can get an + // arena which we will use to allocate all other commands. + iree_hal_task_queue_retire_cmd_t* retire_cmd = NULL; + IREE_RETURN_IF_ERROR(iree_hal_task_queue_retire_cmd_allocate( + &queue->scope, &batch->signal_semaphores, queue->block_pool, + &retire_cmd)); + + // NOTE: if we fail from here on we must drop the retire_cmd arena. + iree_status_t status = iree_ok_status(); + + // A fence we'll use to detect when the entire submission has completed. + // TODO(benvanik): fold into the retire command. + iree_task_fence_t* fence = NULL; + status = + iree_task_executor_acquire_fence(queue->executor, &queue->scope, &fence); + iree_task_set_completion_task(&retire_cmd->task.header, &fence->header); + + // Task to fork and wait for unsatisfied semaphore dependencies. + // This is optional and only required if we have previous submissions still + // in-flight - if the queue is empty then we can directly schedule the waits. + iree_hal_task_queue_wait_cmd_t* wait_cmd = NULL; + if (iree_status_is_ok(status) && batch->wait_semaphores.count > 0) { + status = iree_hal_task_queue_wait_cmd_allocate( + &queue->scope, &batch->wait_semaphores, &retire_cmd->arena, &wait_cmd); + } + + // Task to issue all the command buffers in the batch. + // After this task completes the commands have been issued but have not yet + // completed and the issued commands may complete in any order. + iree_hal_task_queue_issue_cmd_t* issue_cmd = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_task_queue_issue_cmd_allocate( + &queue->scope, queue, &retire_cmd->task.header, + batch->command_buffer_count, batch->command_buffers, &retire_cmd->arena, + &issue_cmd); + } + + // Last chance for failure - from here on we are submitting. + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + iree_arena_deinitialize(&retire_cmd->arena); + return status; + } + + iree_task_submission_t submission; + iree_task_submission_initialize(&submission); + + // Sequencing: wait on semaphores or go directly into the executor queue. + if (wait_cmd != NULL) { + // Ensure that we only issue command buffers after all waits have completed. + iree_task_set_completion_task(&wait_cmd->task.header, + &issue_cmd->task.header); + iree_task_submission_enqueue(&submission, &wait_cmd->task.header); + } else { + // No waits needed; directly enqueue. + iree_task_submission_enqueue(&submission, &issue_cmd->task.header); + } + + iree_slim_mutex_lock(&queue->mutex); + + // If there is an in-flight issue pending then we need to chain onto that + // so that we ensure FIFO submission order is preserved. Note that we are only + // waiting for the issue to complete and *not* all of the commands that are + // issued. + if (queue->tail_issue_task != NULL) { + iree_task_set_completion_task(queue->tail_issue_task, + &issue_cmd->task.header); + } + queue->tail_issue_task = &issue_cmd->task.header; + + iree_slim_mutex_unlock(&queue->mutex); + + // Submit the tasks immediately. The executor may queue them up until we + // force the flush after all batches have been processed. + iree_task_executor_submit(queue->executor, &submission); + return iree_ok_status(); +} + +iree_status_t iree_hal_task_queue_submit( + iree_hal_task_queue_t* queue, iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) { + IREE_TRACE_ZONE_BEGIN(z0); + + // For now we process each batch independently. To elide additional semaphore + // work and prevent unneeded coordinator scheduling logic we could instead + // build the whole DAG prior to submitting. + for (iree_host_size_t i = 0; i < batch_count; ++i) { + const iree_hal_submission_batch_t* batch = &batches[i]; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_task_queue_submit_batch(queue, batch)); + } + + iree_task_executor_flush(queue->executor); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +iree_status_t iree_hal_task_queue_wait_idle_with_deadline( + iree_hal_task_queue_t* queue, iree_time_t deadline_ns) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_task_scope_wait_idle(&queue->scope, deadline_ns); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/local/task_queue.h b/iree/hal/local/task_queue.h new file mode 100644 index 0000000000000..5dfc449d244a6 --- /dev/null +++ b/iree/hal/local/task_queue.h @@ -0,0 +1,78 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_TASK_QUEUE_H_ +#define IREE_HAL_LOCAL_TASK_QUEUE_H_ + +#include "iree/base/api.h" +#include "iree/base/synchronization.h" +#include "iree/hal/api.h" +#include "iree/hal/local/arena.h" +#include "iree/hal/local/task_queue_state.h" +#include "iree/task/executor.h" +#include "iree/task/scope.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct { + // Shared executor that the queue submits tasks to. + iree_task_executor_t* executor; + + // Shared block pool for allocating submission transients (tasks/events/etc). + iree_arena_block_pool_t* block_pool; + + // Scope used for all tasks in the queue. + // This allows for easy waits on all outstanding queue tasks as well as + // differentiation of tasks within the executor. + iree_task_scope_t scope; + + // Guards queue state. Submissions and waits may come from any user thread and + // we do a bit of bookkeeping during command buffer issue that will come from + // an executor thread. + iree_slim_mutex_t mutex; + + // State tracking used during command buffer issue. + // The intra-queue synchronization (barriers/events) carries across command + // buffers and this is used to rendezvous the tasks in each set. + iree_hal_task_queue_state_t state; + + // The last active iree_hal_task_queue_issue_cmd_t submitted to the queue. + // If this is NULL then there are no issues pending - though there may still + // be active work that was previously issued. This is used to chain together + // issues in FIFO order such that all submissions *issue* in order but not + // *execute* in order. + iree_task_t* tail_issue_task; +} iree_hal_task_queue_t; + +void iree_hal_task_queue_initialize(iree_string_view_t identifier, + iree_task_executor_t* executor, + iree_arena_block_pool_t* block_pool, + iree_hal_task_queue_t* out_queue); + +void iree_hal_task_queue_deinitialize(iree_hal_task_queue_t* queue); + +iree_status_t iree_hal_task_queue_submit( + iree_hal_task_queue_t* queue, iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches); + +iree_status_t iree_hal_task_queue_wait_idle_with_deadline( + iree_hal_task_queue_t* queue, iree_time_t deadline_ns); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_TASK_QUEUE_H_ diff --git a/iree/hal/host/nop_event.cc b/iree/hal/local/task_queue_state.c similarity index 61% rename from iree/hal/host/nop_event.cc rename to iree/hal/local/task_queue_state.c index 6db302c753153..7e912414222b2 100644 --- a/iree/hal/host/nop_event.cc +++ b/iree/hal/local/task_queue_state.c @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/hal/host/nop_event.h" +#include "iree/hal/local/task_queue_state.h" -namespace iree { -namespace hal { -namespace host { +#include "iree/base/tracing.h" -NopEvent::NopEvent() = default; +void iree_hal_task_queue_state_initialize( + iree_hal_task_queue_state_t* out_queue_state) { + memset(out_queue_state, 0, sizeof(*out_queue_state)); +} -NopEvent::~NopEvent() = default; - -} // namespace host -} // namespace hal -} // namespace iree +void iree_hal_task_queue_state_deinitialize( + iree_hal_task_queue_state_t* queue_state) {} diff --git a/iree/hal/local/task_queue_state.h b/iree/hal/local/task_queue_state.h new file mode 100644 index 0000000000000..9b3c04be61553 --- /dev/null +++ b/iree/hal/local/task_queue_state.h @@ -0,0 +1,49 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_TASK_QUEUE_STATE_H_ +#define IREE_HAL_LOCAL_TASK_QUEUE_STATE_H_ + +#include "iree/base/api.h" +#include "iree/base/atomics.h" +#include "iree/hal/api.h" +#include "iree/task/scope.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// State tracking for an individual queue. +// +// Thread-compatible: only intended to be used by a queue with the submission +// lock held. +typedef struct { + // TODO(#4518): track event state. + int reserved; +} iree_hal_task_queue_state_t; + +// Initializes queue state with the given |identifier| used to annotate tasks +// submitted to the queue. +void iree_hal_task_queue_state_initialize( + iree_hal_task_queue_state_t* out_queue_state); + +// Deinitializes queue state and cleans up any tracking intermediates. +void iree_hal_task_queue_state_deinitialize( + iree_hal_task_queue_state_t* queue_state); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_TASK_QUEUE_STATE_H_ diff --git a/iree/hal/local/task_semaphore.c b/iree/hal/local/task_semaphore.c new file mode 100644 index 0000000000000..2eb63d67a00df --- /dev/null +++ b/iree/hal/local/task_semaphore.c @@ -0,0 +1,511 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/task_semaphore.h" + +#include + +#include "iree/base/synchronization.h" +#include "iree/base/tracing.h" +#include "iree/base/wait_handle.h" + +// Sentinel used the semaphore has failed and an error status is set. +#define IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE UINT64_MAX + +//===----------------------------------------------------------------------===// +// iree_hal_task_timepoint_t +//===----------------------------------------------------------------------===// + +// Represents a point in the timeline that someone is waiting to be reached. +// When the semaphore is signaled to at least the specified value then the +// given event will be signaled and the timepoint discarded. +// +// Instances are owned and retained by the caller that requested them - usually +// in the arena associated with the submission, but could be on the stack of a +// synchronously waiting thread. +typedef struct iree_hal_task_timepoint_s { + struct iree_hal_task_timepoint_s* next; + struct iree_hal_task_timepoint_s* prev; + uint64_t payload_value; + iree_event_t event; +} iree_hal_task_timepoint_t; + +// A doubly-linked FIFO list of timepoints. +// The order of the timepoints does *not* match increasing payload values but +// instead the order they were added to the list. +// +// Note that the timepoints are not owned by the list - this just nicely +// stitches together timepoints for the semaphore. +typedef struct { + iree_hal_task_timepoint_t* head; + iree_hal_task_timepoint_t* tail; +} iree_hal_task_timepoint_list_t; + +static void iree_hal_task_timepoint_list_initialize( + iree_hal_task_timepoint_list_t* out_list) { + memset(out_list, 0, sizeof(*out_list)); +} + +// Moves |source_list| into |out_target_list|. +// |source_list| will be reset and the prior contents of |out_target_list| will +// be discarded. +static void iree_hal_task_timepoint_list_move( + iree_hal_task_timepoint_list_t* source_list, + iree_hal_task_timepoint_list_t* out_target_list) { + memcpy(out_target_list, source_list, sizeof(*out_target_list)); + memset(source_list, 0, sizeof(*source_list)); +} + +// Appends a timepoint to the end of the timepoint list. +static void iree_hal_task_timepoint_list_append( + iree_hal_task_timepoint_list_t* list, + iree_hal_task_timepoint_t* timepoint) { + timepoint->next = NULL; + timepoint->prev = list->tail; + if (list->tail != NULL) { + list->tail->next = timepoint; + list->tail = timepoint; + } else { + list->head = timepoint; + list->tail = timepoint; + } +} + +// Erases a timepoint from the list. +static void iree_hal_task_timepoint_list_erase( + iree_hal_task_timepoint_list_t* list, + iree_hal_task_timepoint_t* timepoint) { + if (timepoint->prev != NULL) timepoint->prev->next = timepoint->next; + if (timepoint == list->head) list->head = timepoint->next; + if (timepoint == list->tail) list->tail = timepoint->prev; + timepoint->prev = NULL; + timepoint->next = NULL; +} + +// Scans the |pending_list| for all timepoints that are satisfied by the +// timeline having reached |payload_value|. Each satisfied timepoint will be +// moved to |out_ready_list|. +static void iree_hal_task_timepoint_list_take_ready( + iree_hal_task_timepoint_list_t* pending_list, uint64_t payload_value, + iree_hal_task_timepoint_list_t* out_ready_list) { + iree_hal_task_timepoint_list_initialize(out_ready_list); + iree_hal_task_timepoint_t* next = pending_list->head; + while (next != NULL) { + iree_hal_task_timepoint_t* timepoint = next; + next = timepoint->next; + bool is_satisfied = timepoint->payload_value <= payload_value; + if (!is_satisfied) continue; + + // Remove from pending list. + iree_hal_task_timepoint_list_erase(pending_list, timepoint); + + // Add to ready list. + iree_hal_task_timepoint_list_append(out_ready_list, timepoint); + } +} + +// Notifies all of the timepoints in the |ready_list| that their condition has +// been satisfied. |ready_list| will be reset as ownership of the events is +// held by the originator. +static void iree_hal_task_timepoint_list_notify_ready( + iree_hal_task_timepoint_list_t* ready_list) { + iree_hal_task_timepoint_t* next = ready_list->head; + while (next != NULL) { + iree_hal_task_timepoint_t* timepoint = next; + next = timepoint->next; + timepoint->next = NULL; + timepoint->prev = NULL; + iree_event_set(&timepoint->event); + } + iree_hal_task_timepoint_list_initialize(ready_list); +} + +//===----------------------------------------------------------------------===// +// iree_hal_task_semaphore_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + iree_hal_local_event_pool_t* event_pool; + + // Guards all mutable fields. We expect low contention on semaphores and since + // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler + // than trying to make the entire structure lock-free. + iree_slim_mutex_t mutex; + + // Current signaled value. May be IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE to + // indicate that the semaphore has been signaled for failure and + // |failure_status| contains the error. + uint64_t current_value; + + // OK or the status passed to iree_hal_semaphore_fail. Owned by the semaphore. + iree_status_t failure_status; + + // In-process notification signaled when the semaphore value changes. This is + // used exclusively for wait-ones to avoid going to the kernel for a full wait + // handle operation. + iree_notification_t notification; + + // A list of all reserved timepoints waiting for the semaphore to reach a + // certain payload value. + iree_hal_task_timepoint_list_t timepoint_list; +} iree_hal_task_semaphore_t; + +static const iree_hal_semaphore_vtable_t iree_hal_task_semaphore_vtable; + +static iree_hal_task_semaphore_t* iree_hal_task_semaphore_cast( + iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_task_semaphore_vtable); + return (iree_hal_task_semaphore_t*)base_value; +} + +iree_status_t iree_hal_task_semaphore_create( + iree_hal_local_event_pool_t* event_pool, uint64_t initial_value, + iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { + IREE_ASSERT_ARGUMENT(event_pool); + IREE_ASSERT_ARGUMENT(out_semaphore); + *out_semaphore = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_task_semaphore_t* semaphore = NULL; + iree_status_t status = iree_allocator_malloc( + host_allocator, sizeof(*semaphore), (void**)&semaphore); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_task_semaphore_vtable, + &semaphore->resource); + semaphore->host_allocator = host_allocator; + semaphore->event_pool = event_pool; + + iree_slim_mutex_initialize(&semaphore->mutex); + semaphore->current_value = initial_value; + semaphore->failure_status = iree_ok_status(); + iree_notification_initialize(&semaphore->notification); + iree_hal_task_timepoint_list_initialize(&semaphore->timepoint_list); + + *out_semaphore = (iree_hal_semaphore_t*)semaphore; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_task_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_task_semaphore_t* semaphore = + iree_hal_task_semaphore_cast(base_semaphore); + iree_allocator_t host_allocator = semaphore->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_free(semaphore->failure_status); + iree_notification_deinitialize(&semaphore->notification); + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_task_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + iree_hal_task_semaphore_t* semaphore = + iree_hal_task_semaphore_cast(base_semaphore); + + iree_slim_mutex_lock(&semaphore->mutex); + + *out_value = semaphore->current_value; + + iree_status_t status = iree_ok_status(); + if (*out_value >= IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE) { + status = iree_status_clone(semaphore->failure_status); + } + + iree_slim_mutex_unlock(&semaphore->mutex); + + return status; +} + +static iree_status_t iree_hal_task_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + iree_hal_task_semaphore_t* semaphore = + iree_hal_task_semaphore_cast(base_semaphore); + + iree_slim_mutex_lock(&semaphore->mutex); + + if (new_value <= semaphore->current_value) { + uint64_t current_value = semaphore->current_value; + iree_slim_mutex_unlock(&semaphore->mutex); + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "semaphore values must be monotonically " + "increasing; current_value=%" PRIu64 + ", new_value=%" PRIu64, + current_value, new_value); + } + + semaphore->current_value = new_value; + + // Scan for all timepoints that are now satisfied and move them to our local + // ready list. This way we can notify them without needing to continue holding + // the semaphore lock. + iree_hal_task_timepoint_list_t ready_list; + iree_hal_task_timepoint_list_take_ready(&semaphore->timepoint_list, new_value, + &ready_list); + + iree_notification_post(&semaphore->notification, IREE_ALL_WAITERS); + iree_slim_mutex_unlock(&semaphore->mutex); + + // Notify all waiters - note that this must happen outside the lock. + iree_hal_task_timepoint_list_notify_ready(&ready_list); + + return iree_ok_status(); +} + +static void iree_hal_task_semaphore_fail(iree_hal_semaphore_t* base_semaphore, + iree_status_t status) { + iree_hal_task_semaphore_t* semaphore = + iree_hal_task_semaphore_cast(base_semaphore); + + iree_slim_mutex_lock(&semaphore->mutex); + + // Try to set our local status - we only preserve the first failure so only + // do this if we are going from a valid semaphore to a failed one. + if (!iree_status_is_ok(semaphore->failure_status)) { + // Previous status was not OK; drop our new status. + IREE_IGNORE_ERROR(status); + iree_slim_mutex_unlock(&semaphore->mutex); + return; + } + + // Signal to our failure sentinel value. + semaphore->current_value = IREE_HAL_TASK_SEMAPHORE_FAILURE_VALUE; + semaphore->failure_status = status; + + // Take the whole timepoint list as we'll be signaling all of them. Since + // we hold the lock no other timepoints can be created while we are cleaning + // up. + iree_hal_task_timepoint_list_t ready_list; + iree_hal_task_timepoint_list_move(&semaphore->timepoint_list, &ready_list); + + iree_notification_post(&semaphore->notification, IREE_ALL_WAITERS); + iree_slim_mutex_unlock(&semaphore->mutex); + + // Notify all waiters - note that this must happen outside the lock. + iree_hal_task_timepoint_list_notify_ready(&ready_list); +} + +// Acquires a timepoint waiting for the given value. +// |out_timepoint| is owned by the caller and must be kept live until the +// timepoint has been reached (or it is cancelled by the caller). +static iree_status_t iree_hal_task_semaphore_acquire_timepoint( + iree_hal_task_semaphore_t* semaphore, uint64_t minimum_value, + iree_hal_task_timepoint_t* out_timepoint) { + memset(out_timepoint, 0, sizeof(*out_timepoint)); + out_timepoint->payload_value = minimum_value; + IREE_RETURN_IF_ERROR(iree_hal_local_event_pool_acquire( + semaphore->event_pool, 1, &out_timepoint->event)); + iree_hal_task_timepoint_list_append(&semaphore->timepoint_list, + out_timepoint); + return iree_ok_status(); +} + +typedef struct { + iree_task_wait_t task; + iree_hal_task_semaphore_t* semaphore; + iree_hal_task_timepoint_t timepoint; +} iree_hal_task_semaphore_wait_cmd_t; + +// Cleans up a wait task by returning the event used to the pool and - if the +// task failed - ensuring we scrub it from the timepoint list. +static void iree_hal_task_semaphore_wait_cmd_cleanup(iree_task_t* task, + iree_status_t status) { + iree_hal_task_semaphore_wait_cmd_t* cmd = + (iree_hal_task_semaphore_wait_cmd_t*)task; + iree_hal_local_event_pool_release(cmd->semaphore->event_pool, 1, + &cmd->timepoint.event); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + // Abort the timepoint. Note that this is not designed to be fast as + // semaphore failure is an exceptional case. + iree_slim_mutex_lock(&cmd->semaphore->mutex); + iree_hal_task_timepoint_list_erase(&cmd->semaphore->timepoint_list, + &cmd->timepoint); + iree_slim_mutex_unlock(&cmd->semaphore->mutex); + } +} + +iree_status_t iree_hal_task_semaphore_enqueue_timepoint( + iree_hal_semaphore_t* base_semaphore, uint64_t minimum_value, + iree_task_t* issue_task, iree_arena_allocator_t* arena, + iree_task_submission_t* submission) { + iree_hal_task_semaphore_t* semaphore = + iree_hal_task_semaphore_cast(base_semaphore); + + iree_slim_mutex_lock(&semaphore->mutex); + + iree_status_t status = iree_ok_status(); + if (semaphore->current_value >= minimum_value) { + // Fast path: already satisfied. + } else { + // Slow path: acquire a system wait handle and perform a full wait. + iree_hal_task_semaphore_wait_cmd_t* cmd = NULL; + status = iree_arena_allocate(arena, sizeof(*cmd), (void**)&cmd); + if (iree_status_is_ok(status)) { + status = iree_hal_task_semaphore_acquire_timepoint( + semaphore, minimum_value, &cmd->timepoint); + } + if (iree_status_is_ok(status)) { + iree_task_wait_initialize(issue_task->scope, cmd->timepoint.event, + &cmd->task); + iree_task_set_cleanup_fn(&cmd->task.header, + iree_hal_task_semaphore_wait_cmd_cleanup); + iree_task_set_completion_task(&cmd->task.header, issue_task); + cmd->semaphore = semaphore; + iree_task_submission_enqueue(submission, &cmd->task.header); + } + } + + iree_slim_mutex_unlock(&semaphore->mutex); + return status; +} + +static iree_status_t iree_hal_task_semaphore_wait_with_deadline( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_time_t deadline_ns) { + iree_hal_task_semaphore_t* semaphore = + iree_hal_task_semaphore_cast(base_semaphore); + + iree_slim_mutex_lock(&semaphore->mutex); + + if (semaphore->current_value >= value) { + // Fast path: already satisfied. + iree_slim_mutex_unlock(&semaphore->mutex); + return iree_ok_status(); + } else if (deadline_ns == IREE_TIME_INFINITE_PAST) { + // Not satisfied but a poll, so can avoid the expensive wait handle work. + iree_slim_mutex_unlock(&semaphore->mutex); + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + + // Slow path: acquire a timepoint while we hold the lock. + iree_hal_task_timepoint_t timepoint; + iree_status_t status = + iree_hal_task_semaphore_acquire_timepoint(semaphore, value, &timepoint); + + iree_slim_mutex_unlock(&semaphore->mutex); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) return status; + + // Wait until the timepoint resolves. + // If satisfied the timepoint is automatically cleaned up and we are done. If + // the deadline is reached before satisfied then we have to clean it up. + status = iree_wait_one(&timepoint.event, deadline_ns); + if (!iree_status_is_ok(status)) { + iree_slim_mutex_lock(&semaphore->mutex); + iree_hal_task_timepoint_list_erase(&semaphore->timepoint_list, &timepoint); + iree_slim_mutex_unlock(&semaphore->mutex); + } + iree_hal_local_event_pool_release(semaphore->event_pool, 1, &timepoint.event); + return status; +} + +static iree_status_t iree_hal_task_semaphore_wait_with_timeout( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_duration_t timeout_ns) { + return iree_hal_task_semaphore_wait_with_deadline( + base_semaphore, value, iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +iree_status_t iree_hal_task_semaphore_multi_wait( + iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + iree_hal_local_event_pool_t* event_pool, + iree_arena_block_pool_t* block_pool) { + IREE_ASSERT_ARGUMENT(semaphore_list); + if (semaphore_list->count == 0) { + return iree_ok_status(); + } else if (semaphore_list->count == 1) { + // Fast-path for a single semaphore. + return iree_hal_semaphore_wait_with_deadline( + semaphore_list->semaphores[0], semaphore_list->payload_values[0], + deadline_ns); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + // Avoid heap allocations by using the device block pool for the wait set. + iree_arena_allocator_t arena; + iree_arena_initialize(block_pool, &arena); + iree_wait_set_t* wait_set = NULL; + iree_status_t status = iree_wait_set_allocate( + semaphore_list->count, iree_arena_allocator(&arena), &wait_set); + + // Acquire a wait handle for each semaphore timepoint we are to wait on. + // TODO(benvanik): flip this API around so we can batch request events from + // the event pool. We should be acquiring all required time points in one + // call. + iree_host_size_t timepoint_count = 0; + iree_hal_task_timepoint_t* timepoints = NULL; + iree_host_size_t total_timepoint_size = + semaphore_list->count * sizeof(timepoints[0]); + status = + iree_arena_allocate(&arena, total_timepoint_size, (void**)&timepoints); + if (iree_status_is_ok(status)) { + memset(timepoints, 0, total_timepoint_size); + for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) { + iree_hal_task_semaphore_t* semaphore = + iree_hal_task_semaphore_cast(semaphore_list->semaphores[i]); + iree_slim_mutex_lock(&semaphore->mutex); + if (semaphore->current_value >= semaphore_list->payload_values[i]) { + // Fast path: already satisfied. + } else { + // Slow path: get a native wait handle for the timepoint. + iree_hal_task_timepoint_t* timepoint = &timepoints[timepoint_count++]; + status = iree_hal_task_semaphore_acquire_timepoint( + semaphore, semaphore_list->payload_values[i], timepoint); + if (iree_status_is_ok(status)) { + status = iree_wait_set_insert(wait_set, timepoint->event); + } + } + iree_slim_mutex_unlock(&semaphore->mutex); + if (!iree_status_is_ok(status)) break; + } + } + + // Perform the wait. + if (iree_status_is_ok(status)) { + if (wait_mode == IREE_HAL_WAIT_MODE_ANY) { + status = iree_wait_any(wait_set, deadline_ns, /*out_wake_handle=*/NULL); + } else { + status = iree_wait_all(wait_set, deadline_ns); + } + } + + if (timepoints != NULL) { + // TODO(benvanik): if we flip the API to multi-acquire events from the pool + // above then we can multi-release here too. + for (iree_host_size_t i = 0; i < timepoint_count; ++i) { + iree_hal_local_event_pool_release(event_pool, 1, &timepoints[i].event); + } + } + iree_wait_set_free(wait_set); + iree_arena_deinitialize(&arena); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static const iree_hal_semaphore_vtable_t iree_hal_task_semaphore_vtable = { + .destroy = iree_hal_task_semaphore_destroy, + .query = iree_hal_task_semaphore_query, + .signal = iree_hal_task_semaphore_signal, + .fail = iree_hal_task_semaphore_fail, + .wait_with_deadline = iree_hal_task_semaphore_wait_with_deadline, + .wait_with_timeout = iree_hal_task_semaphore_wait_with_timeout, +}; diff --git a/iree/hal/local/task_semaphore.h b/iree/hal/local/task_semaphore.h new file mode 100644 index 0000000000000..88eafb9c80f18 --- /dev/null +++ b/iree/hal/local/task_semaphore.h @@ -0,0 +1,58 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_TASK_SEMAPHORE_H_ +#define IREE_HAL_LOCAL_TASK_SEMAPHORE_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/local/arena.h" +#include "iree/hal/local/event_pool.h" +#include "iree/task/submission.h" +#include "iree/task/task.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a semaphore that integrates with the task system to allow for +// pipelined wait and signal operations. +iree_status_t iree_hal_task_semaphore_create( + iree_hal_local_event_pool_t* event_pool, uint64_t initial_value, + iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore); + +// Reserves a new timepoint in the timeline for the given minimum payload value. +// |issue_task| will wait until the timeline semaphore is signaled to at least +// |minimum_value| before proceeding, with a possible wait task generated and +// appended to the |submission|. Allocations for any intermediates will be made +// from |arena| whose lifetime must be tied to the submission. +iree_status_t iree_hal_task_semaphore_enqueue_timepoint( + iree_hal_semaphore_t* semaphore, uint64_t minimum_value, + iree_task_t* issue_task, iree_arena_allocator_t* arena, + iree_task_submission_t* submission); + +// Performs a multi-wait on one or more semaphores. +// Returns IREE_STATUS_DEADLINE_EXCEEDED if the wait does not complete before +// |deadline_ns| elapses. +iree_status_t iree_hal_task_semaphore_multi_wait( + iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + iree_hal_local_event_pool_t* event_pool, + iree_arena_block_pool_t* block_pool); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_TASK_SEMAPHORE_H_ diff --git a/iree/hal/metal/BUILD.bazel b/iree/hal/metal/BUILD.bazel deleted file mode 100644 index 388b824291663..0000000000000 --- a/iree/hal/metal/BUILD.bazel +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -objc_library( - name = "metal", - srcs = [ - "metal_pipeline_argument_buffer.cc", - ], - hdrs = [ - "dispatch_time_util.h", - "metal_buffer.h", - "metal_capture_manager.h", - "metal_command_buffer.h", - "metal_command_queue.h", - "metal_device.h", - "metal_direct_allocator.h", - "metal_driver.h", - "metal_kernel_library.h", - "metal_pipeline_argument_buffer.h", - "metal_pipeline_cache.h", - "metal_shared_event.h", - ], - copts = ["-std=c++14"], - non_arc_srcs = [ - "metal_buffer.mm", - "metal_capture_manager.mm", - "metal_command_buffer.mm", - "metal_command_queue.mm", - "metal_device.mm", - "metal_direct_allocator.mm", - "metal_driver.mm", - "metal_kernel_library.mm", - "metal_pipeline_cache.mm", - "metal_shared_event.mm", - ], - sdk_frameworks = [ - "Foundation", - "Metal", - ], - deps = [ - "//iree/base:arena", - "//iree/base:file_io", - "//iree/base:flatcc", - "//iree/base:logging", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "//iree/hal:command_buffer_validation", - "//iree/schemas:metal_executable_def_c_fbs", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - ], -) diff --git a/iree/hal/metal/CMakeLists.txt b/iree/hal/metal/CMakeLists.txt deleted file mode 100644 index b8b5f0e40868f..0000000000000 --- a/iree/hal/metal/CMakeLists.txt +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -if(NOT ${IREE_HAL_DRIVER_METAL}) - return() -endif() - -iree_add_all_subdirs() - -iree_cc_library( - NAME - metal - HDRS - "metal_buffer.h" - "metal_capture_manager.h" - "metal_command_buffer.h" - "metal_command_queue.h" - "metal_device.h" - "metal_direct_allocator.h" - "metal_driver.h" - "metal_kernel_library.h" - "metal_pipeline_argument_buffer.h" - "metal_pipeline_cache.h" - "metal_shared_event.h" - SRCS - "metal_buffer.mm" - "metal_capture_manager.mm" - "metal_command_buffer.mm" - "metal_command_queue.mm" - "metal_device.mm" - "metal_direct_allocator.mm" - "metal_driver.mm" - "metal_kernel_library.mm" - "metal_pipeline_argument_buffer.cc" - "metal_pipeline_cache.mm" - "metal_shared_event.mm" - DEPS - absl::flat_hash_map - absl::inlined_vector - absl::memory - absl::span - absl::strings - iree::base::flatcc - iree::base::file_io - iree::base::logging - iree::base::status - iree::base::time - iree::base::tracing - iree::hal - iree::schemas::metal_executable_def_c_fbs - LINKOPTS - "-framework Foundation" - "-framework Metal" - PUBLIC -) diff --git a/iree/hal/metal/README.md b/iree/hal/metal/README.md deleted file mode 100644 index 301d026b519e5..0000000000000 --- a/iree/hal/metal/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Metal HAL Driver - -This directory contains the source code for the Metal HAL driver. See the -[design doc](https://google.github.io/iree/design-docs/metal-hal-driver) for -more details. diff --git a/iree/hal/metal/dispatch_time_util.h b/iree/hal/metal/dispatch_time_util.h deleted file mode 100644 index 6023d38e65461..0000000000000 --- a/iree/hal/metal/dispatch_time_util.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_APPLE_TIME_UTIL_H_ -#define IREE_HAL_METAL_APPLE_TIME_UTIL_H_ - -#include - -#include "iree/base/time.h" - -namespace iree { -namespace hal { -namespace metal { - -// Converts a relative iree::Duration against the currrent time to the -// corresponding dispatch_time_t value. -static inline dispatch_time_t DurationToDispatchTime(Duration duration_ns) { - if (duration_ns == InfiniteDuration()) return DISPATCH_TIME_FOREVER; - if (duration_ns == ZeroDuration()) return DISPATCH_TIME_NOW; - return dispatch_time(DISPATCH_TIME_NOW, static_cast(duration_ns)); -} - -// Converts an absolute iree::Time time to the corresponding dispatch_time_t -// value. -static inline dispatch_time_t DeadlineToDispatchTime(Time deadline_ns) { - return DurationToDispatchTime(DeadlineToRelativeTimeoutNanos(deadline_ns)); -} - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_APPLE_TIME_UTIL_H_ diff --git a/iree/hal/metal/metal_buffer.h b/iree/hal/metal/metal_buffer.h deleted file mode 100644 index dd62e8ec8ccd2..0000000000000 --- a/iree/hal/metal/metal_buffer.h +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_BUFFER_H_ -#define IREE_HAL_METAL_METAL_BUFFER_H_ - -#import - -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { -namespace metal { - -class MetalDirectAllocator; - -// A buffer implementation for Metal that directly wraps a MTLBuffer. -class MetalBuffer final : public Buffer { - public: - // Creates a MetalBuffer instance with retaining the given id. - static StatusOr> Create( - MetalDirectAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length, id buffer, - id transfer_queue); - - // Creates a MetalBuffer instance without retaining the given id. - static StatusOr> CreateUnretained( - MetalDirectAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length, id buffer, - id transfer_queue); - - ~MetalBuffer() override; - - id handle() const { return metal_handle_; } - - private: - // Creates a MetalBuffer instance without retaining the given id. - MetalBuffer(MetalDirectAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length, id buffer, - id transfer_queue); - - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override; - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override; - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override; - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override; - - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override; - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - - // Returns true if we need to automatically invaliate/flush CPU caches to keep - // memory hierarchy consistent. - // - // Note: this is needed when the buffer is requested with - // MemoryType::kHostCoherent bit but under the hood we are using memory types - // that does not have that property natively, e.g., MTLStorageModeManaged. - // Under such circumstances, we need to perform the invalidate/flush operation - // "automatically" for users. - bool requires_autosync() const; - - // We need to hold an reference to the queue so that we can encode - // synchronizeResource commands for synchronizing the buffer with - // MTLResourceStorageModeManaged. - id metal_transfer_queue_; - - id metal_handle_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_BUFFER_H_ diff --git a/iree/hal/metal/metal_buffer.mm b/iree/hal/metal/metal_buffer.mm deleted file mode 100644 index 55533b6843bd1..0000000000000 --- a/iree/hal/metal/metal_buffer.mm +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_buffer.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/metal/metal_direct_allocator.h" - -namespace iree { -namespace hal { -namespace metal { - -// static -StatusOr> MetalBuffer::Create( - MetalDirectAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, device_size_t allocation_size, - device_size_t byte_offset, device_size_t byte_length, id buffer, - id transfer_queue) { - IREE_TRACE_SCOPE0("MetalBuffer::Create"); - return assign_ref(new MetalBuffer(allocator, memory_type, allowed_access, usage, allocation_size, - byte_offset, byte_length, [buffer retain], transfer_queue)); -} - -// static -StatusOr> MetalBuffer::CreateUnretained( - MetalDirectAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, device_size_t allocation_size, - device_size_t byte_offset, device_size_t byte_length, id buffer, - id transfer_queue) { - IREE_TRACE_SCOPE0("MetalBuffer::Create"); - return assign_ref(new MetalBuffer(allocator, memory_type, allowed_access, usage, allocation_size, - byte_offset, byte_length, buffer, transfer_queue)); -} - -MetalBuffer::MetalBuffer(MetalDirectAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length, id buffer, - id transfer_queue) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, byte_offset, - byte_length), - metal_transfer_queue_([transfer_queue retain]), - metal_handle_(buffer) {} - -MetalBuffer::~MetalBuffer() { - IREE_TRACE_SCOPE0("MetalBuffer::dtor"); - [metal_handle_ release]; - [metal_transfer_queue_ release]; -} - -Status MetalBuffer::FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) { - IREE_ASSIGN_OR_RETURN(auto mapping, - MapMemory(MemoryAccess::kDiscardWrite, byte_offset, byte_length)); - void* data_ptr = static_cast(mapping.mutable_data()); - switch (pattern_length) { - case 1: { - uint8_t* data = static_cast(data_ptr); - uint8_t value_bits = *static_cast(pattern); - std::fill_n(data, byte_length, value_bits); - break; - } - case 2: { - uint16_t* data = static_cast(data_ptr); - uint16_t value_bits = *static_cast(pattern); - std::fill_n(data, byte_length / sizeof(uint16_t), value_bits); - break; - } - case 4: { - uint32_t* data = static_cast(data_ptr); - uint32_t value_bits = *static_cast(pattern); - std::fill_n(data, byte_length / sizeof(uint32_t), value_bits); - break; - } - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported scalar data size: " << pattern_length; - } - return OkStatus(); -} - -Status MetalBuffer::ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) { - IREE_ASSIGN_OR_RETURN(auto mapping, - MapMemory(MemoryAccess::kRead, source_offset, data_length)); - std::memcpy(data, mapping.data(), mapping.byte_length()); - return OkStatus(); -} - -Status MetalBuffer::WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) { - IREE_ASSIGN_OR_RETURN( - auto mapping, MapMemory(MemoryAccess::kDiscardWrite, target_offset, data_length)); - std::memcpy(mapping.mutable_data(), data, mapping.byte_length()); - return OkStatus(); -} - -Status MetalBuffer::CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, device_size_t data_length) { - // This is pretty terrible. Let's not do this. - // TODO(benvanik): a way for allocators to indicate transfer compat. - IREE_ASSIGN_OR_RETURN(auto source_mapping, source_buffer->MapMemory( - MemoryAccess::kRead, source_offset, data_length)); - IREE_CHECK_EQ(data_length, source_mapping.size()); - IREE_ASSIGN_OR_RETURN(auto target_mapping, MapMemory(MemoryAccess::kDiscardWrite, - target_offset, data_length)); - IREE_CHECK_EQ(data_length, target_mapping.size()); - std::memcpy(target_mapping.mutable_data(), source_mapping.data(), data_length); - return OkStatus(); -} - -Status MetalBuffer::MapMemoryImpl(MappingMode mapping_mode, MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, device_size_t local_byte_length, - void** out_data) { - uint8_t* data_ptr = reinterpret_cast([metal_handle_ contents]); - *out_data = data_ptr + local_byte_offset; - - // If we mapped for discard scribble over the bytes. This is not a mandated - // behavior but it will make debugging issues easier. Alternatively for - // heap buffers we could reallocate them such that ASAN yells, but that - // would only work if the entire buffer was discarded. -#ifndef NDEBUG - if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) { - std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); - } -#endif // !NDEBUG - - if (requires_autosync()) { - IREE_RETURN_IF_ERROR(InvalidateMappedMemoryImpl(local_byte_offset, local_byte_length)); - } - - return OkStatus(); -} - -Status MetalBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) { - if (requires_autosync()) { - IREE_RETURN_IF_ERROR(FlushMappedMemoryImpl(local_byte_offset, local_byte_length)); - } - - return OkStatus(); -} - -Status MetalBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { -#ifdef IREE_PLATFORM_MACOS - // The following is only necessary for MTLStorageManaged. - if (metal_handle_.storageMode == MTLStorageModeManaged) { - @autoreleasepool { - id command_buffer = - [metal_transfer_queue_ commandBufferWithUnretainedReferences]; - - id blit_encoder = [command_buffer blitCommandEncoder]; - [blit_encoder synchronizeResource:metal_handle_]; - [blit_encoder endEncoding]; - - [command_buffer commit]; - [command_buffer waitUntilCompleted]; - } - } -#endif - - return OkStatus(); -} - -Status MetalBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { -#ifdef IREE_PLATFORM_MACOS - // The following is only necessary for MTLStorageManaged. - if (metal_handle_.storageMode == MTLStorageModeManaged) { - [metal_handle_ didModifyRange:NSMakeRange(local_byte_offset, local_byte_length)]; - } -#endif - - return OkStatus(); -} - -bool MetalBuffer::requires_autosync() const { - // We only need to perform "automatic" resource synchronization if it's MTLStorageModeManaged, - // which is only available on macOS. -#ifdef IREE_PLATFORM_MACOS - return AllBitsSet(memory_type(), MemoryType::kHostCoherent) && - metal_handle_.storageMode == MTLStorageModeManaged; -#else - return false; -#endif -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_capture_manager.h b/iree/hal/metal/metal_capture_manager.h deleted file mode 100644 index 90b8ac6401bab..0000000000000 --- a/iree/hal/metal/metal_capture_manager.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_ -#define IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_ - -#include - -#import - -#include "iree/base/status.h" -#include "iree/hal/debug_capture_manager.h" - -namespace iree { -namespace hal { -namespace metal { - -// A DebugCaptureManager implementation for Metal that directly wraps a -// MTLCaptureManager. -class MetalCaptureManager final : public DebugCaptureManager { - public: - // Creates a capture manager that captures Metal commands to the given |capture_file| if not - // empty. Capture to Xcode otherwise. - static StatusOr> Create(const std::string& capture_file); - ~MetalCaptureManager() override; - - Status Connect() override; - - void Disconnect() override; - - bool is_connected() const override; - - void SetCaptureObject(id object); - - void StartCapture() override; - - void StopCapture() override; - - bool is_capturing() const override; - - private: - explicit MetalCaptureManager(NSURL* capture_file); - - MTLCaptureManager* metal_handle_ = nil; - // The path for storing the .gputrace file. Empty means capturing to Xcode. - NSURL* capture_file_ = nil; - id capture_object_ = nil; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_CAPTURE_MANAGER_H_ diff --git a/iree/hal/metal/metal_capture_manager.mm b/iree/hal/metal/metal_capture_manager.mm deleted file mode 100644 index 4437951f1cca6..0000000000000 --- a/iree/hal/metal/metal_capture_manager.mm +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_capture_manager.h" - -#include - -#include "absl/memory/memory.h" -#include "iree/base/file_io.h" -#include "iree/base/logging.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { -namespace metal { - -// static -StatusOr> MetalCaptureManager::Create( - const std::string& capture_file) { - IREE_TRACE_SCOPE0("MetalCaptureManager::Create"); - @autoreleasepool { - NSURL* capture_url = nil; - if (!capture_file.empty()) { - NSString* ns_string = [NSString stringWithCString:capture_file.c_str() - encoding:[NSString defaultCStringEncoding]]; - NSString* capture_path = ns_string.stringByStandardizingPath; - capture_url = [[NSURL fileURLWithPath:capture_path isDirectory:false] retain]; - } - return absl::WrapUnique(new MetalCaptureManager(capture_url)); - } -} - -MetalCaptureManager::MetalCaptureManager(NSURL* capture_file) : capture_file_(capture_file) {} - -MetalCaptureManager::~MetalCaptureManager() { - IREE_TRACE_SCOPE0("MetalCaptureManager::dtor"); - Disconnect(); - if (capture_file_) [capture_file_ release]; -} - -Status MetalCaptureManager::Connect() { - IREE_TRACE_SCOPE0("MetalCaptureManager::Connect"); - - if (metal_handle_) return OkStatus(); - - @autoreleasepool { - metal_handle_ = [[MTLCaptureManager sharedCaptureManager] retain]; - - if (capture_file_ && - [metal_handle_ supportsDestination:MTLCaptureDestinationGPUTraceDocument]) { - IREE_LOG(INFO) << "Connected to shared Metal capture manager; writing capture to " - << std::string([capture_file_.absoluteString UTF8String]); - } else { - IREE_LOG(INFO) << "Connected to shared Metal capture manager; capturing to Xcode"; - } - } - - return OkStatus(); -} - -void MetalCaptureManager::Disconnect() { - IREE_TRACE_SCOPE0("MetalCaptureManager::Disconnect"); - - if (!metal_handle_) return; - - if (is_capturing()) StopCapture(); - - [metal_handle_ release]; - metal_handle_ = nil; -} - -bool MetalCaptureManager::is_connected() const { return metal_handle_ != nil; } - -void MetalCaptureManager::SetCaptureObject(id object) { capture_object_ = object; } - -void MetalCaptureManager::StartCapture() { - IREE_TRACE_SCOPE0("MetalCaptureManager::StartCapture"); - - IREE_CHECK(is_connected()) << "Can't start capture when not connected"; - IREE_CHECK(!is_capturing()) << "Capture is already started"; - IREE_CHECK(capture_object_) << "Must set capture object before starting"; - - IREE_LOG(INFO) << "Starting Metal capture"; - @autoreleasepool { - MTLCaptureDescriptor* capture_descriptor = [[[MTLCaptureDescriptor alloc] init] autorelease]; - capture_descriptor.captureObject = capture_object_; - if (capture_file_) { - capture_descriptor.destination = MTLCaptureDestinationGPUTraceDocument; - capture_descriptor.outputURL = capture_file_; - } else { - capture_descriptor.destination = MTLCaptureDestinationDeveloperTools; - } - - NSError* error; - if (![metal_handle_ startCaptureWithDescriptor:capture_descriptor error:&error]) { - NSLog(@"Failed to start capture, error %@", error); - } - } -} - -void MetalCaptureManager::StopCapture() { - IREE_TRACE_SCOPE0("MetalCaptureManager::StopCapture"); - - IREE_CHECK(is_capturing()) << "Can't stop capture when not capturing"; - - IREE_LOG(INFO) << "Ending Metal capture"; - [metal_handle_ stopCapture]; -} - -bool MetalCaptureManager::is_capturing() const { - if (!is_connected()) return false; - return metal_handle_.isCapturing; -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_command_buffer.h b/iree/hal/metal/metal_command_buffer.h deleted file mode 100644 index 06973a2616d62..0000000000000 --- a/iree/hal/metal/metal_command_buffer.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_COMMAND_BUFFER_H_ -#define IREE_HAL_METAL_METAL_COMMAND_BUFFER_H_ - -#import - -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/metal/metal_buffer.h" - -namespace iree { -namespace hal { -namespace metal { - -// A command buffer implementation for Metal that directly wraps a -// MTLCommandBuffer. -// -// Objects of this class are not expected to be accessed by multiple threads. -class MetalCommandBuffer final : public CommandBuffer { - public: - static StatusOr> Create( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories, - id command_buffer); - ~MetalCommandBuffer() override; - - id handle() const { return metal_handle_; } - - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status WaitEvents(absl::Span events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - Status DiscardBuffer(Buffer* buffer) override; - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status PushConstants(ExecutableLayout* executable_layout, size_t offset, - absl::Span values) override; - - Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) override; - Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) override; - - Status Dispatch(Executable* executable, int32_t entry_point, - std::array workgroups) override; - Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) override; - - private: - // A struct containing all resources states of the current pipeline. - struct PipelineStateObject { - struct PushState { - absl::InlinedVector resource_bindings; - }; - // Map from set number to push descriptor states - absl::flat_hash_map push_states; - - struct BindState { - DescriptorSet* descriptor_set; - }; - // Map from set number to bind descriptor states - absl::flat_hash_map bind_states; - - struct ConstantState { - absl::InlinedVector values; - }; - // Map from set number to push constant states - absl::flat_hash_map constant_states; - }; - - MetalCommandBuffer(CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories, - id command_buffer); - - StatusOr CastBuffer(Buffer* buffer) const; - - // Gets or begins an active MTLBlitCommandEncoder. This also ends all previous - // encoded compute commands if any. - id GetOrBeginBlitEncoder(); - void EndBlitEncoder(); - - // Gets or begins a new MTLComputeCommandEncoder. This also ends all previous - // encoded blit commands if any. - id GetOrBeginComputeEncoder(); - void EndComputeEncoder(); - - private: - bool is_recording_ = false; - id metal_handle_; - - id current_compute_encoder_ = nil; - id current_blit_encoder_ = nil; - - absl::flat_hash_map - pipeline_state_objects_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_COMMAND_BUFFER_H_ diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm deleted file mode 100644 index b40369fce5c6b..0000000000000 --- a/iree/hal/metal/metal_command_buffer.mm +++ /dev/null @@ -1,389 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_command_buffer.h" - -#include "iree/base/logging.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/metal/metal_kernel_library.h" -#include "iree/hal/metal/metal_pipeline_argument_buffer.h" - -namespace iree { -namespace hal { -namespace metal { - -namespace { - -MTLResourceUsage ConvertResourceUsage(MemoryAccessBitfield memory_access) { - MTLResourceUsage usage = 0; - if (AllBitsSet(memory_access, MemoryAccess::kRead)) usage |= MTLResourceUsageRead; - if (AllBitsSet(memory_access, MemoryAccess::kWrite)) usage |= MTLResourceUsageWrite; - return usage; -} - -} // namespace - -// static -StatusOr> MetalCommandBuffer::Create( - CommandBufferModeBitfield mode, CommandCategoryBitfield command_categories, - id command_buffer) { - return assign_ref(new MetalCommandBuffer(mode, command_categories, command_buffer)); -} - -MetalCommandBuffer::MetalCommandBuffer(CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories, - id command_buffer) - : CommandBuffer(mode, command_categories), metal_handle_([command_buffer retain]) { - metal_handle_.label = @"IREE MetalCommandBuffer"; -} - -MetalCommandBuffer::~MetalCommandBuffer() { - IREE_TRACE_SCOPE0("MetalCommandBuffer::dtor"); - [metal_handle_ release]; -} - -StatusOr MetalCommandBuffer::CastBuffer(Buffer* buffer) const { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return static_cast(buffer->allocated_buffer()); -} - -id MetalCommandBuffer::GetOrBeginBlitEncoder() { - IREE_TRACE_SCOPE0("MetalCommandBuffer::GetOrBeginBlitEncoder"); - - if (current_compute_encoder_) EndComputeEncoder(); - - @autoreleasepool { - if (!current_blit_encoder_) { - current_blit_encoder_ = [[metal_handle_ blitCommandEncoder] retain]; - } - } - - return current_blit_encoder_; -} - -void MetalCommandBuffer::EndBlitEncoder() { - IREE_TRACE_SCOPE0("MetalCommandBuffer::EndBlitEncoder"); - if (current_blit_encoder_) { - [current_blit_encoder_ endEncoding]; - [current_blit_encoder_ release]; - current_blit_encoder_ = nil; - } -} - -id MetalCommandBuffer::GetOrBeginComputeEncoder() { - IREE_TRACE_SCOPE0("MetalCommandBuffer::GetOrBeginComputeEncoder"); - - if (current_blit_encoder_) EndBlitEncoder(); - - @autoreleasepool { - if (!current_compute_encoder_) { - current_compute_encoder_ = [[metal_handle_ computeCommandEncoder] retain]; - } - } - - return current_compute_encoder_; -} - -void MetalCommandBuffer::EndComputeEncoder() { - IREE_TRACE_SCOPE0("MetalCommandBuffer::EndComputeEncoder"); - if (current_compute_encoder_) { - [current_compute_encoder_ endEncoding]; - [current_compute_encoder_ release]; - current_compute_encoder_ = nil; - } -} - -Status MetalCommandBuffer::Begin() { - IREE_TRACE_SCOPE0("MetalCommandBuffer::Begin"); - is_recording_ = true; - return OkStatus(); -} - -Status MetalCommandBuffer::End() { - IREE_TRACE_SCOPE0("MetalCommandBuffer::End"); - EndBlitEncoder(); - EndComputeEncoder(); - is_recording_ = false; - return OkStatus(); -} - -Status MetalCommandBuffer::ExecutionBarrier(ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::ExecutionBarrier"); - - if (AllBitsSet(source_stage_mask, ExecutionStage::kHost) || - AllBitsSet(target_stage_mask, ExecutionStage::kHost)) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::ExecutionBarrier with host bit set"; - } - - // If there is a memory barrier specified, we have to place a catch-all barrier for all buffers. - // Metal does not provide a more fine-grained control here. But we do have the option to specify a - // list of buffers to synchronize if only buffer barriers are specified. - if (!memory_barriers.empty()) { - [GetOrBeginComputeEncoder() memoryBarrierWithScope:MTLBarrierScopeBuffers]; - } else if (!buffer_barriers.empty()) { - std::vector> buffers; - buffers.reserve(buffer_barriers.size()); - for (const auto& barrier : buffer_barriers) { - buffers.push_back(static_cast(barrier.buffer)->handle()); - } - [GetOrBeginComputeEncoder() memoryBarrierWithResources:buffers.data() count:buffers.size()]; - } - - return OkStatus(); -} - -Status MetalCommandBuffer::SignalEvent(Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::SignalEvent"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::SignalEvent"; -} - -Status MetalCommandBuffer::ResetEvent(Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::ResetEvent"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::ResetEvent"; -} - -Status MetalCommandBuffer::WaitEvents(absl::Span events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::WaitEvents"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::WaitEvents"; -} - -Status MetalCommandBuffer::FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::FillBuffer"); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); - - target_offset += target_buffer->byte_offset(); - - // Per the spec for fillBuffer:range:value: "The alignment and length of the range must both be a - // multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS." Although iOS/tvOS is more relaxed on - // this front, we still require 4-byte alignment for uniformity across IREE. - if (target_offset % 4 != 0) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::FillBuffer with offset that is not a multiple of 4 bytes"; - } - - // Note that fillBuffer:range:value: only accepts a single byte as the pattern but FillBuffer - // can accept 1/2/4 bytes. If the pattern itself contains repeated bytes, we can call into - // fillBuffer:range:value:. Otherwise we may need to find another way. Just implement the case - // where we have a single byte to fill for now. - if (pattern_length != 1) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::FillBuffer with non-1-byte pattern"; - } - uint8_t byte_pattern = *reinterpret_cast(pattern); - - [GetOrBeginBlitEncoder() fillBuffer:target_device_buffer->handle() - range:NSMakeRange(target_offset, length) - value:byte_pattern]; - - return OkStatus(); -} - -Status MetalCommandBuffer::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::DiscardBuffer"); - // This is a hint. Nothing to do for Metal. - return OkStatus(); -} - -Status MetalCommandBuffer::UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::UpdateBuffer"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::UpdateBuffer"; -} - -Status MetalCommandBuffer::CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::CopyBuffer"); - - IREE_ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer)); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); - - source_offset += source_buffer->byte_offset(); - target_offset += target_buffer->byte_offset(); - - // Per the spec for copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size, the source/target - // offset must be a multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS. Although iOS/tvOS - // is more relaxed on this front, we still require 4-byte alignment for uniformity across IREE. - if (source_offset % 4 != 0 || target_offset % 4 != 0) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::CopyBuffer with offset that is not a multiple of 4 bytes"; - } - - [GetOrBeginBlitEncoder() copyFromBuffer:source_device_buffer->handle() - sourceOffset:source_offset - toBuffer:target_device_buffer->handle() - destinationOffset:target_offset - size:length]; - - return OkStatus(); -} - -Status MetalCommandBuffer::PushConstants(ExecutableLayout* executable_layout, size_t offset, - absl::Span values) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::PushConstants"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::PushConstants"; -} - -Status MetalCommandBuffer::PushDescriptorSet(ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::PushDescriptorSet"); - if (set != 0) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::PushDescriptorSet with set number > 0"; - } - auto& push_state = pipeline_state_objects_[executable_layout].push_states[set]; - push_state.resource_bindings.assign(bindings.begin(), bindings.end()); - return OkStatus(); -} - -Status MetalCommandBuffer::BindDescriptorSet(ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::BindDescriptorSet"); - if (set != 0) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::BindDescriptorSet with set number > 0"; - } - if (!dynamic_offsets.empty()) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::BindDescriptorSet with dynamic offsets"; - } - pipeline_state_objects_[executable_layout].bind_states[set].descriptor_set = descriptor_set; - return OkStatus(); -} - -Status MetalCommandBuffer::Dispatch(Executable* executable, int32_t entry_point, - std::array workgroups) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::Dispatch"); - IREE_DVLOG(2) << "MetalCommandBuffer::Dispatch"; - - auto* kernel_library = static_cast(executable); - IREE_ASSIGN_OR_RETURN(auto metal_kernel, kernel_library->GetKernelForEntryPoint(entry_point)); - IREE_ASSIGN_OR_RETURN(auto metal_pso, kernel_library->GetPipelineStateForEntryPoint(entry_point)); - IREE_ASSIGN_OR_RETURN(auto workgroup_size, - kernel_library->GetThreadgroupSizeForEntryPoint(entry_point)); - - id compute_encoder = GetOrBeginComputeEncoder(); - [compute_encoder setComputePipelineState:metal_pso]; - - // TODO(antiagainst): only update the PSO for the current executable. - for (const auto& pso_kv : pipeline_state_objects_) { - const auto* pipeline_layout = static_cast(pso_kv.first); - IREE_DVLOG(3) << "Current pipeline layout: " << pipeline_layout->DebugString(); - - const auto& pso = pso_kv.second; - if (pso.push_states.size() > 1) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::Dispatch with more than one push descriptor sets"; - } - if (!pso.bind_states.empty()) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::Dispatch with bound descriptor sets"; - } - if (!pso.constant_states.empty()) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::Dispatch with push constants"; - } - - IREE_DVLOG(3) << "Encoding push descriptors.."; - for (const auto& push_kv : pso.push_states) { - int32_t set_number = push_kv.first; - const PipelineStateObject::PushState& push_state = push_kv.second; - IREE_DVLOG(3) << " For set #" << set_number; - - id argument_encoder = - [metal_kernel newArgumentEncoderWithBufferIndex:set_number]; // retained - argument_encoder.label = @"IREE MetalCommandBuffer::Dispatch ArgumentEncoder"; - if (!argument_encoder) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Buffer index #" << set_number << " is not an argument buffer"; - } - - __block id argument_buffer = - [metal_handle_.device newBufferWithLength:argument_encoder.encodedLength - options:MTLResourceStorageModeShared]; // retained - argument_encoder.label = @"IREE MetalCommandBuffer::Dispatch ArgumentBuffer"; - if (!argument_buffer) { - return InternalErrorBuilder(IREE_LOC) - << "Failed to create argument buffer with length=" << argument_encoder.encodedLength; - } - [metal_handle_ addCompletedHandler:^(id) { - [argument_buffer release]; - [argument_encoder release]; - }]; - - [argument_encoder setArgumentBuffer:argument_buffer offset:0]; - - for (const auto& resource_binding : push_state.resource_bindings) { - IREE_DVLOG(3) << " Resource @[" << resource_binding.DebugStringShort() << "]"; - - if (resource_binding.length != kWholeBuffer && - resource_binding.length != resource_binding.buffer->allocation_size()) { - return UnimplementedErrorBuilder(IREE_LOC) - << "MetalCommandBuffer::Dispatch with sub-buffer"; - } - - IREE_ASSIGN_OR_RETURN(auto buffer, CastBuffer(resource_binding.buffer)); - [argument_encoder setBuffer:buffer->handle() - offset:resource_binding.offset - atIndex:resource_binding.binding]; - - const auto* set_layout = pipeline_layout->set_layouts()[set_number]; - const auto* layout_binding = set_layout->GetBindingForIndex(resource_binding.binding); - if (!layout_binding) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Cannot find binding #" << resource_binding.binding - << " in argument buffer layout"; - } - [compute_encoder useResource:buffer->handle() - usage:ConvertResourceUsage(layout_binding->access)]; - } - - [compute_encoder setBuffer:argument_buffer offset:0 atIndex:set_number]; - } - } - - IREE_DVLOG(2) << "Dispatch workgroup count: (" << workgroups[0] << ", " << workgroups[1] << ", " - << workgroups[2] << "), workgroup size: (" << workgroup_size.x << ", " - << workgroup_size.y << ", " << workgroup_size.z << ")"; - [compute_encoder - dispatchThreadgroups:MTLSizeMake(workgroups[0], workgroups[1], workgroups[2]) - threadsPerThreadgroup:MTLSizeMake(workgroup_size.x, workgroup_size.y, workgroup_size.z)]; - - return OkStatus(); -} - -Status MetalCommandBuffer::DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) { - IREE_TRACE_SCOPE0("MetalCommandBuffer::DispatchIndirect"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::DispatchIndirect"; -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_command_queue.h b/iree/hal/metal/metal_command_queue.h deleted file mode 100644 index caf43f0dd7255..0000000000000 --- a/iree/hal/metal/metal_command_queue.h +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_COMMAND_QUEUE_H_ -#define IREE_HAL_METAL_METAL_COMMAND_QUEUE_H_ - -#import - -#include "iree/base/arena.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/command_queue.h" - -namespace iree { -namespace hal { -namespace metal { - -// A command queue implementation for Metal that directly wraps a -// MTLCommandQueue. -// -// Thread-safe. -class MetalCommandQueue final : public CommandQueue { - public: - MetalCommandQueue(std::string name, - CommandCategoryBitfield supported_categories, - id queue); - ~MetalCommandQueue() override; - - id handle() const { return metal_handle_; } - - Status Submit(absl::Span batches) override; - - Status WaitIdle(Time deadline_ns) override; - - private: - id metal_handle_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_COMMAND_QUEUE_H_ diff --git a/iree/hal/metal/metal_command_queue.mm b/iree/hal/metal/metal_command_queue.mm deleted file mode 100644 index 3f1b6e261e938..0000000000000 --- a/iree/hal/metal/metal_command_queue.mm +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_command_queue.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/metal/dispatch_time_util.h" -#include "iree/hal/metal/metal_command_buffer.h" -#include "iree/hal/metal/metal_shared_event.h" - -namespace iree { -namespace hal { -namespace metal { - -MetalCommandQueue::MetalCommandQueue(std::string name, CommandCategoryBitfield supported_categories, - id queue) - : CommandQueue(std::move(name), supported_categories), metal_handle_([queue retain]) { - metal_handle_.label = @"IREE MetalQueue"; -} - -MetalCommandQueue::~MetalCommandQueue() { [metal_handle_ release]; } - -Status MetalCommandQueue::Submit(absl::Span batches) { - IREE_TRACE_SCOPE0("MetalCommandQueue::Submit"); - for (const auto& batch : batches) { - @autoreleasepool { - // Wait for semaphores blocking this batch. - if (!batch.wait_semaphores.empty()) { - id wait_buffer = [metal_handle_ commandBufferWithUnretainedReferences]; - wait_buffer.label = @"IREE MetalCommandQueue::Submit Wait Semaphore CommandBuffer"; - - for (const auto& semaphore : batch.wait_semaphores) { - auto* event = static_cast(semaphore.semaphore); - [wait_buffer encodeWaitForEvent:event->handle() value:semaphore.value]; - } - [wait_buffer commit]; - } - - // Commit command buffers to the queue. - for (const auto* command_buffer : batch.command_buffers) { - const auto* cmdbuf = static_cast(command_buffer); - [cmdbuf->handle() commit]; - } - - // Signal semaphores advanced by this batch. - if (!batch.signal_semaphores.empty()) { - id signal_buffer = [metal_handle_ commandBufferWithUnretainedReferences]; - signal_buffer.label = @"IREE MetalCommandQueue::Submit Signal Semaphore CommandBuffer"; - - for (const auto& semaphore : batch.signal_semaphores) { - auto* event = static_cast(semaphore.semaphore); - [signal_buffer encodeSignalEvent:event->handle() value:semaphore.value]; - } - [signal_buffer commit]; - } - } - } - return OkStatus(); -} - -Status MetalCommandQueue::WaitIdle(Time deadline_ns) { - IREE_TRACE_SCOPE0("MetalCommandQueue::WaitIdle"); - - dispatch_time_t timeout = DeadlineToDispatchTime(deadline_ns); - - // Submit an empty command buffer and wait for it to complete. That will indicate all previous - // work has completed too. - @autoreleasepool { - id comand_buffer = [metal_handle_ commandBufferWithUnretainedReferences]; - comand_buffer.label = @"IREE MetalCommandQueue::WaitIdle Command Buffer"; - __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0); - [comand_buffer addCompletedHandler:^(id) { - dispatch_semaphore_signal(work_done); - }]; - [comand_buffer commit]; - long timed_out = dispatch_semaphore_wait(work_done, timeout); - dispatch_release(work_done); - if (timed_out) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for dispatch_semaphore_t"; - } - return OkStatus(); - } -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_device.h b/iree/hal/metal/metal_device.h deleted file mode 100644 index f183efa39fbf7..0000000000000 --- a/iree/hal/metal/metal_device.h +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_DEVICE_H_ -#define IREE_HAL_METAL_METAL_DEVICE_H_ - -#import - -#include - -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/allocator.h" -#include "iree/hal/debug_capture_manager.h" -#include "iree/hal/device.h" -#include "iree/hal/driver.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { -namespace metal { - -// A device implementation for Metal that directly wraps a MTLDevice. -class MetalDevice final : public Device { - public: - // Creates a device that retains the underlying Metal GPU device. - // The DriverDeviceID in |device_info| is expected to be an id. - static StatusOr> Create( - ref_ptr driver, const DeviceInfo& device_info, - DebugCaptureManager* debug_capture_manager); - - ~MetalDevice() override; - - std::string DebugString() const override; - - Allocator* allocator() const override { return allocator_.get(); } - - absl::Span dispatch_queues() const override { - return absl::MakeSpan(&common_queue_, 1); - } - - absl::Span transfer_queues() const override { - return absl::MakeSpan(&common_queue_, 1); - } - - ref_ptr CreateExecutableCache() override; - - StatusOr> CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) override; - - StatusOr> CreateExecutableLayout( - absl::Span set_layouts, - size_t push_constants) override; - - StatusOr> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span bindings) override; - - StatusOr> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr> CreateEvent() override; - - StatusOr> CreateSemaphore(uint64_t initial_value) override; - Status WaitAllSemaphores(absl::Span semaphores, - Time deadline_ns) override; - StatusOr WaitAnySemaphore(absl::Span semaphores, - Time deadline_ns) override; - - Status WaitIdle(Time deadline_ns) override; - - private: - MetalDevice(ref_ptr driver, const DeviceInfo& device_info, - DebugCaptureManager* debug_capture_manager); - - ref_ptr driver_; - id metal_handle_; - - std::unique_ptr allocator_; - - // Metal does not have clear graphics/dispatch/transfer queue distinction like - // Vulkan; one just use the same newCommandQueue() API call on MTLDevice to - // get command queues. Command encoders differ for different categories of - // commands though. We expose one queue here for everything. This can be - // changed later if more queues prove to be useful. - - std::unique_ptr command_queue_; - mutable CommandQueue* common_queue_ = nullptr; - - // A dispatch queue and associated event listener for running Objective-C - // blocks. This is typically used to wake up threads waiting on some HAL - // semaphore. - dispatch_queue_t wait_notifier_; - MTLSharedEventListener* event_listener_; - - DebugCaptureManager* debug_capture_manager_ = nullptr; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_DEVICE_H_ diff --git a/iree/hal/metal/metal_device.mm b/iree/hal/metal/metal_device.mm deleted file mode 100644 index f11d4b81ac511..0000000000000 --- a/iree/hal/metal/metal_device.mm +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_device.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/base/tracing.h" -#include "iree/hal/allocator.h" -#include "iree/hal/command_buffer_validation.h" -#include "iree/hal/metal/dispatch_time_util.h" -#include "iree/hal/metal/metal_capture_manager.h" -#include "iree/hal/metal/metal_command_buffer.h" -#include "iree/hal/metal/metal_command_queue.h" -#include "iree/hal/metal/metal_direct_allocator.h" -#include "iree/hal/metal/metal_pipeline_argument_buffer.h" -#include "iree/hal/metal/metal_pipeline_cache.h" -#include "iree/hal/metal/metal_shared_event.h" - -namespace iree { -namespace hal { -namespace metal { - -// static -StatusOr> MetalDevice::Create(ref_ptr driver, - const DeviceInfo& device_info, - DebugCaptureManager* debug_capture_manager) { - return assign_ref(new MetalDevice(std::move(driver), device_info, debug_capture_manager)); -} - -MetalDevice::MetalDevice(ref_ptr driver, const DeviceInfo& device_info, - DebugCaptureManager* debug_capture_manager) - : Device(device_info), - driver_(std::move(driver)), - metal_handle_([(__bridge id)device_info.device_id() retain]), - debug_capture_manager_(debug_capture_manager) { - IREE_TRACE_SCOPE0("MetalDevice::ctor"); - - // Grab one queue for dispatch and transfer. - std::string name = absl::StrCat(device_info.name(), ":queue"); - id metal_queue = [metal_handle_ newCommandQueue]; // retained - - allocator_ = MetalDirectAllocator::Create(metal_handle_, metal_queue); - - if (debug_capture_manager_ && debug_capture_manager_->is_connected()) { - // Record a capture covering the duration of this device lifetime. - static_cast(debug_capture_manager_)->SetCaptureObject(metal_handle_); - debug_capture_manager_->StartCapture(); - } - - command_queue_ = absl::make_unique( - name, CommandCategory::kDispatch | CommandCategory::kTransfer, metal_queue); - common_queue_ = command_queue_.get(); - // MetalCommandQueue retains by itself. Release here to avoid leaking. - [metal_queue release]; - - wait_notifier_ = dispatch_queue_create("com.google.iree.semaphore_wait_notifier", NULL); - event_listener_ = [[MTLSharedEventListener alloc] initWithDispatchQueue:wait_notifier_]; -} - -MetalDevice::~MetalDevice() { - IREE_TRACE_SCOPE0("MetalDevice::dtor"); - - if (debug_capture_manager_ && debug_capture_manager_->is_capturing()) { - debug_capture_manager_->StopCapture(); - } - - [event_listener_ release]; - dispatch_release(wait_notifier_); - - [metal_handle_ release]; -} - -std::string MetalDevice::DebugString() const { - return absl::StrCat(Device::DebugString(), // - "\n[MetalDevice]", // - "\n - Dispatch Queues: 1", // - "\n - Transfer Queues: 1"); -} - -ref_ptr MetalDevice::CreateExecutableCache() { - IREE_TRACE_SCOPE0("MetalDevice::CreateExecutableCache"); - return make_ref(metal_handle_); -} - -StatusOr> MetalDevice::CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) { - IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSetLayout"); - return make_ref(usage_type, bindings); -} - -StatusOr> MetalDevice::CreateExecutableLayout( - absl::Span set_layouts, size_t push_constants) { - IREE_TRACE_SCOPE0("MetalDevice::CreateExecutableLayout"); - return make_ref(set_layouts, push_constants); -} - -StatusOr> MetalDevice::CreateDescriptorSet( - DescriptorSetLayout* set_layout, absl::Span bindings) { - IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSet"); - return make_ref(static_cast(set_layout), - bindings); -} - -StatusOr> MetalDevice::CreateCommandBuffer( - CommandBufferModeBitfield mode, CommandCategoryBitfield command_categories) { - IREE_TRACE_SCOPE0("MetalDevice::CreateCommandBuffer"); - @autoreleasepool { - StatusOr> command_buffer; - // We use commandBufferWithUnretainedReferences here to be performant. This is okay becasue - // IREE tracks the lifetime of various objects with the help from compilers. - id cmdbuf = [static_cast(common_queue_)->handle() - commandBufferWithUnretainedReferences]; - command_buffer = MetalCommandBuffer::Create(mode, command_categories, cmdbuf); - // TODO: WrapCommandBufferWithValidation(allocator(), std::move(impl)); - return command_buffer; - } -} - -StatusOr> MetalDevice::CreateEvent() { - IREE_TRACE_SCOPE0("MetalDevice::CreateEvent"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateEvent"; -} - -StatusOr> MetalDevice::CreateSemaphore(uint64_t initial_value) { - IREE_TRACE_SCOPE0("MetalDevice::CreateSemaphore"); - return MetalSharedEvent::Create(metal_handle_, event_listener_, initial_value); -} - -Status MetalDevice::WaitAllSemaphores(absl::Span semaphores, - Time deadline_ns) { - IREE_TRACE_SCOPE0("MetalDevice::WaitAllSemaphores"); - // Go through all MetalSharedEvents and wait on each of them given we need all of them to be - // signaled anyway. - for (int i = 0; i < semaphores.size(); ++i) { - auto* semaphore = static_cast(semaphores[i].semaphore); - IREE_RETURN_IF_ERROR(semaphore->Wait(semaphores[i].value, deadline_ns)); - } - return OkStatus(); -} - -StatusOr MetalDevice::WaitAnySemaphore(absl::Span semaphores, - Time deadline_ns) { - IREE_TRACE_SCOPE0("MetalDevice::WaitAnySemaphore"); - - if (semaphores.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "expected to have at least one semaphore"; - } - - // If there is just one semaphore, just wait on it. - if (semaphores.size() == 1) { - auto* semaphore = static_cast(semaphores[0].semaphore); - IREE_RETURN_IF_ERROR(semaphore->Wait(semaphores[0].value, deadline_ns)); - return 0; - } - - // Otherwise, we need to go down a more complicated path by registering listeners to all - // MTLSharedEvents to notify us when at least one of them has done the work on GPU by signaling a - // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on - // the semaphore. - - dispatch_time_t timeout = DeadlineToDispatchTime(deadline_ns); - - // Store the handle as a __block variable to allow blocks accessing the same copy for the - // semaphore handle on heap. - // Use an initial value of zero so that any semaphore signal will unblock the wait. - __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0); - // Also create a __block variable to store the index for the signaled semaphore. - __block int signaled_index = 0; - - // The dispatch queue created in the above is a serial one. So even if multiple semaphores signal, - // the semaphore signaling should be serialized. - for (int i = 0; i < semaphores.size(); ++i) { - auto* semaphore = static_cast(semaphores[i].semaphore); - [semaphore->handle() notifyListener:event_listener_ - atValue:semaphores[i].value - block:^(id, uint64_t) { - dispatch_semaphore_signal(work_done); - // This should capture the *current* index for each semaphore. - signaled_index = i; - }]; - } - - long timed_out = dispatch_semaphore_wait(work_done, timeout); - - dispatch_release(work_done); - - if (timed_out) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for dispatch_semaphore_t"; - } - return signaled_index; -} - -Status MetalDevice::WaitIdle(Time deadline_ns) { - IREE_TRACE_SCOPE0("MetalDevice::WaitIdle"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::WaitIdle"; -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_direct_allocator.h b/iree/hal/metal/metal_direct_allocator.h deleted file mode 100644 index bf8dde8c7c790..0000000000000 --- a/iree/hal/metal/metal_direct_allocator.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_DIRECT_ALLOCATOR_H_ -#define IREE_HAL_METAL_METAL_DIRECT_ALLOCATOR_H_ - -#import - -#include - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" - -namespace iree { -namespace hal { -namespace metal { - -class MetalBuffer; - -// An allocator implementation for Metal that directly wraps a MTLDevice and -// requests all allocations on the device. This is not of great performance, -// but good for start. -class MetalDirectAllocator final : public Allocator { - public: - static std::unique_ptr Create( - id device, id transfer_queue); - - ~MetalDirectAllocator() override; - - bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const override; - - bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const override; - - Status MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const override; - - StatusOr> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) override; - - StatusOr> AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr source_buffer) override; - - StatusOr> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - void* data, - size_t data_length) override; - - private: - explicit MetalDirectAllocator(id device, - id transfer_queue); - - StatusOr> AllocateInternal( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - MemoryAccessBitfield allowed_access, size_t allocation_size); - - id metal_device_; - id metal_transfer_queue_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_DIRECT_ALLOCATOR_H_ diff --git a/iree/hal/metal/metal_direct_allocator.mm b/iree/hal/metal/metal_direct_allocator.mm deleted file mode 100644 index 3a88c700bc805..0000000000000 --- a/iree/hal/metal/metal_direct_allocator.mm +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_direct_allocator.h" - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/metal/metal_buffer.h" - -namespace iree { -namespace hal { -namespace metal { - -namespace { - -// Returns the proper Metal resource storage mode given the specific MemoryType. -MTLResourceOptions SelectMTLResourceStorageMode(MemoryType memory_type) { - // There are four MTLStorageMode: - // * Managed: The CPU and GPU may maintain separate copies of the resource, and any changes - // must be explicitly synchronized. - // * Shared: The resource is stored in system memory and is accessible to both the CPU and - // the GPU. - // * Private: The resource can be accessed only by the GPU. - // * Memoryless: The resource’s contents can be accessed only by the GPU and only exist - // temporarily during a render pass. - // macOS has all of the above; MTLStorageModeManaged is not available on iOS. - // - // The IREE HAL is modeled after Vulkan so it's quite explicit. For buffers visible to both - // the host and the device, we would like to opt in with the explicit version - // (MTLStorageManaged) when possible because it should be more performant: "In macOS, - // there’s no difference in GPU performance between managed and private buffers." But for - // iOS, MTLStorageShared should be good given we have a unified memory model there. - - if (AllBitsSet(memory_type, MemoryType::kDeviceLocal)) { - if (AllBitsSet(memory_type, MemoryType::kHostVisible)) { - // Device-local, host-visible. - // TODO(antiagainst): Enable using MTLResourceStorageModeManaged on macOS once we have - // defined invalidate/flush C APIs and wired up their usage through the stack. At the - // moment if we use MTLResourceStorageModeManaged, due to no proper invlidate/flush - // actions, the kernel invocations' data read/write will not be properly synchronized. - return MTLResourceStorageModeShared; - } else { - // Device-local only. - return MTLResourceStorageModePrivate; - } - } else { - if (AllBitsSet(memory_type, MemoryType::kDeviceVisible)) { - // Host-local, device-visible. - return MTLResourceStorageModeShared; - } else { - // Host-local only. - // TODO(antiagainst): we probably want to just use HostBuffer here. - return MTLResourceStorageModeShared; - } - } -} - -} // namespace - -// static -std::unique_ptr MetalDirectAllocator::Create( - id device, id transfer_queue) { - IREE_TRACE_SCOPE0("MetalDirectAllocator::Create"); - return absl::WrapUnique(new MetalDirectAllocator(device, transfer_queue)); -} - -MetalDirectAllocator::MetalDirectAllocator(id device, id transfer_queue) - : metal_device_([device retain]), metal_transfer_queue_([transfer_queue retain]) {} - -MetalDirectAllocator::~MetalDirectAllocator() { - IREE_TRACE_SCOPE0("MetalDirectAllocator::dtor"); - [metal_transfer_queue_ release]; - [metal_device_ release]; -} - -bool MetalDirectAllocator::CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const { - // TODO(benvanik): ensure there is a memory type that can satisfy the request. - return source_allocator == this; -} - -bool MetalDirectAllocator::CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const { - // TODO(benvanik): ensure there is a memory type that can satisfy the request. - return true; -} - -Status MetalDirectAllocator::MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const { - // TODO(benvanik): mutate to match supported memory types. - return OkStatus(); -} - -StatusOr> MetalDirectAllocator::AllocateInternal( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - MemoryAccessBitfield allowed_access, size_t allocation_size) { - IREE_TRACE_SCOPE0("MetalDirectAllocator::AllocateInternal"); - - MTLResourceOptions resource_options = SelectMTLResourceStorageMode(memory_type); - - // IREE is more explicit than Metal: it tracks various state by itself. There is no need - // to incur Metal runtime overhead for hazard tracking. - resource_options |= MTLResourceHazardTrackingModeUntracked; - - id metal_buffer = [metal_device_ newBufferWithLength:allocation_size - options:resource_options]; // retained - - return MetalBuffer::CreateUnretained( - this, memory_type, allowed_access, buffer_usage, allocation_size, /*byte_offset=*/0, - /*byte_length=*/allocation_size, metal_buffer, metal_transfer_queue_); -} - -StatusOr> MetalDirectAllocator::Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("MetalDirectAllocator::Allocate"); - return AllocateInternal(memory_type, buffer_usage, MemoryAccess::kAll, allocation_size); -} - -StatusOr> MetalDirectAllocator::AllocateConstant(BufferUsageBitfield buffer_usage, - ref_ptr source_buffer) { - IREE_TRACE_SCOPE0("MetalDirectAllocator::AllocateConstant"); - // TODO(benvanik): import memory to avoid the copy. - IREE_ASSIGN_OR_RETURN( - auto buffer, AllocateInternal(MemoryType::kDeviceLocal | MemoryType::kHostVisible, - buffer_usage, MemoryAccess::kRead | MemoryAccess::kDiscardWrite, - source_buffer->byte_length())); - IREE_RETURN_IF_ERROR(buffer->CopyData(0, source_buffer.get(), 0, kWholeBuffer)); - return buffer; -} - -StatusOr> MetalDirectAllocator::WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - void* data, size_t data_length) { - IREE_TRACE_SCOPE0("MetalDirectAllocator::WrapMutable"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalDirectAllocator::WrapMutable"; -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_driver.h b/iree/hal/metal/metal_driver.h deleted file mode 100644 index f42e9eda67cef..0000000000000 --- a/iree/hal/metal/metal_driver.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_DRIVER_H_ -#define IREE_HAL_METAL_METAL_DRIVER_H_ - -#include -#include - -#include "iree/hal/debug_capture_manager.h" -#include "iree/hal/driver.h" - -namespace iree { -namespace hal { -namespace metal { - -struct MetalDriverOptions { - // Whether to enable Metal command capture. - bool enable_capture; - // The file to contain the Metal capture. Empty means capturing to Xcode. - std::string capture_file; -}; - -// A pseudo Metal GPU driver which retains all available Metal GPU devices -// during its lifetime. -// -// It uses the DriverDeviceID to store the underlying id. -class MetalDriver final : public Driver { - public: - static StatusOr> Create( - const MetalDriverOptions& options); - - ~MetalDriver() override; - - StatusOr> EnumerateAvailableDevices() override; - - StatusOr> CreateDefaultDevice() override; - - StatusOr> CreateDevice(DriverDeviceID device_id) override; - - private: - MetalDriver(std::vector devices, - std::unique_ptr debug_capture_manager); - - std::vector devices_; - - std::unique_ptr debug_capture_manager_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_DRIVER_H_ diff --git a/iree/hal/metal/metal_driver.mm b/iree/hal/metal/metal_driver.mm deleted file mode 100644 index 3742d3e311684..0000000000000 --- a/iree/hal/metal/metal_driver.mm +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_driver.h" - -#import - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/metal/metal_capture_manager.h" -#include "iree/hal/metal/metal_device.h" - -namespace iree { -namespace hal { -namespace metal { - -namespace { - -// Returns an autoreleased array of available Metal GPU devices. -NSArray>* GetAvailableMetalDevices() { -#if defined(IREE_PLATFORM_MACOS) - // For macOS, we might have more than one GPU devices. - return [MTLCopyAllDevices() autorelease]; -#else - // For other Apple platforms, we only have one GPU device. - id device = [MTLCreateSystemDefaultDevice() autorelease]; - return [NSArray arrayWithObject:device]; -#endif -} - -} // namespace - -// static -StatusOr> MetalDriver::Create(const MetalDriverOptions& options) { - IREE_TRACE_SCOPE0("MetalDriver::Create"); - - @autoreleasepool { - NSArray>* devices = GetAvailableMetalDevices(); - if (devices == nil) { - return UnavailableErrorBuilder(IREE_LOC) << "no Metal GPU devices available"; - } - - std::unique_ptr metal_capture_manager; - if (options.enable_capture) { - IREE_ASSIGN_OR_RETURN(metal_capture_manager, - MetalCaptureManager::Create(options.capture_file)); - IREE_RETURN_IF_ERROR(metal_capture_manager->Connect()); - } - - std::vector device_infos; - for (id device in devices) { - std::string name = std::string([device.name UTF8String]); - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - DriverDeviceID device_id = reinterpret_cast((__bridge void*)device); - device_infos.emplace_back("metal", std::move(name), supported_features, device_id); - } - return assign_ref(new MetalDriver(std::move(device_infos), std::move(metal_capture_manager))); - } -} - -MetalDriver::MetalDriver(std::vector devices, - std::unique_ptr debug_capture_manager) - : Driver("metal"), - devices_(std::move(devices)), - debug_capture_manager_(std::move(debug_capture_manager)) { - // Retain all the retained Metal GPU devices. - for (const auto& device : devices_) { - [(__bridge id)device.device_id() retain]; - } -} - -MetalDriver::~MetalDriver() { - IREE_TRACE_SCOPE0("MetalDriver::dtor"); - - // Release all the retained Metal GPU devices. - for (const auto& device : devices_) { - [(__bridge id)device.device_id() release]; - } -} - -StatusOr> MetalDriver::EnumerateAvailableDevices() { - IREE_TRACE_SCOPE0("MetalDriver::EnumerateAvailableDevices"); - - return devices_; -} - -StatusOr> MetalDriver::CreateDefaultDevice() { - IREE_TRACE_SCOPE0("MetalDriver::CreateDefaultDevice"); - - if (devices_.empty()) { - return UnavailableErrorBuilder(IREE_LOC) << "no Metal GPU devices available"; - } - return CreateDevice(devices_.front().device_id()); -} - -StatusOr> MetalDriver::CreateDevice(DriverDeviceID device_id) { - IREE_TRACE_SCOPE0("MetalDriver::CreateDevice"); - - for (const DeviceInfo& info : devices_) { - if (info.device_id() == device_id) { - return MetalDevice::Create(add_ref(this), info, debug_capture_manager_.get()); - } - } - return InvalidArgumentErrorBuilder(IREE_LOC) << "unknown driver device id: " << device_id; -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_kernel_library.h b/iree/hal/metal/metal_kernel_library.h deleted file mode 100644 index 7ac2506e60f4c..0000000000000 --- a/iree/hal/metal/metal_kernel_library.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_KERNEL_LIBRARY_H_ -#define IREE_HAL_METAL_METAL_KERNEL_LIBRARY_H_ - -#import - -#include - -#include "absl/container/inlined_vector.h" -#include "iree/base/status.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/executable_spec.h" - -// flatcc schemas: -#include "iree/base/flatcc.h" -#include "iree/schemas/metal_executable_def_builder.h" -#include "iree/schemas/metal_executable_def_reader.h" -#include "iree/schemas/metal_executable_def_verifier.h" - -namespace iree { -namespace hal { -namespace metal { - -// An executable implementation for Metal that wraps MTLLibrary and MTLFunction. -// -// Metal represents compute kernels as MTLFunctions. MTLLibrary is just an -// allocation of MTLFunctions. One creates a MTLComputePipelineState from a -// MTLFunction and uses the pipeline state for creating compute pipelines. -// This class bundles all the necesary Metal objects for getting pipeline state -// objects for a compute kernel. -class MetalKernelLibrary final : public Executable { - public: - static StatusOr> Create( - id device, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec); - ~MetalKernelLibrary() override; - - bool supports_debugging() const override { return false; } - - // Returns the MTLFunction for the entry point with the given |ordinal|. - StatusOr> GetKernelForEntryPoint(int ordinal) const; - - // Returns the threadgroup size for the entry point with the given |ordinal|. - StatusOr GetThreadgroupSizeForEntryPoint( - int ordinal) const; - - // Returns the pipeline state object for the entry point with the given - // |ordinal|. - StatusOr> GetPipelineStateForEntryPoint( - int ordinal) const; - - private: - struct KernelObjects { - id function; - iree_MetalThreadgroupSize_t threadgroup_size; - // Baked pipeline state object. - id pipeline_state; - }; - - // Creates a MetalKernelLibrary assuming all Metal objects are already - // retained before passing in. - MetalKernelLibrary(id device, - absl::InlinedVector, 4> libraries, - absl::InlinedVector kernel_objects); - - id device_; - - absl::InlinedVector, 4> libraries_; - absl::InlinedVector kernel_objects_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_KERNEL_LIBRARY_H_ diff --git a/iree/hal/metal/metal_kernel_library.mm b/iree/hal/metal/metal_kernel_library.mm deleted file mode 100644 index 8ea7e6ae25360..0000000000000 --- a/iree/hal/metal/metal_kernel_library.mm +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_kernel_library.h" - -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -// NOTE: starting to port this to ObjC. - -// Verifies the structure of the flatbuffer so that we can avoid doing so during -// runtime. There are still some conditions we must be aware of (such as omitted -// names on functions with internal linkage), however we shouldn't need to -// bounds check anything within the flatbuffer after this succeeds. -static iree_status_t iree_hal_metal_executable_flatbuffer_verify( - iree_const_byte_span_t flatbuffer_data) { - if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer data is not present or less than 16 bytes (%zu total)", - flatbuffer_data.data_length); - } - - // Run flatcc generated verification. This ensures all pointers are in-bounds - // and that we can safely walk the file, but not that the actual contents of - // the flatbuffer meet our expectations. - int verify_ret = - iree_MetalExecutableDef_verify_as_root(flatbuffer_data.data, flatbuffer_data.data_length); - if (verify_ret != flatcc_verify_ok) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "flatbuffer verification failed: %s", - flatcc_verify_error_string(verify_ret)); - } - - iree_MetalExecutableDef_table_t executable_def = - iree_MetalExecutableDef_as_root(flatbuffer_data.data); - - flatbuffers_string_vec_t entry_points_vec = - iree_MetalExecutableDef_entry_points_get(executable_def); - size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); - for (size_t i = 0; i < entry_point_count; ++i) { - if (!flatbuffers_string_len(flatbuffers_string_vec_at(entry_points_vec, i))) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable entry point %zu has no name", i); - } - } - - iree_MetalThreadgroupSize_vec_t threadgroup_sizes_vec = - iree_MetalExecutableDef_threadgroup_sizes(executable_def); - size_t threadgroup_size_count = iree_MetalThreadgroupSize_vec_len(threadgroup_sizes_vec); - if (!threadgroup_size_count) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No threadgroup sizes present"; - } - - flatbuffers_string_vec_t shader_sources_vec = - iree_MetalExecutableDef_shader_sources_get(executable_def); - size_t shader_source_count = flatbuffers_string_vec_len(shader_sources_vec); - for (size_t i = 0; i < shader_source_count; ++i) { - if (!flatbuffers_string_len(flatbuffers_string_vec_at(shader_sources_vec, i))) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "executable shader source %zu is empty", - i); - } - } - - if (entry_point_count != threadgroup_size_count || entry_point_count != shader_source_count) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "mismatch among the numbers of entry points (%zu), thread group sizes " - "(%zu), and source strings (%zu)", - entry_point_count, threadgroup_size_count, shader_source_count); - } - - return iree_ok_status(); -} - -namespace iree { -namespace hal { -namespace metal { - -// static -StatusOr> MetalKernelLibrary::Create(id device, - ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("MetalKernelLibrary::Create"); - - // Verify and fetch the executable flatbuffer wrapper. - iree_const_byte_span_t executable_data = - iree_make_const_byte_span(spec.executable_data.data(), spec.executable_data.size()); - IREE_RETURN_IF_ERROR(iree_hal_metal_executable_flatbuffer_verify(executable_data)); - iree_MetalExecutableDef_table_t executable_def = - iree_MetalExecutableDef_as_root(executable_data.data); - - flatbuffers_string_vec_t entry_points_vec = - iree_MetalExecutableDef_entry_points_get(executable_def); - iree_MetalThreadgroupSize_vec_t threadgroup_sizes_vec = - iree_MetalExecutableDef_threadgroup_sizes(executable_def); - flatbuffers_string_vec_t shader_sources_vec = - iree_MetalExecutableDef_shader_sources_get(executable_def); - - // Compile each MSL source string into a MTLLibrary and get the MTLFunction for the entry point to - // build the pipeline state object. - - absl::InlinedVector, 4> libraries; - absl::InlinedVector kernel_objects; - - MTLCompileOptions* msl_compile_options = [MTLCompileOptions new]; - msl_compile_options.languageVersion = MTLLanguageVersion2_0; - - auto cleanup = MakeCleanup([&]() { - for (const auto& kernel : kernel_objects) { - [kernel.pipeline_state release]; - [kernel.function release]; - } - for (id library : libraries) [library release]; - [msl_compile_options release]; - }); - - // TODO(antiagainst): We are performing synchronous compilation at runtime here. This is good for - // debugging purposes but bad for performance. Enable offline compilation and make that as the - // default. - - for (size_t entry_ordinal = 0; entry_ordinal < flatbuffers_string_vec_len(shader_sources_vec); - ++entry_ordinal) { - flatbuffers_string_t entry_point = flatbuffers_string_vec_at(entry_points_vec, entry_ordinal); - @autoreleasepool { - NSError* error = nil; - - NSString* shader_source = - [NSString stringWithCString:flatbuffers_string_vec_at(shader_sources_vec, entry_ordinal) - encoding:[NSString defaultCStringEncoding]]; - id library = [device newLibraryWithSource:shader_source - options:msl_compile_options - error:&error]; - if (!library) { - NSLog(@"Failed to create MTLLibrary: %@", error); -#ifndef NDEBUG - NSLog(@"Original MSL source: %@", shader_source); -#endif - return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid MSL source"; - } - libraries.push_back(library); - - id function = [library - newFunctionWithName:[NSString stringWithCString:entry_point - encoding:[NSString defaultCStringEncoding]]]; - if (!function) { - NSLog(@"Failed to create MTLFunction"); -#ifndef NDEBUG - NSLog(@"Original MSL source: %@", shader_source); -#endif - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Cannot find entry point '" << entry_point << "' in shader source index " - << entry_ordinal; - } - - id pso = [device newComputePipelineStateWithFunction:function - error:&error]; - if (!pso) { - NSLog(@"Failed to create MTLComputePipelineState: %@", error); -#ifndef NDEBUG - NSLog(@"Original MSL source: %@", shader_source); -#endif - return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid MSL source"; - } - - kernel_objects.push_back( - KernelObjects{function, {static_cast(iree_MetalThreadgroupSize__size())}, pso}); - } - } - - return assign_ref( - new MetalKernelLibrary([device retain], std::move(libraries), std::move(kernel_objects))); -} - -MetalKernelLibrary::MetalKernelLibrary(id device, - absl::InlinedVector, 4> libraries, - absl::InlinedVector kernel_objects) - : device_(device), - libraries_(std::move(libraries)), - kernel_objects_(std::move(kernel_objects)) {} - -MetalKernelLibrary::~MetalKernelLibrary() { - IREE_TRACE_SCOPE0("MetalKernelLibrary::dtor"); - for (const auto& kernel : kernel_objects_) { - [kernel.pipeline_state release]; - [kernel.function release]; - } - for (id library : libraries_) [library release]; -} - -StatusOr> MetalKernelLibrary::GetKernelForEntryPoint(int ordinal) const { - if (ordinal < 0 || ordinal >= kernel_objects_.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal: " << ordinal; - } - return kernel_objects_[ordinal].function; -} - -StatusOr MetalKernelLibrary::GetThreadgroupSizeForEntryPoint( - int ordinal) const { - if (ordinal < 0 || ordinal >= kernel_objects_.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal: " << ordinal; - } - return kernel_objects_[ordinal].threadgroup_size; -} - -StatusOr> MetalKernelLibrary::GetPipelineStateForEntryPoint( - int ordinal) const { - if (ordinal < 0 || ordinal >= kernel_objects_.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal: " << ordinal; - } - return kernel_objects_[ordinal].pipeline_state; -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_pipeline_argument_buffer.cc b/iree/hal/metal/metal_pipeline_argument_buffer.cc deleted file mode 100644 index 018beaba9a369..0000000000000 --- a/iree/hal/metal/metal_pipeline_argument_buffer.cc +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_pipeline_argument_buffer.h" - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" - -namespace iree { -namespace hal { -namespace metal { - -MetalArgumentBufferLayout::MetalArgumentBufferLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) - : usage_type_(usage_type), bindings_(bindings.begin(), bindings.end()) {} - -const DescriptorSetLayout::Binding* -MetalArgumentBufferLayout::GetBindingForIndex(int index) const { - for (const auto& binding : bindings_) { - if (binding.binding == index) return &binding; - } - return nullptr; -} - -std::string MetalArgumentBufferLayout::DebugString() const { - std::vector binding_strings; - binding_strings.reserve(bindings_.size()); - for (const auto& binding : bindings_) { - binding_strings.push_back( - absl::StrCat("[", binding.DebugStringShort(), "]")); - } - return absl::StrCat("bindings=[", absl::StrJoin(binding_strings, ", "), "]"); -} - -MetalPipelineArgumentBufferLayout::MetalPipelineArgumentBufferLayout( - absl::Span set_layouts, size_t push_constants) - : set_layouts_(set_layouts.size()), push_constants_(push_constants) { - for (int i = 0; i < set_layouts.size(); ++i) { - set_layouts_[i] = static_cast(set_layouts[i]); - set_layouts_[i]->AddReference(); - } -} - -MetalPipelineArgumentBufferLayout::~MetalPipelineArgumentBufferLayout() { - for (auto* layout : set_layouts_) layout->ReleaseReference(); -} - -std::string MetalPipelineArgumentBufferLayout::DebugString() const { - std::vector set_strings; - set_strings.reserve(set_layouts_.size()); - for (int i = 0; i < set_layouts_.size(); ++i) { - set_strings.push_back( - absl::StrCat("{set=", i, ", ", set_layouts_[i]->DebugString(), "}")); - } - return absl::StrCat("sets={", absl::StrJoin(set_strings, "; "), "}"); -} - -MetalArgumentBuffer::MetalArgumentBuffer( - MetalArgumentBufferLayout* layout, - absl::Span resources) - : layout_(layout), resources_(resources.begin(), resources.end()) { - layout_->AddReference(); -} - -MetalArgumentBuffer::~MetalArgumentBuffer() { layout_->ReleaseReference(); } - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_pipeline_argument_buffer.h b/iree/hal/metal/metal_pipeline_argument_buffer.h deleted file mode 100644 index 2b349f5b59c18..0000000000000 --- a/iree/hal/metal/metal_pipeline_argument_buffer.h +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_PIPELINE_ARGUMENT_BUFFER_H_ -#define IREE_HAL_METAL_METAL_PIPELINE_ARGUMENT_BUFFER_H_ - -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/hal/descriptor_set.h" -#include "iree/hal/descriptor_set_layout.h" -#include "iree/hal/executable_layout.h" - -// Metal implementaion classes for resource descriptor related interfaces. -// -// See docs/design_docs/metal_hal_driver.md#resource-descriptors for more -// details. - -namespace iree { -namespace hal { -namespace metal { - -class MetalArgumentBufferLayout final : public DescriptorSetLayout { - public: - MetalArgumentBufferLayout(UsageType usage_type, - absl::Span bindings); - ~MetalArgumentBufferLayout() override = default; - - absl::Span bindings() const { return bindings_; } - const Binding* GetBindingForIndex(int index) const; - - std::string DebugString() const override; - - private: - UsageType usage_type_; - absl::InlinedVector bindings_; -}; - -class MetalPipelineArgumentBufferLayout final : public ExecutableLayout { - public: - MetalPipelineArgumentBufferLayout( - absl::Span set_layouts, - size_t push_constants); - ~MetalPipelineArgumentBufferLayout() override; - - absl::Span set_layouts() const { - return set_layouts_; - } - - std::string DebugString() const override; - - private: - absl::InlinedVector set_layouts_; - size_t push_constants_; -}; - -class MetalArgumentBuffer final : public DescriptorSet { - public: - MetalArgumentBuffer(MetalArgumentBufferLayout* layout, - absl::Span resources); - ~MetalArgumentBuffer() override; - - private: - MetalArgumentBufferLayout* layout_; - absl::InlinedVector resources_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_PIPELINE_ARGUMENT_BUFFER_H_ diff --git a/iree/hal/metal/metal_pipeline_cache.h b/iree/hal/metal/metal_pipeline_cache.h deleted file mode 100644 index 1b59e1ca91ceb..0000000000000 --- a/iree/hal/metal/metal_pipeline_cache.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_PIPELINE_CACHE_H_ -#define IREE_HAL_METAL_METAL_PIPELINE_CACHE_H_ - -#import - -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" - -namespace iree { -namespace hal { -namespace metal { - -// An ExecutableCache implementation for Metal. -class MetalPipelineCache final : public ExecutableCache { - public: - explicit MetalPipelineCache(id device); - ~MetalPipelineCache() override; - - bool CanPrepareFormat(ExecutableFormat format) const override; - - StatusOr> PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) override; - - private: - id metal_device_; -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_PIPELINE_CACHE_H_ diff --git a/iree/hal/metal/metal_pipeline_cache.mm b/iree/hal/metal/metal_pipeline_cache.mm deleted file mode 100644 index 1e987704b7320..0000000000000 --- a/iree/hal/metal/metal_pipeline_cache.mm +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_pipeline_cache.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/metal/metal_kernel_library.h" - -namespace iree { -namespace hal { -namespace metal { - -MetalPipelineCache::MetalPipelineCache(id device) : metal_device_([device retain]) {} - -MetalPipelineCache::~MetalPipelineCache() { [metal_device_ release]; } - -bool MetalPipelineCache::CanPrepareFormat(ExecutableFormat format) const { - return format == kExecutableFormatMetal; -} - -StatusOr> MetalPipelineCache::PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("MetalPipelineCache::PrepareExecutable"); - - // Create the Metal library (which may itself own many pipeline states). - IREE_ASSIGN_OR_RETURN(auto executable, MetalKernelLibrary::Create(metal_device_, mode, spec)); - - return executable; -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/metal_shared_event.h b/iree/hal/metal/metal_shared_event.h deleted file mode 100644 index 10a96ce2ea07f..0000000000000 --- a/iree/hal/metal/metal_shared_event.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_METAL_METAL_SHARED_EVENT_H_ -#define IREE_HAL_METAL_METAL_SHARED_EVENT_H_ - -#import - -#include "absl/synchronization/mutex.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { -namespace metal { - -// A semaphore implementation for Metal that directly wraps a MTLSharedEvent. -class MetalSharedEvent final : public Semaphore { - public: - // Creates a MetalSharedEvent with the given |initial_value|. - static StatusOr> Create( - id device, MTLSharedEventListener* event_listener, - uint64_t initial_value); - - ~MetalSharedEvent() override; - - id handle() const { return metal_handle_; } - - StatusOr Query() override; - - Status Signal(uint64_t value) override; - - void Fail(Status status) override; - - Status Wait(uint64_t value, Time deadline_ns) override; - - private: - MetalSharedEvent(id device, MTLSharedEventListener* event_listener, - uint64_t initial_value); - - id metal_handle_; - - // An event listener for waiting and signaling. Its lifetime is managed by - // the parent device. - MTLSharedEventListener* event_listener_; - - // NOTE: the MTLSharedEvent is the source of truth. We only need to access - // this status (and thus take the lock) when we want to either signal failure - // or query the status in the case of the semaphore being set to UINT64_MAX. - mutable absl::Mutex status_mutex_; - Status status_ ABSL_GUARDED_BY(status_mutex_); -}; - -} // namespace metal -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_METAL_METAL_SHARED_EVENT_H_ diff --git a/iree/hal/metal/metal_shared_event.mm b/iree/hal/metal/metal_shared_event.mm deleted file mode 100644 index 325c30a6f20ec..0000000000000 --- a/iree/hal/metal/metal_shared_event.mm +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/metal_shared_event.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/metal/dispatch_time_util.h" - -namespace iree { -namespace hal { -namespace metal { - -// static -StatusOr> MetalSharedEvent::Create(id device, - MTLSharedEventListener* event_listener, - uint64_t initial_value) { - return assign_ref(new MetalSharedEvent(device, event_listener, initial_value)); -} - -MetalSharedEvent::MetalSharedEvent(id device, MTLSharedEventListener* event_listener, - uint64_t initial_value) - : metal_handle_([device newSharedEvent]), event_listener_(event_listener) { - IREE_TRACE_SCOPE0("MetalSharedEvent::ctor"); - metal_handle_.signaledValue = initial_value; -} - -MetalSharedEvent::~MetalSharedEvent() { - IREE_TRACE_SCOPE0("MetalSharedEvent::dtor"); - [metal_handle_ release]; -} - -StatusOr MetalSharedEvent::Query() { - IREE_TRACE_SCOPE0("MetalSharedEvent::Query"); - uint64_t value = metal_handle_.signaledValue; - if (value == UINT64_MAX) { - absl::MutexLock lock(&status_mutex_); - return status_; - } - return value; -} - -Status MetalSharedEvent::Signal(uint64_t value) { - IREE_TRACE_SCOPE0("MetalSharedEvent::Signal"); - metal_handle_.signaledValue = value; - return OkStatus(); -} - -void MetalSharedEvent::Fail(Status status) { - IREE_TRACE_SCOPE0("MetalSharedEvent::Fail"); - absl::MutexLock lock(&status_mutex_); - status_ = std::move(status); - metal_handle_.signaledValue = UINT64_MAX; -} - -Status MetalSharedEvent::Wait(uint64_t value, Time deadline_ns) { - IREE_TRACE_SCOPE0("MetalSharedEvent::Wait"); - - Duration duration_ns = DeadlineToRelativeTimeoutNanos(deadline_ns); - dispatch_time_t timeout = DurationToDispatchTime(duration_ns); - - // Quick path for impatient waiting to avoid all the overhead of dispatch queues and semaphores. - if (duration_ns == ZeroDuration()) { - if (metal_handle_.signaledValue < value) { - return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline exceeded waiting for semaphores"; - } - return OkStatus(); - } - - // Theoretically we don't really need to mark the semaphore handle as __block given that the - // handle itself is not modified and there is only one block and it will copy the handle. - // But marking it as __block serves as good documentation purpose. - __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0); - - // Use a listener to the MTLSharedEvent to notify us when the work is done on GPU by signaling a - // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on - // the semaphore. - [metal_handle_ notifyListener:event_listener_ - atValue:value - block:^(id, uint64_t) { - dispatch_semaphore_signal(work_done); - }]; - - long timed_out = dispatch_semaphore_wait(work_done, timeout); - - dispatch_release(work_done); - - if (timed_out) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for dispatch_semaphore_t"; - } - return OkStatus(); -} - -} // namespace metal -} // namespace hal -} // namespace iree diff --git a/iree/hal/metal/registration/BUILD.bazel b/iree/hal/metal/registration/BUILD.bazel deleted file mode 100644 index 6a87e62762ab2..0000000000000 --- a/iree/hal/metal/registration/BUILD.bazel +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_cmake_extra_content( - content = """ -if(${IREE_HAL_DRIVER_METAL}) -""", - inline = True, -) - -cc_library( - name = "registration", - srcs = ["driver_module.cc"], - hdrs = ["driver_module.h"], - defines = [ - "IREE_HAL_HAVE_METAL_DRIVER_MODULE=1", - ], - deps = [ - "//iree/base:flags", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:api", - "//iree/hal/metal", - "@com_google_absl//absl/flags:flag", - ], -) - -iree_cmake_extra_content( - content = """ -endif() -""", - inline = True, -) diff --git a/iree/hal/metal/registration/CMakeLists.txt b/iree/hal/metal/registration/CMakeLists.txt deleted file mode 100644 index 90033d6f10eb2..0000000000000 --- a/iree/hal/metal/registration/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -iree_add_all_subdirs() - -if(${IREE_HAL_DRIVER_METAL}) - -iree_cc_library( - NAME - registration - HDRS - "driver_module.h" - SRCS - "driver_module.cc" - DEPS - absl::flags - iree::base::flags - iree::base::status - iree::base::tracing - iree::hal::api - iree::hal::metal - DEFINES - "IREE_HAL_HAVE_METAL_DRIVER_MODULE=1" - PUBLIC -) - -endif() diff --git a/iree/hal/metal/registration/driver_module.cc b/iree/hal/metal/registration/driver_module.cc deleted file mode 100644 index 1a9da7d3435ca..0000000000000 --- a/iree/hal/metal/registration/driver_module.cc +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/metal/registration/driver_module.h" - -#include - -#include "absl/flags/flag.h" -#include "iree/base/flags.h" -#include "iree/hal/metal/metal_driver.h" - -ABSL_FLAG(bool, metal_capture, false, "Enables capturing Metal commands."); -ABSL_FLAG( - std::string, metal_capture_to_file, "", - "Full path to store the GPU trace file (empty means capture to Xcode)"); - -#define IREE_HAL_METAL_DRIVER_ID 0x4D544C31u // MTL1 - -static iree_status_t iree_hal_metal_driver_factory_enumerate( - void* self, const iree_hal_driver_info_t** out_driver_infos, - iree_host_size_t* out_driver_info_count) { - // NOTE: we could query supported metal versions or featuresets here. - static const iree_hal_driver_info_t driver_infos[1] = {{ - /*driver_id=*/IREE_HAL_METAL_DRIVER_ID, - /*driver_name=*/iree_make_cstring_view("metal"), - /*full_name=*/iree_make_cstring_view("Apple Metal GPU"), - }}; - *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); - *out_driver_infos = driver_infos; - return iree_ok_status(); -} - -static iree_status_t iree_hal_metal_driver_factory_try_create( - void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator, - iree_hal_driver_t** out_driver) { - if (driver_id != IREE_HAL_METAL_DRIVER_ID) { - return iree_make_status(IREE_STATUS_UNAVAILABLE, - "no driver with ID %016" PRIu64 - " is provided by this factory", - driver_id); - } - iree::hal::metal::MetalDriverOptions options; - options.enable_capture = absl::GetFlag(FLAGS_metal_capture); - options.capture_file = absl::GetFlag(FLAGS_metal_capture_to_file); - IREE_ASSIGN_OR_RETURN(auto driver, - iree::hal::metal::MetalDriver::Create(options)); - *out_driver = reinterpret_cast(driver.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_metal_driver_module_register(iree_hal_driver_registry_t* registry) { - static const iree_hal_driver_factory_t factory = { - /*self=*/NULL, - iree_hal_metal_driver_factory_enumerate, - iree_hal_metal_driver_factory_try_create, - }; - return iree_hal_driver_registry_register_factory(registry, &factory); -} diff --git a/iree/hal/resource.h b/iree/hal/resource.h index 8105133a8a8ba..d906a9d7ca392 100644 --- a/iree/hal/resource.h +++ b/iree/hal/resource.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,38 +15,67 @@ #ifndef IREE_HAL_RESOURCE_H_ #define IREE_HAL_RESOURCE_H_ -#include -#include - -#include "iree/base/ref_ptr.h" - -namespace iree { -namespace hal { - -// Abstract resource type whose lifetime is managed by a ResourceSet. -// Used mostly just to get a virtual dtor, though we could add nicer logging -// by allowing resources to capture debug names, stack traces of creation, etc. -class Resource : public RefObject { - public: - virtual ~Resource() = default; - - // Returns a longer debug string describing the resource and its attributes. - virtual std::string DebugString() const { return DebugStringShort(); } - // Returns a short debug string describing the resource. - virtual std::string DebugStringShort() const { - // TODO(benvanik): remove this when all resource types have custom logic. - return std::string("resource_") + std::to_string(static_cast( - reinterpret_cast(this))); - } -}; - -} // namespace hal -} // namespace iree - -inline std::ostream& operator<<(std::ostream& stream, - const iree::hal::Resource& resource) { - stream << resource.DebugStringShort(); - return stream; +#include +#include + +#include "iree/base/api.h" +#include "iree/base/atomics.h" +#include "iree/base/debugging.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Abstract resource type whose lifetime is managed by reference counting. +// Used mostly just to get a virtual dtor and vtable, though we could add nicer +// logging by allowing resources to capture debug names, stack traces of +// creation, etc. +// +// All resource types must have the iree_hal_resource_t at offset 0. This allows +// the HAL code to cast any type pointer to a resource to gain access to the +// ref count and vtable at predictable locations. Note that this allows for the +// resource to be at >0 of the allocation but the pointers used with the HAL +// (iree_hal_event_t*, etc) must point to the iree_hal_resource_t. +typedef struct iree_hal_resource_s { + // Reference count used to manage resource lifetime. The vtable->destroy + // method will be called when the reference count falls to zero. + iree_atomic_ref_count_t ref_count; + + // Opaque vtable for the resource object. + // + // NOTE: this field may be hidden in the future. Only use this for + // IREE_HAL_VTABLE_DISPATCH and not equality/direct dereferencing. + const void* vtable; + + // TODO(benvanik): debug string/logging utilities. +} iree_hal_resource_t; + +static inline void iree_hal_resource_initialize( + const void* vtable, iree_hal_resource_t* out_resource) { + iree_atomic_ref_count_init(&out_resource->ref_count); + out_resource->vtable = vtable; +} + +// Returns true if the |resource| has the given |vtable| type. +// This is *not* a way to ensure that an instance is of a specific type but +// instead that it has a compatible vtable. This is because LTO may very rarely +// dedupe identical vtables and cause the pointer comparison to succeed even if +// the spellings of the types differs. +static inline bool iree_hal_resource_is(const void* resource, + const void* vtable) { + return resource ? ((const iree_hal_resource_t*)resource)->vtable == vtable + : false; } +// Asserts (**DEBUG ONLY**) that the |resource| has the given |vtable| type. +// This is only useful to check for programmer error and may have false +// positives - do not rely on it for handling untrusted user input. +#define IREE_HAL_ASSERT_TYPE(resource, vtable) \ + IREE_ASSERT_TRUE(iree_hal_resource_is(resource, vtable), \ + "type does not match expected " #vtable) + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + #endif // IREE_HAL_RESOURCE_H_ diff --git a/iree/hal/semaphore.c b/iree/hal/semaphore.c new file mode 100644 index 0000000000000..4d642879507c8 --- /dev/null +++ b/iree/hal/semaphore.c @@ -0,0 +1,91 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/semaphore.h" + +#include "iree/base/tracing.h" +#include "iree/hal/detail.h" +#include "iree/hal/device.h" + +#define _VTABLE_DISPATCH(semaphore, method_name) \ + IREE_HAL_VTABLE_DISPATCH(semaphore, iree_hal_semaphore, method_name) + +IREE_HAL_API_RETAIN_RELEASE(semaphore); + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_create(iree_hal_device_t* device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(out_semaphore); + *out_semaphore = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, create_semaphore)( + device, initial_value, out_semaphore); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_query(iree_hal_semaphore_t* semaphore, uint64_t* out_value) { + IREE_ASSERT_ARGUMENT(semaphore); + IREE_ASSERT_ARGUMENT(out_value); + *out_value = 0; + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + _VTABLE_DISPATCH(semaphore, query)(semaphore, out_value); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_signal(iree_hal_semaphore_t* semaphore, uint64_t new_value) { + IREE_ASSERT_ARGUMENT(semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + _VTABLE_DISPATCH(semaphore, signal)(semaphore, new_value); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_semaphore_fail(iree_hal_semaphore_t* semaphore, iree_status_t status) { + IREE_ASSERT_ARGUMENT(semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + _VTABLE_DISPATCH(semaphore, fail)(semaphore, status); + IREE_TRACE_ZONE_END(z0); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_wait_with_deadline(iree_hal_semaphore_t* semaphore, + uint64_t value, iree_time_t deadline_ns) { + IREE_ASSERT_ARGUMENT(semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(semaphore, wait_with_deadline)( + semaphore, value, deadline_ns); + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_wait_with_timeout(iree_hal_semaphore_t* semaphore, + uint64_t value, + iree_duration_t timeout_ns) { + IREE_ASSERT_ARGUMENT(semaphore); + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = _VTABLE_DISPATCH(semaphore, wait_with_timeout)( + semaphore, value, timeout_ns); + IREE_TRACE_ZONE_END(z0); + return status; +} diff --git a/iree/hal/semaphore.h b/iree/hal/semaphore.h index 255585cbb1cf6..a0105039e26ba 100644 --- a/iree/hal/semaphore.h +++ b/iree/hal/semaphore.h @@ -1,4 +1,4 @@ -// Copyright 2019 Google LLC +// Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,22 +15,21 @@ #ifndef IREE_HAL_SEMAPHORE_H_ #define IREE_HAL_SEMAPHORE_H_ -#include +#include +#include -#include "iree/base/status.h" -#include "iree/base/time.h" +#include "iree/base/api.h" #include "iree/hal/resource.h" -namespace iree { -namespace hal { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -class Semaphore; +typedef struct iree_hal_device_s iree_hal_device_t; -// A reference to a semaphore and associated payload value. -struct SemaphoreValue { - Semaphore* semaphore = nullptr; - uint64_t value = 0; -}; +//===----------------------------------------------------------------------===// +// iree_hal_semaphore_t +//===----------------------------------------------------------------------===// // Synchronization mechanism for host->device, device->host, host->host, // and device->device notification. Semaphores behave like Vulkan timeline @@ -61,42 +60,99 @@ struct SemaphoreValue { // https://www.youtube.com/watch?v=SpE--Rf516Y // https://www.khronos.org/assets/uploads/developers/library/2018-xdc/Vulkan-Timeline-Semaphores-Part-1_Sep18.pdf // https://docs.microsoft.com/en-us/windows/win32/direct3d12/user-mode-heap-synchronization -class Semaphore : public Resource { - public: - // Queries the current payload of the semaphore. As the payload is - // monotonically increasing it is guaranteed that the value is at least equal - // to the previous result of a Query call and coherent with any waits for - // a specified value via Device::WaitAllSemaphores. - // - // Returns the status/payload at the time the method is called without - // blocking and as such is only valid after a semaphore has been signaled. The - // same failure status will be returned regardless of when in the timeline the - // error occurred. - virtual StatusOr Query() = 0; - - // Signals the semaphore to the given payload value. - // The call is ignored if the current payload value exceeds |value|. - virtual Status Signal(uint64_t value) = 0; - - // Signals the semaphore with a failure. The |status| will be returned from - // Query and Signal for the lifetime of the semaphore. - virtual void Fail(Status status) = 0; - - // Blocks the caller until the semaphore reaches or exceedes the specified - // payload value or the |deadline_ns| elapses. - // - // Returns success if the wait is successful and the semaphore has met or - // exceeded the required payload value. - // - // Returns DEADLINE_EXCEEDED if the |deadline_ns| elapses without the - // semaphore reaching the required value. - virtual Status Wait(uint64_t value, Time deadline_ns) = 0; - inline Status Wait(uint64_t value, Duration timeout_ns) { - return Wait(value, RelativeTimeoutToDeadlineNanos(timeout_ns)); - } -}; - -} // namespace hal -} // namespace iree +typedef struct iree_hal_semaphore_s iree_hal_semaphore_t; + +// Creates a semaphore that can be used with command queues owned by this +// device. To use the semaphores with other devices or instances they must +// first be exported. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_create(iree_hal_device_t* device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + +// Retains the given |semaphore| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore); + +// Releases the given |semaphore| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore); + +// Queries the current payload of the semaphore and stores the result in +// |out_value|. As the payload is monotonically increasing it is guaranteed that +// the value is at least equal to the previous result of a +// iree_hal_semaphore_query call and coherent with any waits for a +// specified value via iree_device_wait_all_semaphores. +// +// Returns the status at the time the method is called without blocking and as +// such is only valid after a semaphore has been signaled. The same failure +// status will be returned regardless of when in the timeline the error +// occurred. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_query(iree_hal_semaphore_t* semaphore, uint64_t* out_value); + +// Signals the |semaphore| to the given payload value. +// The call is ignored if the current payload value exceeds |new_value|. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_signal(iree_hal_semaphore_t* semaphore, uint64_t new_value); + +// Signals the |semaphore| with a failure. The |status| will be returned from +// iree_hal_semaphore_query and iree_hal_semaphore_signal for the lifetime +// of the semaphore. Ownership of the status transfers to the semaphore and +// callers must clone it if they wish to retain it. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_semaphore_fail(iree_hal_semaphore_t* semaphore, iree_status_t status); + +// Blocks the caller until the semaphore reaches or exceedes the specified +// payload value or the |deadline_ns| elapses. +// +// Returns success if the wait is successful and the semaphore has met or +// exceeded the required payload value. +// +// Returns DEADLINE_EXCEEDED if the |deadline_ns| elapses without the semaphore +// reaching the required value. If an asynchronous failure occured this will +// return the failure status that was set immediately. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_wait_with_deadline(iree_hal_semaphore_t* semaphore, + uint64_t value, iree_time_t deadline_ns); + +// Blocks the caller until the semaphore reaches or exceedes the specified +// payload value or the |timeout_ns| elapses. +// A relative-time version of iree_hal_semaphore_wait_with_deadline using the +// relative nanoseconds from the time the call is made. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_wait_with_timeout(iree_hal_semaphore_t* semaphore, + uint64_t value, + iree_duration_t timeout_ns); + +//===----------------------------------------------------------------------===// +// iree_hal_semaphore_t implementation details +//===----------------------------------------------------------------------===// + +typedef struct { + // << HAL C porting in progress >> + IREE_API_UNSTABLE + + void(IREE_API_PTR* destroy)(iree_hal_semaphore_t* semaphore); + + iree_status_t(IREE_API_PTR* query)(iree_hal_semaphore_t* semaphore, + uint64_t* out_value); + iree_status_t(IREE_API_PTR* signal)(iree_hal_semaphore_t* semaphore, + uint64_t new_value); + void(IREE_API_PTR* fail)(iree_hal_semaphore_t* semaphore, + iree_status_t status); + + iree_status_t(IREE_API_PTR* wait_with_deadline)( + iree_hal_semaphore_t* semaphore, uint64_t value, iree_time_t deadline_ns); + iree_status_t(IREE_API_PTR* wait_with_timeout)( + iree_hal_semaphore_t* semaphore, uint64_t value, + iree_duration_t timeout_ns); +} iree_hal_semaphore_vtable_t; + +IREE_API_EXPORT void IREE_API_CALL +iree_hal_semaphore_destroy(iree_hal_semaphore_t* semaphore); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_SEMAPHORE_H_ diff --git a/iree/hal/stack_trace.h b/iree/hal/stack_trace.h deleted file mode 100644 index ca4c450a5b45d..0000000000000 --- a/iree/hal/stack_trace.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_STACK_TRACE_H_ -#define IREE_HAL_STACK_TRACE_H_ - -namespace iree { -namespace hal { - -class StackTrace { - public: - // TODO(benvanik): define contents. - // frame: - // device type (cpu, etc) - // effective processor type (determines disasm/etc) <- r52, vliw, etc - // effective offset <- in disasm (abstract, could be op ordinal, byte - // offset) - // source offset <- used in source map lookup - // physical offset <- informative, void* (real memory address) - // physical_context (x86 registers, etc) - // effective_context (??) - // source_context (buffer views/etc?) -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_STACK_TRACE_H_ diff --git a/iree/hal/string_util.cc b/iree/hal/string_util.cc new file mode 100644 index 0000000000000..6cb48bd889df1 --- /dev/null +++ b/iree/hal/string_util.cc @@ -0,0 +1,607 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/string_util.h" + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/types/span.h" +#include "iree/base/api.h" +#include "iree/base/memory.h" +#include "iree/base/tracing.h" +#include "iree/hal/buffer.h" +#include "iree/hal/buffer_view.h" +#include "third_party/half/half.hpp" + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_shape( + iree_string_view_t value, iree_host_size_t shape_capacity, + iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank) { + IREE_ASSERT_ARGUMENT(out_shape_rank); + *out_shape_rank = 0; + + auto str_value = absl::string_view(value.data, value.size); + if (str_value.empty()) { + return iree_ok_status(); // empty shape + } + + absl::InlinedVector dims; + for (auto dim_str : absl::StrSplit(str_value, 'x')) { + int dim_value = 0; + if (!absl::SimpleAtoi(dim_str, &dim_value)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "shape[%zu] invalid value '%.*s' of '%.*s'", + dims.size(), (int)dim_str.size(), dim_str.data(), + (int)value.size, value.data); + } + if (dim_value < 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "shape[%zu] unsupported value %d of '%.*s'", + dims.size(), dim_value, (int)value.size, + value.data); + } + dims.push_back(dim_value); + } + if (out_shape_rank) { + *out_shape_rank = dims.size(); + } + if (dims.size() > shape_capacity) { + return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); + } + if (out_shape) { + std::memcpy(out_shape, dims.data(), dims.size() * sizeof(*out_shape)); + } + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_format_shape(const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_host_size_t buffer_capacity, char* buffer, + iree_host_size_t* out_buffer_length) { + if (out_buffer_length) { + *out_buffer_length = 0; + } + iree_host_size_t buffer_length = 0; + for (iree_host_size_t i = 0; i < shape_rank; ++i) { + int n = std::snprintf(buffer ? buffer + buffer_length : nullptr, + buffer ? buffer_capacity - buffer_length : 0, + (i < shape_rank - 1) ? "%dx" : "%d", shape[i]); + if (n < 0) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "snprintf failed to write dimension %zu", i); + } else if (buffer && n >= buffer_capacity - buffer_length) { + buffer = nullptr; + } + buffer_length += n; + } + if (out_buffer_length) { + *out_buffer_length = buffer_length; + } + return buffer ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element_type( + iree_string_view_t value, iree_hal_element_type_t* out_element_type) { + IREE_ASSERT_ARGUMENT(out_element_type); + *out_element_type = IREE_HAL_ELEMENT_TYPE_NONE; + + auto str_value = absl::string_view(value.data, value.size); + + iree_hal_numerical_type_t numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN; + if (absl::StartsWith(str_value, "i")) { + numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED; + str_value.remove_prefix(1); + } else if (absl::StartsWith(str_value, "u")) { + numerical_type = IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED; + str_value.remove_prefix(1); + } else if (absl::StartsWith(str_value, "f")) { + numerical_type = IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE; + str_value.remove_prefix(1); + } else if (absl::StartsWith(str_value, "x") || + absl::StartsWith(str_value, "*")) { + numerical_type = IREE_HAL_NUMERICAL_TYPE_UNKNOWN; + str_value.remove_prefix(1); + } else { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unhandled element type prefix in '%.*s'", + (int)value.size, value.data); + } + + uint32_t bit_count = 0; + if (!absl::SimpleAtoi(str_value, &bit_count) || bit_count > 0xFFu) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "out of range bit count in '%.*s'", (int)value.size, + value.data); + } + + *out_element_type = iree_hal_make_element_type(numerical_type, bit_count); + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element_type( + iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length) { + if (out_buffer_length) { + *out_buffer_length = 0; + } + const char* prefix; + switch (iree_hal_element_numerical_type(element_type)) { + case IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED: + prefix = "i"; + break; + case IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED: + prefix = "u"; + break; + case IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE: + prefix = "f"; + break; + default: + prefix = "*"; + break; + } + int n = std::snprintf( + buffer, buffer_capacity, "%s%d", prefix, + static_cast(iree_hal_element_bit_count(element_type))); + if (n < 0) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "snprintf failed"); + } + if (out_buffer_length) { + *out_buffer_length = n; + } + return n >= buffer_capacity ? iree_status_from_code(IREE_STATUS_OUT_OF_RANGE) + : iree_ok_status(); +} + +// Parses a string of two character pairs representing hex numbers into bytes. +static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to, + ptrdiff_t num) { + /* clang-format off */ + static constexpr char kHexValue[256] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9' + 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F' + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f' + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + /* clang-format on */ + for (int i = 0; i < num; i++) { + to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) + + (kHexValue[from[i * 2 + 1] & 0xFF]); + } +} + +// Parses a signal element string, assuming that the caller has validated that +// |out_data| has enough storage space for the parsed element data. +static iree_status_t iree_hal_parse_element_unsafe( + iree_string_view_t data_str, iree_hal_element_type_t element_type, + uint8_t* out_data) { + switch (element_type) { + case IREE_HAL_ELEMENT_TYPE_SINT_8: { + int32_t temp = 0; + if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + &temp) || + temp > INT8_MAX) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *reinterpret_cast(out_data) = static_cast(temp); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_UINT_8: { + uint32_t temp = 0; + if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + &temp) || + temp > UINT8_MAX) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *reinterpret_cast(out_data) = static_cast(temp); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_SINT_16: { + int32_t temp = 0; + if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + &temp) || + temp > INT16_MAX) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *reinterpret_cast(out_data) = static_cast(temp); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_UINT_16: { + uint32_t temp = 0; + if (!absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + &temp) || + temp > UINT16_MAX) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *reinterpret_cast(out_data) = static_cast(temp); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_SINT_32: + return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + reinterpret_cast(out_data)) + ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + case IREE_HAL_ELEMENT_TYPE_UINT_32: + return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + reinterpret_cast(out_data)) + ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + case IREE_HAL_ELEMENT_TYPE_SINT_64: + return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + reinterpret_cast(out_data)) + ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + case IREE_HAL_ELEMENT_TYPE_UINT_64: + return absl::SimpleAtoi(absl::string_view(data_str.data, data_str.size), + reinterpret_cast(out_data)) + ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + case IREE_HAL_ELEMENT_TYPE_FLOAT_16: { + float temp = 0; + if (!absl::SimpleAtof(absl::string_view(data_str.data, data_str.size), + &temp)) { + return iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + } + *reinterpret_cast(out_data) = + half_float::detail::float2half(temp); + return iree_ok_status(); + } + case IREE_HAL_ELEMENT_TYPE_FLOAT_32: + return absl::SimpleAtof(absl::string_view(data_str.data, data_str.size), + reinterpret_cast(out_data)) + ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + case IREE_HAL_ELEMENT_TYPE_FLOAT_64: + return absl::SimpleAtod(absl::string_view(data_str.data, data_str.size), + reinterpret_cast(out_data)) + ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_INVALID_ARGUMENT); + default: { + // Treat any unknown format as binary. + iree_host_size_t element_size = iree_hal_element_byte_count(element_type); + if (data_str.size != element_size * 2) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "binary hex element count mismatch: buffer " + "length=%zu < expected=%zu", + data_str.size, element_size * 2); + } + iree_hal_hex_string_to_bytes(data_str.data, out_data, element_size); + return iree_ok_status(); + } + } +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element( + iree_string_view_t data_str, iree_hal_element_type_t element_type, + iree_byte_span_t data_ptr) { + iree_host_size_t element_size = iree_hal_element_byte_count(element_type); + if (data_ptr.data_length < element_size) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "output data buffer overflow: data_length=%zu < element_size=%zu", + data_ptr.data_length, element_size); + } + return iree_hal_parse_element_unsafe(data_str, element_type, data_ptr.data); +} + +// Converts a sequence of bytes into hex number strings. +static void iree_hal_bytes_to_hex_string(const uint8_t* src, char* dest, + ptrdiff_t num) { + static constexpr char kHexTable[513] = + "000102030405060708090A0B0C0D0E0F" + "101112131415161718191A1B1C1D1E1F" + "202122232425262728292A2B2C2D2E2F" + "303132333435363738393A3B3C3D3E3F" + "404142434445464748494A4B4C4D4E4F" + "505152535455565758595A5B5C5D5E5F" + "606162636465666768696A6B6C6D6E6F" + "707172737475767778797A7B7C7D7E7F" + "808182838485868788898A8B8C8D8E8F" + "909192939495969798999A9B9C9D9E9F" + "A0A1A2A3A4A5A6A7A8A9AAABACADAEAF" + "B0B1B2B3B4B5B6B7B8B9BABBBCBDBEBF" + "C0C1C2C3C4C5C6C7C8C9CACBCCCDCECF" + "D0D1D2D3D4D5D6D7D8D9DADBDCDDDEDF" + "E0E1E2E3E4E5E6E7E8E9EAEBECEDEEEF" + "F0F1F2F3F4F5F6F7F8F9FAFBFCFDFEFF"; + for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest += 2) { + const char* hex_p = &kHexTable[*src_ptr * 2]; + std::copy(hex_p, hex_p + 2, dest); + } +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element( + iree_const_byte_span_t data, iree_hal_element_type_t element_type, + iree_host_size_t buffer_capacity, char* buffer, + iree_host_size_t* out_buffer_length) { + iree_host_size_t element_size = iree_hal_element_byte_count(element_type); + if (data.data_length < element_size) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "data buffer underflow: data_length=%zu < element_size=%zu", + data.data_length, element_size); + } + int n = 0; + switch (element_type) { + case IREE_HAL_ELEMENT_TYPE_SINT_8: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi8, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_UINT_8: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu8, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_SINT_16: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi16, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_UINT_16: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu16, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_SINT_32: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi32, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_UINT_32: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu32, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_SINT_64: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIi64, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_UINT_64: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%" PRIu64, + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_16: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", + half_float::detail::half2float( + *reinterpret_cast(data.data))); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_32: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", + *reinterpret_cast(data.data)); + break; + case IREE_HAL_ELEMENT_TYPE_FLOAT_64: + n = std::snprintf(buffer, buffer ? buffer_capacity : 0, "%G", + *reinterpret_cast(data.data)); + break; + default: { + // Treat any unknown format as binary. + n = 2 * (int)element_size; + if (buffer && buffer_capacity > n) { + iree_hal_bytes_to_hex_string(data.data, buffer, element_size); + buffer[n] = 0; + } + } + } + if (n < 0) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "snprintf failed"); + } else if (buffer && n >= buffer_capacity) { + buffer = nullptr; + } + if (out_buffer_length) { + *out_buffer_length = n; + } + return buffer ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_buffer_elements( + iree_string_view_t data_str, iree_hal_element_type_t element_type, + iree_byte_span_t data_ptr) { + IREE_TRACE_SCOPE0("iree_hal_parse_buffer_elements"); + iree_host_size_t element_size = iree_hal_element_byte_count(element_type); + iree_host_size_t element_capacity = data_ptr.data_length / element_size; + if (iree_string_view_is_empty(data_str)) { + memset(data_ptr.data, 0, data_ptr.data_length); + return iree_ok_status(); + } + size_t src_i = 0; + size_t dst_i = 0; + size_t token_start = std::string::npos; + while (src_i < data_str.size) { + char c = data_str.data[src_i++]; + bool is_separator = + absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']'; + if (token_start == std::string::npos) { + if (!is_separator) { + token_start = src_i - 1; + } + continue; + } else if (token_start != std::string::npos && !is_separator) { + continue; + } + if (dst_i >= element_capacity) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "output data buffer overflow: element_capacity=%zu < dst_i=%zu+", + element_capacity, dst_i); + } + IREE_RETURN_IF_ERROR(iree_hal_parse_element_unsafe( + iree_string_view_t{data_str.data + token_start, + src_i - 2 - token_start + 1}, + element_type, data_ptr.data + dst_i * element_size)); + ++dst_i; + token_start = std::string::npos; + } + if (token_start != std::string::npos) { + if (dst_i >= element_capacity) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "output data overflow: element_capacity=%zu < dst_i=%zu", + element_capacity, dst_i); + } + IREE_RETURN_IF_ERROR(iree_hal_parse_element_unsafe( + iree_string_view_t{data_str.data + token_start, + data_str.size - token_start}, + element_type, data_ptr.data + dst_i * element_size)); + ++dst_i; + } + if (dst_i == 1 && element_capacity > 1) { + // Splat the single value we got to the entire buffer. + uint8_t* p = data_ptr.data + element_size; + for (int i = 1; i < element_capacity; ++i, p += element_size) { + memcpy(p, data_ptr.data, element_size); + } + } else if (dst_i < element_capacity) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "input data string underflow: dst_i=%zu < element_capacity=%zu", dst_i, + element_capacity); + } + return iree_ok_status(); +} + +static iree_status_t iree_hal_format_buffer_elements_recursive( + iree_const_byte_span_t data, const iree_hal_dim_t* shape, + iree_host_size_t shape_rank, iree_hal_element_type_t element_type, + iree_host_size_t* max_element_count, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length) { + iree_host_size_t buffer_length = 0; + auto append_char = [&](char c) { + if (buffer) { + if (buffer_length < buffer_capacity - 1) { + buffer[buffer_length] = c; + buffer[buffer_length + 1] = '\0'; + } else { + buffer = nullptr; + } + } + ++buffer_length; + }; + + if (shape_rank == 0) { + // Scalar value; recurse to get on to the leaf dimension path. + const iree_hal_dim_t one = 1; + return iree_hal_format_buffer_elements_recursive( + data, &one, 1, element_type, max_element_count, buffer_capacity, buffer, + out_buffer_length); + } else if (shape_rank > 1) { + // Nested dimension; recurse into the next innermost dimension. + iree_hal_dim_t dim_length = 1; + for (iree_host_size_t i = 1; i < shape_rank; ++i) { + dim_length *= shape[i]; + } + iree_device_size_t dim_stride = + dim_length * iree_hal_element_byte_count(element_type); + if (data.data_length < dim_stride * shape[0]) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "input data underflow: data_length=%zu < expected=%zu", + data.data_length, + static_cast(dim_stride * shape[0])); + } + iree_const_byte_span_t subdata; + subdata.data = data.data; + subdata.data_length = dim_stride; + for (iree_hal_dim_t i = 0; i < shape[0]; ++i) { + append_char('['); + iree_host_size_t actual_length = 0; + iree_status_t status = iree_hal_format_buffer_elements_recursive( + subdata, shape + 1, shape_rank - 1, element_type, max_element_count, + buffer ? buffer_capacity - buffer_length : 0, + buffer ? buffer + buffer_length : nullptr, &actual_length); + buffer_length += actual_length; + if (iree_status_is_out_of_range(status)) { + buffer = nullptr; + } else if (!iree_status_is_ok(status)) { + return status; + } + subdata.data += dim_stride; + append_char(']'); + } + } else { + // Leaf dimension; output data. + iree_host_size_t max_count = + std::min(*max_element_count, static_cast(shape[0])); + iree_device_size_t element_stride = + iree_hal_element_byte_count(element_type); + if (data.data_length < max_count * element_stride) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "input data underflow; data_length=%zu < expected=%zu", + data.data_length, + static_cast(max_count * element_stride)); + } + *max_element_count -= max_count; + iree_const_byte_span_t subdata; + subdata.data = data.data; + subdata.data_length = element_stride; + for (iree_hal_dim_t i = 0; i < max_count; ++i) { + if (i > 0) append_char(' '); + iree_host_size_t actual_length = 0; + iree_status_t status = iree_hal_format_element( + subdata, element_type, buffer ? buffer_capacity - buffer_length : 0, + buffer ? buffer + buffer_length : nullptr, &actual_length); + subdata.data += element_stride; + buffer_length += actual_length; + if (iree_status_is_out_of_range(status)) { + buffer = nullptr; + } else if (!iree_status_is_ok(status)) { + return status; + } + } + if (max_count < shape[0]) { + append_char('.'); + append_char('.'); + append_char('.'); + } + } + if (out_buffer_length) { + *out_buffer_length = buffer_length; + } + return buffer ? iree_ok_status() + : iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_buffer_elements( + iree_const_byte_span_t data, const iree_hal_dim_t* shape, + iree_host_size_t shape_rank, iree_hal_element_type_t element_type, + iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length) { + IREE_TRACE_SCOPE0("iree_hal_format_buffer_elements"); + if (out_buffer_length) { + *out_buffer_length = 0; + } + if (buffer && buffer_capacity) { + buffer[0] = '\0'; + } + return iree_hal_format_buffer_elements_recursive( + data, shape, shape_rank, element_type, &max_element_count, + buffer_capacity, buffer, out_buffer_length); +} diff --git a/iree/hal/string_util.h b/iree/hal/string_util.h new file mode 100644 index 0000000000000..787a40f7bcc7d --- /dev/null +++ b/iree/hal/string_util.h @@ -0,0 +1,112 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_STRING_UTIL_H_ +#define IREE_HAL_STRING_UTIL_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/buffer.h" +#include "iree/hal/buffer_view.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Parses a serialized set of shape dimensions using the canonical shape format +// (the same as produced by iree_hal_format_shape). +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_shape( + iree_string_view_t value, iree_host_size_t shape_capacity, + iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank); + +// Converts shape dimensions into a `4x5x6` format. +// +// Follows the standard API string formatting rules. See iree/base/api.h. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_format_shape(const iree_hal_dim_t* shape, iree_host_size_t shape_rank, + iree_host_size_t buffer_capacity, char* buffer, + iree_host_size_t* out_buffer_length); + +// Parses a serialized iree_hal_element_type_t and sets |out_element_type| if +// it is valid. The format is the same as produced by +// iree_hal_format_element_type. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element_type( + iree_string_view_t value, iree_hal_element_type_t* out_element_type); + +// Converts an iree_hal_element_type_t enum value to a canonical string +// representation, like `IREE_HAL_ELEMENT_TYPE_FLOAT_16` to `f16`. +// |buffer_capacity| defines the size of |buffer| in bytes and +// |out_buffer_length| will return the string length in characters. +// +// Follows the standard API string formatting rules. See iree/base/api.h. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element_type( + iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length); + +// Parses a serialized element of |element_type| to its in-memory form. +// |data_ptr| must be at least large enough to contain the bytes of the element. +// For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4 +// byte float value of 1.2 to |data_ptr|. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_element( + iree_string_view_t data_str, iree_hal_element_type_t element_type, + iree_byte_span_t data_ptr); + +// Converts a single element of |element_type| to a string. +// +// |buffer_capacity| defines the size of |buffer| in bytes and +// |out_buffer_length| will return the string length in characters. Returns +// IREE_STATUS_OUT_OF_RANGE if the buffer capacity is insufficient to hold the +// formatted elements and |out_buffer_length| will contain the required size. +// +// Follows the standard API string formatting rules. See iree/base/api.h. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_element( + iree_const_byte_span_t data, iree_hal_element_type_t element_type, + iree_host_size_t buffer_capacity, char* buffer, + iree_host_size_t* out_buffer_length); + +// Parses a serialized set of elements of the given |element_type|. +// The resulting parsed data is written to |data_ptr|, which must be at least +// large enough to contain the parsed elements. The format is the same as +// produced by iree_hal_format_buffer_elements. Supports additional inputs of +// empty to denote a 0 fill and a single element to denote a splat. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_parse_buffer_elements( + iree_string_view_t data_str, iree_hal_element_type_t element_type, + iree_byte_span_t data_ptr); + +// Converts a shaped buffer of |element_type| elements to a string. +// This will include []'s to denote each dimension, for example for a shape of +// 2x3 the elements will be formatted as `[1 2 3][4 5 6]`. +// +// |max_element_count| can be used to limit the total number of elements printed +// when the count may be large. Elided elements will be replaced with `...`. +// +// |buffer_capacity| defines the size of |buffer| in bytes and +// |out_buffer_length| will return the string length in characters. Returns +// IREE_STATUS_OUT_OF_RANGE if the buffer capacity is insufficient to hold the +// formatted elements and |out_buffer_length| will contain the required size. +// +// Follows the standard API string formatting rules. See iree/base/api.h. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_format_buffer_elements( + iree_const_byte_span_t data, const iree_hal_dim_t* shape, + iree_host_size_t shape_rank, iree_hal_element_type_t element_type, + iree_host_size_t max_element_count, iree_host_size_t buffer_capacity, + char* buffer, iree_host_size_t* out_buffer_length); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_STRING_UTIL_H_ diff --git a/iree/hal/api_string_util_test.cc b/iree/hal/string_util_test.cc similarity index 99% rename from iree/hal/api_string_util_test.cc rename to iree/hal/string_util_test.cc index a148af1b6db1e..6f425d0a9977b 100644 --- a/iree/hal/api_string_util_test.cc +++ b/iree/hal/string_util_test.cc @@ -390,8 +390,9 @@ struct Allocator final // used. static StatusOr CreateHostLocal() { Allocator allocator; - iree_status_t status = iree_hal_allocator_create_host_local( - iree_allocator_system(), &allocator); + iree_status_t status = + iree_hal_allocator_create_heap(iree_make_cstring_view("host_local"), + iree_allocator_system(), &allocator); IREE_RETURN_IF_ERROR(std::move(status)); return std::move(allocator); } @@ -432,8 +433,7 @@ struct BufferView final iree_hal_element_type_t element_type) { BufferView buffer_view; iree_status_t status = iree_hal_buffer_view_create( - buffer, shape.data(), shape.size(), element_type, - iree_allocator_system(), &buffer_view); + buffer, shape.data(), shape.size(), element_type, &buffer_view); IREE_RETURN_IF_ERROR(std::move(status)); return std::move(buffer_view); } diff --git a/iree/hal/testing/BUILD b/iree/hal/testing/BUILD index dbf5b16c534ba..f8b767023c648 100644 --- a/iree/hal/testing/BUILD +++ b/iree/hal/testing/BUILD @@ -29,33 +29,3 @@ cc_library( "//iree/hal/drivers", ], ) - -cc_library( - name = "mock_allocator", - testonly = True, - hdrs = ["mock_allocator.h"], - deps = [ - "//iree/hal", - "//iree/testing:gtest", - ], -) - -cc_library( - name = "mock_command_buffer", - testonly = True, - hdrs = ["mock_command_buffer.h"], - deps = [ - "//iree/hal", - "//iree/testing:gtest", - ], -) - -cc_library( - name = "mock_command_queue", - testonly = True, - hdrs = ["mock_command_queue.h"], - deps = [ - "//iree/hal", - "//iree/testing:gtest", - ], -) diff --git a/iree/hal/testing/CMakeLists.txt b/iree/hal/testing/CMakeLists.txt index 4874217d1621d..8faecc3dfe003 100644 --- a/iree/hal/testing/CMakeLists.txt +++ b/iree/hal/testing/CMakeLists.txt @@ -25,39 +25,3 @@ iree_cc_library( TESTONLY PUBLIC ) - -iree_cc_library( - NAME - mock_allocator - HDRS - "mock_allocator.h" - DEPS - iree::hal - iree::testing::gtest - TESTONLY - PUBLIC -) - -iree_cc_library( - NAME - mock_command_buffer - HDRS - "mock_command_buffer.h" - DEPS - iree::hal - iree::testing::gtest - TESTONLY - PUBLIC -) - -iree_cc_library( - NAME - mock_command_queue - HDRS - "mock_command_queue.h" - DEPS - iree::hal - iree::testing::gtest - TESTONLY - PUBLIC -) diff --git a/iree/hal/testing/mock_allocator.h b/iree/hal/testing/mock_allocator.h deleted file mode 100644 index e13689447d994..0000000000000 --- a/iree/hal/testing/mock_allocator.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_TESTING_MOCK_ALLOCATOR_H_ -#define IREE_HAL_TESTING_MOCK_ALLOCATOR_H_ - -#include "iree/hal/allocator.h" -#include "iree/testing/gtest.h" - -namespace iree { -namespace hal { -namespace testing { - -class MockAllocator : public ::testing::StrictMock { - public: - MockAllocator() : ::testing::StrictMock() {} - - MOCK_METHOD(bool, CanUseBufferLike, - (Allocator * source_allocator, MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage), - (const, override)); - - MOCK_METHOD(bool, CanAllocate, - (MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size), - (const, override)); - - MOCK_METHOD(StatusOr>, Allocate, - (MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size), - (override)); - - MOCK_METHOD(StatusOr>, WrapMutable, - (MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, - size_t data_length), - (override)); -}; - -} // namespace testing -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_TESTING_MOCK_ALLOCATOR_H_ diff --git a/iree/hal/testing/mock_command_buffer.h b/iree/hal/testing/mock_command_buffer.h deleted file mode 100644 index edae319d1787d..0000000000000 --- a/iree/hal/testing/mock_command_buffer.h +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_ -#define IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_ - -#include "iree/hal/command_buffer.h" -#include "iree/testing/gtest.h" - -namespace iree { -namespace hal { -namespace testing { - -class MockCommandBuffer : public ::testing::StrictMock { - public: - MockCommandBuffer(CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) - : ::testing::StrictMock(mode, command_categories) {} - - bool is_recording() const override { return false; } - - MOCK_METHOD(Status, Begin, (), (override)); - MOCK_METHOD(Status, End, (), (override)); - - MOCK_METHOD(Status, ExecutionBarrier, - (ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers), - (override)); - - MOCK_METHOD(Status, SignalEvent, - (Event * event, ExecutionStageBitfield source_stage_mask), - (override)); - - MOCK_METHOD(Status, ResetEvent, - (Event * event, ExecutionStageBitfield source_stage_mask), - (override)); - - MOCK_METHOD(Status, WaitEvents, - (absl::Span events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers), - (override)); - - MOCK_METHOD(Status, FillBuffer, - (Buffer * target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length), - (override)); - - MOCK_METHOD(Status, DiscardBuffer, (Buffer * buffer), (override)); - - MOCK_METHOD(Status, UpdateBuffer, - (const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length), - (override)); - - MOCK_METHOD(Status, CopyBuffer, - (Buffer * source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length), - (override)); - - MOCK_METHOD(Status, PushConstants, - (ExecutableLayout * executable_layout, size_t offset, - absl::Span values), - (override)); - - MOCK_METHOD(Status, PushDescriptorSet, - (ExecutableLayout * executable_layout, int32_t set, - absl::Span bindings), - (override)); - MOCK_METHOD(Status, BindDescriptorSet, - (ExecutableLayout * executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets), - (override)); - - MOCK_METHOD(Status, Dispatch, - (Executable * executable, int32_t entry_point, - (std::array workgroups)), - (override)); - - MOCK_METHOD(Status, DispatchIndirect, - (Executable * executable, int32_t entry_point, - Buffer* workgroups_buffer, device_size_t workgroups_offset), - (override)); -}; - -} // namespace testing -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_ diff --git a/iree/hal/testing/mock_command_queue.h b/iree/hal/testing/mock_command_queue.h deleted file mode 100644 index c281026784dbb..0000000000000 --- a/iree/hal/testing/mock_command_queue.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_ -#define IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_ - -#include "iree/hal/command_queue.h" -#include "iree/testing/gtest.h" - -namespace iree { -namespace hal { -namespace testing { - -class MockCommandQueue : public ::testing::StrictMock { - public: - MockCommandQueue(std::string name, - CommandCategoryBitfield supported_categories) - : ::testing::StrictMock(std::move(name), - supported_categories) {} - - MOCK_METHOD(Status, Submit, (absl::Span batches), - (override)); - - MOCK_METHOD(Status, WaitIdle, (Time deadline_ns), (override)); -}; - -} // namespace testing -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_ diff --git a/iree/hal/vmla/BUILD b/iree/hal/vmla/BUILD index 54a4d9e4efab6..f3d5aa65e2299 100644 --- a/iree/hal/vmla/BUILD +++ b/iree/hal/vmla/BUILD @@ -14,108 +14,8 @@ # A VMLA (VM-based Linear Algebra) runtime HAL backend. -load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") - package( default_visibility = ["//visibility:public"], features = ["layering_check"], licenses = ["notice"], # Apache 2.0 ) - -iree_cmake_extra_content( - content = """ -if(NOT ${IREE_HAL_DRIVER_VMLA}) - return() -endif() -""", -) - -cc_library( - name = "op_kernels", - hdrs = ["op_kernels.h"], - textual_hdrs = [ - # TODO(benvanik): SIMD variants. - "op_kernels_generic.h", - "op_kernels_ruy.h", - "op_kernels_fft.h", - ], - deps = [ - "//iree/base:status", - "//iree/base:tracing", - "@com_google_absl//absl/algorithm", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - "@com_google_ruy//ruy", - "@com_google_ruy//ruy:context", - "@pffft", - ], -) - -cc_test( - name = "op_kernels_test", - srcs = ["op_kernels_test.cc"], - deps = [ - ":op_kernels", - "//iree/base:core_headers", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - "@com_google_absl//absl/container:inlined_vector", - ], -) - -cc_library( - name = "op_module", - srcs = ["op_module.cc"], - hdrs = ["op_module.h"], - deps = [ - ":op_kernels", - "//iree/base:api", - "//iree/base:core_headers", - "//iree/base:ref_ptr", - "//iree/base:status", - "//iree/base:tracing", - "//iree/vm", - "//iree/vm:cc", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "vmla", - srcs = [ - "vmla_cache.cc", - "vmla_device.cc", - "vmla_driver.cc", - "vmla_executable.cc", - ], - hdrs = [ - "vmla_cache.h", - "vmla_device.h", - "vmla_driver.h", - "vmla_executable.h", - ], - deps = [ - ":op_module", - "//iree/base:api", - "//iree/base:core_headers", - "//iree/base:flatcc", - "//iree/base:ref_ptr", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "//iree/hal/host:host_buffer", - "//iree/hal/host:host_executable", - "//iree/hal/host:host_local_device", - "//iree/hal/host/serial:serial_scheduling_model", - "//iree/schemas:vmla_executable_def_c_fbs", - "//iree/vm", - "//iree/vm:bytecode_module", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/iree/hal/vmla/CMakeLists.txt b/iree/hal/vmla/CMakeLists.txt index 330da765e1c98..8b864e5427773 100644 --- a/iree/hal/vmla/CMakeLists.txt +++ b/iree/hal/vmla/CMakeLists.txt @@ -12,100 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_HAL_DRIVER_VMLA}) - return() -endif() - iree_add_all_subdirs() - -iree_cc_library( - NAME - op_kernels - HDRS - "op_kernels.h" - TEXTUAL_HDRS - "op_kernels_fft.h" - "op_kernels_generic.h" - "op_kernels_ruy.h" - DEPS - absl::algorithm - absl::core_headers - absl::flat_hash_set - absl::inlined_vector - absl::memory - absl::span - iree::base::status - iree::base::tracing - pffft - ruy - PUBLIC -) - -iree_cc_test( - NAME - op_kernels_test - SRCS - "op_kernels_test.cc" - DEPS - ::op_kernels - absl::inlined_vector - iree::base::core_headers - iree::testing::gtest - iree::testing::gtest_main -) - -iree_cc_library( - NAME - op_module - HDRS - "op_module.h" - SRCS - "op_module.cc" - DEPS - ::op_kernels - absl::span - iree::base::api - iree::base::core_headers - iree::base::ref_ptr - iree::base::status - iree::base::tracing - iree::vm - iree::vm::cc - PUBLIC -) - -iree_cc_library( - NAME - vmla - HDRS - "vmla_cache.h" - "vmla_device.h" - "vmla_driver.h" - "vmla_executable.h" - SRCS - "vmla_cache.cc" - "vmla_device.cc" - "vmla_driver.cc" - "vmla_executable.cc" - DEPS - ::op_module - absl::inlined_vector - absl::memory - absl::span - absl::strings - iree::base::api - iree::base::core_headers - iree::base::flatcc - iree::base::ref_ptr - iree::base::status - iree::base::tracing - iree::hal - iree::hal::host::host_buffer - iree::hal::host::host_executable - iree::hal::host::host_local_device - iree::hal::host::serial::serial_scheduling_model - iree::schemas::vmla_executable_def_c_fbs - iree::vm - iree::vm::bytecode_module - PUBLIC -) diff --git a/iree/hal/vmla/registration/BUILD b/iree/hal/vmla/registration/BUILD index 4a4d34cd20ad7..dbd5860430d58 100644 --- a/iree/hal/vmla/registration/BUILD +++ b/iree/hal/vmla/registration/BUILD @@ -29,16 +29,15 @@ if(${IREE_HAL_DRIVER_VMLA}) cc_library( name = "registration", - srcs = ["driver_module.cc"], + srcs = ["driver_module.c"], hdrs = ["driver_module.h"], defines = [ "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1", ], deps = [ - "//iree/base:flags", - "//iree/base:status", "//iree/hal:api", - "//iree/hal/vmla", + "//iree/hal/local:task_driver", + "//iree/hal/local/loaders:vmla_module_loader", ], ) diff --git a/iree/hal/vmla/registration/CMakeLists.txt b/iree/hal/vmla/registration/CMakeLists.txt index 66ebf7b2c1562..6c5e18494f4fc 100644 --- a/iree/hal/vmla/registration/CMakeLists.txt +++ b/iree/hal/vmla/registration/CMakeLists.txt @@ -22,12 +22,11 @@ iree_cc_library( HDRS "driver_module.h" SRCS - "driver_module.cc" + "driver_module.c" DEPS - iree::base::flags - iree::base::status iree::hal::api - iree::hal::vmla + iree::hal::local::loaders::vmla_module_loader + iree::hal::local::task_driver DEFINES "IREE_HAL_HAVE_VMLA_DRIVER_MODULE=1" PUBLIC diff --git a/iree/hal/vmla/registration/driver_module.cc b/iree/hal/vmla/registration/driver_module.c similarity index 50% rename from iree/hal/vmla/registration/driver_module.cc rename to iree/hal/vmla/registration/driver_module.c index 6e93527a2bd9a..6fb9f07a7f689 100644 --- a/iree/hal/vmla/registration/driver_module.cc +++ b/iree/hal/vmla/registration/driver_module.c @@ -16,18 +16,24 @@ #include -#include "iree/hal/vmla/vmla_driver.h" +#include "iree/hal/local/loaders/vmla_module_loader.h" +#include "iree/hal/local/task_driver.h" + +// TODO(#4298): remove this driver registration and wrapper. #define IREE_HAL_VMLA_DRIVER_ID 0x564D4C41u // VMLA static iree_status_t iree_hal_vmla_driver_factory_enumerate( void* self, const iree_hal_driver_info_t** out_driver_infos, iree_host_size_t* out_driver_info_count) { - static const iree_hal_driver_info_t driver_infos[1] = {{ - /*driver_id=*/IREE_HAL_VMLA_DRIVER_ID, - /*driver_name=*/iree_make_cstring_view("vmla"), - /*full_name=*/iree_make_cstring_view("VMLA Reference Backend"), - }}; + static const iree_hal_driver_info_t driver_infos[1] = { + { + .driver_id = IREE_HAL_VMLA_DRIVER_ID, + .driver_name = iree_string_view_literal("vmla"), + .full_name = + iree_string_view_literal("Reference backend (deprecated)"), + }, + }; *out_driver_info_count = IREE_ARRAYSIZE(driver_infos); *out_driver_infos = driver_infos; return iree_ok_status(); @@ -42,9 +48,42 @@ static iree_status_t iree_hal_vmla_driver_factory_try_create( " is provided by this factory", driver_id); } - IREE_ASSIGN_OR_RETURN(auto driver, iree::hal::vmla::VMLADriver::Create()); - *out_driver = reinterpret_cast(driver.release()); - return iree_ok_status(); + + iree_hal_task_device_params_t default_params; + iree_hal_task_device_params_initialize(&default_params); + + // NOTE: VMLA doesn't tile so we don't really need many workers - having + // multiple does make it easier to test overlapping execution, though. + iree_task_topology_t topology; + iree_task_topology_initialize_from_group_count(4, &topology); + + iree_vm_instance_t* instance = NULL; + iree_status_t status = iree_vm_instance_create(allocator, &instance); + + iree_hal_executable_loader_t* vmla_loader = NULL; + if (iree_status_is_ok(status)) { + status = + iree_hal_vmla_module_loader_create(instance, allocator, &vmla_loader); + } + iree_hal_executable_loader_t* loaders[1] = {vmla_loader}; + + iree_task_executor_t* executor = NULL; + if (iree_status_is_ok(status)) { + status = iree_task_executor_create(IREE_TASK_SCHEDULING_MODE_RESERVED, + &topology, allocator, &executor); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_task_driver_create( + iree_make_cstring_view("vmla"), &default_params, executor, + IREE_ARRAYSIZE(loaders), loaders, allocator, out_driver); + } + + iree_task_executor_release(executor); + iree_task_topology_deinitialize(&topology); + iree_hal_executable_loader_release(vmla_loader); + iree_vm_instance_release(instance); + return status; } IREE_API_EXPORT iree_status_t IREE_API_CALL diff --git a/iree/hal/vmla/vmla_cache.cc b/iree/hal/vmla/vmla_cache.cc deleted file mode 100644 index 3b71bb2f80f17..0000000000000 --- a/iree/hal/vmla/vmla_cache.cc +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vmla/vmla_cache.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/vmla/vmla_executable.h" - -namespace iree { -namespace hal { -namespace vmla { - -VMLACache::VMLACache(iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module) - : instance_(instance), vmla_module_(vmla_module) { - iree_vm_instance_retain(instance_); - iree_vm_module_retain(vmla_module_); -} - -VMLACache::~VMLACache() { - iree_vm_module_release(vmla_module_); - iree_vm_instance_release(instance_); -} - -bool VMLACache::CanPrepareFormat(ExecutableFormat format) const { - return format == kExecutableFormatVMLA; -} - -StatusOr> VMLACache::PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("VMLACache::PrepareExecutable"); - // Wrap the data (or copy it). - bool allow_aliasing_data = - AllBitsSet(mode, ExecutableCachingMode::kAliasProvidedData); - IREE_ASSIGN_OR_RETURN( - auto executable, - VMLAExecutable::Load(instance_, vmla_module_, spec, allow_aliasing_data)); - - return executable; -} - -} // namespace vmla -} // namespace hal -} // namespace iree diff --git a/iree/hal/vmla/vmla_cache.h b/iree/hal/vmla/vmla_cache.h deleted file mode 100644 index 560d4bda8f12b..0000000000000 --- a/iree/hal/vmla/vmla_cache.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VMLA_VMLA_CACHE_H_ -#define IREE_HAL_VMLA_VMLA_CACHE_H_ - -#include "iree/hal/allocator.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" -#include "iree/vm/api.h" - -namespace iree { -namespace hal { -namespace vmla { - -class VMLACache final : public ExecutableCache { - public: - explicit VMLACache(iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module); - ~VMLACache() override; - - bool CanPrepareFormat(ExecutableFormat format) const override; - - StatusOr> PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) override; - - private: - iree_vm_instance_t* instance_ = nullptr; - iree_vm_module_t* vmla_module_ = nullptr; -}; - -} // namespace vmla -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VMLA_VMLA_CACHE_H_ diff --git a/iree/hal/vmla/vmla_device.cc b/iree/hal/vmla/vmla_device.cc deleted file mode 100644 index e9ce3e8868c30..0000000000000 --- a/iree/hal/vmla/vmla_device.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vmla/vmla_device.h" - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vmla/vmla_cache.h" - -namespace iree { -namespace hal { -namespace vmla { - -VMLADevice::VMLADevice(DeviceInfo device_info, - std::unique_ptr scheduling_model, - iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module) - : HostLocalDevice(std::move(device_info), std::move(scheduling_model)), - instance_(instance), - vmla_module_(vmla_module) { - iree_vm_instance_retain(instance_); - iree_vm_module_retain(vmla_module_); -} - -VMLADevice::~VMLADevice() { - iree_vm_module_release(vmla_module_); - iree_vm_instance_release(instance_); -} - -ref_ptr VMLADevice::CreateExecutableCache() { - IREE_TRACE_SCOPE0("VMLADevice::CreateExecutableCache"); - return make_ref(instance_, vmla_module_); -} - -} // namespace vmla -} // namespace hal -} // namespace iree diff --git a/iree/hal/vmla/vmla_device.h b/iree/hal/vmla/vmla_device.h deleted file mode 100644 index 3d3f3aa214225..0000000000000 --- a/iree/hal/vmla/vmla_device.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VMLA_VMLA_DEVICE_H_ -#define IREE_HAL_VMLA_VMLA_DEVICE_H_ - -#include "iree/base/memory.h" -#include "iree/hal/host/host_local_device.h" -#include "iree/vm/api.h" - -namespace iree { -namespace hal { -namespace vmla { - -class VMLADevice final : public host::HostLocalDevice { - public: - explicit VMLADevice(DeviceInfo device_info, - std::unique_ptr scheduling_model, - iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module); - ~VMLADevice() override; - - ref_ptr CreateExecutableCache() override; - - private: - iree_vm_instance_t* instance_ = nullptr; - iree_vm_module_t* vmla_module_ = nullptr; -}; - -} // namespace vmla -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VMLA_VMLA_DEVICE_H_ diff --git a/iree/hal/vmla/vmla_driver.cc b/iree/hal/vmla/vmla_driver.cc deleted file mode 100644 index 54696b1544243..0000000000000 --- a/iree/hal/vmla/vmla_driver.cc +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vmla/vmla_driver.h" - -#include - -#include "iree/base/tracing.h" -#include "iree/hal/device_info.h" -#include "iree/hal/host/serial/serial_scheduling_model.h" -#include "iree/hal/vmla/op_module.h" -#include "iree/hal/vmla/vmla_device.h" - -namespace iree { -namespace hal { -namespace vmla { - -namespace { - -DeviceInfo GetDefaultDeviceInfo() { - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - // TODO(benvanik): implement debugging/profiling features. - // supported_features |= DeviceFeature::kDebugging; - // supported_features |= DeviceFeature::kCoverage; - // supported_features |= DeviceFeature::kProfiling; - DeviceInfo device_info("vmla", "vmla", supported_features); - // TODO(benvanik): device info. - return device_info; -} - -} // namespace - -// static -StatusOr> VMLADriver::Create() { - IREE_TRACE_SCOPE0("VMLADriver::Create"); - - // NOTE: we could use our own allocator here to hide these from any default - // tracing we have. - iree_vm_instance_t* instance = nullptr; - IREE_RETURN_IF_ERROR( - iree_vm_instance_create(iree_allocator_system(), &instance)); - - // TODO(benvanik): move to instance-based registration. - IREE_RETURN_IF_ERROR(ModuleRegisterTypes()) - << "VMLA type registration failed"; - - iree_vm_module_t* vmla_module = nullptr; - IREE_RETURN_IF_ERROR(ModuleCreate(iree_allocator_system(), &vmla_module)) - << "VMLA shared module creation failed"; - - return make_ref(instance, vmla_module); -} - -VMLADriver::VMLADriver(iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module) - : Driver("vmla"), instance_(instance), vmla_module_(vmla_module) {} - -VMLADriver::~VMLADriver() { - IREE_TRACE_SCOPE0("VMLADriver::dtor"); - iree_vm_module_release(vmla_module_); - iree_vm_instance_release(instance_); -} - -StatusOr> VMLADriver::EnumerateAvailableDevices() { - std::vector device_infos; - device_infos.push_back(GetDefaultDeviceInfo()); - return device_infos; -} - -StatusOr> VMLADriver::CreateDefaultDevice() { - return CreateDevice(0); -} - -StatusOr> VMLADriver::CreateDevice(DriverDeviceID device_id) { - auto scheduling_model = std::make_unique(); - auto device = - make_ref(GetDefaultDeviceInfo(), std::move(scheduling_model), - instance_, vmla_module_); - return device; -} - -} // namespace vmla -} // namespace hal -} // namespace iree diff --git a/iree/hal/vmla/vmla_driver.h b/iree/hal/vmla/vmla_driver.h deleted file mode 100644 index c701f4d20b449..0000000000000 --- a/iree/hal/vmla/vmla_driver.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VMLA_VMLA_DRIVER_H_ -#define IREE_HAL_VMLA_VMLA_DRIVER_H_ - -#include "iree/hal/driver.h" -#include "iree/vm/api.h" - -namespace iree { -namespace hal { -namespace vmla { - -class VMLADriver final : public Driver { - public: - static StatusOr> Create(); - - VMLADriver(iree_vm_instance_t* instance, iree_vm_module_t* vmla_module); - ~VMLADriver() override; - - StatusOr> EnumerateAvailableDevices() override; - - StatusOr> CreateDefaultDevice() override; - - StatusOr> CreateDevice(DriverDeviceID device_id) override; - - private: - iree_vm_instance_t* instance_ = nullptr; - iree_vm_module_t* vmla_module_ = nullptr; -}; - -} // namespace vmla -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VMLA_VMLA_DRIVER_H_ diff --git a/iree/hal/vmla/vmla_executable.cc b/iree/hal/vmla/vmla_executable.cc deleted file mode 100644 index da0cff2c09605..0000000000000 --- a/iree/hal/vmla/vmla_executable.cc +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vmla/vmla_executable.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/host/host_buffer.h" -#include "iree/hal/vmla/op_module.h" -#include "iree/vm/bytecode_module.h" - -// flatcc schemas: -#include "iree/base/flatcc.h" -#include "iree/schemas/vmla_executable_def_reader.h" -#include "iree/schemas/vmla_executable_def_verifier.h" - -// NOTE: starting to port this to C. - -// Verifies the structure of the flatbuffer so that we can avoid doing so during -// runtime. There are still some conditions we must be aware of (such as omitted -// names on functions with internal linkage), however we shouldn't need to -// bounds check anything within the flatbuffer after this succeeds. -static iree_status_t iree_hal_vmla_executable_flatbuffer_verify( - iree_const_byte_span_t flatbuffer_data) { - if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer data is not present or less than 16 bytes (%zu total)", - flatbuffer_data.data_length); - } - - // Run flatcc generated verification. This ensures all pointers are in-bounds - // and that we can safely walk the file, but not that the actual contents of - // the flatbuffer meet our expectations. - int verify_ret = iree_VMLAExecutableDef_verify_as_root( - flatbuffer_data.data, flatbuffer_data.data_length); - if (verify_ret != flatcc_verify_ok) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer verification failed: %s", - flatcc_verify_error_string(verify_ret)); - } - - iree_VMLAExecutableDef_table_t executable_def = - iree_VMLAExecutableDef_as_root(flatbuffer_data.data); - - if (flatbuffers_uint8_vec_len( - iree_VMLAExecutableDef_bytecode_module_get(executable_def)) < 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable bytecode_module is missing/empty"); - } - - // NOTE: we don't check the actual bytecode module contents here; it's opaque - // to us and passed on to the VM. - return iree_ok_status(); -} - -namespace iree { -namespace hal { -namespace vmla { - -// static -StatusOr> VMLAExecutable::Load( - iree_vm_instance_t* instance, iree_vm_module_t* vmla_module, - ExecutableSpec spec, bool allow_aliasing_data) { - IREE_TRACE_SCOPE0("VMLAExecutable::Load"); - // Allocate the executable now. - // We do this here so that if we need to clone the data we are passing that - // to the VM loader instead of the data we may not have access to later. - auto executable = make_ref(spec, allow_aliasing_data); - IREE_RETURN_IF_ERROR(executable->Initialize(instance, vmla_module)); - return executable; -} - -VMLAExecutable::VMLAExecutable(ExecutableSpec spec, bool allow_aliasing_data) - : spec_(spec) { - if (!allow_aliasing_data) { - // Clone data. - cloned_executable_data_ = {spec.executable_data.begin(), - spec.executable_data.end()}; - spec_.executable_data = absl::MakeConstSpan(cloned_executable_data_); - } -} - -VMLAExecutable::~VMLAExecutable() { - IREE_TRACE_SCOPE0("VMLAExecutable::dtor"); - iree_vm_context_release(context_); - context_ = nullptr; -} - -Status VMLAExecutable::Initialize(iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module) { - IREE_TRACE_SCOPE0("VMLAExecutable::Initialize"); - - // Verify and fetch the executable flatbuffer wrapper. - iree_const_byte_span_t executable_data = iree_make_const_byte_span( - spec_.executable_data.data(), spec_.executable_data.size()); - IREE_RETURN_IF_ERROR( - iree_hal_vmla_executable_flatbuffer_verify(executable_data)); - iree_VMLAExecutableDef_table_t executable_def = - iree_VMLAExecutableDef_as_root(executable_data.data); - - // Load bytecode module from the executable spec. - flatbuffers_uint8_vec_t bytecode_module_vec = - iree_VMLAExecutableDef_bytecode_module_get(executable_def); - iree_const_byte_span_t bytecode_module_data = iree_make_const_byte_span( - bytecode_module_vec, flatbuffers_uint8_vec_len(bytecode_module_vec)); - iree_vm_module_t* bytecode_module = nullptr; - IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create( - bytecode_module_data, iree_allocator_null(), iree_allocator_system(), - &bytecode_module)) - << "Failed to load executable bytecode module"; - - entry_functions_.resize( - iree_vm_module_signature(bytecode_module).export_function_count); - for (size_t i = 0; i < entry_functions_.size(); ++i) { - IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal( - bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, - &entry_functions_[i], nullptr)); - } - - // Create context and initialize shared state. Note that each executable here - // has its own context (and thus its own vmla.interface instance). - std::array modules = {vmla_module, bytecode_module}; - auto result = StatusBuilder(iree_vm_context_create_with_modules( - instance, modules.data(), modules.size(), - iree_allocator_system(), &context_), - IREE_LOC) - << "Failed resolving imports for executable module"; - iree_vm_module_release(bytecode_module); - - return std::move(result); -} - -struct VMLADispatchState : public HostExecutable::DispatchState { - VMLADispatchState() { interface_ref = Interface_retain_ref(&interface); } - ~VMLADispatchState() override { iree_vm_ref_release(&interface_ref); } - - iree_vm_function_t function; - Interface interface; - iree_vm_ref_t interface_ref; - iree_host_size_t input_list_size = 0; -}; - -StatusOr> -VMLAExecutable::PrepareDispatch(const DispatchParams& params) { - IREE_TRACE_SCOPE0("VMLAExecutable::PrepareDispatch"); - - if (params.entry_point >= entry_functions_.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid entry point ordinal " << params.entry_point; - } - - auto dispatch_state = make_ref(); - dispatch_state->function = entry_functions_[params.entry_point]; - dispatch_state->input_list_size = iree_vm_list_storage_size( - /*element_type=*/nullptr, /*interface*/ 1 + /*workgroup_xyz[3]*/ 3); - - auto* interface = &dispatch_state->interface; - IREE_RETURN_IF_ERROR(interface->SetConstants(params.push_constants->values)); - - for (size_t set_ordinal = 0; set_ordinal < params.set_bindings.size(); - ++set_ordinal) { - for (const auto& binding : params.set_bindings[set_ordinal]) { - // TODO(benvanik): plumb binding directly into VMLA to avoid this. - void* data = static_cast(binding.buffer->allocated_buffer()) - ->mutable_data(); - data = reinterpret_cast(reinterpret_cast(data) + - binding.buffer->byte_offset() + - binding.offset); - IREE_ASSIGN_OR_RETURN( - auto buffer, Buffer::WrapMutable(data, binding.buffer->byte_length(), - iree_allocator_null())); - IREE_RETURN_IF_ERROR(interface->SetBinding(set_ordinal, binding.binding, - {std::move(buffer)})); - } - } - - return std::move(dispatch_state); -} - -Status VMLAExecutable::DispatchTile(DispatchState* state, - std::array workgroup_xyz) { - auto* dispatch_state = static_cast(state); - IREE_TRACE_SCOPE_DYNAMIC( - iree_vm_function_name(&dispatch_state->function).data); - - auto* input_list_storage = alloca(dispatch_state->input_list_size); - iree_vm_list_t* input_list = nullptr; - IREE_RETURN_IF_ERROR(iree_vm_list_initialize( - iree_make_byte_span(input_list_storage, dispatch_state->input_list_size), - /*element_type=*/nullptr, - /*interface*/ 1 + /*workgroup_xyz[3]*/ 3, &input_list)); - iree_vm_list_push_ref_retain(input_list, &dispatch_state->interface_ref); - for (size_t i = 0; i < workgroup_xyz.size(); ++i) { - iree_vm_value_t value = iree_vm_value_make_i32(workgroup_xyz[i]); - iree_vm_list_push_value(input_list, &value); - } - - // TODO(benvanik): switch to direct calling to avoid the invoke overhead. - auto status = - Status(iree_vm_invoke(context(), dispatch_state->function, - /*policy=*/nullptr, input_list, - /*outputs=*/nullptr, iree_allocator_system())); - - iree_vm_list_deinitialize(input_list); - - return status; -} - -} // namespace vmla -} // namespace hal -} // namespace iree diff --git a/iree/hal/vmla/vmla_executable.h b/iree/hal/vmla/vmla_executable.h deleted file mode 100644 index 7eb3c27472b0f..0000000000000 --- a/iree/hal/vmla/vmla_executable.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VMLA_VMLA_EXECUTABLE_H_ -#define IREE_HAL_VMLA_VMLA_EXECUTABLE_H_ - -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/executable_spec.h" -#include "iree/hal/host/host_executable.h" -#include "iree/vm/api.h" - -namespace iree { -namespace hal { -namespace vmla { - -class Interface; - -class VMLAExecutable final : public HostExecutable { - public: - static StatusOr> Load(iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module, - ExecutableSpec spec, - bool allow_aliasing_data); - - VMLAExecutable(ExecutableSpec spec, bool allow_aliasing_data); - ~VMLAExecutable() override; - - bool supports_debugging() const override { return false; } - - // Reference to the bytecode blob contents. - absl::Span executable_data() const { - return spec_.executable_data; - } - - // VM context containing the loaded executable module. - iree_vm_context_t* context() const { return context_; } - - // Entry point functions in export order. - absl::Span entry_functions() const { - return absl::MakeConstSpan(entry_functions_); - } - - StatusOr> PrepareDispatch( - const DispatchParams& params) override; - Status DispatchTile(DispatchState* state, - std::array workgroup_xyz) override; - - private: - Status Initialize(iree_vm_instance_t* instance, - iree_vm_module_t* vmla_module); - - ExecutableSpec spec_; - std::vector cloned_executable_data_; - - iree_vm_context_t* context_ = nullptr; - absl::InlinedVector entry_functions_; -}; - -} // namespace vmla -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VMLA_VMLA_EXECUTABLE_H_ diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD index 582ae18d2e011..901ca887894d1 100644 --- a/iree/hal/vulkan/BUILD +++ b/iree/hal/vulkan/BUILD @@ -31,151 +31,119 @@ endif() ) cc_library( - name = "api", - srcs = ["api.cc"], - hdrs = ["api.h"], - visibility = ["//visibility:public"], - deps = [ - ":utils", - ":vulkan", - "//iree/base:api", - "//iree/base:tracing", - "//iree/hal:api", - ], -) - -cc_library( - name = "utils", + name = "vulkan", srcs = [ + "api.cc", + "command_queue.h", "debug_reporter.cc", - "dynamic_symbols.cc", - "extensibility_util.cc", - "renderdoc_capture_manager.cc", - "status_util.cc", - "timepoint_util.cc", - ], - hdrs = [ "debug_reporter.h", - "dynamic_symbol_tables.h", - "dynamic_symbols.h", + "descriptor_pool_cache.cc", + "descriptor_pool_cache.h", + "descriptor_set_arena.cc", + "descriptor_set_arena.h", + "direct_command_buffer.cc", + "direct_command_buffer.h", + "direct_command_queue.cc", + "direct_command_queue.h", + "emulated_semaphore.cc", + "emulated_semaphore.h", + "extensibility_util.cc", "extensibility_util.h", "handle_util.h", - "renderdoc_capture_manager.h", + "internal_vk_mem_alloc.cc", + "internal_vk_mem_alloc.h", + "native_descriptor_set.cc", + "native_descriptor_set.h", + "native_descriptor_set_layout.cc", + "native_descriptor_set_layout.h", + "native_event.cc", + "native_event.h", + "native_executable.cc", + "native_executable.h", + "native_executable_layout.cc", + "native_executable_layout.h", + "native_semaphore.cc", + "native_semaphore.h", + "nop_executable_cache.cc", + "nop_executable_cache.h", + "serializing_command_queue.cc", + "serializing_command_queue.h", + "status_util.c", "status_util.h", + "timepoint_util.cc", "timepoint_util.h", + "vma_allocator.cc", + "vma_allocator.h", + "vma_buffer.cc", + "vma_buffer.h", + "vulkan_device.cc", + "vulkan_driver.cc", "vulkan_headers.h", ], + hdrs = [ + # TODO(benvanik): hide all but api.h. + "api.h", + "vulkan_device.h", + "vulkan_driver.h", + ], + visibility = ["//visibility:public"], deps = [ + ":dynamic_symbols", + "//iree/base:api", + "//iree/base:arena", "//iree/base:core_headers", - "//iree/base:dynamic_library", + "//iree/base:flatcc", "//iree/base:intrusive_list", "//iree/base:logging", "//iree/base:ref_ptr", "//iree/base:status", - "//iree/base:time", + "//iree/base:synchronization", "//iree/base:tracing", - "//iree/hal", + "//iree/hal:api", + "//iree/schemas:spirv_executable_def_c_fbs", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@iree_vulkan_headers//:vulkan_headers", - "@renderdoc_api//:renderdoc_app", - ], -) - -cc_test( - name = "dynamic_symbols_test", - srcs = ["dynamic_symbols_test.cc"], - tags = ["driver=vulkan"], - deps = [ - ":utils", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - ], -) - -cc_library( - name = "vma_allocator", - srcs = [ - "internal_vk_mem_alloc.cc", - "internal_vk_mem_alloc.h", - "vma_allocator.cc", - "vma_buffer.cc", - ], - hdrs = [ - "vma_allocator.h", - "vma_buffer.h", - ], - deps = [ - ":utils", - "//iree/base:core_headers", - "//iree/base:logging", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", "@vulkan_memory_allocator//:impl_header_only", ], ) cc_library( - name = "vulkan", + name = "dynamic_symbols", srcs = [ - "descriptor_pool_cache.cc", - "descriptor_set_arena.cc", - "direct_command_buffer.cc", - "direct_command_queue.cc", - "emulated_timeline_semaphore.cc", - "native_descriptor_set.cc", - "native_event.cc", - "native_timeline_semaphore.cc", - "pipeline_cache.cc", - "pipeline_executable.cc", - "pipeline_executable_layout.cc", - "serializing_command_queue.cc", - "vulkan_device.cc", - "vulkan_driver.cc", + "dynamic_symbols.cc", + "vulkan_headers.h", ], hdrs = [ - "descriptor_pool_cache.h", - "descriptor_set_arena.h", - "direct_command_buffer.h", - "direct_command_queue.h", - "emulated_timeline_semaphore.h", - "native_descriptor_set.h", - "native_event.h", - "native_timeline_semaphore.h", - "pipeline_cache.h", - "pipeline_executable.h", - "pipeline_executable_layout.h", - "serializing_command_queue.h", - "vulkan_device.h", - "vulkan_driver.h", + "dynamic_symbol_tables.h", + "dynamic_symbols.h", ], deps = [ - ":utils", - ":vma_allocator", - "//iree/base:api", - "//iree/base:arena", "//iree/base:core_headers", - "//iree/base:flatcc", - "//iree/base:intrusive_list", + "//iree/base:dynamic_library", "//iree/base:ref_ptr", "//iree/base:status", - "//iree/base:time", "//iree/base:tracing", - "//iree/hal", - "//iree/hal:command_buffer_validation", - "//iree/schemas:spirv_executable_def_c_fbs", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@iree_vulkan_headers//:vulkan_headers", + ], +) + +cc_test( + name = "dynamic_symbols_test", + srcs = ["dynamic_symbols_test.cc"], + tags = ["driver=vulkan"], + deps = [ + ":dynamic_symbols", + "//iree/testing:gtest", + "//iree/testing:gtest_main", ], ) diff --git a/iree/hal/vulkan/CMakeLists.txt b/iree/hal/vulkan/CMakeLists.txt index 3bdeeacefd1ca..016e11477ad61 100644 --- a/iree/hal/vulkan/CMakeLists.txt +++ b/iree/hal/vulkan/CMakeLists.txt @@ -20,57 +20,104 @@ iree_add_all_subdirs() iree_cc_library( NAME - api + vulkan HDRS "api.h" + "vulkan_device.h" + "vulkan_driver.h" SRCS "api.cc" + "command_queue.h" + "debug_reporter.cc" + "debug_reporter.h" + "descriptor_pool_cache.cc" + "descriptor_pool_cache.h" + "descriptor_set_arena.cc" + "descriptor_set_arena.h" + "direct_command_buffer.cc" + "direct_command_buffer.h" + "direct_command_queue.cc" + "direct_command_queue.h" + "emulated_semaphore.cc" + "emulated_semaphore.h" + "extensibility_util.cc" + "extensibility_util.h" + "handle_util.h" + "internal_vk_mem_alloc.cc" + "internal_vk_mem_alloc.h" + "native_descriptor_set.cc" + "native_descriptor_set.h" + "native_descriptor_set_layout.cc" + "native_descriptor_set_layout.h" + "native_event.cc" + "native_event.h" + "native_executable.cc" + "native_executable.h" + "native_executable_layout.cc" + "native_executable_layout.h" + "native_semaphore.cc" + "native_semaphore.h" + "nop_executable_cache.cc" + "nop_executable_cache.h" + "serializing_command_queue.cc" + "serializing_command_queue.h" + "status_util.c" + "status_util.h" + "timepoint_util.cc" + "timepoint_util.h" + "vma_allocator.cc" + "vma_allocator.h" + "vma_buffer.cc" + "vma_buffer.h" + "vulkan_device.cc" + "vulkan_driver.cc" + "vulkan_headers.h" DEPS - ::utils - ::vulkan + ::dynamic_symbols + Vulkan::Headers + absl::core_headers + absl::flat_hash_map + absl::inlined_vector + absl::memory + absl::span + absl::strings + absl::synchronization iree::base::api + iree::base::arena + iree::base::core_headers + iree::base::flatcc + iree::base::intrusive_list + iree::base::logging + iree::base::ref_ptr + iree::base::status + iree::base::synchronization iree::base::tracing iree::hal::api + iree::schemas::spirv_executable_def_c_fbs + vulkan_memory_allocator PUBLIC ) iree_cc_library( NAME - utils + dynamic_symbols HDRS - "debug_reporter.h" "dynamic_symbol_tables.h" "dynamic_symbols.h" - "extensibility_util.h" - "handle_util.h" - "renderdoc_capture_manager.h" - "status_util.h" - "timepoint_util.h" - "vulkan_headers.h" SRCS - "debug_reporter.cc" "dynamic_symbols.cc" - "extensibility_util.cc" - "renderdoc_capture_manager.cc" - "status_util.cc" - "timepoint_util.cc" + "vulkan_headers.h" DEPS Vulkan::Headers absl::core_headers absl::memory absl::span absl::strings - absl::synchronization iree::base::core_headers iree::base::dynamic_library - iree::base::intrusive_list - iree::base::logging iree::base::ref_ptr iree::base::status - iree::base::time iree::base::tracing - iree::hal - renderdoc_api::renderdoc_app PUBLIC ) @@ -80,91 +127,9 @@ iree_cc_test( SRCS "dynamic_symbols_test.cc" DEPS - ::utils + ::dynamic_symbols iree::testing::gtest iree::testing::gtest_main LABELS "driver=vulkan" ) - -iree_cc_library( - NAME - vma_allocator - HDRS - "vma_allocator.h" - "vma_buffer.h" - SRCS - "internal_vk_mem_alloc.cc" - "internal_vk_mem_alloc.h" - "vma_allocator.cc" - "vma_buffer.cc" - DEPS - ::utils - absl::flat_hash_map - absl::memory - absl::synchronization - iree::base::core_headers - iree::base::logging - iree::base::status - iree::base::tracing - iree::hal - vulkan_memory_allocator - PUBLIC -) - -iree_cc_library( - NAME - vulkan - HDRS - "descriptor_pool_cache.h" - "descriptor_set_arena.h" - "direct_command_buffer.h" - "direct_command_queue.h" - "emulated_timeline_semaphore.h" - "native_descriptor_set.h" - "native_event.h" - "native_timeline_semaphore.h" - "pipeline_cache.h" - "pipeline_executable.h" - "pipeline_executable_layout.h" - "serializing_command_queue.h" - "vulkan_device.h" - "vulkan_driver.h" - SRCS - "descriptor_pool_cache.cc" - "descriptor_set_arena.cc" - "direct_command_buffer.cc" - "direct_command_queue.cc" - "emulated_timeline_semaphore.cc" - "native_descriptor_set.cc" - "native_event.cc" - "native_timeline_semaphore.cc" - "pipeline_cache.cc" - "pipeline_executable.cc" - "pipeline_executable_layout.cc" - "serializing_command_queue.cc" - "vulkan_device.cc" - "vulkan_driver.cc" - DEPS - ::utils - ::vma_allocator - absl::core_headers - absl::inlined_vector - absl::memory - absl::span - absl::strings - absl::synchronization - iree::base::api - iree::base::arena - iree::base::core_headers - iree::base::flatcc - iree::base::intrusive_list - iree::base::ref_ptr - iree::base::status - iree::base::time - iree::base::tracing - iree::hal - iree::hal::command_buffer_validation - iree::schemas::spirv_executable_def_c_fbs - PUBLIC -) diff --git a/iree/hal/vulkan/api.cc b/iree/hal/vulkan/api.cc index d664577a82bf3..5653ef9466723 100644 --- a/iree/hal/vulkan/api.cc +++ b/iree/hal/vulkan/api.cc @@ -21,16 +21,17 @@ #include "iree/hal/vulkan/vulkan_device.h" #include "iree/hal/vulkan/vulkan_driver.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; + +// TODO(benvanik): move these into the appropriate files and delete this .cc. //===----------------------------------------------------------------------===// // iree::hal::vulkan::DynamicSymbols //===----------------------------------------------------------------------===// IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create( - void* vkGetInstanceProcAddr_fn, iree_hal_vulkan_syms_t** out_syms) { + void* vkGetInstanceProcAddr_fn, iree_allocator_t host_allocator, + iree_hal_vulkan_syms_t** out_syms) { IREE_TRACE_SCOPE0("iree_hal_vulkan_syms_create"); IREE_ASSERT_ARGUMENT(out_syms); *out_syms = nullptr; @@ -53,7 +54,7 @@ IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create( IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create_from_system_loader( - iree_hal_vulkan_syms_t** out_syms) { + iree_allocator_t host_allocator, iree_hal_vulkan_syms_t** out_syms) { IREE_TRACE_SCOPE0("iree_hal_vulkan_syms_create_from_system_loader"); IREE_ASSERT_ARGUMENT(out_syms); *out_syms = nullptr; @@ -63,248 +64,20 @@ iree_hal_vulkan_syms_create_from_system_loader( return iree_ok_status(); } -IREE_API_EXPORT iree_status_t -iree_hal_vulkan_syms_release(iree_hal_vulkan_syms_t* syms) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_syms_release"); +IREE_API_EXPORT void IREE_API_CALL +iree_hal_vulkan_syms_retain(iree_hal_vulkan_syms_t* syms) { IREE_ASSERT_ARGUMENT(syms); auto* handle = reinterpret_cast(syms); - handle->ReleaseReference(); - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::vulkan Extensibility Util -//===----------------------------------------------------------------------===// - -namespace { - -ExtensibilitySpec GetInstanceExtensibilitySpec( - const iree_hal_vulkan_features_t& features) { - ExtensibilitySpec spec; - - // Multiple extensions depend on VK_KHR_get_physical_device_properties2. - // This extension was deprecated in Vulkan 1.1 as its functionality was - // promoted to core, so we list it as optional even though we require it. - spec.optional_extensions.push_back( - VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); - - if (features & IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS) { - spec.optional_layers.push_back("VK_LAYER_KHRONOS_standard_validation"); - } - - if (features & IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS) { - spec.optional_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME); - } - - // Polyfill layer - enable if present. - spec.optional_layers.push_back("VK_LAYER_KHRONOS_timeline_semaphore"); - - return spec; -} - -ExtensibilitySpec GetDeviceExtensibilitySpec( - const iree_hal_vulkan_features_t& features) { - ExtensibilitySpec spec; - - // REQUIRED: these are required extensions that must be present for IREE to - // work (such as those relied upon by SPIR-V kernels, etc). - spec.required_extensions.push_back( - VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); - // Timeline semaphore support is required. - spec.required_extensions.push_back(VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME); - - if (features & IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS) { - spec.optional_extensions.push_back(VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); - } - - return spec; -} - -} // namespace - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_extensions( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t extensions_capacity, - const char** out_extensions, iree_host_size_t* out_extensions_count) { - IREE_ASSERT_ARGUMENT(out_extensions_count); - *out_extensions_count = 0; - - bool is_instance = extensibility_set & IREE_HAL_VULKAN_INSTANCE_BIT; - bool is_required = extensibility_set & IREE_HAL_VULKAN_REQUIRED_BIT; - - ExtensibilitySpec spec = is_instance ? GetInstanceExtensibilitySpec(features) - : GetDeviceExtensibilitySpec(features); - *out_extensions_count = is_required ? spec.required_extensions.size() - : spec.optional_extensions.size(); - - // Return early if only querying number of extensions in this configuration. - if (!out_extensions) { - return iree_ok_status(); - } - - if (extensions_capacity < *out_extensions_count) { - // Not an error; just a size query. - return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); - } - - const std::vector& extensions = - is_required ? spec.required_extensions : spec.optional_extensions; - for (int i = 0; i < extensions.size(); ++i) { - out_extensions[i] = extensions[i]; - } - - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_layers( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t layers_capacity, - const char** out_layers, iree_host_size_t* out_layers_count) { - IREE_ASSERT_ARGUMENT(out_layers_count); - *out_layers_count = 0; - - // Device layers are deprecated and unsupported here. - if (!(extensibility_set & IREE_HAL_VULKAN_INSTANCE_BIT)) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "device layers are deprecated in Vulkan"); - } - - bool is_required = extensibility_set & IREE_HAL_VULKAN_REQUIRED_BIT; - - ExtensibilitySpec spec = GetInstanceExtensibilitySpec(features); - *out_layers_count = - is_required ? spec.required_layers.size() : spec.optional_layers.size(); - - // Return early if only querying number of layers in this configuration. - if (!out_layers) { - return iree_ok_status(); - } - - if (layers_capacity < *out_layers_count) { - // Not an error; just a size query. - return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); - } - - const std::vector& layers = - is_required ? spec.required_layers : spec.optional_layers; - for (int i = 0; i < layers.size(); ++i) { - out_layers[i] = layers[i]; + if (handle) { + handle->AddReference(); } - - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::vulkan::VulkanDriver -//===----------------------------------------------------------------------===// - -namespace { - -VulkanDriver::Options ConvertDriverOptions( - iree_hal_vulkan_driver_options_t options) { - VulkanDriver::Options driver_options; - driver_options.api_version = options.api_version; - driver_options.instance_extensibility = - GetInstanceExtensibilitySpec(options.features); - driver_options.device_options.extensibility_spec = - GetDeviceExtensibilitySpec(options.features); - return driver_options; -} - -} // namespace - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - iree_hal_driver_t** out_driver) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create"); - IREE_ASSERT_ARGUMENT(syms); - IREE_ASSERT_ARGUMENT(out_driver); - *out_driver = nullptr; - - IREE_ASSIGN_OR_RETURN( - auto driver, - VulkanDriver::Create(ConvertDriverOptions(options), - add_ref(reinterpret_cast(syms)))); - *out_driver = reinterpret_cast(driver.release()); - return iree_ok_status(); } -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_using_instance( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - VkInstance instance, iree_hal_driver_t** out_driver) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create_using_instance"); +IREE_API_EXPORT void IREE_API_CALL +iree_hal_vulkan_syms_release(iree_hal_vulkan_syms_t* syms) { IREE_ASSERT_ARGUMENT(syms); - IREE_ASSERT_ARGUMENT(instance); - IREE_ASSERT_ARGUMENT(out_driver); - *out_driver = nullptr; - - IREE_ASSIGN_OR_RETURN( - auto driver, - VulkanDriver::CreateUsingInstance( - ConvertDriverOptions(options), - add_ref(reinterpret_cast(syms)), instance)); - *out_driver = reinterpret_cast(driver.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_default_device(iree_hal_driver_t* driver, - iree_hal_device_t** out_device) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create_default_device"); - IREE_ASSERT_ARGUMENT(driver); - IREE_ASSERT_ARGUMENT(out_device); - *out_device = nullptr; - - auto* handle = reinterpret_cast(driver); - - IREE_LOG(INFO) << "Enumerating available Vulkan devices..."; - IREE_ASSIGN_OR_RETURN(auto available_devices, - handle->EnumerateAvailableDevices()); - for (const auto& device_info : available_devices) { - IREE_LOG(INFO) << " Device: " << device_info.name(); + auto* handle = reinterpret_cast(syms); + if (handle) { + handle->ReleaseReference(); } - IREE_LOG(INFO) << "Creating default device..."; - IREE_ASSIGN_OR_RETURN(auto device, handle->CreateDefaultDevice()); - IREE_LOG(INFO) << "Successfully created device '" << device->info().name() - << "'"; - - *out_device = reinterpret_cast(device.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_wrap_device( - iree_hal_driver_t* driver, VkPhysicalDevice physical_device, - VkDevice logical_device, iree_hal_vulkan_queue_set_t compute_queue_set, - iree_hal_vulkan_queue_set_t transfer_queue_set, - iree_hal_device_t** out_device) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create_device"); - IREE_ASSERT_ARGUMENT(driver); - IREE_ASSERT_ARGUMENT(physical_device); - IREE_ASSERT_ARGUMENT(logical_device); - IREE_ASSERT_ARGUMENT(out_device); - *out_device = nullptr; - - auto* handle = reinterpret_cast(driver); - - IREE_LOG(INFO) << "Creating VulkanDevice..."; - QueueSet compute_qs; - compute_qs.queue_family_index = compute_queue_set.queue_family_index; - compute_qs.queue_indices = compute_queue_set.queue_indices; - QueueSet transfer_qs; - transfer_qs.queue_family_index = transfer_queue_set.queue_family_index; - transfer_qs.queue_indices = transfer_queue_set.queue_indices; - IREE_ASSIGN_OR_RETURN(auto device, - handle->WrapDevice(physical_device, logical_device, - compute_qs, transfer_qs)); - IREE_LOG(INFO) << "Successfully created device '" << device->info().name() - << "'"; - - *out_device = reinterpret_cast(device.release()); - - return iree_ok_status(); } - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/api.h b/iree/hal/vulkan/api.h index 56d1346965c26..3f5912ae295e4 100644 --- a/iree/hal/vulkan/api.h +++ b/iree/hal/vulkan/api.h @@ -29,63 +29,83 @@ extern "C" { #endif // __cplusplus //===----------------------------------------------------------------------===// -// Types and Enums +// iree_hal_vulkan_device_t extensibility util //===----------------------------------------------------------------------===// -// Describes the type of a set of Vulkan extensions. -typedef enum { - IREE_HAL_VULKAN_REQUIRED_BIT = 1 << 0, - IREE_HAL_VULKAN_INSTANCE_BIT = 1 << 1, - - // A set of required instance extension names. - IREE_HAL_VULKAN_INSTANCE_REQUIRED = - IREE_HAL_VULKAN_INSTANCE_BIT | IREE_HAL_VULKAN_REQUIRED_BIT, - // A set of optional instance extension names. - IREE_HAL_VULKAN_INSTANCE_OPTIONAL = IREE_HAL_VULKAN_INSTANCE_BIT, - // A set of required device extension names. - IREE_HAL_VULKAN_DEVICE_REQUIRED = IREE_HAL_VULKAN_REQUIRED_BIT, - // A set of optional device extension names. - IREE_HAL_VULKAN_DEVICE_OPTIONAL = 0, -} iree_hal_vulkan_extensibility_set_t; - +// TODO(benvanik): replace with feature list (easier to version). // Bitfield that defines sets of Vulkan features. -typedef enum { - // Use VK_LAYER_KHRONOS_standard_validation. - IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS = 1 << 0, +enum iree_hal_vulkan_feature_e { + // Use VK_LAYER_KHRONOS_standard_validation to validate Vulkan API usage. + // Has a significant performance penalty and is *not* a security mechanism. + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS = 1 << 0, // Use VK_EXT_debug_utils, record markers, and log errors. - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS = 1 << 1, + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS = 1 << 1, +}; +typedef uint64_t iree_hal_vulkan_features_t; - // Use vkCmdPushDescriptorSetKHR. - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS = 1 << 2, -} iree_hal_vulkan_features_t; +// Describes the type of a set of Vulkan extensions. +enum iree_hal_vulkan_extensibility_set_e { + // A set of required instance layer names. These must all be enabled on + // the VkInstance for IREE to function. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED = 0, -// Vulkan driver creation options. -typedef struct { - // Vulkan version that will be requested, e.g. `VK_API_VERSION_1_0`. - // Driver creation will fail if the required version is not available. - uint32_t api_version; + // A set of optional instance layer names. If omitted fallbacks may be + // used or debugging features may not be available. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL = 1, - // Vulkan features to request. - iree_hal_vulkan_features_t features; -} iree_hal_vulkan_driver_options_t; + // A set of required instance extension names. These must all be enabled on + // the VkInstance for IREE to function. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED = 2, -// A set of queues within a specific queue family on a VkDevice. -typedef struct { - // The index of a particular queue family on a VkPhysicalDevice, as described - // by vkGetPhysicalDeviceQueueFamilyProperties. - uint32_t queue_family_index; + // A set of optional instance extension names. If omitted fallbacks may be + // used or debugging features may not be available. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL = 3, - // Bitfield of queue indices within the queue family at |queue_family_index|. - uint64_t queue_indices; -} iree_hal_vulkan_queue_set_t; + // A set of required device extension names. These must all be enabled on + // the VkDevice for IREE to function. + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED = 4, + + // A set of optional device extension names. If omitted fallbacks may be + // used or debugging features may not be available. + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL = 5, -typedef struct iree_hal_vulkan_syms iree_hal_vulkan_syms_t; + IREE_HAL_VULKAN_EXTENSIBILITY_SET_COUNT, +}; +typedef uint32_t iree_hal_vulkan_extensibility_set_t; + +// Queries the names of the Vulkan layers and extensions used for a given set of +// IREE |requested_features|. All devices used by IREE must have the required +// layers and extensions as defined by these sets. Optional layers and +// extensions will be used when needed and otherwise have fallbacks for when +// they are not available. +// +// Instance extensions should be enabled on VkInstances passed to +// |iree_hal_vulkan_driver_create_using_instance| and device extensions should +// be enabled on VkDevices passed to |iree_hal_vulkan_driver_wrap_device|. +// +// |string_capacity| defines the number of elements available in +// |out_string_values| and |out_string_count| will be set with the actual number +// of strings returned. If |string_capacity| is too small then +// IREE_STATUS_OUT_OF_RANGE will be returned with the required capacity in +// |out_string_count|. To only query the required capacity then +// |out_string_values| may be passed as NULL. +// +// The returned strings originate from the _EXTENSION_NAME Vulkan macros +// (such as 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME') and have a +// lifetime matching whatever module they are defined in. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree_host_size_t string_capacity, + const char** out_string_values, iree_host_size_t* out_string_count); //===----------------------------------------------------------------------===// -// iree::hal::vulkan::DynamicSymbols +// iree_hal_vulkan_syms_t //===----------------------------------------------------------------------===// +typedef struct iree_hal_vulkan_syms_s iree_hal_vulkan_syms_t; + // Loads Vulkan functions by invoking |vkGetInstanceProcAddr|. // // |vkGetInstanceProcAddr| can be obtained in whatever way suites the calling @@ -95,99 +115,58 @@ typedef struct iree_hal_vulkan_syms iree_hal_vulkan_syms_t; // // |out_syms| must be released by the caller. IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create( - void* vkGetInstanceProcAddr_fn, iree_hal_vulkan_syms_t** out_syms); + void* vkGetInstanceProcAddr_fn, iree_allocator_t host_allocator, + iree_hal_vulkan_syms_t** out_syms); // Loads Vulkan functions from the Vulkan loader. // This will look for a Vulkan loader on the system (like libvulkan.so) and // dlsym the functions from that. // -// |out_syms| must be released by the caller. +// |out_syms| must be released by the caller with iree_hal_vulkan_syms_release. IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create_from_system_loader( - iree_hal_vulkan_syms_t** out_syms); + iree_allocator_t host_allocator, iree_hal_vulkan_syms_t** out_syms); + +// Retains the given |syms| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_vulkan_syms_retain(iree_hal_vulkan_syms_t* syms); // Releases the given |syms| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_syms_release(iree_hal_vulkan_syms_t* syms); //===----------------------------------------------------------------------===// -// iree::hal::vulkan Extensibility Util +// iree_hal_vulkan_device_t //===----------------------------------------------------------------------===// -// Gets the names of the Vulkan extensions used for a given set of |features|. -// -// Instance extensions should be enabled on VkInstances passed to -// |iree_hal_vulkan_driver_create_using_instance| and device extensions should -// be enabled on VkDevices passed to |iree_hal_vulkan_driver_wrap_device|. -// -// |extensions_capacity| defines the number of elements available in -// |out_extensions| and |out_extensions_count| will be set with the actual -// number of extensions returned. If |extensions_capacity| is too small -// IREE_STATUS_OUT_OF_RANGE will be returned with the required capacity in -// |out_extensions_count|. To only query the required capacity |out_extensions| -// may be passed as nullptr. -// -// Extension string lifetime is tied to the loader shared object or instance, -// depending on where they came from. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_extensions( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t extensions_capacity, - const char** out_extensions, iree_host_size_t* out_extensions_count); - -// Gets the names of the Vulkan layers used for a given set of |features|. -// -// Instance layers should be enabled on VkInstances passed to -// |iree_hal_vulkan_driver_create_using_instance|. Device layers are deprecated -// and unsupported here. -// -// |layers_capacity| defines the number of elements available in |out_layers| -// and |out_layers_count| will be set with the actual number of layers returned. -// If |layers_capacity| is too small IREE_STATUS_OUT_OF_RANGE will be returned -// with the required capacity in |out_layers_count|. To only query the required -// capacity |out_layers| may be passed as nullptr. -// -// Layer string lifetime is tied to the loader shared object or instance, -// depending on where they came from. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_layers( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t layers_capacity, - const char** out_layers, iree_host_size_t* out_layers_count); - -//===----------------------------------------------------------------------===// -// iree::hal::vulkan::VulkanDriver -//===----------------------------------------------------------------------===// +// A set of queues within a specific queue family on a VkDevice. +typedef struct { + // The index of a particular queue family on a VkPhysicalDevice, as described + // by vkGetPhysicalDeviceQueueFamilyProperties. + uint32_t queue_family_index; -// TODO(scotttodd): Allow applications to provide their own allocators here + // Bitfield of queue indices within the queue family at |queue_family_index|. + uint64_t queue_indices; +} iree_hal_vulkan_queue_set_t; -// Creates a Vulkan HAL driver that manages its own VkInstance. -// -// |out_driver| must be released by the caller (see |iree_hal_driver_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - iree_hal_driver_t** out_driver); +// TODO(benvanik): replace with flag list (easier to version). +enum iree_hal_vulkan_device_flag_e { + // Uses timeline semaphore emulation even if native support exists. + // May be removed in future versions when timeline semaphores can be assumed + // present on all platforms (looking at you, Android ಠ_ಠ). + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION = 1 << 0, +}; +typedef uint64_t iree_hal_vulkan_device_flags_t; -// Creates a Vulkan HAL driver that shares an existing VkInstance. -// -// |instance| is expected to have been created with all extensions returned by -// |iree_hal_vulkan_get_extensions| and IREE_HAL_VULKAN_INSTANCE_REQUIRED using -// |options| enabled. -// -// |instance| must remain valid for the life of |out_driver| and |out_driver| -// itself must be released by the caller (see |iree_hal_driver_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_using_instance( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - VkInstance instance, iree_hal_driver_t** out_driver); +typedef struct { + // Flags controlling device behavior. + iree_hal_vulkan_device_flags_t flags; +} iree_hal_vulkan_device_options_t; -// Creates the default Vulkan HAL device using |driver| that manages its own -// VkPhysicalDevice/VkDevice. -// -// |out_device| must be released by the caller (see |iree_hal_device_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_default_device(iree_hal_driver_t* driver, - iree_hal_device_t** out_device); +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_device_options_initialize( + iree_hal_vulkan_device_options_t* out_options); -// Creates a Vulkan HAL device using |driver| that wraps an existing VkDevice. +// Creates a Vulkan HAL device that wraps an existing VkDevice. // // HAL devices created in this way may share Vulkan resources and synchronize // within the same physical VkPhysicalDevice and logical VkDevice directly. @@ -197,6 +176,9 @@ iree_hal_vulkan_driver_create_default_device(iree_hal_driver_t* driver, // IREE_HAL_VULKAN_DEVICE_REQUIRED using the features provided during driver // creation. // +// |instance_syms| must have at least the instance-specific functions resolved +// and device symbols will be queried from |logical_device| as needed. +// // The device will schedule commands against the queues in // |compute_queue_set| and (if set) |transfer_queue_set|. // @@ -210,14 +192,74 @@ iree_hal_vulkan_driver_create_default_device(iree_hal_driver_t* driver, // |compute_queue_set|, if they are available. // Similarly, dedicated transfer queues (no compute or graphics) are preferred // within |transfer_queue_set|. -// The queues may be the same. +// The queue sets can be the same. // // |out_device| must be released by the caller (see |iree_hal_device_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_wrap_device( - iree_hal_driver_t* driver, VkPhysicalDevice physical_device, - VkDevice logical_device, iree_hal_vulkan_queue_set_t compute_queue_set, - iree_hal_vulkan_queue_set_t transfer_queue_set, - iree_hal_device_t** out_device); +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_wrap_device( + iree_string_view_t identifier, + const iree_hal_vulkan_device_options_t* options, + const iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + VkPhysicalDevice physical_device, VkDevice logical_device, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set, + iree_allocator_t host_allocator, iree_hal_device_t** out_device); + +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_driver_t +//===----------------------------------------------------------------------===// + +// Vulkan driver creation options. +typedef struct { + // Vulkan version that will be requested, e.g. `VK_API_VERSION_1_0`. + // Driver creation will fail if the required version is not available. + uint32_t api_version; + + // IREE features used to configure the VkInstance and VkDevices created using + // it. These are used to populate the active Vulkan layers and extensions when + // the instance and its devices are created. + iree_hal_vulkan_features_t requested_features; + + // TODO(benvanik): remove this single setting - it would be nice instead to + // pass a list to force device enumeration/matrix expansion or omit entirely + // to have auto-discovered options based on capabilities. Right now this + // forces all devices - even if from different vendors - to have the same + // options. + // Options to use for all devices created by the driver. + iree_hal_vulkan_device_options_t device_options; + + // TODO(benvanik): change to something more canonically vulkan (like + // VkPhysicalDeviceProperties::deviceID). + // Index of the default Vulkan device to use within the list of available + // devices. Devices are discovered via vkEnumeratePhysicalDevices then + // considered "available" if compatible with the |requested_features|. + int default_device_index; +} iree_hal_vulkan_driver_options_t; + +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_driver_options_initialize( + iree_hal_vulkan_driver_options_t* out_options); + +// Creates a Vulkan HAL driver that manages its own VkInstance. +// +// |out_driver| must be released by the caller (see |iree_hal_driver_release|). +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* syms, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver); + +// Creates a Vulkan HAL driver that shares an existing VkInstance. +// +// |instance| is expected to have been created with all extensions returned by +// the instance-specific |iree_hal_vulkan_query_extensibility_set| queries. +// +// |instance| must remain valid for the life of |out_driver| and |out_driver| +// itself must be released by the caller (see |iree_hal_driver_release|). +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_driver_create_using_instance( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); #ifdef __cplusplus } // extern "C" diff --git a/iree/hal/vulkan/command_queue.h b/iree/hal/vulkan/command_queue.h new file mode 100644 index 0000000000000..47b3212b3457c --- /dev/null +++ b/iree/hal/vulkan/command_queue.h @@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_COMMAND_QUEUE_H_ +#define IREE_HAL_VULKAN_COMMAND_QUEUE_H_ + +#include + +#include "iree/base/arena.h" +#include "iree/base/status.h" +#include "iree/base/synchronization.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/vulkan/handle_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class CommandQueue { + public: + virtual ~CommandQueue() { + IREE_TRACE_SCOPE0("CommandQueue::dtor"); + iree_slim_mutex_lock(&queue_mutex_); + syms()->vkQueueWaitIdle(queue_); + iree_slim_mutex_unlock(&queue_mutex_); + iree_slim_mutex_deinitialize(&queue_mutex_); + } + + const ref_ptr& syms() const { + return logical_device_->syms(); + } + + bool can_dispatch() const { + return iree_all_bits_set(supported_categories_, + IREE_HAL_COMMAND_CATEGORY_DISPATCH); + } + virtual iree_status_t Submit(iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) = 0; + + virtual iree_status_t WaitIdle(iree_time_t deadline_ns) = 0; + + protected: + CommandQueue(VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, VkQueue queue) + : logical_device_(logical_device), + name_(std::move(name)), + supported_categories_(supported_categories), + queue_(queue) { + iree_slim_mutex_initialize(&queue_mutex_); + } + + VkDeviceHandle* logical_device_; + const std::string name_; + const iree_hal_command_category_t supported_categories_; + + // VkQueue needs to be externally synchronized. + iree_slim_mutex_t queue_mutex_; + VkQueue queue_ IREE_GUARDED_BY(queue_mutex_); +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_COMMAND_QUEUE_H_ diff --git a/iree/hal/vulkan/debug_reporter.cc b/iree/hal/vulkan/debug_reporter.cc index 62f0a16794411..c600030070c89 100644 --- a/iree/hal/vulkan/debug_reporter.cc +++ b/iree/hal/vulkan/debug_reporter.cc @@ -17,21 +17,23 @@ #include "iree/base/tracing.h" #include "iree/hal/vulkan/status_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -namespace { +struct iree_hal_vulkan_debug_reporter_s { + iree_allocator_t host_allocator; + VkInstance instance; + iree::hal::vulkan::DynamicSymbols* syms; + const VkAllocationCallbacks* allocation_callbacks; + VkDebugUtilsMessengerEXT messenger; +}; // NOTE: |user_data| may be nullptr if we are being called during instance // creation. Otherwise it is a pointer to the DebugReporter instance. - +// // NOTE: this callback must be thread safe and must be careful not to reach too // far outside of the call - it is called in-context from arbitrary threads with // some amount of Vulkan state on the stack. Assume that creating or deleting // Vulkan objects, issuing most Vulkan commands, etc are off-limits. - -VKAPI_ATTR VkBool32 VKAPI_CALL DebugUtilsMessageCallback( +static VKAPI_ATTR VkBool32 VKAPI_CALL +iree_hal_vulkan_debug_utils_message_callback( VkDebugUtilsMessageSeverityFlagBitsEXT message_severity, VkDebugUtilsMessageTypeFlagsEXT message_type, const VkDebugUtilsMessengerCallbackDataEXT* callback_data, @@ -41,122 +43,89 @@ VKAPI_ATTR VkBool32 VKAPI_CALL DebugUtilsMessageCallback( } else { IREE_VLOG(1) << callback_data->pMessage; } - - return VK_FALSE; // VK_TRUE is reserved for future use. -} - -VKAPI_ATTR VkBool32 VKAPI_CALL DebugReportCallback( - VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT object_type, - uint64_t object, size_t location, int32_t message_code, - const char* layer_prefix, const char* message, void* user_data) { - IREE_VLOG(1) << message; - return VK_FALSE; // VK_TRUE is reserved for future use. } -} // namespace - -// static -void DebugReporter::PopulateStaticCreateInfo( - VkDebugUtilsMessengerCreateInfoEXT* create_info) { - create_info->sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; - create_info->pNext = nullptr; - create_info->flags = 0; +// Populates |create_info| with an instance-agnostic callback. +// This can be used during instance creation by chaining the |create_info| to +// VkInstanceCreateInfo::pNext. +// +// Only use if VK_EXT_debug_utils is present. +static void iree_hal_vulkan_debug_reporter_populate_create_info( + VkDebugUtilsMessengerCreateInfoEXT* out_create_info) { + out_create_info->sType = + VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; + out_create_info->pNext = nullptr; + out_create_info->flags = 0; // TODO(benvanik): only enable the severities that logging has enabled. - create_info->messageSeverity = + out_create_info->messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT; // TODO(benvanik): allow filtering by category as a flag. - create_info->messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT; - - create_info->pfnUserCallback = DebugUtilsMessageCallback; - create_info->pUserData = nullptr; + out_create_info->messageType = + VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT; + + out_create_info->pfnUserCallback = + iree_hal_vulkan_debug_utils_message_callback; + out_create_info->pUserData = nullptr; } -// static -void DebugReporter::PopulateStaticCreateInfo( - VkDebugReportCallbackCreateInfoEXT* create_info) { - create_info->sType = VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT; - create_info->pNext = nullptr; - create_info->flags = 0; - - // TODO(benvanik): only enable the severities that logging has enabled. - create_info->flags |= - VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | - VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | - VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_DEBUG_BIT_EXT; - - create_info->pfnCallback = DebugReportCallback; - create_info->pUserData = nullptr; -} - -// static -StatusOr> -DebugReporter::CreateDebugUtilsMessenger( - VkInstance instance, const ref_ptr& syms, - const VkAllocationCallbacks* allocation_callbacks) { - IREE_TRACE_SCOPE0("DebugReporter::CreateDebugUtilsMessenger"); - - auto debug_reporter = std::unique_ptr( - new DebugReporter(instance, syms, allocation_callbacks)); +iree_status_t iree_hal_vulkan_debug_reporter_allocate( + VkInstance instance, iree::hal::vulkan::DynamicSymbols* syms, + const VkAllocationCallbacks* allocation_callbacks, + iree_allocator_t host_allocator, + iree_hal_vulkan_debug_reporter_t** out_reporter) { + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(syms); + IREE_ASSERT_ARGUMENT(out_reporter); + IREE_TRACE_ZONE_BEGIN(z0); + + // Allocate our struct first as we need to pass the pointer to the userdata + // of the messager instance when we create it. + iree_hal_vulkan_debug_reporter_t* reporter = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*reporter), + (void**)&reporter)); + reporter->host_allocator = host_allocator; + reporter->instance = instance; + reporter->syms = syms; + reporter->allocation_callbacks = allocation_callbacks; VkDebugUtilsMessengerCreateInfoEXT create_info; - PopulateStaticCreateInfo(&create_info); - create_info.pUserData = debug_reporter.get(); - - VK_RETURN_IF_ERROR(syms->vkCreateDebugUtilsMessengerEXT( - instance, &create_info, allocation_callbacks, - &debug_reporter->messenger_)); - - return debug_reporter; -} - -// static -StatusOr> -DebugReporter::CreateDebugReportCallback( - VkInstance instance, const ref_ptr& syms, - const VkAllocationCallbacks* allocation_callbacks) { - IREE_TRACE_SCOPE0("DebugReporter::CreateDebugReportCallback"); - - auto debug_reporter = std::unique_ptr( - new DebugReporter(instance, syms, allocation_callbacks)); - - VkDebugReportCallbackCreateInfoEXT create_info; - PopulateStaticCreateInfo(&create_info); - create_info.pUserData = debug_reporter.get(); - - VK_RETURN_IF_ERROR(syms->vkCreateDebugReportCallbackEXT( - instance, &create_info, allocation_callbacks, - &debug_reporter->callback_)); - - return debug_reporter; + iree_hal_vulkan_debug_reporter_populate_create_info(&create_info); + create_info.pUserData = reporter; + iree_status_t status = VK_RESULT_TO_STATUS( + syms->vkCreateDebugUtilsMessengerEXT( + instance, &create_info, allocation_callbacks, &reporter->messenger), + "vkCreateDebugUtilsMessengerEXT"); + + if (iree_status_is_ok(status)) { + *out_reporter = reporter; + } else { + iree_hal_vulkan_debug_reporter_free(reporter); + } + IREE_TRACE_ZONE_END(z0); + return status; } -DebugReporter::DebugReporter(VkInstance instance, - const ref_ptr& syms, - const VkAllocationCallbacks* allocation_callbacks) - : instance_(instance), - syms_(add_ref(syms)), - allocation_callbacks_(allocation_callbacks) {} +void iree_hal_vulkan_debug_reporter_free( + iree_hal_vulkan_debug_reporter_t* reporter) { + if (!reporter) return; + iree_allocator_t host_allocator = reporter->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); -DebugReporter::~DebugReporter() { - IREE_TRACE_SCOPE0("DebugReporter::dtor"); - if (messenger_ != VK_NULL_HANDLE) { - syms_->vkDestroyDebugUtilsMessengerEXT(instance_, messenger_, - allocation_callbacks_); + if (reporter->messenger != VK_NULL_HANDLE) { + reporter->syms->vkDestroyDebugUtilsMessengerEXT( + reporter->instance, reporter->messenger, + reporter->allocation_callbacks); } - if (callback_ != VK_NULL_HANDLE) { - syms_->vkDestroyDebugReportCallbackEXT(instance_, callback_, - allocation_callbacks_); - } -} + iree_allocator_free(host_allocator, reporter); -} // namespace vulkan -} // namespace hal -} // namespace iree + IREE_TRACE_ZONE_END(z0); +} diff --git a/iree/hal/vulkan/debug_reporter.h b/iree/hal/vulkan/debug_reporter.h index 82dad6e9ae79f..3c92d827196d5 100644 --- a/iree/hal/vulkan/debug_reporter.h +++ b/iree/hal/vulkan/debug_reporter.h @@ -15,22 +15,13 @@ #ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_ #define IREE_HAL_VULKAN_DEBUG_REPORTER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/base/status.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/dynamic_symbols.h" -namespace iree { -namespace hal { -namespace vulkan { - // A debug reporter that works with the VK_EXT_debug_utils extension. // One reporter should be created per VkInstance to receive callbacks from the -// API and route them to our logging systems. In general VK_EXT_debug_utils -// should be preferred if available as it provides a much cleaner interface and -// more plug-points than VK_EXT_debug_report. +// API and route them to our logging systems. // // Since creating a reporter requires a VkInstance it's not possible to report // on messages during instance creation. To work around this it's possible to @@ -38,52 +29,16 @@ namespace vulkan { // VkInstanceCreateInfo::pNext chain. The callback will only be used this way // during the creation call after which users can create the real // instance-specific reporter. -class DebugReporter final { - public: - // Populates |create_info| with an instance-agnostic callback. - // This can be used during instance creation by chaining the |create_info| to - // VkInstanceCreateInfo::pNext. - // - // Only use if VK_EXT_debug_utils is present. - static void PopulateStaticCreateInfo( - VkDebugUtilsMessengerCreateInfoEXT* create_info); - - // Populates |create_info| with an instance-agnostic callback. - // This can be used during instance creation by chaining the |create_info| to - // VkInstanceCreateInfo::pNext. - // - // Only use if VK_EXT_debug_report is present. - static void PopulateStaticCreateInfo( - VkDebugReportCallbackCreateInfoEXT* create_info); - - // Creates a debug messenger for the given Vulkan |instance| with - // VK_EXT_debug_utils enabled. - static StatusOr> CreateDebugUtilsMessenger( - VkInstance instance, const ref_ptr& syms, - const VkAllocationCallbacks* allocation_callbacks); - - // Creates a debug report callback for the given Vulkan |instance| with - // VK_EXT_debug_report enabled. - static StatusOr> CreateDebugReportCallback( - VkInstance instance, const ref_ptr& syms, - const VkAllocationCallbacks* allocation_callbacks); - - ~DebugReporter(); - - private: - DebugReporter(VkInstance instance, const ref_ptr& syms, - const VkAllocationCallbacks* allocation_callbacks); - - VkInstance instance_ = VK_NULL_HANDLE; - ref_ptr syms_; - const VkAllocationCallbacks* allocation_callbacks_ = nullptr; +typedef struct iree_hal_vulkan_debug_reporter_s + iree_hal_vulkan_debug_reporter_t; - VkDebugUtilsMessengerEXT messenger_ = VK_NULL_HANDLE; - VkDebugReportCallbackEXT callback_ = VK_NULL_HANDLE; -}; +iree_status_t iree_hal_vulkan_debug_reporter_allocate( + VkInstance instance, iree::hal::vulkan::DynamicSymbols* syms, + const VkAllocationCallbacks* allocation_callbacks, + iree_allocator_t host_allocator, + iree_hal_vulkan_debug_reporter_t** out_reporter); -} // namespace vulkan -} // namespace hal -} // namespace iree +void iree_hal_vulkan_debug_reporter_free( + iree_hal_vulkan_debug_reporter_t* reporter); #endif // IREE_HAL_VULKAN_DEBUG_REPORTER_H_ diff --git a/iree/hal/vulkan/descriptor_pool_cache.cc b/iree/hal/vulkan/descriptor_pool_cache.cc index 6feea168d14d8..3853796825b0d 100644 --- a/iree/hal/vulkan/descriptor_pool_cache.cc +++ b/iree/hal/vulkan/descriptor_pool_cache.cc @@ -48,8 +48,8 @@ Status DescriptorSetGroup::Reset() { return OkStatus(); } -DescriptorPoolCache::DescriptorPoolCache(ref_ptr logical_device) - : logical_device_(std::move(logical_device)) {} +DescriptorPoolCache::DescriptorPoolCache(VkDeviceHandle* logical_device) + : logical_device_(logical_device) {} StatusOr DescriptorPoolCache::AcquireDescriptorPool( VkDescriptorType descriptor_type, int max_descriptor_count) { @@ -74,8 +74,9 @@ StatusOr DescriptorPoolCache::AcquireDescriptorPool( descriptor_pool.handle = VK_NULL_HANDLE; VK_RETURN_IF_ERROR(syms().vkCreateDescriptorPool( - *logical_device_, &create_info, logical_device_->allocator(), - &descriptor_pool.handle)); + *logical_device_, &create_info, + logical_device_->allocator(), &descriptor_pool.handle), + "vkCreateDescriptorPool"); return descriptor_pool; } @@ -89,7 +90,8 @@ Status DescriptorPoolCache::ReleaseDescriptorPools( // this leads to better errors when using the validation layers as we'll // throw if there are in-flight command buffers using the sets in the pool. VK_RETURN_IF_ERROR(syms().vkResetDescriptorPool(*logical_device_, - descriptor_pool.handle, 0)); + descriptor_pool.handle, 0), + "vkResetDescriptorPool"); // TODO(benvanik): release to cache. syms().vkDestroyDescriptorPool(*logical_device_, descriptor_pool.handle, diff --git a/iree/hal/vulkan/descriptor_pool_cache.h b/iree/hal/vulkan/descriptor_pool_cache.h index bb4fa331725f1..0e001c0f533c8 100644 --- a/iree/hal/vulkan/descriptor_pool_cache.h +++ b/iree/hal/vulkan/descriptor_pool_cache.h @@ -16,7 +16,6 @@ #define IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_ #include "absl/container/inlined_vector.h" -#include "iree/base/ref_ptr.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/handle_util.h" @@ -43,9 +42,9 @@ struct DescriptorPool { class DescriptorSetGroup final { public: DescriptorSetGroup() = default; - DescriptorSetGroup(ref_ptr descriptor_pool_cache, + DescriptorSetGroup(DescriptorPoolCache* descriptor_pool_cache, absl::InlinedVector descriptor_pools) - : descriptor_pool_cache_(std::move(descriptor_pool_cache)), + : descriptor_pool_cache_(descriptor_pool_cache), descriptor_pools_(std::move(descriptor_pools)) {} DescriptorSetGroup(const DescriptorSetGroup&) = delete; DescriptorSetGroup& operator=(const DescriptorSetGroup&) = delete; @@ -62,7 +61,7 @@ class DescriptorSetGroup final { Status Reset(); private: - ref_ptr descriptor_pool_cache_; + DescriptorPoolCache* descriptor_pool_cache_; absl::InlinedVector descriptor_pools_; }; @@ -72,13 +71,11 @@ class DescriptorSetGroup final { // resources. After the descriptors in the pool are no longer used (all // command buffers using descriptor sets allocated from the pool have retired) // the pool is returned here to be reused in the future. -class DescriptorPoolCache final : public RefObject { +class DescriptorPoolCache final { public: - explicit DescriptorPoolCache(ref_ptr logical_device); + explicit DescriptorPoolCache(VkDeviceHandle* logical_device); - const ref_ptr& logical_device() const { - return logical_device_; - } + VkDeviceHandle* logical_device() const { return logical_device_; } const DynamicSymbols& syms() const { return *logical_device_->syms(); } // Acquires a new descriptor pool for use by the caller. @@ -93,7 +90,7 @@ class DescriptorPoolCache final : public RefObject { Status ReleaseDescriptorPools(absl::Span descriptor_pools); private: - ref_ptr logical_device_; + VkDeviceHandle* logical_device_; }; } // namespace vulkan diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc index ee60e6bb019da..7c86d8c2ddc5f 100644 --- a/iree/hal/vulkan/descriptor_set_arena.cc +++ b/iree/hal/vulkan/descriptor_set_arena.cc @@ -17,6 +17,8 @@ #include "iree/base/alignment.h" #include "iree/base/math.h" #include "iree/base/tracing.h" +#include "iree/hal/vulkan/native_descriptor_set_layout.h" +#include "iree/hal/vulkan/native_executable_layout.h" #include "iree/hal/vulkan/status_util.h" #include "iree/hal/vulkan/vma_buffer.h" @@ -26,27 +28,24 @@ namespace vulkan { namespace { -StatusOr CastBuffer(Buffer* buffer) { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return static_cast(buffer->allocated_buffer()); -} - -StatusOr> PopulateDescriptorSetWriteInfos( - absl::Span bindings, VkDescriptorSet dst_set, +static StatusOr> +PopulateDescriptorSetWriteInfos( + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, VkDescriptorSet dst_set, Arena* arena) { arena->Reset(); auto buffer_infos = - arena->AllocateSpan(bindings.size()); - auto write_infos = arena->AllocateSpan(bindings.size()); + arena->AllocateSpan(binding_count); + auto write_infos = arena->AllocateSpan(binding_count); - for (int i = 0; i < bindings.size(); ++i) { + for (int i = 0; i < binding_count; ++i) { const auto& binding = bindings[i]; auto& buffer_info = buffer_infos[i]; - IREE_ASSIGN_OR_RETURN(auto buffer, CastBuffer(binding.buffer)); - buffer_info.buffer = buffer->handle(); - buffer_info.offset = binding.buffer->byte_offset() + binding.offset; + buffer_info.buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(binding.buffer)); + buffer_info.offset = + iree_hal_buffer_byte_offset(binding.buffer) + binding.offset; // Round up to a multiple of 32-bit. 32-bit is the most native bitwidth on // GPUs; it has the best support compared to other bitwidths. We use VMA to // manage GPU memory for us and VMA should already handled proper alignment @@ -64,10 +63,10 @@ StatusOr> PopulateDescriptorSetWriteInfos( // the shader is considered as out of bounds per the Vulkan spec. // See https://github.com/google/iree/issues/2022#issuecomment-640617234 // for more details. - buffer_info.range = - iree_align(std::min(binding.length, - binding.buffer->byte_length() - binding.offset), - 4); + buffer_info.range = iree_align( + std::min(binding.length, + iree_hal_buffer_byte_length(binding.buffer) - binding.offset), + 4); auto& write_info = write_infos[i]; write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; @@ -85,15 +84,16 @@ StatusOr> PopulateDescriptorSetWriteInfos( return write_infos; } -VkDescriptorSetAllocateInfo PopulateDescriptorSetsAllocateInfo( +static VkDescriptorSetAllocateInfo PopulateDescriptorSetsAllocateInfo( const DescriptorPool& descriptor_pool, - NativeDescriptorSetLayout* set_layout) { + iree_hal_descriptor_set_layout_t* set_layout) { VkDescriptorSetAllocateInfo allocate_info; allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; allocate_info.pNext = nullptr; allocate_info.descriptorPool = descriptor_pool.handle; - VkDescriptorSetLayout set_layout_handle = set_layout->handle(); + VkDescriptorSetLayout set_layout_handle = + iree_hal_vulkan_native_descriptor_set_layout_handle(set_layout); allocate_info.descriptorSetCount = 1; allocate_info.pSetLayouts = &set_layout_handle; @@ -103,9 +103,9 @@ VkDescriptorSetAllocateInfo PopulateDescriptorSetsAllocateInfo( } // namespace DescriptorSetArena::DescriptorSetArena( - ref_ptr descriptor_pool_cache) - : logical_device_(add_ref(descriptor_pool_cache->logical_device())), - descriptor_pool_cache_(std::move(descriptor_pool_cache)) {} + DescriptorPoolCache* descriptor_pool_cache) + : logical_device_(descriptor_pool_cache->logical_device()), + descriptor_pool_cache_(descriptor_pool_cache) {} DescriptorSetArena::~DescriptorSetArena() { if (!used_descriptor_pools_.empty()) { @@ -117,21 +117,25 @@ DescriptorSetArena::~DescriptorSetArena() { } Status DescriptorSetArena::BindDescriptorSet( - VkCommandBuffer command_buffer, PipelineExecutableLayout* executable_layout, - int32_t set, absl::Span bindings) { + VkCommandBuffer command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { // Always prefer using push descriptors when available as we can avoid the // additional API overhead of updating/resetting pools. if (logical_device_->enabled_extensions().push_descriptors) { - return PushDescriptorSet(command_buffer, executable_layout, set, bindings); + return PushDescriptorSet(command_buffer, executable_layout, set, + binding_count, bindings); } IREE_TRACE_SCOPE0("DescriptorSetArena::BindDescriptorSet"); - auto* set_layout = executable_layout->set_layouts()[set].get(); + auto* set_layout = + iree_hal_vulkan_native_executable_layout_set(executable_layout, set); // Pick a bucket based on the number of descriptors required. // NOTE: right now we are 1:1 with bindings. - uint32_t required_descriptor_count = static_cast(bindings.size() * 1); + uint32_t required_descriptor_count = static_cast(binding_count * 1); uint32_t max_descriptor_count = std::max(8u, iree_math_round_up_to_pow2_u32(required_descriptor_count)); uint32_t bucket = @@ -155,7 +159,8 @@ Status DescriptorSetArena::BindDescriptorSet( allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; allocate_info.pNext = nullptr; allocate_info.descriptorPool = descriptor_pool.handle; - VkDescriptorSetLayout set_layout_handle = set_layout->handle(); + VkDescriptorSetLayout set_layout_handle = + iree_hal_vulkan_native_descriptor_set_layout_handle(set_layout); allocate_info.descriptorSetCount = 1; allocate_info.pSetLayouts = &set_layout_handle; @@ -177,18 +182,18 @@ Status DescriptorSetArena::BindDescriptorSet( allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; allocate_info.pNext = nullptr; allocate_info.descriptorPool = descriptor_pool_buckets_[bucket].handle; - VkDescriptorSetLayout set_layout_handle = set_layout->handle(); allocate_info.descriptorSetCount = 1; allocate_info.pSetLayouts = &set_layout_handle; descriptor_set = VK_NULL_HANDLE; VK_RETURN_IF_ERROR(syms().vkAllocateDescriptorSets( - *logical_device_, &allocate_info, &descriptor_set)); + *logical_device_, &allocate_info, &descriptor_set), + "vkAllocateDescriptorSets"); } // Get a list of VkWriteDescriptorSet structs with all bound buffers. - IREE_ASSIGN_OR_RETURN(auto write_infos, - PopulateDescriptorSetWriteInfos( - bindings, descriptor_set, &scratch_arena_)); + IREE_ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( + binding_count, bindings, + descriptor_set, &scratch_arena_)); // This is the reason why push descriptor sets are good. // We can't batch these effectively as we don't know prior to recording what @@ -200,29 +205,33 @@ Status DescriptorSetArena::BindDescriptorSet( write_infos.data(), 0, nullptr); // Bind the descriptor set. - syms().vkCmdBindDescriptorSets(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, - executable_layout->handle(), set, 1, - &descriptor_set, 0, nullptr); + syms().vkCmdBindDescriptorSets( + command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, + iree_hal_vulkan_native_executable_layout_handle(executable_layout), set, + 1, &descriptor_set, 0, nullptr); return OkStatus(); } Status DescriptorSetArena::PushDescriptorSet( - VkCommandBuffer command_buffer, PipelineExecutableLayout* executable_layout, - int32_t set, absl::Span bindings) { + VkCommandBuffer command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { IREE_TRACE_SCOPE0("DescriptorSetArena::PushDescriptorSet"); + VkPipelineLayout device_executable_layout = + iree_hal_vulkan_native_executable_layout_handle(executable_layout); // Get a list of VkWriteDescriptorSet structs with all bound buffers. - IREE_ASSIGN_OR_RETURN(auto write_infos, - PopulateDescriptorSetWriteInfos( - bindings, VK_NULL_HANDLE, &scratch_arena_)); + IREE_ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( + binding_count, bindings, + VK_NULL_HANDLE, &scratch_arena_)); // Fast path using push descriptors. These are pooled internally by the // command buffer and prevent the need for our own pooling mechanisms. syms().vkCmdPushDescriptorSetKHR( - command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, - executable_layout->handle(), set, - static_cast(write_infos.size()), write_infos.data()); + command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, device_executable_layout, + set, static_cast(write_infos.size()), write_infos.data()); return OkStatus(); } @@ -238,7 +247,7 @@ StatusOr DescriptorSetArena::Flush() { for (auto& bucket : descriptor_pool_buckets_) { bucket = {}; } - return DescriptorSetGroup(add_ref(descriptor_pool_cache_), + return DescriptorSetGroup(descriptor_pool_cache_, std::move(used_descriptor_pools_)); } diff --git a/iree/hal/vulkan/descriptor_set_arena.h b/iree/hal/vulkan/descriptor_set_arena.h index 8cc27adbe6670..e34c35b442243 100644 --- a/iree/hal/vulkan/descriptor_set_arena.h +++ b/iree/hal/vulkan/descriptor_set_arena.h @@ -20,9 +20,8 @@ #include "iree/base/arena.h" #include "iree/base/status.h" -#include "iree/hal/command_buffer.h" #include "iree/hal/vulkan/descriptor_pool_cache.h" -#include "iree/hal/vulkan/pipeline_executable.h" +#include "iree/hal/vulkan/native_executable.h" namespace iree { namespace hal { @@ -31,17 +30,16 @@ namespace vulkan { // A reusable arena for allocating descriptor sets and batching updates. class DescriptorSetArena final { public: - explicit DescriptorSetArena( - ref_ptr descriptor_pool_cache); + explicit DescriptorSetArena(DescriptorPoolCache* descriptor_pool_cache); ~DescriptorSetArena(); // Allocates and binds a descriptor set from the arena. // The command buffer will have the descriptor set containing |bindings| bound // to it. Status BindDescriptorSet(VkCommandBuffer command_buffer, - PipelineExecutableLayout* executable_layout, - int32_t set, - absl::Span bindings); + iree_hal_executable_layout_t* executable_layout, + uint32_t set, iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings); // Flushes all pending writes to descriptor sets allocated from the arena and // returns a group that - when dropped - will release the descriptor sets @@ -53,12 +51,12 @@ class DescriptorSetArena final { // Pushes the descriptor set to the command buffer, if supported. Status PushDescriptorSet(VkCommandBuffer command_buffer, - PipelineExecutableLayout* executable_layout, - int32_t set, - absl::Span bindings); + iree_hal_executable_layout_t* executable_layout, + uint32_t set, iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings); - ref_ptr logical_device_; - ref_ptr descriptor_pool_cache_; + VkDeviceHandle* logical_device_; + DescriptorPoolCache* descriptor_pool_cache_; // Arena used for temporary binding information used during allocation. Arena scratch_arena_; diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc index b5a9aabee69c5..52a7bb0565c11 100644 --- a/iree/hal/vulkan/direct_command_buffer.cc +++ b/iree/hal/vulkan/direct_command_buffer.cc @@ -14,332 +14,442 @@ #include "iree/hal/vulkan/direct_command_buffer.h" -#include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" #include "iree/base/math.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" +#include "iree/hal/vulkan/descriptor_set_arena.h" +#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/vulkan/native_descriptor_set.h" +#include "iree/hal/vulkan/native_event.h" +#include "iree/hal/vulkan/native_executable_layout.h" #include "iree/hal/vulkan/status_util.h" +#include "iree/hal/vulkan/vma_buffer.h" + +using namespace iree::hal::vulkan; + +// Command buffer implementation that directly maps to VkCommandBuffer. +// This records the commands on the calling thread without additional threading +// indirection. +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + iree_hal_command_buffer_mode_t mode; + iree_hal_command_category_t allowed_categories; + + VkCommandPoolHandle* command_pool; + VkCommandBuffer handle; + + DynamicSymbols* syms; + + // TODO(benvanik): may grow large - should try to reclaim or reuse. + DescriptorSetArena descriptor_set_arena; + + // The current descriptor set group in use by the command buffer, if any. + // This must remain valid until all in-flight submissions of the command + // buffer complete. + DescriptorSetGroup descriptor_set_group; +} iree_hal_vulkan_direct_command_buffer_t; + +extern const iree_hal_command_buffer_vtable_t + iree_hal_vulkan_direct_command_buffer_vtable; + +static iree_hal_vulkan_direct_command_buffer_t* +iree_hal_vulkan_direct_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_direct_command_buffer_vtable); + return (iree_hal_vulkan_direct_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_vulkan_direct_command_buffer_allocate( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::VkCommandPoolHandle* command_pool, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree::hal::vulkan::DescriptorPoolCache* descriptor_pool_cache, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(command_pool); + IREE_ASSERT_ARGUMENT(descriptor_pool_cache); + IREE_ASSERT_ARGUMENT(out_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + VkCommandBufferAllocateInfo allocate_info; + allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + allocate_info.pNext = NULL; + allocate_info.commandPool = *command_pool; + allocate_info.commandBufferCount = 1; + allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + + VkCommandBuffer handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_pool->Allocate(&allocate_info, &handle)); + + iree_hal_vulkan_direct_command_buffer_t* command_buffer = NULL; + iree_status_t status = + iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*command_buffer), (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_direct_command_buffer_vtable, + &command_buffer->resource); + command_buffer->logical_device = logical_device; + command_buffer->mode = mode; + command_buffer->allowed_categories = command_categories; + command_buffer->command_pool = command_pool; + command_buffer->handle = handle; + command_buffer->syms = logical_device->syms().get(); + + new (&command_buffer->descriptor_set_arena) + DescriptorSetArena(descriptor_pool_cache); + new (&command_buffer->descriptor_set_group) DescriptorSetGroup(); + + *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer; + } else { + command_pool->Free(handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_direct_command_buffer_reset( + iree_hal_vulkan_direct_command_buffer_t* command_buffer) { + // NOTE: we require that command buffers not be recorded while they are + // in-flight so this is safe. + IREE_IGNORE_ERROR(command_buffer->descriptor_set_group.Reset()); +} + +static void iree_hal_vulkan_direct_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + iree_allocator_t host_allocator = + command_buffer->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_direct_command_buffer_reset(command_buffer); + command_buffer->command_pool->Free(command_buffer->handle); + + command_buffer->descriptor_set_group.~DescriptorSetGroup(); + command_buffer->descriptor_set_arena.~DescriptorSetArena(); -namespace iree { -namespace hal { -namespace vulkan { + iree_allocator_free(host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +VkCommandBuffer iree_hal_vulkan_direct_command_buffer_handle( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + return command_buffer->handle; +} -namespace { +static iree_hal_command_category_t +iree_hal_vulkan_direct_command_buffer_allowed_categories( + const iree_hal_command_buffer_t* base_command_buffer) { + return ((const iree_hal_vulkan_direct_command_buffer_t*)base_command_buffer) + ->allowed_categories; +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + iree_hal_vulkan_direct_command_buffer_reset(command_buffer); + + VkCommandBufferBeginInfo begin_info; + begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + begin_info.pNext = NULL; + begin_info.flags = iree_all_bits_set(command_buffer->mode, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) + ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT + : 0; + begin_info.pInheritanceInfo = NULL; + VK_RETURN_IF_ERROR(command_buffer->syms->vkBeginCommandBuffer( + command_buffer->handle, &begin_info), + "vkBeginCommandBuffer"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + VK_RETURN_IF_ERROR( + command_buffer->syms->vkEndCommandBuffer(command_buffer->handle), + "vkEndCommandBuffer"); + + // Flush all pending descriptor set writes (if any). + IREE_ASSIGN_OR_RETURN(command_buffer->descriptor_set_group, + command_buffer->descriptor_set_arena.Flush()); -VkPipelineStageFlags ConvertPipelineStageFlags( - ExecutionStageBitfield stage_mask) { + return iree_ok_status(); +} + +static VkPipelineStageFlags iree_hal_vulkan_convert_pipeline_stage_flags( + iree_hal_execution_stage_t stage_mask) { VkPipelineStageFlags flags = 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandIssue) + flags |= iree_any_bit_set(stage_mask, IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE) ? VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandProcess) - ? VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT - : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kDispatch) + flags |= + iree_any_bit_set(stage_mask, IREE_HAL_EXECUTION_STAGE_COMMAND_PROCESS) + ? VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT + : 0; + flags |= iree_any_bit_set(stage_mask, IREE_HAL_EXECUTION_STAGE_DISPATCH) ? VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kTransfer) + flags |= iree_any_bit_set(stage_mask, IREE_HAL_EXECUTION_STAGE_TRANSFER) ? VK_PIPELINE_STAGE_TRANSFER_BIT : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandRetire) + flags |= iree_any_bit_set(stage_mask, IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE) ? VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kHost) + flags |= iree_any_bit_set(stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) ? VK_PIPELINE_STAGE_HOST_BIT : 0; return flags; } -VkAccessFlags ConvertAccessMask(AccessScopeBitfield access_mask) { +static VkAccessFlags iree_hal_vulkan_convert_access_mask( + iree_hal_access_scope_t access_mask) { VkAccessFlags flags = 0; - flags |= AnyBitSet(access_mask & AccessScope::kIndirectCommandRead) - ? VK_ACCESS_INDIRECT_COMMAND_READ_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kConstantRead) + flags |= + iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_INDIRECT_COMMAND_READ) + ? VK_ACCESS_INDIRECT_COMMAND_READ_BIT + : 0; + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_CONSTANT_READ) ? VK_ACCESS_UNIFORM_READ_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kDispatchRead) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_DISPATCH_READ) ? VK_ACCESS_SHADER_READ_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kDispatchWrite) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE) ? VK_ACCESS_SHADER_WRITE_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kTransferRead) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_TRANSFER_READ) ? VK_ACCESS_TRANSFER_READ_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kTransferWrite) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_TRANSFER_WRITE) ? VK_ACCESS_TRANSFER_WRITE_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kHostRead) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_HOST_READ) ? VK_ACCESS_HOST_READ_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kHostWrite) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_HOST_WRITE) ? VK_ACCESS_HOST_WRITE_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kMemoryRead) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_MEMORY_READ) ? VK_ACCESS_MEMORY_READ_BIT : 0; - flags |= AnyBitSet(access_mask & AccessScope::kMemoryWrite) + flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_MEMORY_WRITE) ? VK_ACCESS_MEMORY_WRITE_BIT : 0; return flags; } -// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value. -uint32_t SplatPattern(const void* pattern, size_t pattern_length) { - switch (pattern_length) { - case 1: { - uint32_t pattern_value = *static_cast(pattern); - return (pattern_value << 24) | (pattern_value << 16) | - (pattern_value << 8) | pattern_value; - } - case 2: { - uint32_t pattern_value = *static_cast(pattern); - return (pattern_value << 16) | pattern_value; - } - case 4: { - uint32_t pattern_value = *static_cast(pattern); - return pattern_value; - } - default: - return 0; // Already verified that this should not be possible. - } -} - -} // namespace - -DirectCommandBuffer::DirectCommandBuffer( - CommandBufferModeBitfield mode, CommandCategoryBitfield command_categories, - ref_ptr descriptor_pool_cache, - ref_ptr command_pool, VkCommandBuffer command_buffer) - : CommandBuffer(mode, command_categories), - command_pool_(std::move(command_pool)), - command_buffer_(command_buffer), - descriptor_set_arena_(std::move(descriptor_pool_cache)) {} - -DirectCommandBuffer::~DirectCommandBuffer() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::dtor"); - descriptor_set_group_.Reset().IgnoreError(); - absl::MutexLock lock(command_pool_->mutex()); - syms()->vkFreeCommandBuffers(*command_pool_->logical_device(), *command_pool_, - 1, &command_buffer_); -} - -StatusOr DirectCommandBuffer::CastEvent(Event* event) const { - // TODO(benvanik): assert the event is valid. - return static_cast(event); -} - -StatusOr DirectCommandBuffer::CastBuffer(Buffer* buffer) const { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return static_cast(buffer->allocated_buffer()); -} - -StatusOr DirectCommandBuffer::CastDescriptorSet( - DescriptorSet* descriptor_set) const { - // TODO(benvanik): assert the descriptor_set is valid. - return static_cast(descriptor_set); -} - -StatusOr DirectCommandBuffer::CastExecutableLayout( - ExecutableLayout* executable_layout) const { - // TODO(benvanik): assert the executable_layout is valid. - return static_cast(executable_layout); -} - -StatusOr DirectCommandBuffer::CastExecutable( - Executable* executable) const { - // TODO(benvanik): assert the executable is valid. - return static_cast(executable); -} - -Status DirectCommandBuffer::Begin() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::Begin"); - - is_recording_ = true; - - // NOTE: we require that command buffers not be recorded while they are - // in-flight so this is safe. - IREE_RETURN_IF_ERROR(descriptor_set_group_.Reset()); - - VkCommandBufferBeginInfo begin_info; - begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; - begin_info.pNext = nullptr; - begin_info.flags = AllBitsSet(mode(), CommandBufferMode::kOneShot) - ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT - : 0; - begin_info.pInheritanceInfo = nullptr; - VK_RETURN_IF_ERROR( - syms()->vkBeginCommandBuffer(command_buffer_, &begin_info)); - - return OkStatus(); -} - -Status DirectCommandBuffer::End() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::End"); - - VK_RETURN_IF_ERROR(syms()->vkEndCommandBuffer(command_buffer_)); - - // Flush all pending descriptor set writes (if any). - IREE_ASSIGN_OR_RETURN(descriptor_set_group_, descriptor_set_arena_.Flush()); - - is_recording_ = false; - - return OkStatus(); -} - -Status DirectCommandBuffer::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::ExecutionBarrier"); +static iree_status_t iree_hal_vulkan_direct_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); absl::InlinedVector memory_barrier_infos( - memory_barriers.size()); - for (int i = 0; i < memory_barriers.size(); ++i) { + memory_barrier_count); + for (int i = 0; i < memory_barrier_count; ++i) { const auto& memory_barrier = memory_barriers[i]; auto& info = memory_barrier_infos[i]; info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.target_scope); } absl::InlinedVector buffer_barrier_infos( - buffer_barriers.size()); - for (int i = 0; i < buffer_barriers.size(); ++i) { + buffer_barrier_count); + for (int i = 0; i < buffer_barrier_count; ++i) { const auto& buffer_barrier = buffer_barriers[i]; auto& info = buffer_barrier_infos[i]; info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.target_scope); info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - IREE_ASSIGN_OR_RETURN(auto* device_buffer, - CastBuffer(buffer_barrier.buffer)); - info.buffer = device_buffer->handle(); + info.buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(buffer_barrier.buffer)); info.offset = buffer_barrier.offset; info.size = buffer_barrier.length; } - syms()->vkCmdPipelineBarrier( - command_buffer_, ConvertPipelineStageFlags(source_stage_mask), - ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0, - static_cast(memory_barrier_infos.size()), + command_buffer->syms->vkCmdPipelineBarrier( + command_buffer->handle, + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask), + iree_hal_vulkan_convert_pipeline_stage_flags(target_stage_mask), + /*dependencyFlags=*/0, static_cast(memory_barrier_infos.size()), memory_barrier_infos.data(), static_cast(buffer_barrier_infos.size()), - buffer_barrier_infos.data(), 0, nullptr); + buffer_barrier_infos.data(), 0, NULL); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::SignalEvent"); - IREE_ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); - syms()->vkCmdSetEvent(command_buffer_, device_event->handle(), - ConvertPipelineStageFlags(source_stage_mask)); - return OkStatus(); +static iree_status_t iree_hal_vulkan_direct_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + command_buffer->syms->vkCmdSetEvent( + command_buffer->handle, iree_hal_vulkan_native_event_handle(event), + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask)); + + return iree_ok_status(); } -Status DirectCommandBuffer::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::ResetEvent"); - IREE_ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); - syms()->vkCmdResetEvent(command_buffer_, device_event->handle(), - ConvertPipelineStageFlags(source_stage_mask)); - return OkStatus(); +static iree_status_t iree_hal_vulkan_direct_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + command_buffer->syms->vkCmdResetEvent( + command_buffer->handle, iree_hal_vulkan_native_event_handle(event), + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask)); + + return iree_ok_status(); } -Status DirectCommandBuffer::WaitEvents( - absl::Span events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::WaitEvents"); - - absl::InlinedVector event_handles(events.size()); - for (int i = 0; i < events.size(); ++i) { - IREE_ASSIGN_OR_RETURN(auto* device_event, CastEvent(events[i])); - event_handles[i] = device_event->handle(); +static iree_status_t iree_hal_vulkan_direct_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + absl::InlinedVector event_handles(event_count); + for (int i = 0; i < event_count; ++i) { + event_handles[i] = iree_hal_vulkan_native_event_handle(events[i]); } absl::InlinedVector memory_barrier_infos( - memory_barriers.size()); - for (int i = 0; i < memory_barriers.size(); ++i) { + memory_barrier_count); + for (int i = 0; i < memory_barrier_count; ++i) { const auto& memory_barrier = memory_barriers[i]; auto& info = memory_barrier_infos[i]; info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.target_scope); } absl::InlinedVector buffer_barrier_infos( - buffer_barriers.size()); - for (int i = 0; i < buffer_barriers.size(); ++i) { + buffer_barrier_count); + for (int i = 0; i < buffer_barrier_count; ++i) { const auto& buffer_barrier = buffer_barriers[i]; auto& info = buffer_barrier_infos[i]; info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.target_scope); info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - IREE_ASSIGN_OR_RETURN(auto* device_buffer, - CastBuffer(buffer_barrier.buffer)); - info.buffer = device_buffer->handle(); + info.buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(buffer_barrier.buffer)); info.offset = buffer_barrier.offset; info.size = buffer_barrier.length; } - syms()->vkCmdWaitEvents(command_buffer_, event_handles.size(), - event_handles.data(), - ConvertPipelineStageFlags(source_stage_mask), - ConvertPipelineStageFlags(target_stage_mask), - static_cast(memory_barrier_infos.size()), - memory_barrier_infos.data(), - static_cast(buffer_barrier_infos.size()), - buffer_barrier_infos.data(), 0, nullptr); - return OkStatus(); -} + command_buffer->syms->vkCmdWaitEvents( + command_buffer->handle, (uint32_t)event_count, event_handles.data(), + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask), + iree_hal_vulkan_convert_pipeline_stage_flags(target_stage_mask), + (uint32_t)memory_barrier_count, memory_barrier_infos.data(), + (uint32_t)buffer_barrier_count, buffer_barrier_infos.data(), 0, NULL); -Status DirectCommandBuffer::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::FillBuffer"); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); + return iree_ok_status(); +} - // Note that fill only accepts 4-byte aligned values so we need to splat out - // our variable-length pattern. - target_offset += target_buffer->byte_offset(); - uint32_t dword_pattern = SplatPattern(pattern, pattern_length); - syms()->vkCmdFillBuffer(command_buffer_, target_device_buffer->handle(), - target_offset, length, dword_pattern); +static iree_status_t iree_hal_vulkan_direct_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + // NOTE: we could use this to prevent queue family transitions. + return iree_ok_status(); +} - return OkStatus(); +// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value. +static uint32_t iree_hal_vulkan_splat_pattern(const void* pattern, + size_t pattern_length) { + switch (pattern_length) { + case 1: { + uint32_t pattern_value = *static_cast(pattern); + return (pattern_value << 24) | (pattern_value << 16) | + (pattern_value << 8) | pattern_value; + } + case 2: { + uint32_t pattern_value = *static_cast(pattern); + return (pattern_value << 16) | pattern_value; + } + case 4: { + uint32_t pattern_value = *static_cast(pattern); + return pattern_value; + } + default: + return 0; // Already verified that this should not be possible. + } } -Status DirectCommandBuffer::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::DiscardBuffer"); - // NOTE: we could use this to prevent queue family transitions. - return OkStatus(); +static iree_status_t iree_hal_vulkan_direct_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + VkBuffer target_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(target_buffer)); + + // Note that fill only accepts 4-byte aligned values so we need to splat out + // our variable-length pattern. + target_offset += iree_hal_buffer_byte_offset(target_buffer); + uint32_t dword_pattern = + iree_hal_vulkan_splat_pattern(pattern, pattern_length); + command_buffer->syms->vkCmdFillBuffer(command_buffer->handle, + target_device_buffer, target_offset, + length, dword_pattern); + + return iree_ok_status(); } -Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::UpdateBuffer"); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + VkBuffer target_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(target_buffer)); // Vulkan only allows updates of <= 65536 because you really, really, really // shouldn't do large updates like this (as it wastes command buffer space and @@ -347,137 +457,176 @@ Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer, // recommendation in the spec for larger updates is to split the single update // into multiple updates over the entire desired range. const auto* source_buffer_ptr = static_cast(source_buffer); - target_offset += target_buffer->byte_offset(); + target_offset += iree_hal_buffer_byte_offset(target_buffer); while (length > 0) { - device_size_t chunk_length = - std::min(static_cast(65536u), length); - syms()->vkCmdUpdateBuffer(command_buffer_, target_device_buffer->handle(), - target_offset, chunk_length, source_buffer_ptr); + iree_device_size_t chunk_length = + iree_min((iree_device_size_t)65536u, length); + command_buffer->syms->vkCmdUpdateBuffer(command_buffer->handle, + target_device_buffer, target_offset, + chunk_length, source_buffer_ptr); source_buffer_ptr += chunk_length; target_offset += chunk_length; length -= chunk_length; } - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::CopyBuffer"); - IREE_ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer)); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + VkBuffer source_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(source_buffer)); + VkBuffer target_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(target_buffer)); VkBufferCopy region; - region.srcOffset = source_buffer->byte_offset() + source_offset; - region.dstOffset = target_buffer->byte_offset() + target_offset; + region.srcOffset = iree_hal_buffer_byte_offset(source_buffer) + source_offset; + region.dstOffset = iree_hal_buffer_byte_offset(target_buffer) + target_offset; region.size = length; - syms()->vkCmdCopyBuffer(command_buffer_, source_device_buffer->handle(), - target_device_buffer->handle(), 1, ®ion); + command_buffer->syms->vkCmdCopyBuffer(command_buffer->handle, + source_device_buffer, + target_device_buffer, 1, ®ion); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::PushConstants(ExecutableLayout* executable_layout, - size_t offset, - absl::Span values) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::PushConstants"); - IREE_ASSIGN_OR_RETURN(auto* device_executable_layout, - CastExecutableLayout(executable_layout)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); - syms()->vkCmdPushConstants( - command_buffer_, device_executable_layout->handle(), - VK_SHADER_STAGE_COMPUTE_BIT, - static_cast(offset * sizeof(uint32_t)), - static_cast(values.size() * sizeof(uint32_t)), values.data()); + command_buffer->syms->vkCmdPushConstants( + command_buffer->handle, + iree_hal_vulkan_native_executable_layout_handle(executable_layout), + VK_SHADER_STAGE_COMPUTE_BIT, (uint32_t)offset, (uint32_t)values_length, + values); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::PushDescriptorSet"); - IREE_ASSIGN_OR_RETURN(auto* device_executable_layout, - CastExecutableLayout(executable_layout)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Either allocate, update, and bind a descriptor set or use push descriptor // sets to use the command buffer pool when supported. - return descriptor_set_arena_.BindDescriptorSet( - command_buffer_, device_executable_layout, set, bindings); + return command_buffer->descriptor_set_arena.BindDescriptorSet( + command_buffer->handle, executable_layout, set, binding_count, bindings); } -Status DirectCommandBuffer::BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::BindDescriptorSet"); - IREE_ASSIGN_OR_RETURN(auto* device_executable_layout, - CastExecutableLayout(executable_layout)); - IREE_ASSIGN_OR_RETURN(auto* device_descriptor_set, - CastDescriptorSet(descriptor_set)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_bind_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Vulkan takes uint32_t as the size here, unlike everywhere else. - absl::InlinedVector dynamic_offsets_i32(dynamic_offsets.size()); - for (int i = 0; i < dynamic_offsets.size(); ++i) { + absl::InlinedVector dynamic_offsets_i32(dynamic_offset_count); + for (int i = 0; i < dynamic_offset_count; ++i) { dynamic_offsets_i32[i] = static_cast(dynamic_offsets[i]); } - std::array descriptor_sets = { - device_descriptor_set->handle()}; - syms()->vkCmdBindDescriptorSets( - command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - device_executable_layout->handle(), set, - static_cast(descriptor_sets.size()), descriptor_sets.data(), + VkDescriptorSet descriptor_sets[1] = { + iree_hal_vulkan_native_descriptor_set_handle(descriptor_set), + }; + command_buffer->syms->vkCmdBindDescriptorSets( + command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, + iree_hal_vulkan_native_executable_layout_handle(executable_layout), set, + (uint32_t)IREE_ARRAYSIZE(descriptor_sets), descriptor_sets, static_cast(dynamic_offsets_i32.size()), dynamic_offsets_i32.data()); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::Dispatch(Executable* executable, - int32_t entry_point, - std::array workgroups) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::Dispatch"); +static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Get the compiled and linked pipeline for the specified entry point and // bind it to the command buffer. - IREE_ASSIGN_OR_RETURN(auto* device_executable, CastExecutable(executable)); - IREE_ASSIGN_OR_RETURN( - VkPipeline pipeline, - device_executable->GetPipelineForEntryPoint(entry_point)); - syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline); - - syms()->vkCmdDispatch(command_buffer_, workgroups[0], workgroups[1], - workgroups[2]); - return OkStatus(); + VkPipeline pipeline_handle = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_native_executable_pipeline_for_entry_point( + executable, entry_point, &pipeline_handle)); + command_buffer->syms->vkCmdBindPipeline( + command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle); + + command_buffer->syms->vkCmdDispatch(command_buffer->handle, workgroup_x, + workgroup_y, workgroup_z); + + return iree_ok_status(); } -Status DirectCommandBuffer::DispatchIndirect(Executable* executable, - int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::DispatchIndirect"); +static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, + iree_device_size_t workgroups_offset) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Get the compiled and linked pipeline for the specified entry point and // bind it to the command buffer. - IREE_ASSIGN_OR_RETURN(auto* device_executable, CastExecutable(executable)); - IREE_ASSIGN_OR_RETURN( - VkPipeline pipeline, - device_executable->GetPipelineForEntryPoint(entry_point)); - syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline); - - IREE_ASSIGN_OR_RETURN(auto* workgroups_device_buffer, - CastBuffer(workgroups_buffer)); - syms()->vkCmdDispatchIndirect( - command_buffer_, workgroups_device_buffer->handle(), workgroups_offset); - return OkStatus(); + VkPipeline pipeline_handle = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_native_executable_pipeline_for_entry_point( + executable, entry_point, &pipeline_handle)); + command_buffer->syms->vkCmdBindPipeline( + command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle); + + VkBuffer workgroups_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(workgroups_buffer)); + workgroups_offset += iree_hal_buffer_byte_offset(workgroups_buffer); + command_buffer->syms->vkCmdDispatchIndirect( + command_buffer->handle, workgroups_device_buffer, workgroups_offset); + + return iree_ok_status(); } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_command_buffer_vtable_t + iree_hal_vulkan_direct_command_buffer_vtable = { + /*.destroy=*/iree_hal_vulkan_direct_command_buffer_destroy, + /*.allowed_categories=*/ + iree_hal_vulkan_direct_command_buffer_allowed_categories, + /*.begin=*/iree_hal_vulkan_direct_command_buffer_begin, + /*.end=*/iree_hal_vulkan_direct_command_buffer_end, + /*.execution_barrier=*/ + iree_hal_vulkan_direct_command_buffer_execution_barrier, + /*.signal_event=*/ + iree_hal_vulkan_direct_command_buffer_signal_event, + /*.reset_event=*/iree_hal_vulkan_direct_command_buffer_reset_event, + /*.wait_events=*/iree_hal_vulkan_direct_command_buffer_wait_events, + /*.discard_buffer=*/ + iree_hal_vulkan_direct_command_buffer_discard_buffer, + /*.fill_buffer=*/iree_hal_vulkan_direct_command_buffer_fill_buffer, + /*.update_buffer=*/ + iree_hal_vulkan_direct_command_buffer_update_buffer, + /*.copy_buffer=*/iree_hal_vulkan_direct_command_buffer_copy_buffer, + /*.push_constants=*/ + iree_hal_vulkan_direct_command_buffer_push_constants, + /*.push_descriptor_set=*/ + iree_hal_vulkan_direct_command_buffer_push_descriptor_set, + /*.bind_descriptor_set=*/ + iree_hal_vulkan_direct_command_buffer_bind_descriptor_set, + /*.dispatch=*/iree_hal_vulkan_direct_command_buffer_dispatch, + /*.dispatch_indirect=*/ + iree_hal_vulkan_direct_command_buffer_dispatch_indirect, +}; diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h index cced6257adf5c..7046093523096 100644 --- a/iree/hal/vulkan/direct_command_buffer.h +++ b/iree/hal/vulkan/direct_command_buffer.h @@ -15,113 +15,29 @@ #ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ #define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/command_buffer.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/descriptor_pool_cache.h" -#include "iree/hal/vulkan/descriptor_set_arena.h" -#include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/native_descriptor_set.h" -#include "iree/hal/vulkan/native_event.h" -#include "iree/hal/vulkan/pipeline_executable.h" -#include "iree/hal/vulkan/pipeline_executable_layout.h" -#include "iree/hal/vulkan/vma_buffer.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// Command buffer implementation that directly maps to VkCommandBuffer. -// This records the commands on the calling thread without additional threading -// indirection. -class DirectCommandBuffer final : public CommandBuffer { - public: - DirectCommandBuffer(CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories, - ref_ptr descriptor_pool_cache, - ref_ptr command_pool, - VkCommandBuffer command_buffer); - ~DirectCommandBuffer() override; - - VkCommandBuffer handle() const { return command_buffer_; } - - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status WaitEvents(absl::Span events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span memory_barriers, - absl::Span buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - Status DiscardBuffer(Buffer* buffer) override; - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status PushConstants(ExecutableLayout* executable_layout, size_t offset, - absl::Span values) override; - - Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span bindings) override; - Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span dynamic_offsets) override; - - Status Dispatch(Executable* executable, int32_t entry_point, - std::array workgroups) override; - Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - device_size_t workgroups_offset) override; - - private: - const ref_ptr& syms() const { return command_pool_->syms(); } - - StatusOr CastEvent(Event* event) const; - StatusOr CastBuffer(Buffer* buffer) const; - StatusOr CastDescriptorSet( - DescriptorSet* descriptor_set) const; - StatusOr CastExecutableLayout( - ExecutableLayout* executable_layout) const; - StatusOr CastExecutable(Executable* executable) const; - - bool is_recording_ = false; - ref_ptr command_pool_; - VkCommandBuffer command_buffer_; - - // TODO(b/140026716): may grow large - should try to reclaim or reuse. - DescriptorSetArena descriptor_set_arena_; - - // The current descriptor set group in use by the command buffer, if any. - // This must remain valid until all in-flight submissions of the command - // buffer complete. - DescriptorSetGroup descriptor_set_group_; -}; -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a command buffer that directly records into a VkCommandBuffer. +iree_status_t iree_hal_vulkan_direct_command_buffer_allocate( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::VkCommandPoolHandle* command_pool, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree::hal::vulkan::DescriptorPoolCache* descriptor_pool_cache, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns the native Vulkan VkCommandBuffer handle. +VkCommandBuffer iree_hal_vulkan_direct_command_buffer_handle( + iree_hal_command_buffer_t* command_buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ diff --git a/iree/hal/vulkan/direct_command_queue.cc b/iree/hal/vulkan/direct_command_queue.cc index 71e8e33ad1979..461ce9aa59ab9 100644 --- a/iree/hal/vulkan/direct_command_queue.cc +++ b/iree/hal/vulkan/direct_command_queue.cc @@ -16,11 +16,9 @@ #include -#include "iree/base/memory.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" #include "iree/hal/vulkan/direct_command_buffer.h" -#include "iree/hal/vulkan/native_timeline_semaphore.h" +#include "iree/hal/vulkan/native_semaphore.h" #include "iree/hal/vulkan/status_util.h" namespace iree { @@ -28,20 +26,15 @@ namespace hal { namespace vulkan { DirectCommandQueue::DirectCommandQueue( - std::string name, CommandCategoryBitfield supported_categories, - const ref_ptr& logical_device, VkQueue queue) - : CommandQueue(std::move(name), supported_categories), - logical_device_(add_ref(logical_device)), - queue_(queue) {} - -DirectCommandQueue::~DirectCommandQueue() { - IREE_TRACE_SCOPE0("DirectCommandQueue::dtor"); - absl::MutexLock lock(&queue_mutex_); - syms()->vkQueueWaitIdle(queue_); -} + VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, VkQueue queue) + : CommandQueue(logical_device, std::move(name), supported_categories, + queue) {} + +DirectCommandQueue::~DirectCommandQueue() = default; -Status DirectCommandQueue::TranslateBatchInfo( - const SubmissionBatch& batch, VkSubmitInfo* submit_info, +iree_status_t DirectCommandQueue::TranslateBatchInfo( + const iree_hal_submission_batch_t* batch, VkSubmitInfo* submit_info, VkTimelineSemaphoreSubmitInfo* timeline_submit_info, Arena* arena) { // TODO(benvanik): see if we can go to finer-grained stages. // For example, if this was just queue ownership transfers then we can use @@ -50,39 +43,33 @@ Status DirectCommandQueue::TranslateBatchInfo( VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; auto wait_semaphore_handles = - arena->AllocateSpan(batch.wait_semaphores.size()); + arena->AllocateSpan(batch->wait_semaphores.count); auto wait_semaphore_values = - arena->AllocateSpan(batch.wait_semaphores.size()); + arena->AllocateSpan(batch->wait_semaphores.count); auto wait_dst_stage_masks = - arena->AllocateSpan(batch.wait_semaphores.size()); - for (int i = 0; i < batch.wait_semaphores.size(); ++i) { - const auto& wait_point = batch.wait_semaphores[i]; - const auto* semaphore = - static_cast(wait_point.semaphore); - wait_semaphore_handles[i] = semaphore->handle(); - wait_semaphore_values[i] = wait_point.value; + arena->AllocateSpan(batch->wait_semaphores.count); + for (iree_host_size_t i = 0; i < batch->wait_semaphores.count; ++i) { + wait_semaphore_handles[i] = iree_hal_vulkan_native_semaphore_handle( + batch->wait_semaphores.semaphores[i]); + wait_semaphore_values[i] = batch->wait_semaphores.payload_values[i]; wait_dst_stage_masks[i] = dst_stage_mask; } auto signal_semaphore_handles = - arena->AllocateSpan(batch.signal_semaphores.size()); + arena->AllocateSpan(batch->signal_semaphores.count); auto signal_semaphore_values = - arena->AllocateSpan(batch.signal_semaphores.size()); - for (int i = 0; i < batch.signal_semaphores.size(); ++i) { - const auto& signal_point = batch.signal_semaphores[i]; - const auto* semaphore = - static_cast(signal_point.semaphore); - signal_semaphore_handles[i] = semaphore->handle(); - signal_semaphore_values[i] = signal_point.value; + arena->AllocateSpan(batch->signal_semaphores.count); + for (iree_host_size_t i = 0; i < batch->signal_semaphores.count; ++i) { + signal_semaphore_handles[i] = iree_hal_vulkan_native_semaphore_handle( + batch->signal_semaphores.semaphores[i]); + signal_semaphore_values[i] = batch->signal_semaphores.payload_values[i]; } auto command_buffer_handles = - arena->AllocateSpan(batch.command_buffers.size()); - for (int i = 0; i < batch.command_buffers.size(); ++i) { - const auto& command_buffer = batch.command_buffers[i]; - auto* direct_command_buffer = - static_cast(command_buffer->impl()); - command_buffer_handles[i] = direct_command_buffer->handle(); + arena->AllocateSpan(batch->command_buffer_count); + for (iree_host_size_t i = 0; i < batch->command_buffer_count; ++i) { + command_buffer_handles[i] = + iree_hal_vulkan_direct_command_buffer_handle(batch->command_buffers[i]); } submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; @@ -111,39 +98,43 @@ Status DirectCommandQueue::TranslateBatchInfo( return OkStatus(); } -Status DirectCommandQueue::Submit(absl::Span batches) { +iree_status_t DirectCommandQueue::Submit( + iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { IREE_TRACE_SCOPE0("DirectCommandQueue::Submit"); // Map the submission batches to VkSubmitInfos. // Note that we must keep all arrays referenced alive until submission // completes and since there are a bunch of them we use an arena. Arena arena(4 * 1024); - auto submit_infos = arena.AllocateSpan(batches.size()); + auto submit_infos = arena.AllocateSpan(batch_count); auto timeline_submit_infos = - arena.AllocateSpan(batches.size()); - for (int i = 0; i < batches.size(); ++i) { - IREE_RETURN_IF_ERROR(TranslateBatchInfo(batches[i], &submit_infos[i], + arena.AllocateSpan(batch_count); + for (int i = 0; i < batch_count; ++i) { + IREE_RETURN_IF_ERROR(TranslateBatchInfo(&batches[i], &submit_infos[i], &timeline_submit_infos[i], &arena)); } - { - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit( - queue_, static_cast(submit_infos.size()), submit_infos.data(), - VK_NULL_HANDLE)); - } + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = VK_RESULT_TO_STATUS( + syms()->vkQueueSubmit(queue_, static_cast(submit_infos.size()), + submit_infos.data(), VK_NULL_HANDLE), + "vkQueueSubmit"); + iree_slim_mutex_unlock(&queue_mutex_); + IREE_RETURN_IF_ERROR(status); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandQueue::WaitIdle(Time deadline_ns) { - if (deadline_ns == InfiniteFuture()) { +iree_status_t DirectCommandQueue::WaitIdle(iree_time_t deadline_ns) { + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it // requires fewer calls into the driver). IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#vkQueueWaitIdle"); - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); - return OkStatus(); + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = + VK_RESULT_TO_STATUS(syms()->vkQueueWaitIdle(queue_), "vkQueueWaitIdle"); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#Fence"); @@ -155,46 +146,52 @@ Status DirectCommandQueue::WaitIdle(Time deadline_ns) { create_info.pNext = nullptr; create_info.flags = 0; VkFence fence = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateFence( - *logical_device_, &create_info, logical_device_->allocator(), &fence)); - auto fence_cleanup = MakeCleanup([this, fence]() { - syms()->vkDestroyFence(*logical_device_, fence, - logical_device_->allocator()); - }); + VK_RETURN_IF_ERROR( + syms()->vkCreateFence(*logical_device_, &create_info, + logical_device_->allocator(), &fence), + "vkCreateFence"); uint64_t timeout_ns; - if (deadline_ns == InfinitePast()) { + if (deadline_ns == IREE_TIME_INFINITE_PAST) { // Do not wait. timeout_ns = 0; - } else if (deadline_ns == InfiniteFuture()) { + } else if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { // Wait forever. timeout_ns = UINT64_MAX; } else { // Convert to relative time in nanoseconds. - // The implementation may not wait with this granularity (like, by 10000x). - Time now_ns = Now(); + // The implementation may not wait with this granularity (like by 10000x). + iree_time_t now_ns = iree_time_now(); if (deadline_ns < now_ns) { - return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline in the past"; + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); } - timeout_ns = static_cast(deadline_ns - now_ns); + timeout_ns = (uint64_t)(deadline_ns - now_ns); } - { - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(queue_, 0, nullptr, fence)); + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = VK_RESULT_TO_STATUS( + syms()->vkQueueSubmit(queue_, 0, nullptr, fence), "vkQueueSubmit"); + iree_slim_mutex_unlock(&queue_mutex_); + + if (iree_status_is_ok(status)) { + VkResult result = syms()->vkWaitForFences(*logical_device_, 1, &fence, + VK_TRUE, timeout_ns); + switch (result) { + case VK_SUCCESS: + status = iree_ok_status(); + break; + case VK_TIMEOUT: + status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + break; + default: + status = VK_RESULT_TO_STATUS(result, "vkWaitForFences"); + break; + } } - VkResult result = - syms()->vkWaitForFences(*logical_device_, 1, &fence, VK_TRUE, timeout_ns); - switch (result) { - case VK_SUCCESS: - return OkStatus(); - case VK_TIMEOUT: - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for idle"; - default: - return VkResultToStatus(result, IREE_LOC); - } + syms()->vkDestroyFence(*logical_device_, fence, logical_device_->allocator()); + + return status; } } // namespace vulkan diff --git a/iree/hal/vulkan/direct_command_queue.h b/iree/hal/vulkan/direct_command_queue.h index 5df9e73b4a8ad..905523361e9c0 100644 --- a/iree/hal/vulkan/direct_command_queue.h +++ b/iree/hal/vulkan/direct_command_queue.h @@ -15,21 +15,8 @@ #ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ #define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" #include "iree/base/arena.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/handle_util.h" +#include "iree/hal/vulkan/command_queue.h" namespace iree { namespace hal { @@ -38,31 +25,20 @@ namespace vulkan { // Command queue implementation directly maps to VkQueue. class DirectCommandQueue final : public CommandQueue { public: - DirectCommandQueue(std::string name, - CommandCategoryBitfield supported_categories, - const ref_ptr& logical_device, + DirectCommandQueue(VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, VkQueue queue); ~DirectCommandQueue() override; - const ref_ptr& syms() const { - return logical_device_->syms(); - } - - Status Submit(absl::Span batches) override; + iree_status_t Submit(iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) override; - Status WaitIdle(Time deadline_ns) override; + iree_status_t WaitIdle(iree_time_t deadline_ns) override; private: - Status TranslateBatchInfo(const SubmissionBatch& batch, - VkSubmitInfo* submit_info, - VkTimelineSemaphoreSubmitInfo* timeline_submit_info, - Arena* arena); - - ref_ptr logical_device_; - - // VkQueue needs to be externally synchronized. - mutable absl::Mutex queue_mutex_; - VkQueue queue_ ABSL_GUARDED_BY(queue_mutex_); + iree_status_t TranslateBatchInfo( + const iree_hal_submission_batch_t* batch, VkSubmitInfo* submit_info, + VkTimelineSemaphoreSubmitInfo* timeline_submit_info, Arena* arena); }; } // namespace vulkan diff --git a/iree/hal/vulkan/dynamic_symbol_tables.h b/iree/hal/vulkan/dynamic_symbol_tables.h index b709e57baea9a..05dcd591a1ea9 100644 --- a/iree/hal/vulkan/dynamic_symbol_tables.h +++ b/iree/hal/vulkan/dynamic_symbol_tables.h @@ -300,12 +300,12 @@ namespace vulkan { DEV_PFN(OPTIONAL, vkSignalSemaphore) \ DEV_PFN(OPTIONAL, vkSignalSemaphoreKHR) \ \ - INS_PFN(OPTIONAL, vkCreateDebugReportCallbackEXT) \ + INS_PFN(EXCLUDED, vkCreateDebugReportCallbackEXT) \ INS_PFN(OPTIONAL, vkCreateDebugUtilsMessengerEXT) \ INS_PFN(EXCLUDED, vkCreateDisplayPlaneSurfaceKHR) \ INS_PFN(EXCLUDED, vkCreateHeadlessSurfaceEXT) \ INS_PFN(EXCLUDED, vkDebugReportMessageEXT) \ - INS_PFN(OPTIONAL, vkDestroyDebugReportCallbackEXT) \ + INS_PFN(EXCLUDED, vkDestroyDebugReportCallbackEXT) \ INS_PFN(OPTIONAL, vkDestroyDebugUtilsMessengerEXT) \ INS_PFN(REQUIRED, vkDestroyInstance) \ INS_PFN(EXCLUDED, vkDestroySurfaceKHR) \ diff --git a/iree/hal/vulkan/dynamic_symbols_test.cc b/iree/hal/vulkan/dynamic_symbols_test.cc index c06e6a65425f5..594673bf034b1 100644 --- a/iree/hal/vulkan/dynamic_symbols_test.cc +++ b/iree/hal/vulkan/dynamic_symbols_test.cc @@ -14,7 +14,6 @@ #include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/status_util.h" #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" @@ -58,8 +57,8 @@ TEST(DynamicSymbolsTest, CreateFromSystemLoader) { VkApplicationInfo app_info = GetApplicationInfo(); VkInstanceCreateInfo create_info = GetInstanceCreateInfo(&app_info); VkInstance instance = VK_NULL_HANDLE; - VK_CHECK_OK( - syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)); + ASSERT_EQ(VK_SUCCESS, syms->vkCreateInstance( + &create_info, /*pAllocator=*/nullptr, &instance)); IREE_ASSERT_OK(syms->LoadFromInstance(instance)); diff --git a/iree/hal/vulkan/emulated_semaphore.cc b/iree/hal/vulkan/emulated_semaphore.cc new file mode 100644 index 0000000000000..b287d5e34cb5d --- /dev/null +++ b/iree/hal/vulkan/emulated_semaphore.cc @@ -0,0 +1,634 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/emulated_semaphore.h" + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "iree/base/intrusive_list.h" +#include "iree/base/ref_ptr.h" +#include "iree/base/status.h" +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/vulkan/serializing_command_queue.h" +#include "iree/hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class EmulatedTimelineSemaphore final { + public: + EmulatedTimelineSemaphore(VkDeviceHandle* logical_device, + TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, + iree::hal::vulkan::CommandQueue** command_queues, + uint64_t initial_value); + + ~EmulatedTimelineSemaphore(); + + iree_status_t Query(uint64_t* out_value); + + iree_status_t Signal(uint64_t value); + + iree_status_t Wait(uint64_t value, iree_time_t deadline_ns); + + void Fail(iree_status_t status); + + // Gets a binary semaphore for waiting on the timeline to advance to the given + // |value|. The semaphore returned won't be waited by anyone else. Returns + // VK_NULL_HANDLE if no available semaphores for the given |value|. + // |wait_fence| is the fence associated with the queue submission that waiting + // on this semaphore. + VkSemaphore GetWaitSemaphore(uint64_t value, + const ref_ptr& wait_fence); + + // Cancels the waiting attempt on the given binary |semaphore|. This allows + // the |semaphore| to be waited by others. + iree_status_t CancelWaitSemaphore(VkSemaphore semaphore); + + // Gets a binary semaphore for signaling the timeline to the given |value|. + // |value| must be smaller than the current timeline value. |signal_fence| is + // the fence associated with the queue submission that signals this semaphore. + iree_status_t GetSignalSemaphore(uint64_t value, + const ref_ptr& signal_fence, + VkSemaphore* out_handle); + + private: + // Tries to advance the timeline to the given |to_upper_value| without + // blocking and returns whether the |to_upper_value| is reached. + iree_status_t TryToAdvanceTimeline(uint64_t to_upper_value, + bool* out_reached_upper_value) + ABSL_LOCKS_EXCLUDED(mutex_); + // Similar to the above, but also returns the fences that are known to have + // already signaled via |signaled_fences|. + iree_status_t TryToAdvanceTimeline( + uint64_t to_upper_value, bool* out_reached_upper_value, + absl::InlinedVector* out_signaled_fences) + ABSL_LOCKS_EXCLUDED(mutex_); + + std::atomic signaled_value_; + + VkDeviceHandle* logical_device_; + TimePointSemaphorePool* semaphore_pool_; + + iree_host_size_t command_queue_count_; + CommandQueue** command_queues_; + + mutable absl::Mutex mutex_; + + // A list of outstanding semaphores used to emulate time points. + // + // The life time of each semaphore is in one of the following state: + // + // * Unused state: value = UINT64_MAX, signal/wait fence = nullptr. This is + // the state of the semaphore when it's initially acquired from the pool and + // not put in the queue for emulating a time point yet. + // * Pending state: signaled value < value < UINT64_MAX, signal fence = + // , wait fence == nullptr. This is the state of the semaphore + // when it's put into the GPU queue for emulating a time point. + // * Pending and waiting state: signaled value < value < UINT64_MAX, signal + // fence = , wait fence == . This is the state of + // the semaphore when it's put into the GPU queue for emulating a time + // point and there is another queue submission waiting on it in GPU. + // * Signaled and not ever waited state: value <= signaled value, singal/wait + // fence = nullptr. This is the state of the semaphore when we know it's + // already signaled on GPU and there is no waiters for it. + // * Signaled and waiting state: value <= signaled value, signal fence = + // nullptr, wait fence = . This is the state of the semaphore + // when we know it's already signaled on GPU and there is still one queue + // submission on GPU is waiting for it. + IntrusiveList outstanding_semaphores_ + ABSL_GUARDED_BY(mutex_); + + // NOTE: We only need to access this status (and thus take the lock) when we + // want to either signal failure or query the status in the case of the + // semaphore being set to UINT64_MAX. + iree_status_t status_ ABSL_GUARDED_BY(mutex_) = iree_ok_status(); +}; + +EmulatedTimelineSemaphore::EmulatedTimelineSemaphore( + VkDeviceHandle* logical_device, TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, CommandQueue** command_queues, + uint64_t initial_value) + : signaled_value_(initial_value), + logical_device_(logical_device), + semaphore_pool_(semaphore_pool), + command_queue_count_(command_queue_count), + command_queues_(command_queues) {} + +EmulatedTimelineSemaphore::~EmulatedTimelineSemaphore() { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::dtor"); + IREE_CHECK_OK( + TryToAdvanceTimeline(UINT64_MAX, /*out_reached_upper_value=*/NULL)); + absl::MutexLock lock(&mutex_); + IREE_CHECK(outstanding_semaphores_.empty()) + << "Destroying an emulated timeline semaphore without first waiting on " + "outstanding signals"; + iree_status_free(status_); +} + +iree_status_t EmulatedTimelineSemaphore::Query(uint64_t* out_value) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Query"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Query"; + IREE_RETURN_IF_ERROR( + TryToAdvanceTimeline(UINT64_MAX, /*out_reached_upper_value=*/NULL)); + uint64_t value = signaled_value_.load(); + IREE_DVLOG(2) << "Current timeline value: " << value; + if (value == UINT64_MAX) { + absl::MutexLock lock(&mutex_); + return iree_status_clone(status_); + } + *out_value = value; + return iree_ok_status(); +} + +iree_status_t EmulatedTimelineSemaphore::Signal(uint64_t value) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Signal"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Signal"; + auto signaled_value = signaled_value_.exchange(value); + IREE_DVLOG(2) << "Previous value: " << signaled_value + << "; new value: " << value; + // Make sure the previous signaled value is smaller than the new value. + IREE_CHECK(signaled_value < value) + << "Attempting to signal a timeline value out of order; trying " << value + << " but " << signaled_value << " already signaled"; + + // Inform the device to make progress given we have a new value signaled now. + for (iree_host_size_t i = 0; i < command_queue_count_; ++i) { + IREE_RETURN_IF_ERROR(((SerializingCommandQueue*)command_queues_[i]) + ->AdvanceQueueSubmission()); + } + + return iree_ok_status(); +} + +iree_status_t EmulatedTimelineSemaphore::Wait(uint64_t value, + iree_time_t deadline_ns) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Wait"; + + VkFence fence = VK_NULL_HANDLE; + do { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait#loop"); + // First try to advance the timeline without blocking to see whether we've + // already reached the desired value. + bool reached_desired_value = false; + IREE_RETURN_IF_ERROR(TryToAdvanceTimeline(value, &reached_desired_value)); + if (reached_desired_value) return iree_ok_status(); + + // We must wait now. Find the first emulated time point that has a value >= + // the desired value so we can wait on its associated signal fence to make + // sure the timeline is advanced to the desired value. + absl::MutexLock lock(&mutex_); + auto semaphore = outstanding_semaphores_.begin(); + for (; semaphore != outstanding_semaphores_.end(); ++semaphore) { + if ((*semaphore)->value >= value) break; + } + if (semaphore != outstanding_semaphores_.end()) { + if (!(*semaphore)->signal_fence) { + return iree_make_status(IREE_STATUS_INTERNAL, + "timeline should have a signal fence for the " + "first time point beyond the signaled value"); + } + IREE_DVLOG(2) << "Found timepoint semaphore " << *semaphore + << " (value: " << (*semaphore)->value + << ") to wait for desired timeline value: " << value; + fence = (*semaphore)->signal_fence->value(); + // Found; we can break the loop and proceed to waiting now. + break; + } + // TODO(antiagainst): figure out a better way instead of the busy loop here. + } while (iree_time_now() < deadline_ns); + + if (fence == VK_NULL_HANDLE) { + // NOTE: not an error; it may be expected that the semaphore is not ready. + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + + uint64_t timeout_ns = + static_cast(iree_absolute_deadline_to_timeout_ns(deadline_ns)); + VK_RETURN_IF_ERROR(logical_device_->syms()->vkWaitForFences( + *logical_device_, /*fenceCount=*/1, &fence, + /*waitAll=*/true, timeout_ns), + "vkWaitForFences"); + + return TryToAdvanceTimeline(value, /*out_reached_upper_value=*/NULL); +} + +void EmulatedTimelineSemaphore::Fail(iree_status_t status) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Fail"); + absl::MutexLock lock(&mutex_); + if (status_) return; + status_ = status; + signaled_value_.store(UINT64_MAX); +} + +VkSemaphore EmulatedTimelineSemaphore::GetWaitSemaphore( + uint64_t value, const ref_ptr& wait_fence) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetWaitSemaphore"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetWaitSemaphore"; + + absl::MutexLock lock(&mutex_); + + VkSemaphore semaphore = VK_NULL_HANDLE; + for (TimePointSemaphore* point : outstanding_semaphores_) { + if (point->value > value && point->wait_fence) { + point->wait_fence = add_ref(wait_fence); + semaphore = point->semaphore; + break; + } + } + + IREE_DVLOG(2) << "Binary VkSemaphore to wait on for timeline value (" << value + << ") and wait fence (" << wait_fence.get() + << "): " << semaphore; + + return semaphore; +} + +iree_status_t EmulatedTimelineSemaphore::CancelWaitSemaphore( + VkSemaphore semaphore) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::CancelWaitSemaphore"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::CancelWaitSemaphore"; + + absl::MutexLock lock(&mutex_); + for (TimePointSemaphore* point : outstanding_semaphores_) { + if (point->semaphore != semaphore) continue; + + if (!point->wait_fence) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "time point wasn't waited before"); + } + point->wait_fence = nullptr; + IREE_DVLOG(2) << "Cancelled waiting on binary VkSemaphore: " << semaphore; + return iree_ok_status(); + } + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no time point for the given semaphore"); +} + +iree_status_t EmulatedTimelineSemaphore::GetSignalSemaphore( + uint64_t value, const ref_ptr& signal_fence, + VkSemaphore* out_handle) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetSignalSemaphore"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetSignalSemaphore"; + + if (signaled_value_.load() >= value) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "timeline semaphore already signaled past %" PRIu64, + value); + } + + absl::MutexLock lock(&mutex_); + + auto insertion_point = outstanding_semaphores_.begin(); + while (insertion_point != outstanding_semaphores_.end()) { + if ((*insertion_point)->value > value) break; + } + + IREE_ASSIGN_OR_RETURN(TimePointSemaphore * semaphore, + semaphore_pool_->Acquire()); + semaphore->value = value; + semaphore->signal_fence = add_ref(signal_fence); + if (semaphore->wait_fence) { + return iree_make_status( + IREE_STATUS_INTERNAL, + "newly acquired time point semaphore should not have waiters"); + } + outstanding_semaphores_.insert(insertion_point, semaphore); + IREE_DVLOG(2) << "Timepoint semaphore to signal for timeline value (" << value + << ") and wait fence (" << signal_fence.get() + << "): " << semaphore + << " (binary VkSemaphore: " << semaphore->semaphore << ")"; + + *out_handle = semaphore->semaphore; + return iree_ok_status(); +} + +iree_status_t EmulatedTimelineSemaphore::TryToAdvanceTimeline( + uint64_t to_upper_value, bool* out_reached_upper_value) { + absl::InlinedVector signaled_fences; + iree_status_t status = TryToAdvanceTimeline( + to_upper_value, out_reached_upper_value, &signaled_fences); + // Inform the queue that some fences are known to have signaled. This should + // happen here instead of inside the other TryToAdvanceTimeline to avoid + // potential mutex deadlock, given here we are not holding a mutex anymore. + if (!signaled_fences.empty()) { + for (iree_host_size_t i = 0; i < command_queue_count_; ++i) { + ((SerializingCommandQueue*)command_queues_[i]) + ->SignalFences(absl::MakeSpan(signaled_fences)); + } + } + return status; +} + +iree_status_t EmulatedTimelineSemaphore::TryToAdvanceTimeline( + uint64_t to_upper_value, bool* out_reached_upper_value, + absl::InlinedVector* out_signaled_fences) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::TryToAdvanceTimeline"); + IREE_DVLOG(3) << "EmulatedTimelineSemaphore::TryToAdvanceTimeline"; + if (out_reached_upper_value) *out_reached_upper_value = false; + + uint64_t past_value = signaled_value_.load(); + IREE_DVLOG(3) << "Current timeline value: " << past_value + << "; desired timeline value: " << to_upper_value; + + // Fast path for when already signaled past the desired value. + if (past_value >= to_upper_value) { + if (out_reached_upper_value) *out_reached_upper_value = true; + return iree_ok_status(); + } + + // We hold the lock during the entire resolve process so that we can resolve + // to the furthest possible value. + absl::MutexLock lock(&mutex_); + + IREE_DVLOG(3) << "# outstanding semaphores: " + << outstanding_semaphores_.size(); + + // The timeline has not signaled past the desired value and there is no + // binary semaphore pending on GPU yet: certainly the timeline cannot + // advance to the desired value. + if (outstanding_semaphores_.empty()) return iree_ok_status(); + + IntrusiveList resolved_semaphores; + + auto clear_signal_fence = + [&out_signaled_fences](ref_ptr& fence) { + if (fence) { + if (out_signaled_fences) + out_signaled_fences->push_back(fence->value()); + fence.reset(); + } + }; + + bool keep_resolving = true; + bool reached_desired_value = false; + while (keep_resolving && !outstanding_semaphores_.empty()) { + auto* semaphore = outstanding_semaphores_.front(); + IREE_DVLOG(3) << "Looking at timepoint semaphore " << semaphore << ".."; + IREE_DVLOG(3) << " value: " << semaphore->value; + IREE_DVLOG(3) << " VkSemaphore: " << semaphore->semaphore; + IREE_DVLOG(3) << " signal fence: " << semaphore->signal_fence.get(); + IREE_DVLOG(3) << " wait fence: " << semaphore->wait_fence.get(); + + // If the current semaphore is for a value beyond our upper limit, then + // early exit so that we don't spend time dealing with signals we don't yet + // care about. This can prevent live lock where one thread is signaling + // fences as fast/faster than another thread can consume them. + if (semaphore->value > to_upper_value) { + keep_resolving = false; + reached_desired_value = true; + break; + } + + // If the current semaphore is for a value not greater than the past + // signaled value, then we know it was signaled previously. But there might + // be a waiter on it on GPU. + if (semaphore->value <= past_value) { + if (semaphore->signal_fence) { + return iree_make_status(IREE_STATUS_INTERNAL, + "timeline should already signaled past this " + "time point and cleared the signal fence"); + } + + // If ther is no waiters, we can recycle this semaphore now. If there + // exists one waiter, then query its status and recycle on success. We + // only handle success status here. Others will be handled when the fence + // is checked for other semaphores' signaling status for the same queue + // submission. + if (!semaphore->wait_fence || + semaphore->wait_fence->GetStatus() == VK_SUCCESS) { + clear_signal_fence(semaphore->signal_fence); + semaphore->wait_fence = nullptr; + outstanding_semaphores_.erase(semaphore); + resolved_semaphores.push_back(semaphore); + IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; + } + + continue; + } + + // This semaphore represents a value gerater than the known previously + // signaled value. We don't know its status so we need to really query now. + + if (!semaphore->signal_fence) { + return iree_make_status(IREE_STATUS_INTERNAL, + "status of this time point in the timeline " + "should still be pending with a singal fence"); + } + VkResult signal_status = semaphore->signal_fence->GetStatus(); + + switch (signal_status) { + case VK_SUCCESS: + IREE_DVLOG(3) << "..semaphore signaled"; + signaled_value_.store(semaphore->value); + clear_signal_fence(semaphore->signal_fence); + // If no waiters, we can recycle this semaphore now. + if (!semaphore->wait_fence) { + semaphore->wait_fence = nullptr; + outstanding_semaphores_.erase(semaphore); + resolved_semaphores.push_back(semaphore); + IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; + } + break; + case VK_NOT_READY: + // The fence has not been signaled yet so this is the furthest time + // point we can go in this timeline. + keep_resolving = false; + IREE_DVLOG(3) << "..semaphore not yet signaled"; + break; + default: + // Fence indicates an error (device lost, out of memory, etc). + // Propagate this back to our status (and thus any waiters). + // Since we only take the first error we find we skip all remaining + // fences. + keep_resolving = false; + clear_signal_fence(semaphore->signal_fence); + status_ = VK_RESULT_TO_STATUS(signal_status, "signal status"); + signaled_value_.store(UINT64_MAX); + break; + } + } + + IREE_DVLOG(3) << "Releasing " << resolved_semaphores.size() + << " resolved semaphores; " << outstanding_semaphores_.size() + << " still outstanding"; + semaphore_pool_->ReleaseResolved(&resolved_semaphores); + if (!iree_status_is_ok(status_)) { + for (iree_host_size_t i = 0; i < command_queue_count_; ++i) { + ((SerializingCommandQueue*)command_queues_[i])->AbortQueueSubmission(); + } + semaphore_pool_->ReleaseUnresolved(&outstanding_semaphores_); + return status_; + } + + if (out_reached_upper_value) *out_reached_upper_value = reached_desired_value; + return iree_ok_status(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree + +using namespace iree::hal::vulkan; + +// Wrap the C++ type above so that we have a somewhat normal C interface. +// Porting the above to C is ideal but since this is just a fallback layer I'm +// not sure it's worth it (given that we may require Vulkan 1.2 with timeline +// semaphores built in at some point soon). +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + EmulatedTimelineSemaphore* handle; +} iree_hal_vulkan_emulated_semaphore_t; + +extern const iree_hal_semaphore_vtable_t + iree_hal_vulkan_emulated_semaphore_vtable; + +static EmulatedTimelineSemaphore* iree_hal_vulkan_emulated_semaphore_cast( + iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_emulated_semaphore_vtable); + return ((iree_hal_vulkan_emulated_semaphore_t*)base_value)->handle; +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, + iree::hal::vulkan::CommandQueue** command_queues, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + iree_hal_vulkan_emulated_semaphore_t* semaphore = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*semaphore), + (void**)&semaphore)); + iree_hal_resource_initialize(&iree_hal_vulkan_emulated_semaphore_vtable, + &semaphore->resource); + semaphore->host_allocator = logical_device->host_allocator(); + semaphore->handle = new EmulatedTimelineSemaphore( + logical_device, semaphore_pool, command_queue_count, command_queues, + initial_value); + + *out_semaphore = (iree_hal_semaphore_t*)semaphore; + return iree_ok_status(); +} + +static void iree_hal_vulkan_emulated_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_vulkan_emulated_semaphore_t* semaphore = + (iree_hal_vulkan_emulated_semaphore_t*)base_semaphore; + iree_allocator_t host_allocator = semaphore->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + delete semaphore->handle; + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_wait_handle( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + const iree::ref_ptr& wait_fence, + VkSemaphore* out_handle) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + *out_handle = semaphore->GetWaitSemaphore(value, wait_fence); + return iree_ok_status(); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_cancel_wait_handle( + iree_hal_semaphore_t* base_semaphore, VkSemaphore handle) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->CancelWaitSemaphore(handle); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_signal_handle( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + const iree::ref_ptr& signal_fence, + VkSemaphore* out_handle) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->GetSignalSemaphore(value, signal_fence, out_handle); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->Query(out_value); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->Signal(new_value); +} + +static void iree_hal_vulkan_emulated_semaphore_fail( + iree_hal_semaphore_t* base_semaphore, iree_status_t status) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + semaphore->Fail(status); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_wait_with_deadline( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_time_t deadline_ns) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->Wait(value, deadline_ns); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_wait_with_timeout( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_duration_t timeout_ns) { + return iree_hal_vulkan_emulated_semaphore_wait_with_deadline( + base_semaphore, value, iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags) { + // TODO(antiagainst): We actually should get the fences associated with the + // emulated timeline semaphores so that we can wait them in a bunch. This + // implementation is problematic if we wait to wait any and we have the + // first semaphore taking extra long time but the following ones signal + // quickly. + for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_emulated_semaphore_wait_with_deadline( + semaphore_list->semaphores[i], semaphore_list->payload_values[i], + deadline_ns)); + if (wait_flags & VK_SEMAPHORE_WAIT_ANY_BIT) return iree_ok_status(); + } + return iree_ok_status(); +} + +const iree_hal_semaphore_vtable_t iree_hal_vulkan_emulated_semaphore_vtable = { + /*.destroy=*/iree_hal_vulkan_emulated_semaphore_destroy, + /*.query=*/iree_hal_vulkan_emulated_semaphore_query, + /*.signal=*/iree_hal_vulkan_emulated_semaphore_signal, + /*.fail=*/iree_hal_vulkan_emulated_semaphore_fail, + /*.wait_with_deadline=*/ + iree_hal_vulkan_emulated_semaphore_wait_with_deadline, + /*.wait_with_timeout=*/ + iree_hal_vulkan_emulated_semaphore_wait_with_timeout, +}; diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.h b/iree/hal/vulkan/emulated_semaphore.h similarity index 50% rename from iree/hal/vulkan/emulated_timeline_semaphore.h rename to iree/hal/vulkan/emulated_semaphore.h index 30c83d9345447..28af0c651879d 100644 --- a/iree/hal/vulkan/emulated_timeline_semaphore.h +++ b/iree/hal/vulkan/emulated_semaphore.h @@ -12,31 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_ -#define IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_ +#ifndef IREE_HAL_VULKAN_ENUMLATED_SEMAPHORE_H_ +#define IREE_HAL_VULKAN_ENUMLATED_SEMAPHORE_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/semaphore.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/command_queue.h" #include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/timepoint_util.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -// A timeline semaphore emulated via `VkFence`s and binary `VkSemaphore`s. +// Creates a timeline semaphore emulated via `VkFence`s and binary +// `VkSemaphore`s. // // Vulkan provides several explicit synchronization primitives: fences, // (binary/timeline) semaphores, events, pipeline barriers, and render passes. @@ -126,111 +115,51 @@ namespace vulkan { // synchronization primitives. So this should not be treated as a full // emulation of the Vulkan spec and thus does not substitute // Vulkan-ExtensionLayer. -class EmulatedTimelineSemaphore final : public Semaphore { - public: - // Creates a timeline semaphore with the given |initial_value|. - static StatusOr> Create( - ref_ptr logical_device, - std::function on_semaphore_signal, - std::function on_semaphore_failure, - std::function)> on_fence_signal, - ref_ptr semaphore_pool, uint64_t initial_value); - - EmulatedTimelineSemaphore( - ref_ptr logical_device, - std::function on_semaphore_signal, - std::function on_semaphore_failure, - std::function)> on_fence_signal, - ref_ptr semaphore_pool, uint64_t initial_value); - - ~EmulatedTimelineSemaphore() override; - - StatusOr Query() override; - - Status Signal(uint64_t value) override; - - Status Wait(uint64_t value, Time deadline_ns) override; - - void Fail(Status status) override; - - // Gets a binary semaphore for waiting on the timeline to advance to the given - // |value|. The semaphore returned won't be waited by anyone else. Returns - // VK_NULL_HANDLE if no available semaphores for the given |value|. - // |wait_fence| is the fence associated with the queue submission that waiting - // on this semaphore. - VkSemaphore GetWaitSemaphore(uint64_t value, - const ref_ptr& wait_fence); - - // Cancels the waiting attempt on the given binary |semaphore|. This allows - // the |semaphore| to be waited by others. - Status CancelWaitSemaphore(VkSemaphore semaphore); - - // Gets a binary semaphore for signaling the timeline to the given |value|. - // |value| must be smaller than the current timeline value. |signal_fence| is - // the fence associated with the queue submission that signals this semaphore. - StatusOr GetSignalSemaphore( - uint64_t value, const ref_ptr& signal_fence); - - private: - // Tries to advance the timeline to the given |to_upper_value| without - // blocking and returns whether the |to_upper_value| is reached. - StatusOr TryToAdvanceTimeline(uint64_t to_upper_value) - ABSL_LOCKS_EXCLUDED(mutex_); - // Similar to the above, but also returns the fences that are known to have - // already signaled via |signaled_fences|. - StatusOr TryToAdvanceTimeline( - uint64_t to_upper_value, absl::InlinedVector* signaled_fences) - ABSL_LOCKS_EXCLUDED(mutex_); - - std::atomic signaled_value_; - - ref_ptr logical_device_; - - // Callback to inform that this timeline semaphore has signaled a new value. - std::function on_semaphore_signal_; - - // Callback to inform that this timeline semaphore has encountered a failure. - std::function on_semaphore_failure_; - - // Callback to inform that the given fences have signaled. - std::function)> on_fence_signal_; - - ref_ptr semaphore_pool_; - - mutable absl::Mutex mutex_; - - // A list of outstanding semaphores used to emulate time points. - // - // The life time of each semaphore is in one of the following state: - // - // * Unused state: value = UINT64_MAX, signal/wait fence = nullptr. This is - // the state of the semaphore when it's initially acquired from the pool and - // not put in the queue for emulating a time point yet. - // * Pending state: signaled value < value < UINT64_MAX, signal fence = - // , wait fence == nullptr. This is the state of the semaphore - // when it's put into the GPU queue for emulating a time point. - // * Pending and waiting state: signaled value < value < UINT64_MAX, signal - // fence = , wait fence == . This is the state of - // the semaphore when it's put into the GPU queue for emulating a time - // point and there is another queue submission waiting on it in GPU. - // * Signaled and not ever waited state: value <= signaled value, singal/wait - // fence = nullptr. This is the state of the semaphore when we know it's - // already signaled on GPU and there is no waiters for it. - // * Signaled and waiting state: value <= signaled value, signal fence = - // nullptr, wait fence = . This is the state of the semaphore - // when we know it's already signaled on GPU and there is still one queue - // submission on GPU is waiting for it. - IntrusiveList outstanding_semaphores_ - ABSL_GUARDED_BY(mutex_); - - // NOTE: We only need to access this status (and thus take the lock) when we - // want to either signal failure or query the status in the case of the - // semaphore being set to UINT64_MAX. - Status status_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_ +iree_status_t iree_hal_vulkan_emulated_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, + iree::hal::vulkan::CommandQueue** command_queues, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + +// Acquires a binary semaphore for waiting on the timeline to advance to the +// given |value|. The semaphore returned won't be waited by anyone else. +// |wait_fence| is the fence associated with the queue submission that waiting +// on this semaphore. +// +// Returns VK_NULL_HANDLE if there are no available semaphores for the given +// |value|. +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_wait_handle( + iree_hal_semaphore_t* semaphore, uint64_t value, + const iree::ref_ptr& wait_fence, + VkSemaphore* out_handle); + +// Cancels the waiting attempt on the given binary |semaphore|. This allows +// the |semaphore| to be waited by others. +iree_status_t iree_hal_vulkan_emulated_semaphore_cancel_wait_handle( + iree_hal_semaphore_t* semaphore, VkSemaphore handle); + +// Acquires a binary semaphore for signaling the timeline to the given |value|. +// |value| must be smaller than the current timeline value. |signal_fence| is +// the fence associated with the queue submission that signals this semaphore. +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_signal_handle( + iree_hal_semaphore_t* semaphore, uint64_t value, + const iree::ref_ptr& signal_fence, + VkSemaphore* out_handle); + +// Performs a multi-wait on one or more semaphores. +// By default this is an all-wait but |wait_flags| may contain +// VK_SEMAPHORE_WAIT_ANY_BIT to change to an any-wait. +// +// Returns IREE_STATUS_DEADLINE_EXCEEDED if the wait does not complete before +// |deadline_ns| elapses. +iree_status_t iree_hal_vulkan_emulated_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_ENUMLATED_SEMAPHORE_H_ diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.cc b/iree/hal/vulkan/emulated_timeline_semaphore.cc deleted file mode 100644 index 475aa33c91314..0000000000000 --- a/iree/hal/vulkan/emulated_timeline_semaphore.cc +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" - -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/time.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// static -StatusOr> EmulatedTimelineSemaphore::Create( - ref_ptr logical_device, - std::function on_semaphore_signal, - std::function on_semaphore_failure, - std::function)> on_fence_signal, - ref_ptr semaphore_pool, uint64_t initial_value) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Create"); - return make_ref( - std::move(logical_device), std::move(on_semaphore_signal), - std::move(on_semaphore_failure), std::move(on_fence_signal), - std::move(semaphore_pool), initial_value); -} - -EmulatedTimelineSemaphore::EmulatedTimelineSemaphore( - ref_ptr logical_device, - std::function on_semaphore_signal, - std::function on_semaphore_failure, - std::function)> on_fence_signal, - ref_ptr semaphore_pool, uint64_t initial_value) - : signaled_value_(initial_value), - logical_device_(std::move(logical_device)), - on_semaphore_signal_(std::move(on_semaphore_signal)), - on_semaphore_failure_(std::move(on_semaphore_failure)), - on_fence_signal_(std::move(on_fence_signal)), - semaphore_pool_(std::move(semaphore_pool)) {} - -EmulatedTimelineSemaphore::~EmulatedTimelineSemaphore() { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::dtor"); - IREE_CHECK_OK(TryToAdvanceTimeline(UINT64_MAX).status()); - absl::MutexLock lock(&mutex_); - IREE_CHECK(outstanding_semaphores_.empty()) - << "Destroying an emulated timeline semaphore without first waiting on " - "outstanding signals"; -} - -StatusOr EmulatedTimelineSemaphore::Query() { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Query"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Query"; - IREE_ASSIGN_OR_RETURN(bool signaled, TryToAdvanceTimeline(UINT64_MAX)); - (void)signaled; - uint64_t value = signaled_value_.load(); - IREE_DVLOG(2) << "Current timeline value: " << value; - if (value == UINT64_MAX) { - absl::MutexLock lock(&mutex_); - return status_; - } - return value; -} - -Status EmulatedTimelineSemaphore::Signal(uint64_t value) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Signal"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Signal"; - auto signaled_value = signaled_value_.exchange(value); - IREE_DVLOG(2) << "Previous value: " << signaled_value - << "; new value: " << value; - // Make sure the previous signaled value is smaller than the new value. - IREE_CHECK(signaled_value < value) - << "Attempting to signal a timeline value out of order; trying " << value - << " but " << signaled_value << " already signaled"; - - // Inform the device to make progress given we have a new value signaled now. - IREE_RETURN_IF_ERROR(on_semaphore_signal_(this)); - - return OkStatus(); -} - -Status EmulatedTimelineSemaphore::Wait(uint64_t value, Time deadline_ns) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Wait"; - - VkFence fence = VK_NULL_HANDLE; - do { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait#loop"); - // First try to advance the timeline without blocking to see whether we've - // already reached the desired value. - IREE_ASSIGN_OR_RETURN(bool reached_desired_value, - TryToAdvanceTimeline(value)); - if (reached_desired_value) return OkStatus(); - - // We must wait now. Find the first emulated time point that has a value >= - // the desired value so we can wait on its associated signal fence to make - // sure the timeline is advanced to the desired value. - absl::MutexLock lock(&mutex_); - auto semaphore = outstanding_semaphores_.begin(); - for (; semaphore != outstanding_semaphores_.end(); ++semaphore) { - if ((*semaphore)->value >= value) break; - } - if (semaphore != outstanding_semaphores_.end()) { - if (!(*semaphore)->signal_fence) { - return InternalErrorBuilder(IREE_LOC) - << "Timeline should have a signal fence for the first time " - "point beyond the signaled value"; - } - IREE_DVLOG(2) << "Found timepoint semaphore " << *semaphore - << " (value: " << (*semaphore)->value - << ") to wait for desired timeline value: " << value; - fence = (*semaphore)->signal_fence->value(); - // Found; we can break the loop and proceed to waiting now. - break; - } - // TODO(antiagainst): figure out a better way instead of the busy loop here. - } while (Now() < deadline_ns); - - if (fence == VK_NULL_HANDLE) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline reached when waiting timeline semaphore"; - } - - uint64_t timeout_ns = - static_cast(DeadlineToRelativeTimeoutNanos(deadline_ns)); - VK_RETURN_IF_ERROR(logical_device_->syms()->vkWaitForFences( - *logical_device_, /*fenceCount=*/1, &fence, /*waitAll=*/true, - timeout_ns)); - - return TryToAdvanceTimeline(value).status(); -} - -void EmulatedTimelineSemaphore::Fail(Status status) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Fail"); - absl::MutexLock lock(&mutex_); - status_ = std::move(status); - signaled_value_.store(UINT64_MAX); -} - -VkSemaphore EmulatedTimelineSemaphore::GetWaitSemaphore( - uint64_t value, const ref_ptr& wait_fence) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetWaitSemaphore"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetWaitSemaphore"; - - absl::MutexLock lock(&mutex_); - - VkSemaphore semaphore = VK_NULL_HANDLE; - for (TimePointSemaphore* point : outstanding_semaphores_) { - if (point->value > value && point->wait_fence) { - point->wait_fence = add_ref(wait_fence); - semaphore = point->semaphore; - break; - } - } - - IREE_DVLOG(2) << "Binary VkSemaphore to wait on for timeline value (" << value - << ") and wait fence (" << wait_fence.get() - << "): " << semaphore; - - return semaphore; -} - -Status EmulatedTimelineSemaphore::CancelWaitSemaphore(VkSemaphore semaphore) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::CancelWaitSemaphore"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::CancelWaitSemaphore"; - - absl::MutexLock lock(&mutex_); - for (TimePointSemaphore* point : outstanding_semaphores_) { - if (point->semaphore != semaphore) continue; - - if (!point->wait_fence) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Time point wasn't waited before"; - } - point->wait_fence = nullptr; - IREE_DVLOG(2) << "Cancelled waiting on binary VkSemaphore: " << semaphore; - return OkStatus(); - } - return InvalidArgumentErrorBuilder(IREE_LOC) - << "No time point for the given semaphore"; -} - -StatusOr EmulatedTimelineSemaphore::GetSignalSemaphore( - uint64_t value, const ref_ptr& signal_fence) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetSignalSemaphore"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetSignalSemaphore"; - - if (signaled_value_.load() >= value) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Timeline semaphore already signaled past " << value; - } - - absl::MutexLock lock(&mutex_); - - auto insertion_point = outstanding_semaphores_.begin(); - while (insertion_point != outstanding_semaphores_.end()) { - if ((*insertion_point)->value > value) break; - } - - IREE_ASSIGN_OR_RETURN(TimePointSemaphore * semaphore, - semaphore_pool_->Acquire()); - semaphore->value = value; - semaphore->signal_fence = add_ref(signal_fence); - if (semaphore->wait_fence) { - return InternalErrorBuilder(IREE_LOC) - << "Newly acquired time point semaphore should not have waiters"; - } - outstanding_semaphores_.insert(insertion_point, semaphore); - IREE_DVLOG(2) << "Timepoint semaphore to signal for timeline value (" << value - << ") and wait fence (" << signal_fence.get() - << "): " << semaphore - << " (binary VkSemaphore: " << semaphore->semaphore << ")"; - - return semaphore->semaphore; -} - -StatusOr EmulatedTimelineSemaphore::TryToAdvanceTimeline( - uint64_t to_upper_value) { - absl::InlinedVector signaled_fences; - auto status = TryToAdvanceTimeline(to_upper_value, &signaled_fences); - // Inform the queue that some fences are known to have signaled. This should - // happen here instead of inside the other TryToAdvanceTimeline to avoid - // potential mutex deadlock, given here we are not holding a mutex anymore. - if (!signaled_fences.empty()) { - on_fence_signal_(absl::MakeSpan(signaled_fences)); - } - return status; -} - -StatusOr EmulatedTimelineSemaphore::TryToAdvanceTimeline( - uint64_t to_upper_value, absl::InlinedVector* signaled_fences) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::TryToAdvanceTimeline"); - IREE_DVLOG(3) << "EmulatedTimelineSemaphore::TryToAdvanceTimeline"; - - uint64_t past_value = signaled_value_.load(); - IREE_DVLOG(3) << "Current timeline value: " << past_value - << "; desired timeline value: " << to_upper_value; - - // Fast path for when already signaled past the desired value. - if (past_value >= to_upper_value) return true; - - // We hold the lock during the entire resolve process so that we can resolve - // to the furthest possible value. - absl::MutexLock lock(&mutex_); - - IREE_DVLOG(3) << "# outstanding semaphores: " - << outstanding_semaphores_.size(); - - // The timeline has not signaled past the desired value and there is no - // binary semaphore pending on GPU yet: certainly the timeline cannot - // advance to the desired value. - if (outstanding_semaphores_.empty()) return false; - - IntrusiveList resolved_semaphores; - - auto clear_signal_fence = [&signaled_fences](ref_ptr& fence) { - if (fence) { - if (signaled_fences) signaled_fences->push_back(fence->value()); - fence = nullptr; - } - }; - - bool keep_resolving = true; - bool reached_desired_value = false; - while (keep_resolving && !outstanding_semaphores_.empty()) { - auto* semaphore = outstanding_semaphores_.front(); - IREE_DVLOG(3) << "Looking at timepoint semaphore " << semaphore << ".."; - IREE_DVLOG(3) << " value: " << semaphore->value; - IREE_DVLOG(3) << " VkSemaphore: " << semaphore->semaphore; - IREE_DVLOG(3) << " signal fence: " << semaphore->signal_fence.get(); - IREE_DVLOG(3) << " wait fence: " << semaphore->wait_fence.get(); - - // If the current semaphore is for a value beyond our upper limit, then - // early exit so that we don't spend time dealing with signals we don't yet - // care about. This can prevent live lock where one thread is signaling - // fences as fast/faster than another thread can consume them. - if (semaphore->value > to_upper_value) { - keep_resolving = false; - reached_desired_value = true; - break; - } - - // If the current semaphore is for a value not greater than the past - // signaled value, then we know it was signaled previously. But there might - // be a waiter on it on GPU. - if (semaphore->value <= past_value) { - if (semaphore->signal_fence) { - return InternalErrorBuilder(IREE_LOC) - << "Timeline should already signaled past this time point and " - "cleared the signal fence"; - } - - // If ther is no waiters, we can recycle this semaphore now. If there - // exists one waiter, then query its status and recycle on success. We - // only handle success status here. Others will be handled when the fence - // is checked for other semaphores' signaling status for the same queue - // submission. - if (!semaphore->wait_fence || - semaphore->wait_fence->GetStatus() == VK_SUCCESS) { - clear_signal_fence(semaphore->signal_fence); - semaphore->wait_fence = nullptr; - outstanding_semaphores_.erase(semaphore); - resolved_semaphores.push_back(semaphore); - IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; - } - - continue; - } - - // This semaphore represents a value gerater than the known previously - // signaled value. We don't know its status so we need to really query now. - - if (!semaphore->signal_fence) { - return InternalErrorBuilder(IREE_LOC) - << "The status of this time point in the timeline should still be " - "pending with a singal fence"; - } - VkResult signal_status = semaphore->signal_fence->GetStatus(); - - switch (signal_status) { - case VK_SUCCESS: - IREE_DVLOG(3) << "..semaphore signaled"; - signaled_value_.store(semaphore->value); - clear_signal_fence(semaphore->signal_fence); - // If no waiters, we can recycle this semaphore now. - if (!semaphore->wait_fence) { - semaphore->wait_fence = nullptr; - outstanding_semaphores_.erase(semaphore); - resolved_semaphores.push_back(semaphore); - IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; - } - break; - case VK_NOT_READY: - // The fence has not been signaled yet so this is the furthest time - // point we can go in this timeline. - keep_resolving = false; - IREE_DVLOG(3) << "..semaphore not yet signaled"; - break; - default: - // Fence indicates an error (device lost, out of memory, etc). - // Propagate this back to our status (and thus any waiters). - // Since we only take the first error we find we skip all remaining - // fences. - keep_resolving = false; - clear_signal_fence(semaphore->signal_fence); - status_ = VkResultToStatus(signal_status, IREE_LOC); - signaled_value_.store(UINT64_MAX); - break; - } - } - - IREE_DVLOG(3) << "Releasing " << resolved_semaphores.size() - << " resolved semaphores; " << outstanding_semaphores_.size() - << " still outstanding"; - semaphore_pool_->ReleaseResolved(&resolved_semaphores); - if (!status_.ok()) { - on_semaphore_failure_(this); - semaphore_pool_->ReleaseUnresolved(&outstanding_semaphores_); - return status_; - } - - return reached_desired_value; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/extensibility_util.cc b/iree/hal/vulkan/extensibility_util.cc index d78892bf211b7..7320cd34c225a 100644 --- a/iree/hal/vulkan/extensibility_util.cc +++ b/iree/hal/vulkan/extensibility_util.cc @@ -14,195 +14,213 @@ #include "iree/hal/vulkan/extensibility_util.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" #include "iree/hal/vulkan/status_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -StatusOr> MatchAvailableLayers( - absl::Span required_layers, - absl::Span optional_layers, - absl::Span properties) { - IREE_TRACE_SCOPE0("MatchAvailableLayers"); - - std::vector enabled_layers; - enabled_layers.reserve(required_layers.size() + optional_layers.size()); - - for (const char* layer_name : required_layers) { - bool found = false; - for (const auto& layer_properties : properties) { - if (std::strcmp(layer_name, layer_properties.layerName) == 0) { - IREE_VLOG(1) << "Enabling required layer: " << layer_name; - found = true; - enabled_layers.push_back(layer_name); - break; - } - } - if (!found) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required layer " << layer_name << " not available"; +// Returns true if |layers| contains a layer matching |layer_name|. +static bool iree_hal_vulkan_layer_list_contains(uint32_t layer_count, + const VkLayerProperties* layers, + const char* layer_name) { + for (uint32_t i = 0; i < layer_count; ++i) { + if (strcmp(layer_name, layers[i].layerName) == 0) { + return true; } } + return false; +} - for (const char* layer_name : optional_layers) { - bool found = false; - for (const auto& layer_properties : properties) { - if (std::strcmp(layer_name, layer_properties.layerName) == 0) { - IREE_VLOG(1) << "Enabling optional layer: " << layer_name; - found = true; - enabled_layers.push_back(layer_name); - break; - } +static iree_status_t iree_hal_vulkan_match_available_layers( + iree_host_size_t available_layers_count, + const VkLayerProperties* available_layers, + const iree_hal_vulkan_string_list_t* required_layers, + const iree_hal_vulkan_string_list_t* optional_layers, + iree_hal_vulkan_string_list_t* out_enabled_layers) { + memset(out_enabled_layers->values, 0, + (required_layers->count + optional_layers->count) * + sizeof(out_enabled_layers->values[0])); + + for (iree_host_size_t i = 0; i < required_layers->count; ++i) { + const char* layer_name = required_layers->values[i]; + if (!iree_hal_vulkan_layer_list_contains(available_layers_count, + available_layers, layer_name)) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "required layer %s not available", layer_name); } - if (!found) { - IREE_VLOG(1) << "Optional layer " << layer_name << " not available"; + out_enabled_layers->values[out_enabled_layers->count++] = layer_name; + } + + for (iree_host_size_t i = 0; i < optional_layers->count; ++i) { + const char* layer_name = optional_layers->values[i]; + if (iree_hal_vulkan_layer_list_contains(available_layers_count, + available_layers, layer_name)) { + out_enabled_layers->values[out_enabled_layers->count++] = layer_name; } } - return enabled_layers; + return iree_ok_status(); } -StatusOr> MatchAvailableExtensions( - absl::Span required_extensions, - absl::Span optional_extensions, - absl::Span properties) { - IREE_TRACE_SCOPE0("MatchAvailableExtensions"); - - std::vector enabled_extensions; - enabled_extensions.reserve(required_extensions.size() + - optional_extensions.size()); - - for (const char* extension_name : required_extensions) { - bool found = false; - for (const auto& extension_properties : properties) { - if (std::strcmp(extension_name, extension_properties.extensionName) == - 0) { - IREE_VLOG(1) << "Enabling required extension: " << extension_name; - found = true; - enabled_extensions.push_back(extension_name); - break; - } - } - if (!found) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required extension " << extension_name << " not available"; +iree_status_t iree_hal_vulkan_match_available_instance_layers( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_layers, + const iree_hal_vulkan_string_list_t* optional_layers, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_enabled_layers) { + uint32_t layer_property_count = 0; + VK_RETURN_IF_ERROR( + syms->vkEnumerateInstanceLayerProperties(&layer_property_count, NULL), + "vkEnumerateInstanceLayerProperties"); + VkLayerProperties* layer_properties = + (VkLayerProperties*)arena->AllocateBytes(layer_property_count * + sizeof(VkLayerProperties)); + VK_RETURN_IF_ERROR(syms->vkEnumerateInstanceLayerProperties( + &layer_property_count, layer_properties), + "vkEnumerateInstanceLayerProperties"); + out_enabled_layers->count = 0; + out_enabled_layers->values = (const char**)arena->AllocateBytes( + (required_layers->count + optional_layers->count) * + sizeof(out_enabled_layers->values[0])); + return iree_hal_vulkan_match_available_layers( + layer_property_count, layer_properties, required_layers, optional_layers, + out_enabled_layers); +} + +// Returns true if |extensions| contains a layer matching |extension_name|. +static bool iree_hal_vulkan_extension_list_contains( + uint32_t extension_count, const VkExtensionProperties* extensions, + const char* extension_name) { + for (uint32_t i = 0; i < extension_count; ++i) { + if (strcmp(extension_name, extensions[i].extensionName) == 0) { + return true; } } + return false; +} - for (const char* extension_name : optional_extensions) { - bool found = false; - for (const auto& extension_properties : properties) { - if (std::strcmp(extension_name, extension_properties.extensionName) == - 0) { - IREE_VLOG(1) << "Enabling optional extension: " << extension_name; - found = true; - enabled_extensions.push_back(extension_name); - break; - } - } - if (!found) { - IREE_VLOG(1) << "Optional extension " << extension_name - << " not available"; +static iree_status_t iree_hal_vulkan_match_available_extensions( + iree_host_size_t available_extension_count, + const VkExtensionProperties* available_extensions, + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree_hal_vulkan_string_list_t* out_enabled_extensions) { + memset(out_enabled_extensions->values, 0, + (required_extensions->count + optional_extensions->count) * + sizeof(out_enabled_extensions->values[0])); + + for (iree_host_size_t i = 0; i < required_extensions->count; ++i) { + const char* extension_name = required_extensions->values[i]; + if (!iree_hal_vulkan_extension_list_contains( + available_extension_count, available_extensions, extension_name)) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "required extension %s not available", + extension_name); } + out_enabled_extensions->values[out_enabled_extensions->count++] = + extension_name; } - return enabled_extensions; -} - -} // namespace + for (iree_host_size_t i = 0; i < optional_extensions->count; ++i) { + const char* extension_name = optional_extensions->values[i]; + if (iree_hal_vulkan_extension_list_contains( + available_extension_count, available_extensions, extension_name)) { + out_enabled_extensions->values[out_enabled_extensions->count++] = + extension_name; + } + } -StatusOr> MatchAvailableInstanceLayers( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { - uint32_t layer_property_count = 0; - VK_RETURN_IF_ERROR( - syms.vkEnumerateInstanceLayerProperties(&layer_property_count, nullptr)); - std::vector layer_properties(layer_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceLayerProperties( - &layer_property_count, layer_properties.data())); - IREE_ASSIGN_OR_RETURN(auto enabled_layers, - MatchAvailableLayers(extensibility_spec.required_layers, - extensibility_spec.optional_layers, - layer_properties), - _ << "Unable to find all required instance layers"); - return enabled_layers; + return iree_ok_status(); } -StatusOr> MatchAvailableInstanceExtensions( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { +iree_status_t iree_hal_vulkan_match_available_instance_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions) { uint32_t extension_property_count = 0; - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( - nullptr, &extension_property_count, nullptr)); - std::vector extension_properties( - extension_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( - nullptr, &extension_property_count, extension_properties.data())); - IREE_ASSIGN_OR_RETURN( - auto enabled_extensions, - MatchAvailableExtensions(extensibility_spec.required_extensions, - extensibility_spec.optional_extensions, - extension_properties), - _ << "Unable to find all required instance extensions"); - IREE_ENABLE_LEAK_CHECKS(); - return enabled_extensions; + VK_RETURN_IF_ERROR(syms->vkEnumerateInstanceExtensionProperties( + NULL, &extension_property_count, NULL), + "vkEnumerateInstanceExtensionProperties"); + VkExtensionProperties* extension_properties = + (VkExtensionProperties*)arena->AllocateBytes( + extension_property_count * sizeof(VkExtensionProperties)); + VK_RETURN_IF_ERROR(syms->vkEnumerateInstanceExtensionProperties( + NULL, &extension_property_count, extension_properties), + "vkEnumerateInstanceExtensionProperties"); + out_enabled_extensions->count = 0; + out_enabled_extensions->values = (const char**)arena->AllocateBytes( + (required_extensions->count + optional_extensions->count) * + sizeof(out_enabled_extensions->values[0])); + return iree_hal_vulkan_match_available_extensions( + extension_property_count, extension_properties, required_extensions, + optional_extensions, out_enabled_extensions); } -StatusOr> MatchAvailableDeviceExtensions( +iree_status_t iree_hal_vulkan_match_available_device_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions) { uint32_t extension_property_count = 0; - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( - physical_device, nullptr, &extension_property_count, nullptr)); - std::vector extension_properties( - extension_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( - physical_device, nullptr, &extension_property_count, - extension_properties.data())); - IREE_ASSIGN_OR_RETURN( - auto enabled_extensions, - MatchAvailableExtensions(extensibility_spec.required_extensions, - extensibility_spec.optional_extensions, - extension_properties), - _ << "Unable to find all required device extensions"); - return enabled_extensions; + VK_RETURN_IF_ERROR( + syms->vkEnumerateDeviceExtensionProperties( + physical_device, NULL, &extension_property_count, NULL), + "vkEnumerateDeviceExtensionProperties"); + VkExtensionProperties* extension_properties = + (VkExtensionProperties*)arena->AllocateBytes( + extension_property_count * sizeof(VkExtensionProperties)); + VK_RETURN_IF_ERROR(syms->vkEnumerateDeviceExtensionProperties( + physical_device, NULL, &extension_property_count, + extension_properties), + "vkEnumerateDeviceExtensionProperties"); + out_enabled_extensions->count = 0; + out_enabled_extensions->values = (const char**)arena->AllocateBytes( + (required_extensions->count + optional_extensions->count) * + sizeof(out_enabled_extensions->values[0])); + return iree_hal_vulkan_match_available_extensions( + extension_property_count, extension_properties, required_extensions, + optional_extensions, out_enabled_extensions); } -InstanceExtensions PopulateEnabledInstanceExtensions( - absl::Span extension_names) { - InstanceExtensions extensions = {0}; - for (const char* extension_name : extension_names) { - if (std::strcmp(extension_name, VK_EXT_DEBUG_REPORT_EXTENSION_NAME) == 0) { - extensions.debug_report = true; - } else if (std::strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) == - 0) { +iree_hal_vulkan_instance_extensions_t +iree_hal_vulkan_populate_enabled_instance_extensions( + const iree_hal_vulkan_string_list_t* enabled_extensions) { + iree_hal_vulkan_instance_extensions_t extensions; + memset(&extensions, 0, sizeof(extensions)); + for (iree_host_size_t i = 0; i < enabled_extensions->count; ++i) { + const char* extension_name = enabled_extensions->values[i]; + if (strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) == 0) { extensions.debug_utils = true; } } return extensions; } -DeviceExtensions PopulateEnabledDeviceExtensions( - absl::Span extension_names) { - DeviceExtensions extensions = {0}; - for (const char* extension_name : extension_names) { - if (std::strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) == - 0) { +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_populate_enabled_device_extensions( + const iree_hal_vulkan_string_list_t* enabled_extensions) { + iree_hal_vulkan_device_extensions_t extensions; + memset(&extensions, 0, sizeof(extensions)); + for (iree_host_size_t i = 0; i < enabled_extensions->count; ++i) { + const char* extension_name = enabled_extensions->values[i]; + if (strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) == 0) { extensions.push_descriptors = true; - } else if (std::strcmp(extension_name, - VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME) == 0) { + } else if (strcmp(extension_name, + VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME) == 0) { extensions.timeline_semaphore = true; } } return extensions; } -} // namespace vulkan -} // namespace hal -} // namespace iree +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_infer_enabled_device_extensions( + const iree::hal::vulkan::DynamicSymbols* device_syms) { + iree_hal_vulkan_device_extensions_t extensions; + memset(&extensions, 0, sizeof(extensions)); + if (device_syms->vkCmdPushDescriptorSetKHR) { + extensions.push_descriptors = true; + } + if (device_syms->vkSignalSemaphore || device_syms->vkSignalSemaphoreKHR) { + extensions.timeline_semaphore = true; + } + return extensions; +} diff --git a/iree/hal/vulkan/extensibility_util.h b/iree/hal/vulkan/extensibility_util.h index 3d9435b9325ae..c0c8ff8f39203 100644 --- a/iree/hal/vulkan/extensibility_util.h +++ b/iree/hal/vulkan/extensibility_util.h @@ -12,89 +12,89 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Utilities for working with layers and extensions. - #ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ #define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include - -#include "absl/types/span.h" -#include "iree/base/status.h" +#include "iree/base/arena.h" +#include "iree/hal/vulkan/api.h" #include "iree/hal/vulkan/dynamic_symbols.h" -namespace iree { -namespace hal { -namespace vulkan { - -// Describes required and optional extensibility points. -struct ExtensibilitySpec { - // A list of required and optional layers. - std::vector required_layers; - std::vector optional_layers; - - // A list of required and optional extensions. - // Prefer using the _EXTENSION_NAME macros to make tracking easier (such as - // 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME'). - std::vector required_extensions; - std::vector optional_extensions; -}; - -// Returns a list of layer names available for instances. -// Fails if any required_layers are unavailable. -StatusOr> MatchAvailableInstanceLayers( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Returns a list of extension names available for instances. -// Fails if any required_extensions are unavailable. -StatusOr> MatchAvailableInstanceExtensions( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Returns a list of extension names available for the given |physical_device|. -// Fails if any required_extensions are unavailable. -StatusOr> MatchAvailableDeviceExtensions( +// A list of NUL-terminated strings (so they can be passed directly to Vulkan). +typedef struct { + iree_host_size_t count; + const char** values; +} iree_hal_vulkan_string_list_t; + +// Populates |out_enabled_layers| with all layers that are both available in the +// implementation and |required_layers| and |optional_layers| lists. +// |out_enabled_layers| must have capacity at least the sum of +// |required_layers|.count and |optional_layer|.count. +// Returns failure if any |required_layers| are unavailable. +iree_status_t iree_hal_vulkan_match_available_instance_layers( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_layers, + const iree_hal_vulkan_string_list_t* optional_layers, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_enabled_layers); + +// Populates |out_enabled_extensions| with all extensions that are both +// available in the implementation and |required_extensions| and +// |optional_extensions| lists. |out_enabled_extensions| must have capacity at +// least the sum of |required_extensions|.count and |optional_extensions|.count. +// Returns failure if any |required_extensions| are unavailable. +iree_status_t iree_hal_vulkan_match_available_instance_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions); + +// Populates |out_enabled_extensions| with all extensions that are both +// available in the implementation and |required_extensions| and +// |optional_extensions| lists. |out_enabled_extensions| must have capacity at +// least the sum of |required_extensions|.count and |optional_extensions|.count. +// Returns failure if any |required_extensions| are unavailable. +iree_status_t iree_hal_vulkan_match_available_device_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions); // Bits for enabled instance extensions. // We must use this to query support instead of just detecting symbol names as // ICDs will resolve the functions sometimes even if they don't support the // extension (or we didn't ask for it to be enabled). -struct InstanceExtensions { - // VK_EXT_debug_report is enabled and a callback is regsitered. - // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_report - bool debug_report : 1; - +typedef struct { // VK_EXT_debug_utils is enabled and a debug messenger is registered. // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_utils bool debug_utils : 1; -}; +} iree_hal_vulkan_instance_extensions_t; // Returns a bitfield with all of the provided extension names. -InstanceExtensions PopulateEnabledInstanceExtensions( - absl::Span extension_names); +iree_hal_vulkan_instance_extensions_t +iree_hal_vulkan_populate_enabled_instance_extensions( + const iree_hal_vulkan_string_list_t* enabled_extension); // Bits for enabled device extensions. // We must use this to query support instead of just detecting symbol names as // ICDs will resolve the functions sometimes even if they don't support the // extension (or we didn't ask for it to be enabled). -struct DeviceExtensions { +typedef struct { // VK_KHR_push_descriptor is enabled and vkCmdPushDescriptorSetKHR is valid. bool push_descriptors : 1; // VK_KHR_timeline_semaphore is enabled. bool timeline_semaphore : 1; -}; +} iree_hal_vulkan_device_extensions_t; // Returns a bitfield with all of the provided extension names. -DeviceExtensions PopulateEnabledDeviceExtensions( - absl::Span extension_names); - -} // namespace vulkan -} // namespace hal -} // namespace iree +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_populate_enabled_device_extensions( + const iree_hal_vulkan_string_list_t* enabled_extension); + +// Returns a bitfield with the extensions that are (likely) available on the +// device symbols. This is less reliable than setting the bits directly when +// the known set of extensions is available. +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_infer_enabled_device_extensions( + const iree::hal::vulkan::DynamicSymbols* device_syms); #endif // IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ diff --git a/iree/hal/vulkan/handle_util.h b/iree/hal/vulkan/handle_util.h index 2cd3f642e6164..7df7402aba672 100644 --- a/iree/hal/vulkan/handle_util.h +++ b/iree/hal/vulkan/handle_util.h @@ -28,11 +28,12 @@ #include "iree/hal/vulkan/vulkan_headers.h" // clang-format on -#include "absl/synchronization/mutex.h" #include "iree/base/ref_ptr.h" #include "iree/base/status.h" +#include "iree/base/synchronization.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/extensibility_util.h" +#include "iree/hal/vulkan/status_util.h" namespace iree { namespace hal { @@ -40,13 +41,15 @@ namespace vulkan { class VkDeviceHandle : public RefObject { public: - VkDeviceHandle(const ref_ptr& syms, - DeviceExtensions enabled_extensions, bool owns_device, + VkDeviceHandle(DynamicSymbols* syms, + iree_hal_vulkan_device_extensions_t enabled_extensions, + bool owns_device, iree_allocator_t host_allocator, const VkAllocationCallbacks* allocator = nullptr) : syms_(add_ref(syms)), enabled_extensions_(enabled_extensions), owns_device_(owns_device), - allocator_(allocator) {} + allocator_(allocator), + host_allocator_(host_allocator) {} ~VkDeviceHandle() { reset(); } VkDeviceHandle(const VkDeviceHandle&) = delete; @@ -57,7 +60,8 @@ class VkDeviceHandle : public RefObject { syms_(std::move(other.syms_)), enabled_extensions_(other.enabled_extensions_), owns_device_(other.owns_device_), - allocator_(other.allocator_) {} + allocator_(other.allocator_), + host_allocator_(other.host_allocator_) {} void reset() { if (value_ == VK_NULL_HANDLE) return; @@ -73,24 +77,31 @@ class VkDeviceHandle : public RefObject { const ref_ptr& syms() const noexcept { return syms_; } const VkAllocationCallbacks* allocator() const noexcept { return allocator_; } + iree_allocator_t host_allocator() const noexcept { return host_allocator_; } - const DeviceExtensions& enabled_extensions() const { + const iree_hal_vulkan_device_extensions_t& enabled_extensions() const { return enabled_extensions_; } private: VkDevice value_ = VK_NULL_HANDLE; ref_ptr syms_; - DeviceExtensions enabled_extensions_; + iree_hal_vulkan_device_extensions_t enabled_extensions_; bool owns_device_; const VkAllocationCallbacks* allocator_ = nullptr; + iree_allocator_t host_allocator_; }; -class VkCommandPoolHandle : public RefObject { +class VkCommandPoolHandle { public: - explicit VkCommandPoolHandle(const ref_ptr& logical_device) - : logical_device_(add_ref(logical_device)) {} - ~VkCommandPoolHandle() { reset(); } + explicit VkCommandPoolHandle(VkDeviceHandle* logical_device) + : logical_device_(logical_device) { + iree_slim_mutex_initialize(&mutex_); + } + ~VkCommandPoolHandle() { + reset(); + iree_slim_mutex_deinitialize(&mutex_); + } VkCommandPoolHandle(const VkCommandPoolHandle&) = delete; VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete; @@ -114,7 +125,7 @@ class VkCommandPoolHandle : public RefObject { VkCommandPool* mutable_value() noexcept { return &value_; } operator VkCommandPool() const noexcept { return value_; } - const ref_ptr& logical_device() const noexcept { + const VkDeviceHandle* logical_device() const noexcept { return logical_device_; } const ref_ptr& syms() const noexcept { @@ -124,16 +135,31 @@ class VkCommandPoolHandle : public RefObject { return logical_device_->allocator(); } - absl::Mutex* mutex() const { return &mutex_; } + iree_status_t Allocate(const VkCommandBufferAllocateInfo* allocate_info, + VkCommandBuffer* out_handle) { + iree_slim_mutex_lock(&mutex_); + iree_status_t status = + VK_RESULT_TO_STATUS(syms()->vkAllocateCommandBuffers( + *logical_device_, allocate_info, out_handle), + "vkAllocateCommandBuffers"); + iree_slim_mutex_unlock(&mutex_); + return status; + } + + void Free(VkCommandBuffer handle) { + iree_slim_mutex_lock(&mutex_); + syms()->vkFreeCommandBuffers(*logical_device_, value_, 1, &handle); + iree_slim_mutex_unlock(&mutex_); + } private: - ref_ptr logical_device_; + VkDeviceHandle* logical_device_; VkCommandPool value_ = VK_NULL_HANDLE; // Vulkan command pools are not thread safe and require external // synchronization. Since we allow arbitrary threads to allocate and // deallocate the HAL command buffers we need to externally synchronize. - mutable absl::Mutex mutex_; + iree_slim_mutex_t mutex_; }; } // namespace vulkan diff --git a/iree/hal/vulkan/native_descriptor_set.cc b/iree/hal/vulkan/native_descriptor_set.cc index f000c1c6c53a7..a047d31d57c99 100644 --- a/iree/hal/vulkan/native_descriptor_set.cc +++ b/iree/hal/vulkan/native_descriptor_set.cc @@ -14,23 +14,80 @@ #include "iree/hal/vulkan/native_descriptor_set.h" -namespace iree { -namespace hal { -namespace vulkan { +#include "iree/base/tracing.h" -NativeDescriptorSet::NativeDescriptorSet(ref_ptr logical_device, - VkDescriptorSet handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} +using namespace iree::hal::vulkan; + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkDescriptorSet handle; +} iree_hal_vulkan_native_descriptor_set_t; + +extern const iree_hal_descriptor_set_vtable_t + iree_hal_vulkan_native_descriptor_set_vtable; + +static iree_hal_vulkan_native_descriptor_set_t* +iree_hal_vulkan_native_descriptor_set_cast( + iree_hal_descriptor_set_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_native_descriptor_set_vtable); + return (iree_hal_vulkan_native_descriptor_set_t*)base_value; +} + +iree_status_t iree_hal_vulkan_native_descriptor_set_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, VkDescriptorSet handle, + iree_hal_descriptor_set_t** out_descriptor_set) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(handle); + IREE_ASSERT_ARGUMENT(out_descriptor_set); + *out_descriptor_set = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_native_descriptor_set_t* descriptor_set = NULL; + iree_status_t status = + iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*descriptor_set), (void**)&descriptor_set); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_descriptor_set_vtable, + &descriptor_set->resource); + descriptor_set->logical_device = logical_device; + descriptor_set->handle = handle; + *out_descriptor_set = (iree_hal_descriptor_set_t*)descriptor_set; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_descriptor_set_destroy( + iree_hal_descriptor_set_t* base_descriptor_set) { + iree_hal_vulkan_native_descriptor_set_t* descriptor_set = + iree_hal_vulkan_native_descriptor_set_cast(base_descriptor_set); + iree_allocator_t host_allocator = + descriptor_set->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); -NativeDescriptorSet::~NativeDescriptorSet() { // TODO(benvanik): return to pool. For now we rely on the descriptor cache to // reset entire pools at once via via vkResetDescriptorPool so we don't need // to do anything here (the VkDescriptorSet handle will just be invalidated). // In the future if we want to have generational collection/defragmentation // of the descriptor cache we'll want to allow both pooled and unpooled // descriptors and clean them up here appropriately. + + iree_allocator_free(host_allocator, descriptor_set); + + IREE_TRACE_ZONE_END(z0); +} + +VkDescriptorSet iree_hal_vulkan_native_descriptor_set_handle( + iree_hal_descriptor_set_t* base_descriptor_set) { + iree_hal_vulkan_native_descriptor_set_t* descriptor_set = + iree_hal_vulkan_native_descriptor_set_cast(base_descriptor_set); + return descriptor_set->handle; } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_descriptor_set_vtable_t + iree_hal_vulkan_native_descriptor_set_vtable = { + /*.destroy=*/iree_hal_vulkan_native_descriptor_set_destroy, +}; diff --git a/iree/hal/vulkan/native_descriptor_set.h b/iree/hal/vulkan/native_descriptor_set.h index b7649282989dd..cf9379eb318ec 100644 --- a/iree/hal/vulkan/native_descriptor_set.h +++ b/iree/hal/vulkan/native_descriptor_set.h @@ -15,33 +15,24 @@ #ifndef IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_ #define IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/descriptor_set.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/handle_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -// A DescriptorSet implemented with the native VkDescriptorSet type. -class NativeDescriptorSet final : public DescriptorSet { - public: - NativeDescriptorSet(ref_ptr logical_device, - VkDescriptorSet handle); - ~NativeDescriptorSet() override; +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus - VkDescriptorSet handle() const { return handle_; } +// Creates a native Vulkan VkDescriptorSet object. +iree_status_t iree_hal_vulkan_native_descriptor_set_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, VkDescriptorSet handle, + iree_hal_descriptor_set_t** out_descriptor_set); - private: - ref_ptr logical_device_; - VkDescriptorSet handle_; -}; +// Returns the native Vulkan VkDescriptorSet handle. +VkDescriptorSet iree_hal_vulkan_native_descriptor_set_handle( + iree_hal_descriptor_set_t* base_descriptor_set); -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_ diff --git a/iree/hal/vulkan/native_descriptor_set_layout.cc b/iree/hal/vulkan/native_descriptor_set_layout.cc new file mode 100644 index 0000000000000..29fbe4b4e6f46 --- /dev/null +++ b/iree/hal/vulkan/native_descriptor_set_layout.cc @@ -0,0 +1,158 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_descriptor_set_layout.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/status_util.h" + +using namespace iree::hal::vulkan; + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkDescriptorSetLayout handle; +} iree_hal_vulkan_native_descriptor_set_layout_t; + +extern const iree_hal_descriptor_set_layout_vtable_t + iree_hal_vulkan_native_descriptor_set_layout_vtable; + +static iree_hal_vulkan_native_descriptor_set_layout_t* +iree_hal_vulkan_native_descriptor_set_layout_cast( + iree_hal_descriptor_set_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_native_descriptor_set_layout_vtable); + return (iree_hal_vulkan_native_descriptor_set_layout_t*)base_value; +} + +static iree_status_t iree_hal_vulkan_create_descriptor_set_layout( + VkDeviceHandle* logical_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + VkDescriptorSetLayout* out_handle) { + VkDescriptorSetLayoutCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + create_info.pNext = NULL; + create_info.flags = 0; + if (usage_type == IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY && + logical_device->enabled_extensions().push_descriptors) { + // Note that we can *only* use push descriptor sets if we set this create + // flag. If push descriptors aren't supported we emulate them with normal + // descriptors so it's fine to have kPushOnly without support. + create_info.flags |= + VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; + } + + VkDescriptorSetLayoutBinding* native_bindings = NULL; + if (binding_count > 0) { + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + logical_device->host_allocator(), + binding_count * sizeof(VkDescriptorSetLayoutBinding), + (void**)&native_bindings)); + for (iree_host_size_t i = 0; i < binding_count; ++i) { + VkDescriptorSetLayoutBinding* native_binding = &native_bindings[i]; + native_binding->binding = bindings[i].binding; + native_binding->descriptorType = + static_cast(bindings[i].type); + native_binding->descriptorCount = 1; + native_binding->stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + native_binding->pImmutableSamplers = NULL; + } + } + create_info.bindingCount = (uint32_t)binding_count; + create_info.pBindings = native_bindings; + + iree_status_t status = + VK_RESULT_TO_STATUS(logical_device->syms()->vkCreateDescriptorSetLayout( + *logical_device, &create_info, + logical_device->allocator(), out_handle), + "vkCreateDescriptorSetLayout"); + + iree_allocator_free(logical_device->host_allocator(), native_bindings); + return status; +} + +static void iree_hal_vulkan_destroy_descriptor_set_layout( + VkDeviceHandle* logical_device, VkDescriptorSetLayout handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyDescriptorSetLayout( + *logical_device, handle, logical_device->allocator()); +} + +iree_status_t iree_hal_vulkan_native_descriptor_set_layout_create( + VkDeviceHandle* logical_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); + *out_descriptor_set_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkDescriptorSetLayout handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_descriptor_set_layout( + logical_device, usage_type, binding_count, bindings, &handle)); + + iree_hal_vulkan_native_descriptor_set_layout_t* descriptor_set_layout = NULL; + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*descriptor_set_layout), + (void**)&descriptor_set_layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize( + &iree_hal_vulkan_native_descriptor_set_layout_vtable, + &descriptor_set_layout->resource); + descriptor_set_layout->logical_device = logical_device; + descriptor_set_layout->handle = handle; + *out_descriptor_set_layout = + (iree_hal_descriptor_set_layout_t*)descriptor_set_layout; + } else { + iree_hal_vulkan_destroy_descriptor_set_layout(logical_device, handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_descriptor_set_layout_destroy( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_vulkan_native_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_vulkan_native_descriptor_set_layout_cast( + base_descriptor_set_layout); + iree_allocator_t host_allocator = + descriptor_set_layout->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_destroy_descriptor_set_layout( + descriptor_set_layout->logical_device, descriptor_set_layout->handle); + iree_allocator_free(host_allocator, descriptor_set_layout); + + IREE_TRACE_ZONE_END(z0); +} + +VkDescriptorSetLayout iree_hal_vulkan_native_descriptor_set_layout_handle( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_vulkan_native_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_vulkan_native_descriptor_set_layout_cast( + base_descriptor_set_layout); + return descriptor_set_layout->handle; +} + +const iree_hal_descriptor_set_layout_vtable_t + iree_hal_vulkan_native_descriptor_set_layout_vtable = { + /*.destroy=*/iree_hal_vulkan_native_descriptor_set_layout_destroy, +}; diff --git a/iree/hal/vulkan/native_descriptor_set_layout.h b/iree/hal/vulkan/native_descriptor_set_layout.h new file mode 100644 index 0000000000000..d7fc86bbb643c --- /dev/null +++ b/iree/hal/vulkan/native_descriptor_set_layout.h @@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_LAYOUT_H_ +#define IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_LAYOUT_H_ + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a native Vulkan VkDescriptorSetLayout object. +iree_status_t iree_hal_vulkan_native_descriptor_set_layout_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + +// Returns the native Vulkan VkDescriptorSetLayout handle. +VkDescriptorSetLayout iree_hal_vulkan_native_descriptor_set_layout_handle( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_LAYOUT_H_ diff --git a/iree/hal/vulkan/native_event.cc b/iree/hal/vulkan/native_event.cc index 28dbc568fa839..c9a7b1366378f 100644 --- a/iree/hal/vulkan/native_event.cc +++ b/iree/hal/vulkan/native_event.cc @@ -14,18 +14,89 @@ #include "iree/hal/vulkan/native_event.h" -namespace iree { -namespace hal { -namespace vulkan { +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/status_util.h" -NativeEvent::NativeEvent(ref_ptr logical_device, VkEvent handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} +using namespace iree::hal::vulkan; -NativeEvent::~NativeEvent() { - logical_device_->syms()->vkDestroyEvent(*logical_device_, handle_, - logical_device_->allocator()); +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkEvent handle; +} iree_hal_vulkan_native_event_t; + +extern const iree_hal_event_vtable_t iree_hal_vulkan_native_event_vtable; + +static iree_hal_vulkan_native_event_t* iree_hal_vulkan_native_event_cast( + iree_hal_event_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_native_event_vtable); + return (iree_hal_vulkan_native_event_t*)base_value; +} + +static iree_status_t iree_hal_vulkan_create_event( + VkDeviceHandle* logical_device, VkEvent* out_handle) { + VkEventCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO; + create_info.pNext = NULL; + create_info.flags = 0; + return VK_RESULT_TO_STATUS(logical_device->syms()->vkCreateEvent( + *logical_device, &create_info, + logical_device->allocator(), out_handle), + "vkCreateEvent"); +} + +static void iree_hal_vulkan_destroy_event(VkDeviceHandle* logical_device, + VkEvent handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyEvent(*logical_device, handle, + logical_device->allocator()); +} + +iree_status_t iree_hal_vulkan_native_event_create( + VkDeviceHandle* logical_device, iree_hal_event_t** out_event) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_event); + *out_event = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkEvent handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_event(logical_device, &handle)); + + iree_hal_vulkan_native_event_t* event = NULL; + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*event), (void**)&event); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_event_vtable, + &event->resource); + event->logical_device = logical_device; + event->handle = handle; + *out_event = (iree_hal_event_t*)event; + } else { + iree_hal_vulkan_destroy_event(logical_device, handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_event_destroy(iree_hal_event_t* base_event) { + iree_hal_vulkan_native_event_t* event = + iree_hal_vulkan_native_event_cast(base_event); + iree_allocator_t host_allocator = event->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_destroy_event(event->logical_device, event->handle); + iree_allocator_free(host_allocator, event); + + IREE_TRACE_ZONE_END(z0); +} + +VkEvent iree_hal_vulkan_native_event_handle( + const iree_hal_event_t* base_event) { + return ((const iree_hal_vulkan_native_event_t*)base_event)->handle; } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_event_vtable_t iree_hal_vulkan_native_event_vtable = { + /*.destroy=*/iree_hal_vulkan_native_event_destroy, +}; diff --git a/iree/hal/vulkan/native_event.h b/iree/hal/vulkan/native_event.h index beab80028eac1..83c79192facb7 100644 --- a/iree/hal/vulkan/native_event.h +++ b/iree/hal/vulkan/native_event.h @@ -15,32 +15,23 @@ #ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_ #define IREE_HAL_VULKAN_NATIVE_EVENT_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/event.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/handle_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -// An event implemented with the native VkEvent type. -class NativeEvent final : public Event { - public: - NativeEvent(ref_ptr logical_device, VkEvent handle); - ~NativeEvent() override; +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus - VkEvent handle() const { return handle_; } +// Creates a native Vulkan VkEvent object. +iree_status_t iree_hal_vulkan_native_event_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_hal_event_t** out_event); - private: - ref_ptr logical_device_; - VkEvent handle_; -}; +// Returns Vulkan event handle. +VkEvent iree_hal_vulkan_native_event_handle(const iree_hal_event_t* event); -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_NATIVE_EVENT_H_ diff --git a/iree/hal/vulkan/native_executable.cc b/iree/hal/vulkan/native_executable.cc new file mode 100644 index 0000000000000..cbe2b10eaa870 --- /dev/null +++ b/iree/hal/vulkan/native_executable.cc @@ -0,0 +1,287 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_executable.h" + +#include "iree/base/memory.h" +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/handle_util.h" +#include "iree/hal/vulkan/native_executable_layout.h" +#include "iree/hal/vulkan/status_util.h" + +// flatcc schemas: +#include "iree/base/flatcc.h" +#include "iree/schemas/spirv_executable_def_reader.h" +#include "iree/schemas/spirv_executable_def_verifier.h" + +using namespace iree::hal::vulkan; + +static iree_status_t iree_hal_vulkan_create_shader_module( + VkDeviceHandle* logical_device, iree_const_byte_span_t code, + VkShaderModule* out_shader_module) { + VkShaderModuleCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + create_info.pNext = NULL; + create_info.flags = 0; + create_info.codeSize = code.data_length; + create_info.pCode = (const uint32_t*)code.data; + VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateShaderModule( + *logical_device, &create_info, + logical_device->allocator(), out_shader_module), + "vkCreateShaderModule"); + return iree_ok_status(); +} + +static void iree_hal_vulkan_destroy_shader_module( + VkDeviceHandle* logical_device, VkShaderModule handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyShaderModule(*logical_device, handle, + logical_device->allocator()); +} + +static iree_status_t iree_hal_vulkan_create_pipelines( + VkDeviceHandle* logical_device, VkPipelineCache pipeline_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_SpirVExecutableDef_table_t executable_def, + VkShaderModule shader_module, iree_host_size_t pipeline_count, + VkPipeline* out_pipelines) { + VkComputePipelineCreateInfo* create_infos = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + logical_device->host_allocator(), + pipeline_count * sizeof(VkComputePipelineCreateInfo), + (void**)&create_infos)); + + flatbuffers_string_vec_t entry_points_vec = + iree_SpirVExecutableDef_entry_points_get(executable_def); + for (iree_host_size_t entry_ordinal = 0; entry_ordinal < pipeline_count; + ++entry_ordinal) { + VkComputePipelineCreateInfo* create_info = &create_infos[entry_ordinal]; + create_info->sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + create_info->pNext = NULL; + create_info->flags = 0; + if (!iree_all_bits_set( + caching_mode, + IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION)) { + create_info->flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT; + } + if (entry_ordinal == 0) { + create_info->flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; + } else { + create_info->flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT; + } + create_info->layout = + iree_hal_vulkan_native_executable_layout_handle(executable_layout); + create_info->basePipelineHandle = VK_NULL_HANDLE; + create_info->basePipelineIndex = 0; + VkPipelineShaderStageCreateInfo* stage_create_info = &create_info->stage; + stage_create_info->sType = + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + stage_create_info->pNext = NULL; + stage_create_info->flags = 0; + stage_create_info->stage = VK_SHADER_STAGE_COMPUTE_BIT; + stage_create_info->module = shader_module; + stage_create_info->pName = + flatbuffers_string_vec_at(entry_points_vec, entry_ordinal); + stage_create_info->pSpecializationInfo = NULL; + } + + iree_status_t status = VK_RESULT_TO_STATUS( + logical_device->syms()->vkCreateComputePipelines( + *logical_device, pipeline_cache, (uint32_t)pipeline_count, + create_infos, logical_device->allocator(), out_pipelines), + "vkCreateComputePipelines"); + + iree_allocator_free(logical_device->host_allocator(), create_infos); + return status; +} + +static void iree_hal_vulkan_destroy_pipeline(VkDeviceHandle* logical_device, + VkPipeline handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyPipeline(*logical_device, handle, + logical_device->allocator()); +} + +// Verifies the structure of the flatbuffer so that we can avoid doing so during +// runtime. There are still some conditions we must be aware of (such as omitted +// names on functions with internal linkage), however we shouldn't need to +// bounds check anything within the flatbuffer after this succeeds. +static iree_status_t iree_hal_spirv_executable_flatbuffer_verify( + iree_const_byte_span_t flatbuffer_data) { + if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer data is not present or less than 16 bytes (%zu total)", + flatbuffer_data.data_length); + } + + // Run flatcc generated verification. This ensures all pointers are in-bounds + // and that we can safely walk the file, but not that the actual contents of + // the flatbuffer meet our expectations. + int verify_ret = iree_SpirVExecutableDef_verify_as_root( + flatbuffer_data.data, flatbuffer_data.data_length); + if (verify_ret != flatcc_verify_ok) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer verification failed: %s", + flatcc_verify_error_string(verify_ret)); + } + + iree_SpirVExecutableDef_table_t executable_def = + iree_SpirVExecutableDef_as_root(flatbuffer_data.data); + + flatbuffers_string_vec_t entry_points_vec = + iree_SpirVExecutableDef_entry_points_get(executable_def); + size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); + for (size_t i = 0; i < entry_point_count; ++i) { + if (!flatbuffers_string_len( + flatbuffers_string_vec_at(entry_points_vec, i))) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable entry point %zu has no name", i); + } + } + + if (flatbuffers_uint32_vec_len( + iree_SpirVExecutableDef_code_get(executable_def)) < 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable SPIR-V code is missing/empty"); + } + + // TODO(benvanik): pull PopulateSpecializationInfo from history and update. + // For now the compiler isn't generating them, and we don't use them. + if (iree_SpirVExecutableDef_specialization_info_is_present(executable_def)) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "executable uses SPIR-V specialization constants; " + "they need to be revived"); + } + + return iree_ok_status(); +} + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + iree_host_size_t pipeline_count; + VkPipeline pipelines[]; +} iree_hal_vulkan_native_executable_t; + +extern const iree_hal_executable_vtable_t + iree_hal_vulkan_native_executable_vtable; + +static iree_hal_vulkan_native_executable_t* +iree_hal_vulkan_native_executable_cast(iree_hal_executable_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_native_executable_vtable); + return (iree_hal_vulkan_native_executable_t*)base_value; +} + +iree_status_t iree_hal_vulkan_native_executable_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + VkPipelineCache pipeline_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Verify and fetch the executable flatbuffer wrapper. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_spirv_executable_flatbuffer_verify(executable_data)); + iree_SpirVExecutableDef_table_t executable_def = + iree_SpirVExecutableDef_as_root(executable_data.data); + + // Create the shader module. + flatbuffers_uint32_vec_t code_vec = + iree_SpirVExecutableDef_code_get(executable_def); + VkShaderModule shader_module = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_shader_module( + logical_device, + iree_make_const_byte_span( + code_vec, + flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t)), + &shader_module)); + + // Create pipelines for each entry point. + flatbuffers_string_vec_t entry_points_vec = + iree_SpirVExecutableDef_entry_points_get(executable_def); + iree_host_size_t pipeline_count = + flatbuffers_string_vec_len(entry_points_vec); + + iree_hal_vulkan_native_executable_t* executable = NULL; + iree_host_size_t total_size = + sizeof(*executable) + pipeline_count * sizeof(*executable->pipelines); + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + total_size, (void**)&executable); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_executable_vtable, + &executable->resource); + executable->logical_device = logical_device; + executable->pipeline_count = pipeline_count; + memset(executable->pipelines, 0, + pipeline_count * sizeof(*executable->pipelines)); + } + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_create_pipelines( + logical_device, pipeline_cache, executable_layout, caching_mode, + executable_def, shader_module, executable->pipeline_count, + executable->pipelines); + } + iree_hal_vulkan_destroy_shader_module(logical_device, shader_module); + + if (iree_status_is_ok(status)) { + *out_executable = (iree_hal_executable_t*)executable; + } else { + iree_hal_executable_destroy((iree_hal_executable_t*)executable); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_vulkan_native_executable_t* executable = + iree_hal_vulkan_native_executable_cast(base_executable); + iree_allocator_t host_allocator = + executable->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < executable->pipeline_count; ++i) { + iree_hal_vulkan_destroy_pipeline(executable->logical_device, + executable->pipelines[i]); + } + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point( + iree_hal_executable_t* base_executable, iree_host_size_t entry_ordinal, + VkPipeline* out_pipeline_handle) { + iree_hal_vulkan_native_executable_t* executable = + iree_hal_vulkan_native_executable_cast(base_executable); + if (entry_ordinal >= executable->pipeline_count) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "invalid entry point ordinal %zu", entry_ordinal); + } + *out_pipeline_handle = executable->pipelines[entry_ordinal]; + return iree_ok_status(); +} + +const iree_hal_executable_vtable_t iree_hal_vulkan_native_executable_vtable = { + /*.destroy=*/iree_hal_vulkan_native_executable_destroy, +}; diff --git a/iree/hal/vulkan/native_executable.h b/iree/hal/vulkan/native_executable.h new file mode 100644 index 0000000000000..d8372c749b1ac --- /dev/null +++ b/iree/hal/vulkan/native_executable.h @@ -0,0 +1,49 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_EXECUTABLE_H_ +#define IREE_HAL_VULKAN_NATIVE_EXECUTABLE_H_ + +// clang-format off: Must be included before all other headers: +#include "iree/hal/vulkan/vulkan_headers.h" +// clang-format on + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a wrapper for one or more VkPipelines that are sourced from the same +// IREE executable. Each of the pipelines will share the same shader module +// and just differs by the entry point into the shader module they reference. +iree_status_t iree_hal_vulkan_native_executable_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + VkPipelineCache pipeline_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable); + +// Returns the cached VkPipeline for the given executable |entry_ordinal|. +iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point( + iree_hal_executable_t* executable, iree_host_size_t entry_ordinal, + VkPipeline* out_pipeline_handle); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_EXECUTABLE_H_ diff --git a/iree/hal/vulkan/native_executable_layout.cc b/iree/hal/vulkan/native_executable_layout.cc new file mode 100644 index 0000000000000..6a250691c7095 --- /dev/null +++ b/iree/hal/vulkan/native_executable_layout.cc @@ -0,0 +1,173 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_executable_layout.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/native_descriptor_set_layout.h" +#include "iree/hal/vulkan/status_util.h" + +using namespace iree::hal::vulkan; + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkPipelineLayout handle; + iree_host_size_t set_layout_count; + iree_hal_descriptor_set_layout_t* set_layouts[]; +} iree_hal_vulkan_native_executable_layout_t; + +extern const iree_hal_executable_layout_vtable_t + iree_hal_vulkan_native_executable_layout_vtable; + +static iree_hal_vulkan_native_executable_layout_t* +iree_hal_vulkan_native_executable_layout_cast( + iree_hal_executable_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_native_executable_layout_vtable); + return (iree_hal_vulkan_native_executable_layout_t*)base_value; +} + +static iree_status_t iree_hal_vulkan_create_pipeline_layout( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constant_count, VkPipelineLayout* out_handle) { + VkDescriptorSetLayout* set_layout_handles = + (VkDescriptorSetLayout*)iree_alloca(set_layout_count * + sizeof(VkDescriptorSetLayout)); + for (iree_host_size_t i = 0; i < set_layout_count; ++i) { + set_layout_handles[i] = + iree_hal_vulkan_native_descriptor_set_layout_handle(set_layouts[i]); + } + + VkPushConstantRange push_constant_ranges[1]; + push_constant_ranges[0].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + push_constant_ranges[0].offset = 0; + push_constant_ranges[0].size = + (uint32_t)(push_constant_count * sizeof(uint32_t)); + + VkPipelineLayoutCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + create_info.setLayoutCount = (uint32_t)set_layout_count; + create_info.pSetLayouts = set_layout_handles; + create_info.pushConstantRangeCount = push_constant_count > 0 ? 1 : 0; + create_info.pPushConstantRanges = push_constant_ranges; + + return VK_RESULT_TO_STATUS(logical_device->syms()->vkCreatePipelineLayout( + *logical_device, &create_info, + logical_device->allocator(), out_handle), + "vkCreatePipelineLayout"); +} + +static void iree_hal_vulkan_destroy_pipeline_layout( + VkDeviceHandle* logical_device, VkPipelineLayout handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyPipelineLayout(*logical_device, handle, + logical_device->allocator()); +} + +iree_status_t iree_hal_vulkan_native_executable_layout_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constant_count, + iree_hal_executable_layout_t** out_executable_layout) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); + IREE_ASSERT_ARGUMENT(out_executable_layout); + *out_executable_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkPipelineLayout handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_pipeline_layout(logical_device, + set_layout_count, set_layouts, + push_constant_count, &handle)); + + iree_hal_vulkan_native_executable_layout_t* executable_layout = NULL; + iree_host_size_t total_size = + sizeof(*executable_layout) + + set_layout_count * sizeof(*executable_layout->set_layouts); + iree_status_t status = iree_allocator_malloc( + logical_device->host_allocator(), total_size, (void**)&executable_layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize( + &iree_hal_vulkan_native_executable_layout_vtable, + &executable_layout->resource); + executable_layout->logical_device = logical_device; + executable_layout->handle = handle; + executable_layout->set_layout_count = set_layout_count; + for (iree_host_size_t i = 0; i < set_layout_count; ++i) { + executable_layout->set_layouts[i] = set_layouts[i]; + iree_hal_descriptor_set_layout_retain(set_layouts[i]); + } + *out_executable_layout = (iree_hal_executable_layout_t*)executable_layout; + } else { + iree_hal_vulkan_destroy_pipeline_layout(logical_device, handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_executable_layout_destroy( + iree_hal_executable_layout_t* base_executable_layout) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + iree_allocator_t host_allocator = + executable_layout->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_destroy_pipeline_layout(executable_layout->logical_device, + executable_layout->handle); + for (iree_host_size_t i = 0; i < executable_layout->set_layout_count; ++i) { + iree_hal_descriptor_set_layout_release(executable_layout->set_layouts[i]); + } + iree_allocator_free(host_allocator, executable_layout); + + IREE_TRACE_ZONE_END(z0); +} + +VkPipelineLayout iree_hal_vulkan_native_executable_layout_handle( + iree_hal_executable_layout_t* base_executable_layout) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + return executable_layout->handle; +} + +iree_host_size_t iree_hal_vulkan_native_executable_layout_set_count( + iree_hal_executable_layout_t* base_executable_layout) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + return executable_layout->set_layout_count; +} + +iree_hal_descriptor_set_layout_t* iree_hal_vulkan_native_executable_layout_set( + iree_hal_executable_layout_t* base_executable_layout, + iree_host_size_t set_index) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + if (IREE_UNLIKELY(set_index >= executable_layout->set_layout_count)) { + return NULL; + } + return executable_layout->set_layouts[set_index]; +} + +const iree_hal_executable_layout_vtable_t + iree_hal_vulkan_native_executable_layout_vtable = { + /*.destroy=*/iree_hal_vulkan_native_executable_layout_destroy, +}; diff --git a/iree/hal/vulkan/native_executable_layout.h b/iree/hal/vulkan/native_executable_layout.h new file mode 100644 index 0000000000000..58500fa2b0ff2 --- /dev/null +++ b/iree/hal/vulkan/native_executable_layout.h @@ -0,0 +1,55 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_EXECUTABLE_LAYOUT_H_ +#define IREE_HAL_VULKAN_NATIVE_EXECUTABLE_LAYOUT_H_ + +// clang-format off: Must be included before all other headers: +#include "iree/hal/vulkan/vulkan_headers.h" +// clang-format on + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a VkPipelineLayout-based executable layout composed of one or more +// descriptor set layouts. +iree_status_t iree_hal_vulkan_native_executable_layout_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constant_count, + iree_hal_executable_layout_t** out_executable_layout); + +// Returns the native VkPipelineLayout handle for the executable layout. +VkPipelineLayout iree_hal_vulkan_native_executable_layout_handle( + iree_hal_executable_layout_t* executable_layout); + +// Returns the total number of descriptor sets within the layout. +iree_host_size_t iree_hal_vulkan_native_executable_layout_set_count( + iree_hal_executable_layout_t* executable_layout); + +// Returns the descriptor set layout with the given |set_index|. +iree_hal_descriptor_set_layout_t* iree_hal_vulkan_native_executable_layout_set( + iree_hal_executable_layout_t* executable_layout, + iree_host_size_t set_index); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_EXECUTABLE_LAYOUT_H_ diff --git a/iree/hal/vulkan/native_semaphore.cc b/iree/hal/vulkan/native_semaphore.cc new file mode 100644 index 0000000000000..ec158e1938647 --- /dev/null +++ b/iree/hal/vulkan/native_semaphore.cc @@ -0,0 +1,283 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_semaphore.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/status_util.h" + +// The maximum valid payload value of an iree_hal_semaphore_t. +// Payload values larger than this indicate that the semaphore has failed. +// +// This originates from Vulkan having a lower-bound of INT_MAX for +// maxTimelineSemaphoreValueDifference and many Android devices only supporting +// that lower-bound. At ~100 signals per second it'll take 1.5+ years to +// saturate. We may increase this value at some point but so long as there are +// some devices in the wild that may have this limitation we can ensure better +// consistency across the backends by observing this. +// +// The major mitigation here is that in proper usage of IREE there are no +// semaphores that are implicitly referenced by multiple VMs (each creates their +// own internally) and in a multitenant system each session should have its own +// semaphores - so even if the process lives for years it's highly unlikely any +// particular session does. Whatever, 640K is enough for anyone. +// +// See: +// https://vulkan.gpuinfo.org/displayextensionproperty.php?name=maxTimelineSemaphoreValueDifference +#define IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE (2147483647ull - 1) + +using namespace iree::hal::vulkan; + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkSemaphore handle; + iree_atomic_intptr_t failure_status; +} iree_hal_vulkan_native_semaphore_t; + +extern const iree_hal_semaphore_vtable_t + iree_hal_vulkan_native_semaphore_vtable; + +static iree_hal_vulkan_native_semaphore_t* +iree_hal_vulkan_native_semaphore_cast(iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_native_semaphore_vtable); + return (iree_hal_vulkan_native_semaphore_t*)base_value; +} + +iree_status_t iree_hal_vulkan_native_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_semaphore); + *out_semaphore = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkSemaphoreTypeCreateInfo timeline_create_info; + timeline_create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_TYPE_CREATE_INFO; + timeline_create_info.pNext = NULL; + timeline_create_info.semaphoreType = VK_SEMAPHORE_TYPE_TIMELINE; + timeline_create_info.initialValue = initial_value; + + VkSemaphoreCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; + create_info.pNext = &timeline_create_info; + create_info.flags = 0; + VkSemaphore handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, VK_RESULT_TO_STATUS(logical_device->syms()->vkCreateSemaphore( + *logical_device, &create_info, + logical_device->allocator(), &handle), + "vkCreateSemaphore")); + + iree_hal_vulkan_native_semaphore_t* semaphore = NULL; + iree_status_t status = iree_allocator_malloc( + logical_device->host_allocator(), sizeof(*semaphore), (void**)&semaphore); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_semaphore_vtable, + &semaphore->resource); + semaphore->logical_device = logical_device; + semaphore->handle = handle; + iree_atomic_store_intptr(&semaphore->failure_status, 0, + iree_memory_order_release); + *out_semaphore = (iree_hal_semaphore_t*)semaphore; + } else { + logical_device->syms()->vkDestroySemaphore(*logical_device, handle, + logical_device->allocator()); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + iree_allocator_t host_allocator = semaphore->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_free((iree_status_t)iree_atomic_load_intptr( + &semaphore->failure_status, iree_memory_order_acquire)); + semaphore->logical_device->syms()->vkDestroySemaphore( + *semaphore->logical_device, semaphore->handle, + semaphore->logical_device->allocator()); + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +VkSemaphore iree_hal_vulkan_native_semaphore_handle( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + return semaphore->handle; +} + +static iree_status_t iree_hal_vulkan_native_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + *out_value = 0; + + uint64_t value = 0; + IREE_RETURN_IF_ERROR(VK_RESULT_TO_STATUS( + semaphore->logical_device->syms()->vkGetSemaphoreCounterValue( + *semaphore->logical_device, semaphore->handle, &value), + "vkGetSemaphoreCounterValue")); + + if (value > IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE) { + iree_status_t failure_status = (iree_status_t)iree_atomic_load_intptr( + &semaphore->failure_status, iree_memory_order_acquire); + if (iree_status_is_ok(failure_status)) { + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "overflowed timeline semaphore max value"); + } + return iree_status_clone(failure_status); + } + + *out_value = value; + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_native_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + + VkSemaphoreSignalInfo signal_info; + signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; + signal_info.pNext = NULL; + signal_info.semaphore = semaphore->handle; + signal_info.value = new_value; + return VK_RESULT_TO_STATUS( + semaphore->logical_device->syms()->vkSignalSemaphore( + *semaphore->logical_device, &signal_info), + "vkSignalSemaphore"); +} + +static void iree_hal_vulkan_native_semaphore_fail( + iree_hal_semaphore_t* base_semaphore, iree_status_t status) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + + // Try to set our local status - we only preserve the first failure so only + // do this if we are going from a valid semaphore to a failed one. + iree_status_t old_status = iree_ok_status(); + if (!iree_atomic_compare_exchange_strong_intptr( + &semaphore->failure_status, (intptr_t*)&old_status, (intptr_t)status, + iree_memory_order_seq_cst, iree_memory_order_seq_cst)) { + // Previous status was not OK; drop our new status. + IREE_IGNORE_ERROR(status); + return; + } + + VkSemaphoreSignalInfo signal_info; + signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; + signal_info.pNext = NULL; + signal_info.semaphore = semaphore->handle; + signal_info.value = IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE + 1; + // NOTE: we don't care about the result in case of failures as we are + // failing and the caller will likely be tearing everything down anyway. + semaphore->logical_device->syms()->vkSignalSemaphore( + *semaphore->logical_device, &signal_info); +} + +iree_status_t iree_hal_vulkan_native_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags) { + if (semaphore_list->count == 0) return iree_ok_status(); + + uint64_t timeout_ns; + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { + timeout_ns = UINT64_MAX; + } else if (deadline_ns == IREE_TIME_INFINITE_PAST) { + timeout_ns = 0; + } else { + iree_time_t now_ns = iree_time_now(); + if (deadline_ns < now_ns) { + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + timeout_ns = (uint64_t)(deadline_ns - now_ns); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + VkSemaphore* semaphore_handles = + (VkSemaphore*)iree_alloca(semaphore_list->count * sizeof(VkSemaphore)); + for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) { + semaphore_handles[i] = + iree_hal_vulkan_native_semaphore_handle(semaphore_list->semaphores[i]); + } + + VkSemaphoreWaitInfo wait_info; + wait_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; + wait_info.pNext = nullptr; + wait_info.flags = wait_flags; + wait_info.semaphoreCount = semaphore_list->count; + wait_info.pSemaphores = semaphore_handles; + wait_info.pValues = semaphore_list->payload_values; + static_assert( + sizeof(wait_info.pValues[0]) == sizeof(semaphore_list->payload_values[0]), + "payload value type must match vulkan expected size"); + + // NOTE: this may fail with a timeout (VK_TIMEOUT) or in the case of a + // device loss event may return either VK_SUCCESS *or* VK_ERROR_DEVICE_LOST. + // We may want to explicitly query for device loss after a successful wait + // to ensure we consistently return errors. + VkResult result = logical_device->syms()->vkWaitSemaphores( + *logical_device, &wait_info, timeout_ns); + + IREE_TRACE_ZONE_END(z0); + + if (result == VK_SUCCESS) { + return iree_ok_status(); + } else if (result == VK_ERROR_DEVICE_LOST) { + // Nothing we do now matters. + return VK_RESULT_TO_STATUS(result, "vkWaitSemaphores"); + } else if (result == VK_TIMEOUT) { + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + return VK_RESULT_TO_STATUS(result, "vkWaitSemaphores"); +} + +static iree_status_t iree_hal_vulkan_native_semaphore_wait_with_deadline( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_time_t deadline_ns) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + iree_hal_semaphore_list_t semaphore_list = { + /*.count=*/1, + /*.semaphores=*/&base_semaphore, + /*.payload_values=*/&value, + }; + return iree_hal_vulkan_native_semaphore_multi_wait( + semaphore->logical_device, &semaphore_list, deadline_ns, 0); +} + +static iree_status_t iree_hal_vulkan_native_semaphore_wait_with_timeout( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_duration_t timeout_ns) { + return iree_hal_vulkan_native_semaphore_wait_with_deadline( + base_semaphore, value, iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +const iree_hal_semaphore_vtable_t iree_hal_vulkan_native_semaphore_vtable = { + /*.destroy=*/iree_hal_vulkan_native_semaphore_destroy, + /*.query=*/iree_hal_vulkan_native_semaphore_query, + /*.signal=*/iree_hal_vulkan_native_semaphore_signal, + /*.fail=*/iree_hal_vulkan_native_semaphore_fail, + /*.wait_with_deadline=*/iree_hal_vulkan_native_semaphore_wait_with_deadline, + /*.wait_with_timeout=*/iree_hal_vulkan_native_semaphore_wait_with_timeout, +}; diff --git a/iree/hal/vulkan/native_semaphore.h b/iree/hal/vulkan/native_semaphore.h new file mode 100644 index 0000000000000..31f2611c287a7 --- /dev/null +++ b/iree/hal/vulkan/native_semaphore.h @@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_SEMAPHORE_H_ +#define IREE_HAL_VULKAN_NATIVE_SEMAPHORE_H_ + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a timeline semaphore implemented using the native VkSemaphore type. +// This may require emulation pre-Vulkan 1.2 when timeline semaphores were only +// an extension. +iree_status_t iree_hal_vulkan_native_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + +// Returns the Vulkan timeline semaphore handle. +VkSemaphore iree_hal_vulkan_native_semaphore_handle( + iree_hal_semaphore_t* semaphore); + +// Performs a multi-wait on one or more semaphores. +// By default this is an all-wait but |wait_flags| may contain +// VK_SEMAPHORE_WAIT_ANY_BIT to change to an any-wait. +// +// Returns IREE_STATUS_DEADLINE_EXCEEDED if the wait does not complete before +// |deadline_ns| elapses. +iree_status_t iree_hal_vulkan_native_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_SEMAPHORE_H_ diff --git a/iree/hal/vulkan/native_timeline_semaphore.cc b/iree/hal/vulkan/native_timeline_semaphore.cc deleted file mode 100644 index ecc340b074476..0000000000000 --- a/iree/hal/vulkan/native_timeline_semaphore.cc +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/native_timeline_semaphore.h" - -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// static -StatusOr> NativeTimelineSemaphore::Create( - ref_ptr logical_device, uint64_t initial_value) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Create"); - - VkSemaphoreTypeCreateInfo timeline_create_info; - timeline_create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_TYPE_CREATE_INFO; - timeline_create_info.pNext = nullptr; - timeline_create_info.semaphoreType = VK_SEMAPHORE_TYPE_TIMELINE; - timeline_create_info.initialValue = initial_value; - - VkSemaphoreCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; - create_info.pNext = &timeline_create_info; - create_info.flags = 0; - VkSemaphore semaphore_handle = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateSemaphore( - *logical_device, &create_info, logical_device->allocator(), - &semaphore_handle)); - - return make_ref(std::move(logical_device), - semaphore_handle, initial_value); -} - -NativeTimelineSemaphore::NativeTimelineSemaphore( - ref_ptr logical_device, VkSemaphore handle, - uint64_t initial_value) - : logical_device_(std::move(logical_device)), handle_(handle) {} - -NativeTimelineSemaphore::~NativeTimelineSemaphore() { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::dtor"); - logical_device_->syms()->vkDestroySemaphore(*logical_device_, handle_, - logical_device_->allocator()); -} - -StatusOr NativeTimelineSemaphore::Query() { - uint64_t value = 0; - VK_RETURN_IF_ERROR(logical_device_->syms()->vkGetSemaphoreCounterValue( - *logical_device_, handle_, &value)); - if (value == UINT64_MAX) { - absl::MutexLock lock(&status_mutex_); - return status_; - } - return value; -} - -Status NativeTimelineSemaphore::Signal(uint64_t value) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Signal"); - - VkSemaphoreSignalInfo signal_info; - signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; - signal_info.pNext = nullptr; - signal_info.semaphore = handle_; - signal_info.value = value; - return VkResultToStatus(logical_device_->syms()->vkSignalSemaphore( - *logical_device_, &signal_info), - IREE_LOC); -} - -void NativeTimelineSemaphore::Fail(Status status) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Fail"); - - // NOTE: we hold the lock here as the vkSignalSemaphore may wake a waiter and - // we want to be able to immediately give them the status. - absl::MutexLock lock(&status_mutex_); - status_ = std::move(status); - - VkSemaphoreSignalInfo signal_info; - signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; - signal_info.pNext = nullptr; - signal_info.semaphore = handle_; - signal_info.value = UINT64_MAX; - // NOTE: we don't care about the result in case of failures as we are - // failing and the caller will likely be tearing everything down anyway. - logical_device_->syms()->vkSignalSemaphore(*logical_device_, &signal_info); -} - -Status NativeTimelineSemaphore::Wait(uint64_t value, Time deadline_ns) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Wait"); - - VkSemaphoreWaitInfo wait_info; - wait_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; - wait_info.pNext = nullptr; - wait_info.flags = 0; - wait_info.semaphoreCount = 1; - wait_info.pSemaphores = &handle_; - wait_info.pValues = &value; - - uint64_t timeout_ns; - if (deadline_ns == InfiniteFuture()) { - timeout_ns = UINT64_MAX; - } else if (deadline_ns == InfinitePast()) { - timeout_ns = 0; - } else { - Duration relative_ns = deadline_ns - Now(); - timeout_ns = static_cast( - relative_ns < ZeroDuration() ? ZeroDuration() : relative_ns); - } - - // NOTE: this may fail with a timeout (VK_TIMEOUT) or in the case of a - // device loss event may return either VK_SUCCESS *or* VK_ERROR_DEVICE_LOST. - // We may want to explicitly query for device loss after a successful wait - // to ensure we consistently return errors. - if (!logical_device_->syms()->vkWaitSemaphores) { - return UnknownErrorBuilder(IREE_LOC) << "vkWaitSemaphores not defined"; - } - VkResult result = logical_device_->syms()->vkWaitSemaphores( - *logical_device_, &wait_info, timeout_ns); - if (result == VK_ERROR_DEVICE_LOST) { - // Nothing we do now matters. - return VkResultToStatus(result, IREE_LOC); - } else if (result == VK_TIMEOUT) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for semaphore"; - } - - return VkResultToStatus(result, IREE_LOC); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/native_timeline_semaphore.h b/iree/hal/vulkan/native_timeline_semaphore.h deleted file mode 100644 index 03e2d45d83596..0000000000000 --- a/iree/hal/vulkan/native_timeline_semaphore.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_ -#define IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "absl/synchronization/mutex.h" -#include "iree/hal/semaphore.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A timeline semaphore implemented using the native VkSemaphore type. -// This may require emulation pre-Vulkan 1.2 when timeline semaphores were only -// an extension. -class NativeTimelineSemaphore final : public Semaphore { - public: - // Creates a timeline semaphore with the given |initial_value|. - static StatusOr> Create( - ref_ptr logical_device, uint64_t initial_value); - - NativeTimelineSemaphore(ref_ptr logical_device, - VkSemaphore handle, uint64_t initial_value); - ~NativeTimelineSemaphore() override; - - VkSemaphore handle() const { return handle_; } - - StatusOr Query() override; - - Status Signal(uint64_t value) override; - void Fail(Status status) override; - Status Wait(uint64_t value, Time deadline_ns) override; - - private: - ref_ptr logical_device_; - VkSemaphore handle_; - - // NOTE: the Vulkan semaphore is the source of truth. We only need to access - // this status (and thus take the lock) when we want to either signal failure - // or query the status in the case of the semaphore being set to UINT64_MAX. - mutable absl::Mutex status_mutex_; - Status status_ ABSL_GUARDED_BY(status_mutex_); -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_ diff --git a/iree/hal/vulkan/nop_executable_cache.cc b/iree/hal/vulkan/nop_executable_cache.cc new file mode 100644 index 0000000000000..5aa6bc003735e --- /dev/null +++ b/iree/hal/vulkan/nop_executable_cache.cc @@ -0,0 +1,105 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/nop_executable_cache.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/native_executable.h" + +using namespace iree::hal::vulkan; + +static const iree_hal_executable_format_t kExecutableFormatSpirV = + iree_hal_make_executable_format("SPVE"); + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; +} iree_hal_vulkan_nop_executable_cache_t; + +extern const iree_hal_executable_cache_vtable_t + iree_hal_vulkan_nop_executable_cache_vtable; + +static iree_hal_vulkan_nop_executable_cache_t* +iree_hal_vulkan_nop_executable_cache_cast( + iree_hal_executable_cache_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_nop_executable_cache_vtable); + return (iree_hal_vulkan_nop_executable_cache_t*)base_value; +} + +iree_status_t iree_hal_vulkan_nop_executable_cache_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { + IREE_ASSERT_ARGUMENT(out_executable_cache); + *out_executable_cache = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_nop_executable_cache_t* executable_cache = NULL; + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*executable_cache), + (void**)&executable_cache); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_nop_executable_cache_vtable, + &executable_cache->resource); + executable_cache->logical_device = logical_device; + + *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_nop_executable_cache_destroy( + iree_hal_executable_cache_t* base_executable_cache) { + iree_hal_vulkan_nop_executable_cache_t* executable_cache = + iree_hal_vulkan_nop_executable_cache_cast(base_executable_cache); + iree_allocator_t host_allocator = + executable_cache->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_cache); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_vulkan_nop_executable_cache_can_prepare_format( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_format_t format) { + return format == kExecutableFormatSpirV; +} + +static iree_status_t iree_hal_vulkan_nop_executable_cache_prepare_executable( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + iree_hal_vulkan_nop_executable_cache_t* executable_cache = + iree_hal_vulkan_nop_executable_cache_cast(base_executable_cache); + return iree_hal_vulkan_native_executable_create( + executable_cache->logical_device, + /*pipeline_cache=*/VK_NULL_HANDLE, executable_layout, caching_mode, + executable_data, out_executable); +} + +const iree_hal_executable_cache_vtable_t + iree_hal_vulkan_nop_executable_cache_vtable = { + /*.destroy=*/iree_hal_vulkan_nop_executable_cache_destroy, + /*.can_prepare_format=*/ + iree_hal_vulkan_nop_executable_cache_can_prepare_format, + /*.prepare_executable=*/ + iree_hal_vulkan_nop_executable_cache_prepare_executable, +}; diff --git a/iree/hal/vulkan/nop_executable_cache.h b/iree/hal/vulkan/nop_executable_cache.h new file mode 100644 index 0000000000000..b6ed2b6303c0d --- /dev/null +++ b/iree/hal/vulkan/nop_executable_cache.h @@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NOP_EXECUTABLE_CACHE_H_ +#define IREE_HAL_VULKAN_NOP_EXECUTABLE_CACHE_H_ + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a no-op executable cache that does not cache at all. +// This is useful to isolate pipeline caching behavior and verify compilation +// behavior. +iree_status_t iree_hal_vulkan_nop_executable_cache_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NOP_EXECUTABLE_CACHE_H_ diff --git a/iree/hal/vulkan/pipeline_cache.cc b/iree/hal/vulkan/pipeline_cache.cc deleted file mode 100644 index 5404cf9410315..0000000000000 --- a/iree/hal/vulkan/pipeline_cache.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_cache.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -PipelineCache::PipelineCache(ref_ptr logical_device) - : logical_device_(std::move(logical_device)) {} - -PipelineCache::~PipelineCache() = default; - -bool PipelineCache::CanPrepareFormat(ExecutableFormat format) const { - return format == kExecutableFormatSpirV; -} - -StatusOr> PipelineCache::PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable"); - - // Create the executable (which may itself own many pipelines). - IREE_ASSIGN_OR_RETURN( - auto executable, - PipelineExecutable::Create( - add_ref(logical_device_), - /*pipeline_cache=*/VK_NULL_HANDLE, - static_cast(executable_layout), mode, - spec)); - return executable; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/pipeline_cache.h b/iree/hal/vulkan/pipeline_cache.h deleted file mode 100644 index 60d2675e08d55..0000000000000 --- a/iree/hal/vulkan/pipeline_cache.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_ -#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "absl/container/inlined_vector.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/pipeline_executable.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class PipelineCache final : public ExecutableCache { - public: - explicit PipelineCache(ref_ptr logical_device); - ~PipelineCache() override; - - const ref_ptr& syms() const { - return logical_device_->syms(); - } - - bool CanPrepareFormat(ExecutableFormat format) const override; - - StatusOr> PrepareExecutable( - ExecutableLayout* executable_layout, ExecutableCachingModeBitfield mode, - const ExecutableSpec& spec) override; - - private: - ref_ptr logical_device_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_CACHE_H_ diff --git a/iree/hal/vulkan/pipeline_executable.cc b/iree/hal/vulkan/pipeline_executable.cc deleted file mode 100644 index b954f24a93682..0000000000000 --- a/iree/hal/vulkan/pipeline_executable.cc +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_executable.h" - -#include "absl/container/inlined_vector.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -// flatcc schemas: -#include "iree/base/flatcc.h" -#include "iree/schemas/spirv_executable_def_reader.h" -#include "iree/schemas/spirv_executable_def_verifier.h" - -// NOTE: starting to port this to C. - -// Verifies the structure of the flatbuffer so that we can avoid doing so during -// runtime. There are still some conditions we must be aware of (such as omitted -// names on functions with internal linkage), however we shouldn't need to -// bounds check anything within the flatbuffer after this succeeds. -static iree_status_t iree_hal_spirv_executable_flatbuffer_verify( - iree_const_byte_span_t flatbuffer_data) { - if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer data is not present or less than 16 bytes (%zu total)", - flatbuffer_data.data_length); - } - - // Run flatcc generated verification. This ensures all pointers are in-bounds - // and that we can safely walk the file, but not that the actual contents of - // the flatbuffer meet our expectations. - int verify_ret = iree_SpirVExecutableDef_verify_as_root( - flatbuffer_data.data, flatbuffer_data.data_length); - if (verify_ret != flatcc_verify_ok) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer verification failed: %s", - flatcc_verify_error_string(verify_ret)); - } - - iree_SpirVExecutableDef_table_t executable_def = - iree_SpirVExecutableDef_as_root(flatbuffer_data.data); - - flatbuffers_string_vec_t entry_points_vec = - iree_SpirVExecutableDef_entry_points_get(executable_def); - size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); - for (size_t i = 0; i < entry_point_count; ++i) { - if (!flatbuffers_string_len( - flatbuffers_string_vec_at(entry_points_vec, i))) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable entry point %zu has no name", i); - } - } - - if (flatbuffers_uint32_vec_len( - iree_SpirVExecutableDef_code_get(executable_def)) < 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable SPIR-V code is missing/empty"); - } - - // TODO(benvanik): pull PopulateSpecializationInfo from history and update. - // For now the compiler isn't generating them, and we don't use them. - if (iree_SpirVExecutableDef_specialization_info_is_present(executable_def)) { - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "executable uses SPIR-V specialization constants; " - "they need to be revived"); - } - - return iree_ok_status(); -} - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -class VkShaderModuleHandle : public RefObject { - public: - explicit VkShaderModuleHandle(const ref_ptr& logical_device) - : logical_device_(add_ref(logical_device)) {} - ~VkShaderModuleHandle() { reset(); } - - VkShaderModuleHandle(const VkShaderModuleHandle&) = delete; - VkShaderModuleHandle& operator=(const VkShaderModuleHandle&) = delete; - VkShaderModuleHandle(VkShaderModuleHandle&& other) noexcept - : logical_device_(std::move(other.logical_device_)), - value_(absl::exchange(other.value_, - static_cast(VK_NULL_HANDLE))) {} - VkShaderModuleHandle& operator=(VkShaderModuleHandle&& other) { - std::swap(logical_device_, other.logical_device_); - std::swap(value_, other.value_); - return *this; - } - - void reset() { - if (value_ == VK_NULL_HANDLE) return; - logical_device_->syms()->vkDestroyShaderModule( - *logical_device_, value_, logical_device_->allocator()); - value_ = VK_NULL_HANDLE; - } - - VkShaderModule value() const noexcept { return value_; } - VkShaderModule* mutable_value() noexcept { return &value_; } - operator VkShaderModule() const noexcept { return value_; } - - private: - ref_ptr logical_device_; - VkShaderModule value_ = VK_NULL_HANDLE; -}; - -} // namespace - -// static -StatusOr> PipelineExecutable::Create( - ref_ptr logical_device, VkPipelineCache pipeline_cache, - PipelineExecutableLayout* executable_layout, - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("PipelineExecutable::Create"); - const auto& syms = logical_device->syms(); - - // Verify and fetch the executable flatbuffer wrapper. - iree_const_byte_span_t executable_data = iree_make_const_byte_span( - spec.executable_data.data(), spec.executable_data.size()); - IREE_RETURN_IF_ERROR( - iree_hal_spirv_executable_flatbuffer_verify(executable_data)); - iree_SpirVExecutableDef_table_t executable_def = - iree_SpirVExecutableDef_as_root(executable_data.data); - - // Create the shader module. - VkShaderModuleCreateInfo shader_module_create_info; - shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; - shader_module_create_info.pNext = nullptr; - shader_module_create_info.flags = 0; - flatbuffers_uint32_vec_t code_vec = - iree_SpirVExecutableDef_code_get(executable_def); - shader_module_create_info.codeSize = - flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t); - shader_module_create_info.pCode = code_vec; - VkShaderModuleHandle shader_module(add_ref(logical_device)); - VK_RETURN_IF_ERROR(syms->vkCreateShaderModule( - *logical_device, &shader_module_create_info, logical_device->allocator(), - shader_module.mutable_value())); - - // Create pipelines for each entry point. - flatbuffers_string_vec_t entry_points_vec = - iree_SpirVExecutableDef_entry_points_get(executable_def); - absl::InlinedVector pipeline_create_infos; - pipeline_create_infos.resize(flatbuffers_string_vec_len(entry_points_vec)); - for (size_t entry_ordinal = 0; - entry_ordinal < flatbuffers_string_vec_len(entry_points_vec); - ++entry_ordinal) { - auto& pipeline_create_info = pipeline_create_infos[entry_ordinal]; - pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; - pipeline_create_info.pNext = nullptr; - pipeline_create_info.flags = 0; - if (!AllBitsSet(mode, ExecutableCachingMode::kAllowOptimization)) { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT; - } - if (entry_ordinal == 0) { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; - } else { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT; - } - pipeline_create_info.layout = executable_layout->handle(); - pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE; - pipeline_create_info.basePipelineIndex = 0; - auto& stage_create_info = pipeline_create_info.stage; - stage_create_info.sType = - VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - stage_create_info.pNext = nullptr; - stage_create_info.flags = 0; - stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT; - stage_create_info.module = shader_module; - stage_create_info.pName = - flatbuffers_string_vec_at(entry_points_vec, entry_ordinal); - stage_create_info.pSpecializationInfo = NULL; - } - absl::InlinedVector pipelines; - pipelines.resize(flatbuffers_string_vec_len(entry_points_vec)); - - // Some ICDs appear to leak in here, out of our control. - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines( - *logical_device, pipeline_cache, - static_cast(pipeline_create_infos.size()), - pipeline_create_infos.data(), logical_device->allocator(), - pipelines.data())); - IREE_ENABLE_LEAK_CHECKS(); - - return make_ref(std::move(logical_device), - std::move(pipelines)); -} - -PipelineExecutable::PipelineExecutable( - ref_ptr logical_device, - absl::InlinedVector pipelines) - : logical_device_(std::move(logical_device)), - pipelines_(std::move(pipelines)) {} - -PipelineExecutable::~PipelineExecutable() { - IREE_TRACE_SCOPE0("PipelineExecutable::dtor"); - for (auto pipeline : pipelines_) { - syms()->vkDestroyPipeline(*logical_device_, pipeline, - logical_device_->allocator()); - } - pipelines_.clear(); -} - -StatusOr PipelineExecutable::GetPipelineForEntryPoint( - int entry_ordinal) const { - if (entry_ordinal < 0 || entry_ordinal >= pipelines_.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal"; - } - return pipelines_[entry_ordinal]; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/pipeline_executable.h b/iree/hal/vulkan/pipeline_executable.h deleted file mode 100644 index 67159375c80f8..0000000000000 --- a/iree/hal/vulkan/pipeline_executable.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ -#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include - -#include "absl/container/inlined_vector.h" -#include "iree/base/status.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/executable_layout.h" -#include "iree/hal/executable_spec.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/native_descriptor_set.h" -#include "iree/hal/vulkan/pipeline_executable_layout.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class PipelineExecutable final : public Executable { - public: - static StatusOr> Create( - ref_ptr logical_device, VkPipelineCache pipeline_cache, - PipelineExecutableLayout* executable_layout, - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec); - - PipelineExecutable(ref_ptr logical_device, - absl::InlinedVector pipelines); - ~PipelineExecutable() override; - - const ref_ptr& syms() const { - return logical_device_->syms(); - } - - bool supports_debugging() const override { return false; } - - StatusOr GetPipelineForEntryPoint(int entry_ordinal) const; - - private: - ref_ptr logical_device_; - - // One pipeline per entry point. - absl::InlinedVector pipelines_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ diff --git a/iree/hal/vulkan/pipeline_executable_layout.cc b/iree/hal/vulkan/pipeline_executable_layout.cc deleted file mode 100644 index 3628b64609905..0000000000000 --- a/iree/hal/vulkan/pipeline_executable_layout.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_executable_layout.h" - -namespace iree { -namespace hal { -namespace vulkan { - -NativeDescriptorSetLayout::NativeDescriptorSetLayout( - ref_ptr logical_device, VkDescriptorSetLayout handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} - -NativeDescriptorSetLayout::~NativeDescriptorSetLayout() { - logical_device_->syms()->vkDestroyDescriptorSetLayout( - *logical_device_, handle_, logical_device_->allocator()); -} - -PipelineExecutableLayout::PipelineExecutableLayout( - ref_ptr logical_device, VkPipelineLayout handle, - absl::InlinedVector, 2> set_layouts) - : logical_device_(std::move(logical_device)), - handle_(handle), - set_layouts_(std::move(set_layouts)) {} - -PipelineExecutableLayout::~PipelineExecutableLayout() { - logical_device_->syms()->vkDestroyPipelineLayout( - *logical_device_, handle_, logical_device_->allocator()); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/pipeline_executable_layout.h b/iree/hal/vulkan/pipeline_executable_layout.h deleted file mode 100644 index 89203dbcf7178..0000000000000 --- a/iree/hal/vulkan/pipeline_executable_layout.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_ -#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/hal/descriptor_set_layout.h" -#include "iree/hal/executable_layout.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A DescriptorSetLayout implemented with the native VkDescriptorSetLayout type. -class NativeDescriptorSetLayout final : public DescriptorSetLayout { - public: - NativeDescriptorSetLayout(ref_ptr logical_device, - VkDescriptorSetLayout handle); - ~NativeDescriptorSetLayout() override; - - VkDescriptorSetLayout handle() const { return handle_; } - - private: - ref_ptr logical_device_; - VkDescriptorSetLayout handle_; -}; - -class PipelineExecutableLayout final : public ExecutableLayout { - public: - PipelineExecutableLayout( - ref_ptr logical_device, VkPipelineLayout handle, - absl::InlinedVector, 2> set_layouts); - ~PipelineExecutableLayout() override; - - VkPipelineLayout handle() const { return handle_; } - - absl::Span> set_layouts() const { - return set_layouts_; - } - - private: - ref_ptr logical_device_; - VkPipelineLayout handle_; - absl::InlinedVector, 2> set_layouts_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_ diff --git a/iree/hal/vulkan/registration/BUILD b/iree/hal/vulkan/registration/BUILD index e94df4970827d..56f30260b2544 100644 --- a/iree/hal/vulkan/registration/BUILD +++ b/iree/hal/vulkan/registration/BUILD @@ -35,12 +35,12 @@ cc_library( "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1", ], deps = [ + "//iree/base:core_headers", "//iree/base:flags", "//iree/base:status", "//iree/base:tracing", "//iree/hal:api", "//iree/hal/vulkan", - "//iree/hal/vulkan:utils", "@com_google_absl//absl/flags:flag", ], ) diff --git a/iree/hal/vulkan/registration/CMakeLists.txt b/iree/hal/vulkan/registration/CMakeLists.txt index 858710f80a1a5..7ae1dbea292e3 100644 --- a/iree/hal/vulkan/registration/CMakeLists.txt +++ b/iree/hal/vulkan/registration/CMakeLists.txt @@ -25,12 +25,12 @@ iree_cc_library( "driver_module.cc" DEPS absl::flags + iree::base::core_headers iree::base::flags iree::base::status iree::base::tracing iree::hal::api iree::hal::vulkan - iree::hal::vulkan::utils DEFINES "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1" PUBLIC diff --git a/iree/hal/vulkan/registration/driver_module.cc b/iree/hal/vulkan/registration/driver_module.cc index 34565e617dd09..cf396f22f8922 100644 --- a/iree/hal/vulkan/registration/driver_module.cc +++ b/iree/hal/vulkan/registration/driver_module.cc @@ -14,124 +14,77 @@ #include "iree/hal/vulkan/registration/driver_module.h" +#include + #include "absl/flags/flag.h" #include "iree/base/flags.h" #include "iree/base/status.h" +#include "iree/base/target_platform.h" #include "iree/base/tracing.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/vulkan_driver.h" +#include "iree/hal/vulkan/api.h" + +#define IREE_HAL_VULKAN_1_X_DRIVER_ID 0x564C4B31u // VLK1 ABSL_FLAG(bool, vulkan_validation_layers, true, "Enables standard Vulkan validation layers."); ABSL_FLAG(bool, vulkan_debug_utils, true, "Enables VK_EXT_debug_utils, records markers, and logs errors."); -ABSL_FLAG(bool, vulkan_debug_report, false, - "Enables VK_EXT_debug_report and logs errors."); -ABSL_FLAG(bool, vulkan_push_descriptors, true, - "Enables use of vkCmdPushDescriptorSetKHR, if available."); + ABSL_FLAG(int, vulkan_default_index, 0, "Index of the default Vulkan device."); -ABSL_FLAG(bool, vulkan_renderdoc, false, "Enables RenderDoc API integration."); + ABSL_FLAG(bool, vulkan_force_timeline_semaphore_emulation, false, "Uses timeline semaphore emulation even if native support exists."); -// Vulkan Memory Allocator (VMA) flags -#if VMA_RECORDING_ENABLED -ABSL_FLAG(std::string, vma_recording_file, "", - "File path to write a CSV containing the VMA recording."); -ABSL_FLAG(bool, vma_recording_flush_after_call, false, - "Flush the VMA recording file after every call (useful if " - "crashing/not exiting cleanly)."); -#endif // VMA_RECORDING_ENABLED - -namespace iree { -namespace hal { -namespace vulkan { -namespace { - -StatusOr> CreateVulkanDriver() { - IREE_TRACE_SCOPE0("CreateVulkanDriver"); - - // Load the Vulkan library. This will fail if the library cannot be found or - // does not have the expected functions. - IREE_ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader()); +static iree_status_t iree_hal_vulkan_create_driver_with_flags( + iree_string_view_t identifier, iree_allocator_t allocator, + iree_hal_driver_t** out_driver) { + IREE_TRACE_SCOPE(); // Setup driver options from flags. We do this here as we want to enable other // consumers that may not be using modules/command line flags to be able to // set their options however they want. - VulkanDriver::Options options; - - // TODO: validation layers have bugs when using VK_EXT_debug_report, so if the - // user requested that we force them off with a warning. Prefer using - // VK_EXT_debug_utils when available. - if (absl::GetFlag(FLAGS_vulkan_debug_report) && - absl::GetFlag(FLAGS_vulkan_validation_layers)) { - IREE_LOG(WARNING) - << "VK_EXT_debug_report has issues with modern validation " - "layers; disabling validation"; - absl::SetFlag(&FLAGS_vulkan_validation_layers, false); - } - - // REQUIRED: these are required extensions that must be present for IREE to - // work (such as those relied upon by SPIR-V kernels, etc). - options.device_options.extensibility_spec.required_extensions.push_back( - VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); - // Multiple extensions depend on VK_KHR_get_physical_device_properties2. - // This extension was deprecated in Vulkan 1.1 as its functionality was - // promoted to core, so we list it as optional even though we require it. - options.instance_extensibility.optional_extensions.push_back( - VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); - - // Timeline semaphore support is optional and will be emulated if necessary. - options.device_options.extensibility_spec.optional_extensions.push_back( - VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME); - // Polyfill layer - enable if present (instead of our custom emulation). - options.instance_extensibility.optional_layers.push_back( - "VK_LAYER_KHRONOS_timeline_semaphore"); + iree_hal_vulkan_driver_options_t driver_options; + iree_hal_vulkan_driver_options_initialize(&driver_options); + +// TODO(benvanik): make this a flag - it's useful for testing the same binary +// against multiple versions of Vulkan. +#if defined(IREE_PLATFORM_ANDROID) + // TODO(#4494): let's see when we can always enable timeline semaphores. + driver_options.api_version = VK_API_VERSION_1_1; +#else + driver_options.api_version = VK_API_VERSION_1_2; +#endif // IREE_PLATFORM_ANDROID if (absl::GetFlag(FLAGS_vulkan_validation_layers)) { - options.instance_extensibility.optional_layers.push_back( - "VK_LAYER_KHRONOS_validation"); - } - - if (absl::GetFlag(FLAGS_vulkan_debug_report)) { - options.instance_extensibility.optional_extensions.push_back( - VK_EXT_DEBUG_REPORT_EXTENSION_NAME); + driver_options.requested_features |= + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS; } if (absl::GetFlag(FLAGS_vulkan_debug_utils)) { - options.instance_extensibility.optional_extensions.push_back( - VK_EXT_DEBUG_UTILS_EXTENSION_NAME); + driver_options.requested_features |= + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS; } - if (absl::GetFlag(FLAGS_vulkan_push_descriptors)) { - options.device_options.extensibility_spec.optional_extensions.push_back( - VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); + driver_options.default_device_index = + absl::GetFlag(FLAGS_vulkan_default_index); + + if (absl::GetFlag(FLAGS_vulkan_force_timeline_semaphore_emulation)) { + driver_options.device_options.flags |= + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION; } - options.default_device_index = absl::GetFlag(FLAGS_vulkan_default_index); - options.enable_renderdoc = absl::GetFlag(FLAGS_vulkan_renderdoc); - options.device_options.force_timeline_semaphore_emulation = - absl::GetFlag(FLAGS_vulkan_force_timeline_semaphore_emulation); + // Load the Vulkan library. This will fail if the library cannot be found or + // does not have the expected functions. + iree_hal_vulkan_syms_t* syms = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_syms_create_from_system_loader(allocator, &syms)); -#if VMA_RECORDING_ENABLED - options.device_options.vma_options.recording_file = - absl::GetFlag(FLAGS_vma_recording_file); - options.device_options.vma_options.recording_flush_after_call = - absl::GetFlag(FLAGS_vma_recording_flush_after_call); -#endif // VMA_RECORDING_ENABLED + iree_status_t status = iree_hal_vulkan_driver_create( + identifier, &driver_options, syms, allocator, out_driver); - // Create the driver and VkInstance. - return VulkanDriver::Create(options, std::move(syms)); + iree_hal_vulkan_syms_release(syms); + return status; } -} // namespace -} // namespace vulkan -} // namespace hal -} // namespace iree - -#include - -#define IREE_HAL_VULKAN_1_X_DRIVER_ID 0x564C4B31u // VLK1 - static iree_status_t iree_hal_vulkan_driver_factory_enumerate( void* self, const iree_hal_driver_info_t** out_driver_infos, iree_host_size_t* out_driver_info_count) { @@ -155,9 +108,13 @@ static iree_status_t iree_hal_vulkan_driver_factory_try_create( " is provided by this factory", driver_id); } - IREE_ASSIGN_OR_RETURN(auto driver, iree::hal::vulkan::CreateVulkanDriver()); - *out_driver = reinterpret_cast(driver.release()); - return iree_ok_status(); + + // When we expose more than one driver (different vulkan versions, etc) we + // can name them here: + iree_string_view_t identifier = iree_make_cstring_view("vulkan"); + + return iree_hal_vulkan_create_driver_with_flags(identifier, allocator, + out_driver); } IREE_API_EXPORT iree_status_t IREE_API_CALL diff --git a/iree/hal/vulkan/renderdoc_capture_manager.cc b/iree/hal/vulkan/renderdoc_capture_manager.cc deleted file mode 100644 index b8b06ce5e3aae..0000000000000 --- a/iree/hal/vulkan/renderdoc_capture_manager.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/renderdoc_capture_manager.h" - -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/target_platform.h" -#include "iree/base/tracing.h" - -#if !defined(IREE_PLATFORM_WINDOWS) -#include -#endif // IREE_PLATFORM_WINDOWS - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -static const char* kRenderDocSearchNames[] = { -#if defined(IREE_PLATFORM_WINDOWS) - "renderdoc.dll", - "C:/Program Files/RenderDoc/renderdoc.dll", -#else - "librenderdoc.so", -#endif // IREE_PLATFORM_WINDOWS -}; - -} // namespace - -RenderDocCaptureManager::RenderDocCaptureManager() {} - -RenderDocCaptureManager::~RenderDocCaptureManager() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::dtor"); - Disconnect(); -} - -Status RenderDocCaptureManager::Connect() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::Connect"); - - if (renderdoc_library_ != nullptr) { - return OkStatus(); - } - - IREE_ASSIGN_OR_RETURN( - renderdoc_library_, - DynamicLibrary::Load(absl::MakeSpan(kRenderDocSearchNames))); - - auto renderdoc_get_api_fn = - renderdoc_library_->GetSymbol("RENDERDOC_GetAPI"); - int ret = renderdoc_get_api_fn(eRENDERDOC_API_Version_1_4_0, - (void**)&renderdoc_api_); - if (ret != 1) { - renderdoc_api_ = nullptr; - return InternalErrorBuilder(IREE_LOC) - << "Failed to get RenderDoc API object"; - } - - IREE_LOG(INFO) << "Connected to RenderDoc's API; writing captures to " - << renderdoc_api_->GetCaptureFilePathTemplate(); - - return OkStatus(); -} - -void RenderDocCaptureManager::Disconnect() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::Disconnect"); - - if (renderdoc_library_ == nullptr) { - return; - } - - if (is_capturing()) { - StopCapture(); - } - - renderdoc_api_ = nullptr; - renderdoc_library_.reset(); -} - -void RenderDocCaptureManager::StartCapture() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::StartCapture"); - - IREE_CHECK(is_connected()) << "Can't start capture when not connected"; - IREE_CHECK(!is_capturing()) << "Capture is already started"; - - IREE_LOG(INFO) << "Starting RenderDoc capture"; - renderdoc_api_->StartFrameCapture(NULL, NULL); -} - -void RenderDocCaptureManager::StopCapture() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::StopCapture"); - - IREE_CHECK(is_capturing()) << "Can't stop capture when not capturing"; - - IREE_LOG(INFO) << "Ending RenderDoc capture"; - renderdoc_api_->EndFrameCapture(NULL, NULL); -} - -bool RenderDocCaptureManager::is_capturing() const { - if (!is_connected()) { - return false; - } - - return renderdoc_api_->IsFrameCapturing() == 1; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/renderdoc_capture_manager.h b/iree/hal/vulkan/renderdoc_capture_manager.h deleted file mode 100644 index 26eb7b6bf62a1..0000000000000 --- a/iree/hal/vulkan/renderdoc_capture_manager.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_RENDERDOC_CAPTURE_MANAGER_H_ -#define IREE_HAL_VULKAN_RENDERDOC_CAPTURE_MANAGER_H_ - -#include "iree/base/dynamic_library.h" -#include "iree/base/status.h" -#include "iree/hal/debug_capture_manager.h" -#include "third_party/renderdoc_api/app/renderdoc_app.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// Capture manager using RenderDoc to record Vulkan commands. -// See https://renderdoc.org/ and https://github.com/baldurk/renderdoc. -class RenderDocCaptureManager final : public DebugCaptureManager { - public: - RenderDocCaptureManager(); - ~RenderDocCaptureManager() override; - - // Note: Connect() must be called *before* creating a VkInstance. - Status Connect() override; - - void Disconnect() override; - - bool is_connected() const override { return renderdoc_api_ != nullptr; } - - // Note: StartCapture() must be called *after* creating a VkDevice. - void StartCapture() override; - - void StopCapture() override; - - bool is_capturing() const override; - - private: - std::unique_ptr renderdoc_library_; - RENDERDOC_API_1_4_0* renderdoc_api_ = nullptr; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_RENDERDOC_CAPTURE_MANAGER_H_ diff --git a/iree/hal/vulkan/serializing_command_queue.cc b/iree/hal/vulkan/serializing_command_queue.cc index f4c3ba8a2ad6b..b524f7bda4523 100644 --- a/iree/hal/vulkan/serializing_command_queue.cc +++ b/iree/hal/vulkan/serializing_command_queue.cc @@ -20,11 +20,9 @@ #include "iree/base/api.h" #include "iree/base/memory.h" #include "iree/base/tracing.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/semaphore.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/direct_command_buffer.h" -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" +#include "iree/hal/vulkan/emulated_semaphore.h" #include "iree/hal/vulkan/status_util.h" namespace iree { @@ -39,199 +37,195 @@ namespace { // batch is ready to be submitted to GPU. // |wait_semaphores| and |signal_semaphores| will be filled with the binary // `VkSemaphores` on success. -StatusOr TryToPrepareSemaphores( +iree_status_t TryToPrepareSemaphores( const absl::InlinedVector& batch_wait_semaphores, const absl::InlinedVector& batch_signal_semaphores, const ref_ptr& batch_fence, absl::InlinedVector* wait_semaphores, - absl::InlinedVector* signal_semaphores) { + absl::InlinedVector* signal_semaphores, + bool* out_ready_to_submit) { IREE_TRACE_SCOPE0("TryToPrepareSemaphores"); - IREE_DVLOG(3) << "TryToPrepareSemaphores"; + *out_ready_to_submit = false; wait_semaphores->clear(); for (const auto& timeline_semaphore : batch_wait_semaphores) { - IREE_DVLOG(3) << "Preparing binary VkSemaphore for timeline semaphore " - << timeline_semaphore.semaphore << ".."; // Query first to progress this timeline semaphore to the furthest. - IREE_ASSIGN_OR_RETURN(auto signaled_value, - timeline_semaphore.semaphore->Query()); + uint64_t signaled_value = 0; + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_query(timeline_semaphore.first, &signaled_value)); // If it's already signaled to a value greater than we require here, // we can just ignore this semaphore now. - if (signaled_value >= timeline_semaphore.value) { - IREE_DVLOG(3) << "..already signaled past; ignoring"; + if (signaled_value >= timeline_semaphore.second) { continue; } - // SerializingCommandQueue only works with EmulatedTimelineSemaphore. - auto* emulated_semaphore = - static_cast(timeline_semaphore.semaphore); - // Otherwise try to get a binary semaphore for this time point so that // we can wait on. - VkSemaphore binary_semaphore = emulated_semaphore->GetWaitSemaphore( - timeline_semaphore.value, batch_fence); - - if (binary_semaphore == VK_NULL_HANDLE) { + // TODO(antiagainst): if this fails we need to cancel. + VkSemaphore wait_semaphore = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_emulated_semaphore_acquire_wait_handle( + timeline_semaphore.first, timeline_semaphore.second, batch_fence, + &wait_semaphore)); + wait_semaphores->push_back(wait_semaphore); + + if (wait_semaphore == VK_NULL_HANDLE) { // We cannot wait on this time point yet: there are no previous semaphores - // submitted to the GPU that can signal a value greater than what's + // submitted to the GPU that can signal a value greater than what's // desired here. // Cancel the wait so others may make progress. - for (VkSemaphore semaphore : *wait_semaphores) { + // TODO(antiagainst): if any of these fail we need to cancel. + for (iree_host_size_t i = 0; i < batch_wait_semaphores.size(); ++i) { + if (!wait_semaphores->at(i)) break; IREE_RETURN_IF_ERROR( - emulated_semaphore->CancelWaitSemaphore(semaphore)); + iree_hal_vulkan_emulated_semaphore_cancel_wait_handle( + batch_wait_semaphores[i].first, wait_semaphores->at(i))); } // This batch cannot be submitted to GPU yet. - return false; + return iree_ok_status(); } - IREE_DVLOG(3) << "..acqiured binary VkSemaphore " << binary_semaphore; - - wait_semaphores->push_back(binary_semaphore); } // We've collected all necessary binary semaphores for each timeline we need // to wait on. Now prepare binary semaphores for signaling. signal_semaphores->clear(); for (const auto& timeline_semaphore : batch_signal_semaphores) { - IREE_DVLOG(3) << "Preparing binary VkSemaphore for timeline semaphore " - << timeline_semaphore.semaphore << ".."; // SerializingCommandQueue only works with EmulatedTimelineSemaphore. - auto* emulated_semaphore = - static_cast(timeline_semaphore.semaphore); - - IREE_ASSIGN_OR_RETURN(auto binary_semaphore, - emulated_semaphore->GetSignalSemaphore( - timeline_semaphore.value, batch_fence)); - signal_semaphores->push_back(binary_semaphore); - IREE_DVLOG(3) << "..acqiured binary VkSemaphore " << binary_semaphore; + VkSemaphore signal_semaphore = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_emulated_semaphore_acquire_signal_handle( + timeline_semaphore.first, timeline_semaphore.second, batch_fence, + &signal_semaphore)); + signal_semaphores->push_back(signal_semaphore); } // Good to submit! - IREE_DVLOG(3) << "Succeeded in preparing binary VkSemaphores for submission"; - return true; + *out_ready_to_submit = true; + return iree_ok_status(); } // Prepares `VkSubmitInfo` to submit the given list of |command_buffers| that // waiting on |wait_semaphores| and signalling |signal_semaphores|. Necessary // structures are allocated from |arena| and the result `VkSubmitInfo` is // written to |submit_info|. -void PrepareSubmitInfo( - const absl::InlinedVector& wait_semaphores, - absl::Span command_buffers, - const absl::InlinedVector& signal_semaphores, - VkSubmitInfo* submit_info, Arena* arena) { - IREE_TRACE_SCOPE0("PrepareSubmitInfo"); - +void PrepareSubmitInfo(absl::Span wait_semaphore_handles, + absl::Span command_buffer_handles, + absl::Span signal_semaphore_handles, + VkSubmitInfo* submit_info, Arena* arena) { // TODO(benvanik): see if we can go to finer-grained stages. // For example, if this was just queue ownership transfers then we can use // the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT. - VkPipelineStageFlags dst_stage_mask = - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; - - auto wait_semaphore_handles = - arena->AllocateSpan(wait_semaphores.size()); auto wait_dst_stage_masks = - arena->AllocateSpan(wait_semaphores.size()); - for (size_t i = 0, e = wait_semaphores.size(); i < e; ++i) { - wait_semaphore_handles[i] = wait_semaphores[i]; - wait_dst_stage_masks[i] = dst_stage_mask; + arena->AllocateSpan(wait_semaphore_handles.size()); + for (size_t i = 0, e = wait_semaphore_handles.size(); i < e; ++i) { + wait_dst_stage_masks[i] = + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; } - auto signal_semaphore_handles = - arena->AllocateSpan(signal_semaphores.size()); - for (size_t i = 0, e = signal_semaphores.size(); i < e; ++i) { - signal_semaphore_handles[i] = signal_semaphores[i]; + // NOTE: this code does some very weird things - the handles we take in as + // args are mutated in-place after this function is called so we can't + // reference them here. If we were going to preserve this code post-Vulkan 1.2 + // then we'd really want to rework all of this to properly use the arena from + // the start instead of all this InlinedVector tomfoolery. + auto wait_semaphores = + arena->AllocateSpan(wait_semaphore_handles.size()); + for (size_t i = 0, e = wait_semaphore_handles.size(); i < e; ++i) { + wait_semaphores[i] = wait_semaphore_handles[i]; } - - auto command_buffer_handles = - arena->AllocateSpan(command_buffers.size()); - for (size_t i = 0, e = command_buffers.size(); i < e; ++i) { - const auto& command_buffer = command_buffers[i]; - auto* direct_command_buffer = - static_cast(command_buffer->impl()); - command_buffer_handles[i] = direct_command_buffer->handle(); + auto command_buffers = + arena->AllocateSpan(command_buffer_handles.size()); + for (size_t i = 0, e = command_buffer_handles.size(); i < e; ++i) { + command_buffers[i] = command_buffer_handles[i]; + } + auto signal_semaphores = + arena->AllocateSpan(signal_semaphore_handles.size()); + for (size_t i = 0, e = signal_semaphore_handles.size(); i < e; ++i) { + signal_semaphores[i] = signal_semaphore_handles[i]; } submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; submit_info->pNext = nullptr; submit_info->waitSemaphoreCount = - static_cast(wait_semaphore_handles.size()); - submit_info->pWaitSemaphores = wait_semaphore_handles.data(); + static_cast(wait_semaphores.size()); + submit_info->pWaitSemaphores = wait_semaphores.data(); submit_info->pWaitDstStageMask = wait_dst_stage_masks.data(); submit_info->commandBufferCount = - static_cast(command_buffer_handles.size()); - submit_info->pCommandBuffers = command_buffer_handles.data(); + static_cast(command_buffers.size()); + submit_info->pCommandBuffers = command_buffers.data(); submit_info->signalSemaphoreCount = - static_cast(signal_semaphore_handles.size()); - submit_info->pSignalSemaphores = signal_semaphore_handles.data(); + static_cast(signal_semaphores.size()); + submit_info->pSignalSemaphores = signal_semaphores.data(); } } // namespace SerializingCommandQueue::SerializingCommandQueue( - std::string name, CommandCategoryBitfield supported_categories, - const ref_ptr& logical_device, - const ref_ptr& fence_pool, VkQueue queue) - : CommandQueue(std::move(name), supported_categories), - logical_device_(add_ref(logical_device)), - fence_pool_(add_ref(fence_pool)), - queue_(queue) {} - -SerializingCommandQueue::~SerializingCommandQueue() { - IREE_TRACE_SCOPE0("SerializingCommandQueue::dtor"); - absl::MutexLock lock(&mutex_); - syms()->vkQueueWaitIdle(queue_); -} + VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, VkQueue queue, + TimePointFencePool* fence_pool) + : CommandQueue(logical_device, std::move(name), supported_categories, + queue), + fence_pool_(fence_pool) {} + +SerializingCommandQueue::~SerializingCommandQueue() = default; -Status SerializingCommandQueue::Submit( - absl::Span batches) { +iree_status_t SerializingCommandQueue::Submit( + iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { IREE_TRACE_SCOPE0("SerializingCommandQueue::Submit"); - IREE_DVLOG(2) << "SerializingCommandQueue::Submit"; - absl::MutexLock lock(&mutex_); - for (size_t i = 0; i < batches.size(); ++i) { + IntrusiveList> new_submissions; + for (iree_host_size_t i = 0; i < batch_count; ++i) { + const iree_hal_submission_batch_t* batch = &batches[i]; + // Grab a fence for this submission first. This will be used to check the // progress of emulated timeline semaphores later. - IREE_ASSIGN_OR_RETURN(auto fence, fence_pool_->Acquire()); auto submission = std::make_unique(); - submission->batch = PendingBatch{ - {batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()}, - {batches[i].command_buffers.begin(), batches[i].command_buffers.end()}, - {batches[i].signal_semaphores.begin(), - batches[i].signal_semaphores.end()}}; - submission->fence = std::move(fence); - deferred_submissions_.push_back(std::move(submission)); - } - - return ProcessDeferredSubmissions().status(); -} + IREE_ASSIGN_OR_RETURN(submission->fence, fence_pool_->Acquire()); -StatusOr SerializingCommandQueue::ProcessDeferredSubmissions() { - IREE_TRACE_SCOPE0("SerializingCommandQueue::ProcessDeferredSubmissions"); - IREE_DVLOG(2) << "SerializingCommandQueue::ProcessDeferredSubmissions"; + submission->wait_semaphores.resize(batch->wait_semaphores.count); + for (iree_host_size_t j = 0; j < batch->wait_semaphores.count; ++j) { + submission->wait_semaphores[j] = { + batch->wait_semaphores.semaphores[j], + batch->wait_semaphores.payload_values[j]}; + } - // Prepare `VkSubmitInfo`s for all submissions we are able to submit. + submission->command_buffers.resize(batch->command_buffer_count); + for (iree_host_size_t j = 0; j < batch->command_buffer_count; ++j) { + submission->command_buffers[j] = + iree_hal_vulkan_direct_command_buffer_handle( + batch->command_buffers[j]); + } - // Note that we must keep all arrays referenced alive until submission - // completes and since there are a bunch of them we use an arena. - Arena arena(4 * 1024); + submission->signal_semaphores.resize(batch->signal_semaphores.count); + for (iree_host_size_t j = 0; j < batch->signal_semaphores.count; ++j) { + submission->signal_semaphores[j] = { + batch->signal_semaphores.semaphores[j], + batch->signal_semaphores.payload_values[j]}; + } - absl::InlinedVector submit_infos; - absl::InlinedVector submit_fences; + new_submissions.push_back(std::move(submission)); + } - absl::InlinedVector wait_semaphores; - absl::InlinedVector signal_semaphores; + iree_slim_mutex_lock(&queue_mutex_); + deferred_submissions_.merge_from(&new_submissions); + iree_status_t status = ProcessDeferredSubmissions(); + iree_slim_mutex_unlock(&queue_mutex_); + return status; +} - // A list of submissions that still needs to be deferred. - IntrusiveList> remaining_submissions; +iree_status_t SerializingCommandQueue::ProcessDeferredSubmissions( + bool* out_work_submitted) { + IREE_TRACE_SCOPE0("SerializingCommandQueue::ProcessDeferredSubmissions"); + if (out_work_submitted) *out_work_submitted = false; // We need to return all remaining submissions back to the queue to avoid // dropping work. + IntrusiveList> remaining_submissions; auto submission_cleanup = MakeCleanup([this, &remaining_submissions]() { // Disable thread-safety-analysis as it doesn't understand this lambda. -// - This entire function is ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) +// - This entire function is ABSL_EXCLUSIVE_LOCKS_REQUIRED(queue_mutex_) // - This Cleanup object is destroyed when it drops out of scope // - The mutex is always held when executing this function #ifdef __clang__ @@ -242,122 +236,125 @@ StatusOr SerializingCommandQueue::ProcessDeferredSubmissions() { deferred_submissions_.push_back( remaining_submissions.take(remaining_submissions.front())); } - - IREE_DVLOG(2) << deferred_submissions_.size() - << " deferred submissions still remaining"; #ifdef __clang__ #pragma clang diagnostic pop #endif }); + Arena arena(4 * 1024); + absl::InlinedVector submit_infos; + absl::InlinedVector submit_fences; while (!deferred_submissions_.empty()) { - IREE_DVLOG(2) << "Looking at deferred submission with timepoint fence " - << deferred_submissions_.front()->fence.get() << ".."; - - wait_semaphores.clear(); - signal_semaphores.clear(); - FencedSubmission* submission = deferred_submissions_.front(); - const PendingBatch& batch = submission->batch; ref_ptr& fence = submission->fence; - IREE_ASSIGN_OR_RETURN( - bool ready_to_submit, - TryToPrepareSemaphores(batch.wait_semaphores, batch.signal_semaphores, - fence, &wait_semaphores, &signal_semaphores)); - + absl::InlinedVector wait_semaphores; + absl::InlinedVector signal_semaphores; + bool ready_to_submit = false; + IREE_RETURN_IF_ERROR(TryToPrepareSemaphores( + submission->wait_semaphores, submission->signal_semaphores, fence, + &wait_semaphores, &signal_semaphores, &ready_to_submit)); if (ready_to_submit) { submit_infos.emplace_back(); - PrepareSubmitInfo(wait_semaphores, batch.command_buffers, + PrepareSubmitInfo(wait_semaphores, submission->command_buffers, signal_semaphores, &submit_infos.back(), &arena); submit_fences.push_back(fence->value()); pending_fences_.emplace_back(std::move(fence)); deferred_submissions_.pop_front(); - IREE_DVLOG(2) << "..ready to submit"; } else { // We need to defer the submission until later. remaining_submissions.push_back(deferred_submissions_.take(submission)); - IREE_DVLOG(2) << "..not ready to submit"; } } - - if (submit_infos.empty()) return false; - - auto infos = arena.AllocateSpan(submit_infos.size()); - for (size_t i = 0, e = submit_infos.size(); i < e; ++i) { - infos[i] = submit_infos[i]; + if (submit_infos.empty()) { + if (out_work_submitted) *out_work_submitted = false; + return iree_ok_status(); } // Note: We might be able to batch the submission but it involves non-trivial // fence handling. We can handle that if really needed. for (size_t i = 0, e = submit_infos.size(); i < e; ++i) { - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit( - queue_, /*submitCount=*/1, &submit_infos[i], submit_fences[i])); + VK_RETURN_IF_ERROR( + syms()->vkQueueSubmit(queue_, /*submitCount=*/1, &submit_infos[i], + submit_fences[i]), + "vkQueueSubmit"); } - IREE_DVLOG(2) << "Released " << submit_infos.size() - << " deferred submissions"; - - return true; + if (out_work_submitted) *out_work_submitted = true; + return iree_ok_status(); } -Status SerializingCommandQueue::WaitIdle(Time deadline_ns) { - absl::MutexLock lock(&mutex_); - IREE_DVLOG(2) << "SerializingCommandQueue::WaitIdle"; +iree_status_t SerializingCommandQueue::WaitIdle(iree_time_t deadline_ns) { + iree_status_t status = iree_ok_status(); - if (deadline_ns == InfiniteFuture()) { + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { IREE_TRACE_SCOPE0("SerializingCommandQueue::WaitIdle#vkQueueWaitIdle"); // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it // requires fewer calls into the driver). + iree_slim_mutex_lock(&queue_mutex_); + // Complete all pending work on the queue. - VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); + status = + VK_RESULT_TO_STATUS(syms()->vkQueueWaitIdle(queue_), "vkQueueWaitIdle"); + if (!iree_status_is_ok(status)) { + iree_slim_mutex_unlock(&queue_mutex_); + return status; + } pending_fences_.clear(); // Submit and complete all deferred work. while (!deferred_submissions_.empty()) { - IREE_ASSIGN_OR_RETURN(bool work_submitted, ProcessDeferredSubmissions()); + bool work_submitted = false; + status = ProcessDeferredSubmissions(&work_submitted); + if (!iree_status_is_ok(status)) break; if (work_submitted) { - VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); + status = VK_RESULT_TO_STATUS(syms()->vkQueueWaitIdle(queue_), + "vkQueueWaitIdle"); + if (!iree_status_is_ok(status)) break; pending_fences_.clear(); } } - return OkStatus(); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } IREE_TRACE_SCOPE0("SerializingCommandQueue::WaitIdle#Fence"); // Keep trying to submit more workload to the GPU until reaching the deadline. + iree_slim_mutex_lock(&queue_mutex_); do { - IREE_RETURN_IF_ERROR(ProcessDeferredSubmissions().status()); + status = ProcessDeferredSubmissions(); + bool has_deferred_submissions = !deferred_submissions_.empty(); + absl::InlinedVector fence_handles(pending_fences_.size()); + for (size_t i = 0; i < pending_fences_.size(); ++i) { + fence_handles[i] = pending_fences_[i]->value(); + } + if (!iree_status_is_ok(status)) { + break; // unable to process submissions + } else if (!has_deferred_submissions && fence_handles.empty()) { + break; // no more work - idle achieved + } uint64_t timeout_ns; - if (deadline_ns == InfiniteFuture()) { + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { timeout_ns = UINT64_MAX; - } else if (deadline_ns == InfinitePast()) { + } else if (deadline_ns == IREE_TIME_INFINITE_PAST) { timeout_ns = 0; } else { // Convert to relative time in nanoseconds. - // The implementation may not wait with this granularity (like, by - // 10000x). - Duration relative_ns = deadline_ns - Now(); - if (relative_ns < ZeroDuration()) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for idle"; + // The implementation may not wait with this granularity (like by 10000x). + iree_time_t now_ns = iree_time_now(); + if (deadline_ns < now_ns) { + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); } - timeout_ns = static_cast(relative_ns); + timeout_ns = (uint64_t)(deadline_ns - now_ns); } - - if (pending_fences_.empty()) continue; - - std::vector fences; - fences.reserve(pending_fences_.size()); - for (const auto& fence : pending_fences_) fences.push_back(fence->value()); - VkResult result = syms()->vkWaitForFences( - *logical_device_, static_cast(fences.size()), fences.data(), + *logical_device_, static_cast(fence_handles.size()), + fence_handles.data(), /*waitAll=*/VK_TRUE, timeout_ns); switch (result) { @@ -365,56 +362,62 @@ Status SerializingCommandQueue::WaitIdle(Time deadline_ns) { pending_fences_.clear(); break; case VK_TIMEOUT: - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for idle"; + status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + break; default: - return VkResultToStatus(result, IREE_LOC); + status = VK_RESULT_TO_STATUS(result, "vkWaitForFences"); + break; } // As long as there is submitted or deferred work still pending. - } while (!pending_fences_.empty() || !deferred_submissions_.empty()); - - return OkStatus(); + } while (iree_status_is_ok(status)); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } -Status SerializingCommandQueue::AdvanceQueueSubmission() { - absl::MutexLock lock(&mutex_); +iree_status_t SerializingCommandQueue::AdvanceQueueSubmission() { // The returned value just indicates whether there were newly ready // submissions gotten submitted to the GPU. Other callers might be // interested in that information but for this API we just want to advance // queue submisison if possible. So we ignore it here. - IREE_ASSIGN_OR_RETURN(std::ignore, ProcessDeferredSubmissions()); - return OkStatus(); + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = ProcessDeferredSubmissions(); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } void SerializingCommandQueue::AbortQueueSubmission() { - absl::MutexLock lock(&mutex_); + iree_slim_mutex_lock(&queue_mutex_); // We have fences in deferred_submissions_ but they are not submitted to GPU // yet so we don't need to reset. deferred_submissions_.clear(); - std::vector fences; - fences.reserve(pending_fences_.size()); - for (const auto& fence : pending_fences_) fences.push_back(fence->value()); + absl::InlinedVector fence_handles(pending_fences_.size()); + for (size_t i = 0; i < pending_fences_.size(); ++i) { + fence_handles[i] = pending_fences_[i]->value(); + } syms()->vkWaitForFences(*logical_device_, - static_cast(fences.size()), fences.data(), + static_cast(fence_handles.size()), + fence_handles.data(), /*waitAll=*/VK_TRUE, /*timeout=*/UINT64_MAX); + // Clear the list. Fences will be automatically returned back to the queue // after refcount reaches 0. pending_fences_.clear(); + + iree_slim_mutex_unlock(&queue_mutex_); } void SerializingCommandQueue::SignalFences(absl::Span fences) { - auto span_contains = [&fences](VkFence fence) { + const auto span_contains = [fences](VkFence fence) { for (VkFence f : fences) { if (f == fence) return true; } return false; }; - absl::MutexLock lock(&mutex_); - + iree_slim_mutex_lock(&queue_mutex_); auto it = pending_fences_.begin(); while (it != pending_fences_.end()) { if (span_contains((*it)->value())) { @@ -423,6 +426,7 @@ void SerializingCommandQueue::SignalFences(absl::Span fences) { ++it; } } + iree_slim_mutex_unlock(&queue_mutex_); } } // namespace vulkan diff --git a/iree/hal/vulkan/serializing_command_queue.h b/iree/hal/vulkan/serializing_command_queue.h index f98226d3815c1..6b413bd8fade7 100644 --- a/iree/hal/vulkan/serializing_command_queue.h +++ b/iree/hal/vulkan/serializing_command_queue.h @@ -24,13 +24,11 @@ #include "absl/base/thread_annotations.h" #include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" #include "iree/base/intrusive_list.h" #include "iree/base/ref_ptr.h" #include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/command_queue.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/command_queue.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/timepoint_util.h" @@ -39,6 +37,8 @@ namespace iree { namespace hal { namespace vulkan { +using SemaphoreValue = std::pair; + // A command queue that potentially defers and serializes command buffer // submission to the GPU. // @@ -52,23 +52,22 @@ namespace vulkan { // the GPU. class SerializingCommandQueue final : public CommandQueue { public: - SerializingCommandQueue(std::string name, - CommandCategoryBitfield supported_categories, - const ref_ptr& logical_device, - const ref_ptr& fence_pool, - VkQueue queue); + SerializingCommandQueue(VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, + VkQueue queue, TimePointFencePool* fence_pool); ~SerializingCommandQueue() override; const ref_ptr& syms() const { return logical_device_->syms(); } - Status Submit(absl::Span batches) override; + iree_status_t Submit(iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) override; - Status WaitIdle(Time deadline_ns) override; + iree_status_t WaitIdle(iree_time_t deadline_ns) override; // Releases all deferred submissions ready to submit to the GPU. - Status AdvanceQueueSubmission(); + iree_status_t AdvanceQueueSubmission(); // Aborts all deferred submissions and waits for submitted work to complete. void AbortQueueSubmission(); @@ -77,37 +76,26 @@ class SerializingCommandQueue final : public CommandQueue { void SignalFences(absl::Span fences); private: - struct PendingBatch { - absl::InlinedVector wait_semaphores; - absl::InlinedVector command_buffers; - absl::InlinedVector signal_semaphores; - }; // A submission batch together with the fence to singal its status. struct FencedSubmission : public IntrusiveLinkBase { - PendingBatch batch; + absl::InlinedVector wait_semaphores; + absl::InlinedVector command_buffers; + absl::InlinedVector signal_semaphores; ref_ptr fence; }; // Processes deferred submissions in this queue and returns whether there are // new workload submitted to the GPU if no errors happen. - StatusOr ProcessDeferredSubmissions() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - ref_ptr logical_device_; + iree_status_t ProcessDeferredSubmissions(bool* out_work_submitted = NULL); - ref_ptr fence_pool_; - - mutable absl::Mutex mutex_; + TimePointFencePool* fence_pool_; // A list of fences that are submitted to GPU. absl::InlinedVector, 4> pending_fences_ - ABSL_GUARDED_BY(mutex_); + IREE_GUARDED_BY(mutex_); // A list of deferred submissions that haven't been submitted to GPU. IntrusiveList> deferred_submissions_ - ABSL_GUARDED_BY(mutex_); - - // VkQueue needs to be externally synchronized. - VkQueue queue_ ABSL_GUARDED_BY(mutex_); + IREE_GUARDED_BY(mutex_); }; } // namespace vulkan diff --git a/iree/hal/vulkan/status_util.cc b/iree/hal/vulkan/status_util.c similarity index 72% rename from iree/hal/vulkan/status_util.cc rename to iree/hal/vulkan/status_util.c index 8231db1cb0aa7..6117caa9bd045 100644 --- a/iree/hal/vulkan/status_util.cc +++ b/iree/hal/vulkan/status_util.c @@ -14,49 +14,48 @@ #include "iree/hal/vulkan/status_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -Status VkResultToStatus(VkResult result, SourceLocation loc) { +iree_status_t iree_hal_vulkan_result_to_status(VkResult result, + const char* file, + uint32_t line) { switch (result) { // Success codes. case VK_SUCCESS: // Command successfully completed. - return OkStatus(); + return iree_ok_status(); case VK_NOT_READY: // A fence or query has not yet completed. - return OkStatus(); + return iree_ok_status(); case VK_TIMEOUT: // A wait operation has not completed in the specified time. - return OkStatus(); + return iree_ok_status(); case VK_EVENT_SET: // An event is signaled. - return OkStatus(); + return iree_ok_status(); case VK_EVENT_RESET: // An event is unsignaled. - return OkStatus(); + return iree_ok_status(); case VK_INCOMPLETE: // A return array was too small for the result. - return OkStatus(); + return iree_ok_status(); case VK_SUBOPTIMAL_KHR: // A swapchain no longer matches the surface properties exactly, but can // still be used to present to the surface successfully. - return OkStatus(); + return iree_ok_status(); // Error codes. case VK_ERROR_OUT_OF_HOST_MEMORY: // A host memory allocation has failed. - return ResourceExhaustedErrorBuilder(loc) - << "VK_ERROR_OUT_OF_HOST_MEMORY"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_OUT_OF_HOST_MEMORY"); case VK_ERROR_OUT_OF_DEVICE_MEMORY: // A device memory allocation has failed. - return ResourceExhaustedErrorBuilder(loc) - << "VK_ERROR_OUT_OF_DEVICE_MEMORY"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_OUT_OF_DEVICE_MEMORY"); case VK_ERROR_INITIALIZATION_FAILED: // Initialization of an object could not be completed for // implementation-specific reasons. - return InternalErrorBuilder(loc) << "VK_ERROR_INITIALIZATION_FAILED"; + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "VK_ERROR_INITIALIZATION_FAILED"); case VK_ERROR_DEVICE_LOST: // The logical or physical device has been lost. // @@ -125,77 +124,87 @@ Status VkResultToStatus(VkResult result, SourceLocation loc) { // command buffer is in the pending state, or whether resources are // considered in-use by the device, a return value of // VK_ERROR_DEVICE_LOST is equivalent to VK_SUCCESS. - return InternalErrorBuilder(loc) << "VK_ERROR_DEVICE_LOST"; + return iree_make_status(IREE_STATUS_INTERNAL, "VK_ERROR_DEVICE_LOST"); case VK_ERROR_MEMORY_MAP_FAILED: // Mapping of a memory object has failed. - return InternalErrorBuilder(loc) << "VK_ERROR_MEMORY_MAP_FAILED"; + return iree_make_status(IREE_STATUS_INTERNAL, + "VK_ERROR_MEMORY_MAP_FAILED"); case VK_ERROR_LAYER_NOT_PRESENT: // A requested layer is not present or could not be loaded. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_LAYER_NOT_PRESENT"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_LAYER_NOT_PRESENT"); case VK_ERROR_EXTENSION_NOT_PRESENT: // A requested extension is not supported. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_EXTENSION_NOT_PRESENT"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_EXTENSION_NOT_PRESENT"); case VK_ERROR_FEATURE_NOT_PRESENT: // A requested feature is not supported. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_FEATURE_NOT_PRESENT"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_FEATURE_NOT_PRESENT"); case VK_ERROR_INCOMPATIBLE_DRIVER: // The requested version of Vulkan is not supported by the driver or is // otherwise incompatible for implementation-specific reasons. - return FailedPreconditionErrorBuilder(loc) - << "VK_ERROR_INCOMPATIBLE_DRIVER"; + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "VK_ERROR_INCOMPATIBLE_DRIVER"); case VK_ERROR_TOO_MANY_OBJECTS: // Too many objects of the type have already been created. - return ResourceExhaustedErrorBuilder(loc) << "VK_ERROR_TOO_MANY_OBJECTS"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_TOO_MANY_OBJECTS"); case VK_ERROR_FORMAT_NOT_SUPPORTED: // A requested format is not supported on this device. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_FORMAT_NOT_SUPPORTED"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_FORMAT_NOT_SUPPORTED"); case VK_ERROR_FRAGMENTED_POOL: - // A pool allocation has failed due to fragmentation of the pool’s memory. - // This must only be returned if no attempt to allocate host or device - // memory was made to accommodate the new allocation. - return ResourceExhaustedErrorBuilder(loc) << "VK_ERROR_FRAGMENTED_POOL"; + // A pool allocation has failed due to fragmentation of the pool’s + // memory. This must only be returned if no attempt to allocate host + // or device memory was made to accommodate the new allocation. + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_FRAGMENTED_POOL"); case VK_ERROR_OUT_OF_POOL_MEMORY: // A pool memory allocation has failed. This must only be returned if no // attempt to allocate host or device memory was made to accommodate the // new allocation. If the failure was definitely due to fragmentation of // the pool, VK_ERROR_FRAGMENTED_POOL should be returned instead. - return ResourceExhaustedErrorBuilder(loc) - << "VK_ERROR_OUT_OF_POOL_MEMORY"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_OUT_OF_POOL_MEMORY"); case VK_ERROR_INVALID_EXTERNAL_HANDLE: // An external handle is not a valid handle of the specified type. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_INVALID_EXTERNAL_HANDLE"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INVALID_EXTERNAL_HANDLE"); case VK_ERROR_SURFACE_LOST_KHR: // A surface is no longer available. - return UnavailableErrorBuilder(loc) << "VK_ERROR_SURFACE_LOST_KHR"; + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "VK_ERROR_SURFACE_LOST_KHR"); case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR: // The requested window is already in use by Vulkan or another API in a // manner which prevents it from being used again. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_NATIVE_WINDOW_IN_USE_KHR"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_NATIVE_WINDOW_IN_USE_KHR"); case VK_ERROR_OUT_OF_DATE_KHR: // A surface has changed in such a way that it is no longer compatible // with the swapchain, and further presentation requests using the // swapchain will fail. Applications must query the new surface properties // and recreate their swapchain if they wish to continue presenting to the // surface. - return FailedPreconditionErrorBuilder(loc) << "VK_ERROR_OUT_OF_DATE_KHR"; + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "VK_ERROR_OUT_OF_DATE_KHR"); case VK_ERROR_INCOMPATIBLE_DISPLAY_KHR: // The display used by a swapchain does not use the same presentable image // layout, or is incompatible in a way that prevents sharing an image. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_INCOMPATIBLE_DISPLAY_KHR"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INCOMPATIBLE_DISPLAY_KHR"); case VK_ERROR_VALIDATION_FAILED_EXT: // Validation layer testing failed. It is not expected that an // application would see this this error code during normal use of the // validation layers. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_VALIDATION_FAILED_EXT"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_VALIDATION_FAILED_EXT"); case VK_ERROR_INVALID_SHADER_NV: // One or more shaders failed to compile or link. More details are // reported back to the application when the validation layer is enabled // using the extension VK_EXT_debug_report. - return InvalidArgumentErrorBuilder(loc) << "VK_ERROR_INVALID_SHADER_NV"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INVALID_SHADER_NV"); case VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT: // When creating an image with // VkImageDrmFormatModifierExplicitCreateInfoEXT, it is the application’s @@ -207,33 +216,33 @@ Status VkResultToStatus(VkResult result, SourceLocation loc) { // outside the scope of Vulkan, and therefore not described by Valid Usage // requirements). If this validation fails, then vkCreateImage returns // VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT"; + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT"); case VK_ERROR_FRAGMENTATION_EXT: // A descriptor pool creation has failed due to fragmentation. - return ResourceExhaustedErrorBuilder(loc) << "VK_ERROR_FRAGMENTATION_EXT"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_FRAGMENTATION_EXT"); case VK_ERROR_NOT_PERMITTED_EXT: // When creating a queue, the caller does not have sufficient privileges // to request to acquire a priority above the default priority // (VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_EXT). - return PermissionDeniedErrorBuilder(loc) << "VK_ERROR_NOT_PERMITTED_EXT"; + return iree_make_status(IREE_STATUS_PERMISSION_DENIED, + "VK_ERROR_NOT_PERMITTED_EXT"); case VK_ERROR_INVALID_DEVICE_ADDRESS_EXT: // A buffer creation failed because the requested address is not // available. - return OutOfRangeErrorBuilder(loc) - << "VK_ERROR_INVALID_DEVICE_ADDRESS_EXT"; + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "VK_ERROR_INVALID_DEVICE_ADDRESS_EXT"); case VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT: // An operation on a swapchain created with // VK_FULL_SCREEN_EXCLUSIVE_APPLICATION_CONTROLLED_EXT failed as it did // not have exlusive full-screen access. This may occur due to // implementation-dependent reasons, outside of the application’s control. - return UnavailableErrorBuilder(loc) - << "VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT"; + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT"); default: - return UnknownErrorBuilder(loc) << result; + return iree_make_status(IREE_STATUS_UNKNOWN, "VkResult=%u", + (uint32_t)result); } } - -} // namespace vulkan -} // namespace hal -} // namespace iree diff --git a/iree/hal/vulkan/status_util.h b/iree/hal/vulkan/status_util.h index d1497fd447154..f225ec58facb4 100644 --- a/iree/hal/vulkan/status_util.h +++ b/iree/hal/vulkan/status_util.h @@ -19,19 +19,27 @@ #include "iree/hal/vulkan/vulkan_headers.h" // clang-format on -#include "iree/base/status.h" +#include "iree/base/api.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Converts a VkResult to an iree_status_t. +// +// Usage: +// iree_status_t status = VK_RESULT_TO_STATUS(vkDoThing(...)); +#define VK_RESULT_TO_STATUS(expr, ...) \ + iree_hal_vulkan_result_to_status((expr), __FILE__, __LINE__) // IREE_RETURN_IF_ERROR but implicitly converts the VkResult return value to // a Status. // // Usage: -// VK_RETURN_IF_ERROR(vkDoThing(...)); -#define VK_RETURN_IF_ERROR(expr) \ - IREE_RETURN_IF_ERROR(::iree::hal::vulkan::VkResultToStatus(expr, IREE_LOC)) +// VK_RETURN_IF_ERROR(vkDoThing(...), "message"); +#define VK_RETURN_IF_ERROR(expr, ...) \ + IREE_RETURN_IF_ERROR( \ + iree_hal_vulkan_result_to_status(expr, __FILE__, __LINE__), __VA_ARGS__) // IREE_CHECK_OK but implicitly converts the VkResults return value to a // Status and checks that it is OkStatus. @@ -39,7 +47,7 @@ namespace vulkan { // Usage: // VK_CHECK_OK(vkDoThing(...)); #define VK_CHECK_OK(expr) \ - IREE_CHECK_OK(::iree::hal::vulkan::VkResultToStatus(expr, IREE_LOC)) + IREE_CHECK_OK(iree_hal_vulkan_result_to_status(expr, __FILE__, __LINE__)) // Converts a VkResult to a Status object. // @@ -81,10 +89,11 @@ namespace vulkan { // - VK_ERROR_NOT_PERMITTED_EXT -> PermissionDeniedError("VK...") // - VK_ERROR_INVALID_DEVICE_ADDRESS_EXT -> OutOfRangeError("VK...") // - VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT -> InternalError("VK...") -Status VkResultToStatus(VkResult result, SourceLocation loc); +iree_status_t iree_hal_vulkan_result_to_status(VkResult result, + const char* file, uint32_t line); -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_STATUS_UTIL_H_ diff --git a/iree/hal/vulkan/timepoint_util.cc b/iree/hal/vulkan/timepoint_util.cc index c14ea26e13ed2..a8080d00f6564 100644 --- a/iree/hal/vulkan/timepoint_util.cc +++ b/iree/hal/vulkan/timepoint_util.cc @@ -17,7 +17,6 @@ #include #include "absl/synchronization/mutex.h" -#include "iree/base/time.h" #include "iree/base/tracing.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/status_util.h" @@ -47,13 +46,13 @@ void TimePointFence::ResetStatus() { } // static -StatusOr> TimePointFencePool::Create( - ref_ptr logical_device) { +iree_status_t TimePointFencePool::Create(VkDeviceHandle* logical_device, + TimePointFencePool** out_pool) { IREE_TRACE_SCOPE0("TimePointFencePool::Create"); - ref_ptr pool( - new TimePointFencePool(std::move(logical_device))); + ref_ptr pool(new TimePointFencePool(logical_device)); IREE_RETURN_IF_ERROR(pool->PreallocateFences()); - return pool; + *out_pool = pool.release(); + return iree_ok_status(); } TimePointFencePool::~TimePointFencePool() { @@ -100,8 +99,8 @@ void TimePointFencePool::ReleaseResolved(TimePointFence* fence) { free_fences_.push_back(std::unique_ptr(fence)); } -TimePointFencePool::TimePointFencePool(ref_ptr logical_device) - : logical_device_(std::move(logical_device)) {} +TimePointFencePool::TimePointFencePool(VkDeviceHandle* logical_device) + : logical_device_(logical_device) {} const ref_ptr& TimePointFencePool::syms() const { return logical_device_->syms(); @@ -120,9 +119,10 @@ Status TimePointFencePool::PreallocateFences() { absl::MutexLock lock(&mutex_); for (int i = 0; i < fences.size(); ++i) { VkFence fence = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateFence(*logical_device_, &create_info, - logical_device_->allocator(), - &fence)); + VK_RETURN_IF_ERROR( + syms()->vkCreateFence(*logical_device_, &create_info, + logical_device_->allocator(), &fence), + "vkCreateFence"); fences[i] = std::make_unique(this, fence); } } @@ -142,13 +142,14 @@ Status TimePointFencePool::PreallocateFences() { } // static -StatusOr> TimePointSemaphorePool::Create( - ref_ptr logical_device) { +iree_status_t TimePointSemaphorePool::Create( + VkDeviceHandle* logical_device, TimePointSemaphorePool** out_pool) { IREE_TRACE_SCOPE0("TimePointSemaphorePool::Create"); ref_ptr pool( - new TimePointSemaphorePool(std::move(logical_device))); + new TimePointSemaphorePool(logical_device)); IREE_RETURN_IF_ERROR(pool->PreallocateSemaphores()); - return pool; + *out_pool = pool.release(); + return iree_ok_status(); } TimePointSemaphorePool::~TimePointSemaphorePool() { @@ -206,9 +207,8 @@ void TimePointSemaphorePool::ReleaseUnresolved( free_semaphores_.merge_from(semaphores); } -TimePointSemaphorePool::TimePointSemaphorePool( - ref_ptr logical_device) - : logical_device_(std::move(logical_device)) {} +TimePointSemaphorePool::TimePointSemaphorePool(VkDeviceHandle* logical_device) + : logical_device_(logical_device) {} const ref_ptr& TimePointSemaphorePool::syms() const { return logical_device_->syms(); @@ -227,7 +227,8 @@ Status TimePointSemaphorePool::PreallocateSemaphores() { auto* semaphore = &storage_[i]; VK_RETURN_IF_ERROR(syms()->vkCreateSemaphore(*logical_device_, &create_info, logical_device_->allocator(), - &semaphore->semaphore)); + &semaphore->semaphore), + "vkCreateSemaphore"); free_semaphores_.push_back(semaphore); } diff --git a/iree/hal/vulkan/timepoint_util.h b/iree/hal/vulkan/timepoint_util.h index 7d10129e63dab..1f127e218d415 100644 --- a/iree/hal/vulkan/timepoint_util.h +++ b/iree/hal/vulkan/timepoint_util.h @@ -125,8 +125,8 @@ class TimePointFencePool final : public RefObject { static constexpr int kMaxInFlightFenceCount = 64; // Creates a new pool and pre-allocates `kMaxInFlightFenceCount` fences. - static StatusOr> Create( - ref_ptr logical_device); + static iree_status_t Create(VkDeviceHandle* logical_device, + TimePointFencePool** out_pool); ~TimePointFencePool(); @@ -143,18 +143,16 @@ class TimePointFencePool final : public RefObject { // not be in flight on GPU. void ReleaseResolved(TimePointFence* fence); - const ref_ptr& logical_device() const { - return logical_device_; - } + VkDeviceHandle* logical_device() const { return logical_device_; } private: - explicit TimePointFencePool(ref_ptr logical_device); + explicit TimePointFencePool(VkDeviceHandle* logical_device); const ref_ptr& syms() const; Status PreallocateFences() ABSL_LOCKS_EXCLUDED(mutex_); - ref_ptr logical_device_; + VkDeviceHandle* logical_device_; absl::Mutex mutex_; @@ -171,8 +169,8 @@ class TimePointSemaphorePool final : public RefObject { // Creates a new pool and pre-allocates `kMaxInFlightSemaphoreCount` binary // semaphores. - static StatusOr> Create( - ref_ptr logical_device); + static iree_status_t Create(VkDeviceHandle* logical_device, + TimePointSemaphorePool** out_pool); ~TimePointSemaphorePool(); @@ -195,13 +193,13 @@ class TimePointSemaphorePool final : public RefObject { void ReleaseUnresolved(IntrusiveList* semaphores); private: - explicit TimePointSemaphorePool(ref_ptr logical_device); + explicit TimePointSemaphorePool(VkDeviceHandle* logical_device); const ref_ptr& syms() const; Status PreallocateSemaphores() ABSL_LOCKS_EXCLUDED(mutex_); - ref_ptr logical_device_; + VkDeviceHandle* logical_device_; absl::Mutex mutex_; diff --git a/iree/hal/vulkan/vma_allocator.cc b/iree/hal/vulkan/vma_allocator.cc index bbdaa43ca22fe..635a94d72988d 100644 --- a/iree/hal/vulkan/vma_allocator.cc +++ b/iree/hal/vulkan/vma_allocator.cc @@ -14,26 +14,39 @@ #include "iree/hal/vulkan/vma_allocator.h" -#include "absl/memory/memory.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" -#include "iree/hal/buffer.h" #include "iree/hal/vulkan/status_util.h" #include "iree/hal/vulkan/vma_buffer.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; -// static -StatusOr> VmaAllocator::Create( - VkPhysicalDevice physical_device, - const ref_ptr& logical_device, VkInstance instance, - Options options) { - IREE_TRACE_SCOPE0("VmaAllocator::Create"); +typedef struct iree_hal_vulkan_vma_allocator_s { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + VmaAllocator vma; +} iree_hal_vulkan_vma_allocator_t; + +extern const iree_hal_allocator_vtable_t iree_hal_vulkan_vma_allocator_vtable; + +static iree_hal_vulkan_vma_allocator_t* iree_hal_vulkan_vma_allocator_cast( + iree_hal_allocator_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_vma_allocator_vtable); + return (iree_hal_vulkan_vma_allocator_t*)base_value; +} + +iree_status_t iree_hal_vulkan_vma_allocator_create( + VkInstance instance, VkPhysicalDevice physical_device, + VkDeviceHandle* logical_device, VmaRecordSettings record_settings, + iree_hal_allocator_t** out_allocator) { + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(physical_device); + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_allocator); + IREE_TRACE_ZONE_BEGIN(z0); const auto& syms = logical_device->syms(); VmaVulkanFunctions vulkan_fns; + memset(&vulkan_fns, 0, sizeof(vulkan_fns)); vulkan_fns.vkGetPhysicalDeviceProperties = syms->vkGetPhysicalDeviceProperties; vulkan_fns.vkGetPhysicalDeviceMemoryProperties = @@ -56,76 +69,110 @@ StatusOr> VmaAllocator::Create( vulkan_fns.vkDestroyImage = syms->vkDestroyImage; vulkan_fns.vkCmdCopyBuffer = syms->vkCmdCopyBuffer; - VmaRecordSettings record_settings; -#if VMA_RECORDING_ENABLED - record_settings.flags = - options.recording_flush_after_call ? VMA_RECORD_FLUSH_AFTER_CALL_BIT : 0; - record_settings.pFilePath = options.recording_file.c_str(); -#else - record_settings.flags = 0; - record_settings.pFilePath = nullptr; -#endif // VMA_RECORDING_ENABLED - - VmaAllocatorCreateInfo create_info{}; + VmaAllocatorCreateInfo create_info; + memset(&create_info, 0, sizeof(create_info)); create_info.flags = 0; create_info.physicalDevice = physical_device; create_info.device = *logical_device; create_info.instance = instance; create_info.preferredLargeHeapBlockSize = 64 * 1024 * 1024; create_info.pAllocationCallbacks = logical_device->allocator(); - create_info.pDeviceMemoryCallbacks = nullptr; + create_info.pDeviceMemoryCallbacks = NULL; create_info.frameInUseCount = 0; - create_info.pHeapSizeLimit = nullptr; + create_info.pHeapSizeLimit = NULL; create_info.pVulkanFunctions = &vulkan_fns; create_info.pRecordSettings = &record_settings; - ::VmaAllocator vma = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(vmaCreateAllocator(&create_info, &vma)); + VmaAllocator vma = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, VK_RESULT_TO_STATUS(vmaCreateAllocator(&create_info, &vma), + "vmaCreateAllocator")); - auto allocator = - absl::WrapUnique(new VmaAllocator(physical_device, logical_device, vma)); - // TODO(benvanik): query memory properties/types. - return allocator; + iree_allocator_t host_allocator = logical_device->host_allocator(); + iree_hal_vulkan_vma_allocator_t* allocator = NULL; + iree_status_t status = iree_allocator_malloc( + host_allocator, sizeof(*allocator), (void**)&allocator); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_vma_allocator_vtable, + &allocator->resource); + allocator->host_allocator = host_allocator; + allocator->vma = vma; + *out_allocator = (iree_hal_allocator_t*)allocator; + } else { + vmaDestroyAllocator(vma); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); } -VmaAllocator::VmaAllocator(VkPhysicalDevice physical_device, - const ref_ptr& logical_device, - ::VmaAllocator vma) - : physical_device_(physical_device), - logical_device_(add_ref(logical_device)), - vma_(vma) {} +static void iree_hal_vulkan_vma_allocator_destroy( + iree_hal_allocator_t* base_allocator) { + iree_hal_vulkan_vma_allocator_t* allocator = + iree_hal_vulkan_vma_allocator_cast(base_allocator); + iree_allocator_t host_allocator = allocator->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); -VmaAllocator::~VmaAllocator() { - IREE_TRACE_SCOPE0("VmaAllocator::dtor"); - vmaDestroyAllocator(vma_); -} + vmaDestroyAllocator(allocator->vma); + iree_allocator_free(host_allocator, allocator); -bool VmaAllocator::CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const { - // TODO(benvanik): ensure there is a memory type that can satisfy the request. - return source_allocator == this; + IREE_TRACE_ZONE_END(z0); } -bool VmaAllocator::CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const { - // TODO(benvnik): ensure there is a memory type that can satisfy the request. - return true; +static iree_allocator_t iree_hal_vulkan_vma_allocator_host_allocator( + const iree_hal_allocator_t* base_allocator) { + iree_hal_vulkan_vma_allocator_t* allocator = + (iree_hal_vulkan_vma_allocator_t*)base_allocator; + return allocator->host_allocator; } -Status VmaAllocator::MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const { - // TODO(benvanik): mutate to match supported memory types. - return OkStatus(); +static iree_hal_buffer_compatibility_t +iree_hal_vulkan_vma_allocator_query_buffer_compatibility( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t intended_usage, + iree_device_size_t allocation_size) { + // TODO(benvanik): check to ensure the allocator can serve the memory type. + + // Disallow usage not permitted by the buffer itself. Since we then use this + // to determine compatibility below we'll naturally set the right compat flags + // based on what's both allowed and intended. + intended_usage &= allowed_usage; + + // All buffers can be allocated on the heap. + iree_hal_buffer_compatibility_t compatibility = + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; + + // Buffers can only be used on the queue if they are device visible. + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } + if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; + } + } + + return compatibility; } -StatusOr> VmaAllocator::AllocateInternal( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - MemoryAccessBitfield allowed_access, size_t allocation_size, - VmaAllocationCreateFlags flags) { - IREE_TRACE_SCOPE0("VmaAllocator::AllocateInternal"); +static iree_status_t iree_hal_vulkan_vma_allocator_make_compatible( + iree_hal_memory_type_t* memory_type, + iree_hal_memory_access_t* allowed_access, + iree_hal_buffer_usage_t* allowed_usage) { + // TODO(benvanik): remove this entirely! + // Host currently uses mapping to copy buffers, which is done a lot. + // We could probably remove this mutation by preventing copies in those cases + // or issuing small copy command buffers. + *allowed_usage |= + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_BUFFER_USAGE_MAPPING; + return iree_ok_status(); +} +static iree_status_t iree_hal_vulkan_vma_allocator_allocate_internal( + iree_hal_vulkan_vma_allocator_t* allocator, + iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage, + iree_hal_memory_access_t allowed_access, size_t allocation_size, + VmaAllocationCreateFlags flags, iree_hal_buffer_t** out_buffer) { // Guard against the corner case where the requested buffer size is 0. The // application is unlikely to do anything when requesting a 0-byte buffer; but // it can happen in real world use cases. So we should at least not crash. @@ -133,22 +180,22 @@ StatusOr> VmaAllocator::AllocateInternal( VkBufferCreateInfo buffer_create_info; buffer_create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - buffer_create_info.pNext = nullptr; + buffer_create_info.pNext = NULL; buffer_create_info.flags = 0; buffer_create_info.size = allocation_size; buffer_create_info.usage = 0; - if (AllBitsSet(buffer_usage, BufferUsage::kTransfer)) { + if (iree_all_bits_set(allowed_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT; buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT; } - if (AllBitsSet(buffer_usage, BufferUsage::kDispatch)) { + if (iree_all_bits_set(allowed_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) { buffer_create_info.usage |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; buffer_create_info.usage |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; buffer_create_info.usage |= VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT; } buffer_create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; buffer_create_info.queueFamilyIndexCount = 0; - buffer_create_info.pQueueFamilyIndices = nullptr; + buffer_create_info.pQueueFamilyIndices = NULL; VmaAllocationCreateInfo allocation_create_info; allocation_create_info.flags = flags; @@ -157,9 +204,9 @@ StatusOr> VmaAllocator::AllocateInternal( allocation_create_info.preferredFlags = 0; allocation_create_info.memoryTypeBits = 0; // Automatic selection. allocation_create_info.pool = VK_NULL_HANDLE; - allocation_create_info.pUserData = nullptr; - if (AllBitsSet(memory_type, MemoryType::kDeviceLocal)) { - if (AllBitsSet(memory_type, MemoryType::kHostVisible)) { + allocation_create_info.pUserData = NULL; + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { // Device-local, host-visible. allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_TO_GPU; allocation_create_info.preferredFlags |= @@ -171,7 +218,7 @@ StatusOr> VmaAllocator::AllocateInternal( VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; } } else { - if (AllBitsSet(memory_type, MemoryType::kDeviceVisible)) { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { // Host-local, device-visible. allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU; } else { @@ -179,67 +226,68 @@ StatusOr> VmaAllocator::AllocateInternal( allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_ONLY; } } - if (AllBitsSet(memory_type, MemoryType::kHostCached)) { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_CACHED)) { allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT; } - if (AllBitsSet(memory_type, MemoryType::kHostCoherent)) { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) { allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; } - if (AllBitsSet(memory_type, MemoryType::kTransient)) { + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_TRANSIENT)) { allocation_create_info.preferredFlags |= VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT; } - if (AllBitsSet(buffer_usage, BufferUsage::kMapping)) { + if (iree_all_bits_set(allowed_usage, IREE_HAL_BUFFER_USAGE_MAPPING)) { allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; } - VkBuffer buffer = VK_NULL_HANDLE; + VkBuffer handle = VK_NULL_HANDLE; VmaAllocation allocation = VK_NULL_HANDLE; VmaAllocationInfo allocation_info; - VK_RETURN_IF_ERROR(vmaCreateBuffer(vma_, &buffer_create_info, - &allocation_create_info, &buffer, - &allocation, &allocation_info)); + VK_RETURN_IF_ERROR(vmaCreateBuffer(allocator->vma, &buffer_create_info, + &allocation_create_info, &handle, + &allocation, &allocation_info), + "vmaCreateBuffer"); - return make_ref(this, memory_type, allowed_access, buffer_usage, - allocation_size, 0, allocation_size, buffer, - allocation, allocation_info); + return iree_hal_vulkan_vma_buffer_wrap( + (iree_hal_allocator_t*)allocator, memory_type, allowed_access, + allowed_usage, allocation_size, + /*byte_offset=*/0, + /*byte_length=*/allocation_size, allocator->vma, handle, allocation, + allocation_info, out_buffer); } -StatusOr> VmaAllocator::Allocate( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("VmaAllocator::Allocate"); - return AllocateInternal(memory_type, buffer_usage, MemoryAccess::kAll, - allocation_size, /*flags=*/0); -} +static iree_status_t iree_hal_vulkan_vma_allocator_allocate_buffer( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size, + iree_hal_buffer_t** out_buffer) { + iree_hal_vulkan_vma_allocator_t* allocator = + iree_hal_vulkan_vma_allocator_cast(base_allocator); + + // Coerce options into those required for use by VMA. + iree_hal_memory_access_t allowed_access = IREE_HAL_MEMORY_ACCESS_ALL; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_vma_allocator_make_compatible( + &memory_type, &allowed_access, &allowed_usage)); -StatusOr> VmaAllocator::AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr source_buffer) { - IREE_TRACE_SCOPE0("VmaAllocator::AllocateConstant"); - // TODO(benvanik): import memory to avoid the copy. - IREE_ASSIGN_OR_RETURN( - auto buffer, - AllocateInternal(MemoryType::kDeviceLocal | MemoryType::kHostVisible, - buffer_usage, - MemoryAccess::kRead | MemoryAccess::kDiscardWrite, - source_buffer->byte_length(), - /*flags=*/0)); - IREE_RETURN_IF_ERROR( - buffer->CopyData(0, source_buffer.get(), 0, kWholeBuffer)); - buffer->set_allowed_access(MemoryAccess::kRead); - return buffer; + return iree_hal_vulkan_vma_allocator_allocate_internal( + allocator, memory_type, allowed_usage, allowed_access, allocation_size, + /*flags=*/0, out_buffer); } -StatusOr> VmaAllocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length) { - IREE_TRACE_SCOPE0("VmaAllocator::WrapMutable"); - // TODO(benvanik): import memory. - return UnimplementedErrorBuilder(IREE_LOC) - << "Wrapping host memory is not yet implemented"; +static iree_status_t iree_hal_vulkan_vma_allocator_wrap_buffer( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data, + iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "wrapping of external buffers not supported"); } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_allocator_vtable_t iree_hal_vulkan_vma_allocator_vtable = { + /*.destroy=*/iree_hal_vulkan_vma_allocator_destroy, + /*.host_allocator=*/iree_hal_vulkan_vma_allocator_host_allocator, + /*.query_buffer_compatibility = */ + iree_hal_vulkan_vma_allocator_query_buffer_compatibility, + /*.allocate_buffer=*/iree_hal_vulkan_vma_allocator_allocate_buffer, + /*.wrap_buffer=*/iree_hal_vulkan_vma_allocator_wrap_buffer, +}; diff --git a/iree/hal/vulkan/vma_allocator.h b/iree/hal/vulkan/vma_allocator.h index b48cbd8dbf1d3..bfa1f55ee0c93 100644 --- a/iree/hal/vulkan/vma_allocator.h +++ b/iree/hal/vulkan/vma_allocator.h @@ -15,25 +15,18 @@ #ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ #define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/internal_vk_mem_alloc.h" -namespace iree { -namespace hal { -namespace vulkan { - -class VmaBuffer; +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -// A HAL allocator using the Vulkan Memory Allocator (VMA) to manage memory. +// Creates a VMA-based allocator that performs internal suballocation and a +// bunch of other fancy things. +// +// This uses the Vulkan Memory Allocator (VMA) to manage memory. // VMA (//third_party/vulkan_memory_allocator) provides dlmalloc-like behavior // with suballocations made with various policies (best fit, first fit, etc). // This reduces the number of allocations we need from the Vulkan implementation @@ -47,78 +40,13 @@ class VmaBuffer; // More information: // https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator // https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/ -class VmaAllocator final : public Allocator { - public: - struct Options { -#if VMA_RECORDING_ENABLED - // File path to write a CSV containing the VMA recording. - std::string recording_file = ""; - - // Flush the VMA recording file after every call (useful if crashing or - // not exiting cleanly). - bool recording_flush_after_call = false; -#endif // VMA_RECORDING_ENABLED - }; - - static StatusOr> Create( - VkPhysicalDevice physical_device, - const ref_ptr& logical_device, VkInstance instance, - Options options); - - ~VmaAllocator() override; - - const ref_ptr& syms() const { - return logical_device_->syms(); - } - - ::VmaAllocator vma() const { return vma_; } - - bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const override; - - bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const override; - - Status MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const override; - - StatusOr> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) override; - - StatusOr> AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr source_buffer) override; - - StatusOr> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - void* data, - size_t data_length) override; - - private: - VmaAllocator(VkPhysicalDevice physical_device, - const ref_ptr& logical_device, - ::VmaAllocator vma); - - StatusOr> AllocateInternal( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - MemoryAccessBitfield allowed_access, size_t allocation_size, - VmaAllocationCreateFlags flags); - - VkPhysicalDevice physical_device_; - ref_ptr logical_device_; - - // Internally synchronized. We could externally synchronize if we thought it - // was worth it, however I'm not sure we'd be able to do much better with the - // current Allocator API. - ::VmaAllocator vma_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +iree_status_t iree_hal_vulkan_vma_allocator_create( + VkInstance instance, VkPhysicalDevice physical_device, + iree::hal::vulkan::VkDeviceHandle* logical_device, + VmaRecordSettings record_settings, iree_hal_allocator_t** out_allocator); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc index 1fb9e41992005..6b88c007ac3be 100644 --- a/iree/hal/vulkan/vma_buffer.cc +++ b/iree/hal/vulkan/vma_buffer.cc @@ -14,148 +14,154 @@ #include "iree/hal/vulkan/vma_buffer.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" #include "iree/hal/vulkan/status_util.h" -#include "iree/hal/vulkan/vma_allocator.h" - -namespace iree { -namespace hal { -namespace vulkan { - -VmaBuffer::VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, device_size_t allocation_size, - device_size_t byte_offset, device_size_t byte_length, - VkBuffer buffer, VmaAllocation allocation, - VmaAllocationInfo allocation_info) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, - byte_offset, byte_length), - vma_(allocator->vma()), - buffer_(buffer), - allocation_(allocation), - allocation_info_(allocation_info) { - // TODO(benvanik): set debug name instead and use the - // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag. - vmaSetAllocationUserData(vma_, allocation_, this); -} -VmaBuffer::~VmaBuffer() { - IREE_TRACE_SCOPE0("VmaBuffer::dtor"); - vmaDestroyBuffer(vma_, buffer_, allocation_); +typedef struct iree_hal_vulkan_vma_buffer_s { + iree_hal_buffer_t base; + + VmaAllocator vma; + VkBuffer handle; + VmaAllocation allocation; + VmaAllocationInfo allocation_info; +} iree_hal_vulkan_vma_buffer_t; + +extern const iree_hal_buffer_vtable_t iree_hal_vulkan_vma_buffer_vtable; + +static iree_hal_vulkan_vma_buffer_t* iree_hal_vulkan_vma_buffer_cast( + iree_hal_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_vma_buffer_vtable); + return (iree_hal_vulkan_vma_buffer_t*)base_value; } -Status VmaBuffer::FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) { - IREE_ASSIGN_OR_RETURN( - auto mapping, MapMemory(MemoryAccess::kDiscardWrite, byte_offset, - byte_length)); - void* data_ptr = static_cast(mapping.mutable_data()); - switch (pattern_length) { - case 1: { - uint8_t* data = static_cast(data_ptr); - uint8_t value_bits = *static_cast(pattern); - std::fill_n(data, byte_length, value_bits); - break; - } - case 2: { - uint16_t* data = static_cast(data_ptr); - uint16_t value_bits = *static_cast(pattern); - std::fill_n(data, byte_length / sizeof(uint16_t), value_bits); - break; - } - case 4: { - uint32_t* data = static_cast(data_ptr); - uint32_t value_bits = *static_cast(pattern); - std::fill_n(data, byte_length / sizeof(uint32_t), value_bits); - break; - } - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported scalar data size: " << pattern_length; +iree_status_t iree_hal_vulkan_vma_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + VmaAllocator vma, VkBuffer handle, VmaAllocation allocation, + VmaAllocationInfo allocation_info, iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(allocator); + IREE_ASSERT_ARGUMENT(vma); + IREE_ASSERT_ARGUMENT(handle); + IREE_ASSERT_ARGUMENT(allocation); + IREE_ASSERT_ARGUMENT(out_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_vma_buffer_t* buffer = NULL; + iree_status_t status = + iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator), + sizeof(*buffer), (void**)&buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_vma_buffer_vtable, + &buffer->base.resource); + buffer->base.allocator = allocator; + buffer->base.allocated_buffer = &buffer->base; + buffer->base.allocation_size = allocation_size; + buffer->base.byte_offset = byte_offset; + buffer->base.byte_length = byte_length; + buffer->base.memory_type = memory_type; + buffer->base.allowed_access = allowed_access; + buffer->base.allowed_usage = allowed_usage; + buffer->vma = vma; + buffer->handle = handle; + buffer->allocation = allocation; + buffer->allocation_info = allocation_info; + + // TODO(benvanik): set debug name instead and use the + // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag. + vmaSetAllocationUserData(buffer->vma, buffer->allocation, buffer); + + *out_buffer = &buffer->base; + } else { + vmaDestroyBuffer(vma, handle, allocation); } - return OkStatus(); -} -Status VmaBuffer::ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) { - IREE_ASSIGN_OR_RETURN( - auto mapping, - MapMemory(MemoryAccess::kRead, source_offset, data_length)); - std::memcpy(data, mapping.data(), mapping.byte_length()); - return OkStatus(); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); } -Status VmaBuffer::WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) { - IREE_ASSIGN_OR_RETURN(auto mapping, - MapMemory(MemoryAccess::kDiscardWrite, - target_offset, data_length)); - std::memcpy(mapping.mutable_data(), data, mapping.byte_length()); - return OkStatus(); +static void iree_hal_vulkan_vma_buffer_destroy(iree_hal_buffer_t* base_buffer) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + iree_allocator_t host_allocator = + iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer)); + IREE_TRACE_ZONE_BEGIN(z0); + + vmaDestroyBuffer(buffer->vma, buffer->handle, buffer->allocation); + iree_allocator_free(host_allocator, buffer); + + IREE_TRACE_ZONE_END(z0); } -Status VmaBuffer::CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - // This is pretty terrible. Let's not do this. - // TODO(benvanik): a way for allocators to indicate transfer compat. - IREE_ASSIGN_OR_RETURN(auto source_mapping, - source_buffer->MapMemory( - MemoryAccess::kRead, source_offset, data_length)); - IREE_CHECK_EQ(data_length, source_mapping.size()); - IREE_ASSIGN_OR_RETURN(auto target_mapping, - MapMemory(MemoryAccess::kDiscardWrite, - target_offset, data_length)); - IREE_CHECK_EQ(data_length, target_mapping.size()); - std::memcpy(target_mapping.mutable_data(), source_mapping.data(), - data_length); - return OkStatus(); +VkBuffer iree_hal_vulkan_vma_buffer_handle(iree_hal_buffer_t* base_buffer) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + return buffer->handle; } -Status VmaBuffer::MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) { +static iree_status_t iree_hal_vulkan_vma_buffer_map_range( + iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode, + iree_hal_memory_access_t memory_access, + iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, + void** out_data_ptr) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + uint8_t* data_ptr = nullptr; VK_RETURN_IF_ERROR( - vmaMapMemory(vma_, allocation_, reinterpret_cast(&data_ptr))); - *out_data = data_ptr + local_byte_offset; + vmaMapMemory(buffer->vma, buffer->allocation, (void**)&data_ptr), + "vmaMapMemory"); + *out_data_ptr = data_ptr + local_byte_offset; // If we mapped for discard scribble over the bytes. This is not a mandated // behavior but it will make debugging issues easier. Alternatively for // heap buffers we could reallocate them such that ASAN yells, but that // would only work if the entire buffer was discarded. #ifndef NDEBUG - if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) { - std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); + if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { + memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); } #endif // !NDEBUG - return OkStatus(); + return iree_ok_status(); } -Status VmaBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) { - vmaUnmapMemory(vma_, allocation_); - return OkStatus(); +static void iree_hal_vulkan_vma_buffer_unmap_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, void* data_ptr) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + vmaUnmapMemory(buffer->vma, buffer->allocation); } -Status VmaBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - vmaInvalidateAllocation(vma_, allocation_, local_byte_offset, - local_byte_length); - return OkStatus(); +static iree_status_t iree_hal_vulkan_vma_buffer_invalidate_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + VK_RETURN_IF_ERROR( + vmaInvalidateAllocation(buffer->vma, buffer->allocation, + local_byte_offset, local_byte_length), + "vmaInvalidateAllocation"); + return iree_ok_status(); } -Status VmaBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - vmaFlushAllocation(vma_, allocation_, local_byte_offset, local_byte_length); - return OkStatus(); +static iree_status_t iree_hal_vulkan_vma_buffer_flush_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + VK_RETURN_IF_ERROR(vmaFlushAllocation(buffer->vma, buffer->allocation, + local_byte_offset, local_byte_length), + "vmaFlushAllocation"); + return iree_ok_status(); } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_buffer_vtable_t iree_hal_vulkan_vma_buffer_vtable = { + /*.destroy=*/iree_hal_vulkan_vma_buffer_destroy, + /*.map_range=*/iree_hal_vulkan_vma_buffer_map_range, + /*.unmap_range=*/iree_hal_vulkan_vma_buffer_unmap_range, + /*.invalidate_range=*/iree_hal_vulkan_vma_buffer_invalidate_range, + /*.flush_range=*/iree_hal_vulkan_vma_buffer_flush_range, +}; diff --git a/iree/hal/vulkan/vma_buffer.h b/iree/hal/vulkan/vma_buffer.h index 23008c1c31c5b..0771e65239b47 100644 --- a/iree/hal/vulkan/vma_buffer.h +++ b/iree/hal/vulkan/vma_buffer.h @@ -15,67 +15,30 @@ #ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_ #define IREE_HAL_VULKAN_VMA_BUFFER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/buffer.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/internal_vk_mem_alloc.h" -namespace iree { -namespace hal { -namespace vulkan { - -class VmaAllocator; - -// A buffer implementation representing an allocation made from within a pool of -// a Vulkan Memory Allocator instance. See VmaAllocator for more information. -class VmaBuffer final : public Buffer { - public: - VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length, VkBuffer buffer, - VmaAllocation allocation, VmaAllocationInfo allocation_info); - ~VmaBuffer() override; - - VkBuffer handle() const { return buffer_; } - VmaAllocation allocation() const { return allocation_; } - const VmaAllocationInfo& allocation_info() const { return allocation_info_; } - - // Exposed so that VmaAllocator can reset access after initial mapping. - using Buffer::set_allowed_access; - - private: - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override; - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override; - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override; - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override; - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override; - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - - ::VmaAllocator vma_; - VkBuffer buffer_; - VmaAllocation allocation_; - VmaAllocationInfo allocation_info_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Wraps a VMA allocation in an iree_hal_buffer_t. +// The allocation will be released back to VMA when the buffer is released. +iree_status_t iree_hal_vulkan_vma_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + VmaAllocator vma, VkBuffer handle, VmaAllocation allocation, + VmaAllocationInfo allocation_info, iree_hal_buffer_t** out_buffer); + +// Returns the Vulkan handle backing the given |buffer|. +// This is the entire allocated_buffer and must be offset by the buffer +// byte_offset and byte_length when used. +VkBuffer iree_hal_vulkan_vma_buffer_handle(iree_hal_buffer_t* buffer); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_VMA_BUFFER_H_ diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc index 582a8c3e61cbe..e127b6a5dee62 100644 --- a/iree/hal/vulkan/vulkan_device.cc +++ b/iree/hal/vulkan/vulkan_device.cc @@ -20,94 +20,205 @@ #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" #include "iree/base/math.h" +#include "iree/base/memory.h" #include "iree/base/status.h" -#include "iree/base/time.h" #include "iree/base/tracing.h" -#include "iree/hal/command_buffer_validation.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/semaphore.h" +#include "iree/hal/vulkan/api.h" +#include "iree/hal/vulkan/command_queue.h" +#include "iree/hal/vulkan/descriptor_pool_cache.h" #include "iree/hal/vulkan/direct_command_buffer.h" #include "iree/hal/vulkan/direct_command_queue.h" #include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" +#include "iree/hal/vulkan/emulated_semaphore.h" #include "iree/hal/vulkan/extensibility_util.h" +#include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/native_descriptor_set.h" +#include "iree/hal/vulkan/native_descriptor_set_layout.h" #include "iree/hal/vulkan/native_event.h" -#include "iree/hal/vulkan/native_timeline_semaphore.h" -#include "iree/hal/vulkan/pipeline_cache.h" -#include "iree/hal/vulkan/pipeline_executable_layout.h" +#include "iree/hal/vulkan/native_executable_layout.h" +#include "iree/hal/vulkan/native_semaphore.h" +#include "iree/hal/vulkan/nop_executable_cache.h" #include "iree/hal/vulkan/serializing_command_queue.h" #include "iree/hal/vulkan/status_util.h" #include "iree/hal/vulkan/vma_allocator.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; + +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_device_t extensibility util +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree_host_size_t string_capacity, + const char** out_string_values, iree_host_size_t* out_string_count) { + *out_string_count = 0; + + iree_status_t status = iree_ok_status(); + iree_host_size_t string_count = 0; +#define ADD_EXT(target_set, name_literal) \ + if (iree_status_is_ok(status) && set == (target_set)) { \ + if (string_count >= string_capacity && out_string_values) { \ + status = iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); \ + } else if (out_string_values) { \ + out_string_values[string_count] = (name_literal); \ + } \ + ++string_count; \ + } + + //===--------------------------------------------------------------------===// + // Baseline IREE requirements + //===--------------------------------------------------------------------===// + // Using IREE at all requires these extensions unconditionally. Adding things + // here changes our minimum requirements and should be done carefully. + // Optional extensions here are feature detected by the runtime. + + // VK_KHR_storage_buffer_storage_class: + // Our generated SPIR-V kernels use storage buffers for all their data access. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED, + VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); + + // VK_KHR_get_physical_device_properties2: + // Multiple extensions depend on VK_KHR_get_physical_device_properties2. + // This extension was deprecated in Vulkan 1.1 as its functionality was + // promoted to core so we list it as optional even though we require it. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL, + VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); + + // VK_KHR_push_descriptor: + // We can avoid a lot of additional Vulkan descriptor set manipulation + // overhead when this extension is present. Android is a holdout, though, and + // we have a fallback for when it's not available. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, + VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); + + //===--------------------------------------------------------------------===// + // Vulkan forward-compatibility shims + //===--------------------------------------------------------------------===// + // These are shims or extensions that are made core later in the spec and can + // be removed once we require the core version that contains them. + + // VK_KHR_timeline_semaphore: + // timeline semaphore support is optional and will be emulated if necessary. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, + VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME); + + // VK_LAYER_KHRONOS_timeline_semaphore: + // polyfill layer - enable if present instead of our custom emulation. Ignored + // if timeline semaphores are supported natively (Vulkan 1.2+). + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, + "VK_LAYER_KHRONOS_timeline_semaphore"); + + //===--------------------------------------------------------------------===// + // Optional debugging features + //===--------------------------------------------------------------------===// + // Used only when explicitly requested as they drastically change the + // performance behavior of Vulkan. + + // VK_LAYER_KHRONOS_validation: + // only enabled if validation is desired. Since validation in Vulkan is just a + // API correctness check it can't be used as a security mechanism and is fine + // to ignore. + if (iree_all_bits_set(requested_features, + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS)) { + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, + "VK_LAYER_KHRONOS_validation"); + } + + // VK_EXT_debug_utils: + // only enabled if debugging is desired to route Vulkan debug messages through + // our logging sinks. Note that this adds a non-trivial runtime overhead and + // we may want to disable it even in debug builds. + if (iree_all_bits_set(requested_features, + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS)) { + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL, + VK_EXT_DEBUG_UTILS_EXTENSION_NAME); + } + + *out_string_count = string_count; + return status; +} -namespace { +//===----------------------------------------------------------------------===// +// Queue selection +//===----------------------------------------------------------------------===// -constexpr uint32_t kInvalidQueueFamilyIndex = -1; +#define IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX (-1) -struct QueueFamilyInfo { - uint32_t dispatch_index = kInvalidQueueFamilyIndex; - uint32_t dispatch_queue_count = 0; - uint32_t transfer_index = kInvalidQueueFamilyIndex; - uint32_t transfer_queue_count = 0; -}; +typedef struct { + uint32_t dispatch_index; + iree_host_size_t dispatch_queue_count; + uint32_t transfer_index; + iree_host_size_t transfer_queue_count; +} iree_hal_vulkan_queue_family_info_t; // Finds the first queue in the listing (which is usually the // driver-preferred) that has all of the |required_queue_flags| and none of -// the |excluded_queue_flags|. Returns kInvalidQueueFamilyIndex if no matching -// queue is found. -uint32_t FindFirstQueueFamilyWithFlags( - absl::Span queue_family_properties, - uint32_t required_queue_flags, uint32_t excluded_queue_flags) { - for (int queue_family_index = 0; - queue_family_index < queue_family_properties.size(); +// the |excluded_queue_flags|. +// Returns IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX if no matching queue is +// found. +static uint32_t iree_hal_vulkan_find_first_queue_family_with_flags( + uint32_t queue_family_count, + const VkQueueFamilyProperties* queue_family_properties, + VkQueueFlags required_queue_flags, VkQueueFlags excluded_queue_flags) { + for (uint32_t queue_family_index = 0; queue_family_index < queue_family_count; ++queue_family_index) { - const auto& properties = queue_family_properties[queue_family_index]; - if ((properties.queueFlags & required_queue_flags) == - required_queue_flags && - (properties.queueFlags & excluded_queue_flags) == 0) { + const VkQueueFamilyProperties* properties = + &queue_family_properties[queue_family_index]; + if (iree_all_bits_set(properties->queueFlags, required_queue_flags) && + !iree_any_bit_set(properties->queueFlags, excluded_queue_flags)) { return queue_family_index; } } - return kInvalidQueueFamilyIndex; + return IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX; } // Selects queue family indices for compute and transfer queues. // Note that both queue families may be the same if there is only one family // available. -StatusOr SelectQueueFamilies( - VkPhysicalDevice physical_device, const ref_ptr& syms) { +static iree_status_t iree_hal_vulkan_select_queue_families( + VkPhysicalDevice physical_device, iree::hal::vulkan::DynamicSymbols* syms, + iree_hal_vulkan_queue_family_info_t* out_family_info) { // Enumerate queue families available on the device. uint32_t queue_family_count = 0; syms->vkGetPhysicalDeviceQueueFamilyProperties(physical_device, - &queue_family_count, nullptr); - absl::InlinedVector queue_family_properties( - queue_family_count); + &queue_family_count, NULL); + VkQueueFamilyProperties* queue_family_properties = + (VkQueueFamilyProperties*)iree_alloca(queue_family_count * + sizeof(VkQueueFamilyProperties)); syms->vkGetPhysicalDeviceQueueFamilyProperties( - physical_device, &queue_family_count, queue_family_properties.data()); + physical_device, &queue_family_count, queue_family_properties); - QueueFamilyInfo queue_family_info; + memset(out_family_info, 0, sizeof(*out_family_info)); + out_family_info->dispatch_index = IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX; + out_family_info->dispatch_queue_count = 0; + out_family_info->transfer_index = IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX; + out_family_info->transfer_queue_count = 0; // Try to find a dedicated compute queue (no graphics caps). // Some may support both transfer and compute. If that fails then fallback // to any queue that supports compute. - queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_COMPUTE_BIT, VK_QUEUE_GRAPHICS_BIT); - if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { - queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_COMPUTE_BIT, 0); - } - if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { - return NotFoundErrorBuilder(IREE_LOC) - << "Unable to find any queue family support compute operations"; - } - queue_family_info.dispatch_queue_count = - queue_family_properties[queue_family_info.dispatch_index].queueCount; + out_family_info->dispatch_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_COMPUTE_BIT, + VK_QUEUE_GRAPHICS_BIT); + if (out_family_info->dispatch_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->dispatch_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_COMPUTE_BIT, + 0); + } + if (out_family_info->dispatch_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "unable to find any queue family support compute operations"); + } + out_family_info->dispatch_queue_count = + queue_family_properties[out_family_info->dispatch_index].queueCount; // Try to find a dedicated transfer queue (no compute or graphics caps). // Not all devices have one, and some have only a queue family for @@ -115,147 +226,430 @@ StatusOr SelectQueueFamilies( // fails then fallback to any queue that supports transfer. Finally, if // /that/ fails then we just won't create a transfer queue and instead use // the compute queue for all operations. - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, - VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT); - if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, VK_QUEUE_GRAPHICS_BIT); - } - if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, 0); - } - if (queue_family_info.transfer_index != kInvalidQueueFamilyIndex) { - queue_family_info.transfer_queue_count = - queue_family_properties[queue_family_info.transfer_index].queueCount; + out_family_info->transfer_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_TRANSFER_BIT, + VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT); + if (out_family_info->transfer_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->transfer_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_TRANSFER_BIT, + VK_QUEUE_GRAPHICS_BIT); + } + if (out_family_info->transfer_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->transfer_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_TRANSFER_BIT, + 0); + } + if (out_family_info->transfer_index != + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->transfer_queue_count = + queue_family_properties[out_family_info->transfer_index].queueCount; } // Ensure that we don't share the dispatch queues with transfer queues if // that would put us over the queue count. + if (out_family_info->dispatch_index == out_family_info->transfer_index) { + out_family_info->transfer_queue_count = iree_min( + queue_family_properties[out_family_info->dispatch_index].queueCount - + out_family_info->dispatch_queue_count, + out_family_info->transfer_queue_count); + } + + // Limit the number of queues we create (for now). + // We may want to allow this to grow, but each queue adds overhead and we + // need to measure to make sure we can effectively use them all. + out_family_info->dispatch_queue_count = + iree_min(2u, out_family_info->dispatch_queue_count); + out_family_info->transfer_queue_count = + iree_min(1u, out_family_info->transfer_queue_count); + + return iree_ok_status(); +} + +// Builds a set of compute and transfer queues based on the queues available on +// the device and some magic heuristical goo. +static iree_status_t iree_hal_vulkan_build_queue_sets( + VkPhysicalDevice physical_device, iree::hal::vulkan::DynamicSymbols* syms, + iree_hal_vulkan_queue_set_t* out_compute_queue_set, + iree_hal_vulkan_queue_set_t* out_transfer_queue_set) { + // Select which queues to use (and fail the implementation can't handle them). + iree_hal_vulkan_queue_family_info_t queue_family_info; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_select_queue_families( + physical_device, syms, &queue_family_info)); + + // Build queue indices for the selected queue families. + memset(out_compute_queue_set, 0, sizeof(*out_compute_queue_set)); + out_compute_queue_set->queue_family_index = queue_family_info.dispatch_index; + for (iree_host_size_t i = 0; i < queue_family_info.dispatch_queue_count; + ++i) { + out_compute_queue_set->queue_indices |= 1ull << i; + } + + memset(out_transfer_queue_set, 0, sizeof(*out_transfer_queue_set)); + out_transfer_queue_set->queue_family_index = queue_family_info.transfer_index; + uint32_t base_queue_index = 0; if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - queue_family_info.transfer_queue_count = std::min( - queue_family_properties[queue_family_info.dispatch_index].queueCount - - queue_family_info.dispatch_queue_count, - queue_family_info.transfer_queue_count); + // Sharing a family, so transfer queues follow compute queues. + base_queue_index = queue_family_info.dispatch_index; } + for (iree_host_size_t i = 0; i < queue_family_info.transfer_queue_count; + ++i) { + out_transfer_queue_set->queue_indices |= 1ull << (i + base_queue_index); + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_device_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_resource_t resource; + iree_string_view_t identifier; + + // Optional driver that owns the instance. We retain it for our lifetime to + // ensure the instance remains valid. + iree_hal_driver_t* driver; + + // Flags overriding default device behavior. + iree_hal_vulkan_device_flags_t flags; + // Which optional extensions are active and available on the device. + iree_hal_vulkan_device_extensions_t device_extensions; + + VkInstance instance; + VkPhysicalDevice physical_device; + VkDeviceHandle* logical_device; + + iree_allocator_t host_allocator; + iree_hal_allocator_t* device_allocator; + + // All queues available on the device; the device owns these. + iree_host_size_t queue_count; + CommandQueue** queues; + // The subset of queues that support dispatch operations. May overlap with + // transfer_queues. + iree_host_size_t dispatch_queue_count; + CommandQueue** dispatch_queues; + // The subset of queues that support transfer operations. May overlap with + // dispatch_queues. + iree_host_size_t transfer_queue_count; + CommandQueue** transfer_queues; + + DescriptorPoolCache* descriptor_pool_cache; + + VkCommandPoolHandle* dispatch_command_pool; + VkCommandPoolHandle* transfer_command_pool; + + // Used only for emulated timeline semaphores. + TimePointSemaphorePool* semaphore_pool; + TimePointFencePool* fence_pool; +} iree_hal_vulkan_device_t; + +extern const iree_hal_device_vtable_t iree_hal_vulkan_device_vtable; + +static iree_hal_vulkan_device_t* iree_hal_vulkan_device_cast( + iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_device_vtable); + return (iree_hal_vulkan_device_t*)base_value; +} - return queue_family_info; +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_device_options_initialize( + iree_hal_vulkan_device_options_t* out_options) { + memset(out_options, 0, sizeof(*out_options)); + out_options->flags = 0; } // Creates a transient command pool for the given queue family. // Command buffers allocated from the pool must only be issued on queues // belonging to the specified family. -StatusOr> CreateTransientCommandPool( - const ref_ptr& logical_device, - uint32_t queue_family_index) { +static iree_status_t iree_hal_vulkan_create_transient_command_pool( + VkDeviceHandle* logical_device, uint32_t queue_family_index, + VkCommandPoolHandle** out_handle) { VkCommandPoolCreateInfo create_info; create_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - create_info.pNext = nullptr; + create_info.pNext = NULL; create_info.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; create_info.queueFamilyIndex = queue_family_index; - - auto command_pool = make_ref(logical_device); - VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateCommandPool( - *logical_device, &create_info, logical_device->allocator(), - command_pool->mutable_value())); - return command_pool; + VkCommandPoolHandle* command_pool = new VkCommandPoolHandle(logical_device); + iree_status_t status = VK_RESULT_TO_STATUS( + logical_device->syms()->vkCreateCommandPool( + *logical_device, &create_info, logical_device->allocator(), + command_pool->mutable_value()), + "vkCreateCommandPool"); + if (iree_status_is_ok(status)) { + *out_handle = command_pool; + } else { + delete command_pool; + } + return status; } -// Creates command queues for the given sets of queues. -absl::InlinedVector, 4> CreateCommandQueues( - const DeviceInfo& device_info, - const ref_ptr& logical_device, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set, - const ref_ptr& fence_pool, - const ref_ptr& syms) { - absl::InlinedVector, 4> command_queues; +// Creates a command queue of the given queue family. +static CommandQueue* iree_hal_vulkan_device_create_queue( + VkDeviceHandle* logical_device, + iree_hal_command_category_t command_category, uint32_t queue_family_index, + uint32_t queue_index, TimePointFencePool* fence_pool) { + VkQueue queue = VK_NULL_HANDLE; + logical_device->syms()->vkGetDeviceQueue(*logical_device, queue_family_index, + queue_index, &queue); + std::string queue_name; + if (!iree_all_bits_set(command_category, + IREE_HAL_COMMAND_CATEGORY_DISPATCH)) { + queue_name = "q(t):"; + } else { + queue_name = "q(d):"; + } + queue_name += std::to_string(queue_index); - uint64_t compute_queue_count = - iree_math_count_ones_u64(compute_queue_set.queue_indices); - for (uint32_t i = 0; i < compute_queue_count; ++i) { - if (!(compute_queue_set.queue_indices & (1ull << i))) continue; - - VkQueue queue = VK_NULL_HANDLE; - syms->vkGetDeviceQueue(*logical_device, - compute_queue_set.queue_family_index, i, &queue); - std::string queue_name = absl::StrCat(device_info.name(), ":d", i); - - if (fence_pool != nullptr) { - command_queues.push_back(absl::make_unique( - std::move(queue_name), - CommandCategory::kDispatch | CommandCategory::kTransfer, - logical_device, fence_pool, queue)); - } else { - command_queues.push_back(absl::make_unique( - std::move(queue_name), - CommandCategory::kDispatch | CommandCategory::kTransfer, - logical_device, queue)); - } + // When emulating timeline semaphores we use a special queue that allows us to + // sequence the semaphores correctly. + if (fence_pool != NULL) { + return new SerializingCommandQueue(logical_device, std::move(queue_name), + command_category, queue, fence_pool); } + return new DirectCommandQueue(logical_device, std::move(queue_name), + command_category, queue); +} + +// Creates command queues for the given sets of queues and populates the +// device queue lists. +static void iree_hal_vulkan_device_initialize_command_queues( + iree_hal_vulkan_device_t* device, iree_string_view_t queue_prefix, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set) { + device->queue_count = 0; + device->dispatch_queue_count = 0; + device->transfer_queue_count = 0; + + uint64_t compute_queue_count = + iree_math_count_ones_u64(compute_queue_set->queue_indices); uint64_t transfer_queue_count = - iree_math_count_ones_u64(transfer_queue_set.queue_indices); - for (uint32_t i = 0; i < transfer_queue_count; ++i) { - if (!(transfer_queue_set.queue_indices & (1ull << i))) continue; - - VkQueue queue = VK_NULL_HANDLE; - syms->vkGetDeviceQueue(*logical_device, - transfer_queue_set.queue_family_index, i, &queue); - std::string queue_name = absl::StrCat(device_info.name(), ":t", i); - if (fence_pool != nullptr) { - command_queues.push_back(absl::make_unique( - std::move(queue_name), CommandCategory::kTransfer, logical_device, - fence_pool, queue)); - } else { - command_queues.push_back(absl::make_unique( - std::move(queue_name), CommandCategory::kTransfer, logical_device, - queue)); + iree_math_count_ones_u64(transfer_queue_set->queue_indices); + for (iree_host_size_t i = 0; i < compute_queue_count; ++i) { + if (!(compute_queue_set->queue_indices & (1ull << i))) continue; + CommandQueue* queue = iree_hal_vulkan_device_create_queue( + device->logical_device, IREE_HAL_COMMAND_CATEGORY_ANY, + compute_queue_set->queue_family_index, i, device->fence_pool); + device->queues[device->queue_count++] = queue; + device->dispatch_queues[device->dispatch_queue_count++] = queue; + if (!transfer_queue_count) { + // If we don't have any dedicated transfer queues then use all dispatch + // queues as transfer queues. + device->transfer_queues[device->transfer_queue_count++] = queue; } } - - return command_queues; + for (iree_host_size_t i = 0; i < transfer_queue_count; ++i) { + if (!(transfer_queue_set->queue_indices & (1ull << i))) continue; + CommandQueue* queue = iree_hal_vulkan_device_create_queue( + device->logical_device, IREE_HAL_COMMAND_CATEGORY_TRANSFER, + transfer_queue_set->queue_family_index, i, device->fence_pool); + device->queues[device->queue_count++] = queue; + device->transfer_queues[device->transfer_queue_count++] = queue; + } } -} // namespace +static iree_status_t iree_hal_vulkan_device_create_internal( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_vulkan_device_options_t* options, VkInstance instance, + VkPhysicalDevice physical_device, VkDeviceHandle* logical_device, + const iree_hal_vulkan_device_extensions_t* device_extensions, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + auto& device_syms = logical_device->syms(); + + iree_host_size_t compute_queue_count = + iree_math_count_ones_u64(compute_queue_set->queue_indices); + iree_host_size_t transfer_queue_count = + iree_math_count_ones_u64(transfer_queue_set->queue_indices); + iree_host_size_t total_queue_count = + compute_queue_count + transfer_queue_count; + + iree_hal_vulkan_device_t* device = NULL; + iree_host_size_t total_size = + sizeof(*device) + identifier.size + + total_queue_count * sizeof(device->queues[0]) + + total_queue_count * sizeof(device->dispatch_queues[0]) + + total_queue_count * sizeof(device->transfer_queues[0]); + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&device)); + memset(device, 0, total_size); + iree_hal_resource_initialize(&iree_hal_vulkan_device_vtable, + &device->resource); + device->host_allocator = host_allocator; + device->driver = driver; + iree_hal_driver_retain(device->driver); + uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); + buffer_ptr += iree_string_view_append_to_buffer( + identifier, &device->identifier, (char*)buffer_ptr); + device->flags = options->flags; + + device->device_extensions = *device_extensions; + device->instance = instance; + device->physical_device = physical_device; + device->logical_device = logical_device; + device->logical_device->AddReference(); + + // Point the queue storage into the new device allocation. The queues + // themselves are populated + device->queues = (CommandQueue**)buffer_ptr; + buffer_ptr += total_queue_count * sizeof(device->queues[0]); + device->dispatch_queues = (CommandQueue**)buffer_ptr; + buffer_ptr += total_queue_count * sizeof(device->dispatch_queues[0]); + device->transfer_queues = (CommandQueue**)buffer_ptr; + buffer_ptr += total_queue_count * sizeof(device->transfer_queues[0]); + + device->descriptor_pool_cache = + new DescriptorPoolCache(device->logical_device); + + // Create the device memory allocator that will service all buffer + // allocation requests. + VmaRecordSettings vma_record_settings; + memset(&vma_record_settings, 0, sizeof(vma_record_settings)); + iree_status_t status = iree_hal_vulkan_vma_allocator_create( + instance, physical_device, logical_device, vma_record_settings, + &device->device_allocator); -// static -StatusOr> VulkanDevice::Create( - ref_ptr driver, VkInstance instance, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, Options options, - const ref_ptr& syms, - DebugCaptureManager* debug_capture_manager) { - IREE_TRACE_SCOPE0("VulkanDevice::Create"); + // Create command pools for each queue family. If we don't have a transfer + // queue then we'll ignore that one and just use the dispatch pool. + // If we wanted to expose the pools through the HAL to allow the VM to more + // effectively manage them (pool per fiber, etc) we could, however I doubt + // the overhead of locking the pool will be even a blip. + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_create_transient_command_pool( + device->logical_device, compute_queue_set->queue_family_index, + &device->dispatch_command_pool); + } + if (transfer_queue_set->queue_indices != 0 && iree_status_is_ok(status)) { + status = iree_hal_vulkan_create_transient_command_pool( + device->logical_device, transfer_queue_set->queue_family_index, + &device->transfer_command_pool); + } + + // Emulate timeline semaphores when the extension is not available and we are + // ony Vulkan versions prior to 1.2 when they were made core. + bool emulate_timeline_semaphores = + device_syms->vkGetSemaphoreCounterValue == NULL || + iree_all_bits_set( + options->flags, + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION); + if (emulate_timeline_semaphores && iree_status_is_ok(status)) { + status = TimePointSemaphorePool::Create(device->logical_device, + &device->semaphore_pool); + } + if (emulate_timeline_semaphores && iree_status_is_ok(status)) { + status = + TimePointFencePool::Create(device->logical_device, &device->fence_pool); + } + + // Initialize queues now that we've completed the rest of the device + // initialization; this happens last as the queues require the pools allocated + // above. + if (iree_status_is_ok(status)) { + iree_hal_vulkan_device_initialize_command_queues( + device, identifier, compute_queue_set, transfer_queue_set); + } + + if (iree_status_is_ok(status)) { + *out_device = (iree_hal_device_t*)device; + } else { + iree_hal_device_destroy((iree_hal_device_t*)device); + } + return status; +} + +static void iree_hal_vulkan_device_destroy(iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); + IREE_TRACE_ZONE_BEGIN(z0); - if (!options.extensibility_spec.optional_layers.empty() || - !options.extensibility_spec.required_layers.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Device layers are deprecated and unsupported by IREE"; + // Drop all command queues. These may wait until idle in their destructor. + for (iree_host_size_t i = 0; i < device->queue_count; ++i) { + delete device->queues[i]; } + // Drop command pools now that we know there are no more outstanding command + // buffers. + delete device->dispatch_command_pool; + delete device->transfer_command_pool; + + // Now that no commands are outstanding we can release all resources that may + // have been in use. + delete device->descriptor_pool_cache; + delete device->semaphore_pool; + delete device->fence_pool; + + // There should be no more buffers live that use the allocator. + iree_hal_allocator_release(device->device_allocator); + + // Finally, destroy the device. + device->logical_device->ReleaseReference(); + iree_hal_driver_release(device->driver); + + iree_allocator_free(host_allocator, device); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_vulkan_device_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_string_list) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, 0, NULL, &out_string_list->count)); + out_string_list->values = (const char**)arena->AllocateBytes( + out_string_list->count * sizeof(out_string_list->values[0])); + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, out_string_list->count, out_string_list->values, + &out_string_list->count)); + return iree_ok_status(); +} + +iree_status_t iree_hal_vulkan_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + iree_hal_vulkan_features_t enabled_features, + const iree_hal_vulkan_device_options_t* options, + iree_hal_vulkan_syms_t* opaque_syms, VkInstance instance, + VkPhysicalDevice physical_device, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + DynamicSymbols* instance_syms = (DynamicSymbols*)opaque_syms; + // Find the extensions we need (or want) that are also available // on the device. This will fail when required ones are not present. - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableDeviceExtensions(physical_device, - options.extensibility_spec, *syms)); - auto enabled_device_extensions = - PopulateEnabledDeviceExtensions(enabled_extension_names); + // TODO(benvanik): replace with a real arena. + iree::Arena arena(128 * 1024); + iree_hal_vulkan_string_list_t required_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_device_query_extensibility_set( + enabled_features, + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED, &arena, + &required_extensions)); + iree_hal_vulkan_string_list_t optional_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_device_query_extensibility_set( + enabled_features, + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, &arena, + &optional_extensions)); + iree_hal_vulkan_string_list_t enabled_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_match_available_device_extensions( + instance_syms, physical_device, &required_extensions, + &optional_extensions, &arena, &enabled_extensions)); + iree_hal_vulkan_device_extensions_t enabled_device_extensions = + iree_hal_vulkan_populate_enabled_device_extensions(&enabled_extensions); // Find queue families we will expose as HAL queues. - IREE_ASSIGN_OR_RETURN(auto queue_family_info, - SelectQueueFamilies(physical_device, syms)); + iree_hal_vulkan_queue_family_info_t queue_family_info; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_select_queue_families( + physical_device, instance_syms, &queue_family_info)); - // Limit the number of queues we create (for now). - // We may want to allow this to grow, but each queue adds overhead and we - // need to measure to make sure we can effectively use them all. - queue_family_info.dispatch_queue_count = - std::min(2u, queue_family_info.dispatch_queue_count); - queue_family_info.transfer_queue_count = - std::min(1u, queue_family_info.transfer_queue_count); bool has_dedicated_transfer_queues = queue_family_info.transfer_queue_count > 0; + // TODO(benvanik): convert to using the arena. // Setup the queue info we'll be using. // Each queue here (created from within a family) will map to a HAL queue. // @@ -263,34 +657,24 @@ StatusOr> VulkanDevice::Create( // are of the same queue family as the dispatch queues: Vulkan requires that // all queues created from the same family are done in the same // VkDeviceQueueCreateInfo struct. - IREE_DVLOG(1) << "Creating " << queue_family_info.dispatch_queue_count - << " dispatch queue(s) in queue family " - << queue_family_info.dispatch_index; absl::InlinedVector queue_create_info; absl::InlinedVector dispatch_queue_priorities; absl::InlinedVector transfer_queue_priorities; queue_create_info.push_back({}); auto& dispatch_queue_info = queue_create_info.back(); dispatch_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - dispatch_queue_info.pNext = nullptr; + dispatch_queue_info.pNext = NULL; dispatch_queue_info.flags = 0; dispatch_queue_info.queueFamilyIndex = queue_family_info.dispatch_index; dispatch_queue_info.queueCount = queue_family_info.dispatch_queue_count; if (has_dedicated_transfer_queues) { if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - IREE_DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count - << " dedicated transfer queue(s) in shared queue family " - << queue_family_info.transfer_index; dispatch_queue_info.queueCount += queue_family_info.transfer_queue_count; } else { - IREE_DVLOG(1) - << "Creating " << queue_family_info.transfer_queue_count - << " dedicated transfer queue(s) in independent queue family " - << queue_family_info.transfer_index; queue_create_info.push_back({}); auto& transfer_queue_info = queue_create_info.back(); transfer_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - transfer_queue_info.pNext = nullptr; + transfer_queue_info.pNext = NULL; transfer_queue_info.queueFamilyIndex = queue_family_info.transfer_index; transfer_queue_info.queueCount = queue_family_info.transfer_queue_count; transfer_queue_info.flags = 0; @@ -302,548 +686,316 @@ StatusOr> VulkanDevice::Create( dispatch_queue_info.pQueuePriorities = dispatch_queue_priorities.data(); // Create device and its queues. - VkDeviceCreateInfo device_create_info = {}; + VkDeviceCreateInfo device_create_info; + memset(&device_create_info, 0, sizeof(device_create_info)); device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; device_create_info.enabledLayerCount = 0; - device_create_info.ppEnabledLayerNames = nullptr; - device_create_info.enabledExtensionCount = enabled_extension_names.size(); - device_create_info.ppEnabledExtensionNames = enabled_extension_names.data(); + device_create_info.ppEnabledLayerNames = NULL; + device_create_info.enabledExtensionCount = enabled_extensions.count; + device_create_info.ppEnabledExtensionNames = enabled_extensions.values; device_create_info.queueCreateInfoCount = queue_create_info.size(); device_create_info.pQueueCreateInfos = queue_create_info.data(); - device_create_info.pEnabledFeatures = nullptr; + device_create_info.pEnabledFeatures = NULL; - VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features; - std::memset(&semaphore_features, 0, sizeof(semaphore_features)); - semaphore_features.sType = - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES; - semaphore_features.timelineSemaphore = VK_TRUE; VkPhysicalDeviceFeatures2 features2; - std::memset(&features2, 0, sizeof(features2)); + memset(&features2, 0, sizeof(features2)); features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; - features2.pNext = &semaphore_features; - - if (!enabled_device_extensions.timeline_semaphore || - options.force_timeline_semaphore_emulation) { - device_create_info.pNext = nullptr; - } else { - device_create_info.pNext = &features2; - } - - auto logical_device = - make_ref(syms, enabled_device_extensions, - /*owns_device=*/true, /*allocator=*/nullptr); - // The Vulkan loader can leak here, depending on which features are enabled. - // This is out of our control, so disable leak checks. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms->vkCreateDevice(physical_device, &device_create_info, - logical_device->allocator(), - logical_device->mutable_value())); - IREE_RETURN_IF_ERROR(logical_device->syms()->LoadFromDevice( - instance, logical_device->value())); - IREE_ENABLE_LEAK_CHECKS(); - - // Create the device memory allocator. - // TODO(benvanik): allow other types to be plugged in. - IREE_ASSIGN_OR_RETURN( - auto allocator, - VmaAllocator::Create(physical_device, logical_device, instance, - std::move(options.vma_options))); + device_create_info.pNext = &features2; - // Create command pools for each queue family. If we don't have a transfer - // queue then we'll ignore that one and just use the dispatch pool. - // If we wanted to expose the pools through the HAL to allow the VM to more - // effectively manage them (pool per fiber, etc) we could, however I doubt - // the overhead of locking the pool will be even a blip. - IREE_ASSIGN_OR_RETURN(auto dispatch_command_pool, - CreateTransientCommandPool( - logical_device, queue_family_info.dispatch_index)); - ref_ptr transfer_command_pool; - if (has_dedicated_transfer_queues) { - IREE_ASSIGN_OR_RETURN( - transfer_command_pool, - CreateTransientCommandPool(logical_device, - queue_family_info.transfer_index)); + VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features; + bool emulate_timeline_semaphores = + !enabled_device_extensions.timeline_semaphore || + iree_all_bits_set( + options->flags, + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION); + if (!emulate_timeline_semaphores) { + memset(&semaphore_features, 0, sizeof(semaphore_features)); + semaphore_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES; + semaphore_features.pNext = features2.pNext; + features2.pNext = &semaphore_features; + semaphore_features.timelineSemaphore = VK_TRUE; + } + + auto logical_device = new VkDeviceHandle( + instance_syms, enabled_device_extensions, + /*owns_device=*/true, host_allocator, /*allocator=*/NULL); + + iree_status_t status = VK_RESULT_TO_STATUS( + instance_syms->vkCreateDevice(physical_device, &device_create_info, + logical_device->allocator(), + logical_device->mutable_value()), + "vkCreateDevice"); + if (iree_status_is_ok(status)) { + status = logical_device->syms()->LoadFromDevice(instance, + logical_device->value()); } // Select queue indices and create command queues with them. - QueueSet compute_queue_set = {}; - compute_queue_set.queue_family_index = queue_family_info.dispatch_index; - for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) { - compute_queue_set.queue_indices |= 1ull << i; - } - QueueSet transfer_queue_set = {}; - transfer_queue_set.queue_family_index = queue_family_info.transfer_index; - uint32_t base_queue_index = 0; - if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - // Sharing a family, so transfer queues follow compute queues. - base_queue_index = queue_family_info.dispatch_index; - } - for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) { - transfer_queue_set.queue_indices |= 1ull << (i + base_queue_index); + iree_hal_vulkan_queue_set_t compute_queue_set; + iree_hal_vulkan_queue_set_t transfer_queue_set; + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_build_queue_sets( + physical_device, logical_device->syms().get(), &compute_queue_set, + &transfer_queue_set); } - // Emulate timeline semaphores if associated functions are not defined. - ref_ptr semaphore_pool = nullptr; - ref_ptr fence_pool = nullptr; - if (syms->vkGetSemaphoreCounterValue == nullptr || - options.force_timeline_semaphore_emulation) { - IREE_ASSIGN_OR_RETURN(semaphore_pool, TimePointSemaphorePool::Create( - add_ref(logical_device))); - IREE_ASSIGN_OR_RETURN(fence_pool, - TimePointFencePool::Create(add_ref(logical_device))); + // Allocate and initialize the device. + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_device_create_internal( + driver, identifier, options, instance, physical_device, logical_device, + &enabled_device_extensions, &compute_queue_set, &transfer_queue_set, + host_allocator, out_device); } - auto command_queues = - CreateCommandQueues(device_info, logical_device, compute_queue_set, - transfer_queue_set, fence_pool, syms); - - return assign_ref(new VulkanDevice( - std::move(driver), device_info, physical_device, - std::move(logical_device), std::move(allocator), - std::move(command_queues), std::move(dispatch_command_pool), - std::move(transfer_command_pool), std::move(semaphore_pool), - std::move(fence_pool), debug_capture_manager)); + logical_device->ReleaseReference(); + return status; } -// static -StatusOr> VulkanDevice::Wrap( - ref_ptr driver, VkInstance instance, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, VkDevice logical_device, Options options, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set, - const ref_ptr& syms) { - IREE_TRACE_SCOPE0("VulkanDevice::Wrap"); +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_wrap_device( + iree_string_view_t identifier, + const iree_hal_vulkan_device_options_t* options, + const iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + VkPhysicalDevice physical_device, VkDevice logical_device, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(instance_syms); + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(physical_device); + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_device); + + if (iree_math_count_ones_u64(compute_queue_set->queue_indices) == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "at least one compute queue is required"); + } + + // Grab symbols from the device. + auto device_syms = iree::make_ref(); + device_syms->vkGetInstanceProcAddr = + ((const DynamicSymbols*)instance_syms)->vkGetInstanceProcAddr; + IREE_RETURN_IF_ERROR(device_syms->LoadFromDevice(instance, logical_device)); - uint64_t compute_queue_count = - iree_math_count_ones_u64(compute_queue_set.queue_indices); - uint64_t transfer_queue_count = - iree_math_count_ones_u64(transfer_queue_set.queue_indices); - - if (compute_queue_count == 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "At least one compute queue is required"; - } - - // Find the extensions we need (or want) that are also available on the - // device. This will fail when required ones are not present. - // // Since the device is already created, we can't actually enable any // extensions or query if they are really enabled - we just have to trust - // that the caller already enabled them for us (or we may fail later). - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableDeviceExtensions(physical_device, - options.extensibility_spec, *syms)); - auto enabled_device_extensions = - PopulateEnabledDeviceExtensions(enabled_extension_names); + // that the caller already enabled them for us or we may fail later. For the + // optional extensions we check for the symbols but this is not always + // guaranteed to work. + iree_hal_vulkan_device_extensions_t enabled_device_extensions = + iree_hal_vulkan_infer_enabled_device_extensions(device_syms.get()); // Wrap the provided VkDevice with a VkDeviceHandle for use within the HAL. - auto device_handle = - make_ref(syms, enabled_device_extensions, - /*owns_device=*/false, /*allocator=*/nullptr); - *device_handle->mutable_value() = logical_device; - - // Create the device memory allocator. - // TODO(benvanik): allow other types to be plugged in. - IREE_ASSIGN_OR_RETURN( - auto allocator, - VmaAllocator::Create(physical_device, device_handle, instance, - std::move(options.vma_options))); - - bool has_dedicated_transfer_queues = transfer_queue_count > 0; - - // Create command pools for each queue family. If we don't have a transfer - // queue then we'll ignore that one and just use the dispatch pool. - // If we wanted to expose the pools through the HAL to allow the VM to more - // effectively manage them (pool per fiber, etc) we could, however I doubt - // the overhead of locking the pool will be even a blip. - IREE_ASSIGN_OR_RETURN( - auto dispatch_command_pool, - CreateTransientCommandPool(device_handle, - compute_queue_set.queue_family_index)); - ref_ptr transfer_command_pool; - if (has_dedicated_transfer_queues) { - IREE_ASSIGN_OR_RETURN( - transfer_command_pool, - CreateTransientCommandPool(device_handle, - transfer_queue_set.queue_family_index)); - } - - // Emulate timeline semaphores if associated functions are not defined. - ref_ptr semaphore_pool = nullptr; - ref_ptr fence_pool = nullptr; - if (syms->vkGetSemaphoreCounterValue == nullptr || - options.force_timeline_semaphore_emulation) { - IREE_ASSIGN_OR_RETURN( - semaphore_pool, TimePointSemaphorePool::Create(add_ref(device_handle))); - IREE_ASSIGN_OR_RETURN(fence_pool, - TimePointFencePool::Create(add_ref(device_handle))); - } - - auto command_queues = - CreateCommandQueues(device_info, device_handle, compute_queue_set, - transfer_queue_set, fence_pool, syms); - - return assign_ref(new VulkanDevice( - std::move(driver), device_info, physical_device, std::move(device_handle), - std::move(allocator), std::move(command_queues), - std::move(dispatch_command_pool), std::move(transfer_command_pool), - std::move(semaphore_pool), std::move(fence_pool), - /*debug_capture_manager=*/nullptr)); -} - -VulkanDevice::VulkanDevice( - ref_ptr driver, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, ref_ptr logical_device, - std::unique_ptr allocator, - absl::InlinedVector, 4> command_queues, - ref_ptr dispatch_command_pool, - ref_ptr transfer_command_pool, - ref_ptr semaphore_pool, - ref_ptr fence_pool, - DebugCaptureManager* debug_capture_manager) - : Device(device_info), - driver_(std::move(driver)), - physical_device_(physical_device), - logical_device_(std::move(logical_device)), - allocator_(std::move(allocator)), - command_queues_(std::move(command_queues)), - descriptor_pool_cache_( - make_ref(add_ref(logical_device_))), - dispatch_command_pool_(std::move(dispatch_command_pool)), - transfer_command_pool_(std::move(transfer_command_pool)), - semaphore_pool_(std::move(semaphore_pool)), - fence_pool_(std::move(fence_pool)), - debug_capture_manager_(debug_capture_manager) { - // Populate the queue lists based on queue capabilities. - for (auto& command_queue : command_queues_) { - if (command_queue->can_dispatch()) { - dispatch_queues_.push_back(command_queue.get()); - if (transfer_command_pool_ == VK_NULL_HANDLE) { - transfer_queues_.push_back(command_queue.get()); - } - } else { - transfer_queues_.push_back(command_queue.get()); - } - } - - if (debug_capture_manager_ && debug_capture_manager_->is_connected()) { - // Record a capture covering the duration of this VkDevice's lifetime. - debug_capture_manager_->StartCapture(); - } + auto logical_device_handle = new VkDeviceHandle( + device_syms.get(), enabled_device_extensions, + /*owns_device=*/false, host_allocator, /*allocator=*/NULL); + *logical_device_handle->mutable_value() = logical_device; + + // Allocate and initialize the device. + iree_status_t status = iree_hal_vulkan_device_create_internal( + /*driver=*/NULL, identifier, options, instance, physical_device, + logical_device_handle, &enabled_device_extensions, compute_queue_set, + transfer_queue_set, host_allocator, out_device); + + logical_device_handle->ReleaseReference(); + return status; } -VulkanDevice::~VulkanDevice() { - IREE_TRACE_SCOPE0("VulkanDevice::dtor"); - if (debug_capture_manager_ && debug_capture_manager_->is_capturing()) { - debug_capture_manager_->StopCapture(); - } - - // Drop all command queues. These may wait until idle. - command_queues_.clear(); - dispatch_queues_.clear(); - transfer_queues_.clear(); - - // Drop command pools now that we know there are no more outstanding command - // buffers. - dispatch_command_pool_.reset(); - transfer_command_pool_.reset(); - - // Now that no commands are outstanding we can release all descriptor sets. - descriptor_pool_cache_.reset(); - - // Finally, destroy the device. - logical_device_.reset(); -} - -std::string VulkanDevice::DebugString() const { - return absl::StrCat(Device::DebugString(), // - "\n[VulkanDevice]", // - "\n Command Queues: ", command_queues_.size(), // - "\n - Dispatch Queues: ", dispatch_queues_.size(), // - "\n - Transfer Queues: ", transfer_queues_.size()); -} - -ref_ptr VulkanDevice::CreateExecutableCache() { - IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableCache"); - return make_ref(add_ref(logical_device_)); -} - -StatusOr> VulkanDevice::CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateDescriptorSetLayout"); - - absl::InlinedVector native_bindings( - bindings.size()); - for (int i = 0; i < bindings.size(); ++i) { - auto& native_binding = native_bindings[i]; - native_binding.binding = bindings[i].binding; - native_binding.descriptorType = - static_cast(bindings[i].type); - native_binding.descriptorCount = 1; - native_binding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - native_binding.pImmutableSamplers = nullptr; - } - - VkDescriptorSetLayoutCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - if (usage_type == DescriptorSetLayout::UsageType::kPushOnly && - logical_device_->enabled_extensions().push_descriptors) { - // Note that we can *only* use push descriptor sets if we set this create - // flag. If push descriptors aren't supported we emulate them with normal - // descriptors so it's fine to have kPushOnly without support. - create_info.flags |= - VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; - } - create_info.bindingCount = native_bindings.size(); - create_info.pBindings = native_bindings.data(); - - // Create and insert into the cache. - VkDescriptorSetLayout descriptor_set_layout = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateDescriptorSetLayout( - *logical_device_, &create_info, logical_device_->allocator(), - &descriptor_set_layout)); - - return make_ref(add_ref(logical_device_), - descriptor_set_layout); +static iree_string_view_t iree_hal_vulkan_device_id( + iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return device->identifier; } -StatusOr> VulkanDevice::CreateExecutableLayout( - absl::Span set_layouts, size_t push_constants) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableLayout"); - - absl::InlinedVector, 2> typed_set_layouts( - set_layouts.size()); - absl::InlinedVector set_layout_handles( - set_layouts.size()); - for (int i = 0; i < set_layouts.size(); ++i) { - typed_set_layouts[i] = - add_ref(static_cast(set_layouts[i])); - set_layout_handles[i] = typed_set_layouts[i]->handle(); - } - - absl::InlinedVector push_constant_ranges; - if (push_constants > 0) { - push_constant_ranges.push_back(VkPushConstantRange{ - VK_SHADER_STAGE_COMPUTE_BIT, 0, - static_cast(sizeof(uint32_t) * push_constants)}); - } - - VkPipelineLayoutCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - create_info.setLayoutCount = set_layout_handles.size(); - create_info.pSetLayouts = set_layout_handles.data(); - create_info.pushConstantRangeCount = push_constant_ranges.size(); - create_info.pPushConstantRanges = push_constant_ranges.data(); - - // Create and insert into the cache. - VkPipelineLayout pipeline_layout = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreatePipelineLayout( - *logical_device_, &create_info, logical_device_->allocator(), - &pipeline_layout)); - - return make_ref( - add_ref(logical_device_), pipeline_layout, std::move(typed_set_layouts)); +static iree_allocator_t iree_hal_vulkan_device_host_allocator( + iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return device->host_allocator; } -StatusOr> VulkanDevice::CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span bindings) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateDescriptorSet"); - return UnimplementedErrorBuilder(IREE_LOC) - << "CreateDescriptorSet not yet implemented (needs timeline)"; +static iree_hal_allocator_t* iree_hal_vulkan_device_allocator( + iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return device->device_allocator; } -StatusOr> VulkanDevice::CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateCommandBuffer"); +static iree_status_t iree_hal_vulkan_device_create_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); // Select the command pool to used based on the types of commands used. // Note that we may not have a dedicated transfer command pool if there are // no dedicated transfer queues. - ref_ptr command_pool; - if (transfer_command_pool_ && - !AllBitsSet(command_categories, CommandCategory::kDispatch)) { - command_pool = add_ref(transfer_command_pool_); + VkCommandPoolHandle* command_pool = NULL; + if (device->transfer_command_pool && + !iree_all_bits_set(command_categories, + IREE_HAL_COMMAND_CATEGORY_DISPATCH)) { + command_pool = device->transfer_command_pool; } else { - command_pool = add_ref(dispatch_command_pool_); + command_pool = device->dispatch_command_pool; } - VkCommandBufferAllocateInfo allocate_info; - allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - allocate_info.pNext = nullptr; - allocate_info.commandPool = *command_pool; - allocate_info.commandBufferCount = 1; - allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - - VkCommandBuffer command_buffer = VK_NULL_HANDLE; - { - absl::MutexLock lock(command_pool->mutex()); - VK_RETURN_IF_ERROR(syms()->vkAllocateCommandBuffers( - *logical_device_, &allocate_info, &command_buffer)); - } - - // TODO(b/140026716): conditionally enable validation. - auto impl = make_ref( - mode, command_categories, add_ref(descriptor_pool_cache_), - add_ref(command_pool), command_buffer); - return WrapCommandBufferWithValidation(allocator(), std::move(impl)); + return iree_hal_vulkan_direct_command_buffer_allocate( + device->logical_device, command_pool, mode, command_categories, + device->descriptor_pool_cache, out_command_buffer); } -StatusOr> VulkanDevice::CreateEvent() { - IREE_TRACE_SCOPE0("VulkanDevice::CreateEvent"); +static iree_status_t iree_hal_vulkan_device_create_descriptor_set( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_t* set_layout, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set) { + // TODO(benvanik): rework the create fn to take the bindings. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-push descriptor sets still need work"); +} - // TODO(b/138729892): pool events. - VkEventCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - VkEvent event_handle = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateEvent(*logical_device_, &create_info, - logical_device_->allocator(), - &event_handle)); +static iree_status_t iree_hal_vulkan_device_create_descriptor_set_layout( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_native_descriptor_set_layout_create( + device->logical_device, usage_type, binding_count, bindings, + out_descriptor_set_layout); +} - return make_ref(add_ref(logical_device_), event_handle); +static iree_status_t iree_hal_vulkan_device_create_event( + iree_hal_device_t* base_device, iree_hal_event_t** out_event) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_native_event_create(device->logical_device, out_event); } -StatusOr> VulkanDevice::CreateSemaphore( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateSemaphore"); - - if (emulating_timeline_semaphores()) { - return EmulatedTimelineSemaphore::Create( - add_ref(logical_device_), - // Triggers necessary processing on all queues due to new values gotten - // signaled for the given timeline |semaphore|. - // Different clang-format versions disagree about asterisk placement. - // clang-format off - [this](Semaphore* /*semaphore*/) -> Status { - // clang-format on - IREE_TRACE_SCOPE0("::OnSemaphoreSignal"); - for (const auto& queue : command_queues_) { - IREE_RETURN_IF_ERROR( - static_cast(queue.get()) - ->AdvanceQueueSubmission()); - } - return OkStatus(); - }, - // Triggers necessary processing on all queues due to failures for the - // given timeline |semaphore|. - [this](Semaphore* /*semaphore*/) { - IREE_TRACE_SCOPE0("::OnSemaphoreFailure"); - for (const auto& queue : command_queues_) { - static_cast(queue.get()) - ->AbortQueueSubmission(); - } - }, - // Triggers necessary processing on all queues due to the given |fence| - // being signaled. This allows the queue to drop the fence ref it holds - // even when we are not waiting on the queue directly. - [this](absl::Span fences) { - IREE_TRACE_SCOPE0("::OnFenceSignal"); - for (const auto& queue : command_queues_) { - static_cast(queue.get()) - ->SignalFences(fences); - } - }, - add_ref(semaphore_pool_), initial_value); - } - - return NativeTimelineSemaphore::Create(add_ref(logical_device_), - initial_value); +static iree_status_t iree_hal_vulkan_device_create_executable_cache( + iree_hal_device_t* base_device, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_nop_executable_cache_create( + device->logical_device, identifier, out_executable_cache); } -Status VulkanDevice::WaitAllSemaphores( - absl::Span semaphores, Time deadline_ns) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitAllSemaphores"); - return WaitSemaphores(semaphores, deadline_ns, /*wait_flags=*/0); +static iree_status_t iree_hal_vulkan_device_create_executable_layout( + iree_hal_device_t* base_device, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, + iree_hal_executable_layout_t** out_executable_layout) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_native_executable_layout_create( + device->logical_device, set_layout_count, set_layouts, push_constants, + out_executable_layout); } -StatusOr VulkanDevice::WaitAnySemaphore( - absl::Span semaphores, Time deadline_ns) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitAnySemaphore"); - return WaitSemaphores(semaphores, deadline_ns, - /*wait_flags=*/VK_SEMAPHORE_WAIT_ANY_BIT); +static iree_status_t iree_hal_vulkan_device_create_semaphore( + iree_hal_device_t* base_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + if (device->semaphore_pool != NULL) { + return iree_hal_vulkan_emulated_semaphore_create( + device->logical_device, device->semaphore_pool, device->queue_count, + device->queues, initial_value, out_semaphore); + } + return iree_hal_vulkan_native_semaphore_create(device->logical_device, + initial_value, out_semaphore); } -Status VulkanDevice::WaitSemaphores(absl::Span semaphores, - Time deadline_ns, - VkSemaphoreWaitFlags wait_flags) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitSemaphores"); - - if (emulating_timeline_semaphores()) { - // TODO(antiagainst): We actually should get the fences associated with the - // emulated timeline semaphores so that we can wait them in a bunch. This - // implementation is problematic if we wait to wait any and we have the - // first semaphore taking extra long time but the following ones signal - // quickly. - for (int i = 0; i < semaphores.size(); ++i) { - auto* semaphore = - static_cast(semaphores[i].semaphore); - IREE_RETURN_IF_ERROR(semaphore->Wait(semaphores[i].value, deadline_ns)); - if (wait_flags & VK_SEMAPHORE_WAIT_ANY_BIT) return OkStatus(); - } +// Returns the queue to submit work to based on the |queue_affinity|. +static CommandQueue* iree_hal_vulkan_device_select_queue( + iree_hal_vulkan_device_t* device, + iree_hal_command_category_t command_categories, uint64_t queue_affinity) { + // TODO(benvanik): meaningful heuristics for affinity. We don't generate + // anything from the compiler that uses multiple queues and until we do it's + // best not to do anything too clever here. + if (command_categories == IREE_HAL_COMMAND_CATEGORY_TRANSFER) { + return device + ->transfer_queues[queue_affinity % device->transfer_queue_count]; + } + return device->dispatch_queues[queue_affinity % device->dispatch_queue_count]; +} - return OkStatus(); - } +static iree_status_t iree_hal_vulkan_device_queue_submit( + iree_hal_device_t* base_device, + iree_hal_command_category_t command_categories, uint64_t queue_affinity, + iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + CommandQueue* queue = iree_hal_vulkan_device_select_queue( + device, command_categories, queue_affinity); + return queue->Submit(batch_count, batches); +} - absl::InlinedVector semaphore_handles(semaphores.size()); - absl::InlinedVector semaphore_values(semaphores.size()); - for (int i = 0; i < semaphores.size(); ++i) { - semaphore_handles[i] = - static_cast(semaphores[i].semaphore) - ->handle(); - semaphore_values[i] = semaphores[i].value; +static iree_status_t iree_hal_vulkan_device_wait_semaphores_with_deadline( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + VkSemaphoreWaitFlags wait_flags = 0; + if (wait_mode == IREE_HAL_WAIT_MODE_ANY) { + wait_flags |= VK_SEMAPHORE_WAIT_ANY_BIT; } - - VkSemaphoreWaitInfo wait_info; - wait_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; - wait_info.pNext = nullptr; - wait_info.flags = wait_flags; - wait_info.semaphoreCount = semaphore_handles.size(); - wait_info.pSemaphores = semaphore_handles.data(); - wait_info.pValues = semaphore_values.data(); - - // NOTE: this may fail with a timeout (VK_TIMEOUT) or in the case of a - // device loss event may return either VK_SUCCESS *or* VK_ERROR_DEVICE_LOST. - // We may want to explicitly query for device loss after a successful wait - // to ensure we consistently return errors. - uint64_t timeout_ns = - static_cast(DeadlineToRelativeTimeoutNanos(deadline_ns)); - VkResult result = - syms()->vkWaitSemaphores(*logical_device_, &wait_info, timeout_ns); - if (result == VK_ERROR_DEVICE_LOST) { - // Nothing we do now matters. - return VkResultToStatus(result, IREE_LOC); + if (device->semaphore_pool != NULL) { + return iree_hal_vulkan_emulated_semaphore_multi_wait( + device->logical_device, semaphore_list, deadline_ns, wait_flags); } + return iree_hal_vulkan_native_semaphore_multi_wait( + device->logical_device, semaphore_list, deadline_ns, wait_flags); +} - // TODO(benvanik): notify the resource timeline that it should check for the - // semaphores we waited on (including those already expired above). - - return OkStatus(); +static iree_status_t iree_hal_vulkan_device_wait_semaphores_with_timeout( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, + iree_duration_t timeout_ns) { + return iree_hal_vulkan_device_wait_semaphores_with_deadline( + base_device, wait_mode, semaphore_list, + iree_relative_timeout_to_deadline_ns(timeout_ns)); } -Status VulkanDevice::WaitIdle(Time deadline_ns) { - if (deadline_ns == InfiniteFuture()) { +static iree_status_t iree_hal_vulkan_device_wait_idle_with_deadline( + iree_hal_device_t* base_device, iree_time_t deadline_ns) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { // Fast path for using vkDeviceWaitIdle, which is usually cheaper (as it // requires fewer calls into the driver). - IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#vkDeviceWaitIdle"); - VK_RETURN_IF_ERROR(syms()->vkDeviceWaitIdle(*logical_device_)); - return OkStatus(); + return VK_RESULT_TO_STATUS(device->logical_device->syms()->vkDeviceWaitIdle( + *device->logical_device), + "vkDeviceWaitIdle"); } - - IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#Semaphores"); - for (auto& command_queue : command_queues_) { - IREE_RETURN_IF_ERROR(command_queue->WaitIdle(deadline_ns)); + for (iree_host_size_t i = 0; i < device->queue_count; ++i) { + IREE_RETURN_IF_ERROR(device->queues[i]->WaitIdle(deadline_ns)); } - return OkStatus(); + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_device_wait_idle_with_timeout( + iree_hal_device_t* base_device, iree_duration_t timeout_ns) { + return iree_hal_vulkan_device_wait_idle_with_deadline( + base_device, iree_relative_timeout_to_deadline_ns(timeout_ns)); } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_device_vtable_t iree_hal_vulkan_device_vtable = { + /*.destroy=*/iree_hal_vulkan_device_destroy, + /*.id=*/iree_hal_vulkan_device_id, + /*.host_allocator=*/iree_hal_vulkan_device_host_allocator, + /*.device_allocator=*/iree_hal_vulkan_device_allocator, + /*.create_command_buffer=*/iree_hal_vulkan_device_create_command_buffer, + /*.create_descriptor_set=*/iree_hal_vulkan_device_create_descriptor_set, + /*.create_descriptor_set_layout=*/ + iree_hal_vulkan_device_create_descriptor_set_layout, + /*.create_event=*/iree_hal_vulkan_device_create_event, + /*.create_executable_cache=*/ + iree_hal_vulkan_device_create_executable_cache, + /*.create_executable_layout=*/ + iree_hal_vulkan_device_create_executable_layout, + /*.create_semaphore=*/iree_hal_vulkan_device_create_semaphore, + /*.queue_submit=*/iree_hal_vulkan_device_queue_submit, + /*.wait_semaphores_with_deadline=*/ + iree_hal_vulkan_device_wait_semaphores_with_deadline, + /*.wait_semaphores_with_timeout=*/ + iree_hal_vulkan_device_wait_semaphores_with_timeout, + /*.wait_idle_with_deadline=*/ + iree_hal_vulkan_device_wait_idle_with_deadline, + /*.wait_idle_with_timeout=*/ + iree_hal_vulkan_device_wait_idle_with_timeout, +}; diff --git a/iree/hal/vulkan/vulkan_device.h b/iree/hal/vulkan/vulkan_device.h index 98d6ee283c96c..c34ad1ca38b07 100644 --- a/iree/hal/vulkan/vulkan_device.h +++ b/iree/hal/vulkan/vulkan_device.h @@ -15,159 +15,31 @@ #ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_ #define IREE_HAL_VULKAN_VULKAN_DEVICE_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/allocator.h" -#include "iree/hal/debug_capture_manager.h" -#include "iree/hal/device.h" -#include "iree/hal/driver.h" -#include "iree/hal/semaphore.h" -#include "iree/hal/vulkan/descriptor_pool_cache.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/api.h" #include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" #include "iree/hal/vulkan/extensibility_util.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/vma_allocator.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A set of queues within a specific queue family on a VkDevice. -struct QueueSet { - // The index of a particular queue family on a VkPhysicalDevice, as described - // by vkGetPhysicalDeviceQueueFamilyProperties. - uint32_t queue_family_index; - - // Bitfield of queue indices within the queue family at |queue_family_index|. - uint64_t queue_indices; -}; - -class VulkanDevice final : public Device { - public: - struct Options { - // Extensibility descriptions for the device. - ExtensibilitySpec extensibility_spec; - - // Options for Vulkan Memory Allocator (VMA). - VmaAllocator::Options vma_options; - - // Uses timeline semaphore emulation even if native support exists. - bool force_timeline_semaphore_emulation = false; - }; - - // Creates a device that manages its own VkDevice. - static StatusOr> Create( - ref_ptr driver, VkInstance instance, - const DeviceInfo& device_info, VkPhysicalDevice physical_device, - Options options, const ref_ptr& syms, - DebugCaptureManager* debug_capture_manager); - - // Creates a device that wraps an externally managed VkDevice. - static StatusOr> Wrap( - ref_ptr driver, VkInstance instance, - const DeviceInfo& device_info, VkPhysicalDevice physical_device, - VkDevice logical_device, Options options, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set, - const ref_ptr& syms); - - ~VulkanDevice() override; - - std::string DebugString() const override; - - const ref_ptr& syms() const { - return logical_device_->syms(); - } - - Allocator* allocator() const override { return allocator_.get(); } - - absl::Span dispatch_queues() const override { - return absl::MakeSpan(dispatch_queues_); - } - - absl::Span transfer_queues() const override { - return absl::MakeSpan(transfer_queues_); - } - - ref_ptr CreateExecutableCache() override; - StatusOr> CreateDescriptorSetLayout( - DescriptorSetLayout::UsageType usage_type, - absl::Span bindings) override; +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus - StatusOr> CreateExecutableLayout( - absl::Span set_layouts, - size_t push_constants) override; - - StatusOr> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span bindings) override; - - StatusOr> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr> CreateEvent() override; - - StatusOr> CreateSemaphore(uint64_t initial_value) override; - Status WaitAllSemaphores(absl::Span semaphores, - Time deadline_ns) override; - StatusOr WaitAnySemaphore(absl::Span semaphores, - Time deadline_ns) override; - - Status WaitIdle(Time deadline_ns) override; - - private: - VulkanDevice( - ref_ptr driver, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, ref_ptr logical_device, - std::unique_ptr allocator, - absl::InlinedVector, 4> command_queues, - ref_ptr dispatch_command_pool, - ref_ptr transfer_command_pool, - ref_ptr semaphore_pool, - ref_ptr fence_pool, - DebugCaptureManager* debug_capture_manager); - - Status WaitSemaphores(absl::Span semaphores, - Time deadline_ns, VkSemaphoreWaitFlags wait_flags); - - bool emulating_timeline_semaphores() const { - return semaphore_pool_ != nullptr; - } - - ref_ptr driver_; - VkPhysicalDevice physical_device_; - ref_ptr logical_device_; - - std::unique_ptr allocator_; - - mutable absl::InlinedVector, 4> command_queues_; - mutable absl::InlinedVector dispatch_queues_; - mutable absl::InlinedVector transfer_queues_; - - ref_ptr descriptor_pool_cache_; - - ref_ptr dispatch_command_pool_; - ref_ptr transfer_command_pool_; - - // Fields used for emulated timeline semaphores. - ref_ptr semaphore_pool_; - ref_ptr fence_pool_; - - DebugCaptureManager* debug_capture_manager_ = nullptr; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +// Creates a device that owns and manages its own VkDevice. +// +// The |driver| will be retained for as long as the device is live such that if +// the driver owns the |instance| provided it is ensured to be valid. |driver| +// may be NULL if there is no parent driver to retain (such as when wrapping +// existing VkInstances provided by the application). +iree_status_t iree_hal_vulkan_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + iree_hal_vulkan_features_t enabled_features, + const iree_hal_vulkan_device_options_t* options, + iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + VkPhysicalDevice physical_device, iree_allocator_t host_allocator, + iree_hal_device_t** out_device); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_VULKAN_DEVICE_H_ diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc index 0e1c7b504381a..5466f17fcef8c 100644 --- a/iree/hal/vulkan/vulkan_driver.cc +++ b/iree/hal/vulkan/vulkan_driver.cc @@ -16,312 +16,467 @@ #include -#include "absl/container/inlined_vector.h" #include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/target_platform.h" #include "iree/base/tracing.h" -#include "iree/hal/device_info.h" +#include "iree/hal/vulkan/api.h" +#include "iree/hal/vulkan/debug_reporter.h" +#include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/extensibility_util.h" #include "iree/hal/vulkan/status_util.h" +#include "iree/hal/vulkan/vulkan_device.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; -namespace { +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + + // Identifier used for the driver in the IREE driver registry. + // We allow overriding so that multiple Vulkan versions can be exposed in the + // same process. + iree_string_view_t identifier; + + iree_hal_vulkan_device_options_t device_options; + int default_device_index; + + iree_hal_vulkan_features_t enabled_features; + + // Which optional extensions are active and available on the instance. + iree_hal_vulkan_instance_extensions_t instance_extensions; + + // (Partial) loaded Vulkan symbols. Devices created within the driver may have + // different function pointers for device-specific functions that change + // behavior with enabled layers/extensions. + iree::ref_ptr syms; + + // The Vulkan instance that all devices created from the driver will share. + VkInstance instance; + bool owns_instance; + + // Optional debug reporter: may be disabled or unavailable (no debug layers). + iree_hal_vulkan_debug_reporter_t* debug_reporter; +} iree_hal_vulkan_driver_t; + +extern const iree_hal_driver_vtable_t iree_hal_vulkan_driver_vtable; + +static iree_hal_vulkan_driver_t* iree_hal_vulkan_driver_cast( + iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_driver_vtable); + return (iree_hal_vulkan_driver_t*)base_value; +} + +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_driver_options_initialize( + iree_hal_vulkan_driver_options_t* out_options) { + memset(out_options, 0, sizeof(*out_options)); + out_options->api_version = VK_API_VERSION_1_2; + out_options->requested_features = 0; + iree_hal_vulkan_device_options_initialize(&out_options->device_options); + out_options->default_device_index = 0; +} // Returns a VkApplicationInfo struct populated with the default app info. // We may allow hosting applications to override this via weak-linkage if it's // useful, otherwise this is enough to create the application. -VkApplicationInfo GetDefaultApplicationInfo() { - VkApplicationInfo info; - info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - info.pNext = nullptr; - info.pApplicationName = "IREE-ML"; - info.applicationVersion = 0; - info.pEngineName = "IREE"; - info.engineVersion = 0; -#ifdef IREE_PLATFORM_ANDROID - info.apiVersion = VK_API_VERSION_1_1; -#else - info.apiVersion = VK_API_VERSION_1_2; -#endif - return info; +static void iree_hal_vulkan_driver_populate_default_app_info( + const iree_hal_vulkan_driver_options_t* options, + VkApplicationInfo* out_app_info) { + memset(out_app_info, 0, sizeof(*out_app_info)); + out_app_info->sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + out_app_info->pNext = NULL; + out_app_info->pApplicationName = "IREE-ML"; + out_app_info->applicationVersion = 0; + out_app_info->pEngineName = "IREE"; + out_app_info->engineVersion = 0; + out_app_info->apiVersion = options->api_version; } -// Populates device information from the given Vulkan physical device handle. -StatusOr PopulateDeviceInfo(VkPhysicalDevice physical_device, - const ref_ptr& syms) { - VkPhysicalDeviceFeatures physical_device_features; - syms->vkGetPhysicalDeviceFeatures(physical_device, &physical_device_features); - // TODO(benvanik): check and optionally require these features: - // - physical_device_features.robustBufferAccess - // - physical_device_features.shaderInt16 - // - physical_device_features.shaderInt64 - // - physical_device_features.shaderFloat64 +// NOTE: takes ownership of |instance|. +static iree_status_t iree_hal_vulkan_driver_create_internal( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + const iree_hal_vulkan_string_list_t* enabled_extensions, + iree_hal_vulkan_syms_t* opaque_syms, VkInstance instance, + bool owns_instance, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + auto* instance_syms = (DynamicSymbols*)opaque_syms; - VkPhysicalDeviceProperties physical_device_properties; - syms->vkGetPhysicalDeviceProperties(physical_device, - &physical_device_properties); - // TODO(benvanik): check and optionally require reasonable limits. + iree_hal_vulkan_instance_extensions_t instance_extensions = + iree_hal_vulkan_populate_enabled_instance_extensions(enabled_extensions); - // TODO(benvanik): more clever/sanitized device naming. - std::string name = std::string(physical_device_properties.deviceName); - - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - // TODO(benvanik): implement debugging/profiling features. - // TODO(benvanik): use props to determine if we have timing info. - // supported_features |= DeviceFeature::kDebugging; - // supported_features |= DeviceFeature::kCoverage; - // supported_features |= DeviceFeature::kProfiling; - return DeviceInfo("vulkan", std::move(name), supported_features, - reinterpret_cast(physical_device)); -} + // The real debug messenger (not just the static one used above) can now be + // created as we've loaded all the required symbols. + // TODO(benvanik): strip in min-size release builds. + iree_hal_vulkan_debug_reporter_t* debug_reporter = NULL; + if (instance_extensions.debug_utils) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_debug_reporter_allocate( + instance, instance_syms, /*allocation_callbacks=*/NULL, host_allocator, + &debug_reporter)); + } -} // namespace + iree_hal_vulkan_driver_t* driver = NULL; + iree_host_size_t total_size = sizeof(*driver) + identifier.size; + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&driver); + if (!iree_status_is_ok(status)) { + // Need to clean up if we fail (as we own these). + iree_hal_vulkan_debug_reporter_free(debug_reporter); + return status; + } + iree_hal_resource_initialize(&iree_hal_vulkan_driver_vtable, + &driver->resource); + driver->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &driver->identifier, + (char*)driver + total_size - identifier.size); + memcpy(&driver->device_options, &options->device_options, + sizeof(driver->device_options)); + driver->default_device_index = options->default_device_index; + driver->enabled_features = options->requested_features; + driver->syms = iree::add_ref(instance_syms); + driver->instance = instance; + driver->owns_instance = owns_instance; + driver->debug_reporter = debug_reporter; + *out_driver = (iree_hal_driver_t*)driver; + return status; +} -// static -StatusOr> VulkanDriver::Create( - Options options, ref_ptr syms) { - IREE_TRACE_SCOPE0("VulkanDriver::Create"); +static void iree_hal_vulkan_driver_destroy(iree_hal_driver_t* base_driver) { + iree_hal_vulkan_driver_t* driver = iree_hal_vulkan_driver_cast(base_driver); + iree_allocator_t host_allocator = driver->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); - // Load and connect to RenderDoc before instance creation. - // Note: RenderDoc assumes that only a single VkDevice is used: - // https://renderdoc.org/docs/behind_scenes/vulkan_support.html#current-support - std::unique_ptr renderdoc_capture_manager; - if (options.enable_renderdoc) { - renderdoc_capture_manager = std::make_unique(); - IREE_RETURN_IF_ERROR(renderdoc_capture_manager->Connect()); + iree_hal_vulkan_debug_reporter_free(driver->debug_reporter); + if (driver->owns_instance) { + driver->syms->vkDestroyInstance(driver->instance, /*pAllocator=*/NULL); } + driver->syms.reset(); + iree_allocator_free(host_allocator, driver); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_vulkan_driver_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_string_list) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, 0, NULL, &out_string_list->count)); + out_string_list->values = (const char**)arena->AllocateBytes( + out_string_list->count * sizeof(out_string_list->values[0])); + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, out_string_list->count, out_string_list->values, + &out_string_list->count)); + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_driver_compute_enabled_extensibility_sets( + iree::hal::vulkan::DynamicSymbols* syms, + iree_hal_vulkan_features_t requested_features, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_enabled_layers, + iree_hal_vulkan_string_list_t* out_enabled_extensions) { + // Query our required and optional layers and extensions based on the IREE + // features the user requested. + iree_hal_vulkan_string_list_t required_layers; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, arena, + &required_layers)); + iree_hal_vulkan_string_list_t optional_layers; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, arena, + &optional_layers)); + iree_hal_vulkan_string_list_t required_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED, arena, + &required_extensions)); + iree_hal_vulkan_string_list_t optional_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL, arena, + &optional_extensions)); // Find the layers and extensions we need (or want) that are also available // on the instance. This will fail when required ones are not present. - IREE_ASSIGN_OR_RETURN( - auto enabled_layer_names, - MatchAvailableInstanceLayers(options.instance_extensibility, *syms)); - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableInstanceExtensions(options.instance_extensibility, *syms)); - auto instance_extensions = - PopulateEnabledInstanceExtensions(enabled_extension_names); + IREE_RETURN_IF_ERROR(iree_hal_vulkan_match_available_instance_layers( + syms, &required_layers, &optional_layers, arena, out_enabled_layers)); + IREE_RETURN_IF_ERROR(iree_hal_vulkan_match_available_instance_extensions( + syms, &required_extensions, &optional_extensions, arena, + out_enabled_extensions)); + + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* opaque_syms, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(opaque_syms); + IREE_ASSERT_ARGUMENT(out_driver); + IREE_TRACE_SCOPE(); + + auto* instance_syms = (DynamicSymbols*)opaque_syms; + + // Query required and optional instance layers/extensions for the requested + // features. + iree::Arena arena; + iree_hal_vulkan_string_list_t enabled_layers; + iree_hal_vulkan_string_list_t enabled_extensions; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_driver_compute_enabled_extensibility_sets( + instance_syms, options->requested_features, &arena, &enabled_layers, + &enabled_extensions)); // Create the instance this driver will use for all requests. - VkApplicationInfo app_info = GetDefaultApplicationInfo(); - app_info.apiVersion = options.api_version; + VkApplicationInfo app_info; + iree_hal_vulkan_driver_populate_default_app_info(options, &app_info); VkInstanceCreateInfo create_info; create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - create_info.pNext = nullptr; + create_info.pNext = NULL; create_info.flags = 0; create_info.pApplicationInfo = &app_info; - create_info.enabledLayerCount = - static_cast(enabled_layer_names.size()); - create_info.ppEnabledLayerNames = enabled_layer_names.data(); - create_info.enabledExtensionCount = - static_cast(enabled_extension_names.size()); - create_info.ppEnabledExtensionNames = enabled_extension_names.data(); - - // If we have the debug_utils extension then we can chain a one-shot messenger - // callback that we can use to log out the instance creation errors. Once we - // have the real instance we can then register a real messenger. - union { - VkDebugUtilsMessengerCreateInfoEXT debug_utils_create_info; - VkDebugReportCallbackCreateInfoEXT debug_report_create_info; - }; - if (instance_extensions.debug_utils) { - create_info.pNext = &debug_utils_create_info; - DebugReporter::PopulateStaticCreateInfo(&debug_utils_create_info); - } else if (instance_extensions.debug_report) { - create_info.pNext = &debug_report_create_info; - DebugReporter::PopulateStaticCreateInfo(&debug_report_create_info); - } + create_info.enabledLayerCount = enabled_layers.count; + create_info.ppEnabledLayerNames = enabled_layers.values; + create_info.enabledExtensionCount = enabled_extensions.count; + create_info.ppEnabledExtensionNames = enabled_extensions.values; - // Some ICDs appear to leak in here, out of our control. - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); VkInstance instance = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR( - syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)) - << "Unable to create Vulkan instance"; - IREE_ENABLE_LEAK_CHECKS(); - - // TODO(benvanik): enable validation layers if needed. + VK_RETURN_IF_ERROR(instance_syms->vkCreateInstance( + &create_info, /*pAllocator=*/NULL, &instance), + "vkCreateInstance: invalid instance configuration"); // Now that the instance has been created we can fetch all of the instance // symbols. - IREE_RETURN_IF_ERROR(syms->LoadFromInstance(instance)); + iree_status_t status = instance_syms->LoadFromInstance(instance); - // The real debug messenger (not just the static one used above) can now be - // created as we've loaded all the required symbols. - // TODO(benvanik): strip in release builds. - std::unique_ptr debug_reporter; - if (instance_extensions.debug_utils) { - IREE_ASSIGN_OR_RETURN(debug_reporter, - DebugReporter::CreateDebugUtilsMessenger( - instance, syms, - /*allocation_callbacks=*/nullptr)); - } else if (instance_extensions.debug_report) { - IREE_ASSIGN_OR_RETURN( - debug_reporter, DebugReporter::CreateDebugReportCallback( - instance, syms, /*allocation_callbacks=*/nullptr)); + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_driver_create_internal( + identifier, options, &enabled_extensions, opaque_syms, instance, + /*owns_instance=*/true, host_allocator, out_driver); } - return assign_ref(new VulkanDriver( - std::move(syms), instance, - /*owns_instance=*/true, std::move(options.device_options), - options.default_device_index, std::move(debug_reporter), - std::move(renderdoc_capture_manager))); + if (!iree_status_is_ok(status)) { + instance_syms->vkDestroyInstance(instance, /*pAllocator=*/NULL); + } + return status; } -// static -StatusOr> VulkanDriver::CreateUsingInstance( - Options options, ref_ptr syms, VkInstance instance) { - IREE_TRACE_SCOPE0("VulkanDriver::CreateUsingInstance"); - +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_driver_create_using_instance( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* opaque_syms, VkInstance instance, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(opaque_syms); + IREE_ASSERT_ARGUMENT(out_driver); if (instance == VK_NULL_HANDLE) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "VkInstance must not be VK_NULL_HANDLE"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "a non-NULL VkInstance must be provided"); } + IREE_TRACE_ZONE_BEGIN(z0); - // Find the extensions we need (or want) that are also available on the - // instance. This will fail when required ones are not present. - // - // Since the instance is already created, we can't actually enable any - // extensions or query if they are really enabled - we just have to trust - // that the caller already enabled them for us (or we may fail later). - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableInstanceExtensions(options.instance_extensibility, *syms)); - auto instance_extensions = - PopulateEnabledInstanceExtensions(enabled_extension_names); + // May be a no-op but don't rely on that so we can be sure we have the right + // function pointers. + auto* instance_syms = (DynamicSymbols*)opaque_syms; + IREE_RETURN_IF_ERROR(instance_syms->LoadFromInstance(instance)); - IREE_RETURN_IF_ERROR(syms->LoadFromInstance(instance)); + // Since the instance is already created we can't actually enable any + // extensions or even query if they are really enabled - we just have to trust + // that the caller already enabled them for us (or we may fail later). + iree::Arena arena; + iree_hal_vulkan_string_list_t enabled_layers; + iree_hal_vulkan_string_list_t enabled_extensions; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_driver_compute_enabled_extensibility_sets( + instance_syms, options->requested_features, &arena, &enabled_layers, + &enabled_extensions)); + + iree_status_t status = iree_hal_vulkan_driver_create_internal( + identifier, options, &enabled_extensions, opaque_syms, instance, + /*owns_instance=*/true, host_allocator, out_driver); + IREE_TRACE_ZONE_END(z0); + return status; +} - // TODO(benvanik): strip in release builds. - std::unique_ptr debug_reporter; - if (instance_extensions.debug_utils) { - IREE_ASSIGN_OR_RETURN(debug_reporter, - DebugReporter::CreateDebugUtilsMessenger( - instance, syms, - /*allocation_callbacks=*/nullptr)); - } else if (instance_extensions.debug_report) { - IREE_ASSIGN_OR_RETURN( - debug_reporter, DebugReporter::CreateDebugReportCallback( - instance, syms, /*allocation_callbacks=*/nullptr)); +// Enumerates all physical devices on |instance| and returns them as an +// allocated list in |out_physical_devices|, which must be freed by the caller. +static iree_status_t iree_hal_vulkan_driver_enumerate_physical_devices( + iree::hal::vulkan::DynamicSymbols* instance_syms, VkInstance instance, + iree_allocator_t host_allocator, uint32_t* out_physical_device_count, + VkPhysicalDevice** out_physical_devices) { + uint32_t physical_device_count = 0; + VK_RETURN_IF_ERROR(instance_syms->vkEnumeratePhysicalDevices( + instance, &physical_device_count, NULL), + "vkEnumeratePhysicalDevices"); + VkPhysicalDevice* physical_devices = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + host_allocator, physical_device_count * sizeof(physical_devices), + (void**)&physical_devices)); + iree_status_t status = VK_RESULT_TO_STATUS( + instance_syms->vkEnumeratePhysicalDevices( + instance, &physical_device_count, physical_devices), + "vkEnumeratePhysicalDevices"); + if (iree_status_is_ok(status)) { + *out_physical_device_count = physical_device_count; + *out_physical_devices = physical_devices; + } else { + iree_allocator_free(host_allocator, physical_devices); } + return status; +} - // Note: no RenderDocCaptureManager here since the VkInstance is already - // created externally. Applications using this function must provide their - // own RenderDoc / debugger integration as desired. - - return assign_ref( - new VulkanDriver(std::move(syms), instance, /*owns_instance=*/false, - std::move(options.device_options), - options.default_device_index, std::move(debug_reporter), - /*debug_capture_manager=*/nullptr)); +// Returns the size, in bytes, of the iree_hal_device_info_t storage required +// for holding the given |physical_device|. +static iree_host_size_t iree_hal_vulkan_calculate_device_info_size( + VkPhysicalDevice physical_device, iree::hal::vulkan::DynamicSymbols* syms) { + VkPhysicalDeviceProperties physical_device_properties; + syms->vkGetPhysicalDeviceProperties(physical_device, + &physical_device_properties); + return strlen(physical_device_properties.deviceName); } -VulkanDriver::VulkanDriver( - ref_ptr syms, VkInstance instance, bool owns_instance, - VulkanDevice::Options device_options, int default_device_index, - std::unique_ptr debug_reporter, - std::unique_ptr renderdoc_capture_manager) - : Driver("vulkan"), - syms_(std::move(syms)), - instance_(instance), - owns_instance_(owns_instance), - device_options_(std::move(device_options)), - default_device_index_(default_device_index), - debug_reporter_(std::move(debug_reporter)), - renderdoc_capture_manager_(std::move(renderdoc_capture_manager)) {} - -VulkanDriver::~VulkanDriver() { - IREE_TRACE_SCOPE0("VulkanDriver::dtor"); - debug_reporter_.reset(); - if (owns_instance_) { - syms()->vkDestroyInstance(instance_, /*pAllocator=*/nullptr); - } +// Populates device information from the given Vulkan physical device handle. +// |out_device_info| must point to valid memory and additional data will be +// appended to |buffer_ptr| and the new pointer is returned. +static uint8_t* iree_hal_vulkan_populate_device_info( + VkPhysicalDevice physical_device, DynamicSymbols* syms, uint8_t* buffer_ptr, + iree_hal_device_info_t* out_device_info) { + memset(out_device_info, 0, sizeof(*out_device_info)); + out_device_info->device_id = (iree_hal_device_id_t)physical_device; + + VkPhysicalDeviceFeatures physical_device_features; + syms->vkGetPhysicalDeviceFeatures(physical_device, &physical_device_features); + // TODO(benvanik): check and optionally require these features: + // - physical_device_features.robustBufferAccess + // - physical_device_features.shaderInt16 + // - physical_device_features.shaderInt64 + // - physical_device_features.shaderFloat64 + + VkPhysicalDeviceProperties physical_device_properties; + syms->vkGetPhysicalDeviceProperties(physical_device, + &physical_device_properties); + // TODO(benvanik): check and optionally require reasonable limits. + + // TODO(benvanik): more clever/sanitized device naming. + iree_string_view_t device_name = + iree_make_string_view(physical_device_properties.deviceName, + strlen(physical_device_properties.deviceName)); + buffer_ptr += iree_string_view_append_to_buffer( + device_name, &out_device_info->name, (char*)buffer_ptr); + + return buffer_ptr; } -StatusOr> VulkanDriver::EnumerateAvailableDevices() { - IREE_TRACE_SCOPE0("VulkanDriver::EnumerateAvailableDevices"); +static iree_status_t iree_hal_vulkan_driver_query_available_devices( + iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, + iree_hal_device_info_t** out_device_infos, + iree_host_size_t* out_device_info_count) { + iree_hal_vulkan_driver_t* driver = iree_hal_vulkan_driver_cast(base_driver); - // Query all available devices (at this moment, note that this may change!). + // Query all devices from the Vulkan instance. uint32_t physical_device_count = 0; - VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( - instance_, &physical_device_count, nullptr)); - absl::InlinedVector physical_devices( - physical_device_count); - VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( - instance_, &physical_device_count, physical_devices.data())); - - // Convert to our HAL structure. - std::vector device_infos; - device_infos.reserve(physical_device_count); - for (auto physical_device : physical_devices) { - // TODO(benvanik): if we fail should we just ignore the device in the list? - IREE_ASSIGN_OR_RETURN(auto device_info, - PopulateDeviceInfo(physical_device, syms())); - device_infos.push_back(std::move(device_info)); + VkPhysicalDevice* physical_devices = NULL; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_enumerate_physical_devices( + driver->syms.get(), driver->instance, host_allocator, + &physical_device_count, &physical_devices)); + + // Allocate the return infos and populate with the devices. + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t total_size = + physical_device_count * sizeof(iree_hal_device_info_t); + for (uint32_t i = 0; i < physical_device_count; ++i) { + total_size += iree_hal_vulkan_calculate_device_info_size( + physical_devices[i], driver->syms.get()); + } + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); + if (iree_status_is_ok(status)) { + uint8_t* buffer_ptr = + (uint8_t*)device_infos + + physical_device_count * sizeof(iree_hal_device_info_t); + for (uint32_t i = 0; i < physical_device_count; ++i) { + buffer_ptr = iree_hal_vulkan_populate_device_info( + physical_devices[i], driver->syms.get(), buffer_ptr, + &device_infos[i]); + } + *out_device_info_count = physical_device_count; + *out_device_infos = device_infos; } - return device_infos; -} -StatusOr> VulkanDriver::CreateDefaultDevice() { - IREE_TRACE_SCOPE0("VulkanDriver::CreateDefaultDevice"); + iree_allocator_free(host_allocator, physical_devices); + return status; +} - // Query available devices. - IREE_ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices()); - if (default_device_index_ < 0 || - default_device_index_ >= available_devices.size()) { - return NotFoundErrorBuilder(IREE_LOC) - << "Device index " << default_device_index_ << " not found " - << "(of " << available_devices.size() << ")"; +static iree_status_t iree_hal_vulkan_driver_select_default_device( + iree::hal::vulkan::DynamicSymbols* instance_syms, VkInstance instance, + int default_device_index, iree_allocator_t host_allocator, + VkPhysicalDevice* out_physical_device) { + uint32_t physical_device_count = 0; + VkPhysicalDevice* physical_devices = NULL; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_enumerate_physical_devices( + instance_syms, instance, host_allocator, &physical_device_count, + &physical_devices)); + iree_status_t status = iree_ok_status(); + if (physical_device_count == 0 || + default_device_index >= physical_device_count) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "default device %d not found (of %d enumerated)", + default_device_index, physical_device_count); + } else { + *out_physical_device = physical_devices[default_device_index]; } - - // Just create the first one we find. - return CreateDevice(available_devices[default_device_index_].device_id()); + iree_allocator_free(host_allocator, physical_devices); + return status; } -StatusOr> VulkanDriver::CreateDevice(DriverDeviceID device_id) { - IREE_TRACE_SCOPE0("VulkanDriver::CreateDevice"); +static iree_status_t iree_hal_vulkan_driver_create_device( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_vulkan_driver_t* driver = iree_hal_vulkan_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Use either the specified device (enumerated earlier) or whatever default + // one was specified when the driver was created. + VkPhysicalDevice physical_device = (VkPhysicalDevice)device_id; + if (physical_device == VK_NULL_HANDLE) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_hal_vulkan_driver_select_default_device( + driver->syms.get(), driver->instance, driver->default_device_index, + host_allocator, &physical_device)); + } - auto physical_device = reinterpret_cast(device_id); - IREE_ASSIGN_OR_RETURN(auto device_info, - PopulateDeviceInfo(physical_device, syms())); + // TODO(benvanik): remove HAL module dependence on the identifier for matching + // devices. Today it *must* be vulkan* to work, whereas really that should be + // a device type (vs the identifier, which is arbitrary). + // Query the device name to use as an identifier. + // VkPhysicalDeviceProperties physical_device_properties; + // driver->syms->vkGetPhysicalDeviceProperties(physical_device, + // &physical_device_properties); + // iree_string_view_t device_name = + // iree_make_string_view(physical_device_properties.deviceName, + // strlen(physical_device_properties.deviceName)); + iree_string_view_t device_name = iree_make_cstring_view("vulkan"); // Attempt to create the device. // This may fail if the device was enumerated but is in exclusive use, // disabled by the system, or permission is denied. - IREE_ASSIGN_OR_RETURN( - auto device, - VulkanDevice::Create(add_ref(this), instance(), device_info, - physical_device, device_options_, syms(), - renderdoc_capture_manager_.get())); - - IREE_LOG(INFO) << "Created Vulkan Device: " << device->info().name(); + iree_status_t status = iree_hal_vulkan_device_create( + base_driver, device_name, driver->enabled_features, + &driver->device_options, (iree_hal_vulkan_syms_t*)driver->syms.get(), + driver->instance, physical_device, host_allocator, out_device); - return device; -} - -StatusOr> VulkanDriver::WrapDevice( - VkPhysicalDevice physical_device, VkDevice logical_device, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set) { - IREE_TRACE_SCOPE0("VulkanDriver::WrapDevice"); - - IREE_ASSIGN_OR_RETURN(auto device_info, - PopulateDeviceInfo(physical_device, syms())); - - // Attempt to create the device. - // This may fail if the VkDevice does not support all necessary features. - IREE_ASSIGN_OR_RETURN( - auto device, - VulkanDevice::Wrap(add_ref(this), instance(), device_info, - physical_device, logical_device, device_options_, - compute_queue_set, transfer_queue_set, syms())); - return device; + IREE_TRACE_ZONE_END(z0); + return status; } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_driver_vtable_t iree_hal_vulkan_driver_vtable = { + /*.destroy=*/iree_hal_vulkan_driver_destroy, + /*.query_available_devices=*/ + iree_hal_vulkan_driver_query_available_devices, + /*.create_device=*/iree_hal_vulkan_driver_create_device, +}; diff --git a/iree/hal/vulkan/vulkan_driver.h b/iree/hal/vulkan/vulkan_driver.h index 611b255eaa9c1..8f53786eaced0 100644 --- a/iree/hal/vulkan/vulkan_driver.h +++ b/iree/hal/vulkan/vulkan_driver.h @@ -15,105 +15,11 @@ #ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_ #define IREE_HAL_VULKAN_VULKAN_DRIVER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on +#include "iree/hal/api.h" +#include "iree/hal/vulkan/api.h" -#include -#include - -#include "iree/hal/driver.h" -#include "iree/hal/vulkan/debug_reporter.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/extensibility_util.h" -#include "iree/hal/vulkan/renderdoc_capture_manager.h" -#include "iree/hal/vulkan/vulkan_device.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class VulkanDriver final : public Driver { - public: - struct Options { - // Vulkan version that will be requested. - // Driver creation will fail if the required version is not available. - uint32_t api_version = VK_API_VERSION_1_0; - - // Extensibility descriptions for instances. - // See VulkanDevice::Options for device extensibility descriptions. - ExtensibilitySpec instance_extensibility; - - // Options to use for all devices created by the driver. - VulkanDevice::Options device_options; - - // Index of the default Vulkan device to use within the list of available - // devices. Devices are discovered via vkEnumeratePhysicalDevices then - // considered "available" if compatible with the driver options. - int default_device_index = 0; - - // Enables RenderDoc integration, connecting via RenderDoc's API and - // recording Vulkan calls for offline inspection and debugging. - bool enable_renderdoc = false; - }; - - // Creates a VulkanDriver that manages its own VkInstance. - static StatusOr> Create(Options options, - ref_ptr syms); - - // Creates a VulkanDriver that shares an externally managed VkInstance. - // - // |options| are checked for compatibility. - // - // |syms| must at least have |vkGetInstanceProcAddr| set. Other symbols will - // be loaded as needed from |instance|. - // - // |instance| must remain valid for the life of the returned VulkanDriver. - static StatusOr> CreateUsingInstance( - Options options, ref_ptr syms, VkInstance instance); - - ~VulkanDriver() override; - - const ref_ptr& syms() const { return syms_; } - - VkInstance instance() const { return instance_; } - - StatusOr> EnumerateAvailableDevices() override; - - StatusOr> CreateDefaultDevice() override; - - StatusOr> CreateDevice(DriverDeviceID device_id) override; - - // Creates a device that wraps an externally managed VkDevice. - // - // The device will schedule commands against the provided queues. - StatusOr> WrapDevice(VkPhysicalDevice physical_device, - VkDevice logical_device, - const QueueSet& compute_queue_set, - const QueueSet& transfer_queue_set); - - DebugCaptureManager* debug_capture_manager() override { - return renderdoc_capture_manager_.get(); - } - - private: - VulkanDriver( - ref_ptr syms, VkInstance instance, bool owns_instance, - VulkanDevice::Options device_options, int default_device_index, - std::unique_ptr debug_reporter, - std::unique_ptr renderdoc_capture_manager); - - ref_ptr syms_; - VkInstance instance_; - bool owns_instance_; - VulkanDevice::Options device_options_; - int default_device_index_; - std::unique_ptr debug_reporter_; - std::unique_ptr renderdoc_capture_manager_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +// NOTE: the driver API calls are defined in api.h. +// TODO(benvanik): clean that up? api.h is nice because then we only need to +// deploy a single header file for the backend, but it is a bit tricky. #endif // IREE_HAL_VULKAN_VULKAN_DRIVER_H_ diff --git a/iree/modules/check/check_test.cc b/iree/modules/check/check_test.cc index 37598a7731639..eb1c130ab1f36 100644 --- a/iree/modules/check/check_test.cc +++ b/iree/modules/check/check_test.cc @@ -94,17 +94,11 @@ class CheckTest : public ::testing::Test { IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE), IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(int32_t), &buffer)); - iree_hal_mapped_memory_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(), - IREE_HAL_MEMORY_ACCESS_WRITE, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); - memcpy(mapped_memory.contents.data, - static_cast(contents.data()), - mapped_memory.contents.data_length); - IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory)); + IREE_ASSERT_OK(iree_hal_buffer_write_data( + buffer.get(), 0, contents.data(), contents.size() * sizeof(int32_t))); IREE_ASSERT_OK(iree_hal_buffer_view_create( buffer.get(), shape.data(), shape.size(), IREE_HAL_ELEMENT_TYPE_SINT_32, - iree_allocator_system(), &*out_buffer_view)); + &*out_buffer_view)); } void CreateFloat32BufferView(absl::Span contents, @@ -122,18 +116,11 @@ class CheckTest : public ::testing::Test { IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE), IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(float), &buffer)); - iree_hal_mapped_memory_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(), - IREE_HAL_MEMORY_ACCESS_WRITE, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); - memcpy(mapped_memory.contents.data, - static_cast(contents.data()), - mapped_memory.contents.data_length); - IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory)); + IREE_ASSERT_OK(iree_hal_buffer_write_data(buffer.get(), 0, contents.data(), + contents.size() * sizeof(float))); IREE_ASSERT_OK(iree_hal_buffer_view_create( buffer.get(), shape.data(), shape.size(), - IREE_HAL_ELEMENT_TYPE_FLOAT_32, iree_allocator_system(), - &*out_buffer_view)); + IREE_HAL_ELEMENT_TYPE_FLOAT_32, &*out_buffer_view)); } void CreateFloat64BufferView(absl::Span contents, @@ -151,18 +138,11 @@ class CheckTest : public ::testing::Test { IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE), IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(double), &buffer)); - iree_hal_mapped_memory_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(), - IREE_HAL_MEMORY_ACCESS_WRITE, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); - memcpy(mapped_memory.contents.data, - static_cast(contents.data()), - mapped_memory.contents.data_length); - IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory)); + IREE_ASSERT_OK(iree_hal_buffer_write_data( + buffer.get(), 0, contents.data(), contents.size() * sizeof(double))); IREE_ASSERT_OK(iree_hal_buffer_view_create( buffer.get(), shape.data(), shape.size(), - IREE_HAL_ELEMENT_TYPE_FLOAT_64, iree_allocator_system(), - &*out_buffer_view)); + IREE_HAL_ELEMENT_TYPE_FLOAT_64, &*out_buffer_view)); } Status Invoke(absl::string_view function_name) { diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc index 5686a1fc3ee72..74287a764852c 100644 --- a/iree/modules/check/native_module.cc +++ b/iree/modules/check/native_module.cc @@ -182,13 +182,13 @@ class CheckModuleState final { iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(view); iree_hal_buffer_t* buf = iree_hal_buffer_view_buffer(view); - iree_hal_mapped_memory_t mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map( + iree_hal_buffer_mapping_t mapped_memory; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( buf, IREE_HAL_MEMORY_ACCESS_READ, /*byte_offset=*/0, IREE_WHOLE_BUFFER, &mapped_memory)); IREE_RETURN_IF_ERROR( ::iree::ExpectAllTrue(mapped_memory.contents, element_type)); - iree_hal_buffer_unmap(buf, &mapped_memory); + iree_hal_buffer_unmap_range(&mapped_memory); return OkStatus(); } @@ -212,13 +212,13 @@ class CheckModuleState final { iree_hal_buffer_view_element_type(rhs); iree_hal_buffer_t* lhs_buf = iree_hal_buffer_view_buffer(lhs); - iree_hal_mapped_memory_t lhs_mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map( + iree_hal_buffer_mapping_t lhs_mapped_memory; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( lhs_buf, IREE_HAL_MEMORY_ACCESS_READ, /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory)); iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs); - iree_hal_mapped_memory_t rhs_mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map( + iree_hal_buffer_mapping_t rhs_mapped_memory; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( rhs_buf, IREE_HAL_MEMORY_ACCESS_READ, /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory)); @@ -226,8 +226,8 @@ class CheckModuleState final { bool shape_eq = lhs_shape == rhs_shape; bool contents_eq = EqByteSpan(lhs_mapped_memory.contents, rhs_mapped_memory.contents); - iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory); - iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory); + iree_hal_buffer_unmap_range(&lhs_mapped_memory); + iree_hal_buffer_unmap_range(&rhs_mapped_memory); if (!element_types_eq || !shape_eq || !contents_eq) { std::ostringstream os; @@ -281,13 +281,13 @@ class CheckModuleState final { iree_hal_buffer_view_element_type(rhs); iree_hal_buffer_t* lhs_buf = iree_hal_buffer_view_buffer(lhs); - iree_hal_mapped_memory_t lhs_mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map( + iree_hal_buffer_mapping_t lhs_mapped_memory; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( lhs_buf, IREE_HAL_MEMORY_ACCESS_READ, /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory)); iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs); - iree_hal_mapped_memory_t rhs_mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map( + iree_hal_buffer_mapping_t rhs_mapped_memory; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( rhs_buf, IREE_HAL_MEMORY_ACCESS_READ, /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory)); @@ -301,8 +301,8 @@ class CheckModuleState final { AlmostEqByteSpan(lhs_mapped_memory.contents, rhs_mapped_memory.contents, lhs_element_type)); } - iree_hal_buffer_unmap(lhs_buf, &lhs_mapped_memory); - iree_hal_buffer_unmap(rhs_buf, &rhs_mapped_memory); + iree_hal_buffer_unmap_range(&lhs_mapped_memory); + iree_hal_buffer_unmap_range(&rhs_mapped_memory); if (!element_types_eq || !shape_eq || !contents_could_be_almost_eq) { std::ostringstream os; diff --git a/iree/modules/hal/BUILD b/iree/modules/hal/BUILD index 6780425259859..6c6d53806c0e9 100644 --- a/iree/modules/hal/BUILD +++ b/iree/modules/hal/BUILD @@ -25,7 +25,6 @@ cc_library( deps = [ "//iree/base:api", "//iree/base:tracing", - "//iree/hal", "//iree/hal:api", "//iree/vm", "//iree/vm:cc", diff --git a/iree/modules/hal/CMakeLists.txt b/iree/modules/hal/CMakeLists.txt index 636a96f56a4fe..36800aaa0b965 100644 --- a/iree/modules/hal/CMakeLists.txt +++ b/iree/modules/hal/CMakeLists.txt @@ -28,7 +28,6 @@ iree_cc_library( absl::span iree::base::api iree::base::tracing - iree::hal iree::hal::api iree::vm iree::vm::cc diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc index 9725e354a2909..c10883a8ecc5f 100644 --- a/iree/modules/hal/hal_module.cc +++ b/iree/modules/hal/hal_module.cc @@ -21,14 +21,8 @@ #include "iree/base/api.h" #include "iree/base/tracing.h" #include "iree/hal/api.h" -#include "iree/hal/api_detail.h" -#include "iree/hal/device.h" #include "iree/vm/native_module_cc.h" -namespace iree { -namespace hal { -namespace { - //===----------------------------------------------------------------------===// // Type registration //===----------------------------------------------------------------------===// @@ -41,36 +35,61 @@ static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_descriptor = {0}; static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_layout_descriptor = {0}; static iree_vm_ref_type_descriptor_t iree_hal_device_descriptor = {0}; +static iree_vm_ref_type_descriptor_t iree_hal_event_descriptor = {0}; static iree_vm_ref_type_descriptor_t iree_hal_executable_descriptor = {0}; static iree_vm_ref_type_descriptor_t iree_hal_executable_cache_descriptor = {0}; static iree_vm_ref_type_descriptor_t iree_hal_executable_layout_descriptor = { 0}; static iree_vm_ref_type_descriptor_t iree_hal_semaphore_descriptor = {0}; +#define IREE_VM_REGISTER_HAL_C_TYPE(type, name, destroy_fn, descriptor) \ + descriptor.type_name = iree_make_cstring_view(name); \ + descriptor.offsetof_counter = offsetof(iree_hal_resource_t, ref_count); \ + descriptor.destroy = (iree_vm_ref_destroy_t)destroy_fn; \ + IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor)); + IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_module_register_types() { static bool has_registered = false; if (has_registered) return iree_ok_status(); - IREE_VM_REGISTER_CC_TYPE(Allocator, "hal.allocator", - iree_hal_allocator_descriptor); - IREE_VM_REGISTER_CC_TYPE(Buffer, "hal.buffer", iree_hal_buffer_descriptor); - IREE_VM_REGISTER_CC_TYPE(iree_hal_buffer_view, "hal.buffer_view", - iree_hal_buffer_view_descriptor); - IREE_VM_REGISTER_CC_TYPE(CommandBuffer, "hal.command_buffer", - iree_hal_command_buffer_descriptor); - IREE_VM_REGISTER_CC_TYPE(DescriptorSet, "hal.descriptor_set", - iree_hal_descriptor_set_descriptor); - IREE_VM_REGISTER_CC_TYPE(DescriptorSetLayout, "hal.descriptor_set_layout", - iree_hal_descriptor_set_layout_descriptor); - IREE_VM_REGISTER_CC_TYPE(Device, "hal.device", iree_hal_device_descriptor); - IREE_VM_REGISTER_CC_TYPE(Executable, "hal.executable", - iree_hal_executable_descriptor); - IREE_VM_REGISTER_CC_TYPE(ExecutableCache, "hal.executable_cache", - iree_hal_executable_cache_descriptor); - IREE_VM_REGISTER_CC_TYPE(ExecutableLayout, "hal.executable_layout", - iree_hal_executable_layout_descriptor); - IREE_VM_REGISTER_CC_TYPE(Semaphore, "hal.semaphore", - iree_hal_semaphore_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_allocator_t, "hal.allocator", + iree_hal_allocator_destroy, + iree_hal_allocator_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_t, "hal.buffer", + iree_hal_buffer_destroy, + iree_hal_buffer_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_buffer_view_t, "hal.buffer_view", + iree_hal_buffer_view_destroy, + iree_hal_buffer_view_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_command_buffer_t, "hal.command_buffer", + iree_hal_command_buffer_destroy, + iree_hal_command_buffer_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_t, "hal.descriptor_set", + iree_hal_descriptor_set_destroy, + iree_hal_descriptor_set_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_descriptor_set_layout_t, + "hal.descriptor_set_layout", + iree_hal_descriptor_set_layout_destroy, + iree_hal_descriptor_set_layout_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_device_t, "hal.device", + iree_hal_device_destroy, + iree_hal_device_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_event_t, "hal.event", + iree_hal_event_destroy, + iree_hal_event_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_t, "hal.executable", + iree_hal_executable_destroy, + iree_hal_executable_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE( + iree_hal_executable_cache_t, "hal.executable_cache", + iree_hal_executable_cache_destroy, iree_hal_executable_cache_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_executable_layout_t, + "hal.executable_layout", + iree_hal_executable_layout_destroy, + iree_hal_executable_layout_descriptor); + IREE_VM_REGISTER_HAL_C_TYPE(iree_hal_semaphore_t, "hal.semaphore", + iree_hal_semaphore_destroy, + iree_hal_semaphore_descriptor); has_registered = true; return iree_ok_status(); @@ -90,6 +109,7 @@ IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set, IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set_layout, iree_hal_descriptor_set_layout_t); IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_device, iree_hal_device_t); +IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_event, iree_hal_event_t); IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t); IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_cache, iree_hal_executable_cache_t); @@ -97,21 +117,27 @@ IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_layout, iree_hal_executable_layout_t); IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_semaphore, iree_hal_semaphore_t); +namespace iree { +namespace hal { +namespace { + //===----------------------------------------------------------------------===// // Module type definitions //===----------------------------------------------------------------------===// class HALModuleState final { public: - HALModuleState(iree_allocator_t allocator, ref_ptr shared_device, - ref_ptr executable_cache) - : allocator_(allocator), shared_device_(std::move(shared_device)) {} + HALModuleState(iree_allocator_t allocator, iree_hal_device_t* shared_device) + : allocator_(allocator), shared_device_(shared_device) { + iree_hal_device_retain(shared_device_); + } ~HALModuleState() { for (auto& ref : deferred_releases_) { iree_vm_ref_release(&ref); } deferred_releases_.clear(); + iree_hal_device_release(shared_device_); } //===--------------------------------------------------------------------===// @@ -121,8 +147,7 @@ class HALModuleState final { // using these APIs are not forward compatible. StatusOr> ExSharedDevice() { - return vm::retain_ref( - reinterpret_cast(shared_device_.get())); + return vm::retain_ref(shared_device_); } template @@ -137,8 +162,8 @@ class HALModuleState final { IREE_TRACE_SCOPE0("HALModuleState::ExSubmitAndWait"); vm::ref semaphore; - IREE_RETURN_IF_ERROR(iree_hal_semaphore_create( - device.get(), 0ull, iree_allocator_system(), &semaphore)); + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_create(device.get(), 0ull, &semaphore)); iree_hal_submission_batch_t batch; memset(&batch, 0, sizeof(batch)); @@ -168,7 +193,7 @@ class HALModuleState final { } //===--------------------------------------------------------------------===// - // iree::hal::Allocator + // iree_hal_allocator_t //===--------------------------------------------------------------------===// StatusOr> AllocatorAllocate( @@ -191,6 +216,8 @@ class HALModuleState final { // TODO(benvanik): wrap when supported. + buffer_usage |= IREE_HAL_BUFFER_USAGE_MAPPING; + size_t buffer_length = source->data.data_length; if (length == -1) { length = static_cast(buffer_length); @@ -216,7 +243,7 @@ class HALModuleState final { } //===--------------------------------------------------------------------===// - // iree::hal::Buffer + // iree_hal_buffer_t //===--------------------------------------------------------------------===// StatusOr> BufferAllocator( @@ -230,7 +257,7 @@ class HALModuleState final { IREE_TRACE_SCOPE0("HALModuleState::BufferSubspan"); vm::ref target_buffer; IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan( - source_buffer.get(), source_offset, length, allocator_, &target_buffer)) + source_buffer.get(), source_offset, length, &target_buffer)) << "Subspan of an existing buffer (source_offset=" << source_offset << ", length=" << length << ")"; return target_buffer; @@ -306,16 +333,15 @@ class HALModuleState final { } //===--------------------------------------------------------------------===// - // iree::hal::BufferView + // iree_hal_buffer_view_t //===--------------------------------------------------------------------===// StatusOr> BufferViewCreate( const vm::ref& buffer, absl::Span shape, iree_hal_element_type_t element_type) { vm::ref buffer_view; - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(buffer.get(), shape.data(), - shape.size(), element_type, - allocator_, &buffer_view)) + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create( + buffer.get(), shape.data(), shape.size(), element_type, &buffer_view)) << "Failed to create buffer view"; return std::move(buffer_view); } @@ -326,7 +352,7 @@ class HALModuleState final { vm::ref new_buffer_view; IREE_RETURN_IF_ERROR(iree_hal_buffer_view_subview( buffer_view.get(), indices.data(), indices.size(), lengths.data(), - lengths.size(), allocator_, &new_buffer_view)) + lengths.size(), &new_buffer_view)) << "Failed to create subview"; return std::move(new_buffer_view); } @@ -429,7 +455,7 @@ class HALModuleState final { } //===--------------------------------------------------------------------===// - // iree::hal::CommandBuffer + // iree_hal_command_buffer_t //===--------------------------------------------------------------------===// StatusOr> CommandBufferCreate( @@ -438,8 +464,7 @@ class HALModuleState final { iree_hal_command_category_t command_categories) { vm::ref command_buffer; IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create( - device.get(), modes, command_categories, iree_allocator_system(), - &command_buffer)) + device.get(), modes, command_categories, &command_buffer)) << "Failed to create command buffer"; return command_buffer; } @@ -497,14 +522,15 @@ class HALModuleState final { uint32_t offset, absl::Span values) { ExDeferRelease(executable_layout); return iree_hal_command_buffer_push_constants( - command_buffer.get(), executable_layout.get(), offset, values.data(), + command_buffer.get(), executable_layout.get(), + offset * sizeof(uint32_t), values.data(), values.size() * sizeof(uint32_t)); } Status CommandBufferPushDescriptorSet( const vm::ref& command_buffer, const vm::ref& executable_layout, - int32_t set, absl::Span binding_ordinals, + uint32_t set, absl::Span binding_ordinals, absl::Span> binding_buffers, absl::Span binding_offsets, absl::Span binding_lengths) { @@ -526,7 +552,7 @@ class HALModuleState final { Status CommandBufferBindDescriptorSet( const vm::ref& command_buffer, const vm::ref& executable_layout, - int32_t set, const vm::ref& descriptor_set, + uint32_t set, const vm::ref& descriptor_set, absl::Span dynamic_offsets) { ExDeferRelease(executable_layout); ExDeferRelease(descriptor_set); @@ -565,16 +591,16 @@ class HALModuleState final { } //===--------------------------------------------------------------------===// - // iree::hal::DescriptorSet + // iree_hal_descriptor_set_t //===--------------------------------------------------------------------===// StatusOr> DescriptorSetCreate( const vm::ref& device, const vm::ref& set_layout, - absl::Span binding_ordinals, + absl::Span binding_ordinals, absl::Span> binding_buffers, - absl::Span binding_offsets, - absl::Span binding_lengths) { + absl::Span binding_offsets, + absl::Span binding_lengths) { absl::InlinedVector binding_structs( binding_ordinals.size()); for (int i = 0; i < binding_ordinals.size(); ++i) { @@ -587,18 +613,18 @@ class HALModuleState final { vm::ref descriptor_set; IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_create( device.get(), set_layout.get(), binding_structs.size(), - binding_structs.data(), allocator_, &descriptor_set)); + binding_structs.data(), &descriptor_set)); return std::move(descriptor_set); } //===--------------------------------------------------------------------===// - // iree::hal::DescriptorSetLayout + // iree_hal_descriptor_set_layout_t //===--------------------------------------------------------------------===// StatusOr> DescriptorSetLayoutCreate( const vm::ref& device, iree_hal_descriptor_set_layout_usage_type_t usage_type, - absl::Span> bindings) { // TODO(benvanik): custom marshaling for the structs. @@ -611,12 +637,12 @@ class HALModuleState final { vm::ref descriptor_set_layout; IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_layout_create( device.get(), usage_type, binding_structs.size(), - binding_structs.data(), allocator_, &descriptor_set_layout)); + binding_structs.data(), &descriptor_set_layout)); return std::move(descriptor_set_layout); } //===--------------------------------------------------------------------===// - // iree::hal::Device + // iree_hal_device_t //===--------------------------------------------------------------------===// StatusOr> DeviceAllocator( @@ -634,7 +660,11 @@ class HALModuleState final { } //===--------------------------------------------------------------------===// - // iree::hal::ExecutableCache + // iree_hal_event_t + //===--------------------------------------------------------------------===// + + //===--------------------------------------------------------------------===// + // iree_hal_executable_cache_t //===--------------------------------------------------------------------===// StatusOr> ExecutableCacheCreate( @@ -642,7 +672,7 @@ class HALModuleState final { vm::ref executable_cache; IREE_RETURN_IF_ERROR(iree_hal_executable_cache_create( device.get(), iree_string_view_t{identifier.data(), identifier.size()}, - allocator_, &executable_cache)); + &executable_cache)); return std::move(executable_cache); } @@ -666,12 +696,12 @@ class HALModuleState final { vm::ref executable; IREE_RETURN_IF_ERROR(iree_hal_executable_cache_prepare_executable( executable_cache.get(), executable_layout.get(), caching_mode, - executable_data->data, allocator_, &executable)); + executable_data->data, &executable)); return std::move(executable); } //===--------------------------------------------------------------------===// - // iree::hal::ExecutableLayout + // iree_hal_executable_layout_t //===--------------------------------------------------------------------===// StatusOr> ExecutableLayoutCreate( @@ -684,19 +714,19 @@ class HALModuleState final { reinterpret_cast( const_cast*>( set_layouts.data())), - push_constants, allocator_, &executable_layout)); + push_constants, &executable_layout)); return std::move(executable_layout); } //===--------------------------------------------------------------------===// - // iree::hal::Semaphore + // iree_hal_semaphore_t //===--------------------------------------------------------------------===// StatusOr> SemaphoreCreate( const vm::ref& device, uint32_t initial_value) { vm::ref semaphore; - IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device.get(), initial_value, - allocator_, &semaphore)); + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_create(device.get(), initial_value, &semaphore)); return std::move(semaphore); } @@ -738,7 +768,7 @@ class HALModuleState final { private: iree_allocator_t allocator_; - ref_ptr shared_device_; + iree_hal_device_t* shared_device_ = NULL; std::vector deferred_releases_; }; @@ -846,33 +876,31 @@ static const vm::NativeFunction kHALModuleFunctions[] = { class HALModule final : public vm::NativeModule { public: - HALModule(iree_allocator_t allocator, ref_ptr shared_device) + HALModule(iree_allocator_t allocator, iree_hal_device_t* shared_device) : vm::NativeModule( "hal", allocator, absl::MakeConstSpan(kHALModuleFunctions)), - shared_device_(std::move(shared_device)) {} - ~HALModule() = default; + shared_device_(shared_device) { + iree_hal_device_retain(shared_device_); + } + + ~HALModule() { iree_hal_device_release(shared_device_); } Status Initialize() { IREE_TRACE_SCOPE0("HALModule::Initialize"); - - executable_cache_ = shared_device_->CreateExecutableCache(); - return OkStatus(); } StatusOr> CreateState( iree_allocator_t allocator) override { IREE_TRACE_SCOPE0("HALModule::CreateState"); - auto state = std::make_unique( - allocator, add_ref(shared_device_), add_ref(executable_cache_)); + auto state = std::make_unique(allocator, shared_device_); // TODO(benvanik): allocate context-specific variables (allocator pool, // etc). return state; } private: - ref_ptr shared_device_; - ref_ptr executable_cache_; + iree_hal_device_t* shared_device_ = NULL; }; IREE_API_EXPORT iree_status_t IREE_API_CALL @@ -881,8 +909,7 @@ iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator, IREE_ASSERT_ARGUMENT(device); IREE_ASSERT_ARGUMENT(out_module); *out_module = nullptr; - auto module = std::make_unique( - allocator, add_ref(reinterpret_cast(device))); + auto module = std::make_unique(allocator, device); IREE_RETURN_IF_ERROR(module->Initialize()); *out_module = module.release()->interface(); return iree_ok_status(); diff --git a/iree/modules/hal/hal_module.h b/iree/modules/hal/hal_module.h index cf3df76e622b8..749828a624e83 100644 --- a/iree/modules/hal/hal_module.h +++ b/iree/modules/hal/hal_module.h @@ -31,6 +31,7 @@ IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_descriptor_set, IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_descriptor_set_layout, iree_hal_descriptor_set_layout_t); IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_device, iree_hal_device_t); +IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_event, iree_hal_event_t); IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t); IREE_VM_DECLARE_TYPE_ADAPTERS(iree_hal_executable_cache, iree_hal_executable_cache_t); diff --git a/iree/modules/strings/strings_module.cc b/iree/modules/strings/strings_module.cc index c8ea45fd1e912..3243e5b2cfe65 100644 --- a/iree/modules/strings/strings_module.cc +++ b/iree/modules/strings/strings_module.cc @@ -138,10 +138,10 @@ class StringsModuleState final { size_t tensor_size = element_size * num_elements; iree_hal_buffer_t* hal_buffer = iree_hal_buffer_view_buffer(hal_buffer_view.get()); - iree_hal_mapped_memory_t tensor_mapping; - IREE_RETURN_IF_ERROR( - iree_hal_buffer_map(hal_buffer, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, tensor_size, &tensor_mapping)); + iree_hal_buffer_mapping_t tensor_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + hal_buffer, IREE_HAL_MEMORY_ACCESS_READ, + /*byte_offset=*/0, tensor_size, &tensor_mapping)); iree_hal_element_type_t type = iree_hal_buffer_view_element_type(hal_buffer_view.get()); @@ -195,7 +195,7 @@ class StringsModuleState final { } // Unmap used buffer. - IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap(hal_buffer, &tensor_mapping)); + iree_hal_buffer_unmap_range(&tensor_mapping); // Place into iree_string_views. std::vector string_views; @@ -237,10 +237,10 @@ class StringsModuleState final { size_t element_size = iree_hal_buffer_view_element_size(ids.get()); size_t tensor_size = element_size * num_elements; iree_hal_buffer_t* hal_buffer = iree_hal_buffer_view_buffer(ids.get()); - iree_hal_mapped_memory_t tensor_mapping; - IREE_RETURN_IF_ERROR( - iree_hal_buffer_map(hal_buffer, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, tensor_size, &tensor_mapping)); + iree_hal_buffer_mapping_t tensor_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + hal_buffer, IREE_HAL_MEMORY_ACCESS_READ, + /*byte_offset=*/0, tensor_size, &tensor_mapping)); iree_string_view_t str; const auto& contents = tensor_mapping.contents; std::vector string_views; @@ -255,7 +255,7 @@ class StringsModuleState final { } // Unmap used buffer. - IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap(hal_buffer, &tensor_mapping)); + iree_hal_buffer_unmap_range(&tensor_mapping); strings_string_tensor_t* string_tensor; IREE_RETURN_IF_ERROR(strings_string_tensor_create( @@ -309,7 +309,7 @@ class StringsModuleState final { iree_allocator_t allocator_ = iree_allocator_system(); template - void GenerateStringsByType(iree_hal_mapped_memory_t tensor_mapping, + void GenerateStringsByType(iree_hal_buffer_mapping_t tensor_mapping, std::vector& strings) { const auto& contents = tensor_mapping.contents; for (const T *p = (const T*)contents.data, diff --git a/iree/modules/strings/strings_module_test.cc b/iree/modules/strings/strings_module_test.cc index 7b187cdc413a7..f49b7f71c97d0 100644 --- a/iree/modules/strings/strings_module_test.cc +++ b/iree/modules/strings/strings_module_test.cc @@ -120,17 +120,10 @@ class StringsModuleTest : public ::testing::Test { IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE), IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(T), &buffer)); - iree_hal_mapped_memory_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(), - IREE_HAL_MEMORY_ACCESS_WRITE, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); - memcpy(mapped_memory.contents.data, - static_cast(contents.data()), - mapped_memory.contents.data_length); - IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory)); + IREE_ASSERT_OK(iree_hal_buffer_write_data(buffer.get(), 0, contents.data(), + contents.size() * sizeof(T))); IREE_ASSERT_OK(iree_hal_buffer_view_create( - buffer.get(), shape.data(), shape.size(), E, iree_allocator_system(), - &*out_buffer_view)); + buffer.get(), shape.data(), shape.size(), E, &*out_buffer_view)); } void TestStringTensorToString( diff --git a/iree/modules/tensorlist/native_module.cc b/iree/modules/tensorlist/native_module.cc index 18d47704f8db6..0d509e0771420 100644 --- a/iree/modules/tensorlist/native_module.cc +++ b/iree/modules/tensorlist/native_module.cc @@ -109,13 +109,12 @@ class TensorList final : public RefObject { vm::ref subview_buffer; IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan( iree_hal_buffer_view_buffer(tensor.get()), start_offset, - subview_length, iree_allocator_system(), &subview_buffer)); + subview_length, &subview_buffer)); iree_hal_buffer_view_t* slice = nullptr; IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create( subview_buffer.get(), element_shape.data(), element_shape.size(), - iree_hal_buffer_view_element_type(tensor.get()), - iree_allocator_system(), &slice)); + iree_hal_buffer_view_element_type(tensor.get()), &slice)); list->SetItem(i, slice); } return list; @@ -171,9 +170,9 @@ class TensorList final : public RefObject { result_shape.push_back(dim); } vm::ref result_view; - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create( - result_buffer.get(), result_shape.data(), result_shape.size(), type, - iree_allocator_system(), &result_view)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_create(result_buffer.get(), result_shape.data(), + result_shape.size(), type, &result_view)); return std::move(result_view); } @@ -241,21 +240,21 @@ class TensorList final : public RefObject { result_shape.push_back(dim); } vm::ref result_view; - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create( - result_buffer.get(), result_shape.data(), result_shape.size(), type, - iree_allocator_system(), &result_view)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_view_create(result_buffer.get(), result_shape.data(), + result_shape.size(), type, &result_view)); return std::move(result_view); } private: iree_status_t CopyTensorBytes(iree_hal_buffer_t* buffer) { - iree_hal_mapped_memory_t result_mapping; + iree_hal_buffer_mapping_t result_mapping; iree_device_size_t dest_byte_size = iree_hal_buffer_byte_length(buffer); - IREE_RETURN_IF_ERROR( - iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_WRITE, - /*byte_offset=*/0, - /*byte_length=*/dest_byte_size, &result_mapping)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + buffer, IREE_HAL_MEMORY_ACCESS_WRITE, + /*byte_offset=*/0, + /*byte_length=*/dest_byte_size, &result_mapping)); // Copy each buffer into the result at the right offset. // This is just a naive map+memcpy. @@ -283,21 +282,12 @@ class TensorList final : public RefObject { } iree_hal_buffer_t* tensor_buffer = iree_hal_buffer_view_buffer(tensor); - iree_hal_mapped_memory_t tensor_mapping; - iree_device_size_t tensor_byte_size = - iree_hal_buffer_byte_length(tensor_buffer); - IREE_RETURN_IF_ERROR( - iree_hal_buffer_map(tensor_buffer, IREE_HAL_MEMORY_ACCESS_READ, 0, - tensor_byte_size, &tensor_mapping)); - - memcpy(block_begin, tensor_mapping.contents.data, block_size); - - IREE_RETURN_IF_ERROR( - iree_hal_buffer_unmap(tensor_buffer, &tensor_mapping)); + iree_hal_buffer_read_data(tensor_buffer, 0, block_begin, block_size)); } - return iree_hal_buffer_unmap(buffer, &result_mapping); + iree_hal_buffer_unmap_range(&result_mapping); + return iree_ok_status(); } std::vector> list_; @@ -349,11 +339,11 @@ static StatusOr ReadInt32FromScalarBufferView( << "expected rank-0 buffer view"; } iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view); - iree_hal_mapped_memory_t mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, - 0, 4, &mapped_memory)); + iree_hal_buffer_mapping_t mapped_memory; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + buffer, IREE_HAL_MEMORY_ACCESS_READ, 0, 4, &mapped_memory)); int32_t scalar = *reinterpret_cast(mapped_memory.contents.data); - IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap(buffer, &mapped_memory)); + iree_hal_buffer_unmap_range(&mapped_memory); return scalar; } @@ -373,16 +363,9 @@ static StatusOr> ReadInt32VectorFromBufferView( buffer_view, /*rank_capacity=*/1, &length, nullptr)); iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view); - iree_hal_mapped_memory_t mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, - 0, length * sizeof(int32_t), - &mapped_memory)); - - std::vector contents( - reinterpret_cast(mapped_memory.contents.data), - reinterpret_cast(mapped_memory.contents.data) + length); - - IREE_RETURN_IF_ERROR(iree_hal_buffer_unmap(buffer, &mapped_memory)); + std::vector contents(length); + IREE_RETURN_IF_ERROR(iree_hal_buffer_read_data( + buffer, 0, contents.data(), contents.size() * sizeof(int32_t))); return contents; } diff --git a/iree/modules/tensorlist/tensorlist_test.cc b/iree/modules/tensorlist/tensorlist_test.cc index 7779173cbeeb0..c76a8e5c588b2 100644 --- a/iree/modules/tensorlist/tensorlist_test.cc +++ b/iree/modules/tensorlist/tensorlist_test.cc @@ -140,16 +140,16 @@ class TensorListModulesTest : public ::testing::Test { iree_hal_buffer_view_buffer(returned_buffer_view); ASSERT_NE(returned_buffer, nullptr); - iree_hal_mapped_memory_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map(returned_buffer, - IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); + iree_hal_buffer_mapping_t mapped_memory; + IREE_ASSERT_OK( + iree_hal_buffer_map_range(returned_buffer, IREE_HAL_MEMORY_ACCESS_READ, + 0, IREE_WHOLE_BUFFER, &mapped_memory)); for (int i = 0; i < expected_values.size(); i++) { EXPECT_EQ(reinterpret_cast(mapped_memory.contents.data)[i], expected_values[i]); } - IREE_ASSERT_OK(iree_hal_buffer_unmap(returned_buffer, &mapped_memory)); + iree_hal_buffer_unmap_range(&mapped_memory); } void CreateBufferView(absl::Span contents, @@ -169,18 +169,11 @@ class TensorListModulesTest : public ::testing::Test { IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE), IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(float), &buffer)); - iree_hal_mapped_memory_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map(buffer.get(), - IREE_HAL_MEMORY_ACCESS_WRITE, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); - memcpy(mapped_memory.contents.data, - static_cast(contents.data()), - mapped_memory.contents.data_length); - IREE_ASSERT_OK(iree_hal_buffer_unmap(buffer.get(), &mapped_memory)); + IREE_ASSERT_OK(iree_hal_buffer_write_data(buffer.get(), 0, contents.data(), + contents.size() * sizeof(float))); IREE_ASSERT_OK(iree_hal_buffer_view_create( buffer.get(), shape.data(), shape.size(), - IREE_HAL_ELEMENT_TYPE_FLOAT_32, iree_allocator_system(), - &*out_buffer_view)); + IREE_HAL_ELEMENT_TYPE_FLOAT_32, &*out_buffer_view)); } iree_hal_device_t* device_ = nullptr; diff --git a/iree/modules/vmla/BUILD b/iree/modules/vmla/BUILD new file mode 100644 index 0000000000000..e13764df31d15 --- /dev/null +++ b/iree/modules/vmla/BUILD @@ -0,0 +1,82 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//iree:build_defs.oss.bzl", "iree_cmake_extra_content") + +package( + default_visibility = ["//visibility:public"], + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_cmake_extra_content( + content = """ +if(NOT ${IREE_HAL_DRIVER_VMLA}) + return() +endif() +""", +) + +cc_library( + name = "op_kernels", + hdrs = ["op_kernels.h"], + textual_hdrs = [ + # TODO(benvanik): SIMD variants. + "op_kernels_generic.h", + "op_kernels_ruy.h", + "op_kernels_fft.h", + ], + deps = [ + "//iree/base:status", + "//iree/base:tracing", + "@com_google_absl//absl/algorithm", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@com_google_ruy//ruy", + "@com_google_ruy//ruy:context", + "@pffft", + ], +) + +cc_test( + name = "op_kernels_test", + srcs = ["op_kernels_test.cc"], + deps = [ + ":op_kernels", + "//iree/base:core_headers", + "//iree/testing:gtest", + "//iree/testing:gtest_main", + "@com_google_absl//absl/container:inlined_vector", + ], +) + +cc_library( + name = "op_module", + srcs = ["op_module.cc"], + hdrs = ["op_module.h"], + deps = [ + ":op_kernels", + "//iree/base:api", + "//iree/base:core_headers", + "//iree/base:ref_ptr", + "//iree/base:status", + "//iree/base:tracing", + "//iree/vm", + "//iree/vm:cc", + "@com_google_absl//absl/types:span", + ], +) diff --git a/iree/modules/vmla/CMakeLists.txt b/iree/modules/vmla/CMakeLists.txt new file mode 100644 index 0000000000000..0a31ca4399200 --- /dev/null +++ b/iree/modules/vmla/CMakeLists.txt @@ -0,0 +1,75 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT ${IREE_HAL_DRIVER_VMLA}) + return() +endif() + +iree_add_all_subdirs() + +iree_cc_library( + NAME + op_kernels + HDRS + "op_kernels.h" + TEXTUAL_HDRS + "op_kernels_fft.h" + "op_kernels_generic.h" + "op_kernels_ruy.h" + DEPS + absl::algorithm + absl::core_headers + absl::flat_hash_set + absl::inlined_vector + absl::memory + absl::span + iree::base::status + iree::base::tracing + pffft + ruy + PUBLIC +) + +iree_cc_test( + NAME + op_kernels_test + SRCS + "op_kernels_test.cc" + DEPS + ::op_kernels + absl::inlined_vector + iree::base::core_headers + iree::testing::gtest + iree::testing::gtest_main +) + +iree_cc_library( + NAME + op_module + HDRS + "op_module.h" + SRCS + "op_module.cc" + DEPS + ::op_kernels + absl::span + iree::base::api + iree::base::core_headers + iree::base::ref_ptr + iree::base::status + iree::base::tracing + iree::vm + iree::vm::cc + PUBLIC +) diff --git a/iree/hal/vmla/op_kernels.h b/iree/modules/vmla/op_kernels.h similarity index 98% rename from iree/hal/vmla/op_kernels.h rename to iree/modules/vmla/op_kernels.h index ac4e30fe4de25..8ad19dd43708c 100644 --- a/iree/hal/vmla/op_kernels.h +++ b/iree/modules/vmla/op_kernels.h @@ -31,8 +31,8 @@ // semantics as reference and platform-specific versions can be implemented // as needed. -#ifndef IREE_HAL_VMLA_OP_KERNELS_H_ -#define IREE_HAL_VMLA_OP_KERNELS_H_ +#ifndef IREE_MODULES_VMLA_OP_KERNELS_H_ +#define IREE_MODULES_VMLA_OP_KERNELS_H_ #include @@ -496,9 +496,9 @@ struct PoolingMax { // Inconsistent automated formatting here. Just disable clang-format (for now?). // clang-format off -#include "iree/hal/vmla/op_kernels_generic.h" // IWYU pragma: export -#include "iree/hal/vmla/op_kernels_ruy.h" // IWYU pragma: export -#include "iree/hal/vmla/op_kernels_fft.h" // IWYU pragma: export +#include "iree/modules/vmla/op_kernels_generic.h" // IWYU pragma: export +#include "iree/modules/vmla/op_kernels_ruy.h" // IWYU pragma: export +#include "iree/modules/vmla/op_kernels_fft.h" // IWYU pragma: export // clang-format on #endif // IREE_HAL_VMLA_OP_KERNELS_H_ diff --git a/iree/hal/vmla/op_kernels_fft.h b/iree/modules/vmla/op_kernels_fft.h similarity index 97% rename from iree/hal/vmla/op_kernels_fft.h rename to iree/modules/vmla/op_kernels_fft.h index 453e2a02fd688..b4a623cbe3540 100644 --- a/iree/hal/vmla/op_kernels_fft.h +++ b/iree/modules/vmla/op_kernels_fft.h @@ -31,8 +31,8 @@ // semantics as reference and platform-specific versions can be implemented // as needed. -#ifndef IREE_HAL_VMLA_OP_KERNELS_FFT_H_ -#define IREE_HAL_VMLA_OP_KERNELS_FFT_H_ +#ifndef IREE_MODULES_VMLA_OP_KERNELS_FFT_H_ +#define IREE_MODULES_VMLA_OP_KERNELS_FFT_H_ #include "absl/types/span.h" #include "iree/base/logging.h" @@ -176,4 +176,4 @@ struct Irfft { } // namespace hal } // namespace iree -#endif // IREE_HAL_VMLA_OP_KERNELS_FFT_H_ +#endif // IREE_MODULES_VMLA_OP_KERNELS_FFT_H_ diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/modules/vmla/op_kernels_generic.h similarity index 99% rename from iree/hal/vmla/op_kernels_generic.h rename to iree/modules/vmla/op_kernels_generic.h index 15c21418800e6..8ddf6c4fddf22 100644 --- a/iree/hal/vmla/op_kernels_generic.h +++ b/iree/modules/vmla/op_kernels_generic.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_ -#define IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_ +#ifndef IREE_MODULES_VMLA_OP_KERNELS_GENERIC_H_ +#define IREE_MODULES_VMLA_OP_KERNELS_GENERIC_H_ #include #include @@ -1167,4 +1167,4 @@ Status PoolingMax::Execute(absl::Span src_buffer, } // namespace hal } // namespace iree -#endif // IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_ +#endif // IREE_MODULES_VMLA_OP_KERNELS_GENERIC_H_ diff --git a/iree/hal/vmla/op_kernels_ruy.h b/iree/modules/vmla/op_kernels_ruy.h similarity index 97% rename from iree/hal/vmla/op_kernels_ruy.h rename to iree/modules/vmla/op_kernels_ruy.h index f252d075503f2..46d73704c2250 100644 --- a/iree/hal/vmla/op_kernels_ruy.h +++ b/iree/modules/vmla/op_kernels_ruy.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef IREE_HAL_VMLA_OP_KERNELS_RUY_H_ -#define IREE_HAL_VMLA_OP_KERNELS_RUY_H_ +#ifndef IREE_MODULES_VMLA_OP_KERNELS_RUY_H_ +#define IREE_MODULES_VMLA_OP_KERNELS_RUY_H_ #include @@ -128,4 +128,4 @@ Status MatMul::Execute(RuntimeState* runtime_state, } // namespace hal } // namespace iree -#endif // IREE_HAL_VMLA_OP_KERNELS_RUY_H_ +#endif // IREE_MODULES_VMLA_OP_KERNELS_RUY_H_ diff --git a/iree/hal/vmla/op_kernels_test.cc b/iree/modules/vmla/op_kernels_test.cc similarity index 99% rename from iree/hal/vmla/op_kernels_test.cc rename to iree/modules/vmla/op_kernels_test.cc index 3bb83de4204f8..5cc17300e52b5 100644 --- a/iree/hal/vmla/op_kernels_test.cc +++ b/iree/modules/vmla/op_kernels_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/hal/vmla/op_kernels.h" +#include "iree/modules/vmla/op_kernels.h" #include "absl/container/inlined_vector.h" #include "iree/base/memory.h" @@ -544,7 +544,7 @@ TEST(Transpose, 2Dimen) { std::vector src_buffer = {1, 2, 3, 4, 5, 6}; std::vector expected_dst = {1, 4, - 2, 5, + 2, 5, 3, 6}; // clang-format on std::vector dst_buffer(GetShapeElementCount(dst_shape), UINT16_MAX); @@ -565,7 +565,7 @@ TEST(Transpose, 3Dimen) { 7, 8, 9, 10, 11, 12}; std::vector expected_dst = {1, 4, - 2, 5, + 2, 5, 3, 6, 7, 10, 8, 11, diff --git a/iree/hal/vmla/op_module.cc b/iree/modules/vmla/op_module.cc similarity index 99% rename from iree/hal/vmla/op_module.cc rename to iree/modules/vmla/op_module.cc index 659b93343683d..26ddaa2c3c59c 100644 --- a/iree/hal/vmla/op_module.cc +++ b/iree/modules/vmla/op_module.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/hal/vmla/op_module.h" +#include "iree/modules/vmla/op_module.h" #include #include "absl/types/span.h" #include "iree/base/tracing.h" -#include "iree/hal/vmla/op_kernels.h" +#include "iree/modules/vmla/op_kernels.h" #include "iree/vm/module_abi_packing.h" //===----------------------------------------------------------------------===// @@ -157,16 +157,16 @@ Status Interface::SetConstants(absl::Span values) { } StatusOr Interface::GetBinding( - int32_t set, int32_t binding) const { - if (set < 0 || set > kMaxSets || binding < 0 || binding > kMaxBindings) { + uint32_t set, uint32_t binding) const { + if (set >= kMaxSets || binding >= kMaxBindings) { return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid binding set=" << set << ", binding=" << binding; } return bindings_[set][binding]; } -Status Interface::SetBinding(int32_t set, int32_t binding, Binding value) { - if (set < 0 || set > kMaxSets || binding < 0 || binding > kMaxBindings) { +Status Interface::SetBinding(uint32_t set, uint32_t binding, Binding value) { + if (set >= kMaxSets || binding >= kMaxBindings) { return InvalidArgumentErrorBuilder(IREE_LOC) << "Invalid binding set=" << set << ", binding=" << binding; } @@ -205,7 +205,7 @@ class VMLAModuleState final { } StatusOr> InterfaceBinding(vm::ref interface, - int32_t set, int32_t binding) { + uint32_t set, uint32_t binding) { IREE_TRACE_SCOPE0("VMLAModuleState::InterfaceBinding"); IREE_ASSIGN_OR_RETURN(const auto& value, interface->GetBinding(set, binding)); diff --git a/iree/hal/vmla/op_module.h b/iree/modules/vmla/op_module.h similarity index 94% rename from iree/hal/vmla/op_module.h rename to iree/modules/vmla/op_module.h index f94a6718f06b3..5eaf07950c18b 100644 --- a/iree/hal/vmla/op_module.h +++ b/iree/modules/vmla/op_module.h @@ -16,8 +16,8 @@ // linked into the same library, because of this we can avoid the C shims and // directly use C++ types. -#ifndef IREE_HAL_VMLA_OP_MODULE_H_ -#define IREE_HAL_VMLA_OP_MODULE_H_ +#ifndef IREE_MODULES_VMLA_OP_MODULE_H_ +#define IREE_MODULES_VMLA_OP_MODULE_H_ #include @@ -115,10 +115,10 @@ class Interface final : public RefObject { Status SetConstants(absl::Span values); // Gets the binding within a set. Note that the buffer may be null. - StatusOr GetBinding(int32_t set, int32_t binding) const; + StatusOr GetBinding(uint32_t set, uint32_t binding) const; // Sets a binding within a set to the given buffer value (possibly null). - Status SetBinding(int32_t set, int32_t binding, Binding value); + Status SetBinding(uint32_t set, uint32_t binding, Binding value); private: std::array constants_; @@ -136,4 +136,4 @@ Status ModuleCreate(iree_allocator_t allocator, iree_vm_module_t** out_module); IREE_VM_DECLARE_TYPE_ADAPTERS(Buffer, iree::hal::vmla::Buffer); IREE_VM_DECLARE_TYPE_ADAPTERS(Interface, iree::hal::vmla::Interface); -#endif // IREE_HAL_VMLA_OP_MODULE_H_ +#endif // IREE_MODULES_VMLA_OP_MODULE_H_ diff --git a/iree/samples/custom_modules/BUILD b/iree/samples/custom_modules/BUILD index c980e20e7b462..3154cc3be0ab8 100644 --- a/iree/samples/custom_modules/BUILD +++ b/iree/samples/custom_modules/BUILD @@ -23,7 +23,8 @@ package( iree_cmake_extra_content( content = """ -if(NOT ${IREE_TARGET_BACKEND_VMLA} OR NOT ${IREE_HAL_DRIVER_VMLA}) +if(NOT "${IREE_TARGET_BACKEND_VMLA}" OR + NOT "${IREE_HAL_DRIVER_VMLA}") return() endif() """, diff --git a/iree/samples/custom_modules/CMakeLists.txt b/iree/samples/custom_modules/CMakeLists.txt index f7942f702d3e0..0e939725e031b 100644 --- a/iree/samples/custom_modules/CMakeLists.txt +++ b/iree/samples/custom_modules/CMakeLists.txt @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_VMLA} OR NOT ${IREE_HAL_DRIVER_VMLA}) +if(NOT "${IREE_TARGET_BACKEND_VMLA}" OR + NOT "${IREE_HAL_DRIVER_VMLA}") return() endif() diff --git a/iree/samples/custom_modules/custom_modules_test.cc b/iree/samples/custom_modules/custom_modules_test.cc index beef630f02dfd..33fdb4e914baa 100644 --- a/iree/samples/custom_modules/custom_modules_test.cc +++ b/iree/samples/custom_modules/custom_modules_test.cc @@ -53,6 +53,7 @@ class CustomModulesTest : public ::testing::Test { hal_driver, iree_allocator_system(), &hal_device)); IREE_CHECK_OK(iree_hal_module_create(hal_device, iree_allocator_system(), &hal_module_)); + hal_allocator_ = iree_hal_device_allocator(hal_device); iree_hal_device_release(hal_device); iree_hal_driver_release(hal_driver); @@ -100,6 +101,7 @@ class CustomModulesTest : public ::testing::Test { iree_vm_module_t* bytecode_module_ = nullptr; iree_vm_module_t* native_module_ = nullptr; iree_vm_module_t* hal_module_ = nullptr; + iree_hal_allocator_t* hal_allocator_ = nullptr; }; TEST_F(CustomModulesTest, ReverseAndPrint) { @@ -145,12 +147,12 @@ TEST_F(CustomModulesTest, PrintTensor) { static float kBufferContents[2 * 4] = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; iree_hal_buffer_t* buffer = nullptr; - IREE_ASSERT_OK(iree_hal_heap_buffer_allocate_copy( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL, IREE_HAL_BUFFER_USAGE_ALL, - IREE_HAL_MEMORY_ACCESS_ALL, + IREE_ASSERT_OK(iree_hal_allocator_wrap_buffer( + hal_allocator_, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, + IREE_HAL_MEMORY_ACCESS_ALL, IREE_HAL_BUFFER_USAGE_ALL, iree_byte_span_t{reinterpret_cast(kBufferContents), sizeof(kBufferContents)}, - iree_allocator_system(), iree_allocator_system(), &buffer)); + iree_allocator_null(), &buffer)); // Pass in the tensor as an expanded HAL buffer. iree::vm::ref inputs; @@ -185,12 +187,12 @@ TEST_F(CustomModulesTest, RoundTripTensor) { static float kBufferContents[2 * 4] = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; iree_hal_buffer_t* buffer = nullptr; - IREE_ASSERT_OK(iree_hal_heap_buffer_allocate_copy( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL, IREE_HAL_BUFFER_USAGE_ALL, - IREE_HAL_MEMORY_ACCESS_ALL, + IREE_ASSERT_OK(iree_hal_allocator_wrap_buffer( + hal_allocator_, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, + IREE_HAL_MEMORY_ACCESS_ALL, IREE_HAL_BUFFER_USAGE_ALL, iree_byte_span_t{reinterpret_cast(kBufferContents), sizeof(kBufferContents)}, - iree_allocator_system(), iree_allocator_system(), &buffer)); + iree_allocator_null(), &buffer)); // Pass in the tensor as an expanded HAL buffer. iree::vm::ref inputs; diff --git a/iree/samples/custom_modules/native_module.cc b/iree/samples/custom_modules/native_module.cc index d2499fcee8b3b..d4f910ef20c1c 100644 --- a/iree/samples/custom_modules/native_module.cc +++ b/iree/samples/custom_modules/native_module.cc @@ -141,8 +141,9 @@ class CustomModuleState final { // Setup a host-local allocator we can use because this sample doesn't have // a real device allocator. - IREE_RETURN_IF_ERROR(iree_hal_allocator_create_host_local( - allocator_, &host_local_allocator_)); + IREE_RETURN_IF_ERROR( + iree_hal_allocator_create_heap(iree_make_cstring_view("host_local"), + allocator_, &host_local_allocator_)); return OkStatus(); } diff --git a/iree/samples/simple_embedding/BUILD b/iree/samples/simple_embedding/BUILD index 6abae0de02753..c93615348c6d7 100644 --- a/iree/samples/simple_embedding/BUILD +++ b/iree/samples/simple_embedding/BUILD @@ -34,10 +34,6 @@ iree_bytecode_module( cc_test( name = "simple_embedding_test", srcs = ["simple_embedding_test.cc"], - data = [ - # For AddressSanitizer when using Vulkan + a local Nvidia GPU - "//iree/tools:sanitizer_suppressions.txt", - ], deps = [ ":simple_embedding_test_bytecode_module_cc", "//iree/base:api", diff --git a/iree/samples/simple_embedding/CMakeLists.txt b/iree/samples/simple_embedding/CMakeLists.txt index a9dada3008c07..01904c4be7901 100644 --- a/iree/samples/simple_embedding/CMakeLists.txt +++ b/iree/samples/simple_embedding/CMakeLists.txt @@ -33,8 +33,6 @@ iree_cc_test( simple_embedding_test SRCS "simple_embedding_test.cc" - DATA - iree::tools::sanitizer_suppressions.txt DEPS ::simple_embedding_test_bytecode_module_cc absl::core_headers diff --git a/iree/samples/simple_embedding/simple_embedding_test.cc b/iree/samples/simple_embedding/simple_embedding_test.cc index 2e3cbd571cb3d..b98338e8b8974 100644 --- a/iree/samples/simple_embedding/simple_embedding_test.cc +++ b/iree/samples/simple_embedding/simple_embedding_test.cc @@ -171,14 +171,15 @@ TEST_P(SimpleEmbeddingTest, RunOnce) { // Read back the results and ensure we got the right values. IREE_LOG(INFO) << "Reading back results..."; - iree_hal_mapped_memory_t mapped_memory; - IREE_ASSERT_OK(iree_hal_buffer_map(ret_buffer, IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); + iree_hal_buffer_mapping_t mapped_memory; + IREE_ASSERT_OK(iree_hal_buffer_map_range(ret_buffer, + IREE_HAL_MEMORY_ACCESS_READ, 0, + IREE_WHOLE_BUFFER, &mapped_memory)); ASSERT_THAT(absl::Span( reinterpret_cast(mapped_memory.contents.data), mapped_memory.contents.data_length / sizeof(float)), ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f})); - IREE_ASSERT_OK(iree_hal_buffer_unmap(ret_buffer, &mapped_memory)); + iree_hal_buffer_unmap_range(&mapped_memory); IREE_LOG(INFO) << "Results match!"; inputs.reset(); diff --git a/iree/samples/vulkan/CMakeLists.txt b/iree/samples/vulkan/CMakeLists.txt index 785eb2c375003..9dad3f3a36332 100644 --- a/iree/samples/vulkan/CMakeLists.txt +++ b/iree/samples/vulkan/CMakeLists.txt @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} OR NOT ${IREE_HAL_DRIVER_VULKAN}) +if(NOT "${IREE_TARGET_BACKEND_VULKAN-SPIRV}" OR + NOT "${IREE_HAL_DRIVER_VULKAN}") return() endif() diff --git a/iree/samples/vulkan/vulkan_inference_gui.cc b/iree/samples/vulkan/vulkan_inference_gui.cc index 45e4ab4c9f9d6..f14e228fb0f4e 100644 --- a/iree/samples/vulkan/vulkan_inference_gui.cc +++ b/iree/samples/vulkan/vulkan_inference_gui.cc @@ -88,9 +88,8 @@ int iree::IreeMain(int argc, char** argv) { // Setup Vulkan iree_hal_vulkan_features_t iree_vulkan_features = static_cast( - IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS | - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS | + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); std::vector layers = GetInstanceLayers(iree_vulkan_features); std::vector extensions = GetInstanceExtensions(window, iree_vulkan_features); @@ -202,28 +201,32 @@ int iree::IreeMain(int argc, char** argv) { // Load symbols from our static `vkGetInstanceProcAddr` for IREE to use. iree_hal_vulkan_syms_t* iree_vk_syms = nullptr; IREE_CHECK_OK(iree_hal_vulkan_syms_create( - reinterpret_cast(&vkGetInstanceProcAddr), &iree_vk_syms)); + reinterpret_cast(&vkGetInstanceProcAddr), iree_allocator_system(), + &iree_vk_syms)); // Create the driver sharing our VkInstance. iree_hal_driver_t* iree_vk_driver = nullptr; - iree_hal_vulkan_driver_options_t options; - options.api_version = VK_API_VERSION_1_0; - options.features = static_cast( - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan"); + iree_hal_vulkan_driver_options_t driver_options; + driver_options.api_version = VK_API_VERSION_1_0; + driver_options.requested_features = static_cast( + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance( - options, iree_vk_syms, g_Instance, &iree_vk_driver)); + driver_identifier, &driver_options, iree_vk_syms, g_Instance, + iree_allocator_system(), &iree_vk_driver)); // Create a device sharing our VkDevice and queue. // We could also create a separate (possibly low priority) compute queue for // IREE, and/or provide a dedicated transfer queue. + iree_string_view_t device_identifier = iree_make_cstring_view("vulkan"); iree_hal_vulkan_queue_set_t compute_queue_set; compute_queue_set.queue_family_index = g_QueueFamily; compute_queue_set.queue_indices = 1 << 0; iree_hal_vulkan_queue_set_t transfer_queue_set; transfer_queue_set.queue_indices = 0; iree_hal_device_t* iree_vk_device = nullptr; - IREE_CHECK_OK(iree_hal_vulkan_driver_wrap_device( - iree_vk_driver, g_PhysicalDevice, g_Device, compute_queue_set, - transfer_queue_set, &iree_vk_device)); + IREE_CHECK_OK(iree_hal_vulkan_wrap_device( + device_identifier, &driver_options.device_options, iree_vk_syms, + g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set, + &transfer_queue_set, iree_allocator_system(), &iree_vk_device)); // Create a HAL module using the HAL device. iree_vm_module_t* hal_module = nullptr; IREE_CHECK_OK(iree_hal_module_create(iree_vk_device, iree_allocator_system(), @@ -282,11 +285,11 @@ int iree::IreeMain(int argc, char** argv) { // Create wait and signal semaphores for async execution. vm::ref wait_semaphore; - IREE_CHECK_OK(iree_hal_semaphore_create( - iree_vk_device, 0ull, iree_allocator_system(), &wait_semaphore)); + IREE_CHECK_OK( + iree_hal_semaphore_create(iree_vk_device, 0ull, &wait_semaphore)); vm::ref signal_semaphore; - IREE_CHECK_OK(iree_hal_semaphore_create( - iree_vk_device, 0ull, iree_allocator_system(), &signal_semaphore)); + IREE_CHECK_OK( + iree_hal_semaphore_create(iree_vk_device, 0ull, &signal_semaphore)); // -------------------------------------------------------------------------- // -------------------------------------------------------------------------- @@ -392,12 +395,10 @@ int iree::IreeMain(int argc, char** argv) { iree_hal_buffer_view_t* input1_buffer_view = nullptr; IREE_CHECK_OK(iree_hal_buffer_view_create( input0_buffer, /*shape=*/&kElementCount, /*shape_rank=*/1, - IREE_HAL_ELEMENT_TYPE_FLOAT_32, iree_allocator_system(), - &input0_buffer_view)); + IREE_HAL_ELEMENT_TYPE_FLOAT_32, &input0_buffer_view)); IREE_CHECK_OK(iree_hal_buffer_view_create( input1_buffer, /*shape=*/&kElementCount, /*shape_rank=*/1, - IREE_HAL_ELEMENT_TYPE_FLOAT_32, iree_allocator_system(), - &input1_buffer_view)); + IREE_HAL_ELEMENT_TYPE_FLOAT_32, &input1_buffer_view)); iree_hal_buffer_release(input0_buffer); iree_hal_buffer_release(input1_buffer); // Marshal inputs through a VM variant list. @@ -449,14 +450,9 @@ int iree::IreeMain(int argc, char** argv) { auto* output_buffer_view = reinterpret_cast( iree_vm_list_get_ref_deref(outputs.get(), 0, iree_hal_buffer_view_get_descriptor())); - auto* output_buffer = iree_hal_buffer_view_buffer(output_buffer_view); - iree_hal_mapped_memory_t mapped_memory; - IREE_CHECK_OK(iree_hal_buffer_map(output_buffer, - IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER, &mapped_memory)); - memcpy(&latest_output, mapped_memory.contents.data, - mapped_memory.contents.data_length); - iree_hal_buffer_unmap(output_buffer, &mapped_memory); + IREE_CHECK_OK(iree_hal_buffer_read_data( + iree_hal_buffer_view_buffer(output_buffer_view), 0, latest_output, + sizeof(latest_output))); dirty = false; } diff --git a/iree/task/BUILD b/iree/task/BUILD index c5e62e81ef857..c3c7a12b1bd54 100644 --- a/iree/task/BUILD +++ b/iree/task/BUILD @@ -111,7 +111,10 @@ cc_test( cc_test( name = "scope_test", - srcs = ["scope_test.cc"], + srcs = [ + "scope_test.cc", + "task_impl.h", + ], deps = [ ":task", "//iree/base:api", @@ -121,6 +124,25 @@ cc_test( ], ) +cc_test( + name = "task_tests", + srcs = [ + "task_test_barrier.cc", + "task_test_call.cc", + "task_test_dispatch.cc", + "task_test_fence.cc", + "task_test_nop.cc", + "task_test_wait.cc", + ], + deps = [ + ":task", + "//iree/base:api", + "//iree/task/testing:task_test", + "//iree/testing:gtest", + "//iree/testing:gtest_main", + ], +) + cc_test( name = "topology_test", srcs = ["topology_test.cc"], @@ -129,5 +151,6 @@ cc_test( "//iree/base:api", "//iree/testing:gtest", "//iree/testing:gtest_main", + "@cpuinfo", ], ) diff --git a/iree/task/CMakeLists.txt b/iree/task/CMakeLists.txt index f365ccf6d0ac6..f88fe09a7e795 100644 --- a/iree/task/CMakeLists.txt +++ b/iree/task/CMakeLists.txt @@ -113,6 +113,7 @@ iree_cc_test( scope_test SRCS "scope_test.cc" + "task_impl.h" DEPS ::task iree::base::api @@ -121,6 +122,24 @@ iree_cc_test( iree::testing::gtest_main ) +iree_cc_test( + NAME + task_tests + SRCS + "task_test_barrier.cc" + "task_test_call.cc" + "task_test_dispatch.cc" + "task_test_fence.cc" + "task_test_nop.cc" + "task_test_wait.cc" + DEPS + ::task + iree::base::api + iree::task::testing::task_test + iree::testing::gtest + iree::testing::gtest_main +) + iree_cc_test( NAME topology_test @@ -128,6 +147,7 @@ iree_cc_test( "topology_test.cc" DEPS ::task + cpuinfo iree::base::api iree::testing::gtest iree::testing::gtest_main diff --git a/iree/task/executor.c b/iree/task/executor.c index 205f5b0952dec..647db0a5dc8b1 100644 --- a/iree/task/executor.c +++ b/iree/task/executor.c @@ -86,16 +86,18 @@ iree_status_t iree_task_executor_create( // executor and since we know the precise lifetime of them we can keep them // entirely within the system here. if (iree_status_is_ok(status)) { - status = iree_task_pool_initialize( - allocator, sizeof(iree_task_dispatch_slice_t), - worker_count * IREE_TASK_EXECUTOR_INITIAL_SLICE_RESERVATION_PER_WORKER, - &executor->slice_task_pool); + status = iree_task_pool_initialize(allocator, sizeof(iree_task_fence_t), 8, + &executor->fence_task_pool); } if (iree_status_is_ok(status)) { status = iree_task_pool_initialize( - allocator, sizeof(iree_task_dispatch_shard_t), - worker_count * IREE_TASK_EXECUTOR_INITIAL_SHARD_RESERVATION_PER_WORKER, - &executor->shard_task_pool); + allocator, + iree_max(sizeof(iree_task_dispatch_shard_t), + sizeof(iree_task_dispatch_slice_t)), + worker_count * + iree_max(IREE_TASK_EXECUTOR_INITIAL_SHARD_RESERVATION_PER_WORKER, + IREE_TASK_EXECUTOR_INITIAL_SLICE_RESERVATION_PER_WORKER), + &executor->dispatch_task_pool); } // Bring up the workers; the threads will be created here but be suspended @@ -169,8 +171,8 @@ static void iree_task_executor_destroy(iree_task_executor_t* executor) { iree_slim_mutex_deinitialize(&executor->coordinator_mutex); iree_atomic_task_slist_deinitialize(&executor->incoming_ready_slist); iree_atomic_task_slist_deinitialize(&executor->incoming_waiting_slist); - iree_task_pool_deinitialize(&executor->slice_task_pool); - iree_task_pool_deinitialize(&executor->shard_task_pool); + iree_task_pool_deinitialize(&executor->fence_task_pool); + iree_task_pool_deinitialize(&executor->dispatch_task_pool); iree_allocator_free(executor->allocator, executor); IREE_TRACE_ZONE_END(z0); @@ -188,6 +190,19 @@ void iree_task_executor_release(iree_task_executor_t* executor) { } } +iree_status_t iree_task_executor_acquire_fence(iree_task_executor_t* executor, + iree_task_scope_t* scope, + iree_task_fence_t** out_fence) { + *out_fence = NULL; + iree_task_fence_t* fence = NULL; + IREE_RETURN_IF_ERROR(iree_task_pool_acquire(&executor->fence_task_pool, + (iree_task_t**)&fence)); + iree_task_fence_initialize(scope, fence); + fence->header.pool = &executor->fence_task_pool; + *out_fence = fence; + return iree_ok_status(); +} + // Schedules a generic task to a worker matching its affinity. // The task will be posted to the worker mailbox and available for the worker to // begin processing as soon as the |post_batch| is submitted. @@ -221,6 +236,9 @@ void iree_task_executor_schedule_ready_tasks( while ((task = iree_task_list_pop_front(&pending_submission->ready_list))) { switch (task->type) { case IREE_TASK_TYPE_NOP: + // Doesn't do anything; just retire and continue on to any dependents. + iree_task_nop_retire((iree_task_nop_t*)task, pending_submission); + break; case IREE_TASK_TYPE_CALL: case IREE_TASK_TYPE_DISPATCH_SLICE: { // Generic routing to workers for tasks that should always run there. @@ -259,11 +277,11 @@ void iree_task_executor_schedule_ready_tasks( } else { if (task->flags & IREE_TASK_FLAG_DISPATCH_SLICED) { iree_task_dispatch_issue_sliced((iree_task_dispatch_t*)task, - &executor->slice_task_pool, + &executor->dispatch_task_pool, pending_submission, post_batch); } else { iree_task_dispatch_issue_sharded((iree_task_dispatch_t*)task, - &executor->shard_task_pool, + &executor->dispatch_task_pool, pending_submission, post_batch); } } @@ -293,28 +311,26 @@ void iree_task_executor_merge_submission(iree_task_executor_t* executor, iree_task_submission_reset(submission); } -iree_status_t iree_task_executor_submit(iree_task_executor_t* executor, - iree_task_submission_t* submission) { +void iree_task_executor_submit(iree_task_executor_t* executor, + iree_task_submission_t* submission) { IREE_TRACE_ZONE_BEGIN(z0); // Concatenate the submitted tasks onto our primary LIFO incoming lists. iree_task_executor_merge_submission(executor, submission); IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); } -iree_status_t iree_task_executor_flush(iree_task_executor_t* executor) { +void iree_task_executor_flush(iree_task_executor_t* executor) { IREE_TRACE_ZONE_BEGIN(z0); // Mostly a no-op today as we aren't deferring submission with the scheduling // mode. Instead, we'll just run the coordinator inline to ensure all tasks // are pushed to workers. iree_task_executor_coordinate(executor, /*current_worker=*/NULL, - /*speculative=*/false); + /*wait_on_idle=*/false); IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); } // Merges incoming likely-unresolved wait tasks into the primary executor lists. @@ -506,101 +522,103 @@ static void iree_task_executor_wait_any_task( // Dispatches tasks in the global submission queue to workers. // This is called by users upon submission of new tasks or by workers when they -// run out of tasks to process. |speculative| indicates whether the coordination -// request is done as a fallback in the event of there possibly being new work -// available. +// run out of tasks to process. |wait_on_idle| indicates whether the +// coordination request is done as a fallback in the event of there possibly +// being new work available. // // If a coordination run ends up with no ready tasks and one or more waiting // tasks then the coordinator will wait for one of the tasks to become ready. -// This only happens in the speculative case (so it's always a worker) as in +// This only happens in the |wait_on_idle| case (so it's always a worker) as in // those cases the next step for the worker would have been to wait anyway. In // the non-speculative case the coordinator polls the wait handles to see if // they have resolved instead, possibly readying more tasks immediately. void iree_task_executor_coordinate(iree_task_executor_t* executor, iree_task_worker_t* current_worker, - bool speculative) { - if (speculative) { - if (!iree_slim_mutex_try_lock(&executor->coordinator_mutex)) { - // Another thread is already holding the coordination lock. - // Return to the caller to wait for it to finish. - // TODO(benvanik): spin here if it's likely we'll have work after the - // other coordinator finishes - that way we don't enter the wait. - return; - } - } else { - iree_slim_mutex_lock(&executor->coordinator_mutex); - } + bool wait_on_idle) { + iree_slim_mutex_lock(&executor->coordinator_mutex); IREE_TRACE_ZONE_BEGIN(z0); - // Check for incoming submissions and move their posted tasks into our - // local lists. Any of the tasks here are ready to execute immediately and - // ones we should be able to distribute to workers without delay. The - // waiting tasks are to the best of the caller's knowledge not ready yet. - // - // Note that we only do this once per coordination; that's so we don't - // starve if submissions come in faster than we can schedule them. - // Coordination will run again when workers become idle and will pick up - // any changes then. - // - // As we schedule tasks we may spawn new ones (like a dispatch -> many - // dispatch slices) and we keep track of those here. By doing a pass through - // all ready tasks and only then merging in the new submission we get - // breadth-first traversal of task graphs even if they originate from - // various places and have no relation - hopefully leading to better average - // latency. - iree_task_submission_t pending_submission; - iree_task_submission_initialize_from_lifo_slist( - &executor->incoming_ready_slist, &pending_submission); - iree_task_list_append_from_fifo_slist(&pending_submission.waiting_list, - &executor->incoming_waiting_slist); - - // Scratch coordinator submission batch used during scheduling to batch up - // all tasks that will be posted to each worker. We could stash this on the - // executor but given that which thread is playing the role of the coordinator - // is random it's better to ensure that these bytes never incur a cache miss - // by making them live here in the stack of the chosen thread. - iree_task_post_batch_t* post_batch = - iree_alloca(sizeof(iree_task_post_batch_t) + - executor->worker_count * sizeof(iree_task_list_t)); - iree_task_post_batch_initialize(executor, current_worker, post_batch); - - // Poll the waiting tasks to see if any have resolved. This dramatically - // cuts latency in cases where the wait handle completes prior to us - // entering the real wait. When we have semaphores sequencing back-to-back - // work this ensures that we pack in future dispatch work earlier vs. - // waiting for a full thread hop. - // - // If any waits have resolved then they'll be moved to the ready list here - // and then get processed FIFO with the tasks that were ready in the - // request. - iree_task_executor_poll_waiting_tasks(executor, &pending_submission); - - // Schedule all ready tasks in this batch. Some may complete inline (such - // as ready barriers with all their dependencies resolved) while others may - // be scheduled on workers via the post batch. - iree_task_executor_schedule_ready_tasks(executor, &pending_submission, - post_batch); - - // Merge any newly waiting tasks into the global wait list. - iree_task_executor_merge_wait_list(executor, - &pending_submission.waiting_list); - - // Post all new work to workers; they may wake and begin executing - // immediately. Returns whether this worker has new tasks for it to work on. - bool did_post = iree_task_post_batch_submit(post_batch); - if (!did_post && speculative) { - // No work was found; wait on one or more of our wait handles. - // This will block the calling thread but that's fine as they were going - // to wait anyway and were just speculatively seeing if there was work first - // by requesting coordination. If work completes here we'll catch it on - // the poll next loop around. - iree_task_executor_wait_any_task(executor, current_worker, - &pending_submission); - } + // We may be adding tasks/waiting/etc on each pass through coordination - to + // ensure we completely drain the incoming queues and satisfied waits we loop + // until there's nothing left to coordinate. + bool schedule_dirty = true; + do { + // Check for incoming submissions and move their posted tasks into our + // local lists. Any of the tasks here are ready to execute immediately and + // ones we should be able to distribute to workers without delay. The + // waiting tasks are to the best of the caller's knowledge not ready yet. + // + // Note that we only do this once per coordination; that's so we don't + // starve if submissions come in faster than we can schedule them. + // Coordination will run again when workers become idle and will pick up + // any changes then. + // + // As we schedule tasks we may spawn new ones (like a dispatch -> many + // dispatch slices) and we keep track of those here. By doing a pass through + // all ready tasks and only then merging in the new submission we get + // breadth-first traversal of task graphs even if they originate from + // various places and have no relation - hopefully leading to better average + // latency. + iree_task_submission_t pending_submission; + iree_task_submission_initialize_from_lifo_slist( + &executor->incoming_ready_slist, &pending_submission); + iree_task_list_append_from_fifo_slist(&pending_submission.waiting_list, + &executor->incoming_waiting_slist); + + // Scratch coordinator submission batch used during scheduling to batch up + // all tasks that will be posted to each worker. We could stash this on the + // executor but given that which thread is playing the role of the + // coordinator is random it's better to ensure that these bytes never incur + // a cache miss by making them live here in the stack of the chosen thread. + iree_task_post_batch_t* post_batch = + iree_alloca(sizeof(iree_task_post_batch_t) + + executor->worker_count * sizeof(iree_task_list_t)); + iree_task_post_batch_initialize(executor, current_worker, post_batch); + + // Poll the waiting tasks to see if any have resolved. This dramatically + // cuts latency in cases where the wait handle completes prior to us + // entering the real wait. When we have semaphores sequencing back-to-back + // work this ensures that we pack in future dispatch work earlier vs. + // waiting for a full thread hop. + // + // If any waits have resolved then they'll be moved to the ready list here + // and then get processed FIFO with the tasks that were ready in the + // request. + iree_task_executor_poll_waiting_tasks(executor, &pending_submission); + + // Schedule all ready tasks in this batch. Some may complete inline (such + // as ready barriers with all their dependencies resolved) while others may + // be scheduled on workers via the post batch. + iree_task_executor_schedule_ready_tasks(executor, &pending_submission, + post_batch); + + // Merge any newly waiting tasks into the global wait list. + iree_task_executor_merge_wait_list(executor, + &pending_submission.waiting_list); + + // Post all new work to workers; they may wake and begin executing + // immediately. Returns whether this worker has new tasks for it to work on. + bool did_post = iree_task_post_batch_submit(post_batch); + if (!did_post && wait_on_idle) { + // No work was found; wait on one or more of our wait handles. + // This will block the calling thread but that's fine as they were going + // to wait anyway and were just speculatively seeing if there was work + // first by requesting coordination. If work completes here we'll catch it + // on the poll next loop around. + iree_task_executor_wait_any_task(executor, current_worker, + &pending_submission); + } - // Merge any new work into the submission list for future coordinators to - // deal with - we don't want the possibility of starvation by looping on this. - iree_task_executor_merge_submission(executor, &pending_submission); + // Merge any new work into the submission list for future coordinators to + // deal with - we don't want the possibility of starvation by looping on + // this. + if (!iree_task_submission_is_empty(&pending_submission)) { + iree_task_executor_merge_submission(executor, &pending_submission); + schedule_dirty = true; + } else { + schedule_dirty = false; + } + } while (schedule_dirty); iree_slim_mutex_unlock(&executor->coordinator_mutex); IREE_TRACE_ZONE_END(z0); diff --git a/iree/task/executor.h b/iree/task/executor.h index 0f49878ad03f3..fc6cc2d6db3d1 100644 --- a/iree/task/executor.h +++ b/iree/task/executor.h @@ -319,6 +319,11 @@ void iree_task_executor_retain(iree_task_executor_t* executor); // Releases the given |executor| from the caller. void iree_task_executor_release(iree_task_executor_t* executor); +// Acquires a fence for the given |scope| from the executor fence pool. +iree_status_t iree_task_executor_acquire_fence(iree_task_executor_t* executor, + iree_task_scope_t* scope, + iree_task_fence_t** out_fence); + // TODO(benvanik): scheduling mode mutation, compute quota control, etc. // Submits a batch of tasks for execution. @@ -334,8 +339,8 @@ void iree_task_executor_release(iree_task_executor_t* executor); // // NOTE: it's possible for all work in the submission to complete prior to this // function returning. -iree_status_t iree_task_executor_submit(iree_task_executor_t* executor, - iree_task_submission_t* submission); +void iree_task_executor_submit(iree_task_executor_t* executor, + iree_task_submission_t* submission); // Flushes any pending task batches for execution. // @@ -344,7 +349,7 @@ iree_status_t iree_task_executor_submit(iree_task_executor_t* executor, // // NOTE: due to races it's possible for new work to arrive from other threads // after the flush has occurred but prior to this call returning. -iree_status_t iree_task_executor_flush(iree_task_executor_t* executor); +void iree_task_executor_flush(iree_task_executor_t* executor); // Donates the calling thread to the executor until either |wait_handle| // resolves or |deadline_ns| is exceeded. diff --git a/iree/task/executor_impl.h b/iree/task/executor_impl.h index dbb17c57a4ff9..ce17b7795d43b 100644 --- a/iree/task/executor_impl.h +++ b/iree/task/executor_impl.h @@ -50,8 +50,8 @@ struct iree_task_executor_s { // Pools of transient dispatch tasks shared across all workers. // Depending on configuration the task pool may allocate after creation using // the allocator provided upon executor creation. - iree_task_pool_t slice_task_pool; - iree_task_pool_t shard_task_pool; + iree_task_pool_t fence_task_pool; + iree_task_pool_t dispatch_task_pool; // A list of incoming tasks that are ready to execute immediately. // The list is LIFO and we require that task lists are reversed by the @@ -126,9 +126,13 @@ void iree_task_executor_schedule_ready_tasks( // |current_worker| will be NULL if called from a non-worker thread and // otherwise be the current worker; used to avoid round-tripping through the // whole system to post to oneself. +// +// If the |current_worker| has no more work remaining and |wait_on_idle| is set +// then the calling thread may wait on any pending wait tasks until one resolves +// or more work is scheduled for the worker. void iree_task_executor_coordinate(iree_task_executor_t* executor, iree_task_worker_t* current_worker, - bool speculative); + bool wait_on_idle); // Tries to steal an entire task from a sibling worker (based on topology). // Returns a task that is available (has not yet begun processing at all). diff --git a/iree/task/executor_test.cc b/iree/task/executor_test.cc index 6641edba65f59..e4ed043877ef3 100644 --- a/iree/task/executor_test.cc +++ b/iree/task/executor_test.cc @@ -24,7 +24,7 @@ namespace { static thread_local volatile uint64_t xxx = 0; -static void simulate_work(iree_task_tile_context_t* tile_context) { +static void simulate_work(const iree_task_tile_context_t* tile_context) { iree_prng_splitmix64_state_t state; iree_prng_splitmix64_initialize(xxx, &state); bool slow = false; // tile_context->workgroup_xyz[0] % 3 == 1; @@ -50,23 +50,23 @@ TEST(ExecutorTest, Any) { iree_allocator_t allocator = iree_allocator_system(); - iree_task_topology_t* topology = NULL; + iree_task_topology_t topology; #if 1 - IREE_CHECK_OK(iree_task_topology_from_physical_cores( - /*max_core_count=*/6, allocator, &topology)); + iree_task_topology_initialize_from_physical_cores( + /*max_core_count=*/6, &topology); #elif 0 - IREE_CHECK_OK(iree_task_topology_from_unique_l2_cache_groups( - /*max_group_count=*/6, allocator, &topology)); + iree_task_topology_initialize_from_unique_l2_cache_groups( + /*max_group_count=*/6, &topology); #else - IREE_CHECK_OK(iree_task_topology_from_group_count(/*group_count=*/6, - allocator, &topology)); + iree_task_topology_initialize_from_group_count(/*group_count=*/6, &topology); #endif + iree_task_executor_t* executor = NULL; iree_task_scheduling_mode_t scheduling_mode = IREE_TASK_SCHEDULING_MODE_RESERVED; - IREE_CHECK_OK(iree_task_executor_create(scheduling_mode, topology, allocator, + IREE_CHECK_OK(iree_task_executor_create(scheduling_mode, &topology, allocator, &executor)); - iree_task_topology_free(topology); + iree_task_topology_deinitialize(&topology); // iree_task_scope_t scope_a; @@ -74,27 +74,27 @@ TEST(ExecutorTest, Any) { // iree_task_call_t call0; - iree_task_call_initialize( - &scope_a, - iree_task_make_closure( - [](uintptr_t user_context, uintptr_t task_context) { - IREE_TRACE_SCOPE0("call0"); - EXPECT_EQ(0, user_context); - return iree_ok_status(); - }, - 0), - &call0); + iree_task_call_initialize(&scope_a, + iree_task_make_call_closure( + [](uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + IREE_TRACE_SCOPE0("call0"); + EXPECT_EQ(0, user_context); + return iree_ok_status(); + }, + 0), + &call0); const uint32_t workgroup_size_0[3] = {256, 1, 1}; const uint32_t workgroup_count_0[3] = {32, 4, 2}; iree_task_dispatch_t dispatch0; iree_task_dispatch_initialize( &scope_a, - iree_task_make_closure( - [](uintptr_t user_context, uintptr_t task_context) { + iree_task_make_dispatch_closure( + [](uintptr_t user_context, + const iree_task_tile_context_t* tile_context, + iree_task_submission_t* pending_submission) { IREE_TRACE_SCOPE0("tile0"); - iree_task_tile_context_t* tile_context = - (iree_task_tile_context_t*)task_context; EXPECT_EQ(0, user_context); simulate_work(tile_context); iree_atomic_fetch_add_int32(&tile_context->statistics->reserved, 1, @@ -110,11 +110,11 @@ TEST(ExecutorTest, Any) { iree_task_dispatch_t dispatch1; iree_task_dispatch_initialize( &scope_a, - iree_task_make_closure( - [](uintptr_t user_context, uintptr_t task_context) { + iree_task_make_dispatch_closure( + [](uintptr_t user_context, + const iree_task_tile_context_t* tile_context, + iree_task_submission_t* pending_submission) { IREE_TRACE_SCOPE0("tile1"); - iree_task_tile_context_t* tile_context = - (iree_task_tile_context_t*)task_context; EXPECT_EQ(0, user_context); simulate_work(tile_context); iree_atomic_fetch_add_int32(&tile_context->statistics->reserved, 1, @@ -123,20 +123,20 @@ TEST(ExecutorTest, Any) { }, 0), workgroup_size_1, workgroup_count_1, &dispatch1); - // dispatch1.header.flags |= IREE_TASK_FLAG_DISPATCH_SLICED; + dispatch1.header.flags |= IREE_TASK_FLAG_DISPATCH_SLICED; // iree_task_call_t call1; - iree_task_call_initialize( - &scope_a, - iree_task_make_closure( - [](uintptr_t user_context, uintptr_t task_context) { - IREE_TRACE_SCOPE0("call1"); - EXPECT_EQ(1, user_context); - return iree_ok_status(); - }, - 1), - &call1); + iree_task_call_initialize(&scope_a, + iree_task_make_call_closure( + [](uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + IREE_TRACE_SCOPE0("call1"); + EXPECT_EQ(1, user_context); + return iree_ok_status(); + }, + 1), + &call1); #if 1 // no barrier between dispatches; fanout @@ -155,22 +155,22 @@ TEST(ExecutorTest, Any) { #endif // fence - iree_task_fence_t fence0; - iree_task_fence_initialize(&scope_a, &fence0); - iree_task_set_completion_task(&call1.header, &fence0.header); + iree_task_fence_t* fence0 = NULL; + IREE_CHECK_OK(iree_task_executor_acquire_fence(executor, &scope_a, &fence0)); + iree_task_set_completion_task(&call1.header, &fence0->header); // iree_task_submission_t sub0; iree_task_submission_initialize(&sub0); iree_task_submission_enqueue(&sub0, &call0.header); - IREE_CHECK_OK(iree_task_executor_submit(executor, &sub0)); + iree_task_executor_submit(executor, &sub0); // // iree_task_submission_t sub1; // iree_task_submission_initialize(&sub1); // IREE_CHECK_OK(iree_task_executor_submit(executor, &sub1)); - IREE_CHECK_OK(iree_task_executor_flush(executor)); + iree_task_executor_flush(executor); IREE_CHECK_OK(iree_task_scope_wait_idle(&scope_a, IREE_TIME_INFINITE_FUTURE)); diff --git a/iree/task/list.c b/iree/task/list.c index 6e700005c1373..bba9a21fc7a74 100644 --- a/iree/task/list.c +++ b/iree/task/list.c @@ -107,8 +107,9 @@ iree_task_t* iree_task_list_pop_front(iree_task_list_t* list) { void iree_task_list_erase(iree_task_list_t* list, iree_task_t* prev_task, iree_task_t* task) { if (task == list->head) { - // Removing head. + // Removing head (which may _also_ be the tail). list->head = task->next_task; + if (list->tail == task) list->tail = task->next_task; } else if (task == list->tail) { // Removing tail. list->tail = prev_task; diff --git a/iree/task/list_test.cc b/iree/task/list_test.cc index 19b472f84cd8b..66f80e0ebd2f7 100644 --- a/iree/task/list_test.cc +++ b/iree/task/list_test.cc @@ -83,6 +83,40 @@ TEST(TaskListTest, Move) { EXPECT_TRUE(CheckListOrderFIFO(&list_b)); } +TEST(TaskListTest, DiscardEmpty) { + iree_task_list_t list; + iree_task_list_initialize(&list); + + EXPECT_TRUE(iree_task_list_is_empty(&list)); + iree_task_list_discard(&list); + EXPECT_TRUE(iree_task_list_is_empty(&list)); +} + +TEST(TaskListTest, Discard) { + auto pool = AllocateNopPool(); + auto scope = AllocateScope("a"); + + iree_task_list_t list; + iree_task_list_initialize(&list); + EXPECT_TRUE(iree_task_list_is_empty(&list)); + + auto task0 = AcquireNopTask(pool, scope, 0); + auto task1 = AcquireNopTask(pool, scope, 1); + auto task2 = AcquireNopTask(pool, scope, 2); + auto task3 = AcquireNopTask(pool, scope, 3); + iree_task_list_push_back(&list, task0); + iree_task_list_push_back(&list, task1); + iree_task_list_push_back(&list, task2); + iree_task_list_push_back(&list, task3); + EXPECT_EQ(4, iree_task_list_calculate_size(&list)); + EXPECT_TRUE(CheckListOrderFIFO(&list)); + + iree_task_list_discard(&list); + EXPECT_TRUE(iree_task_list_is_empty(&list)); + + // IMPLICIT: if the tasks were not released back to the pool we'll leak. +} + TEST(TaskListTest, PushFront) { auto pool = AllocateNopPool(); auto scope = AllocateScope("a"); @@ -175,6 +209,8 @@ TEST(TaskListTest, Erase) { iree_task_list_erase(&list, NULL, task1); EXPECT_TRUE(iree_task_list_is_empty(&list)); + EXPECT_EQ(NULL, iree_task_list_front(&list)); + EXPECT_EQ(NULL, iree_task_list_back(&list)); } TEST(TaskListTest, PrependEmpty) { diff --git a/iree/task/pool.c b/iree/task/pool.c index 172a2ce7046ab..753d9c8b173de 100644 --- a/iree/task/pool.c +++ b/iree/task/pool.c @@ -14,6 +14,7 @@ #include "iree/task/pool.h" +#include "iree/base/debugging.h" #include "iree/base/math.h" // Minimum byte size of a block in bytes, including the tasks as well as the @@ -94,9 +95,11 @@ static iree_status_t iree_task_pool_grow(iree_task_pool_t* pool, iree_task_t* head = (iree_task_t*)p; iree_task_t* tail = head; head->next_task = NULL; + head->pool = pool; for (iree_host_size_t i = 0; i < actual_capacity; ++i, p -= pool->task_size) { iree_task_t* task = (iree_task_t*)p; task->next_task = head; + task->pool = pool; head = task; } @@ -289,6 +292,6 @@ iree_status_t iree_task_pool_acquire_many(iree_task_pool_t* pool, void iree_task_pool_release(iree_task_pool_t* pool, iree_task_t* task) { if (!pool) return; - assert(task->pool == pool); + IREE_ASSERT_EQ(task->pool, pool); iree_atomic_task_slist_push(&pool->available_slist, task); } diff --git a/iree/task/pool.h b/iree/task/pool.h index 7610ff7463414..b011c98e55ca6 100644 --- a/iree/task/pool.h +++ b/iree/task/pool.h @@ -95,7 +95,7 @@ void iree_task_pool_deinitialize(iree_task_pool_t* pool); void iree_task_pool_trim(iree_task_pool_t* pool); // Acquires a task from the task pool. The returned task will have undefined -// contents and must be intialized by the caller. +// contents and must be initialized by the caller. iree_status_t iree_task_pool_acquire(iree_task_pool_t* pool, iree_task_t** out_task); diff --git a/iree/task/pool_test.cc b/iree/task/pool_test.cc index 641292f33df1a..feb05681c4564 100644 --- a/iree/task/pool_test.cc +++ b/iree/task/pool_test.cc @@ -19,8 +19,80 @@ namespace { -TEST(PoolTest, Any) { - // TODO(benvanik): tests. +typedef struct { + iree_task_t base; + uint8_t payload[32]; +} iree_test_task_t; + +TEST(PoolTest, Lifetime) { + iree_task_pool_t pool; + IREE_ASSERT_OK(iree_task_pool_initialize( + iree_allocator_system(), sizeof(iree_test_task_t), 32, &pool)); + iree_task_pool_deinitialize(&pool); +} + +TEST(PoolTest, AcquireRelease) { + // Start with 2 preallocated tasks so we can test both acquiring existing and + // growing to allocate new tasks. + iree_task_pool_t pool; + IREE_ASSERT_OK(iree_task_pool_initialize(iree_allocator_system(), + sizeof(iree_test_task_t), 2, &pool)); + + // Acquire 4 tasks (so we test both the initial size and allocated tasks). + iree_test_task_t* tasks[4] = {NULL, NULL, NULL, NULL}; + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + IREE_ASSERT_OK(iree_task_pool_acquire(&pool, (iree_task_t**)&tasks[i])); + EXPECT_TRUE(tasks[i] != NULL); + } + + // Release all tasks back to the pool. + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + iree_task_pool_release(&pool, (iree_task_t*)tasks[i]); + } + + // Acquire all tasks again to make sure we put them back in correctly. + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + IREE_ASSERT_OK(iree_task_pool_acquire(&pool, (iree_task_t**)&tasks[i])); + EXPECT_TRUE(tasks[i] != NULL); + } + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + iree_task_pool_release(&pool, (iree_task_t*)tasks[i]); + } + + iree_task_pool_deinitialize(&pool); +} + +TEST(PoolTest, Trim) { + // Start with 2 preallocated tasks so we can test both acquiring existing and + // growing to allocate new tasks. + iree_task_pool_t pool; + IREE_ASSERT_OK(iree_task_pool_initialize(iree_allocator_system(), + sizeof(iree_test_task_t), 2, &pool)); + + // Acquire and release some tasks. + iree_test_task_t* tasks[8] = {NULL, NULL, NULL, NULL}; + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + IREE_ASSERT_OK(iree_task_pool_acquire(&pool, (iree_task_t**)&tasks[i])); + EXPECT_TRUE(tasks[i] != NULL); + } + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + iree_task_pool_release(&pool, (iree_task_t*)tasks[i]); + } + + // Trim to shrink the pool memory. + // NOTE: trimming is only supported when there are no outstanding tasks. + iree_task_pool_trim(&pool); + + // Acquire again to make sure we can reallocate the pool. + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + IREE_ASSERT_OK(iree_task_pool_acquire(&pool, (iree_task_t**)&tasks[i])); + EXPECT_TRUE(tasks[i] != NULL); + } + for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(tasks); ++i) { + iree_task_pool_release(&pool, (iree_task_t*)tasks[i]); + } + + iree_task_pool_deinitialize(&pool); } } // namespace diff --git a/iree/task/post_batch.c b/iree/task/post_batch.c index 111c39c4e9dee..1af8c9b0648ad 100644 --- a/iree/task/post_batch.c +++ b/iree/task/post_batch.c @@ -55,8 +55,11 @@ static iree_host_size_t iree_task_post_batch_select_random_worker( iree_host_size_t iree_task_post_batch_select_worker( iree_task_post_batch_t* post_batch, iree_task_affinity_set_t affinity_set) { if (post_batch->current_worker) { - // Posting from a worker - prefer sending right back to this worker. - if (affinity_set & post_batch->current_worker->worker_bit) { + // Posting from a worker - prefer sending right back to this worker if we + // haven't already scheduled for it. + if ((affinity_set & post_batch->current_worker->worker_bit) && + !(post_batch->worker_pending_mask & + post_batch->current_worker->worker_bit)) { return iree_task_affinity_set_count_trailing_zeros( post_batch->current_worker->worker_bit); } @@ -65,10 +68,13 @@ iree_host_size_t iree_task_post_batch_select_worker( // Prefer workers that are idle as though they'll need to wake up it is // guaranteed that they aren't working on something else and the latency of // waking should (hopefully) be less than the latency of waiting for a - // worker's queue to finish. + // worker's queue to finish. Note that we only consider workers idle if we + // ourselves in this batch haven't already queued work for them (as then they + // aren't going to be idle). iree_task_affinity_set_t worker_idle_mask = iree_atomic_task_affinity_set_load( &post_batch->executor->worker_idle_mask, iree_memory_order_relaxed); + worker_idle_mask &= ~post_batch->worker_pending_mask; iree_task_affinity_set_t idle_affinity_set = affinity_set & worker_idle_mask; if (idle_affinity_set) { return iree_task_post_batch_select_random_worker(post_batch, diff --git a/iree/task/queue.c b/iree/task/queue.c index d546b0d18b90c..23e6d0493237c 100644 --- a/iree/task/queue.c +++ b/iree/task/queue.c @@ -49,7 +49,7 @@ void iree_task_queue_append_from_lifo_list_unsafe(iree_task_queue_t* queue, iree_slim_mutex_unlock(&queue->mutex); } -iree_task_t* iree_task_queue_append_from_lifo_slist( +iree_task_t* iree_task_queue_flush_from_lifo_slist( iree_task_queue_t* queue, iree_atomic_task_slist_t* source_slist) { // Perform the flush and swap outside of the lock; acquiring the list is // atomic and then we own it exclusively. diff --git a/iree/task/queue.h b/iree/task/queue.h index 6248551a34d6e..e3e5121ff76cf 100644 --- a/iree/task/queue.h +++ b/iree/task/queue.h @@ -146,7 +146,7 @@ void iree_task_queue_append_from_lifo_list_unsafe(iree_task_queue_t* queue, // pre-existing or from the newly flushed tasks. // // Must only be called from the owning worker's thread. -iree_task_t* iree_task_queue_append_from_lifo_slist( +iree_task_t* iree_task_queue_flush_from_lifo_slist( iree_task_queue_t* queue, iree_atomic_task_slist_t* source_slist); // Pops a task from the front of the queue if any are available. diff --git a/iree/task/queue_test.cc b/iree/task/queue_test.cc index 9b8b448d842f5..6aef795d7949c 100644 --- a/iree/task/queue_test.cc +++ b/iree/task/queue_test.cc @@ -19,8 +19,313 @@ namespace { -TEST(QueueTest, Any) { - // TODO(benvanik): tests. +TEST(QueueTest, Lifetime) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, Empty) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + EXPECT_FALSE(iree_task_queue_pop_front(&queue)); + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, PushPop) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + EXPECT_FALSE(iree_task_queue_pop_front(&queue)); + + iree_task_t task_a = {0}; + iree_task_queue_push_front(&queue, &task_a); + + EXPECT_FALSE(iree_task_queue_is_empty(&queue)); + + iree_task_t task_b = {0}; + iree_task_queue_push_front(&queue, &task_b); + + EXPECT_FALSE(iree_task_queue_is_empty(&queue)); + EXPECT_EQ(&task_b, iree_task_queue_pop_front(&queue)); + + EXPECT_FALSE(iree_task_queue_is_empty(&queue)); + EXPECT_EQ(&task_a, iree_task_queue_pop_front(&queue)); + + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + EXPECT_FALSE(iree_task_queue_pop_front(&queue)); + + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, AppendListEmpty) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + + iree_task_list_t list = {0}; + + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + iree_task_queue_append_from_lifo_list_unsafe(&queue, &list); + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + EXPECT_TRUE(iree_task_list_is_empty(&list)); + + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, AppendList1) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + + iree_task_list_t list = {0}; + iree_task_t task_a = {0}; + iree_task_list_push_front(&list, &task_a); + + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + iree_task_queue_append_from_lifo_list_unsafe(&queue, &list); + EXPECT_FALSE(iree_task_queue_is_empty(&queue)); + EXPECT_TRUE(iree_task_list_is_empty(&list)); + + EXPECT_EQ(&task_a, iree_task_queue_pop_front(&queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, AppendListOrdered) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + + // Make a lifo list: b<-a. + iree_task_list_t list = {0}; + iree_task_t task_a = {0}; + iree_task_list_push_front(&list, &task_a); + iree_task_t task_b = {0}; + iree_task_list_push_front(&list, &task_b); + + // Append the list to the queue; it should swap LIFO->FIFO. + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + iree_task_queue_append_from_lifo_list_unsafe(&queue, &list); + EXPECT_FALSE(iree_task_queue_is_empty(&queue)); + EXPECT_TRUE(iree_task_list_is_empty(&list)); + + // Pop list and ensure order: a->b. + EXPECT_EQ(&task_a, iree_task_queue_pop_front(&queue)); + EXPECT_EQ(&task_b, iree_task_queue_pop_front(&queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, FlushSlistEmpty) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + + iree_atomic_task_slist_t slist; + iree_atomic_task_slist_initialize(&slist); + + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + EXPECT_FALSE(iree_task_queue_flush_from_lifo_slist(&queue, &slist)); + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + + iree_atomic_task_slist_deinitialize(&slist); + + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, FlushSlist1) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + + iree_atomic_task_slist_t slist; + iree_atomic_task_slist_initialize(&slist); + iree_task_t task_a = {0}; + iree_atomic_task_slist_push(&slist, &task_a); + + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + EXPECT_EQ(&task_a, iree_task_queue_flush_from_lifo_slist(&queue, &slist)); + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + + iree_atomic_task_slist_deinitialize(&slist); + + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, FlushSlistOrdered) { + iree_task_queue_t queue; + iree_task_queue_initialize(&queue); + + // Make a lifo list: c<-b<-a. + iree_atomic_task_slist_t slist; + iree_atomic_task_slist_initialize(&slist); + iree_task_t task_a = {0}; + iree_atomic_task_slist_push(&slist, &task_a); + iree_task_t task_b = {0}; + iree_atomic_task_slist_push(&slist, &task_b); + iree_task_t task_c = {0}; + iree_atomic_task_slist_push(&slist, &task_c); + + // Flush the list to the queue; it should swap LIFO->FIFO and return the + // first task in the queue. + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + EXPECT_EQ(&task_a, iree_task_queue_flush_from_lifo_slist(&queue, &slist)); + EXPECT_FALSE(iree_task_queue_is_empty(&queue)); + + // Pop list and ensure order: [a->]b->c. + EXPECT_EQ(&task_b, iree_task_queue_pop_front(&queue)); + EXPECT_EQ(&task_c, iree_task_queue_pop_front(&queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&queue)); + + iree_atomic_task_slist_deinitialize(&slist); + + iree_task_queue_deinitialize(&queue); +} + +TEST(QueueTest, TryStealEmpty) { + iree_task_queue_t source_queue; + iree_task_queue_initialize(&source_queue); + iree_task_queue_t target_queue; + iree_task_queue_initialize(&target_queue); + + iree_task_t task_a = {0}; + iree_task_queue_push_front(&source_queue, &task_a); + iree_task_t task_b = {0}; + iree_task_queue_push_front(&source_queue, &task_b); + iree_task_t task_c = {0}; + iree_task_queue_push_front(&source_queue, &task_c); + + EXPECT_EQ(&task_a, + iree_task_queue_try_steal(&source_queue, &target_queue, 1)); + + iree_task_queue_deinitialize(&source_queue); + iree_task_queue_deinitialize(&target_queue); +} + +TEST(QueueTest, TryStealLast) { + iree_task_queue_t source_queue; + iree_task_queue_initialize(&source_queue); + iree_task_queue_t target_queue; + iree_task_queue_initialize(&target_queue); + + iree_task_t task_a = {0}; + iree_task_queue_push_front(&source_queue, &task_a); + + EXPECT_EQ(&task_a, + iree_task_queue_try_steal(&source_queue, &target_queue, 100)); + EXPECT_TRUE(iree_task_queue_is_empty(&target_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&source_queue)); + + iree_task_queue_deinitialize(&source_queue); + iree_task_queue_deinitialize(&target_queue); +} + +TEST(QueueTest, TrySteal1) { + iree_task_queue_t source_queue; + iree_task_queue_initialize(&source_queue); + iree_task_queue_t target_queue; + iree_task_queue_initialize(&target_queue); + + iree_task_t task_a = {0}; + iree_task_t task_b = {0}; + iree_task_t task_c = {0}; + iree_task_queue_push_front(&source_queue, &task_c); + iree_task_queue_push_front(&source_queue, &task_b); + iree_task_queue_push_front(&source_queue, &task_a); + + EXPECT_EQ(&task_c, + iree_task_queue_try_steal(&source_queue, &target_queue, 1)); + EXPECT_TRUE(iree_task_queue_is_empty(&target_queue)); + + EXPECT_EQ(&task_a, iree_task_queue_pop_front(&source_queue)); + EXPECT_EQ(&task_b, iree_task_queue_pop_front(&source_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&source_queue)); + + iree_task_queue_deinitialize(&source_queue); + iree_task_queue_deinitialize(&target_queue); +} + +TEST(QueueTest, TryStealIntoExisting) { + iree_task_queue_t source_queue; + iree_task_queue_initialize(&source_queue); + iree_task_queue_t target_queue; + iree_task_queue_initialize(&target_queue); + + iree_task_t task_a = {0}; + iree_task_t task_b = {0}; + iree_task_queue_push_front(&source_queue, &task_b); + iree_task_queue_push_front(&source_queue, &task_a); + + iree_task_t task_existing = {0}; + iree_task_queue_push_front(&target_queue, &task_existing); + + EXPECT_EQ(&task_existing, + iree_task_queue_try_steal(&source_queue, &target_queue, 1)); + + EXPECT_EQ(&task_a, iree_task_queue_pop_front(&source_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&source_queue)); + + EXPECT_EQ(&task_b, iree_task_queue_pop_front(&target_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&target_queue)); + + iree_task_queue_deinitialize(&source_queue); + iree_task_queue_deinitialize(&target_queue); +} + +TEST(QueueTest, TryStealMany) { + iree_task_queue_t source_queue; + iree_task_queue_initialize(&source_queue); + iree_task_queue_t target_queue; + iree_task_queue_initialize(&target_queue); + + iree_task_t task_a = {0}; + iree_task_t task_b = {0}; + iree_task_t task_c = {0}; + iree_task_t task_d = {0}; + iree_task_queue_push_front(&source_queue, &task_d); + iree_task_queue_push_front(&source_queue, &task_c); + iree_task_queue_push_front(&source_queue, &task_b); + iree_task_queue_push_front(&source_queue, &task_a); + + EXPECT_EQ(&task_c, + iree_task_queue_try_steal(&source_queue, &target_queue, 2)); + EXPECT_EQ(&task_d, iree_task_queue_pop_front(&target_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&target_queue)); + + EXPECT_EQ(&task_a, iree_task_queue_pop_front(&source_queue)); + EXPECT_EQ(&task_b, iree_task_queue_pop_front(&source_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&source_queue)); + + iree_task_queue_deinitialize(&source_queue); + iree_task_queue_deinitialize(&target_queue); +} + +TEST(QueueTest, TryStealAll) { + iree_task_queue_t source_queue; + iree_task_queue_initialize(&source_queue); + iree_task_queue_t target_queue; + iree_task_queue_initialize(&target_queue); + + iree_task_t task_a = {0}; + iree_task_t task_b = {0}; + iree_task_t task_c = {0}; + iree_task_t task_d = {0}; + iree_task_queue_push_front(&source_queue, &task_d); + iree_task_queue_push_front(&source_queue, &task_c); + iree_task_queue_push_front(&source_queue, &task_b); + iree_task_queue_push_front(&source_queue, &task_a); + + EXPECT_EQ(&task_c, + iree_task_queue_try_steal(&source_queue, &target_queue, 1000)); + EXPECT_EQ(&task_d, iree_task_queue_pop_front(&target_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&target_queue)); + + EXPECT_EQ(&task_a, iree_task_queue_pop_front(&source_queue)); + EXPECT_EQ(&task_b, iree_task_queue_pop_front(&source_queue)); + EXPECT_TRUE(iree_task_queue_is_empty(&source_queue)); + + iree_task_queue_deinitialize(&source_queue); + iree_task_queue_deinitialize(&target_queue); } } // namespace diff --git a/iree/task/scope.c b/iree/task/scope.c index f35a6870dda64..747cfb7a36583 100644 --- a/iree/task/scope.c +++ b/iree/task/scope.c @@ -14,6 +14,8 @@ #include "iree/task/scope.h" +#include "iree/base/debugging.h" + void iree_task_scope_initialize(iree_string_view_t name, iree_task_scope_t* out_scope) { IREE_TRACE_ZONE_BEGIN(z0); @@ -36,9 +38,10 @@ void iree_task_scope_initialize(iree_string_view_t name, void iree_task_scope_deinitialize(iree_task_scope_t* scope) { IREE_TRACE_ZONE_BEGIN(z0); - assert(iree_task_scope_is_idle(scope) && - "pending submissions must be aborted prior to deinitializing their " - "scope"); + IREE_ASSERT( + iree_task_scope_is_idle(scope), + "pending submissions must be aborted prior to deinitializing their " + "scope"); // Makes it easier to see if we were incorrectly using the name even after the // scope is deinitialized. Since scopes may be stack allocated we don't want @@ -46,8 +49,8 @@ void iree_task_scope_deinitialize(iree_task_scope_t* scope) { memset(scope->name, 0xCD, sizeof(scope->name)); // In most cases the status will have been consumed by the scope owner. - iree_status_t status = (iree_status_t)iree_atomic_exchange_ptr( - &scope->permanent_status, (uintptr_t)NULL, iree_memory_order_acquire); + iree_status_t status = (iree_status_t)iree_atomic_exchange_intptr( + &scope->permanent_status, (intptr_t)NULL, iree_memory_order_acquire); IREE_IGNORE_ERROR(status); iree_notification_deinitialize(&scope->idle_notification); @@ -55,6 +58,29 @@ void iree_task_scope_deinitialize(iree_task_scope_t* scope) { IREE_TRACE_ZONE_END(z0); } +iree_string_view_t iree_task_scope_name(iree_task_scope_t* scope) { + return iree_make_cstring_view(scope->name); +} + +iree_task_dispatch_statistics_t iree_task_scope_consume_statistics( + iree_task_scope_t* scope) { + iree_task_dispatch_statistics_t result = scope->dispatch_statistics; + memset(&scope->dispatch_statistics, 0, sizeof(scope->dispatch_statistics)); + return result; +} + +iree_status_t iree_task_scope_consume_status(iree_task_scope_t* scope) { + iree_status_t old_status = iree_ok_status(); + iree_status_t new_status = iree_ok_status(); + while (!iree_atomic_compare_exchange_strong_intptr( + &scope->permanent_status, (intptr_t*)&old_status, (intptr_t)new_status, + iree_memory_order_seq_cst, iree_memory_order_seq_cst)) { + // Previous status was not OK; we have it now though and can try again. + new_status = iree_status_from_code(iree_status_code(old_status)); + } + return old_status; +} + static void iree_task_scope_try_set_status(iree_task_scope_t* scope, iree_status_t new_status) { if (IREE_UNLIKELY(iree_status_is_ok(new_status))) return; @@ -65,16 +91,14 @@ static void iree_task_scope_try_set_status(iree_task_scope_t* scope, z0, iree_status_code_string(iree_status_code(new_status))); iree_status_t old_status = iree_ok_status(); - if (!iree_atomic_compare_exchange_strong_ptr( - &scope->permanent_status, (uintptr_t*)&old_status, - (uintptr_t)new_status, iree_memory_order_seq_cst, + if (!iree_atomic_compare_exchange_strong_intptr( + &scope->permanent_status, (intptr_t*)&old_status, + (intptr_t)new_status, iree_memory_order_seq_cst, iree_memory_order_seq_cst)) { // Previous status was not OK; drop our new status. IREE_IGNORE_ERROR(new_status); } - // TODO(#4026): poke to wake idle waiters. - IREE_TRACE_ZONE_END(z0); } @@ -90,39 +114,35 @@ void iree_task_scope_fail(iree_task_scope_t* scope, iree_task_t* task, iree_task_scope_try_set_status(scope, status); } -iree_status_t iree_task_scope_consume_status(iree_task_scope_t* scope) { - iree_status_t old_status = iree_ok_status(); - iree_status_t new_status = iree_ok_status(); - while (!iree_atomic_compare_exchange_strong_ptr( - &scope->permanent_status, (uintptr_t*)&old_status, (uintptr_t)new_status, - iree_memory_order_seq_cst, iree_memory_order_seq_cst)) { - // Previous status was not OK; we have it now though and can try again. - new_status = iree_status_from_code(iree_status_code(new_status)); - } - return old_status; -} - -iree_task_dispatch_statistics_t iree_task_scope_consume_statistics( - iree_task_scope_t* scope) { - iree_task_dispatch_statistics_t result = scope->dispatch_statistics; - memset(&scope->dispatch_statistics, 0, sizeof(scope->dispatch_statistics)); - return result; -} - bool iree_task_scope_is_idle(iree_task_scope_t* scope) { return iree_atomic_load_int32(&scope->pending_submissions, - iree_memory_order_relaxed) == 0; + iree_memory_order_acquire) == 0; } iree_status_t iree_task_scope_wait_idle(iree_task_scope_t* scope, iree_time_t deadline_ns) { IREE_TRACE_ZONE_BEGIN(z0); - // Wait for the scope to enter the idle state. - // NOTE: we are currently ignoring |deadline_ns|. - iree_notification_await(&scope->idle_notification, - (iree_condition_fn_t)iree_task_scope_is_idle, scope); + iree_status_t status = iree_ok_status(); + if (deadline_ns == IREE_TIME_INFINITE_PAST) { + // Polling for idle. + if (iree_task_scope_is_idle(scope)) { + status = iree_ok_status(); + } else { + status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + } else if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { + // Wait for the scope to enter the idle state. + iree_notification_await(&scope->idle_notification, + (iree_condition_fn_t)iree_task_scope_is_idle, + scope); + } else { + // NOTE: we are currently ignoring |deadline_ns|. + // We need to support timeouts on iree_notification_t to support this. + status = iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "scope-based waits do not yet support timeouts"); + } IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } diff --git a/iree/task/scope.h b/iree/task/scope.h index ec10a3453f1ec..4a1adafed08c3 100644 --- a/iree/task/scope.h +++ b/iree/task/scope.h @@ -58,7 +58,7 @@ typedef struct iree_task_scope_s { // A permanent status code set when a task within the scope fails. All pending // tasks will be cancelled, though any in-flight tasks may continue executing // to completion. - iree_atomic_ptr_t permanent_status; + iree_atomic_intptr_t permanent_status; // Dispatch statistics aggregated from all dispatches in this scope. Updated // relatively infrequently and must not be used for task control as values @@ -86,6 +86,20 @@ void iree_task_scope_initialize(iree_string_view_t name, // No tasks may be pending and the scope must be idle. void iree_task_scope_deinitialize(iree_task_scope_t* scope); +// Returns the name of the scope. Informational only and may be the empty +// string. +iree_string_view_t iree_task_scope_name(iree_task_scope_t* scope); + +// Returns and resets the statistics for the scope. +// Statistics may experience tearing (non-atomic update across fields) if this +// is performed while tasks are in-flight. +iree_task_dispatch_statistics_t iree_task_scope_consume_statistics( + iree_task_scope_t* scope); + +// Returns the permanent scope failure status to the caller (transfering +// ownership). The scope will remain in a failed state with the status code. +iree_status_t iree_task_scope_consume_status(iree_task_scope_t* scope); + // Marks the scope as having been aborted by the user with IREE_STATUS_ABORTED. // All pending tasks will be dropped though in-flight tasks may complete // execution. Callers must use iree_task_scope_wait_idle to ensure the scope @@ -102,16 +116,6 @@ void iree_task_scope_abort(iree_task_scope_t* scope); void iree_task_scope_fail(iree_task_scope_t* scope, iree_task_t* task, iree_status_t status); -// Returns the permanent scope failure status to the caller (transfering -// ownership). The scope will remain in a failed state with the status code. -iree_status_t iree_task_scope_consume_status(iree_task_scope_t* scope); - -// Returns and resets the statistics for the scope. -// Statistics may experience tearing (non-atomic update across fields) if this -// is performed while tasks are in-flight. -iree_task_dispatch_statistics_t iree_task_scope_consume_statistics( - iree_task_scope_t* scope); - // Returns true if the scope has no pending or in-flight tasks. // // May race with other threads enqueuing work and be out of date immediately diff --git a/iree/task/scope_test.cc b/iree/task/scope_test.cc index 9f42dd4964135..e56d2eea872fa 100644 --- a/iree/task/scope_test.cc +++ b/iree/task/scope_test.cc @@ -14,13 +14,242 @@ #include "iree/task/scope.h" +#include + +#include "iree/task/submission.h" +#include "iree/task/task_impl.h" #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" namespace { -TEST(ScopeTest, Any) { - // TODO(benvanik): tests. +TEST(ScopeTest, Lifetime) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + iree_task_scope_deinitialize(&scope); +} + +// NOTE: the exact capacity (and whether we store the name at all) is an +// implementation detail. +TEST(ScopeTest, LongNameTruncation) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("01234567890123456789"), + &scope); + EXPECT_TRUE(iree_string_view_equal(iree_make_cstring_view("012345678901234"), + iree_task_scope_name(&scope))); + iree_task_scope_deinitialize(&scope); +} + +TEST(ScopeTest, AbortEmpty) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + + // Current state is OK. + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok(iree_task_scope_consume_status(&scope))); + + // Enter aborted state. + iree_task_scope_abort(&scope); + iree_status_t consumed_status = iree_task_scope_consume_status(&scope); + EXPECT_TRUE(iree_status_is_aborted(consumed_status)); + iree_status_ignore(consumed_status); + + // Ensure aborted state is sticky. + EXPECT_TRUE(iree_status_is_aborted(iree_task_scope_consume_status(&scope))); + + iree_task_scope_deinitialize(&scope); +} + +TEST(ScopeTest, FailEmpty) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + + // Current state is OK. + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok(iree_task_scope_consume_status(&scope))); + + // Enter failure state. + iree_task_t failed_task = {0}; + failed_task.scope = &scope; + iree_task_scope_fail(&scope, &failed_task, + iree_make_status(IREE_STATUS_DATA_LOSS, "whoops!")); + iree_status_t consumed_status = iree_task_scope_consume_status(&scope); + EXPECT_TRUE(iree_status_is_data_loss(consumed_status)); + iree_status_ignore(consumed_status); + + // Ensure failure state is sticky. + EXPECT_TRUE(iree_status_is_data_loss(iree_task_scope_consume_status(&scope))); + + iree_task_scope_deinitialize(&scope); +} + +// NOTE: only the first failure is recorded and made sticky; subsequent failure +// calls are ignored. +TEST(ScopeTest, FailAgain) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + + // Current state is OK. + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok(iree_task_scope_consume_status(&scope))); + + // Enter initial failure state. + iree_task_t failed_task_a = {0}; + failed_task_a.scope = &scope; + iree_task_scope_fail(&scope, &failed_task_a, + iree_make_status(IREE_STATUS_DATA_LOSS, "whoops 1")); + iree_status_t consumed_status_a = iree_task_scope_consume_status(&scope); + EXPECT_TRUE(iree_status_is_data_loss(consumed_status_a)); + iree_status_ignore(consumed_status_a); + + // Ensure failure s tate is sticky. + EXPECT_TRUE(iree_status_is_data_loss(iree_task_scope_consume_status(&scope))); + + // Try failing again - it should be ignored and correctly iree_status_free'd. + iree_task_t failed_task_b = {0}; + failed_task_b.scope = &scope; + iree_task_scope_fail( + &scope, &failed_task_b, + iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "whoops 2")); + iree_status_t consumed_status_b = iree_task_scope_consume_status(&scope); + EXPECT_TRUE(iree_status_is_data_loss(consumed_status_b)); + iree_status_ignore(consumed_status_b); + + // Still the first failure status. + EXPECT_TRUE(iree_status_is_data_loss(iree_task_scope_consume_status(&scope))); + + iree_task_scope_deinitialize(&scope); +} + +TEST(ScopeTest, WaitIdleWhenIdle) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + + // Current state is OK and idle. + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok(iree_task_scope_consume_status(&scope))); + + // Wait until idle... which is now. + EXPECT_TRUE(iree_status_is_ok( + iree_task_scope_wait_idle(&scope, IREE_TIME_INFINITE_FUTURE))); + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + + iree_task_scope_deinitialize(&scope); +} + +TEST(ScopeTest, WaitIdleDeadlineExceeded) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + + // Current state is OK and idle. + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok(iree_task_scope_consume_status(&scope))); + + // Enqueue a task to the scope so it is no longer idle. + iree_task_fence_t fence_task; + iree_task_fence_initialize(&scope, &fence_task); + EXPECT_FALSE(iree_task_scope_is_idle(&scope)); + + // Poll, which should fail immediately because we have the outstanding task. + iree_status_t wait_status = + iree_task_scope_wait_idle(&scope, IREE_TIME_INFINITE_PAST); + EXPECT_TRUE(iree_status_is_deadline_exceeded(wait_status)); + EXPECT_FALSE(iree_task_scope_is_idle(&scope)); + + // Complete the task (required as part of the scope contract). + iree_task_submission_t pending_submission; + iree_task_submission_initialize(&pending_submission); + iree_task_fence_retire(&fence_task, &pending_submission); + EXPECT_TRUE(iree_task_submission_is_empty(&pending_submission)); + + iree_task_scope_deinitialize(&scope); +} + +TEST(ScopeTest, WaitIdleSuccess) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + + // Current state is OK and idle. + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok(iree_task_scope_consume_status(&scope))); + + // Enqueue a task to the scope so it is no longer idle. + iree_task_fence_t fence_task; + iree_task_fence_initialize(&scope, &fence_task); + EXPECT_FALSE(iree_task_scope_is_idle(&scope)); + + // Spin up a thread to wait on the scope. + std::thread wait_thread([&]() { + EXPECT_FALSE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok( + iree_task_scope_wait_idle(&scope, IREE_TIME_INFINITE_FUTURE))); + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + }); + + // Wait a moment for the thread to spin up. + // NOTE: this may flake. Need to see if there's a better way to do this. + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // Complete the task. + iree_task_submission_t pending_submission; + iree_task_submission_initialize(&pending_submission); + iree_task_fence_retire(&fence_task, &pending_submission); + EXPECT_TRUE(iree_task_submission_is_empty(&pending_submission)); + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + + // Join with the thread - this will hang if it didn't wake correctly. + wait_thread.join(); + + iree_task_scope_deinitialize(&scope); +} + +TEST(ScopeTest, WaitIdleFailure) { + iree_task_scope_t scope; + iree_task_scope_initialize(iree_make_cstring_view("scope_a"), &scope); + + // Current state is OK and idle. + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok(iree_task_scope_consume_status(&scope))); + + // Enqueue a task to the scope so it is no longer idle. + iree_task_fence_t fence_task; + iree_task_fence_initialize(&scope, &fence_task); + EXPECT_FALSE(iree_task_scope_is_idle(&scope)); + + // Spin up a thread to wait on the scope. + std::thread wait_thread([&]() { + EXPECT_FALSE(iree_task_scope_is_idle(&scope)); + EXPECT_TRUE(iree_status_is_ok( + iree_task_scope_wait_idle(&scope, IREE_TIME_INFINITE_FUTURE))); + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + }); + + // Wait a moment for the thread to spin up. + // NOTE: this may flake. Need to see if there's a better way to do this. + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // Set the failure state. + iree_task_scope_fail( + &scope, &fence_task.header, + iree_make_status(IREE_STATUS_FAILED_PRECONDITION, "whoops")); + EXPECT_FALSE(iree_task_scope_is_idle(&scope)); + + // Complete the task. + // Note that even if a scope fails we still must complete the tasks so it + // becomes idle. This ensures that if the scope state is used to control + // deallocation we don't go deallocating the tasks still in flight and waiting + // to gracefully fail. + iree_task_submission_t pending_submission; + iree_task_submission_initialize(&pending_submission); + iree_task_fence_retire(&fence_task, &pending_submission); + EXPECT_TRUE(iree_task_submission_is_empty(&pending_submission)); + EXPECT_TRUE(iree_task_scope_is_idle(&scope)); + + // Join with the thread - this will hang if it didn't wake correctly. + wait_thread.join(); + + iree_task_scope_deinitialize(&scope); } } // namespace diff --git a/iree/task/submission.c b/iree/task/submission.c index 6a1aefe2e8406..2dfc24d5e468d 100644 --- a/iree/task/submission.c +++ b/iree/task/submission.c @@ -63,3 +63,14 @@ void iree_task_submission_enqueue(iree_task_submission_t* submission, iree_task_list_push_front(&submission->ready_list, task); } } + +void iree_task_submission_enqueue_list(iree_task_submission_t* submission, + iree_task_list_t* list) { + iree_task_t* task = list->head; + list->head = list->tail = NULL; + while (task) { + iree_task_t* next = task->next_task; + iree_task_submission_enqueue(submission, task); + task = next; + } +} diff --git a/iree/task/submission.h b/iree/task/submission.h index 0a1e84ee7fa08..1ff23d7d34dde 100644 --- a/iree/task/submission.h +++ b/iree/task/submission.h @@ -44,7 +44,7 @@ extern "C" { // // Thread-compatible; designed to be used from a single thread producing the // submission. -typedef struct { +typedef struct iree_task_submission_s { // List of tasks that are ready for execution immediately. Upon submission to // a queue the tasks will be passed on to the executor with no delay. // @@ -95,6 +95,12 @@ bool iree_task_submission_is_empty(iree_task_submission_t* submission); void iree_task_submission_enqueue(iree_task_submission_t* submission, iree_task_t* task); +// Enqueues all tasks in |list| to the pending |submission|. +// Ownership of the tasks transfers to the submission and the |list| will be +// reset upon return. Ready tasks may execute in any order. +void iree_task_submission_enqueue_list(iree_task_submission_t* submission, + iree_task_list_t* list); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/iree/task/task.c b/iree/task/task.c index 38c708b49fb81..0a46b1346df73 100644 --- a/iree/task/task.c +++ b/iree/task/task.c @@ -16,6 +16,7 @@ #include +#include "iree/base/debugging.h" #include "iree/task/task_impl.h" //============================================================================== @@ -31,9 +32,14 @@ void iree_task_initialize(iree_task_type_t type, iree_task_scope_t* scope, out_task->type = type; } +void iree_task_set_cleanup_fn(iree_task_t* task, + iree_task_cleanup_fn_t cleanup_fn) { + task->cleanup_fn = cleanup_fn; +} + void iree_task_set_completion_task(iree_task_t* task, iree_task_t* completion_task) { - assert(!task->completion_task); + IREE_ASSERT(!task->completion_task); task->completion_task = completion_task; iree_atomic_fetch_add_int32(&completion_task->pending_dependency_count, 1, iree_memory_order_seq_cst); @@ -48,7 +54,25 @@ bool iree_task_is_ready(iree_task_t* task) { return true; } +static void iree_task_cleanup(iree_task_t* task, iree_status_t status) { + // Call the (optional) cleanup function. + // NOTE: this may free the memory of the task itself! + iree_task_pool_t* pool = task->pool; + if (task->cleanup_fn) { + task->cleanup_fn(task, iree_ok_status()); + } + + // Return the task to the pool it was allocated from. + // Some tasks are allocated as part of arenas/ringbuffers and won't have a + // pool as they'll be cleaned up as part of a larger operation. + if (pool) { + iree_task_pool_release(pool, task); + } +} + void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) { + IREE_TRACE_ZONE_BEGIN(z0); + // NOTE: we always try adding to the head of the discard_worklist so that // we hopefully get some locality benefits. This models a DFS discard in // our non-recursive approach. @@ -76,7 +100,7 @@ void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) { // TODO(benvanik): signal as error. // iree_task_fence_t* fence_task = (iree_task_fence_t*)task; iree_atomic_fetch_sub_int32(&task->scope->pending_submissions, 1, - iree_memory_order_relaxed); + iree_memory_order_release); break; } case IREE_TASK_TYPE_WAIT: @@ -85,34 +109,29 @@ void iree_task_discard(iree_task_t* task, iree_task_list_t* discard_worklist) { break; } - // Release the task back to the pool it was allocated from, if any. - // Some tasks are allocated from arenas and may not be able to be freed - // individually. - if (task->pool) { - iree_task_pool_release(task->pool, task); - } - + iree_task_cleanup(task, iree_status_from_code(IREE_STATUS_ABORTED)); // NOTE: task is invalidated here and cannot be used! + + IREE_TRACE_ZONE_END(z0); } static void iree_task_retire(iree_task_t* task, iree_task_submission_t* pending_submission) { + IREE_ASSERT_EQ(0, iree_atomic_load_int32(&task->pending_dependency_count, + iree_memory_order_acquire)); + // Decrement the pending count on the completion task, if any. iree_task_t* completion_task = task->completion_task; + task->completion_task = NULL; if (completion_task && iree_atomic_fetch_sub_int32(&completion_task->pending_dependency_count, 1, iree_memory_order_acq_rel) == 1) { // The completion task has retired and can now be made ready. iree_task_submission_enqueue(pending_submission, completion_task); } - task->completion_task = NULL; - // Return the task to the pool it was allocated from. - // Some tasks are allocated as part of arenas/ringbuffers and won't have a - // pool as they'll be cleaned up as part of a larger operation. - if (task->pool) { - iree_task_pool_release(task->pool, task); - } + iree_task_cleanup(task, iree_ok_status()); + // NOTE: task is invalidated here and cannot be used! } //============================================================================== @@ -124,12 +143,17 @@ void iree_task_nop_initialize(iree_task_scope_t* scope, iree_task_initialize(IREE_TASK_TYPE_NOP, scope, &out_task->header); } +void iree_task_nop_retire(iree_task_nop_t* task, + iree_task_submission_t* pending_submission) { + iree_task_retire(&task->header, pending_submission); +} + //============================================================================== // IREE_TASK_TYPE_CALL //============================================================================== void iree_task_call_initialize(iree_task_scope_t* scope, - iree_task_closure_t closure, + iree_task_call_closure_t closure, iree_task_call_t* out_task) { iree_task_initialize(IREE_TASK_TYPE_CALL, scope, &out_task->header); out_task->closure = closure; @@ -139,10 +163,16 @@ iree_status_t iree_task_call_execute( iree_task_call_t* task, iree_task_submission_t* pending_submission) { IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = - task->closure.fn(task->closure.user_context, /*task_context=*/0); + // Execute the user callback. + // Note that this may enqueue more nested tasks, including tasks that prevent + // this task from retiring. + iree_status_t status = task->closure.fn(task->closure.user_context, + &task->header, pending_submission); + if (iree_atomic_load_int32(&task->header.pending_dependency_count, + iree_memory_order_acquire) == 0) { + iree_task_retire(&task->header, pending_submission); + } - iree_task_retire(&task->header, pending_submission); IREE_TRACE_ZONE_END(z0); return status; } @@ -165,6 +195,25 @@ void iree_task_barrier_initialize(iree_task_scope_t* scope, } } +void iree_task_barrier_initialize_empty(iree_task_scope_t* scope, + iree_task_barrier_t* out_task) { + iree_task_initialize(IREE_TASK_TYPE_BARRIER, scope, &out_task->header); + out_task->dependent_task_count = 0; + out_task->dependent_tasks = NULL; +} + +void iree_task_barrier_set_dependent_tasks( + iree_task_barrier_t* task, iree_host_size_t dependent_task_count, + iree_task_t* const* dependent_tasks) { + task->dependent_task_count = dependent_task_count; + task->dependent_tasks = dependent_tasks; + for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) { + iree_task_t* dependent_task = task->dependent_tasks[i]; + iree_atomic_fetch_add_int32(&dependent_task->pending_dependency_count, 1, + iree_memory_order_relaxed); + } +} + void iree_task_barrier_retire(iree_task_barrier_t* task, iree_task_submission_t* pending_submission) { IREE_TRACE_ZONE_BEGIN(z0); @@ -192,7 +241,7 @@ void iree_task_fence_initialize(iree_task_scope_t* scope, iree_task_fence_t* out_task) { iree_task_initialize(IREE_TASK_TYPE_FENCE, scope, &out_task->header); iree_atomic_fetch_add_int32(&scope->pending_submissions, 1, - iree_memory_order_relaxed); + iree_memory_order_release); } void iree_task_fence_retire(iree_task_fence_t* task, @@ -317,10 +366,9 @@ void iree_task_dispatch_statistics_merge( // IREE_TASK_TYPE_DISPATCH //============================================================================== -static void iree_task_dispatch_initialize_base(iree_task_scope_t* scope, - iree_task_closure_t closure, - const uint32_t workgroup_size[3], - iree_task_dispatch_t* out_task) { +static void iree_task_dispatch_initialize_base( + iree_task_scope_t* scope, iree_task_dispatch_closure_t closure, + const uint32_t workgroup_size[3], iree_task_dispatch_t* out_task) { iree_task_initialize(IREE_TASK_TYPE_DISPATCH, scope, &out_task->header); out_task->closure = closure; memcpy(out_task->workgroup_size, workgroup_size, @@ -330,7 +378,7 @@ static void iree_task_dispatch_initialize_base(iree_task_scope_t* scope, } void iree_task_dispatch_initialize(iree_task_scope_t* scope, - iree_task_closure_t closure, + iree_task_dispatch_closure_t closure, const uint32_t workgroup_size[3], const uint32_t workgroup_count[3], iree_task_dispatch_t* out_task) { @@ -339,11 +387,10 @@ void iree_task_dispatch_initialize(iree_task_scope_t* scope, sizeof(out_task->workgroup_count.value)); } -void iree_task_dispatch_initialize_indirect(iree_task_scope_t* scope, - iree_task_closure_t closure, - const uint32_t workgroup_size[3], - const uint32_t* workgroup_count_ptr, - iree_task_dispatch_t* out_task) { +void iree_task_dispatch_initialize_indirect( + iree_task_scope_t* scope, iree_task_dispatch_closure_t closure, + const uint32_t workgroup_size[3], const uint32_t* workgroup_count_ptr, + iree_task_dispatch_t* out_task) { iree_task_dispatch_initialize_base(scope, closure, workgroup_size, out_task); out_task->header.flags |= IREE_TASK_FLAG_DISPATCH_INDIRECT; out_task->workgroup_count.ptr = workgroup_count_ptr; @@ -370,6 +417,14 @@ void iree_task_dispatch_issue_sliced(iree_task_dispatch_t* dispatch_task, memcpy(workgroup_count, dispatch_task->workgroup_count.value, sizeof(workgroup_count)); } + uint32_t total_workgroup_count = + workgroup_count[0] * workgroup_count[1] * workgroup_count[2]; + if (total_workgroup_count == 0) { + // No workgroups to execute - bail early. + iree_task_dispatch_retire(dispatch_task, pending_submission); + IREE_TRACE_ZONE_END(z0); + return; + } #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION char xyz_string[32]; @@ -411,12 +466,15 @@ void iree_task_dispatch_issue_sliced(iree_task_dispatch_t* dispatch_task, workgroup_base[1] = slice_y * tiles_per_slice_y; workgroup_base[2] = slice_z * tiles_per_slice_z; uint32_t workgroup_range[3]; - workgroup_range[0] = iree_min( - workgroup_count[0], workgroup_base[0] + tiles_per_slice_x - 1); - workgroup_range[1] = iree_min( - workgroup_count[1], workgroup_base[1] + tiles_per_slice_y - 1); - workgroup_range[2] = iree_min( - workgroup_count[2], workgroup_base[2] + tiles_per_slice_z - 1); + workgroup_range[0] = iree_min(workgroup_count[0], + workgroup_base[0] + tiles_per_slice_x) - + 1; + workgroup_range[1] = iree_min(workgroup_count[1], + workgroup_base[1] + tiles_per_slice_y) - + 1; + workgroup_range[2] = iree_min(workgroup_count[2], + workgroup_base[2] + tiles_per_slice_z) - + 1; // Allocate and initialize the slice. iree_task_dispatch_slice_t* slice_task = @@ -653,8 +711,8 @@ iree_status_t iree_task_dispatch_slice_execute( IREE_TRACE_ZONE_APPEND_VALUE(z_tile, z); // IREE_TRACE_ZONE_APPEND_VALUE(z_tile, (uint64_t)task->closure.fn); - iree_status_t status = task->closure.fn(task->closure.user_context, - (uintptr_t)&tile_context); + iree_status_t status = task->closure.fn( + task->closure.user_context, &tile_context, pending_submission); IREE_TRACE_ZONE_END(z_tile); if (IREE_UNLIKELY(!iree_status_is_ok(status))) { @@ -751,10 +809,10 @@ iree_status_t iree_task_dispatch_shard_execute( // TODO(benvanik): faster math here, especially knowing we pull off N // sequential indices per reservation. uint32_t tile_i = tile_index; - tile_context.workgroup_xyz[0] = tile_i % (workgroup_count_x + 1); - tile_i /= (workgroup_count_x + 1); - tile_context.workgroup_xyz[1] = tile_i % (workgroup_count_y + 1); - tile_i /= (workgroup_count_y + 1); + tile_context.workgroup_xyz[0] = tile_i % workgroup_count_x; + tile_i /= workgroup_count_x; + tile_context.workgroup_xyz[1] = tile_i % workgroup_count_y; + tile_i /= workgroup_count_y; tile_context.workgroup_xyz[2] = tile_i; IREE_TRACE_ZONE_BEGIN_NAMED(z_tile, @@ -768,8 +826,9 @@ iree_status_t iree_task_dispatch_shard_execute( IREE_TRACE_ZONE_APPEND_VALUE(z_tile, tile_context.workgroup_xyz[2]); // IREE_TRACE_ZONE_APPEND_VALUE(z_tile, (uint64_t)task->closure.fn); - iree_status_t status = dispatch_task->closure.fn( - dispatch_task->closure.user_context, (uintptr_t)&tile_context); + iree_status_t status = + dispatch_task->closure.fn(dispatch_task->closure.user_context, + &tile_context, pending_submission); IREE_TRACE_ZONE_END(z_tile); if (IREE_UNLIKELY(!iree_status_is_ok(status))) { diff --git a/iree/task/task.h b/iree/task/task.h index 35a632fba04dc..cc4ffbfc86c96 100644 --- a/iree/task/task.h +++ b/iree/task/task.h @@ -29,37 +29,7 @@ extern "C" { typedef struct iree_task_list_s iree_task_list_t; typedef struct iree_task_pool_s iree_task_pool_t; typedef struct iree_task_scope_s iree_task_scope_t; - -//============================================================================== -// Function closures -//============================================================================== - -typedef iree_status_t(IREE_API_PTR* iree_task_closure_fn_t)( - uintptr_t user_context, uintptr_t task_context); - -// A function closure representing the function to call and its arguments. -typedef struct { - // Function called per tile invocation. - iree_task_closure_fn_t fn; - - // User-defined argument passed to task functions during invocation. - // Opaque pointer-sized values that could point to user data structures or - // contain embedded values. No lifetime management is performed by the task - // system and it is required that users ensure that the memory referenced is - // live until after the task has completed. - uintptr_t user_context; - - // TODO(benvanik): cleanup function? right now assume arg is never freed. -} iree_task_closure_t; - -// Binds a function pointer and the arguments it should be called with. -// If the arguments represent pointers they must remain live until the task -// has completed execution. -static inline iree_task_closure_t iree_task_make_closure( - iree_task_closure_fn_t fn, uintptr_t user_context) { - iree_task_closure_t closure = {fn, user_context}; - return closure; -} +typedef struct iree_task_submission_s iree_task_submission_t; //============================================================================== // Task header for internal tracking @@ -151,6 +121,12 @@ typedef uint16_t iree_task_flags_t; typedef struct iree_task_s iree_task_t; +// A function called to cleanup tasks. +// The provided |status| is unowned and must be cloned if used beyond the scope +// of the cleanup function (such as when stored for later usage). +typedef void(IREE_API_PTR* iree_task_cleanup_fn_t)(iree_task_t* task, + iree_status_t status); + // A task within the task system that runs on an executor. // Tasks have an iree_task_type_t that defines which parameters are valid and // how the executor is to treat the task. Dependency edges can be defined that @@ -166,6 +142,11 @@ struct iree_alignas(iree_max_align_t) iree_task_s { // be skipped. iree_task_scope_t* scope; + // Optional function to call to cleanup the task on completion. + // Will be called after the task has retired or if the task fails to issue + // (dependency failed, etc). + iree_task_cleanup_fn_t cleanup_fn; + // Optional task that will be notified when the task completes. // The task will have its pending_dependency_count decremented and will be // readied for execution when the count reaches 0. @@ -207,6 +188,11 @@ static_assert(offsetof(iree_task_t, next_task) == 0, void iree_task_initialize(iree_task_type_t type, iree_task_scope_t* scope, iree_task_t* out_task); +// Sets the optional function called when the task completes (whether successful +// or not). +void iree_task_set_cleanup_fn(iree_task_t* task, + iree_task_cleanup_fn_t cleanup_fn); + // Sets up a dependency edge from |task| to |completion_task| such that when // |task| completes |completion_task| will be notified and have its // pending_dependency_count decremented. @@ -248,6 +234,34 @@ void iree_task_nop_initialize(iree_task_scope_t* scope, // IREE_TASK_TYPE_CALL //============================================================================== +typedef iree_status_t(IREE_API_PTR* iree_task_call_closure_fn_t)( + uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission); + +// A function closure representing the function to call and its arguments. +typedef struct { + // Function called per tile invocation. + iree_task_call_closure_fn_t fn; + + // User-defined argument passed to task functions during invocation. + // Opaque pointer-sized values that could point to user data structures or + // contain embedded values. No lifetime management is performed by the task + // system and it is required that users ensure that the memory referenced is + // live until after the task has completed. + uintptr_t user_context; + + // TODO(benvanik): cleanup function? right now assume arg is never freed. +} iree_task_call_closure_t; + +// Binds a function pointer and the arguments it should be called with. +// If the arguments represent pointers they must remain live until the task +// has completed execution. +static inline iree_task_call_closure_t iree_task_make_call_closure( + iree_task_call_closure_fn_t fn, uintptr_t user_context) { + iree_task_call_closure_t closure = {fn, user_context}; + return closure; +} + // A task that will synchronously call a function from the executor and wait // for it to complete before continuing. // @@ -258,11 +272,11 @@ typedef iree_alignas(iree_max_align_t) struct { iree_task_t header; // Function closure to call when the task is executed. - iree_task_closure_t closure; + iree_task_call_closure_t closure; } iree_task_call_t; void iree_task_call_initialize(iree_task_scope_t* scope, - iree_task_closure_t closure, + iree_task_call_closure_t closure, iree_task_call_t* out_task); //============================================================================== @@ -310,6 +324,13 @@ void iree_task_barrier_initialize(iree_task_scope_t* scope, iree_task_t* const* dependent_tasks, iree_task_barrier_t* out_task); +void iree_task_barrier_initialize_empty(iree_task_scope_t* scope, + iree_task_barrier_t* out_task); + +void iree_task_barrier_set_dependent_tasks( + iree_task_barrier_t* task, iree_host_size_t dependent_task_count, + iree_task_t* const* dependent_tasks); + //============================================================================== // IREE_TASK_TYPE_FENCE //============================================================================== @@ -341,6 +362,7 @@ typedef struct { iree_task_t header; // The external wait handle that the task is waiting on. + // TODO(benvanik): multiple wait handles. iree_wait_handle_t wait_handle; // TODO(benvanik): deadline_ns. @@ -447,6 +469,36 @@ typedef iree_alignas(iree_max_align_t) struct { iree_byte_span_t shared_memory; } iree_task_dispatch_shard_state_t; +//============================================================================== +// Dispatch function closures +//============================================================================== + +typedef iree_status_t(IREE_API_PTR* iree_task_dispatch_closure_fn_t)( + uintptr_t user_context, const iree_task_tile_context_t* tile_context, + iree_task_submission_t* pending_submission); + +// A function closure representing the function to call and its arguments. +typedef struct { + // Function called per tile invocation. + iree_task_dispatch_closure_fn_t fn; + + // User-defined argument passed to task functions during invocation. + // Opaque pointer-sized values that could point to user data structures or + // contain embedded values. No lifetime management is performed by the task + // system and it is required that users ensure that the memory referenced is + // live until after the task has completed. + uintptr_t user_context; +} iree_task_dispatch_closure_t; + +// Binds a function pointer and the arguments it should be called with. +// If the arguments represent pointers they must remain live until the task +// has completed execution. +static inline iree_task_dispatch_closure_t iree_task_make_dispatch_closure( + iree_task_dispatch_closure_fn_t fn, uintptr_t user_context) { + iree_task_dispatch_closure_t closure = {fn, user_context}; + return closure; +} + //============================================================================== // IREE_TASK_TYPE_DISPATCH //============================================================================== @@ -473,7 +525,7 @@ typedef iree_alignas(iree_max_align_t) struct iree_task_dispatch_s { iree_task_t header; // Function closure to call per tile. - iree_task_closure_t closure; + iree_task_dispatch_closure_t closure; // Workgroup size for each invocation. Passed on to tiles without // modification and not used for scheduling. @@ -506,16 +558,15 @@ typedef iree_alignas(iree_max_align_t) struct iree_task_dispatch_s { } iree_task_dispatch_t; void iree_task_dispatch_initialize(iree_task_scope_t* scope, - iree_task_closure_t closure, + iree_task_dispatch_closure_t closure, const uint32_t workgroup_size[3], const uint32_t workgroup_count[3], iree_task_dispatch_t* out_task); -void iree_task_dispatch_initialize_indirect(iree_task_scope_t* scope, - iree_task_closure_t closure, - const uint32_t workgroup_size[3], - const uint32_t* workgroup_count_ptr, - iree_task_dispatch_t* out_task); +void iree_task_dispatch_initialize_indirect( + iree_task_scope_t* scope, iree_task_dispatch_closure_t closure, + const uint32_t workgroup_size[3], const uint32_t* workgroup_count_ptr, + iree_task_dispatch_t* out_task); //============================================================================== // IREE_TASK_TYPE_DISPATCH_SLICE @@ -552,7 +603,7 @@ typedef iree_alignas(iree_max_align_t) struct { // tile which would likely be a cache miss as we fan out to other cores. // Function closure to call per tile (same as the closure in the dispatch). - iree_task_closure_t closure; + iree_task_dispatch_closure_t closure; // Base workgroup ID for the slice range. uint32_t workgroup_base[3]; diff --git a/iree/task/task_impl.h b/iree/task/task_impl.h index 63209989cfceb..8682a7fa53f3c 100644 --- a/iree/task/task_impl.h +++ b/iree/task/task_impl.h @@ -29,6 +29,12 @@ extern "C" { // IREE_TASK_TYPE_NOP //============================================================================== +// Retires a no-op task. +// No-op tasks don't *do* anything but must still be handled like any other +// task in the system so dependent tasks are properly scheduled. +void iree_task_nop_retire(iree_task_nop_t* task, + iree_task_submission_t* pending_submission); + //============================================================================== // IREE_TASK_TYPE_CALL //============================================================================== diff --git a/iree/task/task_test_barrier.cc b/iree/task/task_test_barrier.cc new file mode 100644 index 0000000000000..c834348876dc7 --- /dev/null +++ b/iree/task/task_test_barrier.cc @@ -0,0 +1,153 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "iree/task/testing/task_test.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace { + +class TaskBarrierTest : public TaskTest {}; + +enum { + TASK_A = 1 << 0, + TASK_B = 1 << 1, + TASK_C = 1 << 2, + TASK_D = 1 << 3, +}; + +// We track which tasks were successfully executed +struct TaskCtx { + std::atomic tasks_called = {0}; +}; + +#define MAKE_CALL_TASK_CLOSURE(task_ctx, task_id) \ + iree_task_make_call_closure( \ + [](uintptr_t user_context, iree_task_t* task, \ + iree_task_submission_t* pending_submission) { \ + auto* ctx = (TaskCtx*)user_context; \ + EXPECT_EQ(0, (ctx->tasks_called & (task_id))); \ + ctx->tasks_called |= (task_id); \ + return iree_ok_status(); \ + }, \ + (uintptr_t)task_ctx) + +// Issues a standalone empty barrier: +// { barrier } +TEST_F(TaskBarrierTest, IssueStandalone) { + iree_task_barrier_t barrier_task; + iree_task_barrier_initialize_empty(&scope_, &barrier_task); + IREE_ASSERT_OK( + SubmitTasksAndWaitIdle(&barrier_task.header, &barrier_task.header)); +} + +// Issues a serialized sequence: +// { a | barrier | b } +TEST_F(TaskBarrierTest, IssueSerializedSequence) { + TaskCtx task_ctx; + + iree_task_call_t task_a; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_A), + &task_a); + iree_task_call_t task_b; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_B), + &task_b); + + iree_task_t* dependent_tasks[1] = {&task_b.header}; + iree_task_barrier_t barrier_task; + iree_task_barrier_initialize(&scope_, IREE_ARRAYSIZE(dependent_tasks), + dependent_tasks, &barrier_task); + iree_task_set_completion_task(&task_a.header, &barrier_task.header); + + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task_a.header, &task_b.header)); + EXPECT_EQ(TASK_A | TASK_B, task_ctx.tasks_called); +} + +// Issues a join: +// { a, b, c | barrier | d } +TEST_F(TaskBarrierTest, IssueJoin) { + TaskCtx task_ctx; + + iree_task_call_t task_a; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_A), + &task_a); + iree_task_call_t task_b; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_B), + &task_b); + iree_task_call_t task_c; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_C), + &task_c); + iree_task_call_t task_d; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_D), + &task_d); + + iree_task_t* dependent_tasks[1] = {&task_d.header}; + iree_task_barrier_t barrier_task; + iree_task_barrier_initialize(&scope_, IREE_ARRAYSIZE(dependent_tasks), + dependent_tasks, &barrier_task); + iree_task_set_completion_task(&task_a.header, &barrier_task.header); + iree_task_set_completion_task(&task_b.header, &barrier_task.header); + iree_task_set_completion_task(&task_c.header, &barrier_task.header); + + iree_task_submission_t submission; + iree_task_submission_initialize(&submission); + iree_task_submission_enqueue(&submission, &task_a.header); + iree_task_submission_enqueue(&submission, &task_b.header); + iree_task_submission_enqueue(&submission, &task_c.header); + IREE_ASSERT_OK(SubmitAndWaitIdle(&submission, &task_d.header)); + EXPECT_EQ(TASK_A | TASK_B | TASK_C | TASK_D, task_ctx.tasks_called); +} + +// Issues a fork: +// { a | barrier | b, c, d | nop } +TEST_F(TaskBarrierTest, IssueFork) { + TaskCtx task_ctx; + + iree_task_call_t task_a; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_A), + &task_a); + iree_task_call_t task_b; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_B), + &task_b); + iree_task_call_t task_c; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_C), + &task_c); + iree_task_call_t task_d; + iree_task_call_initialize(&scope_, MAKE_CALL_TASK_CLOSURE(&task_ctx, TASK_D), + &task_d); + + iree_task_t* dependent_tasks[3] = { + &task_b.header, + &task_c.header, + &task_d.header, + }; + iree_task_barrier_t barrier_task; + iree_task_barrier_initialize(&scope_, IREE_ARRAYSIZE(dependent_tasks), + dependent_tasks, &barrier_task); + iree_task_set_completion_task(&task_a.header, &barrier_task.header); + + // Just to give us a tail task to wait on. + iree_task_nop_t nop_task; + iree_task_nop_initialize(&scope_, &nop_task); + iree_task_set_completion_task(&task_b.header, &nop_task.header); + iree_task_set_completion_task(&task_c.header, &nop_task.header); + iree_task_set_completion_task(&task_d.header, &nop_task.header); + + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task_a.header, &nop_task.header)); + EXPECT_EQ(TASK_A | TASK_B | TASK_C | TASK_D, task_ctx.tasks_called); +} + +} // namespace diff --git a/iree/task/task_test_call.cc b/iree/task/task_test_call.cc new file mode 100644 index 0000000000000..8c4ad43596736 --- /dev/null +++ b/iree/task/task_test_call.cc @@ -0,0 +1,108 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "iree/task/testing/task_test.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace { + +class TaskCallTest : public TaskTest {}; + +TEST_F(TaskCallTest, Issue) { + struct TestCtx { + int did_call = 0; + }; + TestCtx ctx; + + iree_task_call_t task; + iree_task_call_initialize(&scope_, + iree_task_make_call_closure( + [](uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + auto* ctx = (TestCtx*)user_context; + EXPECT_TRUE(NULL != ctx); + EXPECT_EQ(0, ctx->did_call); + ++ctx->did_call; + return iree_ok_status(); + }, + (uintptr_t)&ctx), + &task); + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task.header, &task.header)); + EXPECT_EQ(1, ctx.did_call); +} + +// Issues task_a which then issues a nested task_b and waits for it to complete +// prior to progressing. This models dynamic parallelism: +// http://developer.download.nvidia.com/GTC/PDF/GTC2012/PresentationPDF/S0338-GTC2012-CUDA-Programming-Model.pdf +TEST_F(TaskCallTest, IssueNested) { + struct TestCtx { + std::atomic did_call_a = {0}; + std::atomic did_call_b = {0}; + std::atomic has_issued = {false}; + iree_task_call_t task_b; + }; + TestCtx ctx; + + // task_a will get called twice: the first time it will schedule task_b and + // then it'll get called again when task_b completes. This is not the only way + // to do this: task_a could set it up so that a task_c ran after task_b + // completed instead of getting itself called twice. Both approaches have + // their uses. + iree_task_call_t task_a; + iree_task_call_initialize( + &scope_, + iree_task_make_call_closure( + [](uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + auto* ctx = (TestCtx*)user_context; + EXPECT_TRUE(NULL != ctx); + + if (!ctx->has_issued) { + ctx->has_issued = true; + EXPECT_EQ(0, ctx->did_call_a); + ++ctx->did_call_a; + iree_task_call_initialize( + task->scope, + iree_task_make_call_closure( + [](uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + auto* ctx = (TestCtx*)user_context; + EXPECT_TRUE(NULL != ctx); + EXPECT_EQ(0, ctx->did_call_b); + ++ctx->did_call_b; + return iree_ok_status(); + }, + user_context), + &ctx->task_b); + iree_task_set_completion_task(&ctx->task_b.header, task); + iree_task_submission_enqueue(pending_submission, + &ctx->task_b.header); + } else { + EXPECT_EQ(1, ctx->did_call_a); + ++ctx->did_call_a; + } + + return iree_ok_status(); + }, + (uintptr_t)&ctx), + &task_a); + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task_a.header, &task_a.header)); + EXPECT_EQ(2, ctx.did_call_a); + EXPECT_EQ(1, ctx.did_call_b); +} + +} // namespace diff --git a/iree/task/task_test_dispatch.cc b/iree/task/task_test_dispatch.cc new file mode 100644 index 0000000000000..c19c251c36756 --- /dev/null +++ b/iree/task/task_test_dispatch.cc @@ -0,0 +1,171 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "iree/task/testing/task_test.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace { + +class GridCoverage { + public: + explicit GridCoverage(const uint32_t workgroup_count[3]) + : workgroup_count_(workgroup_count[0] * workgroup_count[1] * + workgroup_count[2]), + storage_(new iree_atomic_int32_t[workgroup_count_]) { + for (iree_host_size_t i = 0; i < workgroup_count_; ++i) { + storage_[i] = IREE_ATOMIC_VAR_INIT(0); + } + } + + bool Verify() { + fflush(stdout); + for (iree_host_size_t i = 0; i < workgroup_count_; ++i) { + if (iree_atomic_load_int32(&storage_[i], iree_memory_order_seq_cst) != + 1) { + return false; + } + } + return true; + } + + static iree_status_t Tile(uintptr_t user_context, + const iree_task_tile_context_t* tile_context, + iree_task_submission_t* pending_submission) { + GridCoverage* coverage = reinterpret_cast(user_context); + uint32_t slot = + tile_context->workgroup_xyz[2] * (tile_context->workgroup_count[1] * + tile_context->workgroup_count[0]) + + tile_context->workgroup_xyz[1] * tile_context->workgroup_count[0] + + tile_context->workgroup_xyz[0]; + iree_atomic_fetch_add_int32(&coverage->storage_[slot], 1, + iree_memory_order_seq_cst); + + // Useful when testing large grids: + // printf("%u, %u, %u\n", tile_context->workgroup_xyz[0], + // tile_context->workgroup_xyz[1], tile_context->workgroup_xyz[2]); + + return iree_ok_status(); + } + + private: + size_t workgroup_count_; + std::unique_ptr storage_; +}; + +class TaskDispatchTest : public TaskTest { + public: + void DispatchAndVerifyGrid(const uint32_t workgroup_size[3], + const uint32_t workgroup_count[3], + uint32_t dispatch_flags) { + GridCoverage coverage(workgroup_count); + iree_task_dispatch_t task; + iree_task_dispatch_initialize(&scope_, + iree_task_make_dispatch_closure( + GridCoverage::Tile, (uintptr_t)&coverage), + workgroup_size, workgroup_count, &task); + task.header.flags |= dispatch_flags; + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task.header, &task.header)); + EXPECT_TRUE(coverage.Verify()); + } +}; + +TEST_F(TaskDispatchTest, Issue000Sharded) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {0, 0, 0}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, 0); +} + +TEST_F(TaskDispatchTest, Issue000Sliced) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {0, 0, 0}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, + IREE_TASK_FLAG_DISPATCH_SLICED); +} + +TEST_F(TaskDispatchTest, Issue120Sharded) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {1, 2, 0}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, 0); +} + +TEST_F(TaskDispatchTest, Issue120Sliced) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {1, 2, 0}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, + IREE_TASK_FLAG_DISPATCH_SLICED); +} + +TEST_F(TaskDispatchTest, Issue111Sharded) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {1, 1, 1}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, 0); +} + +TEST_F(TaskDispatchTest, Issue111Sliced) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {1, 1, 1}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, + IREE_TASK_FLAG_DISPATCH_SLICED); +} + +TEST_F(TaskDispatchTest, Issue345Sharded) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {3, 4, 5}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, 0); +} + +TEST_F(TaskDispatchTest, Issue345Sliced) { + const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + const uint32_t kWorkgroupCount[3] = {3, 4, 5}; + DispatchAndVerifyGrid(kWorkgroupSize, kWorkgroupCount, + IREE_TASK_FLAG_DISPATCH_SLICED); +} + +TEST_F(TaskDispatchTest, IssueIndirect) { + static const uint32_t kWorkgroupSize[3] = {1, 1, 1}; + static const uint32_t kWorkgroupCount[3] = {3, 4, 5}; + uint32_t indirect_workgroup_count[3] = {0, 0, 0}; + GridCoverage coverage(kWorkgroupCount); + + iree_task_call_t calculate_task; + iree_task_call_initialize( + &scope_, + iree_task_make_call_closure( + [](uintptr_t user_context, iree_task_t* task, + iree_task_submission_t* pending_submission) { + uint32_t* indirect_workgroup_count_ptr = (uint32_t*)user_context; + for (size_t i = 0; i < IREE_ARRAYSIZE(kWorkgroupCount); ++i) { + indirect_workgroup_count_ptr[i] = kWorkgroupCount[i]; + } + return iree_ok_status(); + }, + (uintptr_t)indirect_workgroup_count), + &calculate_task); + + iree_task_dispatch_t dispatch_task; + iree_task_dispatch_initialize_indirect( + &scope_, + iree_task_make_dispatch_closure(GridCoverage::Tile, (uintptr_t)&coverage), + kWorkgroupSize, indirect_workgroup_count, &dispatch_task); + iree_task_set_completion_task(&calculate_task.header, &dispatch_task.header); + + IREE_ASSERT_OK( + SubmitTasksAndWaitIdle(&calculate_task.header, &dispatch_task.header)); + EXPECT_TRUE(coverage.Verify()); +} + +} // namespace diff --git a/iree/task/task_test_fence.cc b/iree/task/task_test_fence.cc new file mode 100644 index 0000000000000..f84f737253611 --- /dev/null +++ b/iree/task/task_test_fence.cc @@ -0,0 +1,38 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/task/testing/task_test.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace { + +class TaskFenceTest : public TaskTest {}; + +TEST_F(TaskFenceTest, IssueChained) { + iree_task_fence_t task_a; + iree_task_fence_initialize(&scope_, &task_a); + + iree_task_fence_t task_b; + iree_task_fence_initialize(&scope_, &task_b); + iree_task_set_completion_task(&task_a.header, &task_b.header); + + iree_task_fence_t task_c; + iree_task_fence_initialize(&scope_, &task_c); + iree_task_set_completion_task(&task_b.header, &task_c.header); + + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task_a.header, &task_c.header)); +} + +} // namespace diff --git a/iree/hal/host/host_descriptor_set.cc b/iree/task/task_test_nop.cc similarity index 59% rename from iree/hal/host/host_descriptor_set.cc rename to iree/task/task_test_nop.cc index 2e841dfaf755e..8fa651e5da412 100644 --- a/iree/hal/host/host_descriptor_set.cc +++ b/iree/task/task_test_nop.cc @@ -1,4 +1,4 @@ -// Copyright 2020 Google LLC +// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/hal/host/host_descriptor_set.h" +#include "iree/task/testing/task_test.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" -namespace iree { -namespace hal { +namespace { -HostDescriptorSet::HostDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span bindings) - : bindings_(bindings.begin(), bindings.end()) {} +class TaskNopTest : public TaskTest {}; -HostDescriptorSet::~HostDescriptorSet() = default; +TEST_F(TaskNopTest, Issue) { + iree_task_nop_t task; + iree_task_nop_initialize(&scope_, &task); + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task.header, &task.header)); +} -} // namespace hal -} // namespace iree +} // namespace diff --git a/iree/task/task_test_wait.cc b/iree/task/task_test_wait.cc new file mode 100644 index 0000000000000..5ff3d6516f9b0 --- /dev/null +++ b/iree/task/task_test_wait.cc @@ -0,0 +1,65 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "iree/task/testing/task_test.h" +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +namespace { + +class TaskWaitTest : public TaskTest {}; + +TEST_F(TaskWaitTest, IssueSignaled) { + iree_event_t event; + iree_event_initialize(/*initial_state=*/true, &event); + + iree_task_wait_t task; + iree_task_wait_initialize(&scope_, event, &task); + + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task.header, &task.header)); + + iree_event_deinitialize(&event); +} + +TEST_F(TaskWaitTest, DISABLED_IssueUnsignaled) { + iree_event_t event; + iree_event_initialize(/*initial_state=*/false, &event); + + iree_task_wait_t task; + iree_task_wait_initialize(&scope_, event, &task); + + // Spin up a thread that will signal the event after we start waiting on it. + std::atomic has_signaled = {false}; + std::thread signal_thread([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + EXPECT_FALSE(has_signaled); + has_signaled = true; + iree_event_set(&event); + }); + + EXPECT_FALSE(has_signaled); + IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task.header, &task.header)); + EXPECT_TRUE(has_signaled); + + signal_thread.join(); + iree_event_deinitialize(&event); +} + +// TODO(benvanik): multi-waits: join wait a/b/c to task d. +// TODO(benvanik): multi-waits: co-issue wait a/b/c to task d/e/f. + +} // namespace diff --git a/iree/task/testing/BUILD b/iree/task/testing/BUILD index 5c32559e4aa42..40184a71d2566 100644 --- a/iree/task/testing/BUILD +++ b/iree/task/testing/BUILD @@ -18,6 +18,16 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "task_test", + testonly = 1, + hdrs = ["task_test.h"], + deps = [ + "//iree/task", + "//iree/testing:gtest", + ], +) + cc_library( name = "test_util", testonly = 1, diff --git a/iree/task/testing/CMakeLists.txt b/iree/task/testing/CMakeLists.txt index 5ba9ce927b8ef..1134143c2e74d 100644 --- a/iree/task/testing/CMakeLists.txt +++ b/iree/task/testing/CMakeLists.txt @@ -14,6 +14,18 @@ iree_add_all_subdirs() +iree_cc_library( + NAME + task_test + HDRS + "task_test.h" + DEPS + iree::task + iree::testing::gtest + TESTONLY + PUBLIC +) + iree_cc_library( NAME test_util diff --git a/iree/task/testing/task_test.h b/iree/task/testing/task_test.h new file mode 100644 index 0000000000000..140b33cf8fb88 --- /dev/null +++ b/iree/task/testing/task_test.h @@ -0,0 +1,84 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTE: the best kind of synchronization is no synchronization; always try to +// design your algorithm so that you don't need anything from this file :) +// See https://travisdowns.github.io/blog/2020/07/06/concurrency-costs.html + +#ifndef IREE_TASK_TESTING_TASK_TEST_H_ +#define IREE_TASK_TESTING_TASK_TEST_H_ + +#include + +#include "iree/task/executor.h" +#include "iree/task/scope.h" +#include "iree/task/task.h" +#include "iree/task/topology.h" +#include "iree/testing/status_matchers.h" + +class TaskTest : public ::testing::Test { + protected: + virtual void SetUp() { + iree_task_topology_t topology; + iree_task_topology_initialize_from_group_count(8, &topology); + IREE_ASSERT_OK(iree_task_executor_create(IREE_TASK_SCHEDULING_MODE_RESERVED, + &topology, iree_allocator_system(), + &executor_)); + iree_task_topology_deinitialize(&topology); + + iree_task_scope_initialize(iree_make_cstring_view("scope"), &scope_); + } + + virtual void TearDown() { + iree_task_scope_deinitialize(&scope_); + + iree_task_executor_release(executor_); + } + + // Submits a sequence of tasks with |head_task| at the head and |tail_task| at + // the tail (they can be the same). + iree_status_t SubmitTasksAndWaitIdle(iree_task_t* head_task, + iree_task_t* tail_task) { + iree_task_fence_t* fence = NULL; + IREE_RETURN_IF_ERROR( + iree_task_executor_acquire_fence(executor_, &scope_, &fence)); + iree_task_set_completion_task(tail_task, &fence->header); + + iree_task_submission_t submission; + iree_task_submission_initialize(&submission); + iree_task_submission_enqueue(&submission, head_task); + iree_task_executor_submit(executor_, &submission); + iree_task_executor_flush(executor_); + return iree_task_scope_wait_idle(&scope_, IREE_TIME_INFINITE_FUTURE); + } + + // Submits a DAG of tasks with |tail_task| at the tail (used just for idle + // detection). + iree_status_t SubmitAndWaitIdle(iree_task_submission_t* submission, + iree_task_t* tail_task) { + iree_task_fence_t* fence = NULL; + IREE_RETURN_IF_ERROR( + iree_task_executor_acquire_fence(executor_, &scope_, &fence)); + iree_task_set_completion_task(tail_task, &fence->header); + + iree_task_executor_submit(executor_, submission); + iree_task_executor_flush(executor_); + return iree_task_scope_wait_idle(&scope_, IREE_TIME_INFINITE_FUTURE); + } + + iree_task_executor_t* executor_ = NULL; + iree_task_scope_t scope_; +}; + +#endif // IREE_TASK_TESTING_TASK_TEST_H_ diff --git a/iree/task/topology.c b/iree/task/topology.c index ee0061f3e5fdb..f20f2ac628674 100644 --- a/iree/task/topology.c +++ b/iree/task/topology.c @@ -14,50 +14,36 @@ #include "iree/task/topology.h" +#include #include #include +#include "iree/base/debugging.h" #include "iree/base/math.h" #include "iree/base/tracing.h" +#include "iree/task/tuning.h" -struct iree_task_topology_s { - iree_allocator_t allocator; - iree_host_size_t group_capacity; - iree_host_size_t group_count; - iree_task_topology_group_t groups[0]; -}; - -iree_status_t iree_task_topology_allocate(iree_host_size_t group_capacity, - iree_allocator_t allocator, - iree_task_topology_t** out_topology) { - IREE_TRACE_ZONE_BEGIN(z0); - - iree_host_size_t topology_size = - sizeof(iree_task_topology_t) + - group_capacity * sizeof(iree_task_topology_group_t); - - iree_task_topology_t* topology = NULL; - IREE_RETURN_IF_ERROR( - iree_allocator_malloc(allocator, topology_size, (void**)&topology)); - topology->allocator = allocator; - topology->group_capacity = group_capacity; - topology->group_count = 0; +void iree_task_topology_group_initialize( + uint8_t group_index, iree_task_topology_group_t* out_group) { + memset(out_group, 0, sizeof(*out_group)); + out_group->group_index = group_index; + snprintf(out_group->name, IREE_ARRAYSIZE(out_group->name), "worker[%u]", + group_index); + iree_thread_affinity_set_any(&out_group->ideal_thread_affinity); + out_group->constructive_sharing_mask = IREE_TASK_TOPOLOGY_GROUP_MASK_ALL; +} - *out_topology = topology; - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); +void iree_task_topology_initialize(iree_task_topology_t* out_topology) { + IREE_ASSERT_ARGUMENT(out_topology); + memset(out_topology, 0, sizeof(*out_topology)); } -void iree_task_topology_free(iree_task_topology_t* topology) { - if (!topology) return; - IREE_TRACE_ZONE_BEGIN(z0); - iree_allocator_free(topology->allocator, topology); - IREE_TRACE_ZONE_END(z0); +void iree_task_topology_deinitialize(iree_task_topology_t* topology) { + IREE_ASSERT_ARGUMENT(topology); } iree_status_t iree_task_topology_parse(iree_string_view_t value, - iree_allocator_t allocator, - iree_task_topology_t** out_topology) { + iree_task_topology_t* out_topology) { // TODO(benvanik): define a format that is generally useful alongside cpuinfo. // Maybe colon-separated group-id values from thread affinities? Like: // 0.0:0.2:0.4:0.8 to indicate cores 0,2,4,8 on group 0 @@ -73,6 +59,11 @@ bool iree_task_topology_format(const iree_task_topology_t* topology, return false; } +iree_host_size_t iree_task_topology_group_capacity( + const iree_task_topology_t* topology) { + return IREE_ARRAYSIZE(topology->groups); +} + iree_host_size_t iree_task_topology_group_count( const iree_task_topology_t* topology) { return topology->group_count; @@ -86,7 +77,7 @@ const iree_task_topology_group_t* iree_task_topology_get_group( iree_status_t iree_task_topology_push_group( iree_task_topology_t* topology, const iree_task_topology_group_t* group) { - if (topology->group_count + 1 > topology->group_capacity) { + if (topology->group_count + 1 > IREE_ARRAYSIZE(topology->groups)) { return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, "group capacity exceeded"); } @@ -97,27 +88,18 @@ iree_status_t iree_task_topology_push_group( return iree_ok_status(); } -iree_status_t iree_task_topology_from_group_count( - iree_host_size_t group_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology) { +void iree_task_topology_initialize_from_group_count( + iree_host_size_t group_count, iree_task_topology_t* out_topology) { IREE_TRACE_ZONE_BEGIN(z0); - iree_task_topology_t* topology = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_task_topology_allocate(group_count, allocator, &topology)); - + iree_task_topology_initialize(out_topology); for (iree_host_size_t i = 0; i < group_count; ++i) { - iree_task_topology_group_t* group = &topology->groups[i]; - group->group_index = i; - snprintf(group->name, IREE_ARRAYSIZE(group->name), "worker[%d]", (int)i); - iree_thread_affinity_set_any(&group->ideal_thread_affinity); - group->constructive_sharing_mask = IREE_TASK_TOPOLOGY_GROUP_MASK_ALL; + iree_task_topology_group_t* group = &out_topology->groups[i]; + iree_task_topology_group_initialize(i, group); } - topology->group_count = group_count; + out_topology->group_count = group_count; - *out_topology = topology; IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); } // Runs the cpuinfo initializer which caches its result on the first call. @@ -130,6 +112,10 @@ static iree_status_t iree_task_topology_ensure_cpuinfo_available() { return iree_ok_status(); } +static bool iree_task_topology_is_cpuinfo_available() { + return cpuinfo_initialize(); +} + // Returns the core of the calling thread or NULL if not supported. // We wrap this here because cpuinfo only returns non-NULL on linux. static const struct cpuinfo_core* iree_task_topology_get_current_core() { @@ -227,10 +213,7 @@ static uint64_t iree_task_topology_calculate_constructive_sharing_mask( static void iree_task_topology_group_initialize_from_core( uint32_t group_index, const struct cpuinfo_core* core, iree_task_topology_group_t* out_group) { - memset(out_group, 0, sizeof(*out_group)); - out_group->group_index = group_index; - snprintf(out_group->name, IREE_ARRAYSIZE(out_group->name), "worker[%u]", - group_index); + iree_task_topology_group_initialize(group_index, out_group); // Guess: always pick the first processor in a core. // When pinning to threads we'll take into account whether the core is SMT @@ -273,18 +256,30 @@ static void iree_task_topology_fixup_constructive_sharing_masks( } } +// Initializes |out_topology| with a standardized behavior when cpuinfo is not +// available (unsupported arch, failed to query, etc). +static void iree_task_topology_initialize_fallback( + iree_host_size_t max_group_count, iree_task_topology_t* out_topology) { + IREE_TRACE_ZONE_BEGIN(z0); + // TODO(benvanik): implement our own query... but that seems not so great. + // For now we default to a single group: if a user wants more then they can + // either get cpuinfo working for their platform or manually construct the + // topology themselves. + iree_host_size_t group_count = 1; + iree_task_topology_initialize_from_group_count(group_count, out_topology); + IREE_TRACE_ZONE_END(z0); +} + // Matches all cores. static bool iree_task_topology_core_filter_all(const struct cpuinfo_core* core, uintptr_t user_data) { return true; } -iree_status_t iree_task_topology_from_physical_cores( - iree_host_size_t max_core_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology) { - return iree_task_topology_from_physical_cores_with_filter( - iree_task_topology_core_filter_all, 0, max_core_count, allocator, - out_topology); +void iree_task_topology_initialize_from_physical_cores( + iree_host_size_t max_core_count, iree_task_topology_t* out_topology) { + iree_task_topology_initialize_from_physical_cores_with_filter( + iree_task_topology_core_filter_all, 0, max_core_count, out_topology); } // Matches only cores with the uarch as specified in |user_data|. @@ -293,21 +288,23 @@ static bool iree_task_topology_core_filter_uarch( return core->uarch == user_data; } -iree_status_t iree_task_topology_from_physical_cores_with_uarch( +void iree_task_topology_initialize_from_physical_cores_with_uarch( uint32_t cpuinfo_uarch, iree_host_size_t max_core_count, - iree_allocator_t allocator, iree_task_topology_t** out_topology) { - return iree_task_topology_from_physical_cores_with_filter( + iree_task_topology_t* out_topology) { + iree_task_topology_initialize_from_physical_cores_with_filter( iree_task_topology_core_filter_uarch, cpuinfo_uarch, max_core_count, - allocator, out_topology); + out_topology); } -iree_status_t iree_task_topology_from_physical_cores_with_filter( +void iree_task_topology_initialize_from_physical_cores_with_filter( iree_task_topology_core_filter_t filter_fn, uintptr_t filter_fn_data, - iree_host_size_t max_core_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology) { + iree_host_size_t max_core_count, iree_task_topology_t* out_topology) { + if (!iree_task_topology_is_cpuinfo_available()) { + iree_task_topology_initialize_fallback(max_core_count, out_topology); + return; + } + IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_task_topology_ensure_cpuinfo_available()); // Count cores that match the filter. iree_host_size_t core_count = 0; @@ -317,17 +314,15 @@ iree_status_t iree_task_topology_from_physical_cores_with_filter( } core_count = iree_min(core_count, max_core_count); - iree_task_topology_t* topology = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_task_topology_allocate(core_count, allocator, &topology)); + iree_task_topology_initialize(out_topology); // Build each core up to the max allowed. // TODO(benvanik): if our group_count <= core_count/2 then distribute better; // for now we just do a straight-line through (cores 0-N) when instead we may // want to take advantage of L3 cache info (half of groups on one L3 cache, // half of groups on another, etc). - topology->group_count = core_count; - for (uint32_t core_i = 0, group_i = 0; group_i < topology->group_count; + out_topology->group_count = core_count; + for (uint32_t core_i = 0, group_i = 0; group_i < out_topology->group_count; ++core_i) { // Rotate the core ID so that we avoid setting the affinity to the calling // thread which we assume is something the user has plans for and doesn't @@ -335,31 +330,29 @@ iree_status_t iree_task_topology_from_physical_cores_with_filter( const struct cpuinfo_core* core = cpuinfo_get_core(iree_task_topology_rotate_from_base_core(core_i)); if (filter_fn(core, filter_fn_data)) { - iree_task_topology_group_initialize_from_core(group_i, core, - &topology->groups[group_i]); + iree_task_topology_group_initialize_from_core( + group_i, core, &out_topology->groups[group_i]); ++group_i; } } - iree_task_topology_fixup_constructive_sharing_masks(topology); - *out_topology = topology; + iree_task_topology_fixup_constructive_sharing_masks(out_topology); IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); } -iree_status_t iree_task_topology_from_unique_l2_cache_groups( - iree_host_size_t max_group_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology) { +void iree_task_topology_initialize_from_unique_l2_cache_groups( + iree_host_size_t max_group_count, iree_task_topology_t* out_topology) { + if (!iree_task_topology_is_cpuinfo_available()) { + iree_task_topology_initialize_fallback(max_group_count, out_topology); + return; + } + IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_task_topology_ensure_cpuinfo_available()); iree_host_size_t cache_count = cpuinfo_get_l2_caches_count(); cache_count = iree_min(cache_count, max_group_count); - iree_task_topology_t* topology = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_task_topology_allocate(cache_count, allocator, &topology)); + iree_task_topology_initialize(out_topology); // TODO(benvanik): iree_task_topology_rotate_from_base_core to offset all of // the selection here (while still preserving the cache groups). May need to @@ -369,19 +362,17 @@ iree_status_t iree_task_topology_from_unique_l2_cache_groups( // TODO(benvanik): if our group_count <= cache_count/2 then distribute better; // we could use l3 cache in addition to ensure we are selecting cores that do // (or do not) share. - topology->group_count = cache_count; - for (uint32_t cache_i = 0, group_i = 0; group_i < topology->group_count; + out_topology->group_count = cache_count; + for (uint32_t cache_i = 0, group_i = 0; group_i < out_topology->group_count; ++cache_i) { const struct cpuinfo_cache* cache = cpuinfo_get_l2_cache(cache_i); const struct cpuinfo_core* core = cpuinfo_get_processor(cache->processor_start)->core; - iree_task_topology_group_initialize_from_core(group_i, core, - &topology->groups[group_i]); + iree_task_topology_group_initialize_from_core( + group_i, core, &out_topology->groups[group_i]); ++group_i; } - iree_task_topology_fixup_constructive_sharing_masks(topology); - *out_topology = topology; + iree_task_topology_fixup_constructive_sharing_masks(out_topology); IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); } diff --git a/iree/task/topology.h b/iree/task/topology.h index c07d68f225593..8a06b76a8ba9a 100644 --- a/iree/task/topology.h +++ b/iree/task/topology.h @@ -19,6 +19,7 @@ #include "iree/base/api.h" #include "iree/base/threading.h" +#include "iree/task/tuning.h" #ifdef __cplusplus extern "C" { @@ -59,6 +60,10 @@ typedef struct { iree_task_topology_group_mask_t constructive_sharing_mask; } iree_task_topology_group_t; +// Initializes |out_group| with a |group_index| derived name. +void iree_task_topology_group_initialize(uint8_t group_index, + iree_task_topology_group_t* out_group); + // Task system topology information used to define the workers within an // executor. // @@ -76,20 +81,20 @@ typedef struct { // and attempt to derive some (hopefully) useful task system topology from it. // We can add the more common heuristics over time to the core and leave the // edge cases for applications to construct. -typedef struct iree_task_topology_s iree_task_topology_t; +typedef struct { + iree_host_size_t group_count; + iree_task_topology_group_t groups[IREE_TASK_EXECUTOR_MAX_WORKER_COUNT]; +} iree_task_topology_t; -// Allocates a task topology with at least |group_capacity|. -iree_status_t iree_task_topology_allocate(iree_host_size_t group_capacity, - iree_allocator_t allocator, - iree_task_topology_t** out_topology); +// Initializes an empty task topology. +void iree_task_topology_initialize(iree_task_topology_t* out_topology); -// Frees a topology structure. -void iree_task_topology_free(iree_task_topology_t* topology); +// Deinitializes a topology structure. +void iree_task_topology_deinitialize(iree_task_topology_t* topology); // Parses a serialized topology in string form. iree_status_t iree_task_topology_parse(iree_string_view_t value, - iree_allocator_t allocator, - iree_task_topology_t** out_topology); + iree_task_topology_t* out_topology); // Formats the topology as a string value that can be parsed with // iree_task_topology_parse. @@ -97,6 +102,10 @@ bool iree_task_topology_format(const iree_task_topology_t* topology, iree_host_size_t buffer_capacity, char* buffer, iree_host_size_t* out_buffer_length); +// Returns the group capacity in the topology structure. +iree_host_size_t iree_task_topology_group_capacity( + const iree_task_topology_t* topology); + // Returns the total group count defined by the topology. iree_host_size_t iree_task_topology_group_count( const iree_task_topology_t* topology); @@ -106,51 +115,57 @@ const iree_task_topology_group_t* iree_task_topology_get_group( const iree_task_topology_t* topology, iree_host_size_t group_index); // Pushes a new group onto the topology set. -// The provided group data will be copied into the toplogy structure. +// The provided group data will be copied into the topology structure. iree_status_t iree_task_topology_push_group( iree_task_topology_t* topology, const iree_task_topology_group_t* group); -// Allocates a topology with the specified number of groups. +// Initializes a topology with the specified number of groups. // 0 is a valid value, indicating that only donated threads will be used to // perform work. Groups will have no specific affinity and rely on the OS // scheduler to ensure they are distributed in a meaningful way; this generally // works out as threads created within a process are usually rotated across // preferred processors by default. -iree_status_t iree_task_topology_from_group_count( - iree_host_size_t group_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology); +void iree_task_topology_initialize_from_group_count( + iree_host_size_t group_count, iree_task_topology_t* out_topology); -// Allocates a topology with one group for each physical core in the machine. +// Initializes a topology with one group for each physical core in the machine. +// // If detailed cache information is not available this is a decent // approximation that can be used as a fallback. -iree_status_t iree_task_topology_from_physical_cores( - iree_host_size_t max_core_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology); +void iree_task_topology_initialize_from_physical_cores( + iree_host_size_t max_core_count, iree_task_topology_t* out_topology); -// Allocates a topology with one group for each physical core in the machine +// Initializes a topology with one group for each physical core in the machine // with the given microarchitecture specified as a cpuinfo_uarch value. -iree_status_t iree_task_topology_from_physical_cores_with_uarch( +// +// If detailed uarch information is not available this falls back to the same +// behavior as iree_task_topology_initialize_from_physical_cores. +void iree_task_topology_initialize_from_physical_cores_with_uarch( uint32_t cpuinfo_uarch, iree_host_size_t max_core_count, - iree_allocator_t allocator, iree_task_topology_t** out_topology); + iree_task_topology_t* out_topology); // Returns true if the given |core| passes the filter and should be included. // |user_data| is the value passed alongside the filter function. typedef bool (*iree_task_topology_core_filter_t)( const struct cpuinfo_core* core, uintptr_t user_data); -// Allocates a topology with one group for each core that matches |filter_fn|. -iree_status_t iree_task_topology_from_physical_cores_with_filter( +// Initializes a topology with one group for each core that matches |filter_fn|. +// +// If cpuinfo is not available this falls back to the same behavior as +// iree_task_topology_initialize_from_physical_cores. +void iree_task_topology_initialize_from_physical_cores_with_filter( iree_task_topology_core_filter_t filter_fn, uintptr_t filter_fn_data, - iree_host_size_t max_core_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology); + iree_host_size_t max_core_count, iree_task_topology_t* out_topology); -// Allocates a topology with one group for each unique L2 cache group across +// Initializes a topology with one group for each unique L2 cache group across // all available cores. This optimizes for temporal and spatial cache locality // but may suffer from oversubscription if there are other processes trying to // use the same cores. -iree_status_t iree_task_topology_from_unique_l2_cache_groups( - iree_host_size_t max_group_count, iree_allocator_t allocator, - iree_task_topology_t** out_topology); +// +// If detailed cache information is not available this falls back to the same +// behavior as iree_task_topology_initialize_from_physical_cores. +void iree_task_topology_initialize_from_unique_l2_cache_groups( + iree_host_size_t max_group_count, iree_task_topology_t* out_topology); // TODO(benvanik): more? or just make users implement as desired? Ideas: // - _from_unique_l2_cache_groups but with a min/max count (N% utilization) diff --git a/iree/task/topology_test.cc b/iree/task/topology_test.cc index 2b7a665376604..40126c29b11da 100644 --- a/iree/task/topology_test.cc +++ b/iree/task/topology_test.cc @@ -14,13 +14,141 @@ #include "iree/task/topology.h" +#include + #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" namespace { -TEST(TopologyTest, Any) { - // TODO(benvanik): tests. +using namespace iree::testing::status; + +TEST(TopologyTest, Lifetime) { + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + EXPECT_GT(iree_task_topology_group_capacity(&topology), 0); + EXPECT_EQ(0, iree_task_topology_group_count(&topology)); + iree_task_topology_deinitialize(&topology); +} + +TEST(TopologyTest, Empty) { + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + + EXPECT_EQ(0, iree_task_topology_group_count(&topology)); + EXPECT_EQ(NULL, iree_task_topology_get_group(&topology, 0)); + EXPECT_EQ(NULL, iree_task_topology_get_group(&topology, 100)); + + iree_task_topology_deinitialize(&topology); +} + +TEST(TopologyTest, Parsing) { + // TODO(benvanik): implement parsing. +} + +TEST(TopologyTest, Formatting) { + // TODO(benvanik): implement formatting. +} + +TEST(TopologyTest, Construction) { + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + + EXPECT_EQ(0, iree_task_topology_group_count(&topology)); + + for (iree_host_size_t i = 0; i < 8; ++i) { + iree_task_topology_group_t group; + iree_task_topology_group_initialize(i, &group); + IREE_EXPECT_OK(iree_task_topology_push_group(&topology, &group)); + EXPECT_EQ(i + 1, iree_task_topology_group_count(&topology)); + } + EXPECT_EQ(8, iree_task_topology_group_count(&topology)); + + for (iree_host_size_t i = 0; i < 8; ++i) { + const iree_task_topology_group_t* group = + iree_task_topology_get_group(&topology, i); + EXPECT_EQ(i, group->group_index); + } + + iree_task_topology_deinitialize(&topology); +} + +TEST(TopologyTest, MaxCapacity) { + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + + EXPECT_EQ(0, iree_task_topology_group_count(&topology)); + + // Fill up to capacity. + for (iree_host_size_t i = 0; i < iree_task_topology_group_capacity(&topology); + ++i) { + iree_task_topology_group_t group; + iree_task_topology_group_initialize(i, &group); + IREE_EXPECT_OK(iree_task_topology_push_group(&topology, &group)); + EXPECT_EQ(i + 1, iree_task_topology_group_count(&topology)); + } + EXPECT_EQ(iree_task_topology_group_capacity(&topology), + iree_task_topology_group_count(&topology)); + + // Try adding one more - it should it fail because we are at capacity. + iree_task_topology_group_t extra_group; + iree_task_topology_group_initialize(UINT8_MAX, &extra_group); + iree_status_t status = iree_task_topology_push_group(&topology, &extra_group); + EXPECT_TRUE(iree_status_is_resource_exhausted(status)); + iree_status_ignore(status); + + // Confirm that the only groups we have are the valid ones we added above. + for (iree_host_size_t i = 0; i < 8; ++i) { + const iree_task_topology_group_t* group = + iree_task_topology_get_group(&topology, i); + EXPECT_EQ(i, group->group_index); + } + + iree_task_topology_deinitialize(&topology); +} + +TEST(TopologyTest, FromGroupCount) { + static constexpr iree_host_size_t kGroupCount = 4; + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + + iree_task_topology_initialize_from_group_count(kGroupCount, &topology); + EXPECT_LE(iree_task_topology_group_count(&topology), + iree_task_topology_group_capacity(&topology)); + EXPECT_EQ(iree_task_topology_group_count(&topology), kGroupCount); + for (iree_host_size_t i = 0; i < kGroupCount; ++i) { + const iree_task_topology_group_t* group = + iree_task_topology_get_group(&topology, i); + EXPECT_EQ(i, group->group_index); + } + + iree_task_topology_deinitialize(&topology); +} + +// Verifies only that the |topology| is usable. +// If we actually checked the contents here then we'd just be validating that +// cpuinfo was working and the tests would become machine-dependent. +static void EnsureTopologyValid(iree_host_size_t max_group_count, + iree_task_topology_t* topology) { + EXPECT_LE(iree_task_topology_group_count(topology), + iree_task_topology_group_capacity(topology)); + EXPECT_LE(iree_task_topology_group_count(topology), max_group_count); + EXPECT_GE(iree_task_topology_group_count(topology), 1); + for (iree_host_size_t i = 0; i < iree_task_topology_group_count(topology); + ++i) { + const iree_task_topology_group_t* group = + iree_task_topology_get_group(topology, i); + EXPECT_EQ(i, group->group_index); + } +} + +TEST(TopologyTest, FromPhysicalCores) { + static constexpr iree_host_size_t kMaxGroupCount = 4; + iree_task_topology_t topology; + iree_task_topology_initialize(&topology); + iree_task_topology_initialize_from_physical_cores(kMaxGroupCount, &topology); + EnsureTopologyValid(kMaxGroupCount, &topology); + iree_task_topology_deinitialize(&topology); } } // namespace diff --git a/iree/task/worker.c b/iree/task/worker.c index a2eecb206ad94..a5572cbfd393c 100644 --- a/iree/task/worker.c +++ b/iree/task/worker.c @@ -47,7 +47,7 @@ iree_status_t iree_task_worker_initialize( initial_state = IREE_TASK_WORKER_STATE_SUSPENDED; } iree_atomic_store_int32(&out_worker->state, initial_state, - iree_memory_order_relaxed); + iree_memory_order_seq_cst); iree_notification_initialize(&out_worker->wake_notification); iree_notification_initialize(&out_worker->state_notification); @@ -125,7 +125,7 @@ void iree_task_worker_request_exit(iree_task_worker_t* worker) { case IREE_TASK_WORKER_STATE_ZOMBIE: // Worker already exited; reset state to ZOMBIE. iree_atomic_store_int32(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, - iree_memory_order_relaxed); + iree_memory_order_seq_cst); break; default: // Worker now set to EXITING and should exit soon. @@ -166,8 +166,9 @@ iree_task_t* iree_task_worker_try_steal_task(iree_task_worker_t* worker, // Executes a task on a worker. // Only task types that are scheduled to workers are handled; all others must be // handled by the coordinator during scheduling. -static iree_status_t iree_task_worker_execute(iree_task_worker_t* worker, - iree_task_t* task) { +static iree_status_t iree_task_worker_execute( + iree_task_worker_t* worker, iree_task_t* task, + iree_task_submission_t* pending_submission) { // Execute the task and resolve the task and gather any tasks that are now // ready for submission to the executor. They'll be scheduled the next time // the coordinator runs. @@ -176,22 +177,20 @@ static iree_status_t iree_task_worker_execute(iree_task_worker_t* worker, // BFS behavior at the cost of the additional merge overhead - it's probably // worth it? // TODO(benvanik): handle partial tasks and re-queuing. - iree_task_submission_t pending_submission; - iree_task_submission_initialize(&pending_submission); switch (task->type) { case IREE_TASK_TYPE_CALL: { IREE_RETURN_IF_ERROR( - iree_task_call_execute((iree_task_call_t*)task, &pending_submission)); + iree_task_call_execute((iree_task_call_t*)task, pending_submission)); break; } case IREE_TASK_TYPE_DISPATCH_SLICE: { IREE_RETURN_IF_ERROR(iree_task_dispatch_slice_execute( - (iree_task_dispatch_slice_t*)task, &pending_submission)); + (iree_task_dispatch_slice_t*)task, pending_submission)); break; } case IREE_TASK_TYPE_DISPATCH_SHARD: { IREE_RETURN_IF_ERROR(iree_task_dispatch_shard_execute( - (iree_task_dispatch_shard_t*)task, &pending_submission)); + (iree_task_dispatch_shard_t*)task, pending_submission)); break; } default: @@ -202,16 +201,14 @@ static iree_status_t iree_task_worker_execute(iree_task_worker_t* worker, // NOTE: task is invalidated here! task = NULL; - if (!iree_task_submission_is_empty(&pending_submission)) { - iree_task_executor_merge_submission(worker->executor, &pending_submission); - } return iree_ok_status(); } // Pumps the worker thread once, processing a single task. // Returns true if pumping should continue as there are more tasks remaining or // false if the caller should wait for more tasks to be posted. -static bool iree_task_worker_pump_once(iree_task_worker_t* worker) { +static bool iree_task_worker_pump_once( + iree_task_worker_t* worker, iree_task_submission_t* pending_submission) { IREE_TRACE_ZONE_BEGIN(z0); // Check the local work queue for any work we know we should start @@ -229,8 +226,8 @@ static bool iree_task_worker_pump_once(iree_task_worker_t* worker) { // first place (large uneven workloads for various workers, bad distribution // in the face of heterogenous multi-core architectures where some workers // complete tasks faster than others, etc). - task = iree_task_queue_append_from_lifo_slist(&worker->local_task_queue, - &worker->mailbox_slist); + task = iree_task_queue_flush_from_lifo_slist(&worker->local_task_queue, + &worker->mailbox_slist); } // If we ran out of work assigned to this specific worker try to steal some @@ -252,7 +249,8 @@ static bool iree_task_worker_pump_once(iree_task_worker_t* worker) { // Execute the task (may call out to arbitrary user code and may submit more // tasks for execution). - iree_status_t status = iree_task_worker_execute(worker, task); + iree_status_t status = + iree_task_worker_execute(worker, task, pending_submission); // TODO(#4026): propagate failure to task scope. // We currently drop the error on the floor here; that's because the error @@ -275,14 +273,6 @@ static bool iree_task_worker_pump_once(iree_task_worker_t* worker) { static void iree_task_worker_pump_until_exit(iree_task_worker_t* worker) { // Pump the thread loop to process more tasks. while (true) { - // Check state to see if we've been asked to exit. - if (iree_atomic_load_int32(&worker->state, iree_memory_order_relaxed) == - IREE_TASK_WORKER_STATE_EXITING) { - // Thread exit requested - cancel pumping. - // TODO(benvanik): complete tasks before exiting? - break; - } - // If we fail to find any work to do we'll wait at the end of this loop. // In order not to not miss any work that is enqueued after we've already // checked a particular source we use an interruptable wait token that @@ -292,12 +282,38 @@ static void iree_task_worker_pump_until_exit(iree_task_worker_t* worker) { iree_notification_prepare_wait(&worker->wake_notification); iree_atomic_task_affinity_set_fetch_and(&worker->executor->worker_idle_mask, ~worker->worker_bit, - iree_memory_order_relaxed); + iree_memory_order_seq_cst); + + // Check state to see if we've been asked to exit. + if (iree_atomic_load_int32(&worker->state, iree_memory_order_seq_cst) == + IREE_TASK_WORKER_STATE_EXITING) { + // Thread exit requested - cancel pumping. + iree_notification_cancel_wait(&worker->wake_notification); + // TODO(benvanik): complete tasks before exiting? + break; + } + + iree_task_submission_t pending_submission; + iree_task_submission_initialize(&pending_submission); - while (iree_task_worker_pump_once(worker)) { + while (iree_task_worker_pump_once(worker, &pending_submission)) { // All work done ^, which will return false when the worker should wait. } + bool schedule_dirty = false; + if (!iree_task_submission_is_empty(&pending_submission)) { + iree_task_executor_merge_submission(worker->executor, + &pending_submission); + schedule_dirty = true; + } + + // We've finished all the work we have scheduled so set our idle flag. + // This ensures that if any other thread comes in and wants to give us + // work we will properly coordinate/wake below. + iree_atomic_task_affinity_set_fetch_or(&worker->executor->worker_idle_mask, + worker->worker_bit, + iree_memory_order_seq_cst); + // When we encounter a complete lack of work we can self-nominate to check // the global work queue and distribute work to other threads. Only one // coordinator can be running at a time so we also ensure that if another @@ -312,15 +328,13 @@ static void iree_task_worker_pump_until_exit(iree_task_worker_t* worker) { // If nothing has been enqueued since we started this loop (so even // coordination didn't find anything) we go idle. Otherwise we fall // through and try the loop again. - if (!iree_task_queue_is_empty(&worker->local_task_queue)) { + if (schedule_dirty || + !iree_task_queue_is_empty(&worker->local_task_queue)) { // Have more work to do; loop around to try another pump. iree_notification_cancel_wait(&worker->wake_notification); } else { IREE_TRACE_ZONE_BEGIN_NAMED(z_wait, "iree_task_worker_main_pump_wake_wait"); - iree_atomic_task_affinity_set_fetch_or( - &worker->executor->worker_idle_mask, worker->worker_bit, - iree_memory_order_relaxed); iree_notification_commit_wait(&worker->wake_notification, wait_token); IREE_TRACE_ZONE_END(z_wait); } @@ -353,7 +367,7 @@ static int iree_task_worker_main(iree_task_worker_t* worker) { IREE_TRACE_ZONE_END(thread_zone); iree_atomic_store_int32(&worker->state, IREE_TASK_WORKER_STATE_ZOMBIE, - iree_memory_order_release); + iree_memory_order_seq_cst); iree_notification_post(&worker->state_notification, IREE_ALL_WAITERS); return 0; } diff --git a/iree/test/e2e/regression/dynamic_dot_general.mlir b/iree/test/e2e/regression/dynamic_dot_general.mlir index e79fac328933f..3d5fb99b63066 100644 --- a/iree/test/e2e/regression/dynamic_dot_general.mlir +++ b/iree/test/e2e/regression/dynamic_dot_general.mlir @@ -1,4 +1,4 @@ -// RUN: iree-run-mlir -export-all %s -iree-hal-target-backends=vmla -function-input="2x2xf32=[[1.0, 0.0], [0.0, 1.0]]" -function-input="2x3xf32=[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]" -function-input="2x2x2xf32=[[[1.0, 0.0], [0.0, 1.0]], [[2.0, 0.0], [0.0, 2.0]]]" -function-input="2x2x3xf32=[[[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]" | IreeFileCheck %s +// RUN: iree-run-mlir -export-all %s -iree-hal-target-backends=vmla -function-input="2x2xf32=[[1.0,0.0],[0.0,1.0]]" -function-input="2x3xf32=[[1.0,2.0,3.0],[4.0,5.0,6.0]]" -function-input="2x2x2xf32=[[[1.0,0.0],[0.0,1.0]],[[2.0,0.0],[0.0,2.0]]]" -function-input="2x2x3xf32=[[[1.5,2.5,3.5],[4.5,5.5,6.5]],[[1.0,2.0,3.0],[4.0,5.0,6.0]]]" | IreeFileCheck %s // TODO(silvasean): Extent xla_ops directory test infra to support // testing dynamic shapes. diff --git a/iree/test/e2e/regression/dynamic_torch_index_select_high_rank.mlir b/iree/test/e2e/regression/dynamic_torch_index_select_high_rank.mlir index 8a02faa6d74f8..9dfc3f64b2472 100644 --- a/iree/test/e2e/regression/dynamic_torch_index_select_high_rank.mlir +++ b/iree/test/e2e/regression/dynamic_torch_index_select_high_rank.mlir @@ -1,4 +1,4 @@ -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=dylib-llvm-aot -function-input="2x2xi32=[6, 7] [8, 9]" -function-input="2x2x2x2xi32=[[[0, 1] [1, 0]] [[0, 0] [1, 1]]] [[[1, 1] [0, 0]] [[0, 1] [1, 0]]]" | IreeFileCheck %s) +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir %s -iree-hal-target-backends=dylib-llvm-aot -function-input="2x2xi32=[6,7][8,9]" -function-input="2x2x2x2xi32=[[[0,1][1,0]][[0,0][1,1]]][[[1,1][0,0]][[0,1][1,0]]]" | IreeFileCheck %s) // CHECK-LABEL: EXEC @torch_index_select1 func @torch_index_select1(%arg0: tensor, %arg1: tensor) -> tensor attributes {iree.module.export} { diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD index a4abcb57a0b59..2693834edf59b 100644 --- a/iree/test/e2e/xla_ops/BUILD +++ b/iree/test/e2e/xla_ops/BUILD @@ -33,45 +33,6 @@ iree_check_single_backend_test_suite( target_backend = "vmla", ) -iree_check_single_backend_test_suite( - name = "check_metal-spirv_metal", - srcs = [ - "abs.mlir", - "add.mlir", - "broadcast.mlir", - "broadcast_add.mlir", - "broadcast_in_dim.mlir", - "clamp.mlir", - "compare.mlir", - "constant.mlir", - "convert.mlir", - "cosine.mlir", - "divide.mlir", - "exponential.mlir", - "gather.mlir", - "log.mlir", - "log_plus_one.mlir", - "maximum.mlir", - "minimum.mlir", - "multiply.mlir", - "negate.mlir", - "remainder.mlir", - "reshape.mlir", - "rsqrt.mlir", - "select.mlir", - "sine.mlir", - "slice.mlir", - "sqrt.mlir", - "subtract.mlir", - "tanh.mlir", - "torch_index_select.mlir", - "transpose.mlir", - "while.mlir", - ], - driver = "metal", - target_backend = "metal-spirv", -) - iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", srcs = [ @@ -177,7 +138,6 @@ test_suite( name = "check", tests = [ ":check_dylib-llvm-aot_dylib", - ":check_metal-spirv_metal", ":check_vmla_vmla", ":check_vulkan-spirv_vulkan", ], diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt index 0c8978dbc6b5f..65c4d936a8550 100644 --- a/iree/test/e2e/xla_ops/CMakeLists.txt +++ b/iree/test/e2e/xla_ops/CMakeLists.txt @@ -26,47 +26,6 @@ iree_check_single_backend_test_suite( "vmla" ) -iree_check_single_backend_test_suite( - NAME - check_metal-spirv_metal - SRCS - "abs.mlir" - "add.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - "clamp.mlir" - "compare.mlir" - "constant.mlir" - "convert.mlir" - "cosine.mlir" - "divide.mlir" - "exponential.mlir" - "gather.mlir" - "log.mlir" - "log_plus_one.mlir" - "maximum.mlir" - "minimum.mlir" - "multiply.mlir" - "negate.mlir" - "remainder.mlir" - "reshape.mlir" - "rsqrt.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - "while.mlir" - TARGET_BACKEND - "metal-spirv" - DRIVER - "metal" -) - iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan diff --git a/iree/test/e2e/xla_ops/partial/BUILD b/iree/test/e2e/xla_ops/partial/BUILD index 9adea0cfe0b99..3228fcdd59ec1 100644 --- a/iree/test/e2e/xla_ops/partial/BUILD +++ b/iree/test/e2e/xla_ops/partial/BUILD @@ -38,14 +38,6 @@ iree_check_single_backend_test_suite( target_backend = "vmla", ) -iree_check_single_backend_test_suite( - name = "check_metal-spirv_metal", - srcs = [ - ], - driver = "metal", - target_backend = "metal-spirv", -) - iree_check_single_backend_test_suite( name = "check_vulkan-spirv_vulkan", srcs = [ @@ -66,7 +58,6 @@ test_suite( name = "check", tests = [ ":check_dylib-llvm-aot_dylib", - ":check_metal-spirv_metal", ":check_vmla_vmla", ":check_vulkan-spirv_vulkan", ], diff --git a/iree/test/e2e/xla_ops/partial/CMakeLists.txt b/iree/test/e2e/xla_ops/partial/CMakeLists.txt index 7cc0f007b7511..43b2bba9f496f 100644 --- a/iree/test/e2e/xla_ops/partial/CMakeLists.txt +++ b/iree/test/e2e/xla_ops/partial/CMakeLists.txt @@ -26,15 +26,6 @@ iree_check_single_backend_test_suite( "vmla" ) -iree_check_single_backend_test_suite( - NAME - check_metal-spirv_metal - TARGET_BACKEND - "metal-spirv" - DRIVER - "metal" -) - iree_check_single_backend_test_suite( NAME check_vulkan-spirv_vulkan diff --git a/iree/testing/gtest_main.cc b/iree/testing/gtest_main.cc index 1e8ff363f7588..b8f45825b517b 100644 --- a/iree/testing/gtest_main.cc +++ b/iree/testing/gtest_main.cc @@ -16,8 +16,8 @@ #include "iree/testing/gtest.h" extern "C" int main(int argc, char** argv) { - iree_flags_parse_checked(&argc, &argv); ::testing::InitGoogleTest(&argc, argv); + iree_flags_parse_checked(&argc, &argv); return RUN_ALL_TESTS(); } diff --git a/iree/testing/status_matchers.h b/iree/testing/status_matchers.h index b0052bbfa7952..d68500f263f1a 100644 --- a/iree/testing/status_matchers.h +++ b/iree/testing/status_matchers.h @@ -327,9 +327,9 @@ inline internal::IsOkMatcherGenerator IsOk() { // Macros for testing the results of functions that return iree::Status or // iree::StatusOr (for any type T). #define IREE_EXPECT_OK(rexpr) \ - EXPECT_THAT(rexpr, ::iree::testing::status::IsOk()) + EXPECT_THAT(rexpr, ::iree::testing::status::StatusIs(::iree::StatusCode::kOk)) #define IREE_ASSERT_OK(rexpr) \ - ASSERT_THAT(rexpr, ::iree::testing::status::IsOk()) + ASSERT_THAT(rexpr, ::iree::testing::status::StatusIs(::iree::StatusCode::kOk)) #define IREE_EXPECT_STATUS_IS(expected_code, expr) \ EXPECT_THAT(expr, ::iree::testing::status::StatusIs( \ static_cast<::iree::StatusCode>(expected_code))) diff --git a/iree/testing/vulkan/CMakeLists.txt b/iree/testing/vulkan/CMakeLists.txt index 8b650aa936d5c..263ff177b64fa 100644 --- a/iree/testing/vulkan/CMakeLists.txt +++ b/iree/testing/vulkan/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT ${IREE_HAL_DRIVER_VULKAN} OR NOT ${IREE_BUILD_SAMPLES}) +if(NOT "${IREE_HAL_DRIVER_VULKAN}" OR NOT "${IREE_BUILD_SAMPLES}") return() endif() @@ -51,7 +51,7 @@ iree_cc_library( imgui::imgui iree::base::api iree::base::logging - iree::hal::vulkan::api + iree::hal::vulkan SDL2::SDL2 Vulkan::Vulkan ) diff --git a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc index 28bccd044a546..1d20bf2bf8e82 100644 --- a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc +++ b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc @@ -168,9 +168,8 @@ int iree::IreeMain(int argc, char** argv) { // Setup Vulkan iree_hal_vulkan_features_t iree_vulkan_features = static_cast( - IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS | - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS | + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); std::vector layers = GetInstanceLayers(iree_vulkan_features); std::vector extensions = GetInstanceExtensions(window, iree_vulkan_features); @@ -275,28 +274,32 @@ int iree::IreeMain(int argc, char** argv) { // Load symbols from our static `vkGetInstanceProcAddr` for IREE to use. iree_hal_vulkan_syms_t* iree_vk_syms = nullptr; IREE_CHECK_OK(iree_hal_vulkan_syms_create( - reinterpret_cast(&vkGetInstanceProcAddr), &iree_vk_syms)); + reinterpret_cast(&vkGetInstanceProcAddr), iree_allocator_system(), + &iree_vk_syms)); // Create the driver sharing our VkInstance. iree_hal_driver_t* iree_vk_driver = nullptr; - iree_hal_vulkan_driver_options_t options; - options.api_version = VK_API_VERSION_1_0; - options.features = static_cast( - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan"); + iree_hal_vulkan_driver_options_t driver_options; + driver_options.api_version = VK_API_VERSION_1_0; + driver_options.requested_features = static_cast( + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance( - options, iree_vk_syms, g_Instance, &iree_vk_driver)); + driver_identifier, &driver_options, iree_vk_syms, g_Instance, + iree_allocator_system(), &iree_vk_driver)); // Create a device sharing our VkDevice and queue. This makes capturing with // vendor tools easier because we will have sync compute residing in the // rendered frame. + iree_string_view_t device_identifier = iree_make_cstring_view("vulkan"); iree_hal_vulkan_queue_set_t compute_queue_set; compute_queue_set.queue_family_index = g_QueueFamily; compute_queue_set.queue_indices = 1 << 0; iree_hal_vulkan_queue_set_t transfer_queue_set; transfer_queue_set.queue_indices = 0; iree_hal_device_t* iree_vk_device = nullptr; - IREE_CHECK_OK(iree_hal_vulkan_driver_wrap_device( - iree_vk_driver, g_PhysicalDevice, g_Device, compute_queue_set, - transfer_queue_set, &iree_vk_device)); + IREE_CHECK_OK(iree_hal_vulkan_wrap_device( + device_identifier, &driver_options.device_options, iree_vk_syms, + g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set, + &transfer_queue_set, iree_allocator_system(), &iree_vk_device)); // Create a HAL module using the HAL device. iree_vm_module_t* hal_module = nullptr; IREE_CHECK_OK(iree_hal_module_create(iree_vk_device, iree_allocator_system(), diff --git a/iree/testing/vulkan/vulkan_gui_util.cc b/iree/testing/vulkan/vulkan_gui_util.cc index c633c2cb75aa9..9eb7b57119627 100644 --- a/iree/testing/vulkan/vulkan_gui_util.cc +++ b/iree/testing/vulkan/vulkan_gui_util.cc @@ -35,11 +35,13 @@ std::vector GetIreeLayers( iree_hal_vulkan_extensibility_set_t extensibility_set, iree_hal_vulkan_features_t features) { iree_host_size_t required_count; - iree_hal_vulkan_get_layers(extensibility_set, features, 0, NULL, - &required_count); + iree_hal_vulkan_query_extensibility_set( + features, extensibility_set, /*string_capacity=*/0, + /*out_string_values=*/NULL, &required_count); std::vector layers(required_count); - iree_hal_vulkan_get_layers(extensibility_set, features, layers.size(), - layers.data(), &required_count); + iree_hal_vulkan_query_extensibility_set(features, extensibility_set, + layers.size(), layers.data(), + &required_count); return layers; } @@ -49,11 +51,13 @@ std::vector GetIreeExtensions( iree_hal_vulkan_extensibility_set_t extensibility_set, iree_hal_vulkan_features_t features) { iree_host_size_t required_count; - iree_hal_vulkan_get_extensions(extensibility_set, features, 0, NULL, - &required_count); + iree_hal_vulkan_query_extensibility_set( + features, extensibility_set, /*string_capacity=*/0, + /*out_string_values=*/NULL, &required_count); std::vector extensions(required_count); - iree_hal_vulkan_get_extensions(extensibility_set, features, extensions.size(), - extensions.data(), &required_count); + iree_hal_vulkan_query_extensibility_set(features, extensibility_set, + extensions.size(), extensions.data(), + &required_count); return extensions; } @@ -61,10 +65,12 @@ std::vector GetIreeExtensions( // |vulkan_features|. std::vector GetDeviceExtensions( iree_hal_vulkan_features_t vulkan_features) { - std::vector iree_required_extensions = - GetIreeExtensions(IREE_HAL_VULKAN_DEVICE_REQUIRED, vulkan_features); - std::vector iree_optional_extensions = - GetIreeExtensions(IREE_HAL_VULKAN_DEVICE_OPTIONAL, vulkan_features); + std::vector iree_required_extensions = GetIreeExtensions( + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED, + vulkan_features); + std::vector iree_optional_extensions = GetIreeExtensions( + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, + vulkan_features); // Merge extensions lists, including optional and required for simplicity. std::set ext_set; @@ -82,10 +88,10 @@ std::vector GetDeviceExtensions( std::vector GetInstanceLayers( iree_hal_vulkan_features_t vulkan_features) { // Query the layers that IREE wants / needs. - std::vector required_layers = - GetIreeLayers(IREE_HAL_VULKAN_INSTANCE_REQUIRED, vulkan_features); - std::vector optional_layers = - GetIreeLayers(IREE_HAL_VULKAN_INSTANCE_OPTIONAL, vulkan_features); + std::vector required_layers = GetIreeLayers( + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, vulkan_features); + std::vector optional_layers = GetIreeLayers( + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, vulkan_features); // Query the layers that are available on the Vulkan ICD. uint32_t layer_property_count = 0; @@ -131,10 +137,12 @@ std::vector GetInstanceExtensions( SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count, sdl_extensions.data()); - std::vector iree_required_extensions = - GetIreeExtensions(IREE_HAL_VULKAN_INSTANCE_REQUIRED, vulkan_features); - std::vector iree_optional_extensions = - GetIreeExtensions(IREE_HAL_VULKAN_INSTANCE_OPTIONAL, vulkan_features); + std::vector iree_required_extensions = GetIreeExtensions( + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED, + vulkan_features); + std::vector iree_optional_extensions = GetIreeExtensions( + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL, + vulkan_features); // Merge extensions lists, including optional and required for simplicity. std::set ext_set; diff --git a/iree/tools/BUILD b/iree/tools/BUILD index ee866808e54d1..9864340dcbf3e 100644 --- a/iree/tools/BUILD +++ b/iree/tools/BUILD @@ -23,7 +23,6 @@ package( exports_files([ "run_lit.sh", - "sanitizer_suppressions.txt", ]) cc_binary( diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt index 5a49475f30647..d35e3c4a23930 100644 --- a/iree/tools/CMakeLists.txt +++ b/iree/tools/CMakeLists.txt @@ -21,19 +21,19 @@ add_subdirectory(utils) # Enable compiler targets based on options. set(IREE_COMPILER_TARGETS "") set(IREE_COMPILER_TARGET_COPTS "") -if(${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}) +if("${IREE_TARGET_BACKEND_DYLIB-LLVM-AOT}") list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::LLVM) list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_LLVMAOT_TARGET") endif() -if(${IREE_TARGET_BACKEND_METAL-SPIRV}) +if("${IREE_TARGET_BACKEND_METAL-SPIRV}") list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::MetalSPIRV) list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_METALSPIRV_TARGET") endif() -if(${IREE_TARGET_BACKEND_VMLA}) +if("${IREE_TARGET_BACKEND_VMLA}") list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VMLA) list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VMLA_TARGET") endif() -if(${IREE_TARGET_BACKEND_VULKAN-SPIRV}) +if("${IREE_TARGET_BACKEND_VULKAN-SPIRV}") list(APPEND IREE_COMPILER_TARGETS iree::compiler::Dialect::HAL::Target::VulkanSPIRV) list(APPEND IREE_COMPILER_TARGET_COPTS "-DIREE_HAVE_VULKANSPIRV_TARGET") endif() diff --git a/iree/tools/sanitizer_suppressions.txt b/iree/tools/sanitizer_suppressions.txt deleted file mode 100644 index e8c20b9300d38..0000000000000 --- a/iree/tools/sanitizer_suppressions.txt +++ /dev/null @@ -1 +0,0 @@ -leak:libGLX_nvidia.so diff --git a/iree/tools/test/iree-benchmark-module.mlir b/iree/tools/test/iree-benchmark-module.mlir new file mode 100644 index 0000000000000..a75b6394a8bdf --- /dev/null +++ b/iree/tools/test/iree-benchmark-module.mlir @@ -0,0 +1,9 @@ +// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s +// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vulkan --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s) +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=dylib --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s) + +// CHECK-LABEL: BM_abs +func @abs(%input : tensor) -> (tensor) attributes { iree.module.export } { + %result = "mhlo.abs"(%input) : (tensor) -> tensor + return %result : tensor +} diff --git a/iree/tools/test/iree-run-mlir.mlir b/iree/tools/test/iree-run-mlir.mlir new file mode 100644 index 0000000000000..a79dc6860b1d5 --- /dev/null +++ b/iree/tools/test/iree-run-mlir.mlir @@ -0,0 +1,10 @@ +// RUN: (iree-run-mlir --iree-hal-target-backends=vmla --function-input="i32=-2" %s) | IreeFileCheck %s +// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv --function-input="i32=-2" %s | IreeFileCheck %s) +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=dylib-llvm-aot --function-input="i32=-2" %s | IreeFileCheck %s) + +// CHECK-LABEL: EXEC @abs +func @abs(%input : tensor) -> (tensor) attributes { iree.module.export } { + %result = "mhlo.abs"(%input) : (tensor) -> tensor + return %result : tensor +} +// CHECK: i32=2 diff --git a/iree/tools/test/iree-run-module.mlir b/iree/tools/test/iree-run-module.mlir new file mode 100644 index 0000000000000..a97daf912e636 --- /dev/null +++ b/iree/tools/test/iree-run-module.mlir @@ -0,0 +1,10 @@ +// RUN: (iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vmla --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s +// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s) +// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=dylib --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s) + +// CHECK-LABEL: EXEC @abs +func @abs(%input : tensor) -> (tensor) attributes { iree.module.export } { + %result = "mhlo.abs"(%input) : (tensor) -> tensor + return %result : tensor +} +// CHECK: i32=2 diff --git a/iree/tools/test/simple.mlir b/iree/tools/test/simple.mlir deleted file mode 100644 index e81c719e8391f..0000000000000 --- a/iree/tools/test/simple.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// iree-run-module -// RUN: (iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vmla --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s -// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=vulkan --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s) -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || ((iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-run-module --driver=dylib --entry_function=abs --function_inputs="i32=-2") | IreeFileCheck %s) - -// iree-benchmark-module -// RUN: iree-translate --iree-hal-target-backends=vmla -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vmla --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s --check-prefix=BENCHMARK -// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=vulkan-spirv -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=vulkan --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s --check-prefix=BENCHMARK) -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-translate --iree-hal-target-backends=dylib-llvm-aot -iree-mlir-to-vm-bytecode-module %s | iree-benchmark-module --driver=dylib --entry_function=abs --function_inputs="i32=-2" | IreeFileCheck %s --check-prefix=BENCHMARK) - -// iree-run-mlir -// RUN: (iree-run-mlir --iree-hal-target-backends=vmla --function-input="i32=-2" %s) | IreeFileCheck %s -// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv --function-input="i32=-2" %s | IreeFileCheck %s) -// RUN: [[ $IREE_LLVMAOT_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=dylib-llvm-aot --function-input="i32=-2" %s | IreeFileCheck %s) - -// BENCHMARK-LABEL: BM_abs -// CHECK-LABEL: EXEC @abs -func @abs(%input : tensor) -> (tensor) attributes { iree.module.export } { - %result = "mhlo.abs"(%input) : (tensor) -> tensor - return %result : tensor -} -// CHECK: i32=2 diff --git a/iree/vm/native_module_cc.h b/iree/vm/native_module_cc.h index 32c0d56ed64cd..d902aeb54b0e0 100644 --- a/iree/vm/native_module_cc.h +++ b/iree/vm/native_module_cc.h @@ -235,12 +235,15 @@ class NativeModule { /*frame_cleanup_fn=*/nullptr, &callee_frame)); auto* state = FromStatePointer(callee_frame->module_state); - IREE_RETURN_IF_ERROR(info.call(info.ptr, state, stack, call, out_result), - "while invoking C++ function %s.%.*s", module->name_, - (int)info.name.size, info.name.data); + iree_status_t status = info.call(info.ptr, state, stack, call, out_result); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + status = iree_status_annotate_f( + status, "while invoking C++ function %s.%.*s", module->name_, + (int)info.name.size, info.name.data); + return status; + } return iree_vm_stack_function_leave(stack); - ; } const char* name_;