Skip to content

Commit

Permalink
[Sparse] Support sparse matrix dividing scalar (dmlc#5173)
Browse files Browse the repository at this point in the history
* use NotImplemented

* format

* extend to pytorch scalar

* sparse div scalar

* oops

* Apply suggestions from code review

Co-authored-by: Mufei Li <[email protected]>

Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
BarclayII and mufeili authored Jan 13, 2023
1 parent b5c5c86 commit acc567a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
20 changes: 15 additions & 5 deletions python/dgl/sparse/elementwise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def mul(


def div(
A: Union[DiagMatrix], B: Union[DiagMatrix, Scalar]
) -> Union[DiagMatrix]:
A: Union[SparseMatrix, DiagMatrix], B: Union[DiagMatrix, Scalar]
) -> Union[SparseMatrix, DiagMatrix]:
r"""Elementwise division for ``DiagMatrix`` and ``SparseMatrix``, equivalent
to ``A / B``.
Expand All @@ -164,15 +164,15 @@ def div(
+--------------+------------+--------------+--------+
| DiagMatrix | ✅ | 🚫 | ✅ |
+--------------+------------+--------------+--------+
| SparseMatrix | 🚫 | 🚫 | 🚫 |
| SparseMatrix | 🚫 | 🚫 | |
+--------------+------------+--------------+--------+
| scalar | 🚫 | 🚫 | 🚫 |
+--------------+------------+--------------+--------+
Parameters
----------
A : DiagMatrix
Diagonal matrix
A : SparseMatrix or DiagMatrix
Sparse or diagonal matrix
B : DiagMatrix or Scalar
Diagonal matrix or scalar value
Expand All @@ -193,6 +193,16 @@ def div(
>>> div(A, 2)
DiagMatrix(val=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 3))
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([1, 2, 3])
>>> A = from_coo(row, col, val, shape=(3, 4))
>>> A / 2
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 4), nnz=3)
"""
return A / B

Expand Down
35 changes: 35 additions & 0 deletions python/dgl/sparse/elementwise_op_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,40 @@ def sp_mul(A: SparseMatrix, B: Scalar) -> SparseMatrix:
return NotImplemented


def sp_div(A: SparseMatrix, B: Scalar) -> SparseMatrix:
"""Elementwise division
Parameters
----------
A : SparseMatrix
First operand
B : Scalar
Second operand
Returns
-------
SparseMatrix
Result of A / B
Examples
--------
>>> row = torch.tensor([1, 0, 2])
>>> col = torch.tensor([0, 3, 2])
>>> val = torch.tensor([1, 2, 3])
>>> A = from_coo(row, col, val, shape=(3, 4))
>>> A / 2
SparseMatrix(indices=tensor([[1, 0, 2],
[0, 3, 2]]),
values=tensor([0.5000, 1.0000, 1.5000]),
shape=(3, 4), nnz=3)
"""
if is_scalar(B):
return val_like(A, A.val / B)
# Python falls back to B.__rtruediv__(A) then TypeError when NotImplemented
# is returned.
return NotImplemented


def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
"""Take the power of each nonzero element and return a sparse matrix with
the result.
Expand Down Expand Up @@ -125,4 +159,5 @@ def sp_power(A: SparseMatrix, scalar: Scalar) -> SparseMatrix:
SparseMatrix.__add__ = sp_add
SparseMatrix.__mul__ = sp_mul
SparseMatrix.__rmul__ = sp_mul
SparseMatrix.__truediv__ = sp_div
SparseMatrix.__pow__ = sp_power
11 changes: 10 additions & 1 deletion tests/pytorch/sparse/test_elementwise_op_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def all_close_sparse(A, row, col, val, shape):
@pytest.mark.parametrize(
"v_scalar", [2, 2.5, torch.tensor(2), torch.tensor(2.5)]
)
def test_mul_scalar(v_scalar):
def test_muldiv_scalar(v_scalar):
ctx = F.ctx()
row = torch.tensor([1, 0, 2]).to(ctx)
col = torch.tensor([0, 3, 2]).to(ctx)
Expand All @@ -40,6 +40,15 @@ def test_mul_scalar(v_scalar):
assert torch.allclose(A1.val * v_scalar, A2.val, rtol=1e-4, atol=1e-4)
assert A1.shape == A2.shape

# A / v
A2 = A1 / v_scalar
assert torch.allclose(A1.val / v_scalar, A2.val, rtol=1e-4, atol=1e-4)
assert A1.shape == A2.shape

# v / A
with pytest.raises(TypeError):
v_scalar / A1


@pytest.mark.parametrize("val_shape", [(3,), (3, 2)])
def test_pow(val_shape):
Expand Down

0 comments on commit acc567a

Please sign in to comment.