Skip to content

Commit d72b5c2

Browse files
committed
Lift Subtensor over AdvancedSubtensor
1 parent 2c28177 commit d72b5c2

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
)
99

1010
from pytensor import Variable
11+
from pytensor.compile import optdb
1112
from pytensor.graph import Constant, FunctionGraph, node_rewriter
1213
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
1314
from pytensor.scalar import basic as ps
@@ -42,16 +43,18 @@
4243
)
4344
from pytensor.tensor.special import Softmax, softmax
4445
from pytensor.tensor.subtensor import (
46+
AdvancedSubtensor,
4547
AdvancedSubtensor1,
4648
Subtensor,
49+
_non_contiguous_adv_indexing,
4750
as_index_literal,
4851
get_canonical_form_slice,
4952
get_constant_idx,
5053
get_idx_list,
5154
indices_from_subtensor,
5255
)
5356
from pytensor.tensor.type import TensorType
54-
from pytensor.tensor.type_other import SliceType
57+
from pytensor.tensor.type_other import NoneTypeT, SliceType
5558
from pytensor.tensor.variable import TensorVariable
5659

5760

@@ -819,3 +822,79 @@ def local_subtensor_shape_constant(fgraph, node):
819822
return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
820823
elif shape_parts:
821824
return [as_tensor(1, dtype=np.int64)]
825+
826+
827+
@node_rewriter([Subtensor])
828+
def local_subtensor_of_adv_subtensor(fgraph, node):
829+
"""Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones.
830+
831+
x[:, :, vec_idx][i, j] -> x[i, j][vec_idx]
832+
x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k]
833+
834+
Restricted to a single advanced indexing dimension.
835+
836+
An alternative approach could have fused the basic and advanced indices,
837+
so it is not clear this rewrite should be canonical or a specialization.
838+
Users must include it manually if it fits their use case.
839+
"""
840+
adv_subtensor, *idxs = node.inputs
841+
842+
if not (
843+
adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor)
844+
):
845+
return None
846+
847+
if len(fgraph.clients[adv_subtensor]) > 1:
848+
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
849+
return None
850+
851+
x, *adv_idxs = adv_subtensor.owner.inputs
852+
853+
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
854+
if any(
855+
(
856+
isinstance(adv_idx.type, NoneTypeT)
857+
or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool")
858+
or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx))
859+
)
860+
for adv_idx in adv_idxs
861+
) or _non_contiguous_adv_indexing(adv_idxs):
862+
return None
863+
864+
for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
865+
# We already made sure there were only None slices besides integer indexes
866+
if isinstance(adv_idx.type, TensorType):
867+
break
868+
else: # no-break
869+
# Not sure if this should ever happen, but better safe than sorry
870+
return None
871+
872+
basic_idxs = indices_from_subtensor(idxs, node.op.idx_list)
873+
basic_idxs_lifted = basic_idxs[:first_adv_idx_dim]
874+
basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[
875+
first_adv_idx_dim:
876+
]
877+
878+
if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted):
879+
# All basic indices happen to the right of the advanced indices
880+
return None
881+
882+
[basic_subtensor] = node.outputs
883+
dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted)
884+
885+
x_indexed = x[basic_idxs_lifted]
886+
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)
887+
888+
x_after_index_lift = expand_dims(x_indexed, dropped_dims)
889+
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs)
890+
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)
891+
892+
new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
893+
return [new_out]
894+
895+
896+
# Rewrite will only be included if tagged by name
897+
r = local_subtensor_of_adv_subtensor
898+
optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False)
899+
optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False)
900+
del r

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
)
5353
from pytensor.tensor.shape import SpecifyShape, Unbroadcast, _shape
5454
from pytensor.tensor.special import softmax
55-
from pytensor.tensor.subtensor import Subtensor
55+
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
5656

5757

5858
mode_opt = config.mode
@@ -762,3 +762,47 @@ def __eq__(self, other):
762762
x = shape(Variable(MyType(), None, None))[0]
763763

764764
assert not local_subtensor_shape_constant.transform(None, x.owner)
765+
766+
767+
@pytest.mark.parametrize(
768+
"original_fn, supported",
769+
[
770+
(lambda x: x[:, [0, 1]][0], True),
771+
(lambda x: x[:, [0, 1], [0, 0]][1:], True),
772+
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
773+
# Not supported, basic indexing on advanced indexing dim
774+
(lambda x: x[[0, 1]][0], False),
775+
# Not implemented, basic indexing on the right of advanced indexing
776+
(lambda x: x[[0, 1]][:, 0], False),
777+
# Not implemented, complex flavors of advanced indexing
778+
(lambda x: x[:, None, [0, 1]][0], False),
779+
(lambda x: x[:, 5:, [0, 1]][0], False),
780+
(lambda x: x[:, :, np.array([True, False, False])][0], False),
781+
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
782+
],
783+
)
784+
def test_local_subtensor_of_adv_subtensor(original_fn, supported):
785+
rng = np.random.default_rng(257)
786+
x = pt.tensor3("x", shape=(7, 5, 3))
787+
x_test = rng.normal(size=x.type.shape)
788+
789+
out = original_fn(x)
790+
opt_out = rewrite_graph(
791+
out, include=("canonicalize", "local_subtensor_of_adv_subtensor")
792+
)
793+
# The graphs generated are too complicated to assert
794+
# We simply check that the happens before the advanced subtensor
795+
toposort = FunctionGraph(outputs=[opt_out], clone=False).toposort()
796+
[idx_subtensor] = [
797+
i for i, node in enumerate(toposort) if isinstance(node.op, Subtensor)
798+
]
799+
[idx_adv_subtensor] = [
800+
i for i, node in enumerate(toposort) if isinstance(node.op, AdvancedSubtensor)
801+
]
802+
swapped = idx_subtensor < idx_adv_subtensor
803+
correct = swapped if supported else not swapped
804+
assert correct, debugprint(opt_out, print_type=True)
805+
np.testing.assert_allclose(
806+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
807+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
808+
)

0 commit comments

Comments
 (0)