Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonicalize Subtensor slices #761

Merged
merged 3 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,49 +337,72 @@ def local_subtensor_of_dot(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@register_stabilize
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form:
1. X[0, :] -> X[0]
2. X[:] -> X

Also, rewrite Subtensor of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7

"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]

if not idxs:
return [node.inputs[0]]

last_useless_slice = len(idxs)
for s in idxs[::-1]:
# check if slice and then check slice indices
new_idxs = list(idxs)
change_flag = False
Dhruvanshu-Joshi marked this conversation as resolved.
Show resolved Hide resolved
last_useful_idx = -1
for dim, s in enumerate(new_idxs):
if not isinstance(s, slice):
Dhruvanshu-Joshi marked this conversation as resolved.
Show resolved Hide resolved
last_useful_idx = dim
continue

if s == slice(None):
continue

start = s.start
stop = s.stop
step = s.step
if (
isinstance(s, slice)
and s.start is None
and s.stop is None
and (
s.step is None
or extract_constant(s.step, only_process_constants=True) == 1
)
start is not None
and extract_constant(start, only_process_constants=True) == 0
):
last_useless_slice -= 1
else:
break
# check if we removed something
if last_useless_slice < len(idxs):
new_idxs = idxs[:last_useless_slice]
if new_idxs:
new_subtensor = Subtensor(new_idxs)
new_subtensor_inputs = get_slice_elements(
new_idxs, lambda x: isinstance(x, Variable)
)
out = new_subtensor(node.inputs[0], *new_subtensor_inputs)
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)
return [out]
else:
# Subtensor is not needed at all
return [node.inputs[0]]
change_flag = True
start = None

if (
stop is not None
and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
):
change_flag = True
stop = None

if (
step is not None
and extract_constant(step, only_process_constants=True) == 1
):
change_flag = True
step = None

if not (start is None and stop is None and step is None):
last_useful_idx = dim

new_idxs[dim] = slice(start, stop, step)
Dhruvanshu-Joshi marked this conversation as resolved.
Show resolved Hide resolved

if change_flag or ((last_useful_idx + 1) < len(idxs)):
out = x[tuple(new_idxs[: last_useful_idx + 1])]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)

return [out]


# fast_compile to allow opt subtensor(cast{float32}(make_vector))
Expand Down
43 changes: 42 additions & 1 deletion tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
Expand Down Expand Up @@ -2402,3 +2402,44 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
else:
expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)


def test_slice_canonicalize():
rng = np.random.default_rng(43)
x = tensor(shape=(3, 5, None, 9))
test_x = rng.normal(size=(3, 5, 8, 9))
# Test case 1
y = x[0:None, 0:5, 0:7, 0:9:1]
f = pytensor.function([x], y, allow_input_downcast=True)

# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not the output directly, there's a deepcopy?

assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)

expected_y = x[None:None:None, None:None:None, None:7:None]

assert equal_computations([test_y], [expected_y])

np.testing.assert_allclose(
f(test_x),
test_x[
0:None, 0:5, 0:7, 0:9:1
], # Use the unoptimized slice to make sure our rewrite logic is correct
)

# Test case 2
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
f1 = pytensor.function([x], y1, allow_input_downcast=True)

# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)

expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]

assert equal_computations([test_y1], [expected_y1])

np.testing.assert_allclose(
f1(test_x),
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
)
Loading