Skip to content

Commit

Permalink
api: always use conditional dimension for interpolation radius dim
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 11, 2023
1 parent 0d4c099 commit 6ad85cf
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 28 deletions.
28 changes: 17 additions & 11 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,19 @@ def _rdim(self):
-self.r+1, self.r, 2*self.r, parent)
for d in self._gdims]

return DimensionTuple(*dims, getters=self._gdims)
# Make radius dimension conditional to avoid OOB
rdims = []
pos = self.sfunction._position_map.values()

for (d, rd, p) in zip(self._gdims, dims, pos):
# Add conditional to avoid OOB
lb = sympy.And(rd + p >= d.symbolic_min - self.r, evaluate=False)
ub = sympy.And(rd + p <= d.symbolic_max + self.r, evaluate=False)
cond = sympy.And(lb, ub, evaluate=False)
rdims.append(ConditionalDimension(rd.name, rd, condition=cond,
indirect=True))

return DimensionTuple(*rdims, getters=self._gdims)

def _augment_implicit_dims(self, implicit_dims):
if self.sfunction._sparse_position == -1:
Expand All @@ -177,24 +189,17 @@ def _interp_idx(self, variables, implicit_dims=None):
mapper = {}
pos = self.sfunction._position_map.values()

for ((di, d), rd, p) in zip(enumerate(self._gdims), self._rdim, pos):
# Add conditional to avoid OOB
lb = sympy.And(rd + p >= d.symbolic_min - self.r, evaluate=False)
ub = sympy.And(rd + p <= d.symbolic_max + self.r, evaluate=False)
cond = sympy.And(lb, ub, evaluate=False)
mapper[d] = ConditionalDimension(rd.name, rd, condition=cond, indirect=True)

# Temporaries for the position
temps = self._positions(implicit_dims)

# Coefficient symbol expression
temps.extend(self._coeff_temps(implicit_dims))

# Substitution mapper for variables
mapper = self._rdim._getters
idx_subs = {v: v.subs({k: c + p
for ((k, c), p) in zip(mapper.items(), pos)})
for v in variables}
idx_subs.update(dict(zip(self._rdim, mapper.values())))

return idx_subs, temps

Expand Down Expand Up @@ -290,7 +295,7 @@ def _inject(self, field, expr, implicit_dims=None):
injection expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims) + self._rdim
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Make iterable to support inject((u, v), expr=expr)
# or inject((u, v), expr=(expr1, expr2))
Expand Down Expand Up @@ -380,5 +385,6 @@ def interpolation_coeffs(self):
@property
def _weights(self):
ddim, cdim = self.interpolation_coeffs.dimensions[1:]
return Mul(*[self.interpolation_coeffs.subs({ddim: ri, cdim: rd-rd.symbolic_min})
return Mul(*[self.interpolation_coeffs.subs({ddim: ri,
cdim: rd-rd.parent.symbolic_min})
for (ri, rd) in enumerate(self._rdim)])
4 changes: 2 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from devito.symbolics import estimate_cost
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_tuple, flatten,
filter_sorted, frozendict, is_integer, split, timed_pass,
timed_region)
timed_region, contains_val)
from devito.types import Grid, Evaluable

__all__ = ['Operator']
Expand Down Expand Up @@ -560,7 +560,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
# a TimeFunction `usave(t_sub, x, y)`, an override for `fact` is
# supplied w/o overriding `usave`; that's legal
pass
elif is_integer(args[k]) and args[k] not in as_tuple(v):
elif is_integer(args[k]) and not contains_val(args[k], v):
raise ValueError("Default `%s` is incompatible with other args as "
"`%s=%s`, while `%s=%s` is expected. Perhaps you "
"forgot to override `%s`?" %
Expand Down
7 changes: 7 additions & 0 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def compare_to_first(v):
return not v.isdisjoint(first)
else:
return first in v
elif isinstance(v, range):
if isinstance(first, range):
return first.stop > v.start or v.stop > first.start
else:
return first >= v.start and first < v.stop
elif isinstance(first, range):
return v >= first.start and v < first.stop
elif isinstance(first, Set):
return v in first
else:
Expand Down
10 changes: 9 additions & 1 deletion devito/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
'roundm', 'powerset', 'invert', 'flatten', 'single_or', 'filter_ordered',
'as_mapper', 'filter_sorted', 'pprint', 'sweep', 'all_equal', 'as_list',
'indices_to_slices', 'indices_to_sections', 'transitive_closure',
'humanbytes']
'humanbytes', 'contains_val']


def prod(iterable, initial=1):
Expand Down Expand Up @@ -75,6 +75,14 @@ def is_integer(value):
return isinstance(value, (int, np.integer, sympy.Integer))


def contains_val(val, items):
print(val, items)
try:
return val in items
except TypeError:
return val == items


def generator():
"""
Return a function ``f`` that generates integer numbers starting at 0
Expand Down
4 changes: 3 additions & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,13 +838,15 @@ def __new__(cls, *args, **kwargs):
newobj._dimensions = dimensions
newobj._shape = cls.__shape_setup__(**kwargs)
newobj._dtype = cls.__dtype_setup__(**kwargs)
newobj.__init_finalize__(*args, **kwargs)

# All objects created off an existing AbstractFunction `f` (e.g.,
# via .func, or .subs, such as `f(x + 1)`) keep a reference to `f`
# through the `function` field
newobj.function = function or newobj

# Initialization
newobj.__init_finalize__(*args, **kwargs)

return newobj

def __init__(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init_finalize__(self, *args, function=None, **kwargs):
# a reference to the user-provided buffer
self._initializer = None
if len(initializer) > 0:
self.data_with_halo[:] = initializer
self.data_with_halo[:] = initializer[:]
else:
# This is a corner case -- we might get here, for example, when
# running with MPI and some processes get 0-size arrays after
Expand Down
33 changes: 21 additions & 12 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,19 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
# may represent sets of legal values. If that's the case, here we just
# pick one. Note that we sort for determinism
try:
loc_minv = sorted(loc_minv).pop(0)
except TypeError:
pass
loc_minv = loc_minv.start
except AttributeError:
try:
loc_minv = sorted(loc_minv).pop(0)
except TypeError:
pass
try:
loc_maxv = sorted(loc_maxv).pop(0)
except TypeError:
pass
loc_maxv = loc_maxv.start
except AttributeError:
try:
loc_maxv = sorted(loc_maxv).pop(0)
except TypeError:
pass

return {self.min_name: loc_minv, self.max_name: loc_maxv}

Expand Down Expand Up @@ -853,8 +859,7 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
factor = defaults[dim._factor.name] = dim._factor.data
except AttributeError:
factor = dim._factor
defaults[dim.parent.max_name] = \
frozenset(range(factor*(size - 1), factor*(size)))
defaults[dim.parent.max_name] = range(1, factor*(size))

return defaults

Expand Down Expand Up @@ -977,8 +982,9 @@ def symbolic_incr(self):
def bound_symbols(self):
return set(self.parent.bound_symbols)

def _arg_defaults(self, **kwargs):
return {}
def _arg_defaults(self, alias=None, **kwargs):
dim = alias or self
return {dim.parent.max_name: range(self.symbolic_size, np.iinfo(np.int64).max)}

def _arg_values(self, *args, **kwargs):
return {}
Expand Down Expand Up @@ -1446,7 +1452,7 @@ def symbolic_max(self):
def _arg_names(self):
return (self.min_name, self.max_name, self.name) + self.parent._arg_names

def _arg_defaults(self, _min=None, **kwargs):
def _arg_defaults(self, _min=None, size=None, **kwargs):
"""
A map of default argument values defined by this dimension.
Expand All @@ -1460,7 +1466,10 @@ def _arg_defaults(self, _min=None, **kwargs):
A SteppingDimension does not know its max point and therefore
does not have a size argument.
"""
return {self.parent.min_name: _min}
args = {self.parent.min_name: _min}
if size:
args[self.parent.max_name] = range(size, np.iinfo(np.int32).max)
return args

def _arg_values(self, *args, **kwargs):
"""
Expand Down

0 comments on commit 6ad85cf

Please sign in to comment.