Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding experimental synchronous executor using inline command buffers. #5509

Merged
merged 1 commit into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions iree/base/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,14 @@ static inline iree_timeout_t iree_immediate_timeout() {
return timeout;
}

// Returns true if the |timeout| indicates an immediate/polling/nonblocking
// timeout.
static inline bool iree_timeout_is_immediate(iree_timeout_t timeout) {
return timeout.type == IREE_TIMEOUT_ABSOLUTE
? timeout.nanos == IREE_TIME_INFINITE_PAST
: timeout.nanos == IREE_DURATION_ZERO;
}

// Returns a timeout that will never be reached.
// This can be used with APIs that can wait to disable the early
// deadline-exceeded returns when a condition is not met. It should be used with
Expand All @@ -852,6 +860,13 @@ static inline iree_timeout_t iree_infinite_timeout() {
return timeout;
}

// Returns true if the |timeout| indicates an infinite/forever blocking timeout.
static inline bool iree_timeout_is_infinite(iree_timeout_t timeout) {
return timeout.type == IREE_TIMEOUT_ABSOLUTE
? timeout.nanos == IREE_TIME_INFINITE_FUTURE
: timeout.nanos == IREE_DURATION_INFINITE;
}

// Defines an absolute timeout with the given time in nanoseconds.
static inline iree_timeout_t iree_make_deadline(iree_time_t deadline_ns) {
iree_timeout_t timeout = {IREE_TIMEOUT_ABSOLUTE, deadline_ns};
Expand Down
6 changes: 6 additions & 0 deletions iree/hal/cts/command_buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace cts {
using ::testing::ContainerEq;

class CommandBufferTest : public CtsTestBase {
public:
CommandBufferTest() {
// TODO(#4680): command buffer recording so that this can run on sync HAL.
SkipUnavailableDriver("dylib-sync");
}

protected:
static constexpr iree_device_size_t kBufferSize = 4096;
};
Expand Down
6 changes: 4 additions & 2 deletions iree/hal/cts/cts_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class CtsTestBase : public ::testing::TestWithParam<std::string> {
driver_block_list_.insert(driver_name);
}
// Allow skipping tests for unsupported features.
void declareUnavailableDriver(const std::string& driver_name) {
void SkipUnavailableDriver(const std::string& driver_name) {
driver_block_list_.insert(driver_name);
}

Expand Down Expand Up @@ -170,7 +170,9 @@ struct GenerateTestName {
template <class ParamType>
std::string operator()(
const ::testing::TestParamInfo<ParamType>& info) const {
return info.param;
std::string name = info.param;
std::replace(name.begin(), name.end(), '-', '_');
return name;
}
};

Expand Down
8 changes: 7 additions & 1 deletion iree/hal/cts/event_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ namespace iree {
namespace hal {
namespace cts {

class EventTest : public CtsTestBase {};
class EventTest : public CtsTestBase {
public:
EventTest() {
// TODO(#4680): command buffer recording so that this can run on sync HAL.
SkipUnavailableDriver("dylib-sync");
}
};

TEST_P(EventTest, Create) {
iree_hal_event_t* event;
Expand Down
8 changes: 6 additions & 2 deletions iree/hal/cts/semaphore_submission_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ namespace cts {

class SemaphoreSubmissionTest : public CtsTestBase {
public:
// Disable cuda backend for this test as semaphores are not implemented yet.
SemaphoreSubmissionTest() { declareUnavailableDriver("cuda"); }
SemaphoreSubmissionTest() {
// Disable cuda backend for this test as semaphores are not implemented yet.
SkipUnavailableDriver("cuda");
// TODO(#4680): command buffer recording so that this can run on sync HAL.
SkipUnavailableDriver("dylib-sync");
}
};

TEST_P(SemaphoreSubmissionTest, SubmitWithNoCommandBuffers) {
Expand Down
2 changes: 1 addition & 1 deletion iree/hal/cts/semaphore_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace cts {
class SemaphoreTest : public CtsTestBase {
public:
// Disable cuda backend for this test as semaphores are not implemented yet.
SemaphoreTest() { declareUnavailableDriver("cuda"); }
SemaphoreTest() { SkipUnavailableDriver("cuda"); }
};

// Tests that a semaphore that is unused properly cleans itself up.
Expand Down
10 changes: 7 additions & 3 deletions iree/hal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,13 @@ IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_device_submit_and_wait(
// Returns success if the wait is successful and semaphores have been signaled
// satisfying the |wait_mode|.
//
// Returns DEADLINE_EXCEEDED if the |timeout| elapses without the |wait_mode|
// being satisfied. Note that even on success only a subset of the semaphores
// may have been signaled and each can be queried to see which ones.
// Returns IREE_STATUS_DEADLINE_EXCEEDED if the |timeout| elapses without the
// |wait_mode| being satisfied. Note that even on success only a subset of the
// semaphores may have been signaled and each can be queried to see which ones.
//
// Returns IREE_STATUS_ABORTED if one or more semaphores has failed. Callers can
// use iree_hal_semaphore_query on the semaphores to find the ones that have
// failed and get the status.
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_device_wait_semaphores(
iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t* semaphore_list, iree_timeout_t timeout);
Expand Down
1 change: 1 addition & 0 deletions iree/hal/drivers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ cc_library(
] + [
# TODO(*): select() and only pull in based on build configuration.
"//iree/hal/dylib/registration",
"//iree/hal/dylib/registration:sync",
"//iree/hal/vmla/registration",
"//iree/hal/vulkan/registration",
] + IREE_CUDA_DEPS,
Expand Down
13 changes: 2 additions & 11 deletions iree/hal/drivers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ if(${IREE_HAL_DRIVER_CUDA})
endif()
if(${IREE_HAL_DRIVER_DYLIB})
list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration)
# TODO(benvanik): add a IREE_HAL_DRIVER_DYLIB_SYNC or global flag.
list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::dylib::registration::sync)
endif()
if(${IREE_HAL_DRIVER_VMLA})
list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::registration)
Expand All @@ -28,14 +30,6 @@ if(${IREE_HAL_DRIVER_VULKAN})
list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vulkan::registration)
endif()

# TODO: Move to either hal/metal/CMakeLists.txt or
# hal/metal/registration/CMakeLists.txt if bazel-to-cmake issues can be
# resolved.
if(APPLE)
find_library(Foundation Foundation)
find_library(Metal Metal)
endif()

iree_cc_library(
NAME
drivers
Expand All @@ -47,8 +41,5 @@ iree_cc_library(
iree::base::api
iree::base::tracing
${IREE_HAL_DRIVER_MODULES}
# TODO: Also move as above if bazel-to-cmake issues can be resolved.
$<$<PLATFORM_ID:Darwin>:${Foundation}>
$<$<PLATFORM_ID:Darwin>:${Metal}>
PUBLIC
)
9 changes: 9 additions & 0 deletions iree/hal/drivers/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
#include "iree/hal/dylib/registration/driver_module.h"
#endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE

#if defined(IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE)
#include "iree/hal/dylib/registration/driver_module_sync.h"
#endif // IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE

#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
#include "iree/hal/vmla/registration/driver_module.h"
#endif // IREE_HAL_HAVE_VMLA_DRIVER_MODULE
Expand All @@ -46,6 +50,11 @@ iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) {
z0, iree_hal_dylib_driver_module_register(registry));
#endif // IREE_HAL_HAVE_DYLIB_DRIVER_MODULE

#if defined(IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_dylib_sync_driver_module_register(registry));
#endif // IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE

#if defined(IREE_HAL_HAVE_VMLA_DRIVER_MODULE)
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_vmla_driver_module_register(registry));
Expand Down
14 changes: 14 additions & 0 deletions iree/hal/dylib/registration/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,17 @@ cc_library(
"@com_google_absl//absl/flags:flag",
],
)

cc_library(
name = "sync",
srcs = ["driver_module_sync.c"],
hdrs = ["driver_module_sync.h"],
defines = [
"IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE=1",
],
deps = [
"//iree/hal:api",
"//iree/hal/local:sync_driver",
"//iree/hal/local/loaders:legacy_library_loader",
],
)
16 changes: 16 additions & 0 deletions iree/hal/dylib/registration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,20 @@ iree_cc_library(
PUBLIC
)

iree_cc_library(
NAME
sync
HDRS
"driver_module_sync.h"
SRCS
"driver_module_sync.c"
DEPS
iree::hal::api
iree::hal::local::loaders::legacy_library_loader
iree::hal::local::sync_driver
DEFINES
"IREE_HAL_HAVE_DYLIB_SYNC_DRIVER_MODULE=1"
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
79 changes: 79 additions & 0 deletions iree/hal/dylib/registration/driver_module_sync.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "iree/hal/dylib/registration/driver_module_sync.h"

#include <inttypes.h>

#include "iree/hal/local/loaders/legacy_library_loader.h"
#include "iree/hal/local/sync_driver.h"

// TODO(#4298): remove this driver registration and wrapper.
// By having a single iree/hal/local/registration that then has the loaders
// added to it based on compilation settings we can have a single set of flags
// for everything.

#define IREE_HAL_DYLIB_SYNC_DRIVER_ID 0x53444C4Cu // SDLL

static iree_status_t iree_hal_dylib_sync_driver_factory_enumerate(
void* self, const iree_hal_driver_info_t** out_driver_infos,
iree_host_size_t* out_driver_info_count) {
static const iree_hal_driver_info_t default_driver_info = {
.driver_id = IREE_HAL_DYLIB_SYNC_DRIVER_ID,
.driver_name = iree_string_view_literal("dylib-sync"),
.full_name = iree_string_view_literal("AOT compiled dynamic libraries"),
};
*out_driver_info_count = 1;
*out_driver_infos = &default_driver_info;
return iree_ok_status();
}

static iree_status_t iree_hal_dylib_sync_driver_factory_try_create(
void* self, iree_hal_driver_id_t driver_id, iree_allocator_t allocator,
iree_hal_driver_t** out_driver) {
if (driver_id != IREE_HAL_DYLIB_SYNC_DRIVER_ID) {
return iree_make_status(IREE_STATUS_UNAVAILABLE,
"no driver with ID %016" PRIu64
" is provided by this factory",
driver_id);
}

iree_hal_sync_device_params_t default_params;
iree_hal_sync_device_params_initialize(&default_params);

iree_hal_executable_loader_t* dylib_loader = NULL;
iree_status_t status =
iree_hal_legacy_library_loader_create(allocator, &dylib_loader);
iree_hal_executable_loader_t* loaders[1] = {dylib_loader};

if (iree_status_is_ok(status)) {
status = iree_hal_sync_driver_create(
iree_make_cstring_view("dylib"), &default_params,
IREE_ARRAYSIZE(loaders), loaders, allocator, out_driver);
}

iree_hal_executable_loader_release(dylib_loader);
return status;
}

IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_dylib_sync_driver_module_register(
iree_hal_driver_registry_t* registry) {
static const iree_hal_driver_factory_t factory = {
.self = NULL,
.enumerate = iree_hal_dylib_sync_driver_factory_enumerate,
.try_create = iree_hal_dylib_sync_driver_factory_try_create,
};
return iree_hal_driver_registry_register_factory(registry, &factory);
}
34 changes: 34 additions & 0 deletions iree/hal/dylib/registration/driver_module_sync.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_SYNC_H_
#define IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_SYNC_H_

#include "iree/hal/api.h"

#ifdef __cplusplus
extern "C" {
#endif // __cplusplus

// DEPRECATED: this entire driver will be removed soon.
benvanik marked this conversation as resolved.
Show resolved Hide resolved
// TODO(#3580): remove this entire driver w/ iree_hal_executable_library_t.
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_hal_dylib_sync_driver_module_register(
iree_hal_driver_registry_t* registry);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

#endif // IREE_HAL_DYLIB_REGISTRATION_DRIVER_MODULE_SYNC_H_
28 changes: 28 additions & 0 deletions iree/hal/local/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,34 @@ cc_library(
],
)

cc_library(
name = "sync_driver",
srcs = [
"sync_device.c",
"sync_driver.c",
"sync_event.c",
"sync_semaphore.c",
],
hdrs = [
"sync_device.h",
"sync_driver.h",
"sync_event.h",
"sync_semaphore.h",
],
deps = [
":arena",
":local",
"//iree/base:api",
"//iree/base:core_headers",
"//iree/base:synchronization",
"//iree/base:tracing",
"//iree/base/internal",
"//iree/base/internal:wait_handle",
"//iree/hal:api",
"//iree/task",
],
)

cc_library(
name = "task_driver",
srcs = [
Expand Down
Loading