Skip to content

Commit

Permalink
Canonicalize subtensor slices
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed May 18, 2024
1 parent 30b760f commit 36645b7
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
47 changes: 47 additions & 0 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def local_useless_slice(fgraph, node):
# check if we removed something
if last_useless_slice < len(idxs):
new_idxs = idxs[:last_useless_slice]

if new_idxs:
new_subtensor = Subtensor(new_idxs)
new_subtensor_inputs = get_slice_elements(
Expand All @@ -381,6 +382,52 @@ def local_useless_slice(fgraph, node):
return [node.inputs[0]]


@register_useless
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Subtensor])
def local_replace_slice(fgraph, node):
"""
Rewrite Subtensor of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7
"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]

if not idxs:
return [x]

new_idxs = list(idxs)
idx_flag = False
for dim, s in enumerate(new_idxs):
if not isinstance(s, slice):
continue

start = s.start
stop = s.stop
step = s.step
if extract_constant(start, only_process_constants=True) == 0:
idx_flag = True
start = None

if extract_constant(stop, only_process_constants=True) == x.type.shape[dim]:
idx_flag = True
stop = None

if extract_constant(step, only_process_constants=True) == 1:
idx_flag = True
step = None

new_idxs[dim] = slice(start, stop, step)

assert node.outputs[0].type == x[tuple(new_idxs)].type
if idx_flag is True:
return [x[tuple(new_idxs)]]


# fast_compile to allow opt subtensor(cast{float32}(make_vector))
@register_canonicalize("fast_compile")
@node_rewriter([Subtensor])
Expand Down
23 changes: 22 additions & 1 deletion tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytensor.compile.ops import DeepCopyOp
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph, vectorize_graph
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
Expand Down Expand Up @@ -2402,3 +2402,24 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
else:
expected_out[:, :, core_idxs] += test_y
np.testing.assert_allclose(fn(test_x, test_y), expected_out)


@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)])
def test_slice_canonicalize(fstop, lstop, lstep):
x = tensor(shape=(3, 5, None, 9))
y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep]
f = pytensor.function([x], y)
test_y = f.maker.fgraph.toposort()

y1 = x[None:None:None, None:None:None, None:7:None, None:None:None]

if fstop == -1 and lstop == -1 and lstep == -1:
y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]

f1 = pytensor.function([x], y1)
expected_y = f1.maker.fgraph.toposort()

assert all(
equal_computations([x1], [y1])
for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs)
)

0 comments on commit 36645b7

Please sign in to comment.