diff --git a/test/composite/test_logsumexp.py b/test/composite/test_logsumexp.py index 49e7a9c4..a196ed69 100644 --- a/test/composite/test_logsumexp.py +++ b/test/composite/test_logsumexp.py @@ -1,31 +1,64 @@ +from itertools import product + +import pytest import torch from torch_scatter import scatter_logsumexp +from torch_scatter.testing import float_dtypes, assert_equal + +edge_values = [0.0, 1.0, -1e33, 1e33, float("nan"), float("-inf"), + float("inf")] + +tests = [ + [0.5, -2.1, 3.2], + [], + *map(list, product(edge_values, edge_values)), +] + + +@pytest.mark.parametrize('src,dtype', product(tests, float_dtypes)) +def test_logsumexp(src, dtype): + src = torch.tensor(src, dtype=dtype) + index = torch.zeros_like(src, dtype=torch.long) + out_scatter = scatter_logsumexp(src, index, dim_size=1) + out_torch = torch.logsumexp(src, dim=0, keepdim=True) + assert_equal(out_scatter, out_torch, equal_nan=True) + +@pytest.mark.parametrize('src,out', product(tests, edge_values)) +def test_logsumexp_inplace(src, out): + src = torch.tensor(src) + out = torch.tensor([out]) + out_scatter = out.clone() + index = torch.zeros_like(src, dtype=torch.long) + scatter_logsumexp(src, index, out=out_scatter) + out_torch = torch.logsumexp(torch.cat([out, src]), dim=0, keepdim=True) + assert_equal(out_scatter, out_torch, equal_nan=True) -def test_logsumexp(): - inputs = torch.tensor([ - 0.5, - 0.5, - 0.0, - -2.1, - 3.2, - 7.0, - -1.0, - -100.0, - ]) - inputs.requires_grad_() - index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4]) - splits = [2, 3, 1, 0, 2] - - outputs = scatter_logsumexp(inputs, index) - - for src, out in zip(inputs.split(splits), outputs.unbind()): - if src.numel() > 0: - assert out.tolist() == torch.logsumexp(src, dim=0).tolist() - else: - assert out.item() == 0.0 + +def test_logsumexp_parallel_backward_jit(): + splits = [len(src) for src in tests] + srcs = torch.tensor(sum(tests, start=[])) + index = torch.repeat_interleave(torch.tensor(splits)) + + srcs.requires_grad_() + outputs = scatter_logsumexp(srcs, index) + + for src, out_scatter in zip(srcs.split(splits), outputs.unbind()): + out_torch = torch.logsumexp(src, dim=0) + assert_equal(out_scatter, out_torch, equal_nan=True) outputs.backward(torch.randn_like(outputs)) jit = torch.jit.script(scatter_logsumexp) - assert jit(inputs, index).tolist() == outputs.tolist() + assert_equal(jit(srcs, index), outputs, equal_nan=True) + + +def test_logsumexp_inplace_dimsize(): + # if both `out` and `dim_size` are provided, they should match + src = torch.zeros(3) + index = src.to(torch.long) + out = torch.zeros(1) + + scatter_logsumexp(src, index, 0, out, dim_size=1) + with pytest.raises(AssertionError): + scatter_logsumexp(src, index, 0, out, dim_size=2) diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py index 355d0c0e..e61c9b84 100644 --- a/torch_scatter/composite/logsumexp.py +++ b/torch_scatter/composite/logsumexp.py @@ -8,34 +8,36 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - eps: float = 1e-12) -> torch.Tensor: + dim_size: Optional[int] = None) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') index = broadcast(index, src, dim) - if out is not None: - dim_size = out.size(dim) - else: - if dim_size is None: + if dim_size is None: + if out is not None: + dim_size = out.size(dim) + else: dim_size = int(index.max()) + 1 + elif out is not None: + assert dim_size == out.size(dim) size = list(src.size()) size[dim] = dim_size - max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, - device=src.device) - scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0] + + if out is None: + max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, + device=src.device) + else: + max_value_per_index = out.clone() + scatter_max(src, index, dim, max_value_per_index) + max_value_per_index.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) max_per_src_element = max_value_per_index.gather(dim, index) - recentered_score = src - max_per_src_element - recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf')) + src_sub_max = src - max_per_src_element if out is not None: - out = out.sub_(max_value_per_index).exp_() - - sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out, - dim_size) + out.sub_(max_value_per_index).exp_() - out = sum_per_index.add_(eps).log_().add_(max_value_per_index) - return out.nan_to_num_(neginf=0.0) + sum_per_index = scatter_sum(src_sub_max.exp_(), index, dim, out, dim_size) + return sum_per_index.log_().add_(max_value_per_index) diff --git a/torch_scatter/testing.py b/torch_scatter/testing.py index 2407b8a0..24ad3877 100644 --- a/torch_scatter/testing.py +++ b/torch_scatter/testing.py @@ -8,6 +8,7 @@ torch.half, torch.bfloat16, torch.float, torch.double, torch.int, torch.long ] +float_dtypes = list(filter(lambda x: x.is_floating_point, dtypes)) grad_dtypes = [torch.float, torch.double] devices = [torch.device('cpu')] @@ -17,3 +18,9 @@ def tensor(x: Any, dtype: torch.dtype, device: torch.device): return None if x is None else torch.tensor(x, device=device).to(dtype) + + +def assert_equal(actual: torch.Tensor, expected: torch.Tensor, + equal_nan=False): + torch.testing.assert_close(actual, expected, equal_nan=equal_nan, rtol=0, + atol=0)