Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance drop for the batching rule for aten::_sparse_mm #1075

Open
lrnzgiusti opened this issue Dec 4, 2022 · 3 comments
Open

Performance drop for the batching rule for aten::_sparse_mm #1075

lrnzgiusti opened this issue Dec 4, 2022 · 3 comments

Comments

@lrnzgiusti
Copy link

Hello @zou3519 , @samdow.

TLDR: I got the following error UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation.

I was trying to use vmap for batching sparse-dense matrix multiplications :

from functorch import vmap

A = tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 1, 1],
                           [0, 1, 2, 3, 0, 1, 2, 3],
                           [0, 1, 2, 3, 0, 1, 2, 3]]),
                   values=tensor([1., 1., 1., 1., 2., 2., 2., 2.]),
                   size=(2, 4, 4), nnz=8, layout=torch.sparse_coo)

X = tensor([[-0.0533, -1.3950, -0.2621],
            [-1.0800,  0.3210,  0.7954],
            [ 0.7737,  0.3655,  0.5691],
            [-0.3505, -1.0423, -2.0650]])


bspmm = vmap(torch.sparse.mm, in_dims=(0, None))
Z = bspmm(A,X) 

In [1]: A.shape
Out[1]: torch.Size([2, 4, 4])

In [2]: X.shape
Out[2]: torch.Size([4, 3])

In [3]: Z.shape
Out[3]: torch.Size([2, 4, 3])

which yields correct results but:

.../functorch/_src/vmap.py:489: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.)

Are there plans to implement batching for this operation in the near future ?

Thanks

@zou3519
Copy link
Contributor

zou3519 commented Dec 5, 2022

We currently do not support vmap over sparse tensors. Could you tell us a bit more about your use case please? (cc @cpuhrsch)

@lrnzgiusti
Copy link
Author

TLDR: Filtering on topological spaces (i.e. graphs, simplicial/cell complexes) require K sparse-dense matmul ops (i.e. sum_k S^k • X • W_k ) where S^k is sparse.

My specific use case is to implement a differentiable filtering operation on topological spaces (take graphs as example, higher-order relational structure in the general case). By looking around it seems that the only way to do this is using a for loop like this:

out = torch.stack([Sk.mm(X).mm(W[k]) for k, Sk in enumerate(G.adj)], dim=0).sum(axis=0)

However, with vmap the for loop above is obsolete:

from functorch import vmap
mm = vmap(torch.sparse.mm, in_dims=(0, None))

comp = mm(S, X)
out = torch.bmm(comp, W).sum(axis=0)

Where S is a KxNxN sparse tensor, X is a NxFin dense matrix and W is a KxFinxFout dense tensor.

Maybe is too specific to my use case but I think it can be very useful for all the folks that are interested in machine learning on graphs.

@fzimmermann89
Copy link

We would also be interested in a performant vmap for sparse-dense matrix-vector multiplications.

We use the sparse matrix to represent different interpolations in MR data, for examplecin non-uniform FFTs or Volume-to-slice Projections.
In both cases, it is much more convenient (and faster) to construct the sparse matrix once and use on matmull compared to other python-only approaches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants