Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reshape for numpy arrays #984

Merged
merged 24 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ struct npy_api {
// Unused. Not removed because that affects ABI of the class.
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes an ABI break?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@henryiii shrugs @rwgk thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any virtual here. If there is no vtable, how does adding or removing member functions change the ABI? I'm having doubts about the correctness of the comment in line 203. I could imagine maybe getting into trouble removing a function, but adding? If a newer extension knows about it, the machine code for it will be in that extension for sure. No?


private:
enum functions {
API_PyArray_GetNDArrayCFeatureVersion = 211,
Expand All @@ -213,6 +215,7 @@ struct npy_api {
API_PyArray_DescrFromScalar = 57,
API_PyArray_FromAny = 69,
API_PyArray_Resize = 80,
API_PyArray_Newshape = 135,
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
API_PyArray_CopyInto = 82,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94,
Expand Down Expand Up @@ -244,6 +247,7 @@ struct npy_api {
DECL_NPY_API(PyArray_DescrFromScalar);
DECL_NPY_API(PyArray_FromAny);
DECL_NPY_API(PyArray_Resize);
DECL_NPY_API(PyArray_Newshape);
DECL_NPY_API(PyArray_CopyInto);
DECL_NPY_API(PyArray_NewCopy);
DECL_NPY_API(PyArray_NewFromDescr);
Expand All @@ -253,6 +257,7 @@ struct npy_api {
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
DECL_NPY_API(PyArray_Squeeze);
DECL_NPY_API(PyArray_SetBaseObject);

#undef DECL_NPY_API
return api;
}
Expand Down Expand Up @@ -790,6 +795,15 @@ class array : public buffer {
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
}

array reshape(ShapeContainer new_shape) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we're at it, I'd also plumb through the order.
even if it's not actually used now (idk), it may be in future numpy versions.
easier to just do it than to explain why not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwgk We would need to add support for it as a function arg, which we currently don't. That can be added in a followup PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can be added in a followup PR.
OK

Copy link
Collaborator

@Skylion007 Skylion007 Aug 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an NORDER Enum that it takes, but I am not sure how to expose that properly. We probably should leave it as is now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thanks for looking into it!

detail::npy_api::PyArray_Dims d
= {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
// try to reshape, set ordering param to 0 cause it's not used anyway
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best to remove this comment, and add a comment just above the array reshape(ShapeContainer new_shape) { line, e.g.

// Optional `order` parameter omitted here, to be added as needed.

return reinterpret_steal<array>(
detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
}


/// Ensure that the argument is a NumPy array
/// In case of an error, nullptr is returned and the Python error is cleared.
static array ensure(handle h, int ExtraFlags = 0) {
Expand Down
13 changes: 13 additions & 0 deletions tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,19 @@ TEST_SUBMODULE(numpy_array, sm) {
return a;
});

sm.def("array_reshape1", [](py::array_t<double> a, size_t N) {
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
return a.reshape({N, N, N});
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the tests lean, I'd remove this function completely (and test_array_reshape with it). The exact same interface is covered by test_array_reshape. (I realize it takes away from PyPy coverage because of the xfail for test_array_reshape, but that's not enough reason in my mind to keep this one.)
If you feel differently, I'd at least change this to something like {N, N+5, N+7} to get a little bit more value.


sm.def("create_and_reshape", [](size_t N, size_t M, size_t O) {
py::array_t<double> a;
a.resize({N*M*O});
std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.);
return a.reshape({N, M, O});
});
sm.def("reshape_tuple", [](py::array_t<double> a, const std::vector<int> &new_shape) {
return a.reshape(new_shape);
});
sm.def("index_using_ellipsis",
[](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });

Expand Down
18 changes: 18 additions & 0 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,24 @@ def test_array_create_and_resize(msg):
assert np.all(a == 42.0)


def test_array_reshape(msg):
a = np.random.randn(10 * 10 * 10).astype("float64")
x = m.array_reshape1(a, 10)
assert x.shape == (10, 10, 10)


@pytest.mark.xfail("env.PYPY")
def test_create_and_reshape(msg):
x = m.create_and_reshape(10, 20, 30)
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
assert x.shape == (10, 20, 30)


def test_reshape_tuple(msg):
a = np.random.randn(10 * 10 * 10).astype("float64")
x = m.reshape_tuple(a, (10, 10, 10))
Skylion007 marked this conversation as resolved.
Show resolved Hide resolved
assert x.shape == (10, 10, 10)


def test_index_using_ellipsis():
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
assert a.shape == (6,)
Expand Down