Skip to content

Consider lifting Subtensor through Joins #919

Open
@ricardoV94

Description

@ricardoV94

Description

Some example rewrite code

@node_rewriter([Subtensor])
def local_subtensor_of_join(fgraph, node):
    """Lift a Subtensor through a Join.

    join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
    join(axis=0, x, y, z, w)[2] -> z[0]
    """
    join_var, *idx = node.inputs

    if not (join_var.owner and isinstance(join_var.owner.op, Join)):
        return None

    join_axis, *join_components = join_var.owner.inputs

    # Rewrite only works when the join axis is a constant
    if not isinstance(join_axis, Constant):
        return None

    axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
    idx_tuple = indices_from_subtensor(idx, node.op.idx_list)

    if _axis_is_indexed_by_basic_index(idx_tuple, axis):
        # In this case we lift if we have to figure out which component is selected
        # by the index along axis
        axis_index = idx_tuple[axis]
        if isinstance(axis_index, slice):
            # This isn't too hard to support, but it's not implemented yet
            return None
        if not isinstance(axis_index, Constant):
            return None
        axis_index = axis_index.data.item()
        if axis_index < 0:
            return None  # TODO: Just have to iterate from right to left
        for indexed_component in join_components:
            component_axis_length = indexed_component.type.shape[axis]
            if component_axis_length is None:
                # We can't figure out if this component or a later one will be indexed
                return None
            if axis_index >= component_axis_length:
                # Axis index is beyond this component
                axis_index -= component_axis_length
            else:
                # This is the indexed component
                break
        else:  # no-break:
            return
        out = indexed_component[(*idx_tuple[:axis], axis_index, *idx_tuple[axis + 1 :])]

    else:
        # Indexing does not acto on axis, we can simply lift through which component
        # and join again
        indexed_components = [component[idx_tuple] for component in join_components]
        new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
        out = join(new_axis, *indexed_components)

    return [out]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions