diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 8ee86e6021..f234b46804 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -337,6 +337,7 @@ def local_subtensor_of_dot(fgraph, node): @register_useless @register_canonicalize @register_specialize +@register_stabilize @node_rewriter([Subtensor]) def local_useless_slice(fgraph, node): """ @@ -344,42 +345,64 @@ def local_useless_slice(fgraph, node): 1. X[0, :] -> X[0] 2. X[:] -> X + Also, rewrite Subtensor of the form: + X[0:7:1] -> X[None:None:None] + where X is a vector of length 7 + """ idxs = get_idx_list(node.inputs, node.op.idx_list) + x = node.inputs[0] if not idxs: return [node.inputs[0]] - last_useless_slice = len(idxs) - for s in idxs[::-1]: - # check if slice and then check slice indices + new_idxs = list(idxs) + change_flag = False + last_useful_idx = -1 + for dim, s in enumerate(new_idxs): + if not isinstance(s, slice): + last_useful_idx = dim + continue + + if s == slice(None): + continue + + start = s.start + stop = s.stop + step = s.step if ( - isinstance(s, slice) - and s.start is None - and s.stop is None - and ( - s.step is None - or extract_constant(s.step, only_process_constants=True) == 1 - ) + start is not None + and extract_constant(start, only_process_constants=True) == 0 ): - last_useless_slice -= 1 - else: - break - # check if we removed something - if last_useless_slice < len(idxs): - new_idxs = idxs[:last_useless_slice] - if new_idxs: - new_subtensor = Subtensor(new_idxs) - new_subtensor_inputs = get_slice_elements( - new_idxs, lambda x: isinstance(x, Variable) - ) - out = new_subtensor(node.inputs[0], *new_subtensor_inputs) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, out) - return [out] - else: - # Subtensor is not needed at all - return [node.inputs[0]] + change_flag = True + start = None + + if ( + stop is not None + and x.type.shape[dim] is not None + and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] + ): + change_flag = True + stop = None + + if ( + step is not None + and extract_constant(step, only_process_constants=True) == 1 + ): + change_flag = True + step = None + + if not (start is None and stop is None and step is None): + last_useful_idx = dim + + new_idxs[dim] = slice(start, stop, step) + + if change_flag or ((last_useful_idx + 1) < len(idxs)): + out = x[tuple(new_idxs[: last_useful_idx + 1])] + # Copy over previous output stacktrace + copy_stack_trace(node.outputs, out) + + return [out] # fast_compile to allow opt subtensor(cast{float32}(make_vector)) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index f7ea7cdce4..91575bc7da 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -10,7 +10,7 @@ from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config from pytensor.graph import FunctionGraph, vectorize_graph -from pytensor.graph.basic import Constant, Variable, ancestors +from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph @@ -2402,3 +2402,44 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): else: expected_out[:, :, core_idxs] += test_y np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + +def test_slice_canonicalize(): + rng = np.random.default_rng(43) + x = tensor(shape=(3, 5, None, 9)) + test_x = rng.normal(size=(3, 5, 8, 9)) + # Test case 1 + y = x[0:None, 0:5, 0:7, 0:9:1] + f = pytensor.function([x], y, allow_input_downcast=True) + + # Get the DeepCopy input and assert that the Op is a DeepCopy + test_y = f.maker.fgraph.outputs[0].owner.inputs[0] + assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp) + + expected_y = x[None:None:None, None:None:None, None:7:None] + + assert equal_computations([test_y], [expected_y]) + + np.testing.assert_allclose( + f(test_x), + test_x[ + 0:None, 0:5, 0:7, 0:9:1 + ], # Use the unoptimized slice to make sure our rewrite logic is correct + ) + + # Test case 2 + y1 = x[0:-1, 0:5, 0:7, 0:-1:-1] + f1 = pytensor.function([x], y1, allow_input_downcast=True) + + # Get the DeepCopy input and assert that the Op is a DeepCopy + test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0] + assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp) + + expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1] + + assert equal_computations([test_y1], [expected_y1]) + + np.testing.assert_allclose( + f1(test_x), + test_x[0:-1, 0:5, 0:7, 0:-1:-1], + )