Skip to content

Commit

Permalink
[Sparse] Refactor matmul interface. (dmlc#5162)
Browse files Browse the repository at this point in the history
* [Sparse] Refactor matmul interface.

* Update
  • Loading branch information
czkkkkkk authored Jan 13, 2023
1 parent 9334421 commit b5c5c86
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 73 deletions.
2 changes: 1 addition & 1 deletion docs/source/api/python/dgl.sparse_v0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ Matrix Multiplication
.. autosummary::
:toctree: ../../generated/

matmul
spmm
bspmm
spspmm
mm
sddmm
bsddmm

Expand Down
187 changes: 125 additions & 62 deletions python/dgl/sparse/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from .sparse_matrix import SparseMatrix, val_like

__all__ = ["spmm", "bspmm", "spspmm", "mm"]
__all__ = ["spmm", "bspmm", "spspmm", "matmul"]


def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""Multiply a sparse matrix by a dense matrix.
"""Multiply a sparse matrix by a dense matrix, equivalent to ``A @ X``.
Parameters
----------
Expand Down Expand Up @@ -54,7 +54,8 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:


def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
"""Multiply a sparse matrix by a dense matrix by batches.
"""Multiply a sparse matrix by a dense matrix by batches, equivalent to
``A @ X``.
Parameters
----------
Expand Down Expand Up @@ -91,30 +92,30 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
return spmm(A, X)


def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix:
def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix:
"""Internal function for multiplying a diagonal matrix by a diagonal matrix
Parameters
----------
A1 : DiagMatrix
A : DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A2 : DiagMatrix
B : DiagMatrix
Matrix of shape (M, P), with values of shape (nnz2)
Returns
-------
DiagMatrix
The result of multiplication.
"""
M, N = A1.shape
N, P = A2.shape
M, N = A.shape
N, P = B.shape
common_diag_len = min(M, N, P)
new_diag_len = min(M, P)
diag_val = torch.zeros(new_diag_len)
diag_val[:common_diag_len] = (
A1.val[:common_diag_len] * A2.val[:common_diag_len]
A.val[:common_diag_len] * B.val[:common_diag_len]
)
return diag(diag_val.to(A1.device), (M, P))
return diag(diag_val.to(A.device), (M, P))


def _sparse_diag_mm(A, D):
Expand Down Expand Up @@ -174,16 +175,17 @@ def _diag_sparse_mm(D, A):


def spspmm(
A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix]
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
"""Multiply a sparse matrix by a sparse matrix. The non-zero values of the
two sparse matrices must be 1D.
"""Multiply a sparse matrix by a sparse matrix, equivalent to ``A @ B``.
The non-zero values of the two sparse matrices must be 1D.
Parameters
----------
A1 : SparseMatrix or DiagMatrix
A : SparseMatrix or DiagMatrix
Sparse matrix of shape (N, M) with values of shape (nnz)
A2 : SparseMatrix or DiagMatrix
B : SparseMatrix or DiagMatrix
Sparse matrix of shape (M, P) with values of shape (nnz)
Returns
Expand All @@ -198,87 +200,148 @@ def spspmm(
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1))
>>> A1 = from_coo(row1, col1, val1)
>>> A = from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2))
>>> A2 = from_coo(row2, col2, val2)
>>> result = dgl.sparse.spspmm(A1, A2)
>>> B = from_coo(row2, col2, val2)
>>> result = dgl.sparse.spspmm(A, B)
>>> print(result)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(2, 3), nnz=5)
"""
assert isinstance(
A1, (SparseMatrix, DiagMatrix)
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A1)}"
A, (SparseMatrix, DiagMatrix)
), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}"
assert isinstance(
A2, (SparseMatrix, DiagMatrix)
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(A2)}"

if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix):
return _diag_diag_mm(A1, A2)
if isinstance(A1, DiagMatrix):
return _diag_sparse_mm(A1, A2)
if isinstance(A2, DiagMatrix):
return _sparse_diag_mm(A1, A2)
B, (SparseMatrix, DiagMatrix)
), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}"

if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A, B)
if isinstance(A, DiagMatrix):
return _diag_sparse_mm(A, B)
if isinstance(B, DiagMatrix):
return _sparse_diag_mm(A, B)
return SparseMatrix(
torch.ops.dgl_sparse.spspmm(A1.c_sparse_matrix, A2.c_sparse_matrix)
torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix)
)


def mm(
A1: Union[SparseMatrix, DiagMatrix],
A2: Union[torch.Tensor, SparseMatrix, DiagMatrix],
def matmul(
A: Union[torch.Tensor, SparseMatrix, DiagMatrix],
B: Union[torch.Tensor, SparseMatrix, DiagMatrix],
) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]:
"""Multiply a sparse/diagonal matrix by a dense/sparse/diagonal matrix.
If an input is a SparseMatrix or DiagMatrix, its non-zero values should
be 1-D.
"""Multiply two dense/sparse/diagonal matrices, equivalent to ``A @ B``.
The supported combinations are shown as follows.
+--------------+--------+------------+--------------+
| A \\ B | Tensor | DiagMatrix | SparseMatrix |
+--------------+--------+------------+--------------+
| Tensor | ✅ | 🚫 | 🚫 |
+--------------+--------+------------+--------------+
| SparseMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
| DiagMatrix | ✅ | ✅ | ✅ |
+--------------+--------+------------+--------------+
* If both matrices are torch.Tensor, it calls \
:func:`torch.matmul()`. The result is a dense matrix.
* If both matrices are sparse or diagonal, it calls \
:func:`dgl.sparse.spspmm`. The result is a sparse matrix.
* If :attr:`A` is sparse or diagonal while :attr:`B` is dense, it \
calls :func:`dgl.sparse.spmm`. The result is a dense matrix.
* The operator supports batched sparse-dense matrix multiplication. In \
this case, the sparse or diagonal matrix :attr:`A` should have shape \
:math:`(L, M)`, where the non-zero values have a batch dimension \
:math:`K`. The dense matrix :attr:`B` should have shape \
:math:`(M, N, K)`. The output is a dense matrix of shape \
:math:`(L, N, K)`.
* Sparse-sparse matrix multiplication does not support batched computation.
Parameters
----------
A1 : SparseMatrix or DiagMatrix
Matrix of shape (N, M), with values of shape (nnz1)
A2 : torch.Tensor, SparseMatrix, or DiagMatrix
Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix,
it should have values of shape (nnz2).
A : torch.Tensor, SparseMatrix or DiagMatrix
The first matrix.
B : torch.Tensor, SparseMatrix, or DiagMatrix
The second matrix.
Returns
-------
torch.Tensor or DiagMatrix or SparseMatrix
The result of multiplication of shape (N, P)
* It is a dense torch tensor if :attr:`A2` is so.
* It is a DiagMatrix object if both :attr:`A1` and :attr:`A2` are so.
* It is a SparseMatrix object otherwise.
torch.Tensor, SparseMatrix or DiagMatrix
The result matrix
Examples
--------
Multiply a diagonal matrix with a dense matrix.
>>> val = torch.randn(3)
>>> A1 = diag(val)
>>> A2 = torch.randn(3, 2)
>>> result = dgl.sparse.mm(A1, A2)
>>> A = diag(val)
>>> B = torch.randn(3, 2)
>>> result = dgl.sparse.matmul(A, B)
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([3, 2])
Multiply a sparse matrix with a dense matrix.
>>> row = torch.tensor([0, 1, 1])
>>> col = torch.tensor([1, 0, 1])
>>> val = torch.randn(len(row))
>>> A = from_coo(row, col, val)
>>> X = torch.randn(2, 3)
>>> result = dgl.sparse.matmul(A, X)
>>> print(type(result))
<class 'torch.Tensor'>
>>> print(result.shape)
torch.Size([2, 3])
Multiply a sparse matrix with a sparse matrix.
>>> row1 = torch.tensor([0, 1, 1])
>>> col1 = torch.tensor([1, 0, 1])
>>> val1 = torch.ones(len(row1))
>>> A = from_coo(row1, col1, val1)
>>> row2 = torch.tensor([0, 1, 1])
>>> col2 = torch.tensor([0, 2, 1])
>>> val2 = torch.ones(len(row2))
>>> B = from_coo(row2, col2, val2)
>>> result = dgl.sparse.matmul(A, B)
>>> print(type(result))
<class 'dgl.sparse.sparse_matrix.SparseMatrix'>
>>> print(result.shape)
(2, 3)
"""
assert isinstance(
A1, (SparseMatrix, DiagMatrix)
), f"Expect arg1 to be a SparseMatrix, or DiagMatrix object, got {type(A1)}."
assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), (
assert isinstance(A, (torch.Tensor, SparseMatrix, DiagMatrix)), (
f"Expect arg1 to be a torch.Tensor, SparseMatrix, or DiagMatrix object,"
f"got {type(A)}."
)
assert isinstance(B, (torch.Tensor, SparseMatrix, DiagMatrix)), (
f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix"
f"object, got {type(A2)}."
f"object, got {type(B)}."
)
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
return torch.matmul(A, B)
assert not isinstance(A, torch.Tensor), (
f"Expect arg2 to be a torch Tensor if arg 1 is torch Tensor, "
f"got {type(B)}."
)
if isinstance(A2, torch.Tensor):
return spmm(A1, A2)
if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix):
return _diag_diag_mm(A1, A2)
return spspmm(A1, A2)
if isinstance(B, torch.Tensor):
return spmm(A, B)
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return _diag_diag_mm(A, B)
return spspmm(A, B)


SparseMatrix.__matmul__ = mm
DiagMatrix.__matmul__ = mm
SparseMatrix.__matmul__ = matmul
DiagMatrix.__matmul__ = matmul
19 changes: 9 additions & 10 deletions tests/pytorch/sparse/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pytest
import torch

from dgl.sparse import bspmm, diag, from_coo, mm, val_like
from dgl.sparse import bspmm, diag, from_coo, val_like
from dgl.sparse.matmul import matmul

from .utils import (
clone_detach_and_grad,
Expand Down Expand Up @@ -33,7 +34,7 @@ def test_spmm(create_func, shape, nnz, out_dim):
else:
X = torch.randn(shape[1], requires_grad=True, device=dev)

sparse_result = A @ X
sparse_result = matmul(A, X)
grad = torch.randn_like(sparse_result)
sparse_result.backward(grad)

Expand All @@ -60,7 +61,7 @@ def test_bspmm(create_func, shape, nnz):
A = create_func(shape, nnz, dev, 2)
X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev)

sparse_result = bspmm(A, X)
sparse_result = matmul(A, X)
grad = torch.randn_like(sparse_result)
sparse_result.backward(grad)

Expand Down Expand Up @@ -92,7 +93,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
shape2 = (shape_n_m[1], shape_k)
A1 = create_func1(shape1, nnz1, dev)
A2 = create_func2(shape2, nnz2, dev)
A3 = A1 @ A2
A3 = matmul(A1, A2)
grad = torch.randn_like(A3.val)
A3.val.backward(grad)

Expand Down Expand Up @@ -132,14 +133,14 @@ def test_spspmm_duplicate():
A2 = from_coo(row, col, val, shape)

try:
A1 @ A2
matmul(A1, A2)
except:
pass
else:
assert False, "Should raise error."

try:
A2 @ A1
matmul(A2, A1)
except:
pass
else:
Expand All @@ -155,8 +156,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz):
A = create_func(sparse_shape, nnz, dev)
diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True)
D = diag(diag_val, diag_shape)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B = mm(A, D)
B = matmul(A, D)
grad = torch.randn_like(B.val)
B.val.backward(grad)

Expand Down Expand Up @@ -189,8 +189,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz):
A = create_func(sparse_shape, nnz, dev)
diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True)
D = diag(diag_val, diag_shape)
# (TODO) Need to use dgl.sparse.matmul after rename mm to matmul
B = mm(D, A)
B = matmul(D, A)
grad = torch.randn_like(B.val)
B.val.backward(grad)

Expand Down

0 comments on commit b5c5c86

Please sign in to comment.