diff --git a/docs/source/api/python/dgl.sparse_v0.rst b/docs/source/api/python/dgl.sparse_v0.rst index 8f7ff4a711c3..3d7cdd8c7a81 100644 --- a/docs/source/api/python/dgl.sparse_v0.rst +++ b/docs/source/api/python/dgl.sparse_v0.rst @@ -171,10 +171,10 @@ Matrix Multiplication .. autosummary:: :toctree: ../../generated/ + matmul spmm bspmm spspmm - mm sddmm bsddmm diff --git a/python/dgl/sparse/matmul.py b/python/dgl/sparse/matmul.py index 3d6ee66f93d3..1c8182e5dd23 100644 --- a/python/dgl/sparse/matmul.py +++ b/python/dgl/sparse/matmul.py @@ -8,11 +8,11 @@ from .sparse_matrix import SparseMatrix, val_like -__all__ = ["spmm", "bspmm", "spspmm", "mm"] +__all__ = ["spmm", "bspmm", "spspmm", "matmul"] def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: - """Multiply a sparse matrix by a dense matrix. + """Multiply a sparse matrix by a dense matrix, equivalent to ``A @ X``. Parameters ---------- @@ -54,7 +54,8 @@ def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: - """Multiply a sparse matrix by a dense matrix by batches. + """Multiply a sparse matrix by a dense matrix by batches, equivalent to + ``A @ X``. Parameters ---------- @@ -91,14 +92,14 @@ def bspmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor: return spmm(A, X) -def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix: +def _diag_diag_mm(A: DiagMatrix, B: DiagMatrix) -> DiagMatrix: """Internal function for multiplying a diagonal matrix by a diagonal matrix Parameters ---------- - A1 : DiagMatrix + A : DiagMatrix Matrix of shape (N, M), with values of shape (nnz1) - A2 : DiagMatrix + B : DiagMatrix Matrix of shape (M, P), with values of shape (nnz2) Returns @@ -106,15 +107,15 @@ def _diag_diag_mm(A1: DiagMatrix, A2: DiagMatrix) -> DiagMatrix: DiagMatrix The result of multiplication. """ - M, N = A1.shape - N, P = A2.shape + M, N = A.shape + N, P = B.shape common_diag_len = min(M, N, P) new_diag_len = min(M, P) diag_val = torch.zeros(new_diag_len) diag_val[:common_diag_len] = ( - A1.val[:common_diag_len] * A2.val[:common_diag_len] + A.val[:common_diag_len] * B.val[:common_diag_len] ) - return diag(diag_val.to(A1.device), (M, P)) + return diag(diag_val.to(A.device), (M, P)) def _sparse_diag_mm(A, D): @@ -174,16 +175,17 @@ def _diag_sparse_mm(D, A): def spspmm( - A1: Union[SparseMatrix, DiagMatrix], A2: Union[SparseMatrix, DiagMatrix] + A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix] ) -> Union[SparseMatrix, DiagMatrix]: - """Multiply a sparse matrix by a sparse matrix. The non-zero values of the - two sparse matrices must be 1D. + """Multiply a sparse matrix by a sparse matrix, equivalent to ``A @ B``. + + The non-zero values of the two sparse matrices must be 1D. Parameters ---------- - A1 : SparseMatrix or DiagMatrix + A : SparseMatrix or DiagMatrix Sparse matrix of shape (N, M) with values of shape (nnz) - A2 : SparseMatrix or DiagMatrix + B : SparseMatrix or DiagMatrix Sparse matrix of shape (M, P) with values of shape (nnz) Returns @@ -198,13 +200,13 @@ def spspmm( >>> row1 = torch.tensor([0, 1, 1]) >>> col1 = torch.tensor([1, 0, 1]) >>> val1 = torch.ones(len(row1)) - >>> A1 = from_coo(row1, col1, val1) + >>> A = from_coo(row1, col1, val1) >>> row2 = torch.tensor([0, 1, 1]) >>> col2 = torch.tensor([0, 2, 1]) >>> val2 = torch.ones(len(row2)) - >>> A2 = from_coo(row2, col2, val2) - >>> result = dgl.sparse.spspmm(A1, A2) + >>> B = from_coo(row2, col2, val2) + >>> result = dgl.sparse.spspmm(A, B) >>> print(result) SparseMatrix(indices=tensor([[0, 0, 1, 1, 1], [1, 2, 0, 1, 2]]), @@ -212,73 +214,134 @@ def spspmm( shape=(2, 3), nnz=5) """ assert isinstance( - A1, (SparseMatrix, DiagMatrix) - ), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A1)}" + A, (SparseMatrix, DiagMatrix) + ), f"Expect A1 to be a SparseMatrix or DiagMatrix object, got {type(A)}" assert isinstance( - A2, (SparseMatrix, DiagMatrix) - ), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(A2)}" - - if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix): - return _diag_diag_mm(A1, A2) - if isinstance(A1, DiagMatrix): - return _diag_sparse_mm(A1, A2) - if isinstance(A2, DiagMatrix): - return _sparse_diag_mm(A1, A2) + B, (SparseMatrix, DiagMatrix) + ), f"Expect A2 to be a SparseMatrix or DiagMatrix object, got {type(B)}" + + if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix): + return _diag_diag_mm(A, B) + if isinstance(A, DiagMatrix): + return _diag_sparse_mm(A, B) + if isinstance(B, DiagMatrix): + return _sparse_diag_mm(A, B) return SparseMatrix( - torch.ops.dgl_sparse.spspmm(A1.c_sparse_matrix, A2.c_sparse_matrix) + torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix) ) -def mm( - A1: Union[SparseMatrix, DiagMatrix], - A2: Union[torch.Tensor, SparseMatrix, DiagMatrix], +def matmul( + A: Union[torch.Tensor, SparseMatrix, DiagMatrix], + B: Union[torch.Tensor, SparseMatrix, DiagMatrix], ) -> Union[torch.Tensor, SparseMatrix, DiagMatrix]: - """Multiply a sparse/diagonal matrix by a dense/sparse/diagonal matrix. - If an input is a SparseMatrix or DiagMatrix, its non-zero values should - be 1-D. + """Multiply two dense/sparse/diagonal matrices, equivalent to ``A @ B``. + + The supported combinations are shown as follows. + + +--------------+--------+------------+--------------+ + | A \\ B | Tensor | DiagMatrix | SparseMatrix | + +--------------+--------+------------+--------------+ + | Tensor | ✅ | 🚫 | 🚫 | + +--------------+--------+------------+--------------+ + | SparseMatrix | ✅ | ✅ | ✅ | + +--------------+--------+------------+--------------+ + | DiagMatrix | ✅ | ✅ | ✅ | + +--------------+--------+------------+--------------+ + + * If both matrices are torch.Tensor, it calls \ + :func:`torch.matmul()`. The result is a dense matrix. + + * If both matrices are sparse or diagonal, it calls \ + :func:`dgl.sparse.spspmm`. The result is a sparse matrix. + + * If :attr:`A` is sparse or diagonal while :attr:`B` is dense, it \ + calls :func:`dgl.sparse.spmm`. The result is a dense matrix. + + * The operator supports batched sparse-dense matrix multiplication. In \ + this case, the sparse or diagonal matrix :attr:`A` should have shape \ + :math:`(L, M)`, where the non-zero values have a batch dimension \ + :math:`K`. The dense matrix :attr:`B` should have shape \ + :math:`(M, N, K)`. The output is a dense matrix of shape \ + :math:`(L, N, K)`. + + * Sparse-sparse matrix multiplication does not support batched computation. Parameters ---------- - A1 : SparseMatrix or DiagMatrix - Matrix of shape (N, M), with values of shape (nnz1) - A2 : torch.Tensor, SparseMatrix, or DiagMatrix - Matrix of shape (M, P). If it is a SparseMatrix or DiagMatrix, - it should have values of shape (nnz2). + A : torch.Tensor, SparseMatrix or DiagMatrix + The first matrix. + B : torch.Tensor, SparseMatrix, or DiagMatrix + The second matrix. Returns ------- - torch.Tensor or DiagMatrix or SparseMatrix - The result of multiplication of shape (N, P) - - * It is a dense torch tensor if :attr:`A2` is so. - * It is a DiagMatrix object if both :attr:`A1` and :attr:`A2` are so. - * It is a SparseMatrix object otherwise. + torch.Tensor, SparseMatrix or DiagMatrix + The result matrix Examples -------- + Multiply a diagonal matrix with a dense matrix. + >>> val = torch.randn(3) - >>> A1 = diag(val) - >>> A2 = torch.randn(3, 2) - >>> result = dgl.sparse.mm(A1, A2) + >>> A = diag(val) + >>> B = torch.randn(3, 2) + >>> result = dgl.sparse.matmul(A, B) >>> print(type(result)) >>> print(result.shape) torch.Size([3, 2]) + + Multiply a sparse matrix with a dense matrix. + + >>> row = torch.tensor([0, 1, 1]) + >>> col = torch.tensor([1, 0, 1]) + >>> val = torch.randn(len(row)) + >>> A = from_coo(row, col, val) + >>> X = torch.randn(2, 3) + >>> result = dgl.sparse.matmul(A, X) + >>> print(type(result)) + + >>> print(result.shape) + torch.Size([2, 3]) + + Multiply a sparse matrix with a sparse matrix. + + >>> row1 = torch.tensor([0, 1, 1]) + >>> col1 = torch.tensor([1, 0, 1]) + >>> val1 = torch.ones(len(row1)) + >>> A = from_coo(row1, col1, val1) + >>> row2 = torch.tensor([0, 1, 1]) + >>> col2 = torch.tensor([0, 2, 1]) + >>> val2 = torch.ones(len(row2)) + >>> B = from_coo(row2, col2, val2) + >>> result = dgl.sparse.matmul(A, B) + >>> print(type(result)) + + >>> print(result.shape) + (2, 3) """ - assert isinstance( - A1, (SparseMatrix, DiagMatrix) - ), f"Expect arg1 to be a SparseMatrix, or DiagMatrix object, got {type(A1)}." - assert isinstance(A2, (torch.Tensor, SparseMatrix, DiagMatrix)), ( + assert isinstance(A, (torch.Tensor, SparseMatrix, DiagMatrix)), ( + f"Expect arg1 to be a torch.Tensor, SparseMatrix, or DiagMatrix object," + f"got {type(A)}." + ) + assert isinstance(B, (torch.Tensor, SparseMatrix, DiagMatrix)), ( f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix" - f"object, got {type(A2)}." + f"object, got {type(B)}." + ) + if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): + return torch.matmul(A, B) + assert not isinstance(A, torch.Tensor), ( + f"Expect arg2 to be a torch Tensor if arg 1 is torch Tensor, " + f"got {type(B)}." ) - if isinstance(A2, torch.Tensor): - return spmm(A1, A2) - if isinstance(A1, DiagMatrix) and isinstance(A2, DiagMatrix): - return _diag_diag_mm(A1, A2) - return spspmm(A1, A2) + if isinstance(B, torch.Tensor): + return spmm(A, B) + if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix): + return _diag_diag_mm(A, B) + return spspmm(A, B) -SparseMatrix.__matmul__ = mm -DiagMatrix.__matmul__ = mm +SparseMatrix.__matmul__ = matmul +DiagMatrix.__matmul__ = matmul diff --git a/tests/pytorch/sparse/test_matmul.py b/tests/pytorch/sparse/test_matmul.py index eccdbe841167..0cfc3448fc55 100644 --- a/tests/pytorch/sparse/test_matmul.py +++ b/tests/pytorch/sparse/test_matmul.py @@ -4,7 +4,8 @@ import pytest import torch -from dgl.sparse import bspmm, diag, from_coo, mm, val_like +from dgl.sparse import bspmm, diag, from_coo, val_like +from dgl.sparse.matmul import matmul from .utils import ( clone_detach_and_grad, @@ -33,7 +34,7 @@ def test_spmm(create_func, shape, nnz, out_dim): else: X = torch.randn(shape[1], requires_grad=True, device=dev) - sparse_result = A @ X + sparse_result = matmul(A, X) grad = torch.randn_like(sparse_result) sparse_result.backward(grad) @@ -60,7 +61,7 @@ def test_bspmm(create_func, shape, nnz): A = create_func(shape, nnz, dev, 2) X = torch.randn(shape[1], 10, 2, requires_grad=True, device=dev) - sparse_result = bspmm(A, X) + sparse_result = matmul(A, X) grad = torch.randn_like(sparse_result) sparse_result.backward(grad) @@ -92,7 +93,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2): shape2 = (shape_n_m[1], shape_k) A1 = create_func1(shape1, nnz1, dev) A2 = create_func2(shape2, nnz2, dev) - A3 = A1 @ A2 + A3 = matmul(A1, A2) grad = torch.randn_like(A3.val) A3.val.backward(grad) @@ -132,14 +133,14 @@ def test_spspmm_duplicate(): A2 = from_coo(row, col, val, shape) try: - A1 @ A2 + matmul(A1, A2) except: pass else: assert False, "Should raise error." try: - A2 @ A1 + matmul(A2, A1) except: pass else: @@ -155,8 +156,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz): A = create_func(sparse_shape, nnz, dev) diag_val = torch.randn(sparse_shape[1], device=dev, requires_grad=True) D = diag(diag_val, diag_shape) - # (TODO) Need to use dgl.sparse.matmul after rename mm to matmul - B = mm(A, D) + B = matmul(A, D) grad = torch.randn_like(B.val) B.val.backward(grad) @@ -189,8 +189,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz): A = create_func(sparse_shape, nnz, dev) diag_val = torch.randn(sparse_shape[0], device=dev, requires_grad=True) D = diag(diag_val, diag_shape) - # (TODO) Need to use dgl.sparse.matmul after rename mm to matmul - B = mm(D, A) + B = matmul(D, A) grad = torch.randn_like(B.val) B.val.backward(grad)