Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement several subtensor lift rewrites #1158

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 20, 2025

This allows reducing computations on batch dimensions by lifting simple indexing operations closer to the inputs.

An obvious example is:

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

mode = get_default_mode()

x = pt.matrix("x", shape=(512, 512))
x_test = np.random.normal(size=x.type.shape)

x_sum = x.sum(axis=1)
out = x_sum[0]

fn_before = pytensor.function([x], out, mode=mode.excluding("local_subtensor_of_reduce"))
fn_before.dprint(print_type=True)
%timeit fn_before(x_test)
# Subtensor{i} [id A] <Scalar(float64, shape=())> 1
#  ├─ Sum{axis=1} [id B] <Vector(float64, shape=(512,))> 0
#  │  └─ x [id C] <Matrix(float64, shape=(512, 512))>
#  └─ 0 [id D] <uint8>
# 762 μs ± 7.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

print()
fn_after = pytensor.function([x], out, mode=mode)
fn_after.dprint(print_type=True)
%timeit fn_after(x_test)
# Sum{axes=None} [id A] <Scalar(float64, shape=())> 1
#  └─ Subtensor{i} [id B] <Vector(float64, shape=(512,))> 0
#     ├─ x [id C] <Matrix(float64, shape=(512, 512))>
#     └─ 0 [id D] <uint8>
# 5.26 μs ± 86 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

📚 Documentation preview 📚: https://pytensor--1158.org.readthedocs.build/en/1158/

@ricardoV94 ricardoV94 force-pushed the subtensor_lift branch 3 times, most recently from cbe0c96 to 9b47cee Compare January 20, 2025 17:56
@ricardoV94 ricardoV94 changed the title Implement sereval subtensor lift rewrites Implement several subtensor lift rewrites Jan 21, 2025
@ricardoV94 ricardoV94 force-pushed the subtensor_lift branch 2 times, most recently from d72b5c2 to 23870fa Compare January 30, 2025 14:42
@ricardoV94 ricardoV94 force-pushed the subtensor_lift branch 2 times, most recently from dd44f13 to 4ae95d7 Compare March 10, 2025 16:09
This reduces the number of passes, as other rewrites don't really care about the dtype of the indices, and can easily introduce non-uint index operations
Copy link

codecov bot commented Mar 11, 2025

Codecov Report

Attention: Patch coverage is 91.64786% with 37 lines in your changes missing coverage. Please review.

Project coverage is 82.04%. Comparing base (00fea0e) to head (580149b).

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/subtensor_lift.py 91.29% 17 Missing and 20 partials ⚠️

❌ Your patch check has failed because the patch coverage (91.64%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1158      +/-   ##
==========================================
+ Coverage   81.99%   82.04%   +0.05%     
==========================================
  Files         188      189       +1     
  Lines       48582    48806     +224     
  Branches     8688     8727      +39     
==========================================
+ Hits        39833    40043     +210     
- Misses       6585     6592       +7     
- Partials     2164     2171       +7     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/subtensor.py 90.55% <100.00%> (+0.42%) ⬆️
pytensor/tensor/rewriting/subtensor_lift.py 91.29% <91.29%> (ø)
🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 requested a review from lucianopaz March 11, 2025 16:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant