Skip to content

Commit

Permalink
Simplify Elemwise perform method and issue informative warning when n…
Browse files Browse the repository at this point in the history
…umber of operands is too large.

This also clears a hard to debug error when perform method attempted to falback to the C-implementation.
  • Loading branch information
ricardoV94 committed Oct 25, 2024
1 parent 5909d93 commit f1420d7
Showing 1 changed file with 62 additions and 76 deletions.
138 changes: 62 additions & 76 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Sequence
from copy import copy
from textwrap import dedent
Expand All @@ -19,9 +20,9 @@
from pytensor.misc.frozendict import frozendict
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import Composite, transfer_type, upcast
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
Expand Down Expand Up @@ -364,6 +365,7 @@ def __init__(
self.name = name
self.scalar_op = scalar_op
self.inplace_pattern = inplace_pattern
self.ufunc = None
self.destroy_map = {o: [i] for o, i in self.inplace_pattern.items()}

if nfunc_spec is None:
Expand All @@ -375,14 +377,12 @@ def __init__(
def __getstate__(self):
d = copy(self.__dict__)
d.pop("ufunc")
d.pop("nfunc")
d.pop("__epydoc_asRoutine", None)
return d

def __setstate__(self, d):
d.pop("nfunc", None) # This used to be stored in the Op, not anymore
super().__setstate__(d)
self.ufunc = None
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)

def get_output_info(self, *inputs):
Expand Down Expand Up @@ -623,31 +623,47 @@ def transform(r):

return ret

def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
# But our c code support more.
# - nfunc is reused for scipy and scipy is optional
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
impl = "c"

if getattr(self, "nfunc_spec", None) and impl != "c":
self.nfunc = import_func_from_string(self.nfunc_spec[0])

def _create_node_ufunc(self, node) -> None:
if (
(len(node.inputs) + len(node.outputs)) <= 32
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
and self.ufunc is None
and impl == "py"
self.nfunc_spec is not None
# Some scalar Ops like `Add` allow for a variable number of inputs,
# whereas the numpy counterpart does not.
and len(node.inputs) == self.nfunc_spec[1]
):
ufunc = import_func_from_string(self.nfunc_spec[0])
if ufunc is None:
raise ValueError(
f"Could not import ufunc {self.nfunc_spec[0]} for {self}"
)

elif self.ufunc is not None:
# Cached before
ufunc = self.ufunc

else:
if (len(node.inputs) + len(node.outputs)) > 32:
if isinstance(self.scalar_op, Composite):
warnings.warn(
"Trying to create a Python Composite Elemwise function with more than 32 operands.\n"
"This operation should not have been introduced if the C-backend is not properly setup. "
'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n'
"Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
'`pytensor.config.mode = "NUMBA" (or "JAX").'
)
else:
warnings.warn(
f"Trying to create a Python Elemwise function for the scalar Op {self.scalar_op} "
f"with more than 32 operands. This will likely fail."
)

ufunc = np.frompyfunc(
self.scalar_op.impl, len(node.inputs), self.scalar_op.nout
)
if self.scalar_op.nin > 0:
# We can reuse it for many nodes
if self.scalar_op.nin > 0: # Default in base class is -1
# Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
self.ufunc = ufunc
else:
node.tag.ufunc = ufunc

node.tag.ufunc = ufunc

# Numpy ufuncs will sometimes perform operations in
# float16, in particular when the input is int8.
Expand All @@ -669,6 +685,11 @@ def prepare_node(self, node, storage_map, compute_map, impl):
char = np.sctype2char(out_dtype)
sig = char * node.nin + "->" + char * node.nout
node.tag.sig = sig

def prepare_node(self, node, storage_map, compute_map, impl):
if impl == "py":
self._create_node_ufunc(node)

node.tag.fake_node = Apply(
self.scalar_op,
[
Expand All @@ -684,71 +705,36 @@ def prepare_node(self, node, storage_map, compute_map, impl):
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)

def perform(self, node, inputs, output_storage):
if (len(node.inputs) + len(node.outputs)) > 32:
# Some versions of NumPy will segfault, other will raise a
# ValueError, if the number of operands in an ufunc is more than 32.
# In that case, the C version should be used, or Elemwise fusion
# should be disabled.
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage)
ufunc = getattr(node.tag, "ufunc", None)
if ufunc is None:
self._create_node_ufunc(node)
ufunc = node.tag.ufunc

self._check_runtime_broadcast(node, inputs)

ufunc_args = inputs
ufunc_kwargs = {}
# We supported in the past calling manually op.perform.
# To keep that support we need to sometimes call self.prepare_node
if self.nfunc is None and self.ufunc is None:
self.prepare_node(node, None, None, "py")
if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc
nout = self.nfunc_spec[2]
if hasattr(node.tag, "sig"):
ufunc_kwargs["sig"] = node.tag.sig
# Unfortunately, the else case does not allow us to
# directly feed the destination arguments to the nfunc
# since it sometimes requires resizing. Doing this
# optimization is probably not worth the effort, since we
# should normally run the C version of the Op.
else:
# the second calling form is used because in certain versions of
# numpy the first (faster) version leads to segfaults
if self.ufunc:
ufunc = self.ufunc
elif not hasattr(node.tag, "ufunc"):
# It happen that make_thunk isn't called, like in
# get_underlying_scalar_constant_value
self.prepare_node(node, None, None, "py")
# prepare_node will add ufunc to self or the tag
# depending if we can reuse it or not. So we need to
# test both again.
if self.ufunc:
ufunc = self.ufunc
else:
ufunc = node.tag.ufunc
else:
ufunc = node.tag.ufunc

nout = ufunc.nout
if hasattr(node.tag, "sig"):
ufunc_kwargs["sig"] = node.tag.sig

variables = ufunc(*ufunc_args, **ufunc_kwargs)
outputs = ufunc(*inputs, **ufunc_kwargs)

if nout == 1:
variables = [variables]
if not isinstance(outputs, tuple):
outputs = (outputs,)

for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs)
for i, (out, out_storage, node_out) in enumerate(
zip(outputs, output_storage, node.outputs)
):
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
# Numpy frompyfunc always returns object arrays
out_storage[0] = out = np.asarray(out, dtype=node_out.dtype)

if i in self.inplace_pattern:
odat = inputs[self.inplace_pattern[i]]
odat[...] = variable
storage[0] = odat
inp = inputs[self.inplace_pattern[i]]
inp[...] = out
out_storage[0] = inp

# numpy.real return a view!
if not variable.flags.owndata:
storage[0] = variable.copy()
if not out.flags.owndata:
out_storage[0] = out.copy()

@staticmethod
def _check_runtime_broadcast(node, inputs):
Expand Down

0 comments on commit f1420d7

Please sign in to comment.