From 5d92f75fabb439b1172a33627e7d33557f4ed9d1 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Sat, 2 Apr 2022 23:26:38 +0200 Subject: [PATCH] DLTensor progress --- .github/workflows/ci.yml | 2 +- README.md | 91 ++---- cmake/nanobind-config.cmake | 8 + docs/lowlevel.md | 67 ++++ docs/tensor.md | 218 +++++++++++++ include/nanobind/dlpack.h | 223 ------------- include/nanobind/nb_descr.h | 15 + include/nanobind/nb_lib.h | 27 +- include/nanobind/tensor.h | 336 +++++++++++++++++++ src/common.cpp | 194 +---------- src/internals.cpp | 5 +- src/internals.h | 27 ++ src/nb_func.cpp | 3 +- src/tensor.cpp | 630 ++++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 4 +- tests/test_classes.cpp | 2 +- tests/test_classes.py | 2 +- tests/test_dlpack.cpp | 30 -- tests/test_dlpack.py | 69 ---- tests/test_tensor.cpp | 128 ++++++++ tests/test_tensor.py | 293 +++++++++++++++++ 21 files changed, 1797 insertions(+), 577 deletions(-) create mode 100644 docs/lowlevel.md create mode 100644 docs/tensor.md delete mode 100644 include/nanobind/dlpack.h create mode 100644 include/nanobind/tensor.h create mode 100644 src/tensor.cpp delete mode 100644 tests/test_dlpack.cpp delete mode 100644 tests/test_dlpack.py create mode 100644 tests/test_tensor.cpp create mode 100644 tests/test_tensor.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fefde58e..105d5c31 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,7 +44,7 @@ jobs: - name: Install Python dependencies run: | - python -m pip install pytest pytest-github-actions-annotate-failures + python -m pip install pytest pytest-github-actions-annotate-failures numpy - name: Configure run: > diff --git a/README.md b/README.md index be8e67f9..8c028bf1 100644 --- a/README.md +++ b/README.md @@ -167,7 +167,10 @@ Removed features include: - ● PyPy support is gone. (PyPy requires many workaround in _pybind11_ that complicate the its internals. Making PyPy interoperate with _nanobind_ will likely require changes to the PyPy CPython emulation layer.) -- ◑ Eigen and NumPy integration have been removed. +- ◑ NumPy integration was replaced by a more general ``nb::tensor<>`` + integration that supports CPU/GPU tensors produced by various frameworks + (NumPy, PyTorch, TensorFlow, JAX, ..). +- ◑ Eigen integration was removed. - ◑ Buffer protocol functionality was removed. - ◑ Nested exceptions are not supported. - ◑ Features to facilitate pickling and unpickling were removed. @@ -240,6 +243,11 @@ improvements for developers: when the binding code of those C++ types hasn't yet run. _nanobind_ does not pre-render function docstrings, they are created on the fly when queried. +- _nanobind_ has [greatly + improved](https://github.com/wjakob/nanobind/blob/master/docs/tensor.md) + support for exchanging tensor data structures with modern array programming + frameworks. + ### Dependencies _nanobind_ depends on recent versions of everything: @@ -467,6 +475,19 @@ changes are detailed below. - **New features.** + - **Unified DLPack/Buffer protocol integration**: _nanobind_ can retrieve and + return tensors using two standard protocols: + [DLPack](https://github.com/dmlc/dlpack), and the the [buffer + protocol](https://docs.python.org/3/c-api/buffer.html). This enables + zero-copy data exchange of CPU and GPU tensors with array programming + frameworks including [NumPy](https://numpy.org), + [PyTorch](https://pytorch.org), [TensorFlow](https://www.tensorflow.org), + [JAX](https://jax.readthedocs.io), etc. + + Details on using this feature can be found + [here](https://github.com/wjakob/nanobind/blob/master/docs/tensor.md). + + - **Supplemental type data**: _nanobind_ can store supplemental data along with registered types. This information is co-located with the Python type object. An example use of this fairly advanced feature are libraries that @@ -488,75 +509,21 @@ changes are detailed below. - **Low-level interface**: _nanobind_ exposes a low-level interface to provide fine-grained control over the sequence of steps that instantiates a - Python object wrapping a C++ instance. An example is shown below: - - ```cpp - /* Look up the Python type object associated with a C++ class named `MyClass`. - Requires a previous nb::class_<> binding declaration, otherwise this line - will return a NULL pointer (this can be checked via py_type.is_valid()). */ - nb::handle py_type = nb::type(); - - // Type metadata can also be queried in the other direction - assert(py_type.is_valid() && // Did the type lookup work? - nb::type_check(py_type) && - nb::type_size(py_type) == sizeof(MyClass) && // nanobind knows the size+alignment - nb::type_align(py_type) == alignof(MyClass) && - nb::type_info(py_type) == typeid(MyClass)); // Query C++ RTTI record - - /* Allocate an uninitialized Python instance of this type. Nanobind will - refuse to pass this (still unitialized) object to bound C++ functions */ - nb::object py_inst = nb::inst_alloc(py_type); - assert(nb::inst_check(py_inst) && py_inst.type().is(py_type) && !nb::inst_ready(py_inst)); - - /* For POD types, the following line zero-initializes the object and marks - it as ready. Alternatively, the next lines show how to perform a fancy - object initialization using the C++ constructor */ - // nb::inst_zero(py_inst); - - // Get a C++ pointer to the uninitialized instance data - MyClass *ptr = nb::inst_ptr(py_inst); - - // Perform an in-place construction of the C++ object - new (ptr) MyClass(); - - /* Mark the Python object as ready. When reference count reaches zero, - nanobind will automatically call the destructor (MyClass::~MyClass). */ - nb::inst_mark_ready(py_inst); - assert(nb::inst_ready(py_inst)); - - /* Alternatively, we can force-call the destructor and transition the - instance back to non-ready status. The instance could then be reused - by initializing it yet again. */ - nb::inst_destruct(py_inst); - assert(!nb::inst_ready(py_inst)); - - /* We can copy- or move-construct 'py_inst' from another instance of the - same type. This calls the C++ copy or move constructor and transitions - 'py_inst' back to 'ready' status. Note that this is equivalent to calling - an in-place version of these constructors above but compiles to more - compact code (the 'nb::class_' declaration has already created - bindings for both constructors, and this simply calls those bindings). */ - // nb::inst_copy(/* dst = */ py_inst, /* src = */ some_other_instance); - // nb::inst_move(/* dst = */ py_inst, /* src = */ some_other_instance); - ``` + Python object wrapping a C++ instance. Like the above point, this is useful + when writing generic binding code that manipulates _nanobind_-based objects + of various types. - Note that these functions are all _unsafe_ in the sense that they do not - verify that their input arguments are valid. This is done for performance - reasons, and such checks (if needed) are therefore the responsibility of - the caller. Functions labeled `nb::type_*` should only be called with - _nanobind_ type objects, and functions labeled `nb::inst_*` should only be - called with _nanobind_ instance objects. The functions `nb::type_check()` - and `nb::inst_check()` accept any Python object and test whether something - is a _nanobind_ type or instance object. + Details on using this feature can be found + [here](https://github.com/wjakob/nanobind/blob/master/docs/lowlevel.md). - **Python type wrappers**: The `nb::handle_of` type behaves just like the `nb::handle` class and wraps a `PyObject *` pointer. However, when binding a function that takes such an argument, _nanobind_ will only call the associated function overload when the underlying Python object wraps a C++ - `T` instance. + instance of type `T`. - **Raw docstrings**: In cases where absolute control over docstrings is - required (for example, so that they can be parsed by a tool like + required (for example, so that complex cases can be parsed by a tool like [Sphinx](https://www.sphinx-doc.org)), the ``nb::raw_doc`` attribute can be specified to functions. In this case, _nanobind_ will _skip_ generation of a combined docstring that enumerates overloads along with type information. diff --git a/cmake/nanobind-config.cmake b/cmake/nanobind-config.cmake index e1ac385f..3323e3b9 100644 --- a/cmake/nanobind-config.cmake +++ b/cmake/nanobind-config.cmake @@ -54,12 +54,20 @@ function (nanobuild_build_library TARGET_NAME TARGET_TYPE) ${NB_DIR}/include/nanobind/nb_traits.h ${NB_DIR}/include/nanobind/nb_types.h ${NB_DIR}/include/nanobind/trampoline.h + ${NB_DIR}/include/nanobind/tensor.h + ${NB_DIR}/include/nanobind/operators.h + ${NB_DIR}/include/nanobind/stl/shared_ptr.h + ${NB_DIR}/include/nanobind/stl/unique_ptr.h ${NB_DIR}/include/nanobind/stl/string.h + ${NB_DIR}/include/nanobind/stl/tuple.h + ${NB_DIR}/include/nanobind/stl/pair.h + ${NB_DIR}/include/nanobind/stl/function.h ${NB_DIR}/src/internals.h ${NB_DIR}/src/buffer.h ${NB_DIR}/src/internals.cpp ${NB_DIR}/src/common.cpp + ${NB_DIR}/src/tensor.cpp ${NB_DIR}/src/nb_func.cpp ${NB_DIR}/src/nb_type.cpp ${NB_DIR}/src/nb_enum.cpp diff --git a/docs/lowlevel.md b/docs/lowlevel.md new file mode 100644 index 00000000..34835a86 --- /dev/null +++ b/docs/lowlevel.md @@ -0,0 +1,67 @@ +# **Low-level instance interface** + +_nanobind_ exposes a low-level interface to provide fine-grained control over +the sequence of steps that instantiates a Python object wrapping a C++ +instance. Like the above point, this is useful when writing generic binding +code that manipulates _nanobind_-based objects of various types. + +An example is shown below: + +```cpp +/* Look up the Python type object associated with a C++ class named `MyClass`. + Requires a previous nb::class_<> binding declaration, otherwise this line + will return a NULL pointer (this can be checked via py_type.is_valid()). */ +nb::handle py_type = nb::type(); + +// Type metadata can also be queried in the other direction +assert(py_type.is_valid() && // Did the type lookup work? + nb::type_check(py_type) && + nb::type_size(py_type) == sizeof(MyClass) && // nanobind knows the size+alignment + nb::type_align(py_type) == alignof(MyClass) && + nb::type_info(py_type) == typeid(MyClass)); // Query C++ RTTI record + +/* Allocate an uninitialized Python instance of this type. Nanobind will + refuse to pass this (still unitialized) object to bound C++ functions */ +nb::object py_inst = nb::inst_alloc(py_type); +assert(nb::inst_check(py_inst) && py_inst.type().is(py_type) && !nb::inst_ready(py_inst)); + +/* For POD types, the following line zero-initializes the object and marks + it as ready. Alternatively, the next lines show how to perform a fancy + object initialization using the C++ constructor */ +// nb::inst_zero(py_inst); + +// Get a C++ pointer to the uninitialized instance data +MyClass *ptr = nb::inst_ptr(py_inst); + +// Perform an in-place construction of the C++ object +new (ptr) MyClass(); + +/* Mark the Python object as ready. When reference count reaches zero, + nanobind will automatically call the destructor (MyClass::~MyClass). */ +nb::inst_mark_ready(py_inst); +assert(nb::inst_ready(py_inst)); + +/* Alternatively, we can force-call the destructor and transition the + instance back to non-ready status. The instance could then be reused + by initializing it yet again. */ +nb::inst_destruct(py_inst); +assert(!nb::inst_ready(py_inst)); + +/* We can copy- or move-construct 'py_inst' from another instance of the + same type. This calls the C++ copy or move constructor and transitions + 'py_inst' back to 'ready' status. Note that this is equivalent to calling + an in-place version of these constructors above but compiles to more + compact code (the 'nb::class_' declaration has already created + bindings for both constructors, and this simply calls those bindings). */ +// nb::inst_copy(/* dst = */ py_inst, /* src = */ some_other_instance); +// nb::inst_move(/* dst = */ py_inst, /* src = */ some_other_instance); +``` + +Note that these functions are all _unsafe_ in the sense that they do not +verify that their input arguments are valid. This is done for performance +reasons, and such checks (if needed) are therefore the responsibility of +the caller. Functions labeled `nb::type_*` should only be called with +_nanobind_ type objects, and functions labeled `nb::inst_*` should only be +called with _nanobind_ instance objects. The functions `nb::type_check()` +and `nb::inst_check()` accept any Python object and test whether something +is a _nanobind_ type or instance object. diff --git a/docs/tensor.md b/docs/tensor.md new file mode 100644 index 00000000..6dbbe5bc --- /dev/null +++ b/docs/tensor.md @@ -0,0 +1,218 @@ +# Retrieving and returning tensors + +_nanobind_ can retrieve and return tensors using two common data exchange +formats. + +- The classic [buffer protocol](https://docs.python.org/3/c-api/buffer.html). +- [DLPack](https://github.com/dmlc/dlpack), which is a GPU-compatible + generalization of the buffer protocol. + +This feature enables _zero-copy_ data exchange with various modern array +programming frameworks including [NumPy](https://numpy.org), +[PyTorch](https://pytorch.org), [TensorFlow](https://www.tensorflow.org), and +[JAX](https://jax.readthedocs.io). _nanobind_ knows how to talk to each of +these frameworks and takes care of all the nitty-gritty details. + +To use this feature, you must include the optional header file +[`nanobind/tensor.h`](https://github.com/wjakob/nanobind/blob/master/include/nanobind/tensor.h) +Following this, you can bind functions that involve `nb::tensor<>`-typed +parameters or return values. + +## Binding functions that take tensors as input + +A function that accepts a `nb::tensor<>` parameter can be called with *any* +tensor from any framework regardless of the device on which it is stored. The +following example binding declaration uses this functionality to inspect the +properties of an arbitrary tensor: + +```cpp +m.def("inspect", [](nb::tensor<> tensor) { + printf("Tensor data pointer : %p\n", tensor.data()); + printf("Tensor dimension : %zu\n", tensor.ndim()); + for (size_t i = 0; i < tensor.ndim(); ++i) { + printf("Tensor dimension [%zu] : %zu\n", i, tensor.shape(i)); + printf("Tensor stride [%zu] : %zd\n", i, tensor.stride(i)); + } + printf("Device ID = %u (cpu=%i, cuda=%i)\n", tensor.device_id(), + int(tensor.device_type() == nb::device::cpu::value), + int(tensor.device_type() == nb::device::cuda::value) + ); + printf("Tensor dtype: int16=%i, uint32=%i, float32=%i\n", + tensor.dtype() == nb::dtype(), + tensor.dtype() == nb::dtype(), + tensor.dtype() == nb::dtype() + ); +}); +``` + +Below is an example of what this function does when called with a NumPy array: +```pycon +>>> my_module.inspect(np.array([[1,2,3], [3,4,5]], dtype=np.float32)) +Tensor data pointer : 0x1c30f60 +Tensor dimension : 2 +Tensor dimension [0] : 2 +Tensor stride [0] : 3 +Tensor dimension [1] : 3 +Tensor stride [1] : 1 +Device ID = 0 (cpu=1, cuda=0) +Tensor dtype: int16=0, uint32=0, float32=1 +``` + +## Tensor constraints + +In practice, it can often be useful to *constrain* what kinds of tensors +constitute valid inputs to a function. For example, a function expecting CPU +storage would likely crash if given a pointer to GPU memory, and _nanobind_ +should therefore prevent such undefined behavior. `nb::tensor<>` accepts +template arguments to specify such constraints. For example the function +interface below guarantees that the implementation is only invoked when +`tensor` represents a `MxNx3` tensor of 8-bit unsigned integers that is +stored contiguously in CPU memory using a C-style array ordering convention. + +```cpp +m.def("process", [](nb::tensor, nb::c_contig, nb::device::cpu> tensor) { + // Double brightness of the MxNx3 RGB image + for (size_t y = 0; y < tensor.shape(0); ++y) + for (size_t x = 0; y < tensor.shape(1); ++x) + for (size_t ch = 0; ch < 3; ++ch) + tensor(y, x, ch) = (uint8_t) std::min(255, tensor(y, x, ch) * 2); + +}); +``` + +The above example also demonstrates the use of `nb::tensor<...>::operator()`, +which provides direct (i.e., high-performance) read/write access to the tensor +data. Note that this function is only available when the underlying data type +and tensor rank are specified. + +### Tensor constraints + +The following tensor constraints are available + +- `nb::shape` annotation simultaneously constrains the tensor rank and + the size along specific dimensions. A `nb::any` entry leaves the + corresponding dimension unconstrained. + +- Device tags: `nb::device::cpu`, `nb::device::cuda`, `nb::device::cuda_host`, + `nb::device::cuda_managed`, `nb::device::opencl`, `nb::device::vulkan`, + `nb::device::metal`, `nb::device::rocm`, `nb::device::rocm_host`, and + `nb::device::oneapi`. + +- Ordering tags: `nb::c_contig` (contiguous C-style array storage), + and `nb::f_contig` (contiguous Fortran-style array storage). + +## Passing `nb::tensor<>` instances in C++ code + +`nb::tensor<>` behaves like a shared pointer with builtin reference +counting: it can be moved or copied within C++ code. Copies will point to +the same underlying buffer and increase the reference count until they go +out of scope. It is legal call `nb::tensor<>` members from multithreaded +code even when the +[GIL](https://wiki.python.org/moin/GlobalInterpreterLock) is not held. + +## Tensors in docstrings + +_nanobind_ displays tensor constraints in docstrings and error messages. +For example, suppose that we now call the `process()` function with an +invalid input. This produces the following error message: + +```pycon +>>> my_module.process(tensor=np.zeros(1)) + +TypeError: process(): incompatible function arguments. The following argument types are supported: +1. process(arg: tensor[dtype=uint8, shape=(*, *, 3), order='C', device='cpu'], /) -> None + +Invoked with types: numpy.ndarray +``` + +Note that these type annotations are intended for humans--they will not +currently work with automatic type checking tools like +[MyPy](https://mypy.readthedocs.io/en/stable/) (which at least for the time +being don't provide a portable or sufficiently flexible annotation of tensor +objects. ) + +## Tensors and function overloads + +A bound function taking a tensor argument can declare multiple overloads +with different constraints (e.g. a CPU and GPU implementation), in which +case the first first matching overload will be called. When no perfect +match could be found, _nanobind_ will try each overload once more while +performing basic implicit conversions: it will convert strided arrays into +C- or F-contiguous arrays (if requested) and perform type conversion. This, +e.g., makes possible to call a function expecting a `float32` array with +`float64` data. Implicit conversions create temporary tensors containing a +copy of the data, which can be undesirable. To suppress then, add a +`nb::arg("tensor").noconvert()` or `"tensor"_a.noconvert()` function +binding annotation. + + +## Binding functions that return tensors + +To return a tensor from C++ code, you must indicate its type, shape, a pointer +to CPU/GPU memory, and what tensor framework (NumPy/..) should be used to +encapsulate the data. + +The following simple binding declaration shows how to return a `2x4` NumPy +floating point matrix. + +```cpp +const float *data = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + +m.def("ret_pytorch", []() { + size_t shape[2] = { 2, 4 }; + return nb::tensor>( + data, /* ndim = */ 2, shape); +}); + +# The auto-generated docstring of this function is: +# ret_pytorch() -> np.ndarray[float32, shape=(2, *)] + +# Calling it in Python yields +# array([[1., 2., 3., 4.], +# [5., 6., 7., 8.]], dtype=float32) +``` + +The following additional tensor declarations are possible for return values: + +- `nb::numpy`. Returns the tensor as a `numpy.ndarray`. +- `nb::pytorch`. Returns the tensor as a `torch.Tensor`. +- `nb::tensorflow`. Returns the tensor as a `tensorflow.python.framework.ops.EagerTensor`. +- `nb.jax`. Returns the tensor as a `jaxlib.xla_extension.DeviceArray`. +- `nb::none` (the default). In this case, _nanobind_ will return a raw Python + `dltensor` [capsule](https://docs.python.org/3/c-api/capsule.html) + representing the [DLPack](https://github.com/dmlc/dlpack) metadata. + +Note that shape and order annotations like `nb::shape` and `nb::c_contig` +enter into docstring, but _nanobind_ won't spend time on additional checks. It +trusts that your method returns what it declares. + +The full signature of the tensor constructor is: +```cpp +tensor(void *value, + size_t ndim, + const size_t *shape, + handle owner = nanobind::handle(), + const int64_t *strides = nullptr, + dlpack::dtype dtype = nanobind::dtype(), + int32_t device_type = device::cpu::value, + int32_t device_id = 0) { .. } +``` + +The `owner` parameter can be used to keep another Python object alive while +the tensor data is referenced by a consumer. This mechanism can be used to +implement a data destructor as follows: + + +```cpp +m.def("ret_pytorch", []() { + float *data = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + /// Delete 'data' when the 'deleter' capsule expires + nb::capsule deleter(data, [](void *p) { + delete[] (float *) p; + }); + + return nb::tensor(data, 2, shape, /* owner = */ deleter); +}); +``` diff --git a/include/nanobind/dlpack.h b/include/nanobind/dlpack.h deleted file mode 100644 index 46002404..00000000 --- a/include/nanobind/dlpack.h +++ /dev/null @@ -1,223 +0,0 @@ -/* - nanobind/dlpack.h: functionality to input/output tensors via DLPack - - Copyright (c) 2022 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. - - The API below is based on the DLPack project - (https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h) -*/ - -#include - -NAMESPACE_BEGIN(NB_NAMESPACE) -NAMESPACE_BEGIN(dlpack) - -enum class DeviceType : int32_t { - Undefined = 0, CPU = 1, CUDA = 2, CUDAHost = 3, - OpenCL = 4, Vulkan = 7, Metal = 8, ROCM = 10, - ROCMHost = 11, CUDAManaged = 13, OneAPI = 14 -}; - -enum class DataTypeCode : uint8_t { - Int = 0, UInt = 1, Float = 2, Bfloat = 4, Complex = 5 -} ; - -struct Device { - DeviceType device_type = DeviceType::Undefined; - int32_t device_id = 0; -}; - -struct DataType { - uint8_t code = 0; - uint8_t bits = 0; - uint16_t lanes = 0; - - bool operator==(const DataType &o) const { - return code == o.code && bits == o.bits && lanes == o.lanes; - } - bool operator!=(const DataType &o) const { return !operator==(o); } -}; - -struct Tensor { - void *data = nullptr; - Device device; - int32_t ndim = 0; - DataType dtype; - int64_t *shape = nullptr; - int64_t *strides = nullptr; - uint64_t byte_offset = 0; -}; - -NAMESPACE_END(dlpack) - -constexpr size_t any = (size_t) -1; -template class shape { }; -template class order { - static_assert(O == 'C' || O == 'F', "Only C ('C') and Fortran ('F')-style " - "ordering conventions are supported!"); -}; - -template constexpr dlpack::DataType dtype() { - static_assert( - std::is_floating_point_v || std::is_integral_v, - "nanobind::dtype: T must be a floating point or integer variable!" - ); - - dlpack::DataType result; - - if constexpr (std::is_floating_point_v) - result.code = (uint8_t) dlpack::DataTypeCode::Float; - else if constexpr (std::is_signed_v) - result.code = (uint8_t) dlpack::DataTypeCode::Int; - else - result.code = (uint8_t) dlpack::DataTypeCode::UInt; - - result.bits = sizeof(T) * 8; - result.lanes = 1; - - return result; -} - - -NAMESPACE_BEGIN(detail) - -struct TensorReq { - dlpack::DataType dtype; - uint32_t ndim = 0; - size_t *shape = nullptr; - bool req_shape = false; - bool req_dtype = false; - char req_order = '\0'; -}; - -template struct tensor_arg; - -template struct tensor_arg>> { - static constexpr size_t size = 0; - - static constexpr auto name = - const_name("dtype=float") + const_name(); - - static void apply(TensorReq &tr) { - tr.dtype = dtype(); - tr.req_dtype = true; - } -}; - -template struct tensor_arg>> { - static constexpr size_t size = 0; - - static constexpr auto name = - const_name("dtype=") + const_name>("u", "") + - const_name("int") + const_name(); - - static void apply(TensorReq &tr) { - tr.dtype = dtype(); - tr.req_dtype = true; - } -}; - -template struct tensor_arg> { - static constexpr size_t size = sizeof...(Is); - static constexpr auto name = - const_name("shape=(") + - concat(const_name(const_name("*"), const_name())...) + - const_name(")"); - - static void apply(TensorReq &tr) { - size_t i = 0; - ((tr.shape[i++] = Is), ...); - tr.ndim = (uint32_t) sizeof...(Is); - tr.req_shape = true; - } -}; - -template struct tensor_arg> { - static constexpr size_t size = 0; - static constexpr auto name = - const_name("order='") + const_name(O) + const_name('\''); - - static void apply(TensorReq &tr) { - tr.req_order = O; - } -}; - -NAMESPACE_END(detail) - -template class tensor { -public: - tensor() = default; - - explicit tensor(detail::TensorHandle *handle) : m_handle(handle) { - if (handle) - m_tensor = *detail::tensor_inc_ref(handle); - } - - ~tensor() { - detail::tensor_dec_ref(m_handle); - } - - tensor(const tensor &t) : m_handle(t.m_handle), m_tensor(t.m_tensor) { - detail::tensor_inc_ref(m_handle); - } - - tensor(tensor &&t) noexcept : m_handle(t.m_handle), m_tensor(t.m_tensor) { - t.m_handle = nullptr; - t.m_tensor = dlpack::Tensor(); - } - - tensor &operator=(tensor &&t) noexcept { - detail::tensor_dec_ref(m_handle); - m_handle = t.m_handle; - m_tensor = t.m_tensor; - t.m_handle = nullptr; - t.m_tensor = dlpack::Tensor(); - return *this; - } - - tensor &operator=(const tensor &t) { - detail::tensor_inc_ref(t.m_handle); - detail::tensor_dec_ref(m_handle); - m_handle = t.m_handle; - m_tensor = t.m_tensor; - } - - dlpack::DataType dtype() const { return m_tensor.dtype; } - size_t ndim() const { return m_tensor.ndim; } - size_t shape(size_t i) const { return m_tensor.shape[i]; } - size_t strides(size_t i) const { return m_tensor.strides[i]; } - bool is_valid() const { return m_handle != nullptr; } - -private: - detail::TensorHandle *m_handle = nullptr; - dlpack::Tensor m_tensor; -}; - -NAMESPACE_BEGIN(detail) - -template struct type_caster> { - NB_TYPE_CASTER(tensor, const_name("tensor[") + - concat(detail::tensor_arg::name...) + - const_name("]")); - - bool from_python(handle src, uint8_t, cleanup_list *) noexcept { - constexpr size_t size = (0 + ... + detail::tensor_arg::size); - size_t shape[size + 1]; - detail::TensorReq req; - req.shape = shape; - (detail::tensor_arg::apply(req), ...); - value = tensor(tensor_create(src.ptr(), &req)); - return value.is_valid(); - } - - static handle from_cpp(const tensor &, rv_policy, - cleanup_list *) noexcept { - return handle(); - } -}; - -NAMESPACE_END(detail) -NAMESPACE_END(NB_NAMESPACE) diff --git a/include/nanobind/nb_descr.h b/include/nanobind/nb_descr.h index 7f22a0f9..00077227 100644 --- a/include/nanobind/nb_descr.h +++ b/include/nanobind/nb_descr.h @@ -85,16 +85,31 @@ auto constexpr const_name() -> std::remove_cv_t constexpr descr<1, Type> const_name() { return {'%'}; } constexpr descr<0> concat() { return {}; } +constexpr descr<0> concat_maybe() { return {}; } template constexpr descr concat(const descr &descr) { return descr; } +template +constexpr descr concat_maybe(const descr &descr) { return descr; } + template constexpr auto concat(const descr &d, const Args &...args) -> decltype(std::declval>() + concat(args...)) { return d + const_name(", ") + concat(args...); } +template +constexpr auto concat_maybe(const descr &d, const Args &... args) + -> decltype( + std::declval>() + + concat_maybe(args...)) { + if constexpr (N + sizeof...(Ts) == 0) + return concat_maybe(args...); + else + return d + const_name(", ") + concat_maybe(args...); +} + template constexpr descr type_descr(const descr &descr) { return const_name("{") + descr + const_name("}"); diff --git a/include/nanobind/nb_lib.h b/include/nanobind/nb_lib.h index 1a323b34..c5d57986 100644 --- a/include/nanobind/nb_lib.h +++ b/include/nanobind/nb_lib.h @@ -10,13 +10,13 @@ NAMESPACE_BEGIN(NB_NAMESPACE) // Forward declarations for types in dlpack.h (1) -namespace dlpack { struct Tensor; } +namespace dlpack { struct tensor; struct dtype; } NAMESPACE_BEGIN(detail) // Forward declarations for types in dlpack.h (2) -struct TensorHandle; -struct TensorReq; +struct tensor_handle; +struct tensor_req; /** * Helper class to clean temporaries created by function dispatch. @@ -317,15 +317,26 @@ NB_CORE PyObject *module_new_submodule(PyObject *base, const char *name, // ======================================================================== -// Try to create a reference-counted tensor object via DLPack -NB_CORE TensorHandle *tensor_create(PyObject *o, const TensorReq *req) noexcept; +// Try to import a reference-counted tensor object via DLPack +NB_CORE tensor_handle *tensor_import(PyObject *o, const tensor_req *req, + bool convert) noexcept; + +// Describe a local tensor object using a DLPack capsule +NB_CORE tensor_handle *tensor_create(void *value, size_t ndim, + const size_t *shape, PyObject *owner, + const int64_t *strides, + dlpack::dtype *dtype, int32_t device, + int32_t device_id); /// Increase the reference count of the given tensor object; returns a pointer -/// to the underlying DLTensor -NB_CORE dlpack::Tensor *tensor_inc_ref(TensorHandle *) noexcept; +/// to the underlying DLtensor +NB_CORE dlpack::tensor *tensor_inc_ref(tensor_handle *) noexcept; /// Decrease the reference count of the given tensor object -NB_CORE void tensor_dec_ref(TensorHandle *) noexcept; +NB_CORE void tensor_dec_ref(tensor_handle *) noexcept; + +/// Wrap a tensor_handle* into a PyCapsule +NB_CORE PyObject *tensor_wrap(tensor_handle *, int framework) noexcept; // ======================================================================== diff --git a/include/nanobind/tensor.h b/include/nanobind/tensor.h new file mode 100644 index 00000000..83c5e83a --- /dev/null +++ b/include/nanobind/tensor.h @@ -0,0 +1,336 @@ +/* + nanobind/tensor.h: functionality to input/output tensors via DLPack + + Copyright (c) 2022 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. + + The API below is based on the DLPack project + (https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h) +*/ + +#include + +NAMESPACE_BEGIN(NB_NAMESPACE) + +NAMESPACE_BEGIN(device) +#define NB_DEVICE(enum_name, enum_value) \ + struct enum_name { \ + static constexpr auto name = detail::const_name(#enum_name); \ + static constexpr int32_t value = enum_value; \ + static constexpr bool is_device = true; \ + } +NB_DEVICE(none, 0); NB_DEVICE(cpu, 1); NB_DEVICE(cuda, 2); +NB_DEVICE(cuda_host, 3); NB_DEVICE(opencl, 4); NB_DEVICE(vulkan, 7); +NB_DEVICE(metal, 8); NB_DEVICE(rocm, 10); NB_DEVICE(rocm_host, 11); +NB_DEVICE(cuda_managed, 13); NB_DEVICE(oneapi, 14); +#undef NB_DEVICE +NAMESPACE_END(device) + +NAMESPACE_BEGIN(dlpack) + +enum class dtype_code : uint8_t { + Int = 0, UInt = 1, Float = 2, Bfloat = 4, Complex = 5 +}; + +struct device { + int32_t device_type = 0; + int32_t device_id = 0; +}; + +struct dtype { + uint8_t code = 0; + uint8_t bits = 0; + uint16_t lanes = 0; + + bool operator==(const dtype &o) const { + return code == o.code && bits == o.bits && lanes == o.lanes; + } + bool operator!=(const dtype &o) const { return !operator==(o); } +}; + +struct tensor { + void *data = nullptr; + nanobind::dlpack::device device; + int32_t ndim = 0; + nanobind::dlpack::dtype dtype; + int64_t *shape = nullptr; + int64_t *strides = nullptr; + uint64_t byte_offset = 0; +}; + +NAMESPACE_END(dlpack) + +constexpr size_t any = (size_t) -1; + +template struct shape { + static constexpr size_t size = sizeof...(Is); +}; + +struct c_contig { }; +struct f_contig { }; +struct numpy { }; +struct tensorflow { }; +struct pytorch { }; +struct jax { }; + +template constexpr dlpack::dtype dtype() { + static_assert( + std::is_floating_point_v || std::is_integral_v, + "nanobind::dtype: T must be a floating point or integer variable!" + ); + + dlpack::dtype result; + + if constexpr (std::is_floating_point_v) + result.code = (uint8_t) dlpack::dtype_code::Float; + else if constexpr (std::is_signed_v) + result.code = (uint8_t) dlpack::dtype_code::Int; + else + result.code = (uint8_t) dlpack::dtype_code::UInt; + + result.bits = sizeof(T) * 8; + result.lanes = 1; + + return result; +} + + +NAMESPACE_BEGIN(detail) + +enum class tensor_framework : int { none, numpy, tensorflow, pytorch, jax }; + +struct tensor_req { + dlpack::dtype dtype; + uint32_t ndim = 0; + size_t *shape = nullptr; + bool req_shape = false; + bool req_dtype = false; + char req_order = '\0'; + uint8_t req_device = 0; +}; + +template struct tensor_arg { + static constexpr size_t size = 0; + static constexpr auto name = descr<0>{ }; + static void apply(tensor_req &) { } +}; + +template struct tensor_arg>> { + static constexpr size_t size = 0; + + static constexpr auto name = + const_name("dtype=float") + const_name(); + + static void apply(tensor_req &tr) { + tr.dtype = dtype(); + tr.req_dtype = true; + } +}; + +template struct tensor_arg>> { + static constexpr size_t size = 0; + + static constexpr auto name = + const_name("dtype=") + const_name>("u", "") + + const_name("int") + const_name(); + + static void apply(tensor_req &tr) { + tr.dtype = dtype(); + tr.req_dtype = true; + } +}; + +template struct tensor_arg> { + static constexpr size_t size = sizeof...(Is); + static constexpr auto name = + const_name("shape=(") + + concat(const_name(const_name("*"), const_name())...) + + const_name(")"); + + static void apply(tensor_req &tr) { + size_t i = 0; + ((tr.shape[i++] = Is), ...); + tr.ndim = (uint32_t) sizeof...(Is); + tr.req_shape = true; + } +}; + +template <> struct tensor_arg { + static constexpr size_t size = 0; + static constexpr auto name = const_name("order='C'"); + static void apply(tensor_req &tr) { tr.req_order = 'C'; } +}; + +template <> struct tensor_arg { + static constexpr size_t size = 0; + static constexpr auto name = const_name("order='F'"); + static void apply(tensor_req &tr) { tr.req_order = 'F'; } +}; + +template struct tensor_arg> { + static constexpr size_t size = 0; + static constexpr auto name = const_name("device='") + T::name + const_name('\''); + static void apply(tensor_req &tr) { tr.req_device = (uint8_t) T::value; } +}; + +template struct tensor_info { + using scalar_type = void; + using shape_type = void; + constexpr static auto name = const_name("tensor"); + constexpr static tensor_framework framework = tensor_framework::none; +}; + +template struct tensor_info : tensor_info { + using scalar_type = + std::conditional_t, T, + typename tensor_info::scalar_type>; +}; + +template struct tensor_info, Ts...> : tensor_info { + using shape_type = shape; +}; + +template struct tensor_info : tensor_info { + constexpr static auto name = const_name("numpy.ndarray"); + constexpr static tensor_framework framework = tensor_framework::numpy; +}; + +template struct tensor_info : tensor_info { + constexpr static auto name = const_name("torch.Tensor"); + constexpr static tensor_framework framework = tensor_framework::pytorch; +}; + +template struct tensor_info : tensor_info { + constexpr static auto name = const_name("tensorflow.python.framework.ops.EagerTensor"); + constexpr static tensor_framework framework = tensor_framework::tensorflow; +}; + +template struct tensor_info : tensor_info { + constexpr static auto name = const_name("jaxlib.xla_extension.DeviceArray"); + constexpr static tensor_framework framework = tensor_framework::jax; +}; + +NAMESPACE_END(detail) + +template class tensor { +public: + using Info = detail::tensor_info; + using Scalar = typename Info::scalar_type; + + tensor() = default; + + explicit tensor(detail::tensor_handle *handle) : m_handle(handle) { + if (handle) + m_tensor = *detail::tensor_inc_ref(handle); + } + + tensor(void *value, + size_t ndim, + const size_t *shape, + handle owner = nanobind::handle(), + const int64_t *strides = nullptr, + dlpack::dtype dtype = nanobind::dtype(), + int32_t device_type = device::cpu::value, + int32_t device_id = 0) { + m_handle = + detail::tensor_create(value, ndim, shape, owner.ptr(), strides, + &dtype, device_type, device_id); + m_tensor = *detail::tensor_inc_ref(m_handle); + } + + ~tensor() { + detail::tensor_dec_ref(m_handle); + } + + tensor(const tensor &t) : m_handle(t.m_handle), m_tensor(t.m_tensor) { + detail::tensor_inc_ref(m_handle); + } + + tensor(tensor &&t) noexcept : m_handle(t.m_handle), m_tensor(t.m_tensor) { + t.m_handle = nullptr; + t.m_tensor = dlpack::tensor(); + } + + tensor &operator=(tensor &&t) noexcept { + detail::tensor_dec_ref(m_handle); + m_handle = t.m_handle; + m_tensor = t.m_tensor; + t.m_handle = nullptr; + t.m_tensor = dlpack::tensor(); + return *this; + } + + tensor &operator=(const tensor &t) { + detail::tensor_inc_ref(t.m_handle); + detail::tensor_dec_ref(m_handle); + m_handle = t.m_handle; + m_tensor = t.m_tensor; + } + + dlpack::dtype dtype() const { return m_tensor.dtype; } + size_t ndim() const { return m_tensor.ndim; } + size_t shape(size_t i) const { return m_tensor.shape[i]; } + int64_t stride(size_t i) const { return m_tensor.strides[i]; } + bool is_valid() const { return m_handle != nullptr; } + int32_t device_type() const { return m_tensor.device.device_type; } + int32_t device_id() const { return m_tensor.device.device_id; } + detail::tensor_handle *handle() const { return m_handle; } + + const void *data() const { + return (const uint8_t *) m_tensor.data + m_tensor.byte_offset; + } + void *data() { return (uint8_t *) m_tensor.data + m_tensor.byte_offset; } + + template + NB_INLINE auto& operator()(Ts... indices) { + static_assert( + !std::is_same_v, + "To use nb::tensor::operator(), you must add a scalar type " + "annotation (e.g. 'float') to the tensor template parameters."); + static_assert( + !std::is_same_v, + "To use nb::tensor::operator(), you must add a nb::shape<> " + "annotation to the tensor template parameters."); + static_assert(sizeof...(Ts) == Info::shape_type::size, + "nb::tensor::operator(): invalid number of arguments"); + + int64_t counter = 0, index = 0; + ((index += int64_t(indices) * m_tensor.strides[counter++]), ...); + return (typename Info::scalar_type &) *( + (uint8_t *) m_tensor.data + m_tensor.byte_offset + + index * sizeof(typename Info::scalar_type)); + } + +private: + detail::tensor_handle *m_handle = nullptr; + dlpack::tensor m_tensor; +}; + +NAMESPACE_BEGIN(detail) + +template struct type_caster> { + NB_TYPE_CASTER(tensor, Value::Info::name + const_name("[") + + concat_maybe(detail::tensor_arg::name...) + + const_name("]")); + + bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept { + constexpr size_t size = (0 + ... + detail::tensor_arg::size); + size_t shape[size + 1]; + detail::tensor_req req; + req.shape = shape; + (detail::tensor_arg::apply(req), ...); + value = tensor(tensor_import( + src.ptr(), &req, flags & (uint8_t) cast_flags::convert)); + return value.is_valid(); + } + + static handle from_cpp(const tensor &tensor, rv_policy, + cleanup_list *) noexcept { + return tensor_wrap(tensor.handle(), int(Value::Info::framework)); + } +}; + +NAMESPACE_END(detail) +NAMESPACE_END(NB_NAMESPACE) diff --git a/src/common.cpp b/src/common.cpp index ccb07807..daf192c2 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -8,15 +8,25 @@ */ #include -#include -#include #include "internals.h" NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) _Py_static_string(id_stdout, "stdout"); -_Py_static_string(id_dlpack, "__dlpack__"); + +#if PY_VERSION_HEX < 0x03090000 +PyObject *nb_vectorcall_method(PyObject *name, PyObject *const *args, + size_t nargsf, PyObject *kwnames) { + PyObject *obj = PyObject_GetAttr(args[0], name); + if (!obj) + return obj; + PyObject *result = NB_VECTORCALL(obj, args + 1, nargsf - 1, kwnames); + Py_DECREF(obj); + return result; +} +#endif + #if defined(__GNUC__) __attribute__((noreturn, __format__ (__printf__, 1, 2))) @@ -34,19 +44,13 @@ void raise(const char *fmt, ...) { if (size < sizeof(buf)) throw std::runtime_error(buf); - char *ptr = (char *) malloc(size + 1); - if (!ptr) { - fprintf(stderr, "nb::detail::raise(): out of memory!"); - abort(); - } + scoped_pymalloc temp(size + 1); va_start(args, fmt); - vsnprintf(ptr, size + 1, fmt, args); + vsnprintf(temp.get(), size + 1, fmt, args); va_end(args); - std::runtime_error err(ptr); - free(ptr); - throw err; + throw std::runtime_error(temp.get()); } /// Abort the process with a fatal error @@ -209,18 +213,6 @@ PyObject *obj_op_2(PyObject *a, PyObject *b, return res; } -#if PY_VERSION_HEX < 0x03090000 -static PyObject *nb_vectorcall_method(PyObject *name, PyObject *const *args, - size_t nargsf, PyObject *kwnames) { - PyObject *obj = PyObject_GetAttr(args[0], name); - if (!obj) - return obj; - PyObject *result = NB_VECTORCALL(obj, args + 1, nargsf - 1, kwnames); - Py_DECREF(obj); - return result; -} -#endif - PyObject *obj_vectorcall(PyObject *base, PyObject *const *args, size_t nargsf, PyObject *kwnames, bool method_call) { const char *error = nullptr; @@ -540,160 +532,6 @@ void tuple_check(PyObject *tuple, size_t nargs) { } } -// ======================================================================== - -struct ManagedTensor { - dlpack::Tensor dl_tensor; - void *manager_ctx; - void (*deleter)(ManagedTensor *); -}; - -struct TensorHandle { - ManagedTensor *tensor; - std::atomic refcount; - bool free_strides; -}; - -TensorHandle *tensor_create(PyObject *o, const TensorReq *req) noexcept { - PyObject *temp = nullptr; - - // If this is not a capsule, try calling o.__dlpack__() - if (!PyCapsule_CheckExact(o)) { - o = temp = PyObject_CallMethodNoArgs(o, _PyUnicode_FromId(&id_dlpack)); - - if (!o) { - PyErr_Clear(); - return nullptr; - } - } - - // Extract the pointer underlying the capsule - void *ptr = PyCapsule_GetPointer(o, "dltensor"); - if (!ptr) { - PyErr_Clear(); - Py_XDECREF(temp); - return nullptr; - } - - // Check if the tensor satisfies the requirements - bool valid = true; - dlpack::Tensor &t = ((ManagedTensor *) ptr)->dl_tensor; - int64_t *strides = (int64_t *) PyMem_Malloc(sizeof(int64_t) * (size_t) t.ndim); - - if (!strides) { - Py_XDECREF(temp); - return nullptr; - } - - if (req->req_dtype) - valid &= t.dtype == req->dtype; - - if (req->req_shape) { - valid &= req->ndim == (uint32_t) t.ndim; - - if (valid) { - for (uint32_t i = 0; i < req->ndim; ++i) { - if (req->shape[i] != (size_t) t.shape[i] && - req->shape[i] != nanobind::any) { - valid = false; - break; - } - } - } - } - - if ((req->req_order || t.strides == nullptr) && t.ndim > 0) { - size_t accum = 1; - - if (req->req_order == 'C' || t.strides == nullptr) { - for (uint32_t i = (uint32_t) (t.ndim - 1);;) { - strides[i] = accum; - accum *= t.shape[i]; - if (i == 0) - break; - --i; - } - } else if (req->req_order == 'F') { - valid &= t.strides != nullptr; - - for (uint32_t i = 0; i < (uint32_t) t.ndim; ++i) { - strides[i] = accum; - accum *= t.shape[i]; - } - } else { - valid = false; - } - - if (t.strides) { - for (uint32_t i = 0; i < (uint32_t) t.ndim; ++i) { - if (strides[i] != t.strides[i]) { - valid = false; - break; - } - } - } - } - - if (!valid) { - PyMem_Free(strides); - Py_XDECREF(temp); - return nullptr; - } - - // Create a reference-counted wrapper - TensorHandle *result = (TensorHandle *) PyMem_Malloc(sizeof(TensorHandle)); - if (!result) { - Py_XDECREF(temp); - PyMem_Free(strides); - return nullptr; - } - - result->tensor = (ManagedTensor *) ptr; - result->refcount = 0; - - // Ensure that the strides member is always initialized - if (t.strides) { - result->free_strides = false; - PyMem_Free(strides); - } else { - result->free_strides = true; - t.strides = strides; - } - - // Mark the dltensor capsule as "consumed" - if (PyCapsule_SetName(o, "used_dltensor") || - PyCapsule_SetDestructor(o, nullptr)) - fail("nanobind::detail::tensor_create(): could not mark dltensor " - "capsule as consumed!"); - - Py_XDECREF(temp); - - return result; -} - - -dlpack::Tensor *tensor_inc_ref(TensorHandle *th) noexcept { - if (!th) - return nullptr; - ++th->refcount; - return &th->tensor->dl_tensor; -} - -void tensor_dec_ref(TensorHandle *th) noexcept { - if (!th) - return; - if (--th->refcount == 0) { - if (th->free_strides) { - PyMem_Free(th->tensor->dl_tensor.strides); - th->tensor->dl_tensor.strides = nullptr; - } - if (th->tensor->deleter) - th->tensor->deleter(th->tensor); - PyMem_Free(th); - } -} - - // ======================================================================== void print(PyObject *value, PyObject *end, PyObject *file) { diff --git a/src/internals.cpp b/src/internals.cpp index 84528818..bc4a26d5 100644 --- a/src/internals.cpp +++ b/src/internals.cpp @@ -173,6 +173,7 @@ static void internals_make() { nb_enum_type.ob_base.ob_base.ob_refcnt = 1; nb_func_type.ob_base.ob_base.ob_refcnt = 1; nb_meth_type.ob_base.ob_base.ob_refcnt = 1; + nb_tensor_type.ob_base.ob_base.ob_refcnt = 1; nb_static_property_type.ob_base.ob_base.ob_refcnt = 1; nb_type_type.tp_base = &PyType_Type; @@ -187,7 +188,8 @@ static void internals_make() { if (PyType_Ready(&nb_type_type) < 0 || PyType_Ready(&nb_func_type) < 0 || PyType_Ready(&nb_meth_type) < 0 || PyType_Ready(&nb_enum_type) < 0 || - PyType_Ready(&nb_static_property_type)) + PyType_Ready(&nb_static_property_type) < 0 || + PyType_Ready(&nb_tensor_type) < 0) fail("nanobind::detail::internals_make(): type initialization failed!"); if ((nb_type_type.tp_flags & Py_TPFLAGS_HEAPTYPE) != 0 || @@ -207,6 +209,7 @@ static void internals_make() { internals_p->nb_func = &nb_func_type; internals_p->nb_meth = &nb_meth_type; internals_p->nb_enum = &nb_enum_type; + internals_p->nb_tensor = &nb_tensor_type; internals_p->nb_static_property = &nb_static_property_type; } diff --git a/src/internals.h b/src/internals.h index f2ef40ea..34b74e02 100644 --- a/src/internals.h +++ b/src/internals.h @@ -26,6 +26,7 @@ extern PyTypeObject nb_enum_type; extern PyTypeObject nb_func_type; extern PyTypeObject nb_meth_type; +extern PyTypeObject nb_tensor_type; /// Nanobind function metadata (overloads, etc.) struct func_record : func_data<0> { @@ -172,6 +173,9 @@ struct internals { /// Property variant for static attributes PyTypeObject *nb_static_property; + /// Tensor wrpaper + PyTypeObject *nb_tensor; + /// Instance pointer -> Python object mapping py_map, nb_inst *, ptr_type_hash> inst_c2p; @@ -215,6 +219,29 @@ inline void *inst_ptr(nb_inst *self) { return self->direct ? ptr : *(void **) ptr; } +template struct scoped_pymalloc { + scoped_pymalloc(size_t size = 1) { + ptr = (T *) PyMem_Malloc(size * sizeof(T)); + if (!ptr) + fail("scoped_pymalloc(): could not allocate %zu bytes of memory!", size); + } + ~scoped_pymalloc() { PyMem_Free(ptr); } + T *release() { + T *temp = ptr; + ptr = nullptr; + return temp; + } + T *get() const { return ptr; } + T &operator[](size_t i) { return ptr[i]; } + T *operator->() { return ptr; } +private: + T *ptr{ nullptr }; +}; + +#if PY_VERSION_HEX < 0x03090000 +extern PyObject *nb_vectorcall_method(PyObject *name, PyObject *const *args, + size_t nargsf, PyObject *kwnames); +#endif NAMESPACE_END(detail) NAMESPACE_END(NB_NAMESPACE) diff --git a/src/nb_func.cpp b/src/nb_func.cpp index 45797287..a56e4fa3 100644 --- a/src/nb_func.cpp +++ b/src/nb_func.cpp @@ -796,7 +796,8 @@ static void nb_func_render_signature(const func_record *f) noexcept { continue; } else { buf.put("arg"); - if (arg_index > is_method || f->nargs > 1 + (uint32_t) is_method) + if (arg_index > size_t(is_method) || + f->nargs > 1 + (uint32_t) is_method) buf.put_uint32((uint32_t) (arg_index - is_method)); } diff --git a/src/tensor.cpp b/src/tensor.cpp new file mode 100644 index 00000000..ed34d17a --- /dev/null +++ b/src/tensor.cpp @@ -0,0 +1,630 @@ +#include +#include +#include "internals.h" + +NAMESPACE_BEGIN(NB_NAMESPACE) +NAMESPACE_BEGIN(detail) + +// ======================================================================== + +_Py_static_string(id_dlpack, "__dlpack__"); + + +struct managed_tensor { + dlpack::tensor dl_tensor; + void *manager_ctx; + void (*deleter)(managed_tensor *); +}; + +struct tensor_handle { + managed_tensor *tensor; + std::atomic refcount; + PyObject *owner; + bool free_shape; + bool free_strides; + bool call_deleter; +}; + +struct nb_tensor { + PyObject_HEAD + PyObject *capsule; +}; + +static PyObject *nb_tensor_new(PyTypeObject *subtype, PyObject *args, + PyObject *kwargs) { + PyObject* result = subtype->tp_alloc(subtype,0); + if (PyTuple_GET_SIZE(args) != 1 || kwargs) + fail("nanobind::detail::nb_tensor_new(): internal error!"); + + PyObject *capsule = PyTuple_GET_ITEM(args, 0); + ((nb_tensor *) result)->capsule = capsule; + Py_INCREF(capsule); + return result; +} + +static void nb_tensor_dealloc(PyObject *self) { + Py_DECREF(((nb_tensor *) self)->capsule); + Py_TYPE(self)->tp_free(self); +} + +static PyObject *nb_tensor_get(PyObject *self, PyObject *) { + PyObject *result = ((nb_tensor *) self)->capsule; + Py_INCREF(result); + return result; +} + +int nb_tensor_getbuffer(PyObject *exporter, Py_buffer *view, int) { + nb_tensor *self = (nb_tensor *) exporter; + + void *ptr = PyCapsule_GetPointer(self->capsule, "dltensor"); + if (!ptr) + fail("nanobind::tensor::nb_tensor_getbuffer(): internal error!"); + + dlpack::tensor &t = ((managed_tensor *) ptr)->dl_tensor; + + if (t.device.device_type != device::cpu::value) { + PyErr_SetString(PyExc_BufferError, "Only CPU-allocated tensors can be " + "accessed via the buffer protocol!"); + return -1; + } + + const char *format = nullptr; + switch ((dlpack::dtype_code) t.dtype.code) { + case dlpack::dtype_code::Int: + switch (t.dtype.bits) { + case 8: format = "b"; break; + case 16: format = "h"; break; + case 32: format = "i"; break; + case 64: format = "q"; break; + } + break; + + case dlpack::dtype_code::UInt: + switch (t.dtype.bits) { + case 8: format = "B"; break; + case 16: format = "H"; break; + case 32: format = "I"; break; + case 64: format = "Q"; break; + } + break; + + case dlpack::dtype_code::Float: + switch (t.dtype.bits) { + case 16: format = "e"; break; + case 32: format = "f"; break; + case 64: format = "d"; break; + } + break; + + default: + break; + } + + if (!format || t.dtype.lanes != 1) { + PyErr_SetString( + PyExc_BufferError, + "Don't know how to convert DLPack dtype into buffer protocol format!"); + return -1; + } + + view->format = (char *) format; + view->itemsize = t.dtype.bits / 8; + view->buf = (void *) ((uintptr_t) t.data + t.byte_offset); + view->obj = exporter; + Py_INCREF(exporter); + + Py_ssize_t len = view->itemsize; + scoped_pymalloc strides(t.ndim), shape(t.ndim); + for (int32_t i = 0; i < t.ndim; ++i) { + len *= (Py_ssize_t) t.shape[i]; + strides[i] = (Py_ssize_t) t.strides[i] * view->itemsize; + shape[i] = (Py_ssize_t) t.shape[i]; + } + + view->ndim = t.ndim; + view->len = len; + view->readonly = false; + view->suboffsets = nullptr; + view->internal = nullptr; + view->strides = strides.release(); + view->shape = shape.release(); + + return 0; +} + +void nb_tensor_releasebuffer(PyObject *, Py_buffer *view) { + PyMem_Free(view->shape); + PyMem_Free(view->strides); +} + +static PyMethodDef nb_tensor_methods[] = { + { "__dlpack__", (PyCFunction) nb_tensor_get, METH_NOARGS, nullptr }, + { nullptr, nullptr, 0, nullptr} +}; + +static PyBufferProcs nb_tensor_as_buffer = { + .bf_getbuffer = nb_tensor_getbuffer, + .bf_releasebuffer = nb_tensor_releasebuffer +}; + +PyTypeObject nb_tensor_type = { + .tp_name = "nb_tensor", + .tp_basicsize = sizeof(nb_tensor), + .tp_dealloc = nb_tensor_dealloc, + .tp_as_buffer = &nb_tensor_as_buffer, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_methods = nb_tensor_methods, + .tp_new = nb_tensor_new +}; + +static PyObject *dlpack_from_buffer_protocol(PyObject *o) { + scoped_pymalloc view; + scoped_pymalloc mt; + + if (PyObject_GetBuffer(o, view.get(), PyBUF_RECORDS)) { + PyErr_Clear(); + return nullptr; + } + + char format = 'B'; + const char *format_str = view->format; + if (format_str) + format = *format_str; + + bool skip_first = format == '@' || format == '='; + + int32_t num = 1; + if(*(uint8_t *) &num == 1) { + if (format == '<') + skip_first = true; + } else { + if (format == '!' || format == '>') + skip_first = true; + } + + if (skip_first && format_str) + format = *++format_str; + + dlpack::dtype dt { }; + bool fail = format_str && format_str[1] != '\0'; + + if (!fail) { + switch (format) { + case 'c': + case 'b': + case 'h': + case 'i': + case 'l': + case 'q': + case 'n': dt.code = (uint8_t) dlpack::dtype_code::Int; break; + + case 'B': + case 'H': + case 'I': + case 'L': + case 'Q': + case 'N': dt.code = (uint8_t) dlpack::dtype_code::UInt; break; + + case 'e': + case 'f': + case 'd': dt.code = (uint8_t) dlpack::dtype_code::Float; break; + + default: + fail = true; + } + dt.lanes = 1; + dt.bits = (uint8_t) (view->itemsize * 8); + } + + if (fail) { + PyBuffer_Release(view.get()); + return nullptr; + } + + mt->deleter = [](managed_tensor *mt2) { + gil_scoped_acquire guard; + Py_buffer *buf = (Py_buffer *) mt2->manager_ctx; + PyBuffer_Release(buf); + PyMem_Free(mt2->dl_tensor.shape); + PyMem_Free(mt2->dl_tensor.strides); + PyMem_Free(mt2); + }; + + /* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, but + PyTorch unfortunately ignores the 'byte_offset' value.. :-( */ +#if 0 + uintptr_t value_int = (uintptr_t) view->buf, + value_rounded = (value_int / 256) * 256; +#else + uintptr_t value_int = (uintptr_t) view->buf, + value_rounded = value_int; +#endif + + mt->dl_tensor.data = (void *) value_rounded; + mt->dl_tensor.device = { device::cpu::value, 0 }; + mt->dl_tensor.ndim = view->ndim; + mt->dl_tensor.dtype = dt; + mt->dl_tensor.byte_offset = value_int - value_rounded; + + scoped_pymalloc strides(view->ndim); + scoped_pymalloc shape(view->ndim); + for (size_t i = 0; i < (size_t) view->ndim; ++i) { + strides[i] = (int64_t) (view->strides[i] / view->itemsize); + shape[i] = (int64_t) view->shape[i]; + } + + mt->manager_ctx = view.release(); + mt->dl_tensor.shape = shape.release(); + mt->dl_tensor.strides = strides.release(); + + return PyCapsule_New(mt.release(), "dltensor", [](PyObject *o) { + error_scope scope; // temporarily save any existing errors + managed_tensor *mt = + (managed_tensor *) PyCapsule_GetPointer(o, "dltensor"); + if (mt) { + if (mt->deleter) + mt->deleter(mt); + } else { + PyErr_Clear(); + } + }); +} + +tensor_handle *tensor_import(PyObject *o, const tensor_req *req, + bool convert) noexcept { + object capsule; + + // If this is not a capsule, try calling o.__dlpack__() + if (!PyCapsule_CheckExact(o)) { + PyObject *args[1] = { o }, *name = _PyUnicode_FromId(&id_dlpack); + capsule = steal(NB_VECTORCALL_METHOD( + name, args, 1 | PY_VECTORCALL_ARGUMENTS_OFFSET, nullptr)); + + if (!capsule.is_valid()) { + PyErr_Clear(); + PyTypeObject *tp = Py_TYPE(o); + + try { + const char *module_name = + borrow(handle(tp).attr("__module__")).c_str(); + + object package; + if (strncmp(module_name, "tensorflow.", 11) == 0) + package = module_::import_("tensorflow.experimental.dlpack"); + else if (strcmp(module_name, "torch") == 0) + package = module_::import_("torch.utils.dlpack"); + else if (strncmp(module_name, "jaxlib", 6) == 0) + package = module_::import_("jax.dlpack"); + + if (package.is_valid()) + capsule = package.attr("to_dlpack")(handle(o)); + } catch (...) { + capsule.clear(); + } + } + + // Try creating a tensor via the buffer protocol + if (!capsule.is_valid()) + capsule = steal(dlpack_from_buffer_protocol(o)); + + if (!capsule.is_valid()) + return nullptr; + } else { + capsule = borrow(o); + } + + // Extract the pointer underlying the capsule + void *ptr = PyCapsule_GetPointer(capsule.ptr(), "dltensor"); + if (!ptr) { + PyErr_Clear(); + return nullptr; + } + + // Check if the tensor satisfies the requirements + dlpack::tensor &t = ((managed_tensor *) ptr)->dl_tensor; + + bool pass_dtype = true, pass_device = true, + pass_shape = true, pass_order = true; + + if (req->req_dtype) + pass_dtype = t.dtype == req->dtype; + + if (req->req_device) + pass_device = t.device.device_type == req->req_device; + + if (req->req_shape) { + pass_shape &= req->ndim == (uint32_t) t.ndim; + + if (pass_shape) { + for (uint32_t i = 0; i < req->ndim; ++i) { + if (req->shape[i] != (size_t) t.shape[i] && + req->shape[i] != nanobind::any) { + pass_shape = false; + break; + } + } + } + } + + scoped_pymalloc strides(t.ndim); + if ((req->req_order || t.strides == nullptr) && t.ndim > 0) { + size_t accum = 1; + + if (req->req_order == 'C' || t.strides == nullptr) { + for (uint32_t i = (uint32_t) (t.ndim - 1);;) { + strides[i] = accum; + accum *= t.shape[i]; + if (i == 0) + break; + --i; + } + } else if (req->req_order == 'F') { + pass_order &= t.strides != nullptr; + + for (uint32_t i = 0; i < (uint32_t) t.ndim; ++i) { + strides[i] = accum; + accum *= t.shape[i]; + } + } else { + pass_order = false; + } + + if (t.strides) { + for (uint32_t i = 0; i < (uint32_t) t.ndim; ++i) { + if (strides[i] != t.strides[i]) { + pass_order = false; + break; + } + } + } + } + + // Support implicit conversion of 'dtype' and order + if (pass_device && pass_shape && (!pass_dtype || !pass_order) && convert && + capsule.ptr() != o) { + PyTypeObject *tp = Py_TYPE(o); + const char *module_name = + borrow(handle(tp).attr("__module__")).c_str(); + + char order = 'K'; + if (req->req_order != '\0') + order = req->req_order; + + if (req->dtype.lanes != 1) + return nullptr; + + const char *prefix = nullptr; + char dtype[8]; + switch (req->dtype.code) { + case (uint8_t) dlpack::dtype_code::Int: prefix = "int"; break; + case (uint8_t) dlpack::dtype_code::UInt: prefix = "uint"; break; + case (uint8_t) dlpack::dtype_code::Float: prefix = "float"; break; + default: + return nullptr; + } + snprintf(dtype, sizeof(dtype), "%s%u", prefix, req->dtype.bits); + + object converted; + try { + if (strcmp(tp->tp_name, "numpy.ndarray") == 0) { + converted = handle(o).attr("astype")( + dtype, + order + ); + } + + if (strcmp(module_name, "torch") == 0) { + converted = handle(o).attr("to")( + arg("dtype") = module_::import_("torch").attr(dtype), + arg("copy") = true + ); + } else if (strncmp(module_name, "tensorflow.", 11) == 0) { + converted = module_::import_("tensorflow") + .attr("cast")(handle(o), dtype); + } else if (strncmp(module_name, "jaxlib", 6) == 0) { + converted = handle(o).attr("astype")(dtype); + } + } catch (...) { converted.clear(); } + + // Potentially try again recursively + if (!converted.is_valid()) + return nullptr; + else + return tensor_import(converted.ptr(), req, false); + } + + if (!pass_dtype || !pass_device || !pass_shape || !pass_order) + return nullptr; + + // Create a reference-counted wrapper + scoped_pymalloc result; + result->tensor = (managed_tensor *) ptr; + result->refcount = 0; + result->owner = nullptr; + result->free_shape = false; + result->call_deleter = true; + + // Ensure that the strides member is always initialized + if (t.strides) { + result->free_strides = false; + } else { + result->free_strides = true; + t.strides = strides.release(); + } + + // Mark the dltensor capsule as "consumed" + if (PyCapsule_SetName(capsule.ptr(), "used_dltensor") || + PyCapsule_SetDestructor(capsule.ptr(), nullptr)) + fail("nanobind::detail::tensor_import(): could not mark dltensor " + "capsule as consumed!"); + + return result.release(); +} + +dlpack::tensor *tensor_inc_ref(tensor_handle *th) noexcept { + if (!th) + return nullptr; + ++th->refcount; + return &th->tensor->dl_tensor; +} + +void tensor_dec_ref(tensor_handle *th) noexcept { + if (!th) + return; + size_t rc_value = th->refcount--; + + if (rc_value == 0) { + fail("tensor_dec_ref(): reference count became negative!"); + } else if (rc_value == 1) { + Py_XDECREF(th->owner); + managed_tensor *mt = th->tensor; + if (th->free_shape) { + PyMem_Free(mt->dl_tensor.shape); + mt->dl_tensor.shape = nullptr; + } + if (th->free_strides) { + PyMem_Free(mt->dl_tensor.strides); + mt->dl_tensor.strides = nullptr; + } + if (th->call_deleter) { + if (mt->deleter) + mt->deleter(mt); + } else { + PyMem_Free(mt); + } + PyMem_Free(th); + } +} + +tensor_handle *tensor_create(void *value, size_t ndim, const size_t *shape_in, + PyObject *owner, const int64_t *strides_in, + dlpack::dtype *dtype, int32_t device_type, + int32_t device_id) { + /* DLPack mandates 256-byte alignment of the 'DLTensor::data' field, but + PyTorch unfortunately ignores the 'byte_offset' value.. :-( */ +#if 0 + uintptr_t value_int = (uintptr_t) value, + value_rounded = (value_int / 256) * 256; +#else + uintptr_t value_int = (uintptr_t) value, + value_rounded = value_int; +#endif + + + scoped_pymalloc tensor; + scoped_pymalloc result; + scoped_pymalloc shape(ndim), strides(ndim); + + auto deleter = [](managed_tensor *mt) { + gil_scoped_acquire guard; + tensor_handle *th = (tensor_handle *) mt->manager_ctx; + tensor_dec_ref(th); + }; + + for (size_t i = 0; i < ndim; ++i) + shape[i] = (int64_t) shape_in[i]; + + int64_t prod = 1; + for (size_t i = ndim - 1; ;) { + if (strides_in) { + strides[i] = strides_in[i]; + } else { + strides[i] = prod; + prod *= (int64_t) shape_in[i]; + } + if (i == 0) + break; + --i; + } + + tensor->dl_tensor.data = (void *) value_rounded; + tensor->dl_tensor.device.device_type = device_type; + tensor->dl_tensor.device.device_id = device_id; + tensor->dl_tensor.ndim = (int32_t) ndim; + tensor->dl_tensor.dtype = *dtype; + tensor->dl_tensor.byte_offset = value_int - value_rounded; + tensor->dl_tensor.shape = shape.release(); + tensor->dl_tensor.strides = strides.release(); + tensor->manager_ctx = result.get(); + tensor->deleter = deleter; + result->tensor = (managed_tensor *) tensor.release(); + result->refcount = 0; + result->owner = owner; + result->free_shape = true; + result->free_strides = true; + result->call_deleter = false; + Py_XINCREF(owner); + return result.release(); +} + +static void tensor_capsule_destructor(PyObject *o) { + error_scope scope; // temporarily save any existing errors + managed_tensor *mt = + (managed_tensor *) PyCapsule_GetPointer(o, "dltensor"); + if (mt) + tensor_dec_ref((tensor_handle *) mt->manager_ctx); + else + PyErr_Clear(); +} + +PyObject *tensor_wrap(tensor_handle *th, int framework) noexcept { + tensor_inc_ref(th); + object o = steal(PyCapsule_New(th->tensor, "dltensor", tensor_capsule_destructor)), + package; + + switch ((tensor_framework) framework) { + case tensor_framework::none: + break; + + case tensor_framework::numpy: + package = module_::import_("numpy"); + o = handle(&nb_tensor_type)(o); + break; + + case tensor_framework::pytorch: + package = module_::import_("torch.utils.dlpack"); + break; + + + case tensor_framework::tensorflow: + package = module_::import_("tensorflow.experimental.dlpack"); + break; + + case tensor_framework::jax: + package = module_::import_("jax.dlpack"); + break; + + + default: + fail("nanobind::detail::tensor_wrap(): unknown framework " + "specified!"); + } + + + if (package.is_valid()) { + try { + o = package.attr("from_dlpack")(o); + } catch (...) { + if ((tensor_framework) framework == tensor_framework::numpy) { + try { + // Older numpy versions + o = package.attr("_from_dlpack")(o); + } catch (...) { + try { + // Yet older numpy versions + o = package.attr("asarray")(o); + } catch (...) { + return nullptr; + } + } + } else { + return nullptr; + } + } + } + + return o.release().ptr(); +} + +NAMESPACE_END(detail) +NAMESPACE_END(NB_NAMESPACE) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index abbb3e38..c93a7aa1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,7 +3,7 @@ nanobind_add_module(test_classes_ext test_classes.cpp) nanobind_add_module(test_holders_ext test_holders.cpp) nanobind_add_module(test_stl_ext test_stl.cpp) nanobind_add_module(test_enum_ext test_enum.cpp) -nanobind_add_module(test_dlpack_ext test_dlpack.cpp) +nanobind_add_module(test_tensor_ext test_tensor.cpp) set(TEST_FILES test_functions.py @@ -11,7 +11,7 @@ set(TEST_FILES test_holders.py test_stl.py test_enum.py - test_dlpack.py + test_tensor.py ) if (NOT (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_CURRENT_BINARY_DIR) OR MSVC) diff --git a/tests/test_classes.cpp b/tests/test_classes.cpp index 1387e23e..ba3f6ebf 100644 --- a/tests/test_classes.cpp +++ b/tests/test_classes.cpp @@ -286,7 +286,7 @@ NB_MODULE(test_classes_ext, m) { scls.def(nb::init<>()); Supplement &supplement = nb::type_supplement(scls); - for (uint16_t i = 0; i < 0xFF; ++i) + for (uint8_t i = 0; i < 0xFF; ++i) supplement.data[i] = i; m.def("check_supplement", [](nb::handle h) { diff --git a/tests/test_classes.py b/tests/test_classes.py index da4b87c4..0083e72b 100644 --- a/tests/test_classes.py +++ b/tests/test_classes.py @@ -40,7 +40,7 @@ def test02_static_overload(): def test03_instantiate(clean): - s1 = t.Struct() + s1 : t.Struct = t.Struct() assert s1.value() == 5 s2 = t.Struct(10) assert s2.value() == 10 diff --git a/tests/test_dlpack.cpp b/tests/test_dlpack.cpp deleted file mode 100644 index 2af1e476..00000000 --- a/tests/test_dlpack.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include - -namespace nb = nanobind; - -NB_MODULE(test_dlpack_ext, m) { - m.def("get_shape", [](const nb::tensor<> &t) { - nb::list l; - for (size_t i = 0; i < t.ndim(); ++i) - l.append(t.shape(i)); - return l; - }); - - m.def("check_float", [](const nb::tensor<> &t) { - return t.dtype() == nb::dtype(); - }); - - m.def("pass_float32", [](const nb::tensor &) { }); - m.def("pass_uint32", [](const nb::tensor &) { }); - m.def("pass_float32_shaped", - [](const nb::tensor> &) {}); - - m.def("pass_float32_shaped_ordered", - [](const nb::tensor, - nb::shape> &) {}); - - m.def("check_order", [](nb::tensor>) -> char { return 'C'; }); - m.def("check_order", [](nb::tensor>) -> char { return 'F'; }); - m.def("check_order", [](nb::tensor<>) -> char { return '?'; }); -} diff --git a/tests/test_dlpack.py b/tests/test_dlpack.py deleted file mode 100644 index d39ba67b..00000000 --- a/tests/test_dlpack.py +++ /dev/null @@ -1,69 +0,0 @@ -import test_dlpack_ext as t -import numpy as np -import pytest - -def test01_metadata(): - a = np.zeros(shape=()) - assert t.get_shape(a) == [] - b = a.__dlpack__() - assert t.get_shape(b) == [] - - with pytest.raises(TypeError) as excinfo: - # Capsule can only be consumed once - assert t.get_shape(b) == [] - assert 'incompatible function arguments' in str(excinfo.value) - - a = np.zeros(shape=(3, 4, 5)) - assert t.get_shape(a) == [3, 4, 5] - assert t.get_shape(a.__dlpack__()) == [3, 4, 5] - assert not t.check_float(np.array([1], dtype=np.uint32)) and \ - t.check_float(np.array([1], dtype=np.float32)) - - -def test02_docstr(): - assert t.get_shape.__doc__ == "get_shape(arg: tensor[], /) -> list" - assert t.pass_uint32.__doc__ == "pass_uint32(arg: tensor[dtype=uint32], /) -> None" - assert t.pass_float32.__doc__ == "pass_float32(arg: tensor[dtype=float32], /) -> None" - assert t.pass_float32_shaped.__doc__ == "pass_float32_shaped(arg: tensor[dtype=float32, shape=(3, *, 4)], /) -> None" - assert t.pass_float32_shaped_ordered.__doc__ == "pass_float32_shaped_ordered(arg: tensor[dtype=float32, order='C', shape=(*, *, 4)], /) -> None" - assert t.check_device.__doc__ == ("check_device(arg: tensor[device='cpu'], /) -> str\n" - "check_device(arg: tensor[device='cuda'], /) -> str") - - -def test03_constrain_dtype(): - a_u32 = np.array([1], dtype=np.uint32) - a_f32 = np.array([1], dtype=np.float32) - - t.pass_uint32(a_u32) - t.pass_float32(a_f32) - - with pytest.raises(TypeError) as excinfo: - t.pass_uint32(a_f32) - assert 'incompatible function arguments' in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - t.pass_float32(a_u32) - assert 'incompatible function arguments' in str(excinfo.value) - -def test04_constrain_shape(): - t.pass_float32_shaped(np.zeros((3, 0, 4), dtype=np.float32)) - t.pass_float32_shaped(np.zeros((3, 5, 4), dtype=np.float32)) - - with pytest.raises(TypeError) as excinfo: - t.pass_float32_shaped(np.zeros((3, 5), dtype=np.float32)) - - with pytest.raises(TypeError) as excinfo: - t.pass_float32_shaped(np.zeros((2, 5, 4), dtype=np.float32)) - - with pytest.raises(TypeError) as excinfo: - t.pass_float32_shaped(np.zeros((3, 5, 6), dtype=np.float32)) - - with pytest.raises(TypeError) as excinfo: - t.pass_float32_shaped(np.zeros((3, 5, 4, 6), dtype=np.float32)) - - -def test04_constrain_order(): - assert t.check_order(np.zeros((3, 5, 4, 6), order='C')) == 'C' - assert t.check_order(np.zeros((3, 5, 4, 6), order='F')) == 'F' - assert t.check_order(np.zeros((3, 5, 4, 6), order='C')[:, 2, :, :]) == '?' - assert t.check_order(np.zeros((3, 5, 4, 6), order='F')[:, 2, :, :]) == '?' diff --git a/tests/test_tensor.cpp b/tests/test_tensor.cpp new file mode 100644 index 00000000..e0b878a7 --- /dev/null +++ b/tests/test_tensor.cpp @@ -0,0 +1,128 @@ +#include +#include +#include + +namespace nb = nanobind; + +using namespace nb::literals; + +int destruct_count = 0; + +NB_MODULE(test_tensor_ext, m) { + m.def("get_shape", [](const nb::tensor<> &t) { + nb::list l; + for (size_t i = 0; i < t.ndim(); ++i) + l.append(t.shape(i)); + return l; + }, "array"_a.noconvert()); + + m.def("check_float", [](const nb::tensor<> &t) { + return t.dtype() == nb::dtype(); + }); + + m.def("pass_float32", [](const nb::tensor &) { }, "array"_a.noconvert()); + m.def("pass_uint32", [](const nb::tensor &) { }, "array"_a.noconvert()); + m.def("pass_float32_shaped", + [](const nb::tensor> &) {}, "array"_a.noconvert()); + + m.def("pass_float32_shaped_ordered", + [](const nb::tensor> &) {}, "array"_a.noconvert()); + + m.def("check_order", [](nb::tensor) -> char { return 'C'; }); + m.def("check_order", [](nb::tensor) -> char { return 'F'; }); + m.def("check_order", [](nb::tensor<>) -> char { return '?'; }); + + m.def("check_device", [](nb::tensor) -> const char * { return "cpu"; }); + m.def("check_device", [](nb::tensor) -> const char * { return "cuda"; }); + + m.def("initialize", + [](nb::tensor, nb::device::cpu> &t) { + for (size_t i = 0; i < 10; ++i) + t(i) = (float) i; + }); + + m.def("initialize", + [](nb::tensor, nb::device::cpu> &t) { + int k = 0; + for (size_t i = 0; i < 10; ++i) + for (size_t j = 0; j < t.shape(1); ++j) + t(i, j) = (float) k++; + }); + + m.def( + "noimplicit", + [](nb::tensor>) { return 0; }, + "array"_a.noconvert()); + + m.def( + "implicit", + [](nb::tensor>) { return 0; }, + "array"_a); + + m.def("inspect_tensor", [](nb::tensor<> tensor) { + printf("Tensor data pointer : %p\n", tensor.data()); + printf("Tensor dimension : %zu\n", tensor.ndim()); + for (size_t i = 0; i < tensor.ndim(); ++i) { + printf("Tensor dimension [%zu] : %zu\n", i, tensor.shape(i)); + printf("Tensor stride [%zu] : %zu\n", i, (size_t) tensor.stride(i)); + } + printf("Tensor is on CPU? %i\n", tensor.device_type() == nb::device::cpu::value); + printf("Device ID = %u\n", tensor.device_id()); + printf("Tensor dtype check: int16=%i, uint32=%i, float32=%i\n", + tensor.dtype() == nb::dtype(), + tensor.dtype() == nb::dtype(), + tensor.dtype() == nb::dtype() + ); + }); + + m.def("process", [](nb::tensor, + nb::c_contig, nb::device::cpu> tensor) { + // Double brightness of the MxNx3 RGB image + for (size_t y = 0; y < tensor.shape(0); ++y) + for (size_t x = 0; y < tensor.shape(1); ++x) + for (size_t ch = 0; ch < 3; ++ch) + tensor(y, x, ch) = (uint8_t) std::min(255, tensor(y, x, ch) * 2); + + }); + + m.def("destruct_count", []() { return destruct_count; }); + m.def("return_dlpack", []() { + float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + nb::capsule deleter(f, [](void *data) { + destruct_count++; + delete[] (float *) data; + }); + + return nb::tensor>(f, 2, shape, deleter); + }); + m.def("passthrough", [](nb::tensor<> a) { return a; }); + + m.def("ret_numpy", []() { + float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + nb::capsule deleter(f, [](void *data) { + destruct_count++; + delete[] (float *) data; + }); + + return nb::tensor>(f, 2, shape, + deleter); + }); + + m.def("ret_pytorch", []() { + float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + nb::capsule deleter(f, [](void *data) { + destruct_count++; + delete[] (float *) data; + }); + + return nb::tensor>(f, 2, shape, + deleter); + }); +} diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100644 index 00000000..53c45e45 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,293 @@ +import test_tensor_ext as t +import numpy as np +import pytest +import warnings +import gc + +def test01_metadata(): + a = np.zeros(shape=()) + assert t.get_shape(a) == [] + + if hasattr(a, '__dlpack__'): + b = a.__dlpack__() + assert t.get_shape(b) == [] + else: + b = None + + with pytest.raises(TypeError) as excinfo: + # Capsule can only be consumed once + assert t.get_shape(b) == [] + assert 'incompatible function arguments' in str(excinfo.value) + + a = np.zeros(shape=(3, 4, 5)) + assert t.get_shape(a) == [3, 4, 5] + if hasattr(a, '__dlpack__'): + assert t.get_shape(a.__dlpack__()) == [3, 4, 5] + assert not t.check_float(np.array([1], dtype=np.uint32)) and \ + t.check_float(np.array([1], dtype=np.float32)) + + +def test02_docstr(): + assert t.get_shape.__doc__ == "get_shape(array: tensor[]) -> list" + assert t.pass_uint32.__doc__ == "pass_uint32(array: tensor[dtype=uint32]) -> None" + assert t.pass_float32.__doc__ == "pass_float32(array: tensor[dtype=float32]) -> None" + assert t.pass_float32_shaped.__doc__ == "pass_float32_shaped(array: tensor[dtype=float32, shape=(3, *, 4)]) -> None" + assert t.pass_float32_shaped_ordered.__doc__ == "pass_float32_shaped_ordered(array: tensor[dtype=float32, order='C', shape=(*, *, 4)]) -> None" + assert t.check_device.__doc__ == ("check_device(arg: tensor[device='cpu'], /) -> str\n" + "check_device(arg: tensor[device='cuda'], /) -> str") + + +def test03_constrain_dtype(): + a_u32 = np.array([1], dtype=np.uint32) + a_f32 = np.array([1], dtype=np.float32) + + t.pass_uint32(a_u32) + t.pass_float32(a_f32) + + with pytest.raises(TypeError) as excinfo: + t.pass_uint32(a_f32) + assert 'incompatible function arguments' in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + t.pass_float32(a_u32) + assert 'incompatible function arguments' in str(excinfo.value) + + +def test04_constrain_shape(): + t.pass_float32_shaped(np.zeros((3, 0, 4), dtype=np.float32)) + t.pass_float32_shaped(np.zeros((3, 5, 4), dtype=np.float32)) + + with pytest.raises(TypeError) as excinfo: + t.pass_float32_shaped(np.zeros((3, 5), dtype=np.float32)) + + with pytest.raises(TypeError) as excinfo: + t.pass_float32_shaped(np.zeros((2, 5, 4), dtype=np.float32)) + + with pytest.raises(TypeError) as excinfo: + t.pass_float32_shaped(np.zeros((3, 5, 6), dtype=np.float32)) + + with pytest.raises(TypeError) as excinfo: + t.pass_float32_shaped(np.zeros((3, 5, 4, 6), dtype=np.float32)) + + +def test04_constrain_order(): + assert t.check_order(np.zeros((3, 5, 4, 6), order='C')) == 'C' + assert t.check_order(np.zeros((3, 5, 4, 6), order='F')) == 'F' + assert t.check_order(np.zeros((3, 5, 4, 6), order='C')[:, 2, :, :]) == '?' + assert t.check_order(np.zeros((3, 5, 4, 6), order='F')[:, 2, :, :]) == '?' + + +def test05_constrain_order_jax(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import jax + c = jax.numpy.zeros((3, 5)) + except: + pytest.skip('jax is missing') + + z = jax.numpy.zeros((3, 5, 4, 6)) + assert t.check_order(z) == 'C' + + +@pytest.mark.filterwarnings +def test06_constrain_order_pytorch(): + try: + import torch + c = torch.zeros(3, 5) + c.__dlpack__() + except: + pytest.skip('pytorch is missing') + + f = c.t().contiguous().t() + assert t.check_order(c) == 'C' + assert t.check_order(f) == 'F' + assert t.check_order(c[:, 2:5]) == '?' + assert t.check_order(f[1:3, :]) == '?' + assert t.check_device(torch.zeros(3, 5)) == 'cpu' + if torch.cuda.is_available(): + assert t.check_device(torch.zeros(3, 5, device='cuda')) == 'cuda' + + +def test07_constrain_order_tensorflow(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import tensorflow as tf + c = tf.zeros((3, 5)) + except: + pytest.skip('tensorflow is missing') + + assert t.check_order(c) == 'C' + + +def test08_write_from_cpp(): + x = np.zeros(10, dtype=np.float32) + t.initialize(x) + assert np.all(x == np.arange(10, dtype=np.float32)) + + x = np.zeros((10, 3), dtype=np.float32) + t.initialize(x) + assert np.all(x == np.arange(30, dtype=np.float32).reshape(10, 3)) + + +def test09_implicit_conversion(): + t.implicit(np.zeros((2, 2), dtype=np.uint32)) + t.implicit(np.zeros((2, 2, 10), dtype=np.float32)[:, :, 4]) + t.implicit(np.zeros((2, 2, 10), dtype=np.uint32)[:, :, 4]) + + with pytest.raises(TypeError) as excinfo: + t.noimplicit(np.zeros((2, 2), dtype=np.uint32)) + + with pytest.raises(TypeError) as excinfo: + t.noimplicit(np.zeros((2, 2, 10), dtype=np.float32)[:, :, 4]) + + +def test10_implicit_conversion_pytorch(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import torch + c = torch.zeros(3, 5) + c.__dlpack__() + except: + pytest.skip('pytorch is missing') + + t.implicit(torch.zeros(2, 2, dtype=torch.int32)) + t.implicit(torch.zeros(2, 2, 10, dtype=torch.float32)[:, :, 4]) + t.implicit(torch.zeros(2, 2, 10, dtype=torch.int32)[:, :, 4]) + + with pytest.raises(TypeError) as excinfo: + t.noimplicit(torch.zeros(2, 2, dtype=torch.int32)) + + with pytest.raises(TypeError) as excinfo: + t.noimplicit(torch.zeros(2, 2, 10, dtype=torch.float32)[:, :, 4]) + + +def test11_implicit_conversion_tensorflow(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import tensorflow as tf + c = tf.zeros((3, 5)) + except: + pytest.skip('tensorflow is missing') + + t.implicit(tf.zeros((2, 2), dtype=tf.int32)) + t.implicit(tf.zeros((2, 2, 10), dtype=tf.float32)[:, :, 4]) + t.implicit(tf.zeros((2, 2, 10), dtype=tf.int32)[:, :, 4]) + + with pytest.raises(TypeError) as excinfo: + t.noimplicit(tf.zeros((2, 2), dtype=tf.int32)) + + +def test12_implicit_conversion_jax(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import jax.numpy as jnp + c = jnp.zeros((3, 5)) + except: + pytest.skip('jax is missing') + + t.implicit(jnp.zeros((2, 2), dtype=jnp.int32)) + t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.float32)[:, :, 4]) + t.implicit(jnp.zeros((2, 2, 10), dtype=jnp.int32)[:, :, 4]) + + with pytest.raises(TypeError) as excinfo: + t.noimplicit(jnp.zeros((2, 2), dtype=jnp.int32)) + + +def test13_destroy_capsule(): + gc.collect() + dc = t.destruct_count() + a = t.return_dlpack() + assert dc == t.destruct_count() + del a + gc.collect() + assert t.destruct_count() - dc == 1 + + +def test14_consume_numpy(): + gc.collect() + class wrapper: + def __init__(self, value): + self.value = value + def __dlpack__(self): + return self.value + import numpy as np + dc = t.destruct_count() + a = t.return_dlpack() + if hasattr(np, '_from_dlpack'): + x = np._from_dlpack(wrapper(a)) + elif hasattr(np, 'from_dlpack'): + x = np.from_dlpack(wrapper(a)) + else: + pytest.skip('your version of numpy is too old') + + del a + gc.collect() + assert x.shape == (2, 4) + assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) + assert dc == t.destruct_count() + del x + gc.collect() + assert t.destruct_count() - dc == 1 + + +def test15_passthrough(): + gc.collect() + class wrapper: + def __init__(self, value): + self.value = value + def __dlpack__(self): + return self.value + import numpy as np + dc = t.destruct_count() + a = t.return_dlpack() + b = t.passthrough(a) + if hasattr(np, '_from_dlpack'): + y = np._from_dlpack(wrapper(b)) + elif hasattr(np, 'from_dlpack'): + y = np.from_dlpack(wrapper(b)) + else: + pytest.skip('your version of numpy is too old') + + del a + del b + gc.collect() + assert dc == t.destruct_count() + assert y.shape == (2, 4) + assert np.all(y == [[1, 2, 3, 4], [5, 6, 7, 8]]) + del y + gc.collect() + assert t.destruct_count() - dc == 1 + + +def test16_return_numpy(): + gc.collect() + import numpy as np + dc = t.destruct_count() + x = t.ret_numpy() + assert x.shape == (2, 4) + assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) + del x + gc.collect() + assert t.destruct_count() - dc == 1 + + +def test17_return_pytorch(): + try: + import torch + c = torch.zeros(3, 5) + except: + pytest.skip('pytorch is missing') + gc.collect() + import numpy as np + dc = t.destruct_count() + x = t.ret_pytorch() + assert x.shape == (2, 4) + assert torch.all(x == torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])) + del x + gc.collect() + assert t.destruct_count() - dc == 1