Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reshape for numpy arrays #984

Merged
merged 24 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ struct npy_api {
// Unused. Not removed because that affects ABI of the class.
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes an ABI break?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@henryiii shrugs @rwgk thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any virtual here. If there is no vtable, how does adding or removing member functions change the ABI? I'm having doubts about the correctness of the comment in line 203. I could imagine maybe getting into trouble removing a function, but adding? If a newer extension knows about it, the machine code for it will be in that extension for sure. No?


private:
enum functions {
API_PyArray_GetNDArrayCFeatureVersion = 211,
Expand All @@ -212,10 +214,11 @@ struct npy_api {
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
API_PyArray_DescrNewFromType = 96,
API_PyArray_Newshape = 135,
API_PyArray_Squeeze = 136,
API_PyArray_DescrConverter = 174,
API_PyArray_EquivTypes = 182,
API_PyArray_GetArrayParamsFromObject = 278,
API_PyArray_Squeeze = 136,
API_PyArray_SetBaseObject = 282
};

Expand Down Expand Up @@ -243,11 +246,13 @@ struct npy_api {
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
DECL_NPY_API(PyArray_DescrNewFromType);
DECL_NPY_API(PyArray_Newshape);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_DescrConverter);
DECL_NPY_API(PyArray_EquivTypes);
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_SetBaseObject);

#undef DECL_NPY_API
return api;
}
Expand Down Expand Up @@ -785,6 +790,18 @@ class array : public buffer {
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
}

/// Optional `order` parameter omitted, to be added as needed.
array reshape(ShapeContainer new_shape) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we're at it, I'd also plumb through the order.
even if it's not actually used now (idk), it may be in future numpy versions.
easier to just do it than to explain why not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwgk We would need to add support for it as a function arg, which we currently don't. That can be added in a followup PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can be added in a followup PR.
OK

Copy link
Collaborator

@Skylion007 Skylion007 Aug 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an NORDER Enum that it takes, but I am not sure how to expose that properly. We probably should leave it as is now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thanks for looking into it!

detail::npy_api::PyArray_Dims d
= {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
auto new_array
= reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
if (!new_array) {
throw error_already_set();
}
return new_array;
}

/// Ensure that the argument is a NumPy array
/// In case of an error, nullptr is returned and the Python error is cleared.
static array ensure(handle h, int ExtraFlags = 0) {
Expand Down
7 changes: 7 additions & 0 deletions tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,13 @@ TEST_SUBMODULE(numpy_array, sm) {
return a;
});

sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
return a.reshape({N, M, O});
});
sm.def("reshape_tuple", [](py::array_t<int> a, const std::vector<int> &new_shape) {
return a.reshape(new_shape);
});

sm.def("index_using_ellipsis",
[](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });

Expand Down
31 changes: 28 additions & 3 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def test_array_unchecked_fixed_dims(msg):
assert m.proxy_auxiliaries2_const_ref(z1)


def test_array_unchecked_dyn_dims(msg):
def test_array_unchecked_dyn_dims():
z1 = np.array([[1, 2], [3, 4]], dtype="float64")
m.proxy_add2_dyn(z1, 10)
assert np.all(z1 == [[11, 12], [13, 14]])
Expand Down Expand Up @@ -444,7 +444,7 @@ def test_initializer_list():
assert m.array_initializer_list4().shape == (1, 2, 3, 4)


def test_array_resize(msg):
def test_array_resize():
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
m.array_reshape2(a)
assert a.size == 9
Expand All @@ -470,12 +470,37 @@ def test_array_resize(msg):


@pytest.mark.xfail("env.PYPY")
def test_array_create_and_resize(msg):
def test_array_create_and_resize():
a = m.create_and_resize(2)
assert a.size == 4
assert np.all(a == 42.0)


def test_reshape_initializer_list():
a = np.arange(2 * 7 * 3) + 1
x = m.reshape_initializer_list(a, 2, 7, 3)
assert x.shape == (2, 7, 3)
assert list(x[1][4]) == [34, 35, 36]
with pytest.raises(ValueError) as excinfo:
m.reshape_initializer_list(a, 1, 7, 3)
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (1,7,3)"


def test_reshape_tuple():
a = np.arange(3 * 7 * 2) + 1
x = m.reshape_tuple(a, (3, 7, 2))
assert x.shape == (3, 7, 2)
assert list(x[1][4]) == [23, 24]
y = m.reshape_tuple(x, (x.size,))
assert y.shape == (42,)
with pytest.raises(ValueError) as excinfo:
m.reshape_tuple(a, (3, 7, 1))
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (3,7,1)"
with pytest.raises(ValueError) as excinfo:
m.reshape_tuple(a, ())
assert str(excinfo.value) == "cannot reshape array of size 42 into shape ()"


def test_index_using_ellipsis():
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
assert a.shape == (6,)
Expand Down