You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@node_rewriter([Subtensor])deflocal_subtensor_of_join(fgraph, node):
"""Lift a Subtensor through a Join. join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0]) join(axis=0, x, y, z, w)[2] -> z[0] """join_var, *idx=node.inputsifnot (join_var.ownerandisinstance(join_var.owner.op, Join)):
returnNonejoin_axis, *join_components=join_var.owner.inputs# Rewrite only works when the join axis is a constantifnotisinstance(join_axis, Constant):
returnNoneaxis=normalize_axis_index(join_axis.data, join_components[0].type.ndim)
idx_tuple=indices_from_subtensor(idx, node.op.idx_list)
if_axis_is_indexed_by_basic_index(idx_tuple, axis):
# In this case we lift if we have to figure out which component is selected# by the index along axisaxis_index=idx_tuple[axis]
ifisinstance(axis_index, slice):
# This isn't too hard to support, but it's not implemented yetreturnNoneifnotisinstance(axis_index, Constant):
returnNoneaxis_index=axis_index.data.item()
ifaxis_index<0:
returnNone# TODO: Just have to iterate from right to leftforindexed_componentinjoin_components:
component_axis_length=indexed_component.type.shape[axis]
ifcomponent_axis_lengthisNone:
# We can't figure out if this component or a later one will be indexedreturnNoneifaxis_index>=component_axis_length:
# Axis index is beyond this componentaxis_index-=component_axis_lengthelse:
# This is the indexed componentbreakelse: # no-break:returnout=indexed_component[(*idx_tuple[:axis], axis_index, *idx_tuple[axis+1 :])]
else:
# Indexing does not acto on axis, we can simply lift through which component# and join againindexed_components= [component[idx_tuple] forcomponentinjoin_components]
new_axis=axis-_ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
out=join(new_axis, *indexed_components)
return [out]
The text was updated successfully, but these errors were encountered:
Description
Some example rewrite code
The text was updated successfully, but these errors were encountered: