Skip to content

Commit

Permalink
Merge pull request #2458 from devitocodes/revamp-cross
Browse files Browse the repository at this point in the history
API: revamp cross derivative shortcuts
  • Loading branch information
mloubout authored Sep 26, 2024
2 parents 20a9de8 + db7a10a commit 25d87fc
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 64 deletions.
76 changes: 48 additions & 28 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ def __new__(cls, expr, *dims, **kwargs):
obj = Differentiable.__new__(cls, expr, *var_count)
obj._dims = tuple(OrderedDict.fromkeys(new_dims))

skip = kwargs.get('preprocessed', False) or obj.ndims == 1

obj._fd_order = fd_o if skip else DimensionTuple(*fd_o, getters=obj._dims)
obj._deriv_order = orders if skip else DimensionTuple(*orders, getters=obj._dims)
obj._fd_order = DimensionTuple(*as_tuple(fd_o), getters=obj._dims)
obj._deriv_order = DimensionTuple(*as_tuple(orders), getters=obj._dims)
obj._side = kwargs.get("side")
obj._transpose = kwargs.get("transpose", direct)
obj._method = kwargs.get("method", 'FD')
Expand Down Expand Up @@ -137,7 +135,7 @@ def _process_kwargs(cls, expr, *dims, **kwargs):
fd_orders = kwargs.get('fd_order')
deriv_orders = kwargs.get('deriv_order')
if len(dims) == 1:
dims = tuple([dims[0]]*max(1, deriv_orders))
dims = tuple([dims[0]]*max(1, deriv_orders[0]))
variable_count = [sympy.Tuple(s, dims.count(s))
for s in filter_ordered(dims)]
return dims, deriv_orders, fd_orders, variable_count
Expand Down Expand Up @@ -222,25 +220,34 @@ def _process_weights(cls, **kwargs):

def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None):
side = side or self._side
method = method or self._method
weights = weights if weights is not None else self._weights

x0 = self._process_x0(self.dims, x0=x0)
_x0 = frozendict({**self.x0, **x0})
if self.ndims == 1:
fd_order = fd_order or self._fd_order
method = method or self._method
weights = weights if weights is not None else self._weights
return self._rebuild(fd_order=fd_order, side=side, x0=_x0, method=method,
weights=weights)

# Cross derivative

_fd_order = dict(self.fd_order.getters)
try:
_fd_order.update(fd_order or {})
_fd_order = tuple(_fd_order.values())
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)
except TypeError:
assert self.ndims == 1
_fd_order.update({self.dims[0]: fd_order or self.fd_order[0]})
except AttributeError:
raise TypeError("Multi-dimensional Derivative, input expected as a dict")
raise TypeError("fd_order incompatible with dimensions")

if isinstance(self.expr, Derivative):
# In case this was called on a perfect cross-derivative `u.dxdy`
# we need to propagate the call to the nested derivative
x0s = self._filter_dims(self.expr._filter_dims(_x0), neg=True)
expr = self.expr(x0=x0s, fd_order=self.expr._filter_dims(_fd_order),
side=side, method=method)
else:
expr = self.expr

_fd_order = self._filter_dims(_fd_order, as_tuple=True)

return self._rebuild(fd_order=_fd_order, x0=_x0, side=side)
return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method,
weights=weights, expr=expr)

def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
Expand Down Expand Up @@ -293,15 +300,32 @@ def _xreplace(self, subs):
except AttributeError:
return new, True

# Resolve nested derivatives
dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)}
expr = self.expr.xreplace(dsubs)

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._rebuild(subs=subs), True
return self._rebuild(subs=subs, expr=expr), True

@cached_property
def _metadata(self):
ret = [self.dims] + [getattr(self, i) for i in self.__rkwargs__]
ret.append(self.expr.staggered or (None,))
return tuple(ret)

def _filter_dims(self, col, as_tuple=False, neg=False):
"""
Filter collection to only keep the Derivative's dimensions as keys.
"""
if neg:
filtered = {k: v for k, v in col.items() if k not in self.dims}
else:
filtered = {k: v for k, v in col.items() if k in self.dims}
if as_tuple:
return DimensionTuple(*filtered.values(), getters=self.dims)
else:
return filtered

@property
def dims(self):
return self._dims
Expand Down Expand Up @@ -422,13 +446,9 @@ def _eval_fd(self, expr, **kwargs):
"""
# Step 1: Evaluate non-derivative x0. We currently enforce a simple 2nd order
# interpolation to avoid very expensive finite differences on top of it
x0_interp = {}
x0_deriv = {}
for d, v in self.x0.items():
if d in self.dims:
x0_deriv[d] = v
elif not d.is_Time:
x0_interp[d] = v
x0_deriv = self._filter_dims(self.x0)
x0_interp = {d: v for d, v in self.x0.items()
if d not in x0_deriv and not d.is_Time}

if x0_interp and self.method == 'FD':
expr = interp_for_fd(expr, x0_interp, **kwargs)
Expand All @@ -446,7 +466,7 @@ def _eval_fd(self, expr, **kwargs):
# Step 3: Evaluate FD of the new expression
if self.method == 'RSFD':
assert len(self.dims) == 1
assert self.deriv_order == 1
assert self.deriv_order[0] == 1
res = d45(expr, self.dims[0], x0=self.x0, expand=expand)
elif len(self.dims) > 1:
assert self.method == 'FD'
Expand All @@ -455,8 +475,8 @@ def _eval_fd(self, expr, **kwargs):
side=self.side)
else:
assert self.method == 'FD'
res = generic_derivative(expr, self.dims[0], as_tuple(self.fd_order)[0],
self.deriv_order, weights=self.weights,
res = generic_derivative(expr, self.dims[0], self.fd_order[0],
self.deriv_order[0], weights=self.weights,
side=self.side, matvec=self.transpose,
x0=self.x0, expand=expand)

Expand Down
22 changes: 8 additions & 14 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,9 @@ def coefficients(self):
coefficients = {f.coefficients for f in self._functions}
# If there is multiple ones, we have to revert to the highest priority
# i.e we have to remove symbolic
key = lambda x: coeff_priority[x]
key = lambda x: coeff_priority.get(x, -1)
return sorted(coefficients, key=key, reverse=True)[0]

@cached_property
def _coeff_symbol(self, *args, **kwargs):
if self._uses_symbolic_coefficients:
return W
else:
raise ValueError("Couldn't find any symbolic coefficients")

def _eval_at(self, func):
if not func.is_Staggered:
# Cartesian grid, do no waste time
Expand Down Expand Up @@ -427,14 +420,14 @@ def has_free(self, *patterns):


def highest_priority(DiffOp):
prio = lambda x: getattr(x, '_fd_priority', 0)
# We want to get the object with highest priority
# We also need to make sure that the object with the largest
# set of dimensions is used when multiple ones with the same
# priority appear
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


# Abstract symbol representing a symbolic coefficient
W = sympy.Function('W')


class DifferentiableOp(Differentiable):

__sympy_class__ = None
Expand Down Expand Up @@ -1018,7 +1011,8 @@ def interp_for_fd(expr, x0, **kwargs):

@interp_for_fd.register(sympy.Derivative)
def _(expr, x0, **kwargs):
return expr.func(expr=interp_for_fd(expr.expr, x0, **kwargs))
x0_expr = {d: v for d, v in x0.items() if d not in expr.dims}
return expr.func(expr=interp_for_fd(expr.expr, x0_expr, **kwargs))


@interp_for_fd.register(sympy.Expr)
Expand Down
2 changes: 1 addition & 1 deletion devito/finite_differences/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def cross_derivative(expr, dims, fd_order, deriv_order, x0=None, side=None, **kw
Semantically, this is equivalent to
>>> (f*g).dxdy
Derivative(f(x, y)*g(x, y), x, y)
Derivative(Derivative(f(x, y)*g(x, y), x), y)
The only difference is that in the latter case derivatives remain unevaluated.
The expanded form is obtained via ``evaluate``
Expand Down
15 changes: 12 additions & 3 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,14 @@ def generate_fd_shortcuts(dims, so, to=0):
from devito.finite_differences.derivative import Derivative

def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs):
return Derivative(expr, *as_tuple(dims), deriv_order=deriv_order,
fd_order=fd_order, side=side, **kwargs)
# Separate dimensions to always have cross derivatives return nested
# derivatives. E.g `u.dxdy -> u.dx.dy`
dims = as_tuple(dims)
deriv_order = as_tuple(deriv_order)
fd_order = as_tuple(fd_order)
for (d, do, fo) in zip(dims, deriv_order, fd_order):
expr = Derivative(expr, d, deriv_order=do, fd_order=fo, side=side, **kwargs)
return expr

all_combs = dim_with_order(dims, orders)

Expand Down Expand Up @@ -225,7 +231,8 @@ def numeric_weights(function, deriv_order, indices, x0):
return finite_diff_weights(deriv_order, indices, x0)[-1][-1]


fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights}
fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights,
'symbolic': numeric_weights} # Backward compat for 'symbolic'
coeff_priority = {'taylor': 1, 'standard': 1}


Expand Down Expand Up @@ -318,6 +325,8 @@ def process_weights(weights, expr):
if weights is None:
return 0, None
elif isinstance(weights, Function):
if len(weights.dimensions) == 1:
return weights.shape[0], weights.dimensions[0]
wdim = {d for d in weights.dimensions if d not in expr.dimensions}
assert len(wdim) == 1
wdim = wdim.pop()
Expand Down
8 changes: 7 additions & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
ConditionalDimension)
from devito.types.array import Array
from devito.types.basic import AbstractFunction
from devito.types.grid import MultiSubDimension

Expand Down Expand Up @@ -135,8 +136,13 @@ def _lower_exprs(expressions, subs):
if dimension_map:
indices = [j.xreplace(dimension_map) for j in indices]

mapper[i] = f.indexed[indices]
# Handle Array
if isinstance(f, Array) and f.initvalue is not None:
initvalue = [_lower_exprs(i, subs) for i in f.initvalue]
# TODO: fix rebuild to avoid new name
f = f._rebuild(name='%si' % f.name, initvalue=initvalue)

mapper[i] = f.indexed[indices]
# Add dimensions map to the mapper in case dimensions are used
# as an expression, i.e. Eq(u, x, subdomain=xleft)
mapper.update(dimension_map)
Expand Down
2 changes: 2 additions & 0 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Iterable
from functools import singledispatch

import numpy as np
from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort
Expand Down Expand Up @@ -98,6 +99,7 @@ def _(expr, rule):
return _uxreplace(expr, rule)


@_uxreplace_dispatch.register(np.ndarray)
@_uxreplace_dispatch.register(tuple)
@_uxreplace_dispatch.register(Tuple)
@_uxreplace_dispatch.register(list)
Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def retrieve_indexed(exprs, mode='all', deep=False):
return search(exprs, q_indexed, mode, 'dfs', deep)


def retrieve_functions(exprs, mode='all'):
def retrieve_functions(exprs, mode='all', deep=False):
"""Shorthand to retrieve the DiscreteFunctions in `exprs`."""
indexeds = search(exprs, q_indexed, mode, 'dfs')
indexeds = search(exprs, q_indexed, mode, 'dfs', deep)

functions = search(exprs, q_function, mode, 'dfs')
functions = search(exprs, q_function, mode, 'dfs', deep)
functions.update({i.function for i in indexeds})

return functions
Expand Down
3 changes: 2 additions & 1 deletion devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ def _apply_coeffs(cls, expr, coefficients):
for coeff in coefficients.coefficients:
derivs = [d for d in retrieve_derivatives(expr)
if coeff.dimension in d.dims and
coeff.deriv_order == d.deriv_order]
coeff.deriv_order == d.deriv_order.get(coeff.dimension, None)]
if not derivs:
continue
mapper.update({d: d._rebuild(weights=coeff.weights) for d in derivs})
if not mapper:
return expr

return expr.xreplace(mapper)

def _evaluate(self, **kwargs):
Expand Down
36 changes: 24 additions & 12 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,19 @@ def test_stencil_derivative(self, SymbolType, dim):

@pytest.mark.parametrize('SymbolType, derivative, dim, expected', [
(Function, ['dx2'], 3, 'Derivative(u(x, y, z), (x, 2))'),
(Function, ['dx2dy'], 3, 'Derivative(u(x, y, z), (x, 2), y)'),
(Function, ['dx2dydz'], 3, 'Derivative(u(x, y, z), (x, 2), y, z)'),
(Function, ['dx2dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dydz'], 3,
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), z)'),
(Function, ['dx2', 'dy'], 3, 'Derivative(Derivative(u(x, y, z), (x, 2)), y)'),
(Function, ['dx2dy', 'dz2'], 3,
'Derivative(Derivative(u(x, y, z), (x, 2), y), (z, 2))'),
'Derivative(Derivative(Derivative(u(x, y, z), (x, 2)), y), (z, 2))'),
(TimeFunction, ['dx2'], 3, 'Derivative(u(t, x, y, z), (x, 2))'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(u(t, x, y, z), (x, 2), y)'),
(TimeFunction, ['dx2dy'], 3, 'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx2', 'dy'], 3,
'Derivative(Derivative(u(t, x, y, z), (x, 2)), y)'),
(TimeFunction, ['dx', 'dy', 'dx2', 'dz', 'dydz'], 3,
'Derivative(Derivative(Derivative(Derivative(Derivative(u(t, x, y, z), x), y),' +
' (x, 2)), z), y, z)')
'Derivative(Derivative(Derivative(Derivative(Derivative(Derivative(' +
'u(t, x, y, z), x), y), (x, 2)), z), y), z)')
])
def test_unevaluation(self, SymbolType, derivative, dim, expected):
u = SymbolType(name='u', grid=self.grid, time_order=2, space_order=2)
Expand All @@ -111,13 +112,13 @@ def test_unevaluation(self, SymbolType, derivative, dim, expected):

@pytest.mark.parametrize('expr,expected', [
('u.dx + u.dy', 'Derivative(u, x) + Derivative(u, y)'),
('u.dxdy', 'Derivative(u, x, y)'),
('u.dxdy', 'Derivative(Derivative(u, x), y)'),
('u.laplace',
'Derivative(u, (x, 2)) + Derivative(u, (y, 2)) + Derivative(u, (z, 2))'),
('(u.dx + u.dy).dx', 'Derivative(Derivative(u, x) + Derivative(u, y), x)'),
('((u.dx + u.dy).dx + u.dxdy).dx',
'Derivative(Derivative(Derivative(u, x) + Derivative(u, y), x) +' +
' Derivative(u, x, y), x)'),
' Derivative(Derivative(u, x), y), x)'),
('(u**4).dx', 'Derivative(u**4, x)'),
('(u/4).dx', 'Derivative(u/4, x)'),
('((u.dx + v.dy).dx * v.dx).dy.dz',
Expand Down Expand Up @@ -403,6 +404,11 @@ def test_xderiv_x0(self):
- f.dx(x0=x+h_x/2).dy(x0=y+h_y/2).evaluate
assert simplify(expr) == 0

# Check x0 is correctly set
dfdxdx = f.dx(x0=x+h_x/2).dx(x0=x-h_x/2)
assert dict(dfdxdx.x0) == {x: x-h_x/2}
assert dict(dfdxdx.expr.x0) == {x: x+h_x/2}

def test_fd_new_side(self):
grid = Grid((10,))
u = Function(name="u", grid=grid, space_order=4)
Expand Down Expand Up @@ -659,9 +665,9 @@ def test_zero_spec(self):
drv1 = Derivative(f, (x, 2), (y, 0))
assert drv0.dims == (x,)
assert drv1.dims == (x, y)
assert drv0.fd_order == 2
assert drv0.fd_order == (2,)
assert drv1.fd_order == (2, 2)
assert drv0.deriv_order == 2
assert drv0.deriv_order == (2,)
assert drv1.deriv_order == (2, 0)

assert drv0.evaluate == drv1.evaluate
Expand Down Expand Up @@ -731,6 +737,12 @@ def test_issue_2442(self):
dfdxdy_split = f.dxc.dyc
assert dfdxdy.evaluate == dfdxdy_split.evaluate

def test_cross_newnest(self):
grid = Grid((11, 11))
f = Function(name="f", grid=grid, space_order=2)

assert f.dxdy == f.dx.dy


class TestTwoStageEvaluation:

Expand Down Expand Up @@ -984,8 +996,8 @@ def test_laplacian_opt(self):
df = f.laplacian(order=2, shift=.5)
for (v, d) in zip(df.args, grid.dimensions):
assert v.dims[0] == d
assert v.fd_order == 2
assert v.deriv_order == 2
assert v.fd_order == (2,)
assert v.deriv_order == (2,)
assert d in v.x0


Expand Down
Loading

0 comments on commit 25d87fc

Please sign in to comment.