Skip to content

Performance drop for the batching rule for aten::_sparse_mm #1075

Open
@lrnzgiusti

Description

@lrnzgiusti

Hello @zou3519 , @samdow.

TLDR: I got the following error UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation.

I was trying to use vmap for batching sparse-dense matrix multiplications :

from functorch import vmap

A = tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 1, 1],
                           [0, 1, 2, 3, 0, 1, 2, 3],
                           [0, 1, 2, 3, 0, 1, 2, 3]]),
                   values=tensor([1., 1., 1., 1., 2., 2., 2., 2.]),
                   size=(2, 4, 4), nnz=8, layout=torch.sparse_coo)

X = tensor([[-0.0533, -1.3950, -0.2621],
            [-1.0800,  0.3210,  0.7954],
            [ 0.7737,  0.3655,  0.5691],
            [-0.3505, -1.0423, -2.0650]])


bspmm = vmap(torch.sparse.mm, in_dims=(0, None))
Z = bspmm(A,X) 

In [1]: A.shape
Out[1]: torch.Size([2, 4, 4])

In [2]: X.shape
Out[2]: torch.Size([4, 3])

In [3]: Z.shape
Out[3]: torch.Size([2, 4, 3])

which yields correct results but:

.../functorch/_src/vmap.py:489: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.)

Are there plans to implement batching for this operation in the near future ?

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions