Skip to content

Commit

Permalink
API: Add numpy2.h instead and make numpy.h safe
Browse files Browse the repository at this point in the history
This means that users of `numpy.h` cannot be broken, but need to
update to `numpy2.h` if they want to compile for NumPy 2.

Using Macros simply and didn't bother to try to remove unnecessary
code paths.
  • Loading branch information
seberg committed Mar 8, 2024
1 parent 9116d69 commit ea9ba5f
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 15 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ set(PYBIND11_HEADERS
include/pybind11/iostream.h
include/pybind11/functional.h
include/pybind11/numpy.h
include/pybind11/numpy2.h
include/pybind11/operators.h
include/pybind11/pybind11.h
include/pybind11/pytypes.h
Expand Down
2 changes: 1 addition & 1 deletion include/pybind11/eigen/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#pragma once

#include "../numpy.h"
#include "../numpy2.h"
#include "common.h"

/* HINT: To suppress warnings originating from the Eigen headers, use -isystem.
Expand Down
2 changes: 1 addition & 1 deletion include/pybind11/eigen/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#pragma once

#include "../numpy.h"
#include "../numpy2.h"
#include "common.h"

#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
Expand Down
70 changes: 57 additions & 13 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,6 @@ struct handle_type_name<array> {
template <typename type, typename SFINAE = void>
struct npy_format_descriptor;

struct PyArrayDescr_Proxy {
PyObject_HEAD
PyObject *typeobj;
char kind;
char type;
char byteorder;
char _former_flags;
int type_num;
/* Additional fields are NumPy version specific. */
};

/* NumPy 1 proxy (always includes legacy fields) */
struct PyArrayDescr1_Proxy {
PyObject_HEAD
Expand All @@ -80,6 +69,22 @@ struct PyArrayDescr1_Proxy {
PyObject *names;
};

#ifdef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
struct PyArrayDescr_Proxy {
PyObject_HEAD
PyObject *typeobj;
char kind;
char type;
char byteorder;
char _former_flags;
int type_num;
/* Additional fields are NumPy version specific. */
};
#else
/* NumPy 1.x only, we can expose all fields */
typedef PyArrayDescr1_Proxy PyArrayDescr_Proxy;
#endif

/* NumPy 2 proxy, including legacy fields */
struct PyArrayDescr2_Proxy {
PyObject_HEAD
Expand Down Expand Up @@ -164,6 +169,13 @@ PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
int major_version = numpy_version.attr("major").cast<int>();

#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
if (major_version >= 2) {
throw std::runtime_error("module compiled without NumPy 2 support. Please modify the "
"`pybind11/numpy.h` include to `pybind11/numpy2.h` and recompile "
"(this remains NumPy 1.x compatible but has minor changes).");
}
#endif
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
became a private module. */
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
Expand Down Expand Up @@ -276,6 +288,16 @@ struct npy_api {
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
PyObject *,
unsigned char,
PyObject **,
int *,
Py_intptr_t *,
PyObject **,
PyObject *);
#endif
PyObject *(*PyArray_Squeeze_)(PyObject *);
// Unused. Not removed because that affects ABI of the class.
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
Expand All @@ -302,6 +324,9 @@ struct npy_api {
API_PyArray_View = 137,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
API_PyArray_GetArrayParamsFromObject = 278,
#endif
API_PyArray_SetBaseObject = 282
};

Expand Down Expand Up @@ -336,6 +361,9 @@ struct npy_api {
DECL_NPY_API(PyArray_View);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
#endif
DECL_NPY_API(PyArray_SetBaseObject);

#undef DECL_NPY_API
Expand Down Expand Up @@ -644,14 +672,21 @@ class dtype : public object {
}

/// Size of the data type in bytes.
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
#else
ssize_t itemsize() const {
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
return detail::array_descriptor1_proxy(m_ptr)->elsize;
}
return detail::array_descriptor2_proxy(m_ptr)->elsize;
}
#endif

/// Returns true for structured data types.
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
#else
bool has_fields() const {
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
Expand All @@ -661,6 +696,7 @@ class dtype : public object {
}
return detail::array_descriptor2_proxy(m_ptr)->names != nullptr;
}
#endif

/// Single-character code for dtype's kind.
/// For example, floating point types are 'f' and integral types are 'i'.
Expand All @@ -686,21 +722,29 @@ class dtype : public object {
/// Single character for byteorder
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }

/// Alignment of the data type
/// Alignment of the data type
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
#else
ssize_t alignment() const {
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
return detail::array_descriptor1_proxy(m_ptr)->alignment;
}
return detail::array_descriptor2_proxy(m_ptr)->alignment;
}
#endif

/// Flags for the array descriptor
/// Flags for the array descriptor
#ifndef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
#else
std::uint64_t flags() const {
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
}
return detail::array_descriptor2_proxy(m_ptr)->flags;
}
#endif

private:
static object &_dtype_from_pep3118() {
Expand Down
5 changes: 5 additions & 0 deletions include/pybind11/numpy2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#define PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
#include "numpy.h"
#undef PYBIND11_COMPILE_WITH_NUMPY2_SUPPORT
1 change: 1 addition & 0 deletions tests/extra_python_package/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"include/pybind11/gil_safe_call_once.h",
"include/pybind11/iostream.h",
"include/pybind11/numpy.h",
"include/pybind11/numpy2.h",
"include/pybind11/operators.h",
"include/pybind11/options.h",
"include/pybind11/pybind11.h",
Expand Down

0 comments on commit ea9ba5f

Please sign in to comment.