Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch function transforms are not compatible with PyroParams #393

Closed
eb8680 opened this issue Nov 27, 2023 · 4 comments
Closed

PyTorch function transforms are not compatible with PyroParams #393

eb8680 opened this issue Nov 27, 2023 · 4 comments
Assignees
Labels
bug Something isn't working module:robust upstream

Comments

@eb8680
Copy link
Contributor

eb8680 commented Nov 27, 2023

In chirho.robust, we're making extensive use of torch.func, especially the vectorization transform torch.func.vmap and forward and reverse-mode autodiff transforms torch.func.jvp/vjp.

Unfortunately, the autodiff transforms seem to be fundamentally incompatible with PyroModules that have constrained PyroParams, including most of the standard autoguides in pyro.infer.autoguide - the raw PyTorch Parameter tensors underlying the constrained PyroParams do not seem to receive or propagate gradients. It's not clear what exactly in Pyro is causing the problem or how to fix or work around it, although it at least doesn't seem specific to ChiRho.

This isn't immediately blocking other work, but without compatibility with PyroModules the practical utility of chirho.robust will be significantly diminished, so we should try to resolve this somehow before merging #398 into master.

@eb8680 eb8680 added bug Something isn't working upstream module:robust labels Nov 27, 2023
@eb8680 eb8680 self-assigned this Nov 27, 2023
@agrawalraj
Copy link
Contributor

agrawalraj commented Dec 6, 2023

Here are some odd interactions of torch.func:

import torch


def f1(x):
    return x**2


def f2(x):
    return torch.tensor([x[0] ** 2, x[1] ** 2, x[2] ** 2])


x = 1.0 * torch.ones(3, requires_grad=True)
v = torch.ones(3)

print(torch.func.jvp(f1, (x,), (v,))[1]) # returns tensor([2., 2., 2.], grad_fn=<MulBackward0>)
print(torch.func.jvp(f2, (x,), (v,))[1]) # returns tensor([0., 0., 0.])

Figured I would post in case it was related at all to the composability issue with PyroParams.

@eb8680
Copy link
Contributor Author

eb8680 commented Feb 15, 2024

This should be resolved once a version of Pyro including pyro-ppl/pyro#3328 is released.

@agrawalraj
Copy link
Contributor

Ok that's awesome, thanks!

@eb8680
Copy link
Contributor Author

eb8680 commented Feb 27, 2024

Resolved in the Pyro 1.9.0 release

@eb8680 eb8680 closed this as completed Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module:robust upstream
Projects
None yet
Development

No branches or pull requests

2 participants