Skip to content

Commit

Permalink
compiler: Drop redundant code
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 4, 2024
1 parent 0a8c49f commit 467de7f
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 139 deletions.
7 changes: 2 additions & 5 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,16 +1461,13 @@ def __repr__(self):
for f in self.functions:
loc_indices = set().union(*[self.halo_scheme.fmapper[f].loc_indices.values()])
loc_indices = list(loc_indices)
if loc_indices:
loc_indices_str = str(loc_indices)
else:
loc_indices_str = ""
loc_indices_str = str(list(loc_indices)) if loc_indices else ""

fstrings.append(f"{f.name}{loc_indices_str}")

functions = ",".join(fstrings)

return "<%s(%s)>" % (self.__class__.__name__, functions)
return f"<{self.__class__.__name__}({functions})>"

@property
def halo_scheme(self):
Expand Down
8 changes: 2 additions & 6 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,13 @@ def __repr__(self):
fstrings = []
for f in self.fmapper:
loc_indices = set().union(*[self._mapper[f].loc_indices.values()])
loc_indices = list(loc_indices)
if loc_indices:
loc_indices_str = str(loc_indices)
else:
loc_indices_str = ""
loc_indices_str = str(list(loc_indices)) if loc_indices else ""

fstrings.append(f"{f.name}{loc_indices_str}")

functions = ",".join(fstrings)

return "%s<%s>" % (self.__class__.__name__, functions)
return f"<{self.__class__.__name__}({functions})>"

def __eq__(self, other):
return (isinstance(other, HaloScheme) and
Expand Down
160 changes: 33 additions & 127 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,42 +60,25 @@ def _hoist_halospots(iet):
Hoist HaloSpots from inner to outer Iterations where all data dependencies
would be honored.
This function looks for hoisting opportunities in two ways:
- `_process_halo_to_halo`: The first one (more aggressive) examines halo to halo
opportunites looking for halospots that redundantly update the same slice of
a `TimeFunction`
Example:
haloupd v[t0]
for time for time
write to v[t1] - read from vx[t0] write to v[t1] - read from v[t0]
haloupd v[t1] haloupd v[t1]
read from v[t1] read from v[t1]
haloupd v[t0] read from v[t0]
read from v[t0]
- `_process_halo`: The second one (more general) examines the data dependencies
of HaloSpots and if honored it hoists HaloSpots from inner to outer iterations
haloupd vx[t0]
for time for time
write to v[t1] - read from vx[t0] write to v[t1] - read from v[t0]
haloupd v[t1] haloupd v[t1]
read from v[t1] read from v[t1]
haloupd v[t0] read from v[t0]
read from v[t0]
haloupd v[t0]
for time for time
W v[t1]- R v[t0] W v[t1]- R v[t0]
haloupd v[t1] haloupd v[t1]
R v[t1] R v[t1]
haloupd v[t0] R v[t0]
R v[t0]
"""

# Precompute scopes to save time
scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()}
cond_mapper = _create_cond_mapper(iet)

i_mapper = defaultdict(list)
cond_mapper = _create_cond_mapper(iet)

hs_mapper = {}
# Analysis
hsmapper = {}
imapper = defaultdict(list)

# Look for parent Iterations of children HaloSpots
for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items():
Expand All @@ -116,29 +99,21 @@ def _hoist_halospots(iet):
for i in hs1.halo_scheme.loc_values):
continue

# Look for halo-to-halo optimization possibilities
_process_halo_to_halo(hs0, hs1, iters, scopes, hs_mapper, i_mapper)

# If the HaloSpot is already processed, skip
if hs0 in hs_mapper:
continue

# Look for halo-specific optimization possibilities
_process_halo(hs0, iters, scopes, hs_mapper, i_mapper)
# Compare hs0 to subsequent halo_spots, looking for optimization
# possibilities
_process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper)

mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
for i, hss in i_mapper.items()}

for i, hss in imapper.items()}
mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
for i, hs in hs_mapper.items()})
for i, hs in hsmapper.items()})

if mapper:
iet = Transformer(mapper, nested=True).visit(iet)
iet = Transformer(mapper, nested=True).visit(iet)

return iet


def _process_halo_to_halo(hs0, hs1, iters, scopes, hs_mapper, i_mapper):
def _process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper):

# Loop over the functions in the HaloSpots
for f, v in hs1.fmapper.items():
Expand All @@ -151,51 +126,27 @@ def _process_halo_to_halo(hs0, hs1, iters, scopes, hs_mapper, i_mapper):
continue

for iter in iters:
# Ensure they are merge-able
# If also merge-able we can start hoisting the latter
for dep in scopes[iter].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
else:
fm = hs1.halo_scheme.fmapper[f]
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Since lifted, we need to update the loc_indices with
# known values
# Entering here means we can lift, and we need to update
# the loc_indices with known values
# TODO: Can I get this in a more elegant way?
for d in fm.loc_indices:
root_min = fm.loc_indices[d].symbolic_min
new_min = root_min.subs(fm.loc_indices[d].root,
fm.loc_indices[d].root.symbolic_min)
for d in hse.loc_indices:
root_min = hse.loc_indices[d].symbolic_min
new_min = root_min.subs(hse.loc_indices[d].root,
hse.loc_indices[d].root.symbolic_min)
raw_loc_indices[d] = new_min

hs_entry = fm.rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hs_entry

hs_mapper[hs1] = hs_mapper.get(hs1, hs1.halo_scheme).drop(f)
i_mapper[iter].append(hs1.halo_scheme.project(f))
hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse


def _process_halo(hs0, iters, scopes, h_mapper, imapper):
# Look for halo-specific optimization possibilities
for f, v in hs0.fmapper.items():
loc_dims = frozenset().union([q for d in v.loc_indices
for q in d._defines])
for n, iter in enumerate(iters):
if iter not in scopes:
continue

candidates = [i.dim._defines for i in iters[n:]]
all_candidates = set().union(*candidates)
reads = scopes[iter].getreads(f)
if any(set(a.ispace.dimensions) & all_candidates for a in reads):
continue

for dep in scopes[iter].d_flow.project(f):
if not any(r(dep, candidates, loc_dims)
for r in hoist_rules()):
break
else:
h_mapper[hs0] = h_mapper.get(hs0, hs0.halo_scheme).drop(f)
imapper[iter].append(hs0.halo_scheme.project(f))
hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[iter].append(hs1.halo_scheme.project(f))


def _merge_halospots(iet):
Expand All @@ -214,10 +165,6 @@ def _merge_halospots(iet):

scope = Scope([e.expr for e in FindNodes(Expression).visit(iter)])

# TOFIX: Why only comparing against the first HaloSpot?
# We could similary to `_hoist_halospots` compare all pairs of HaloSpots
# and merge them based upon some priority ranking on which pair we prefer to
# merge
hs0 = halo_spots[0]
mapper[hs0] = hs0.halo_scheme

Expand All @@ -235,14 +182,9 @@ def _merge_halospots(iet):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
else:
try:
hs = hs1.halo_scheme.project(f)
mapper[hs0] = HaloScheme.union([mapper[hs0], hs])
mapper[hs1] = mapper[hs1].drop(f)
except ValueError:
# `hs1.loc_indices=<frozendict {t: t1}` and
# `hs0.loc_indices=<frozendict {t: t0}`
pass
hs = hs1.halo_scheme.project(f)
mapper[hs0] = HaloScheme.union([mapper[hs0], hs])
mapper[hs1] = mapper[hs1].drop(f)

# Post-process analysis
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
Expand Down Expand Up @@ -416,21 +358,6 @@ def mpiize(graph, **kwargs):

# Utility functions to avoid code duplication

def denest_halospots(iet):
"""
Denest nested HaloSpots.
# TOFIX: This also merges HaloSpots that have different loc_indices
"""
mapper = {}
for hs in FindNodes(HaloSpot).visit(iet):
if hs.body.is_HaloSpot:
halo_scheme = HaloScheme.union([hs.halo_scheme, hs.body.halo_scheme])
mapper[hs] = hs._rebuild(halo_scheme=halo_scheme, body=hs.body.body)
iet = Transformer(mapper, nested=True).visit(iet)

return iet


def _create_cond_mapper(iet):
cond_mapper = MapHaloSpots().visit(iet)
return {hs: {i for i in v if i.is_Conditional and
Expand Down Expand Up @@ -461,24 +388,3 @@ def rule2(dep, hs, loc_indices):
rules = [rule0, rule1, rule2]

return rules


def hoist_rules():
# Hoisting rules -- if the retval is True, then it means the input `dep` is not
# a stopper to halo hoisting

def rule0(dep, candidates, loc_dims):
# E.g., `dep=W<f,[x]> -> R<f,[x-1]>` and `candidates=({time}, {x})` => False
# E.g., `dep=W<f,[t1, x, y]> -> R<f,[t0, x-1, y+1]>`, `dep.cause={t,time}` and
# `candidates=({x},)` => True
return (all(i & set(dep.distance_mapper) for i in candidates) and
not any(i & dep.cause for i in candidates) and
not any(i & loc_dims for i in candidates))

def rule1(dep, candidates, loc_dims):
# A reduction isn't a stopper to hoisting
return dep.write is not None and dep.write.is_reduction

rules = [rule0, rule1]

return rules
1 change: 0 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt, floor, Ge, Lt, Derivative, solve)

from devito.exceptions import InvalidArgument, InvalidOperator
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
FindSymbols, ParallelIteration, retrieve_iteration_tree)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from examples.seismic.acoustic import acoustic_setup
from examples.seismic import Receiver, TimeAxis, demo_model
from tests.test_dse import TestTTI


class TestDistributor:
Expand Down Expand Up @@ -1116,6 +1117,7 @@ def test_avoid_haloupdate_with_constant_index(self, mode):
eq = Eq(u.forward, u[t, 1] + u[t, 1 + x.symbolic_min] + u[t, x])
op = Operator(eq)
calls = FindNodes(Call).visit(op)

assert len(calls) == 0

@pytest.mark.parallel(mode=1)
Expand Down Expand Up @@ -1211,6 +1213,7 @@ def test_hoist_haloupdate_from_innerloop(self, mode):

calls = FindNodes(Call).visit(op)
assert len(calls) == 1

# Also make sure the Call is at the right place in the IET
assert op.body.body[-1].body[1].body[0].body[0].body[0].body[0].is_Call
assert op.body.body[-1].body[1].body[0].body[0].body[1].is_Iteration
Expand All @@ -1236,6 +1239,7 @@ def test_unhoist_haloupdate_if_invariant(self, mode):

op = Operator(eqns)
op.apply(time=1)

calls = FindNodes(Call).visit(op)
assert len(calls) == 2

Expand Down Expand Up @@ -2860,6 +2864,22 @@ def test_elastic_structure(self, mode):
assert calls[4].arguments[1] is v[1]


class TestTTI_w_MPI:

@pytest.mark.parallel(mode=[(1)])
def test_halo_structure(self, mode):

mytest = TestTTI()
solver = mytest.tti_operator(opt='advanced', space_order=8)
op = solver.op_fwd(save=False)

calls = [i for i in FindNodes(Call).visit(op) if isinstance(i, HaloUpdateCall)]

assert len(calls) == 1
assert calls[0].functions[0].name == 'u'
assert calls[0].functions[1].name == 'v'


if __name__ == "__main__":
# configuration['mpi'] = 'overlap'
# TestDecomposition().test_reshape_left_right()
Expand Down

0 comments on commit 467de7f

Please sign in to comment.