Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider lifting Subtensor through Joins #919

Open
ricardoV94 opened this issue Jul 10, 2024 · 0 comments
Open

Consider lifting Subtensor through Joins #919

ricardoV94 opened this issue Jul 10, 2024 · 0 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 10, 2024

Description

Some example rewrite code

@node_rewriter([Subtensor])
def local_subtensor_of_join(fgraph, node):
    """Lift a Subtensor through a Join.

    join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
    join(axis=0, x, y, z, w)[2] -> z[0]
    """
    join_var, *idx = node.inputs

    if not (join_var.owner and isinstance(join_var.owner.op, Join)):
        return None

    join_axis, *join_components = join_var.owner.inputs

    # Rewrite only works when the join axis is a constant
    if not isinstance(join_axis, Constant):
        return None

    axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
    idx_tuple = indices_from_subtensor(idx, node.op.idx_list)

    if _axis_is_indexed_by_basic_index(idx_tuple, axis):
        # In this case we lift if we have to figure out which component is selected
        # by the index along axis
        axis_index = idx_tuple[axis]
        if isinstance(axis_index, slice):
            # This isn't too hard to support, but it's not implemented yet
            return None
        if not isinstance(axis_index, Constant):
            return None
        axis_index = axis_index.data.item()
        if axis_index < 0:
            return None  # TODO: Just have to iterate from right to left
        for indexed_component in join_components:
            component_axis_length = indexed_component.type.shape[axis]
            if component_axis_length is None:
                # We can't figure out if this component or a later one will be indexed
                return None
            if axis_index >= component_axis_length:
                # Axis index is beyond this component
                axis_index -= component_axis_length
            else:
                # This is the indexed component
                break
        else:  # no-break:
            return
        out = indexed_component[(*idx_tuple[:axis], axis_index, *idx_tuple[axis + 1 :])]

    else:
        # Indexing does not acto on axis, we can simply lift through which component
        # and join again
        indexed_components = [component[idx_tuple] for component in join_components]
        new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
        out = join(new_axis, *indexed_components)

    return [out]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant