Skip to content

Commit

Permalink
Merge pull request #2472 from devitocodes/fixup-cire-rotate
Browse files Browse the repository at this point in the history
compiler: Fixup minor cire-rotate bug
  • Loading branch information
mloubout authored Oct 22, 2024
2 parents 154140f + 7af2372 commit afae8af
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 24 deletions.
4 changes: 3 additions & 1 deletion devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,9 @@ def optimize_schedule_rotations(schedule, sregistry):
iis = candidate.lower
iib = candidate.upper

ii = ModuloDimension('%sii' % d.root.name, ds, iis, incr=iib)
name = sregistry.make_name(prefix='%sii' % d.root.name)
ii = ModuloDimension(name, ds, iis, incr=iib)

cd = CustomDimension(name='%sc' % d.root.name, symbolic_min=ii,
symbolic_max=iib, symbolic_size=n)
dsi = ModuloDimension('%si' % ds.root.name, cd, cd + ds - iis, n)
Expand Down
34 changes: 31 additions & 3 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import sympy

from devito.finite_differences import Max, Min
from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes,
FindSymbols, Transformer, Uxreplace, filter_iterations,
retrieve_iteration_tree, pull_dims)
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
FindApplications, FindNodes, FindSymbols, Transformer,
Uxreplace, filter_iterations, retrieve_iteration_tree,
pull_dims)
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper,
Expand Down Expand Up @@ -231,10 +232,13 @@ def minimize_symbols(iet):
* Remove redundant ModuloDimensions (e.g., due to using the
`save=Buffer(2)` API)
* Simplify Iteration headers (e.g., ModuloDimensions with identical
starting point and step)
* Abridge SubDimension names where possible to declutter generated
loop nests and shrink indices
"""
iet = remove_redundant_moddims(iet)
iet = simplify_iteration_headers(iet)
iet = abridge_dim_names(iet)

return iet, {}
Expand Down Expand Up @@ -264,6 +268,30 @@ def remove_redundant_moddims(iet):
return iet


def simplify_iteration_headers(iet):
mapper = {}
for i in FindNodes(Iteration).visit(iet):
candidates = [d for d in i.uindices
if d.is_Modulo and d.symbolic_min == d.symbolic_incr]

# Don't touch `t0, t1, ...` for codegen aesthetics and to avoid
# massive changes in the test suite
candidates = [d for d in candidates
if not any(dd.is_Time for dd in d._defines)]

if not candidates:
continue

uindices = [d for d in i.uindices if d not in candidates]
stmts = [DummyExpr(d, d.symbolic_incr, init=True) for d in candidates]

mapper[i] = i._rebuild(nodes=tuple(stmts) + i.nodes, uindices=uindices)

iet = Transformer(mapper, nested=True).visit(iet)

return iet


@singledispatch
def abridge_dim_names(iet):
return iet
Expand Down
11 changes: 9 additions & 2 deletions examples/performance/00_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1358,10 +1358,17 @@
" {\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n",
" for (int y = y0_blk0, ys = 0, yii0 = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yii0 = 2)\n",
" {\n",
" for (int yc = yii, yi = (yc + ys + 2)%(5); yc <= 2; yc += 1, yi = (yc + ys + 2)%(5))\n",
" int yr0 = (ys)%(5);\n",
" int yr1 = (ys + 3)%(5);\n",
" int yr2 = (ys + 4)%(5);\n",
" int yr3 = (ys + 1)%(5);\n",
"\n",
" for (int yc = yii0; yc <= 2; yc += 1)\n",
" {\n",
" int yi = (yc + ys + 2)%(5);\n",
"\n",
" #pragma omp simd aligned(u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
" {\n",
Expand Down
16 changes: 8 additions & 8 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,17 +1319,17 @@ def test_multiple_subnests_v1(self):
bns, _ = assert_blocking(op, {'x0_blk0'})

trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2
assert len(trees) == 4

assert trees[0][0] is trees[1][0]
assert trees[0][0].pragmas[0].ccode.value ==\
assert len(set(i.root for i in trees)) == 1
assert trees[-2].root.pragmas[0].ccode.value ==\
'omp for collapse(2) schedule(dynamic,1)'
assert not trees[0][2].pragmas
assert not trees[0][3].pragmas
assert trees[0][4].pragmas[0].ccode.value ==\
assert not trees[-2][2].pragmas
assert not trees[-2][3].pragmas
assert trees[-2][4].pragmas[0].ccode.value ==\
'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)'
assert not trees[1][2].pragmas
assert trees[1][3].pragmas[0].ccode.value ==\
assert not trees[-1][2].pragmas
assert trees[-1][3].pragmas[0].ccode.value ==\
'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)'

@pytest.mark.parametrize('blocklevels', [1, 2])
Expand Down
77 changes: 67 additions & 10 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt, floor, Ge, Lt)
sin, sqrt, floor, Ge, Lt, Derivative)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
FindSymbols, ParallelIteration, retrieve_iteration_tree)
Expand Down Expand Up @@ -1131,13 +1131,13 @@ def test_from_different_nests(self, rotate):
# Check code generation
bns, _ = assert_blocking(op1, {'x0_blk0', 'x1_blk0'})
trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2
assert trees[0][-1].nodes[0].body[0].write.is_Array
assert trees[1][-1].nodes[0].body[0].write is u
assert len(trees) == 4 if rotate else 2
assert trees[-2][-1].nodes[0].body[0].write.is_Array
assert trees[-1][-1].nodes[0].body[0].write is u
trees = retrieve_iteration_tree(bns['x1_blk0'])
assert len(trees) == 2
assert trees[0][-1].nodes[0].body[0].write.is_Array
assert trees[1][-1].nodes[0].body[0].write is v
assert len(trees) == 4 if rotate else 2
assert trees[-2][-1].nodes[0].body[0].write.is_Array
assert trees[-1][-1].nodes[0].body[0].write is v

# Check numerical output
op0(time_M=1)
Expand Down Expand Up @@ -2093,15 +2093,68 @@ def test_maxpar_option(self, rotate):
# Check code generation
bns, _ = assert_blocking(op1, {'x0_blk0'})
trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2
if rotate:
assert len(trees) == 5
else:
assert len(trees) == 2
assert trees[0][2] is not trees[1][2]
assert trees[0][1] is trees[1][1]
assert trees[0][2] is not trees[1][2]

# Check numerical output
op0.apply(time_M=2)
op1.apply(time_M=2, u=u1)
assert np.isclose(norm(u), norm(u1), rtol=1e-5)

def test_multiple_rotating_dims(self):
space_order = 8
grid = Grid(shape=(51, 51, 51))
x, y, z = grid.dimensions

dt = 0.1
nt = 5

u = TimeFunction(name="u", grid=grid, space_order=space_order)
vx = TimeFunction(name="vx", grid=grid, space_order=space_order)
vy = TimeFunction(name="vy", grid=grid, space_order=space_order)

f = Function(name='f', grid=grid, space_order=space_order)
g = Function(name='g', grid=grid, space_order=space_order)

expr0 = 1-cos(f)**2
expr1 = sin(f)*cos(f)
expr2 = sin(g)*cos(f)
expr3 = (1-cos(g))*sin(f)*cos(f)

stencil0 = ((expr0*vx.forward).dx(x0=x-x.spacing/2) +
Derivative(expr1*vx.forward, x, deriv_order=0, fd_order=2,
x0=x-x.spacing/2).dy(x0=y) +
Derivative(expr2*vx.forward, x, deriv_order=0, fd_order=2,
x0=x-x.spacing/2).dz(x0=z))
stencil1 = Derivative(expr3*vy.forward, y, deriv_order=0, fd_order=2,
x0=y-y.spacing/2).dx(x0=x)

eqns = [Eq(vx.forward, u*.1),
Eq(vy.forward, u*.1),
Eq(u.forward, stencil0 + stencil1 + .1)]

op0 = Operator(eqns)
op1 = Operator(eqns, opt=("advanced", {"cire-rotate": True}))

f.data_with_halo[:] = .3
g.data_with_halo[:] = .7

u1 = u.func(name='u1')
vx1 = vx.func(name='vx1')
vy1 = vy.func(name='vy1')

op0.apply(time_m=0, time_M=nt-2, dt=dt)

# NOTE: the main issue leading to this test was actually failing
# to jit-compile `op1`. However, we also check numerical correctness
op1.apply(time_m=0, time_M=nt-2, dt=dt, u=u1, vx=vx1, vy=vy1)

assert np.allclose(u.data, u1.data, rtol=1e-5)

def test_maxpar_option_v2(self):
"""
Another test for the compiler option `cire-maxpar=True`.
Expand Down Expand Up @@ -2191,7 +2244,11 @@ def test_blocking_options(self, rotate):
if rotate:
assert_structure(
op1,
prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z',
prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc,z',
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,z'],
't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z,y,yc,z,z'
Expand Down

0 comments on commit afae8af

Please sign in to comment.