Description
TLDR: I got the following error UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation
.
I was trying to use vmap for batching sparse-dense matrix multiplications :
from functorch import vmap
A = tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 2, 3, 0, 1, 2, 3],
[0, 1, 2, 3, 0, 1, 2, 3]]),
values=tensor([1., 1., 1., 1., 2., 2., 2., 2.]),
size=(2, 4, 4), nnz=8, layout=torch.sparse_coo)
X = tensor([[-0.0533, -1.3950, -0.2621],
[-1.0800, 0.3210, 0.7954],
[ 0.7737, 0.3655, 0.5691],
[-0.3505, -1.0423, -2.0650]])
bspmm = vmap(torch.sparse.mm, in_dims=(0, None))
Z = bspmm(A,X)
In [1]: A.shape
Out[1]: torch.Size([2, 4, 4])
In [2]: X.shape
Out[2]: torch.Size([4, 3])
In [3]: Z.shape
Out[3]: torch.Size([2, 4, 3])
which yields correct results but:
.../functorch/_src/vmap.py:489: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.)
Are there plans to implement batching for this operation in the near future ?
Thanks