Skip to content

Commit

Permalink
compiler: Add wrapper for subs vs uxreplace
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Sep 11, 2024
1 parent 85aa136 commit 31359c1
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort

from devito.finite_differences.differentiable import EvalDerivative
from devito.finite_differences.differentiable import (
EvalDerivative, IndexDerivative
)
from devito.symbolics.extended_sympy import DefFunction, rfunc
from devito.symbolics.queries import q_leaf
from devito.symbolics.search import retrieve_indexed, retrieve_functions
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
from devito.types.basic import Basic
from devito.types.basic import Basic, Indexed
from devito.types.array import ComponentAccess
from devito.types.equation import Eq
from devito.types.relational import Le, Lt, Gt, Ge

__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched',
'evalrel', 'flatten_args']
'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite',
'reuse_if_untouched', 'evalrel', 'flatten_args']


def uxreplace(expr, rule):
Expand Down Expand Up @@ -244,6 +246,20 @@ def add(self, expr, make, terms=None):
self[base] = self.extracted[base] = make()


def subs_if_composite(expr, subs):
"""
Call `expr.subs(subs)` if `subs` contain composite expressions, that is
expressions that can be part of larger expressions of the same type (e.g.,
`a*b` could be part of `a*b*c`, while `a[1]` cannot be part of a "larger
Indexed"). Instead, if `subs` consists of just "primitive" expressions, then
resort to the much faster `uxreplace`.
"""
if all(isinstance(i, (Indexed, IndexDerivative)) for i in subs):
return uxreplace(expr, subs)
else:
return expr.subs(subs)


def xreplace_indices(exprs, mapper, key=None):
"""
Replace array indices in expressions.
Expand Down

0 comments on commit 31359c1

Please sign in to comment.