Skip to content

Commit

Permalink
Get rid of expensive Blockwise(Reshape)
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 4, 2024
1 parent c76d626 commit ca8bf54
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 91 deletions.
64 changes: 55 additions & 9 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from pytensor import Variable
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.uncanonicalize import local_dimshuffle_alloc
from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor


Expand Down Expand Up @@ -70,7 +74,7 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor,
):
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
# These other Ops can't always be trivially vectored at runtime,
# These other Ops can't always be trivially vectorized at runtime,
# since their inputs may imply non-rectangular shapes.
return local_useless_unbatched_blockwise.fn(fgraph, node)

Expand All @@ -86,6 +90,18 @@ def _squeeze_left(x, stop_at_dim: int | None = None):
return x.squeeze(axis=tuple(range(squeeze_ndim)))


def alloc_or_expand_dims_of_alloc(var: Variable) -> bool:
return var.owner and (
isinstance(var.owner.op, Alloc)
or (
isinstance(var.owner.op, DimShuffle)
and var.owner.inputs[0].owner
and isinstance(var.owner.inputs[0].owner.op, Alloc)
)
)


@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_alloc(fgraph, node):
Expand All @@ -97,19 +113,25 @@ def local_blockwise_alloc(fgraph, node):
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
"""

if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
return None

op: Blockwise = node.op # type: ignore

batch_ndim = op.batch_ndim(node)
if not batch_ndim:
return None

if not any(alloc_or_expand_dims_of_alloc(var) for var in node.inputs):
return None

new_inputs = []
batch_shapes = []
can_push_any_alloc = False
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
if inp.owner and isinstance(inp.owner.op, DimShuffle):
# Convert DimShuffle of Alloc to Alloc
new_inp = local_dimshuffle_alloc.transform(None, inp.owner)
if new_inp:
[inp] = new_inp

if inp.owner and isinstance(inp.owner.op, Alloc):
# Push batch dims from Alloc
value, *shape = inp.owner.inputs
Expand Down Expand Up @@ -167,17 +189,15 @@ def local_blockwise_alloc(fgraph, node):
missing_ndim = old_out_type.ndim - new_out_type.ndim
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
if old_out_type.broadcastable[i]:
continue
for batch_dim in batch_dims:
if batch_dim == 1:
continue
batch_shape[i] = batch_dim
if isinstance(batch_dim, Constant):
# Give preference to Constants
batch_shape[i] = batch_dim
break
elif old_out_type.broadcastable[i]:
# Only use non Constant shapes if absolutely necessary
# Otherwise, we use the shape of the non-alloc output
batch_shape[i] = batch_dim

copy_stack_trace(node.outputs, new_outs)
new_outs = [
Expand All @@ -190,3 +210,29 @@ def local_blockwise_alloc(fgraph, node):
]
copy_stack_trace(node.outputs, new_outs)
return new_outs


@register_canonicalize
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_reshape(fgraph, node):
"""Rewrite away square Blockwise reshapes.
Reshape is tricky to vectorize eagerly, because a graph like
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
that must be vectorized before we arrize at the reshape operation.
For the square Reshape case, we must wait for all the intemediate
operations to be lifted as Allocs
"""
if not isinstance(node.op.core_op, Reshape):
return None

x, output_shape = node.inputs
batch_ndim = node.op.batch_ndim(node)
if all(output_shape.type.broadcastable[:batch_ndim]):
batched_shape = x.shape[:batch_ndim]
core_reshape = _squeeze_left(output_shape, batch_ndim)
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
121 changes: 56 additions & 65 deletions pytensor/tensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
register_stabilize,
register_useless,
topo_constant_folding,
)
Expand Down Expand Up @@ -749,51 +748,43 @@ def apply(self, fgraph):
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)


def local_reshape_chain(op):
@node_rewriter([op])
def f(fgraph, node):
"""
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
"""
if not check_chain(node, op, op):
return False

# TODO: this can permit a failing program to run by eliminating
# the lower reshape
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])

# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)

# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]
else:
return False

return f
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Reshape])
def local_reshape_chain(fgraph, node):
"""
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
"""
if not check_chain(node, Reshape, Reshape):
return False

register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])

# Copy over stacktrace from previous output node, as any error
# in new computational graph would have been caused by last op
# in the old computational graph.
copy_stack_trace(node.outputs, rval)

# It might happen that the desired output of this node has a
# broadcastable pattern that does not match that of 'rval'. This is
# when originally, we were able to figure out that one of the
# dimensions of the reshape is one, but some other transformation
# replaced the shape by one for which this cannot be guessed.
# We should try to figure out why we lost the information about this
# constant value... but in the meantime, better not apply this
# rewrite.
if rval.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
if s1 == 1 or s2 == 1
):
return [rval]


@register_useless
@register_canonicalize
@register_stabilize
@register_useless("shape_unsafe")
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Reshape])
def local_useless_reshape(fgraph, node):
"""Remove two kinds of useless `Reshape`.
Expand All @@ -802,24 +793,17 @@ def local_useless_reshape(fgraph, node):
- Remove `Reshape` when reshaping to the shape of the input.
"""
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
inp, output_shape = node.inputs
[output] = node.outputs

if inp.type.ndim != output.type.ndim:
return False

# Simple case: both input and output have a single dimension.
# TODO FIXME XXX: This could hide errors if the user provides inconsistent
# shapes.
if (
inp.type.ndim == 1
and output.type.ndim == 1
and all(
s1 == s2
for s1, s2 in zip(inp.type.shape, output.type.shape)
if s1 == 1 or s2 == 1
)
and inp.type.broadcastable == output.type.broadcastable
):
return [inp]

Expand All @@ -832,8 +816,15 @@ def local_useless_reshape(fgraph, node):

# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
# broadcastable and constant dimensions
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
output_shape_is = output_shape.owner.inputs
if isinstance(output_shape, Constant) or (
output_shape.owner and isinstance(output_shape.owner.op, MakeVector)
):
if isinstance(output_shape, Constant):
output_shape_is = [
as_tensor_variable(dim, ndim=0) for dim in output_shape.data
]
else:
output_shape_is = output_shape.owner.inputs

shape_feature = getattr(fgraph, "shape_feature", None)

Expand Down Expand Up @@ -865,9 +856,9 @@ def local_useless_reshape(fgraph, node):
shape_match[dim] = True
continue

# Match 1 if input.type.shape[dim] == 1
# Match constant if input.type.shape[dim] == constant
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True
continue

Expand All @@ -881,17 +872,18 @@ def local_useless_reshape(fgraph, node):
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=1)
== extract_constant(outshp_i, only_process_constants=1)
extract_constant(inpshp_i, only_process_constants=True)
== extract_constant(outshp_i, only_process_constants=True)
):
shape_match[dim] = True
continue

if all(shape_match) and nb_m1 <= 1:
if nb_m1 <= 1 and all(shape_match):
return [inp]

if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1):
return [inp]

# TODO later: if all the shapes except one match, we may want to
# consider it useless as well, like we do in the 1-dim case.
return False


Expand All @@ -910,9 +902,8 @@ def local_reshape_to_dimshuffle(fgraph, node):
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
"""
op = node.op
inp = node.inputs[0]
output = node.outputs[0]
output_shape = node.inputs[1]
inp, output_shape = node.inputs
[output] = node.outputs

dimshuffle_new_order = []
new_output_shape = []
Expand Down Expand Up @@ -944,7 +935,7 @@ def local_reshape_to_dimshuffle(fgraph, node):


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([Reshape])
def local_reshape_lift(fgraph, node):
"""
Expand Down
31 changes: 29 additions & 2 deletions tests/tensor/rewriting/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from functools import partial

from pytensor import function
from pytensor.graph import FunctionGraph, rewrite_graph
import numpy as np

from pytensor import Mode, function
from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph
from pytensor.graph.basic import equal_computations
from pytensor.scalar import log as scalar_log
from pytensor.tensor import add, alloc, matrix, tensor, tensor3
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.nlinalg import MatrixPinv
from pytensor.tensor.rewriting.blockwise import local_useless_blockwise
from pytensor.tensor.shape import Reshape


def test_useless_blockwise_of_elemwise():
Expand Down Expand Up @@ -118,3 +121,27 @@ def test_blockwise_alloc():
out = vector_add(x, alloc(y, 5))
expected_out = out
assert equal([rewrite(out)], [expected_out])


def test_blockwise_reshape():
x = tensor("x", shape=(None, None, None))
y = x.reshape([x.shape[0] * x.shape[1], -1])

new_x = tensor("x", shape=(None, None, None, None))
new_y = vectorize_graph(y, {x: new_x})
assert not isinstance(new_y.owner.op, Reshape)
assert isinstance(new_y.owner.op, Blockwise) and isinstance(
new_y.owner.op.core_op, Reshape
)

rewritten_y = rewrite_graph(
new_y, include=("canonicalize", "specialize"), clone=True
)
assert isinstance(rewritten_y.owner.op, Reshape)

no_rewrites = Mode(linker="py", optimizer=None)
test_x = np.arange(5 * 4 * 3 * 2).reshape(5, 4, 3, 2)
np.testing.assert_allclose(
new_y.eval({"x": test_x}, mode=no_rewrites),
rewritten_y.eval({"x": test_x}, mode=no_rewrites),
)
Loading

0 comments on commit ca8bf54

Please sign in to comment.