diff --git a/devito/passes/clusters/aliases.py b/devito/passes/clusters/aliases.py index ee3892f1f4..8b7f610564 100644 --- a/devito/passes/clusters/aliases.py +++ b/devito/passes/clusters/aliases.py @@ -779,7 +779,9 @@ def optimize_schedule_rotations(schedule, sregistry): iis = candidate.lower iib = candidate.upper - ii = ModuloDimension('%sii' % d.root.name, ds, iis, incr=iib) + name = sregistry.make_name(prefix='%sii' % d.root.name) + ii = ModuloDimension(name, ds, iis, incr=iib) + cd = CustomDimension(name='%sc' % d.root.name, symbolic_min=ii, symbolic_max=iib, symbolic_size=n) dsi = ModuloDimension('%si' % ds.root.name, cd, cd + ds - iis, n) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 4617b47336..e9d8308336 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -5,9 +5,10 @@ import sympy from devito.finite_differences import Max, Min -from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes, - FindSymbols, Transformer, Uxreplace, filter_iterations, - retrieve_iteration_tree, pull_dims) +from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder, + FindApplications, FindNodes, FindSymbols, Transformer, + Uxreplace, filter_iterations, retrieve_iteration_tree, + pull_dims) from devito.passes.iet.engine import iet_pass from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, @@ -231,10 +232,13 @@ def minimize_symbols(iet): * Remove redundant ModuloDimensions (e.g., due to using the `save=Buffer(2)` API) + * Simplify Iteration headers (e.g., ModuloDimensions with identical + starting point and step) * Abridge SubDimension names where possible to declutter generated loop nests and shrink indices """ iet = remove_redundant_moddims(iet) + iet = simplify_iteration_headers(iet) iet = abridge_dim_names(iet) return iet, {} @@ -264,6 +268,30 @@ def remove_redundant_moddims(iet): return iet +def simplify_iteration_headers(iet): + mapper = {} + for i in FindNodes(Iteration).visit(iet): + candidates = [d for d in i.uindices + if d.is_Modulo and d.symbolic_min == d.symbolic_incr] + + # Don't touch `t0, t1, ...` for codegen aesthetics and to avoid + # massive changes in the test suite + candidates = [d for d in candidates + if not any(dd.is_Time for dd in d._defines)] + + if not candidates: + continue + + uindices = [d for d in i.uindices if d not in candidates] + stmts = [DummyExpr(d, d.symbolic_incr, init=True) for d in candidates] + + mapper[i] = i._rebuild(nodes=tuple(stmts) + i.nodes, uindices=uindices) + + iet = Transformer(mapper, nested=True).visit(iet) + + return iet + + @singledispatch def abridge_dim_names(iet): return iet diff --git a/examples/performance/00_overview.ipynb b/examples/performance/00_overview.ipynb index 3ed1a45014..71dc1a5569 100644 --- a/examples/performance/00_overview.ipynb +++ b/examples/performance/00_overview.ipynb @@ -1358,10 +1358,17 @@ " {\n", " for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n", " {\n", - " for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n", + " for (int y = y0_blk0, ys = 0, yii0 = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yii0 = 2)\n", " {\n", - " for (int yc = yii, yi = (yc + ys + 2)%(5); yc <= 2; yc += 1, yi = (yc + ys + 2)%(5))\n", + " int yr0 = (ys)%(5);\n", + " int yr1 = (ys + 3)%(5);\n", + " int yr2 = (ys + 4)%(5);\n", + " int yr3 = (ys + 1)%(5);\n", + "\n", + " for (int yc = yii0; yc <= 2; yc += 1)\n", " {\n", + " int yi = (yc + ys + 2)%(5);\n", + "\n", " #pragma omp simd aligned(u:32)\n", " for (int z = z_m; z <= z_M; z += 1)\n", " {\n", diff --git a/tests/test_dle.py b/tests/test_dle.py index efc41bd2bc..45583d4346 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -1319,17 +1319,17 @@ def test_multiple_subnests_v1(self): bns, _ = assert_blocking(op, {'x0_blk0'}) trees = retrieve_iteration_tree(bns['x0_blk0']) - assert len(trees) == 2 + assert len(trees) == 4 - assert trees[0][0] is trees[1][0] - assert trees[0][0].pragmas[0].ccode.value ==\ + assert len(set(i.root for i in trees)) == 1 + assert trees[-2].root.pragmas[0].ccode.value ==\ 'omp for collapse(2) schedule(dynamic,1)' - assert not trees[0][2].pragmas - assert not trees[0][3].pragmas - assert trees[0][4].pragmas[0].ccode.value ==\ + assert not trees[-2][2].pragmas + assert not trees[-2][3].pragmas + assert trees[-2][4].pragmas[0].ccode.value ==\ 'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)' - assert not trees[1][2].pragmas - assert trees[1][3].pragmas[0].ccode.value ==\ + assert not trees[-1][2].pragmas + assert trees[-1][3].pragmas[0].ccode.value ==\ 'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)' @pytest.mark.parametrize('blocklevels', [1, 2]) diff --git a/tests/test_dse.py b/tests/test_dse.py index 7fd0298d93..49cece3b34 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -12,7 +12,7 @@ ConditionalDimension, DefaultDimension, Grid, Operator, norm, grad, div, dimensions, switchconfig, configuration, centered, first_derivative, solve, transpose, Abs, cos, - sin, sqrt, floor, Ge, Lt) + sin, sqrt, floor, Ge, Lt, Derivative) from devito.exceptions import InvalidArgument, InvalidOperator from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, FindSymbols, ParallelIteration, retrieve_iteration_tree) @@ -1131,13 +1131,13 @@ def test_from_different_nests(self, rotate): # Check code generation bns, _ = assert_blocking(op1, {'x0_blk0', 'x1_blk0'}) trees = retrieve_iteration_tree(bns['x0_blk0']) - assert len(trees) == 2 - assert trees[0][-1].nodes[0].body[0].write.is_Array - assert trees[1][-1].nodes[0].body[0].write is u + assert len(trees) == 4 if rotate else 2 + assert trees[-2][-1].nodes[0].body[0].write.is_Array + assert trees[-1][-1].nodes[0].body[0].write is u trees = retrieve_iteration_tree(bns['x1_blk0']) - assert len(trees) == 2 - assert trees[0][-1].nodes[0].body[0].write.is_Array - assert trees[1][-1].nodes[0].body[0].write is v + assert len(trees) == 4 if rotate else 2 + assert trees[-2][-1].nodes[0].body[0].write.is_Array + assert trees[-1][-1].nodes[0].body[0].write is v # Check numerical output op0(time_M=1) @@ -2093,15 +2093,68 @@ def test_maxpar_option(self, rotate): # Check code generation bns, _ = assert_blocking(op1, {'x0_blk0'}) trees = retrieve_iteration_tree(bns['x0_blk0']) - assert len(trees) == 2 + if rotate: + assert len(trees) == 5 + else: + assert len(trees) == 2 + assert trees[0][2] is not trees[1][2] assert trees[0][1] is trees[1][1] - assert trees[0][2] is not trees[1][2] # Check numerical output op0.apply(time_M=2) op1.apply(time_M=2, u=u1) assert np.isclose(norm(u), norm(u1), rtol=1e-5) + def test_multiple_rotating_dims(self): + space_order = 8 + grid = Grid(shape=(51, 51, 51)) + x, y, z = grid.dimensions + + dt = 0.1 + nt = 5 + + u = TimeFunction(name="u", grid=grid, space_order=space_order) + vx = TimeFunction(name="vx", grid=grid, space_order=space_order) + vy = TimeFunction(name="vy", grid=grid, space_order=space_order) + + f = Function(name='f', grid=grid, space_order=space_order) + g = Function(name='g', grid=grid, space_order=space_order) + + expr0 = 1-cos(f)**2 + expr1 = sin(f)*cos(f) + expr2 = sin(g)*cos(f) + expr3 = (1-cos(g))*sin(f)*cos(f) + + stencil0 = ((expr0*vx.forward).dx(x0=x-x.spacing/2) + + Derivative(expr1*vx.forward, x, deriv_order=0, fd_order=2, + x0=x-x.spacing/2).dy(x0=y) + + Derivative(expr2*vx.forward, x, deriv_order=0, fd_order=2, + x0=x-x.spacing/2).dz(x0=z)) + stencil1 = Derivative(expr3*vy.forward, y, deriv_order=0, fd_order=2, + x0=y-y.spacing/2).dx(x0=x) + + eqns = [Eq(vx.forward, u*.1), + Eq(vy.forward, u*.1), + Eq(u.forward, stencil0 + stencil1 + .1)] + + op0 = Operator(eqns) + op1 = Operator(eqns, opt=("advanced", {"cire-rotate": True})) + + f.data_with_halo[:] = .3 + g.data_with_halo[:] = .7 + + u1 = u.func(name='u1') + vx1 = vx.func(name='vx1') + vy1 = vy.func(name='vy1') + + op0.apply(time_m=0, time_M=nt-2, dt=dt) + + # NOTE: the main issue leading to this test was actually failing + # to jit-compile `op1`. However, we also check numerical correctness + op1.apply(time_m=0, time_M=nt-2, dt=dt, u=u1, vx=vx1, vy=vy1) + + assert np.allclose(u.data, u1.data, rtol=1e-5) + def test_maxpar_option_v2(self): """ Another test for the compiler option `cire-maxpar=True`. @@ -2191,7 +2244,11 @@ def test_blocking_options(self, rotate): if rotate: assert_structure( op1, - prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z', + prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc', 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc,z', 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,z'], 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z,y,yc,z,z'