From 5f5db2dfc76440abd702ebaa8ff533a1125fb737 Mon Sep 17 00:00:00 2001 From: "Quan (Andy) Gan" Date: Fri, 13 Jan 2023 14:00:52 +0800 Subject: [PATCH] [Sparse] Make functions compatible with PyTorch scalar tensors (#5163) * use NotImplemented * format * extend to pytorch scalar * reformat * reformat * lint --- python/dgl/sparse/elementwise_op.py | 17 ++++++++------- python/dgl/sparse/elementwise_op_diag.py | 21 +++++++++---------- python/dgl/sparse/elementwise_op_sp.py | 17 ++++++--------- python/dgl/sparse/utils.py | 14 +++++++++++++ .../sparse/test_elementwise_op_diag.py | 4 +++- .../pytorch/sparse/test_elementwise_op_sp.py | 8 +++++-- 6 files changed, 48 insertions(+), 33 deletions(-) create mode 100644 python/dgl/sparse/utils.py diff --git a/python/dgl/sparse/elementwise_op.py b/python/dgl/sparse/elementwise_op.py index 96c33f516405..acd08ad3a93c 100644 --- a/python/dgl/sparse/elementwise_op.py +++ b/python/dgl/sparse/elementwise_op.py @@ -4,6 +4,7 @@ from .diag_matrix import DiagMatrix from .sparse_matrix import SparseMatrix +from .utils import Scalar __all__ = ["add", "sub", "mul", "div", "power"] @@ -95,8 +96,8 @@ def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]: def mul( - A: Union[SparseMatrix, DiagMatrix, float, int], - B: Union[SparseMatrix, DiagMatrix, float, int], + A: Union[SparseMatrix, DiagMatrix, Scalar], + B: Union[SparseMatrix, DiagMatrix, Scalar], ) -> Union[SparseMatrix, DiagMatrix]: r"""Elementwise multiplication for ``DiagMatrix`` and ``SparseMatrix``, equivalent to ``A * B``. @@ -115,9 +116,9 @@ def mul( Parameters ---------- - A : SparseMatrix or DiagMatrix or float or int + A : SparseMatrix or DiagMatrix or Scalar Sparse matrix or diagonal matrix or scalar value - B : SparseMatrix or DiagMatrix or float or int + B : SparseMatrix or DiagMatrix or Scalar Sparse matrix or diagonal matrix or scalar value Returns @@ -151,7 +152,7 @@ def mul( def div( - A: Union[DiagMatrix], B: Union[DiagMatrix, float, int] + A: Union[DiagMatrix], B: Union[DiagMatrix, Scalar] ) -> Union[DiagMatrix]: r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent to ``A / B``. @@ -172,7 +173,7 @@ def div( ---------- A : DiagMatrix Diagonal matrix - B : DiagMatrix or float or int + B : DiagMatrix or Scalar Diagonal matrix or scalar value Returns @@ -197,7 +198,7 @@ def div( def power( - A: Union[SparseMatrix, DiagMatrix], scalar: Union[float, int] + A: Union[SparseMatrix, DiagMatrix], scalar: Scalar ) -> Union[SparseMatrix, DiagMatrix]: r"""Elementwise exponentiation for ``DiagMatrix`` and ``SparseMatrix``, equivalent to ``A ** scalar``. @@ -218,7 +219,7 @@ def power( ---------- A : SparseMatrix or DiagMatrix Sparse matrix or diagonal matrix - scalar : float or int + scalar : Scalar Exponent Returns diff --git a/python/dgl/sparse/elementwise_op_diag.py b/python/dgl/sparse/elementwise_op_diag.py index 4b9e4457391b..ff8dba5f9bc4 100644 --- a/python/dgl/sparse/elementwise_op_diag.py +++ b/python/dgl/sparse/elementwise_op_diag.py @@ -3,6 +3,7 @@ from .diag_matrix import diag, DiagMatrix from .sparse_matrix import SparseMatrix +from .utils import is_scalar, Scalar def diag_add( @@ -82,14 +83,14 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix: return NotImplemented -def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: +def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix: """Elementwise multiplication Parameters ---------- D1 : DiagMatrix Diagonal matrix - D2 : DiagMatrix or float or int + D2 : DiagMatrix or Scalar Diagonal matrix or scalar value Returns @@ -113,7 +114,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: f"{D1.shape} and D2 {D2.shape} must match." ) return diag(D1.val * D2.val, D1.shape) - elif isinstance(D2, (float, int)): + elif is_scalar(D2): return diag(D1.val * D2, D1.shape) else: # Python falls back to D2.__rmul__(D1) then TypeError when @@ -121,7 +122,7 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: return NotImplemented -def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: +def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix: """Elementwise division of a diagonal matrix by a diagonal matrix or a scalar @@ -129,7 +130,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: ---------- D1 : DiagMatrix Diagonal matrix - D2 : DiagMatrix or float or int + D2 : DiagMatrix or Scalar Diagonal matrix or scalar value. If :attr:`D2` is a DiagMatrix, division is only applied to the diagonal elements. @@ -155,7 +156,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: "must match." ) return diag(D1.val / D2.val, D1.shape) - elif isinstance(D2, (float, int)): + elif is_scalar(D2): assert D2 != 0, "Division by zero is not allowed." return diag(D1.val / D2, D1.shape) else: @@ -165,7 +166,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix: # pylint: disable=invalid-name -def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix: +def diag_power(D: DiagMatrix, scalar: Scalar) -> DiagMatrix: """Take the power of each nonzero element and return a diagonal matrix with the result. @@ -173,7 +174,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix: ---------- D : DiagMatrix Diagonal matrix - scalar : float or int + scalar : Scalar Exponent Returns @@ -189,9 +190,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix: shape=(3, 3)) """ return ( - diag(D.val**scalar, D.shape) - if isinstance(scalar, (float, int)) - else NotImplemented + diag(D.val**scalar, D.shape) if is_scalar(scalar) else NotImplemented ) diff --git a/python/dgl/sparse/elementwise_op_sp.py b/python/dgl/sparse/elementwise_op_sp.py index 5588aaf86a32..f6d1096eb2cc 100644 --- a/python/dgl/sparse/elementwise_op_sp.py +++ b/python/dgl/sparse/elementwise_op_sp.py @@ -1,9 +1,8 @@ """DGL elementwise operators for sparse matrix module.""" -from typing import Union - import torch from .sparse_matrix import SparseMatrix, val_like +from .utils import is_scalar, Scalar def spsp_add(A, B): @@ -46,14 +45,14 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix: return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented -def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix: +def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix: """Elementwise multiplication Parameters ---------- A : SparseMatrix First operand - B : float or int + B : Scalar Second operand Returns @@ -81,7 +80,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix: values=tensor([2, 4, 6]), shape=(3, 4), nnz=3) """ - if isinstance(B, (float, int)): + if is_scalar(B): return val_like(A, A.val * B) # Python falls back to B.__rmul__(A) then TypeError when NotImplemented is # returned. @@ -90,7 +89,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix: return NotImplemented -def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix: +def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix: """Take the power of each nonzero element and return a sparse matrix with the result. @@ -120,11 +119,7 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix: """ # Python falls back to scalar.__rpow__ then TypeError when NotImplemented # is returned. - return ( - val_like(A, A.val**scalar) - if isinstance(scalar, (float, int)) - else NotImplemented - ) + return val_like(A, A.val**scalar) if is_scalar(scalar) else NotImplemented SparseMatrix.__add__ = sp_add diff --git a/python/dgl/sparse/utils.py b/python/dgl/sparse/utils.py new file mode 100644 index 000000000000..c95889362021 --- /dev/null +++ b/python/dgl/sparse/utils.py @@ -0,0 +1,14 @@ +"""Utilities for DGL sparse module.""" +from numbers import Number +from typing import Union + +import torch + + +def is_scalar(x): + """Check if the input is a scalar.""" + return isinstance(x, Number) or (torch.is_tensor(x) and x.dim() == 0) + + +# Scalar type annotation +Scalar = Union[Number, torch.Tensor] diff --git a/tests/pytorch/sparse/test_elementwise_op_diag.py b/tests/pytorch/sparse/test_elementwise_op_diag.py index 898fbedef0e1..ee2d6287c82e 100644 --- a/tests/pytorch/sparse/test_elementwise_op_diag.py +++ b/tests/pytorch/sparse/test_elementwise_op_diag.py @@ -29,7 +29,9 @@ def test_diag_op_diag(op): assert result.shape == D1.shape -@pytest.mark.parametrize("v_scalar", [2, 2.5]) +@pytest.mark.parametrize( + "v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)] +) def test_diag_op_scalar(v_scalar): ctx = F.ctx() shape = (3, 4) diff --git a/tests/pytorch/sparse/test_elementwise_op_sp.py b/tests/pytorch/sparse/test_elementwise_op_sp.py index 54e6df8c9baa..ef2cecb93074 100644 --- a/tests/pytorch/sparse/test_elementwise_op_sp.py +++ b/tests/pytorch/sparse/test_elementwise_op_sp.py @@ -20,7 +20,9 @@ def all_close_sparse(A, row, col, val, shape): assert A.shape == shape -@pytest.mark.parametrize("v_scalar", [2, 2.5]) +@pytest.mark.parametrize( + "v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)] +) def test_mul_scalar(v_scalar): ctx = F.ctx() row = torch.tensor([1, 0, 2]).to(ctx) @@ -65,7 +67,9 @@ def test_pow(val_shape): @pytest.mark.parametrize("op", ["add", "sub"]) -@pytest.mark.parametrize("v_scalar", [2, 2.5]) +@pytest.mark.parametrize( + "v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)] +) def test_error_op_scalar(op, v_scalar): ctx = F.ctx() row = torch.tensor([1, 0, 2]).to(ctx)