Skip to content

Commit

Permalink
Cleanup Max and Argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 8, 2024
1 parent 0d12385 commit d1b0d8a
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 182 deletions.
33 changes: 13 additions & 20 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
float_dtypes,
lvector,
)
from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string
from pytensor.tensor.utils import (
broadcast_static_dim_lengths,
import_func_from_string,
normalize_reduce_axis,
)
from pytensor.tensor.variable import TensorVariable
from pytensor.utils import uniq

Expand Down Expand Up @@ -1371,7 +1375,6 @@ def _acc_dtype(self, idtype):

def make_node(self, input):
input = as_tensor_variable(input)
inp_dims = input.type.ndim
inp_dtype = input.type.dtype

# We need to redefine make_node so that, if self.dtype is None,
Expand All @@ -1383,29 +1386,19 @@ def make_node(self, input):
assert dtype is not None
assert acc_dtype is not None

axis = self.axis
axis = normalize_reduce_axis(input, self.axis)

# scalar inputs are treated as 1D regarding axis in this `Op`
if axis is not None:
try:
axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims))
except np.AxisError:
raise np.AxisError(axis, ndim=inp_dims)
if axis != self.axis or dtype != self.dtype or acc_dtype != self.acc_dtype:
op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype)
else:
op = self

if axis is None:
out_shape = ()
else:
out_shape = tuple(
s for i, s in enumerate(input.type.shape) if i not in axis
)
else:
out_shape = ()

if (
(axis is not None and any(a < 0 for a in axis))
or dtype != self.dtype
or acc_dtype != self.acc_dtype
):
op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype)
else:
op = self

output = TensorType(dtype=dtype, shape=out_shape)()

Expand Down
169 changes: 41 additions & 128 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from pytensor import config, printing
from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
Expand All @@ -26,9 +25,9 @@
cast,
concatenate,
constant,
expand_dims,
stack,
switch,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import (
Expand All @@ -45,14 +44,11 @@
continuous_dtypes,
discrete_dtypes,
int_dtypes,
integer_dtypes,
tensor,
uint_dtypes,
)
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import as_list
from pytensor.tensor.utils import as_list, normalize_reduce_axis
from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
_tensor_py_operators,
)
Expand Down Expand Up @@ -157,7 +153,7 @@ class Argmax(COp):

def __init__(self, axis):
if axis is not None:
axis = tuple(axis)
axis = tuple(sorted(axis))
self.axis = axis

def get_params(self, node):
Expand All @@ -168,7 +164,7 @@ def get_params(self, node):
c_axis = np.int64(-1)
return self.params_type.get_params(c_axis=c_axis)

def make_node(self, x, axis=None):
def make_node(self, x):
x = as_tensor_variable(x)
if self.axis is None:
all_axes = list(range(x.ndim))
Expand Down Expand Up @@ -198,7 +194,9 @@ def perform(self, node, inp, outs):
# Work around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
transposed_x = np.transpose(
x, np.concatenate((keep_axes, np.asarray(axes, dtype="int64")))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
Expand All @@ -214,7 +212,7 @@ def c_code(self, node, name, inp, out, sub):
if self.axis is None:
axis_code = "axis = NPY_MAXDIMS;"
else:
if len(self.axis) > 1:
if len(self.axis) != 1:
raise NotImplementedError()
# params is only used here for now
axis_code = """
Expand Down Expand Up @@ -253,7 +251,7 @@ def c_code(self, node, name, inp, out, sub):
return ret % locals()

def c_code_cache_version(self):
return (1,)
return (2,)

def infer_shape(self, fgraph, node, shapes):
(ishape,) = shapes
Expand All @@ -277,7 +275,7 @@ def grad(self, inp, grads):
return [x.zeros_like()]


def argmax(x, axis=None, keepdims=False):
def argmax(x: TensorLike, axis=None, keepdims: bool = False):
"""
Returns indices of maximum elements obtained by iterating over given axis.
Expand All @@ -286,17 +284,29 @@ def argmax(x, axis=None, keepdims=False):
Parameters
----------
x: TensorLike
Array on which to compute argmax
axis:
Axis along which to compute argmax. Unlike numpy multiple partial axis are supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor.
Returns
-------
TensorVariable
TensorVariable representing the argmax operation
"""
argout = max_and_argmax(x, axis)[1]
x = as_tensor_variable(x)
axis = normalize_reduce_axis(x, axis)
out = Argmax(axis)(x)

if keepdims:
argout = makeKeepDims(x, argout, axis)
return argout
out = makeKeepDims(x, out, axis)

return out


@_vectorize_node.register(Argmax)
Expand Down Expand Up @@ -324,59 +334,6 @@ def makeKeepDims(x, y, axis):
return expand_dims(y, axis)


def check_and_normalize_axes(x, axis):
"""Check axes, normalize and convert them to a Python list of integers.
Parameters
----------
x: TensorVariable
axis: int, tuple or list of integers
Returns
-------
axis: list of integers
Return an empty list if argument is None.
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif isinstance(axis, int | np.integer) or (
isinstance(axis, np.ndarray) and axis.ndim == 0
):
axis = [int(axis)]
elif isinstance(axis, tuple | list | np.ndarray):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError(f"Computation needs a constant axis. Got {axis}")
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, int | np.integer) or (
isinstance(axis.data, np.ndarray) and axis.data.ndim == 0
):
axis = [int(axis.data)]
elif isinstance(axis.data, list | np.ndarray):
axis = [int(i) for i in axis.data]
else:
raise TypeError(
f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}"
)
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError(
f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}"
)
axis = list(set(axis))
axis.sort()
return axis


def max_and_argmax(a, axis=None, keepdims=False):
"""
Returns maximum elements and their indices obtained by iterating over
Expand All @@ -395,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
"""
# Check axis and convert it to a Python list of integers.
# Axis will be used as an op param of Max and Argmax.
a = as_tensor_variable(a)

is_axis_empty = False
if axis == ():
is_axis_empty = True

axis = check_and_normalize_axes(a, axis)

if len(axis) == 0 and not is_axis_empty:
axis = None

out = Max(axis)(a)

if not is_axis_empty:
argout = Argmax(axis)(a)
else:
argout = zeros_like(a, dtype="int64")

if keepdims:
out = makeKeepDims(a, out, axis)
argout = makeKeepDims(a, argout, axis)
return [out, argout]
return [
max(a, axis=axis, keepdims=keepdims),
argmax(a, axis=axis, keepdims=keepdims),
]


class FixedOpCAReduce(CAReduce):
Expand Down Expand Up @@ -465,7 +404,7 @@ def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis)

def grad(self, inp, grads):
def L_op(self, inputs, outputs, grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
Expand All @@ -479,53 +418,27 @@ def grad(self, inp, grads):
# g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically
x = inp[0]
if self.axis is None:
self.axis = tuple(range(x.ndim))
axis = as_tensor_variable(self.axis)
(g_max,) = grads

g_max_disconnected = isinstance(g_max.type, DisconnectedType)
[x] = inputs
[out] = outputs
[g_out] = grads

# if the op is totally disconnected, so are its inputs
if g_max_disconnected:
return [DisconnectedType()()]

# if NoneConst.equals(axis):
if axis is None:
axis_ = list(range(x.ndim))
else:
axis_ = axis
xmax = max(x, axis_)

# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i in axis.data:
pattern.append("x")
else:
pattern.append(out_dim)
out_dim += 1
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
axis = tuple(range(x.ndim)) if self.axis is None else self.axis
out_pad = expand_dims(out, axis)
g_out_pad = expand_dims(g_out, axis)

# Set the grad to the correct position.
g_x = eq(xmax_pad, x) * g_max_pad
g_x = eq(out_pad, x) * g_out_pad
return (g_x,)

def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
if len(self.axis) != 1:
raise ValueError("R_op supported for arg_max only for one axis!")
raise ValueError("R_op supported for max only for one axis!")
if self.axis[0] > 1:
raise ValueError("R_op supported for arg_max only when axis is 0 or 1")
raise ValueError("R_op supported for max only when axis is 0 or 1")
if inputs[0].ndim != 2:
raise ValueError("R_op supported for arg_max only when input is a matrix")
raise ValueError("R_op supported for max only when input is a matrix")
max_pos = Argmax(self.axis).make_node(*inputs).outputs
# print(eval_points[0].eval())
if self.axis[0] == 0:
Expand Down Expand Up @@ -564,7 +477,7 @@ def max(x, axis=None, keepdims=False):
We return an error as numpy when we reduce a dim with a shape of 0.
"""
out = max_and_argmax(x, axis)[0]
out = Max(axis=axis)(x)

if keepdims:
out = makeKeepDims(x, out, axis)
Expand Down
23 changes: 23 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import re
import typing
from collections.abc import Sequence

import numpy as np
from numpy.core.numeric import normalize_axis_tuple

import pytensor
from pytensor.utils import hash_from_code


if typing.TYPE_CHECKING:
from pytensor.tensor.var import TensorVariable


def hash_from_ndarray(data):
"""
Return a hash from an ndarray.
Expand Down Expand Up @@ -222,3 +228,20 @@ def operand_sig(operand_ndim: int, prefix: str) -> str:
operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim)
)
return f"{inputs_sig}->{outputs_sig}"


def normalize_reduce_axis(x: "TensorVariable", axis) -> tuple[int, ...] | None:
"""Normalize the axis parameter for reduce operations."""
if axis is None:
return None

# scalar inputs are treated as 1D regarding axis in reduce operations
x_ndim = x.type.ndim
if axis is not None:
try:
axis = normalize_axis_tuple(axis, ndim=max(1, x_ndim))
except np.AxisError:
raise np.AxisError(axis, ndim=x_ndim)

# TODO: If axis tuple is equivalent to None, return None for more canonicalization?
return axis
Loading

0 comments on commit d1b0d8a

Please sign in to comment.