Open
Description
Description
Some example rewrite code
@node_rewriter([Subtensor])
def local_subtensor_of_join(fgraph, node):
"""Lift a Subtensor through a Join.
join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
join(axis=0, x, y, z, w)[2] -> z[0]
"""
join_var, *idx = node.inputs
if not (join_var.owner and isinstance(join_var.owner.op, Join)):
return None
join_axis, *join_components = join_var.owner.inputs
# Rewrite only works when the join axis is a constant
if not isinstance(join_axis, Constant):
return None
axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
# In this case we lift if we have to figure out which component is selected
# by the index along axis
axis_index = idx_tuple[axis]
if isinstance(axis_index, slice):
# This isn't too hard to support, but it's not implemented yet
return None
if not isinstance(axis_index, Constant):
return None
axis_index = axis_index.data.item()
if axis_index < 0:
return None # TODO: Just have to iterate from right to left
for indexed_component in join_components:
component_axis_length = indexed_component.type.shape[axis]
if component_axis_length is None:
# We can't figure out if this component or a later one will be indexed
return None
if axis_index >= component_axis_length:
# Axis index is beyond this component
axis_index -= component_axis_length
else:
# This is the indexed component
break
else: # no-break:
return
out = indexed_component[(*idx_tuple[:axis], axis_index, *idx_tuple[axis + 1 :])]
else:
# Indexing does not acto on axis, we can simply lift through which component
# and join again
indexed_components = [component[idx_tuple] for component in join_components]
new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
out = join(new_axis, *indexed_components)
return [out]