Skip to content

Commit

Permalink
[runtime][python] Add multi-device HAL module construction (iree-org#…
Browse files Browse the repository at this point in the history
…17943)

The underlying C HAL function supports creating the HAL with multiple
devices. The Python API should support that as well.

---------

Signed-off-by: Boian Petkantchin <[email protected]>
  • Loading branch information
sogartar authored Aug 7, 2024
1 parent ca24b96 commit ea8b4fb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 7 deletions.
40 changes: 34 additions & 6 deletions runtime/bindings/python/hal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

#include "./hal.h"

#include <nanobind/nanobind.h>
#include <nanobind/stl/vector.h>

#include <optional>

#include "./local_dlpack.h"
#include "./numpy_interop.h"
#include "./vm.h"
Expand Down Expand Up @@ -1066,12 +1071,34 @@ HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri,
// HAL module
//------------------------------------------------------------------------------

// TODO(multi-device): allow for multiple devices to be passed in.
VmModule CreateHalModule(VmInstance* instance, HalDevice* device) {
iree_hal_device_t* device_ptr = device->raw_ptr();
VmModule CreateHalModule(VmInstance* instance, std::optional<HalDevice*> device,
std::optional<py::list> devices) {
if (device && devices) {
PyErr_SetString(
PyExc_ValueError,
"\"device\" and \"devices\" are mutually exclusive arguments.");
}
std::vector<iree_hal_device_t*> devices_vector;
iree_hal_device_t* device_ptr;
iree_hal_device_t** devices_ptr;
iree_host_size_t device_count;
iree_vm_module_t* module = NULL;
CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), /*device_count=*/1,
&device_ptr, IREE_HAL_MODULE_FLAG_NONE,
if (device) {
device_ptr = device.value()->raw_ptr();
devices_ptr = &device_ptr;
device_count = 1;
} else {
// Set device related arguments in the case of multiple devices.
devices_vector.reserve(devices->size());
for (auto devicesIt = devices->begin(); devicesIt != devices->end();
++devicesIt) {
devices_vector.push_back(py::cast<HalDevice*>(*devicesIt)->raw_ptr());
}
devices_ptr = devices_vector.data();
device_count = devices_vector.size();
}
CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count,
devices_ptr, IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &module),
"Error creating hal module");
return VmModule::StealFromRawPtr(module);
Expand All @@ -1085,7 +1112,8 @@ void SetupHalBindings(nanobind::module_ m) {
py::dict driver_cache;

// Built-in module creation.
m.def("create_hal_module", &CreateHalModule);
m.def("create_hal_module", &CreateHalModule, py::arg("instance"),
py::arg("device") = py::none(), py::arg("devices") = py::none());

// Enums.
py::enum_<enum iree_hal_memory_type_bits_t>(m, "MemoryType")
Expand Down
6 changes: 5 additions & 1 deletion runtime/bindings/python/iree/runtime/_binding.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ from typing import overload

import asyncio

def create_hal_module(instance: VmInstance, device: HalDevice) -> VmModule: ...
def create_hal_module(
instance: VmInstance,
device: Optional[HalDevice] = None,
devices: Optional[List[HalDevice]] = None,
) -> VmModule: ...
def create_io_parameters_module(
instance: VmInstance, *providers: ParameterProvider
) -> VmModule: ...
Expand Down
9 changes: 9 additions & 0 deletions runtime/bindings/python/tests/vm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ def test_synchronous_invoke_function_new_abi(self):
logging.info("result: %s", result)
np.testing.assert_allclose(result, [4.0, 10.0, 18.0, 28.0])

def test_create_vm_module_with_multiple_devices(self):
"""Sanity test that we can create a VM module with 2 devices."""
devices = [
iree.runtime.get_device("local-task"),
iree.runtime.get_device("local-sync"),
]
module = iree.runtime.create_hal_module(self.instance, devices=devices)
assert isinstance(module, iree.runtime.VmModule)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit ea8b4fb

Please sign in to comment.