From c247f77efe5ef30255debd4848c004aa666a85c4 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Wed, 7 Feb 2024 13:57:55 -0800 Subject: [PATCH] MFIter in Python Iterate boxes and particle tiles in Python. This simplifies downstream usage with derived data types, avoiding "downcasting" iterators to their base types during iteration. --- src/Base/Iterator.H | 63 -------------------------------- src/Base/MultiFab.cpp | 32 +++++++--------- src/Base/PODVector.cpp | 2 +- src/Particle/ParticleContainer.H | 23 +++++------- src/amrex/Iterator.py | 40 ++++++++++++++++++++ src/amrex/MultiFab.py | 10 +++++ src/amrex/ParticleContainer.py | 15 ++++++++ 7 files changed, 88 insertions(+), 97 deletions(-) delete mode 100644 src/Base/Iterator.H create mode 100644 src/amrex/Iterator.py diff --git a/src/Base/Iterator.H b/src/Base/Iterator.H deleted file mode 100644 index c4693588..00000000 --- a/src/Base/Iterator.H +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2021-2022 The AMReX Community - * - * Authors: Axel Huebl - * License: BSD-3-Clause-LBNL - */ -#pragma once - -#include "pyAMReX.H" - -#include -#include -#include -#include -#include -#include - -#include -#include - - -namespace pyAMReX -{ - /** This is a helper function for the C++ equivalent of void operator++() - * - * In Python, iterators always are called with __next__, even for the - * first access. This means we need to handle the first iterator element - * explicitly, otherwise we will jump directly to the 2nd element. We do - * this the same way as pybind11 does this, via a little state: - * https://github.com/AMReX-Codes/pyamrex/pull/50 - * https://github.com/pybind/pybind11/blob/v2.10.0/include/pybind11/pybind11.h#L2269-L2282 - * - * To avoid unnecessary (and expensive) copies, remember to only call this - * helper always with py::return_value_policy::reference_internal! - * - * - * @tparam T_Iterator This is usally MFIter or Par(Const)Iter or derived classes - * @param it the current iterator - * @return the updated iterator - */ - template< typename T_Iterator > - T_Iterator & - iterator_next( T_Iterator & it ) - { - py::object self = py::cast(it); - if (!py::hasattr(self, "first_or_done")) - self.attr("first_or_done") = true; - - bool first_or_done = self.attr("first_or_done").cast(); - if (first_or_done) { - first_or_done = false; - self.attr("first_or_done") = first_or_done; - } - else - ++it; - if( !it.isValid() ) - { - first_or_done = true; - it.Finalize(); - throw py::stop_iteration(); - } - return it; - } -} diff --git a/src/Base/MultiFab.cpp b/src/Base/MultiFab.cpp index 5853d0f5..09921891 100644 --- a/src/Base/MultiFab.cpp +++ b/src/Base/MultiFab.cpp @@ -5,8 +5,6 @@ */ #include "pyAMReX.H" -#include "Base/Iterator.H" - #include #include #include @@ -60,20 +58,26 @@ void init_MultiFab(py::module &m) return r; } ) - .def(py::init< FabArrayBase const & >()) + .def(py::init< FabArrayBase const & >(), + // while the created iterator (argument 1: this) exists, + // keep the FabArrayBase (argument 2) alive + py::keep_alive<1, 2>() + ) //.def(py::init< FabArrayBase const &, MFItInfo const & >()) - .def(py::init< MultiFab const & >()) + .def(py::init< MultiFab const & >(), + // while the created iterator (argument 1: this) exists, + // keep the MultiFab (argument 2) alive + py::keep_alive<1, 2>() + ) //.def(py::init< MultiFab const &, MFItInfo const & >()) //.def(py::init< iMultiFab const & >()) //.def(py::init< iMultiFab const &, MFItInfo const & >()) - // eq. to void operator++() - .def("__next__", - &pyAMReX::iterator_next, - py::return_value_policy::reference_internal - ) + // helpers for iteration __next__ + .def("_incr", &MFIter::operator++) + .def("finalize", &MFIter::Finalize) .def("tilebox", py::overload_cast< >(&MFIter::tilebox, py::const_)) .def("tilebox", py::overload_cast< IntVect const & >(&MFIter::tilebox, py::const_)) @@ -114,16 +118,6 @@ void init_MultiFab(py::module &m) .def_property_readonly("size", &FabArrayBase::size) .def_property_readonly("n_grow_vect", &FabArrayBase::nGrowVect) - - /* data access in Box index space */ - .def("__iter__", - [](FabArrayBase& fab) { - return MFIter(fab); - }, - // while the returned iterator (argument 0) exists, - // keep the FabArrayBase (argument 1; usually a MultiFab) alive - py::keep_alive<0, 1>() - ) ; py_FabArray_FArrayBox diff --git a/src/Base/PODVector.cpp b/src/Base/PODVector.cpp index ffa43fd7..109745b7 100644 --- a/src/Base/PODVector.cpp +++ b/src/Base/PODVector.cpp @@ -81,7 +81,7 @@ void make_PODVector(py::module &m, std::string typestr, std::string allocstr) ); Gpu::streamSynchronize(); return h_data; - }) + }, py::return_value_policy::move) // front // back diff --git a/src/Particle/ParticleContainer.H b/src/Particle/ParticleContainer.H index 8cc12b91..532542a1 100644 --- a/src/Particle/ParticleContainer.H +++ b/src/Particle/ParticleContainer.H @@ -7,8 +7,6 @@ #include "pyAMReX.H" -#include "Base/Iterator.H" - #include #include #include @@ -56,10 +54,14 @@ void make_Base_Iterators (py::module &m, std::string allocstr) std::string particle_it_base_name = std::string("Par"); if (is_const) particle_it_base_name += "Const"; particle_it_base_name += "IterBase_" + suffix + "_" + allocstr; - auto py_it_base = py::class_(m, particle_it_base_name.c_str()) + auto py_it_base = py::class_(m, particle_it_base_name.c_str(), py::dynamic_attr()) .def(py::init(), + // while the created iterator (argument 1: this) exists, + // keep the ParticleContainer (argument 2) alive + py::keep_alive<1, 2>(), py::arg("particle_container"), py::arg("level")) //.def(py::init(), + // py::keep_alive<1, 2>(), // py::arg("particle_container"), py::arg("level"), py::arg("info")) .def("particle_tile", &iterator_base::GetParticleTile, @@ -75,19 +77,12 @@ void make_Base_Iterators (py::module &m, std::string allocstr) .def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles) .def_property_readonly("level", &iterator_base::GetLevel) .def_property_readonly("pair_index", &iterator_base::GetPairIndex) + .def_property_readonly("is_valid", &iterator_base::isValid) .def("geom", &iterator_base::Geom, py::arg("level")) - // eq. to void operator++() - .def("__next__", - &pyAMReX::iterator_next, - py::return_value_policy::reference_internal - ) - .def("__iter__", - [](iterator_base & it) -> iterator_base & { - return it; - }, - py::return_value_policy::reference_internal - ) + // helpers for iteration __next__ + .def("_incr", &iterator_base::operator++) + .def("finalize", &iterator_base::Finalize) ; // only legacy particle has an AoS data structure for positions and id+cpu diff --git a/src/amrex/Iterator.py b/src/amrex/Iterator.py new file mode 100644 index 00000000..1cf8109a --- /dev/null +++ b/src/amrex/Iterator.py @@ -0,0 +1,40 @@ +""" +This file is part of pyAMReX + +Copyright 2023 AMReX community +Authors: Axel Huebl +License: BSD-3-Clause-LBNL +""" + + +def next(self): + """This is a helper function for the C++ equivalent of void operator++() + + In Python, iterators always are called with __next__, even for the + first access. This means we need to handle the first iterator element + explicitly, otherwise we will jump directly to the 2nd element. We do + this the same way as pybind11 does this, via a little state: + https://github.com/AMReX-Codes/pyamrex/pull/50 + https://github.com/AMReX-Codes/pyamrex/pull/262 + https://github.com/pybind/pybind11/blob/v2.10.0/include/pybind11/pybind11.h#L2269-L2282 + + Important: we must NOT copy the AMReX iterator (unnecessary and expensive). + + self: the current iterator + returns: the updated iterator + """ + if hasattr(self, "first_or_done") is False: + self.first_or_done = True + + first_or_done = self.first_or_done + if first_or_done: + first_or_done = False + self.first_or_done = first_or_done + else: + self._incr() + if self.is_valid is False: + self.first_or_done = True + self.finalize() + raise StopIteration + + return self diff --git a/src/amrex/MultiFab.py b/src/amrex/MultiFab.py index d53f1371..1c84bccc 100644 --- a/src/amrex/MultiFab.py +++ b/src/amrex/MultiFab.py @@ -6,6 +6,8 @@ License: BSD-3-Clause-LBNL """ +from .Iterator import next + def mf_to_numpy(amr, self, copy=False, order="F"): """ @@ -92,7 +94,15 @@ def mf_to_cupy(self, copy=False, order="F"): def register_MultiFab_extension(amr): """MultiFab helper methods""" + # register member functions for the MFIter type + amr.MFIter.__next__ = next + + # FabArrayBase: iterate as data access in Box index space + amr.FabArrayBase.__iter__ = lambda fab: amr.MFIter(fab) + # register member functions for the MultiFab type + amr.MultiFab.__iter__ = lambda mfab: amr.MFIter(mfab) + amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy( amr, self, copy, order ) diff --git a/src/amrex/ParticleContainer.py b/src/amrex/ParticleContainer.py index ab934920..6e37d00e 100644 --- a/src/amrex/ParticleContainer.py +++ b/src/amrex/ParticleContainer.py @@ -6,6 +6,8 @@ License: BSD-3-Clause-LBNL """ +from .Iterator import next + def pc_to_df(self, local=True, comm=None, root_rank=0): """ @@ -116,6 +118,19 @@ def register_ParticleContainer_extension(amr): import inspect import sys + # register member functions for every Par(Const)Iter* type + for _, ParIter_type in inspect.getmembers( + sys.modules[amr.__name__], + lambda member: inspect.isclass(member) + and member.__module__ == amr.__name__ + and ( + member.__name__.startswith("ParIter") + or member.__name__.startswith("ParConstIter") + ), + ): + ParIter_type.__next__ = next + ParIter_type.__iter__ = lambda self: self + # register member functions for every ParticleContainer_* type for _, ParticleContainer_type in inspect.getmembers( sys.modules[amr.__name__],