-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance drop for the batching rule for aten::_sparse_mm #1075
Comments
We currently do not support vmap over sparse tensors. Could you tell us a bit more about your use case please? (cc @cpuhrsch) |
TLDR: Filtering on topological spaces (i.e. graphs, simplicial/cell complexes) require K sparse-dense matmul ops (i.e. sum_k S^k • X • W_k ) where S^k is sparse. My specific use case is to implement a differentiable filtering operation on topological spaces (take graphs as example, higher-order relational structure in the general case). By looking around it seems that the only way to do this is using a for loop like this: out = torch.stack([Sk.mm(X).mm(W[k]) for k, Sk in enumerate(G.adj)], dim=0).sum(axis=0) However, with vmap the for loop above is obsolete: from functorch import vmap
mm = vmap(torch.sparse.mm, in_dims=(0, None))
comp = mm(S, X)
out = torch.bmm(comp, W).sum(axis=0) Where S is a KxNxN sparse tensor, X is a NxFin dense matrix and W is a KxFinxFout dense tensor. Maybe is too specific to my use case but I think it can be very useful for all the folks that are interested in machine learning on graphs. |
We would also be interested in a performant vmap for sparse-dense matrix-vector multiplications. We use the sparse matrix to represent different interpolations in MR data, for examplecin non-uniform FFTs or Volume-to-slice Projections. |
Hello @zou3519 , @samdow.
TLDR: I got the following error
UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation
.I was trying to use vmap for batching sparse-dense matrix multiplications :
which yields correct results but:
.../functorch/_src/vmap.py:489: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.)
Are there plans to implement batching for this operation in the near future ?
Thanks
The text was updated successfully, but these errors were encountered: