Skip to content

Commit

Permalink
MFIter in Python
Browse files Browse the repository at this point in the history
Iterate boxes and particle tiles in Python. This simplifies downstream usage
with derived data types, avoiding "downcasting" iterators to their base
types during iteration.
  • Loading branch information
ax3l committed Feb 8, 2024
1 parent e442401 commit c247f77
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 97 deletions.
63 changes: 0 additions & 63 deletions src/Base/Iterator.H

This file was deleted.

32 changes: 13 additions & 19 deletions src/Base/MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
*/
#include "pyAMReX.H"

#include "Base/Iterator.H"

#include <AMReX_BoxArray.H>
#include <AMReX_DistributionMapping.H>
#include <AMReX_FArrayBox.H>
Expand Down Expand Up @@ -60,20 +58,26 @@ void init_MultiFab(py::module &m)
return r;
}
)
.def(py::init< FabArrayBase const & >())
.def(py::init< FabArrayBase const & >(),
// while the created iterator (argument 1: this) exists,
// keep the FabArrayBase (argument 2) alive
py::keep_alive<1, 2>()
)
//.def(py::init< FabArrayBase const &, MFItInfo const & >())

.def(py::init< MultiFab const & >())
.def(py::init< MultiFab const & >(),
// while the created iterator (argument 1: this) exists,
// keep the MultiFab (argument 2) alive
py::keep_alive<1, 2>()
)
//.def(py::init< MultiFab const &, MFItInfo const & >())

//.def(py::init< iMultiFab const & >())
//.def(py::init< iMultiFab const &, MFItInfo const & >())

// eq. to void operator++()
.def("__next__",
&pyAMReX::iterator_next<MFIter>,
py::return_value_policy::reference_internal
)
// helpers for iteration __next__
.def("_incr", &MFIter::operator++)
.def("finalize", &MFIter::Finalize)

.def("tilebox", py::overload_cast< >(&MFIter::tilebox, py::const_))
.def("tilebox", py::overload_cast< IntVect const & >(&MFIter::tilebox, py::const_))
Expand Down Expand Up @@ -114,16 +118,6 @@ void init_MultiFab(py::module &m)
.def_property_readonly("size", &FabArrayBase::size)

.def_property_readonly("n_grow_vect", &FabArrayBase::nGrowVect)

/* data access in Box index space */
.def("__iter__",
[](FabArrayBase& fab) {
return MFIter(fab);
},
// while the returned iterator (argument 0) exists,
// keep the FabArrayBase (argument 1; usually a MultiFab) alive
py::keep_alive<0, 1>()
)
;

py_FabArray_FArrayBox
Expand Down
2 changes: 1 addition & 1 deletion src/Base/PODVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void make_PODVector(py::module &m, std::string typestr, std::string allocstr)
);
Gpu::streamSynchronize();
return h_data;
})
}, py::return_value_policy::move)

// front
// back
Expand Down
23 changes: 9 additions & 14 deletions src/Particle/ParticleContainer.H
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

#include "pyAMReX.H"

#include "Base/Iterator.H"

#include <AMReX_BoxArray.H>
#include <AMReX_GpuAllocators.H>
#include <AMReX_IntVect.H>
Expand Down Expand Up @@ -56,10 +54,14 @@ void make_Base_Iterators (py::module &m, std::string allocstr)
std::string particle_it_base_name = std::string("Par");
if (is_const) particle_it_base_name += "Const";
particle_it_base_name += "IterBase_" + suffix + "_" + allocstr;
auto py_it_base = py::class_<iterator_base, MFIter>(m, particle_it_base_name.c_str())
auto py_it_base = py::class_<iterator_base, MFIter>(m, particle_it_base_name.c_str(), py::dynamic_attr())
.def(py::init<container&, int>(),
// while the created iterator (argument 1: this) exists,
// keep the ParticleContainer (argument 2) alive
py::keep_alive<1, 2>(),
py::arg("particle_container"), py::arg("level"))
//.def(py::init<container&, int, MFItInfo&>(),
// py::keep_alive<1, 2>(),
// py::arg("particle_container"), py::arg("level"), py::arg("info"))

.def("particle_tile", &iterator_base::GetParticleTile,
Expand All @@ -75,19 +77,12 @@ void make_Base_Iterators (py::module &m, std::string allocstr)
.def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles)
.def_property_readonly("level", &iterator_base::GetLevel)
.def_property_readonly("pair_index", &iterator_base::GetPairIndex)
.def_property_readonly("is_valid", &iterator_base::isValid)
.def("geom", &iterator_base::Geom, py::arg("level"))

// eq. to void operator++()
.def("__next__",
&pyAMReX::iterator_next<iterator_base>,
py::return_value_policy::reference_internal
)
.def("__iter__",
[](iterator_base & it) -> iterator_base & {
return it;
},
py::return_value_policy::reference_internal
)
// helpers for iteration __next__
.def("_incr", &iterator_base::operator++)
.def("finalize", &iterator_base::Finalize)
;

// only legacy particle has an AoS data structure for positions and id+cpu
Expand Down
40 changes: 40 additions & 0 deletions src/amrex/Iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
This file is part of pyAMReX
Copyright 2023 AMReX community
Authors: Axel Huebl
License: BSD-3-Clause-LBNL
"""


def next(self):
"""This is a helper function for the C++ equivalent of void operator++()
In Python, iterators always are called with __next__, even for the
first access. This means we need to handle the first iterator element
explicitly, otherwise we will jump directly to the 2nd element. We do
this the same way as pybind11 does this, via a little state:
https://github.com/AMReX-Codes/pyamrex/pull/50
https://github.com/AMReX-Codes/pyamrex/pull/262
https://github.com/pybind/pybind11/blob/v2.10.0/include/pybind11/pybind11.h#L2269-L2282
Important: we must NOT copy the AMReX iterator (unnecessary and expensive).
self: the current iterator
returns: the updated iterator
"""
if hasattr(self, "first_or_done") is False:
self.first_or_done = True

first_or_done = self.first_or_done
if first_or_done:
first_or_done = False
self.first_or_done = first_or_done
else:
self._incr()
if self.is_valid is False:
self.first_or_done = True
self.finalize()
raise StopIteration

return self
10 changes: 10 additions & 0 deletions src/amrex/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
License: BSD-3-Clause-LBNL
"""

from .Iterator import next


def mf_to_numpy(amr, self, copy=False, order="F"):
"""
Expand Down Expand Up @@ -92,7 +94,15 @@ def mf_to_cupy(self, copy=False, order="F"):
def register_MultiFab_extension(amr):
"""MultiFab helper methods"""

# register member functions for the MFIter type
amr.MFIter.__next__ = next

# FabArrayBase: iterate as data access in Box index space
amr.FabArrayBase.__iter__ = lambda fab: amr.MFIter(fab)

# register member functions for the MultiFab type
amr.MultiFab.__iter__ = lambda mfab: amr.MFIter(mfab)

amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(
amr, self, copy, order
)
Expand Down
15 changes: 15 additions & 0 deletions src/amrex/ParticleContainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
License: BSD-3-Clause-LBNL
"""

from .Iterator import next


def pc_to_df(self, local=True, comm=None, root_rank=0):
"""
Expand Down Expand Up @@ -116,6 +118,19 @@ def register_ParticleContainer_extension(amr):
import inspect
import sys

# register member functions for every Par(Const)Iter* type
for _, ParIter_type in inspect.getmembers(
sys.modules[amr.__name__],
lambda member: inspect.isclass(member)
and member.__module__ == amr.__name__
and (
member.__name__.startswith("ParIter")
or member.__name__.startswith("ParConstIter")
),
):
ParIter_type.__next__ = next
ParIter_type.__iter__ = lambda self: self

# register member functions for every ParticleContainer_* type
for _, ParticleContainer_type in inspect.getmembers(
sys.modules[amr.__name__],
Expand Down

0 comments on commit c247f77

Please sign in to comment.