From 467de7ffff7e2131e31efc31f2d31bce11638839 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Mon, 4 Nov 2024 13:08:56 +0200 Subject: [PATCH] compiler: Drop redundant code --- devito/ir/iet/nodes.py | 7 +- devito/mpi/halo_scheme.py | 8 +- devito/passes/iet/mpi.py | 160 ++++++++------------------------------ tests/test_dse.py | 1 - tests/test_mpi.py | 20 +++++ 5 files changed, 57 insertions(+), 139 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index fc85f3c545..9def607678 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1461,16 +1461,13 @@ def __repr__(self): for f in self.functions: loc_indices = set().union(*[self.halo_scheme.fmapper[f].loc_indices.values()]) loc_indices = list(loc_indices) - if loc_indices: - loc_indices_str = str(loc_indices) - else: - loc_indices_str = "" + loc_indices_str = str(list(loc_indices)) if loc_indices else "" fstrings.append(f"{f.name}{loc_indices_str}") functions = ",".join(fstrings) - return "<%s(%s)>" % (self.__class__.__name__, functions) + return f"<{self.__class__.__name__}({functions})>" @property def halo_scheme(self): diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index e00ad204f0..f6afa3a07e 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -130,17 +130,13 @@ def __repr__(self): fstrings = [] for f in self.fmapper: loc_indices = set().union(*[self._mapper[f].loc_indices.values()]) - loc_indices = list(loc_indices) - if loc_indices: - loc_indices_str = str(loc_indices) - else: - loc_indices_str = "" + loc_indices_str = str(list(loc_indices)) if loc_indices else "" fstrings.append(f"{f.name}{loc_indices_str}") functions = ",".join(fstrings) - return "%s<%s>" % (self.__class__.__name__, functions) + return f"<{self.__class__.__name__}({functions})>" def __eq__(self, other): return (isinstance(other, HaloScheme) and diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index e8b0651f94..ccc3dc729f 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -60,42 +60,25 @@ def _hoist_halospots(iet): Hoist HaloSpots from inner to outer Iterations where all data dependencies would be honored. - This function looks for hoisting opportunities in two ways: - - - `_process_halo_to_halo`: The first one (more aggressive) examines halo to halo - opportunites looking for halospots that redundantly update the same slice of - a `TimeFunction` - Example: - - haloupd v[t0] - for time for time - write to v[t1] - read from vx[t0] write to v[t1] - read from v[t0] - haloupd v[t1] haloupd v[t1] - read from v[t1] read from v[t1] - haloupd v[t0] read from v[t0] - read from v[t0] - - - `_process_halo`: The second one (more general) examines the data dependencies - of HaloSpots and if honored it hoists HaloSpots from inner to outer iterations - - haloupd vx[t0] - for time for time - write to v[t1] - read from vx[t0] write to v[t1] - read from v[t0] - haloupd v[t1] haloupd v[t1] - read from v[t1] read from v[t1] - haloupd v[t0] read from v[t0] - read from v[t0] + haloupd v[t0] + for time for time + W v[t1]- R v[t0] W v[t1]- R v[t0] + haloupd v[t1] haloupd v[t1] + R v[t1] R v[t1] + haloupd v[t0] R v[t0] + R v[t0] """ # Precompute scopes to save time scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()} - cond_mapper = _create_cond_mapper(iet) - i_mapper = defaultdict(list) + cond_mapper = _create_cond_mapper(iet) - hs_mapper = {} + # Analysis + hsmapper = {} + imapper = defaultdict(list) # Look for parent Iterations of children HaloSpots for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items(): @@ -116,29 +99,21 @@ def _hoist_halospots(iet): for i in hs1.halo_scheme.loc_values): continue - # Look for halo-to-halo optimization possibilities - _process_halo_to_halo(hs0, hs1, iters, scopes, hs_mapper, i_mapper) - - # If the HaloSpot is already processed, skip - if hs0 in hs_mapper: - continue - - # Look for halo-specific optimization possibilities - _process_halo(hs0, iters, scopes, hs_mapper, i_mapper) + # Compare hs0 to subsequent halo_spots, looking for optimization + # possibilities + _process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper) mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss)) - for i, hss in i_mapper.items()} - + for i, hss in imapper.items()} mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs) - for i, hs in hs_mapper.items()}) + for i, hs in hsmapper.items()}) - if mapper: - iet = Transformer(mapper, nested=True).visit(iet) + iet = Transformer(mapper, nested=True).visit(iet) return iet -def _process_halo_to_halo(hs0, hs1, iters, scopes, hs_mapper, i_mapper): +def _process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper): # Loop over the functions in the HaloSpots for f, v in hs1.fmapper.items(): @@ -151,51 +126,27 @@ def _process_halo_to_halo(hs0, hs1, iters, scopes, hs_mapper, i_mapper): continue for iter in iters: - # Ensure they are merge-able + # If also merge-able we can start hoisting the latter for dep in scopes[iter].d_flow.project(f): if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()): break else: - fm = hs1.halo_scheme.fmapper[f] + hse = hs1.halo_scheme.fmapper[f] raw_loc_indices = {} - # Since lifted, we need to update the loc_indices with - # known values + # Entering here means we can lift, and we need to update + # the loc_indices with known values # TODO: Can I get this in a more elegant way? - for d in fm.loc_indices: - root_min = fm.loc_indices[d].symbolic_min - new_min = root_min.subs(fm.loc_indices[d].root, - fm.loc_indices[d].root.symbolic_min) + for d in hse.loc_indices: + root_min = hse.loc_indices[d].symbolic_min + new_min = root_min.subs(hse.loc_indices[d].root, + hse.loc_indices[d].root.symbolic_min) raw_loc_indices[d] = new_min - hs_entry = fm.rebuild(loc_indices=frozendict(raw_loc_indices)) - hs1.halo_scheme.fmapper[f] = hs_entry - - hs_mapper[hs1] = hs_mapper.get(hs1, hs1.halo_scheme).drop(f) - i_mapper[iter].append(hs1.halo_scheme.project(f)) + hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices)) + hs1.halo_scheme.fmapper[f] = hse - -def _process_halo(hs0, iters, scopes, h_mapper, imapper): - # Look for halo-specific optimization possibilities - for f, v in hs0.fmapper.items(): - loc_dims = frozenset().union([q for d in v.loc_indices - for q in d._defines]) - for n, iter in enumerate(iters): - if iter not in scopes: - continue - - candidates = [i.dim._defines for i in iters[n:]] - all_candidates = set().union(*candidates) - reads = scopes[iter].getreads(f) - if any(set(a.ispace.dimensions) & all_candidates for a in reads): - continue - - for dep in scopes[iter].d_flow.project(f): - if not any(r(dep, candidates, loc_dims) - for r in hoist_rules()): - break - else: - h_mapper[hs0] = h_mapper.get(hs0, hs0.halo_scheme).drop(f) - imapper[iter].append(hs0.halo_scheme.project(f)) + hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f) + imapper[iter].append(hs1.halo_scheme.project(f)) def _merge_halospots(iet): @@ -214,10 +165,6 @@ def _merge_halospots(iet): scope = Scope([e.expr for e in FindNodes(Expression).visit(iter)]) - # TOFIX: Why only comparing against the first HaloSpot? - # We could similary to `_hoist_halospots` compare all pairs of HaloSpots - # and merge them based upon some priority ranking on which pair we prefer to - # merge hs0 = halo_spots[0] mapper[hs0] = hs0.halo_scheme @@ -235,14 +182,9 @@ def _merge_halospots(iet): if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()): break else: - try: - hs = hs1.halo_scheme.project(f) - mapper[hs0] = HaloScheme.union([mapper[hs0], hs]) - mapper[hs1] = mapper[hs1].drop(f) - except ValueError: - # `hs1.loc_indices= -> R` and `candidates=({time}, {x})` => False - # E.g., `dep=W -> R`, `dep.cause={t,time}` and - # `candidates=({x},)` => True - return (all(i & set(dep.distance_mapper) for i in candidates) and - not any(i & dep.cause for i in candidates) and - not any(i & loc_dims for i in candidates)) - - def rule1(dep, candidates, loc_dims): - # A reduction isn't a stopper to hoisting - return dep.write is not None and dep.write.is_reduction - - rules = [rule0, rule1] - - return rules diff --git a/tests/test_dse.py b/tests/test_dse.py index 74e9cf72f3..c82848a28a 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -13,7 +13,6 @@ norm, grad, div, dimensions, switchconfig, configuration, centered, first_derivative, solve, transpose, Abs, cos, sin, sqrt, floor, Ge, Lt, Derivative, solve) - from devito.exceptions import InvalidArgument, InvalidOperator from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, FindSymbols, ParallelIteration, retrieve_iteration_tree) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index b16f4c41f5..e51a8f06ee 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -23,6 +23,7 @@ from examples.seismic.acoustic import acoustic_setup from examples.seismic import Receiver, TimeAxis, demo_model +from tests.test_dse import TestTTI class TestDistributor: @@ -1116,6 +1117,7 @@ def test_avoid_haloupdate_with_constant_index(self, mode): eq = Eq(u.forward, u[t, 1] + u[t, 1 + x.symbolic_min] + u[t, x]) op = Operator(eq) calls = FindNodes(Call).visit(op) + assert len(calls) == 0 @pytest.mark.parallel(mode=1) @@ -1211,6 +1213,7 @@ def test_hoist_haloupdate_from_innerloop(self, mode): calls = FindNodes(Call).visit(op) assert len(calls) == 1 + # Also make sure the Call is at the right place in the IET assert op.body.body[-1].body[1].body[0].body[0].body[0].body[0].is_Call assert op.body.body[-1].body[1].body[0].body[0].body[1].is_Iteration @@ -1236,6 +1239,7 @@ def test_unhoist_haloupdate_if_invariant(self, mode): op = Operator(eqns) op.apply(time=1) + calls = FindNodes(Call).visit(op) assert len(calls) == 2 @@ -2860,6 +2864,22 @@ def test_elastic_structure(self, mode): assert calls[4].arguments[1] is v[1] +class TestTTI_w_MPI: + + @pytest.mark.parallel(mode=[(1)]) + def test_halo_structure(self, mode): + + mytest = TestTTI() + solver = mytest.tti_operator(opt='advanced', space_order=8) + op = solver.op_fwd(save=False) + + calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)] + + assert len(calls) == 1 + assert calls[0].functions[0].name == 'u' + assert calls[0].functions[1].name == 'v' + + if __name__ == "__main__": # configuration['mpi'] = 'overlap' # TestDecomposition().test_reshape_left_right()