diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 1a2c27cf7f3..1dfb6ba1038 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -248,7 +248,7 @@ void TransformReplay::selfAllocationReplay( // have mismatch reduction IDs on the logical between `self` and // `new_self`. auto new_self_logical = TensorDomain::noReductions(new_self->logical()); - auto self_logical = TensorDomain::noReductions(self->logical()); + auto self_logical = self->logical(); NVF_ERROR( new_self_logical.size() == self_logical.size(), @@ -286,13 +286,8 @@ void TransformReplay::selfAllocationReplay( // Replay producer dimensions. const std::vector& self_allocation = self->maybeAllocation(); const std::vector>& self_contiguity = self->contiguity(); - const std::vector& self_allocation_no_reduction = - TensorDomain::noReductions(self_allocation); - // we replay only non-reduction IDs. The reason is that, we might have - // non-mapping reduction IDs between self and new_self. This is used in - // `RemoveBcastSqueeze`. - ReplaySelf replay(self_allocation_no_reduction, axis_map); + ReplaySelf replay(self_allocation, axis_map); std::vector new_alloc_domain; std::vector> new_contiguity; new_alloc_domain.reserve(self_allocation.size()); @@ -303,7 +298,7 @@ void TransformReplay::selfAllocationReplay( for (auto id : new_self->logical()) { if (id->isReduction()) { new_alloc_domain.push_back(id); - // NOLINTNEXTLINE (modernize-use-emplace) + // NOLINTNEXTLINE(modernize-use-emplace) new_contiguity.push_back(std::nullopt); } } @@ -311,9 +306,6 @@ void TransformReplay::selfAllocationReplay( // Pushing the mapped IDs and corresponding contiguity flags for (size_t i : c10::irange(self_allocation.size())) { IterDomain* id = self_allocation[i]; - if (id->isReduction()) { - continue; - } auto it = replay.getReplay().find(id); NVF_ERROR( it != replay.getReplay().end(), "failed to replay IterDomain: ", id);