Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply forwarding in pjit linearization rule to avoid intermediate copies #27735

Merged
merged 1 commit into from
Apr 7, 2025

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Apr 4, 2025

Many linearization rules require forwarding inputs as residuals and we should avoid instantiating copies of those arguments. We have discussed adding forwarding information to the signature of the linearize rules but, in the meantime, it's probably useful to explicitly apply our input forwarding logic when linearizing pjit.

This PR also adds a DCE pass to linearize_jaxpr to prune some other unused residuals.

These two changes fix some OOM crashes with real workloads when using direct linearize for AD.

@dfm dfm self-assigned this Apr 4, 2025
@dfm dfm added the pull ready Ready for copybara import and testing label Apr 4, 2025
@dfm dfm requested a review from dougalm April 4, 2025 16:45
@copybara-service copybara-service bot merged commit 5581e7d into jax-ml:main Apr 7, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants