Skip to content

Commit

Permalink
[Sparse] Make functions compatible with PyTorch scalar tensors (dmlc#…
Browse files Browse the repository at this point in the history
…5163)

* use NotImplemented

* format

* extend to pytorch scalar

* reformat

* reformat

* lint
  • Loading branch information
BarclayII authored Jan 13, 2023
1 parent 1d1b08b commit 5f5db2d
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 33 deletions.
17 changes: 9 additions & 8 deletions python/dgl/sparse/elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix
from .utils import Scalar

__all__ = ["add", "sub", "mul", "div", "power"]

Expand Down Expand Up @@ -95,8 +96,8 @@ def sub(A: Union[DiagMatrix], B: Union[DiagMatrix]) -> Union[DiagMatrix]:


def mul(
A: Union[SparseMatrix, DiagMatrix, float, int],
B: Union[SparseMatrix, DiagMatrix, float, int],
A: Union[SparseMatrix, DiagMatrix, Scalar],
B: Union[SparseMatrix, DiagMatrix, Scalar],
) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise multiplication for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A * B``.
Expand All @@ -115,9 +116,9 @@ def mul(
Parameters
----------
A : SparseMatrix or DiagMatrix or float or int
A : SparseMatrix or DiagMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value
B : SparseMatrix or DiagMatrix or float or int
B : SparseMatrix or DiagMatrix or Scalar
Sparse matrix or diagonal matrix or scalar value
Returns
Expand Down Expand Up @@ -151,7 +152,7 @@ def mul(


def div(
A: Union[DiagMatrix], B: Union[DiagMatrix, float, int]
A: Union[DiagMatrix], B: Union[DiagMatrix, Scalar]
) -> Union[DiagMatrix]:
r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent
to ``A / B``.
Expand All @@ -172,7 +173,7 @@ def div(
----------
A : DiagMatrix
Diagonal matrix
B : DiagMatrix or float or int
B : DiagMatrix or Scalar
Diagonal matrix or scalar value
Returns
Expand All @@ -197,7 +198,7 @@ def div(


def power(
A: Union[SparseMatrix, DiagMatrix], scalar: Union[float, int]
A: Union[SparseMatrix, DiagMatrix], scalar: Scalar
) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise exponentiation for ``DiagMatrix`` and ``SparseMatrix``,
equivalent to ``A ** scalar``.
Expand All @@ -218,7 +219,7 @@ def power(
----------
A : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix
scalar : float or int
scalar : Scalar
Exponent
Returns
Expand Down
21 changes: 10 additions & 11 deletions python/dgl/sparse/elementwise_op_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .diag_matrix import diag, DiagMatrix
from .sparse_matrix import SparseMatrix
from .utils import is_scalar, Scalar


def diag_add(
Expand Down Expand Up @@ -82,14 +83,14 @@ def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
return NotImplemented


def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise multiplication
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or float or int
D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value
Returns
Expand All @@ -113,23 +114,23 @@ def diag_mul(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
f"{D1.shape} and D2 {D2.shape} must match."
)
return diag(D1.val * D2.val, D1.shape)
elif isinstance(D2, (float, int)):
elif is_scalar(D2):
return diag(D1.val * D2, D1.shape)
else:
# Python falls back to D2.__rmul__(D1) then TypeError when
# NotImplemented is returned.
return NotImplemented


def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, Scalar]) -> DiagMatrix:
"""Elementwise division of a diagonal matrix by a diagonal matrix or a
scalar
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or float or int
D2 : DiagMatrix or Scalar
Diagonal matrix or scalar value. If :attr:`D2` is a DiagMatrix,
division is only applied to the diagonal elements.
Expand All @@ -155,7 +156,7 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:
"must match."
)
return diag(D1.val / D2.val, D1.shape)
elif isinstance(D2, (float, int)):
elif is_scalar(D2):
assert D2 != 0, "Division by zero is not allowed."
return diag(D1.val / D2, D1.shape)
else:
Expand All @@ -165,15 +166,15 @@ def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float, int]) -> DiagMatrix:


# pylint: disable=invalid-name
def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
def diag_power(D: DiagMatrix, scalar: Scalar) -> DiagMatrix:
"""Take the power of each nonzero element and return a diagonal matrix with
the result.
Parameters
----------
D : DiagMatrix
Diagonal matrix
scalar : float or int
scalar : Scalar
Exponent
Returns
Expand All @@ -189,9 +190,7 @@ def diag_power(D: DiagMatrix, scalar: Union[float, int]) -> DiagMatrix:
shape=(3, 3))
"""
return (
diag(D.val**scalar, D.shape)
if isinstance(scalar, (float, int))
else NotImplemented
diag(D.val**scalar, D.shape) if is_scalar(scalar) else NotImplemented
)


Expand Down
17 changes: 6 additions & 11 deletions python/dgl/sparse/elementwise_op_sp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""DGL elementwise operators for sparse matrix module."""
from typing import Union

import torch

from .sparse_matrix import SparseMatrix, val_like
from .utils import is_scalar, Scalar


def spsp_add(A, B):
Expand Down Expand Up @@ -46,14 +45,14 @@ def sp_add(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
return spsp_add(A, B) if isinstance(B, SparseMatrix) else NotImplemented


def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
"""Elementwise multiplication
Parameters
----------
A : SparseMatrix
First operand
B : float or int
B : Scalar
Second operand
Returns
Expand Down Expand Up @@ -81,7 +80,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
values=tensor([2, 4, 6]),
shape=(3, 4), nnz=3)
"""
if isinstance(B, (float, int)):
if is_scalar(B):
return val_like(A, A.val * B)
# Python falls back to B.__rmul__(A) then TypeError when NotImplemented is
# returned.
Expand All @@ -90,7 +89,7 @@ def sp_mul(A: SparseMatrix, B: Union[float, int]) -> SparseMatrix:
return NotImplemented


def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
"""Take the power of each nonzero element and return a sparse matrix with
the result.
Expand Down Expand Up @@ -120,11 +119,7 @@ def sp_power(A: SparseMatrix, scalar: Union[float, int]) -> SparseMatrix:
"""
# Python falls back to scalar.__rpow__ then TypeError when NotImplemented
# is returned.
return (
val_like(A, A.val**scalar)
if isinstance(scalar, (float, int))
else NotImplemented
)
return val_like(A, A.val**scalar) if is_scalar(scalar) else NotImplemented


SparseMatrix.__add__ = sp_add
Expand Down
14 changes: 14 additions & 0 deletions python/dgl/sparse/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Utilities for DGL sparse module."""
from numbers import Number
from typing import Union

import torch


def is_scalar(x):
"""Check if the input is a scalar."""
return isinstance(x, Number) or (torch.is_tensor(x) and x.dim() == 0)


# Scalar type annotation
Scalar = Union[Number, torch.Tensor]
4 changes: 3 additions & 1 deletion tests/pytorch/sparse/test_elementwise_op_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def test_diag_op_diag(op):
assert result.shape == D1.shape


@pytest.mark.parametrize("v_scalar", [2, 2.5])
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_diag_op_scalar(v_scalar):
ctx = F.ctx()
shape = (3, 4)
Expand Down
8 changes: 6 additions & 2 deletions tests/pytorch/sparse/test_elementwise_op_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def all_close_sparse(A, row, col, val, shape):
assert A.shape == shape


@pytest.mark.parametrize("v_scalar", [2, 2.5])
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_mul_scalar(v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
Expand Down Expand Up @@ -65,7 +67,9 @@ def test_pow(val_shape):


@pytest.mark.parametrize("op", ["add", "sub"])
@pytest.mark.parametrize("v_scalar", [2, 2.5])
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_error_op_scalar(op, v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
Expand Down

0 comments on commit 5f5db2d

Please sign in to comment.