From 699120e2584077b0113eb9519e58c5e3dbcc1ec6 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sat, 2 Dec 2023 12:24:23 +0100 Subject: [PATCH] More in-depth tests for `EdgeIndex` (#8510) --- test/data/test_edge_index.py | 338 +++++++++++++++++++---------- test/utils/test_index_sort.py | 16 ++ torch_geometric/data/edge_index.py | 97 +++++---- torch_geometric/utils/sort.py | 6 +- 4 files changed, 292 insertions(+), 165 deletions(-) create mode 100644 test/utils/test_index_sort.py diff --git a/test/data/test_edge_index.py b/test/data/test_edge_index.py index 31b57b12c0bb..0dbd10ea7626 100644 --- a/test/data/test_edge_index.py +++ b/test/data/test_edge_index.py @@ -4,10 +4,10 @@ import pytest import torch -from torch import Tensor +from torch import Tensor, tensor import torch_geometric -from torch_geometric.data.edge_index import EdgeIndex +from torch_geometric.data.edge_index import SUPPORTED_DTYPES, EdgeIndex from torch_geometric.profile import benchmark from torch_geometric.testing import ( disableExtensions, @@ -19,13 +19,23 @@ from torch_geometric.typing import SparseTensor from torch_geometric.utils import scatter +DTYPES = [pytest.param(dtype, id=str(dtype)[6:]) for dtype in SUPPORTED_DTYPES] +IS_UNDIRECTED = [ + pytest.param(False, id='directed'), + pytest.param(True, id='undirected'), +] -def test_basic(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3)) + +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_basic(dtype, device): + kwargs = dict(dtype=dtype, device=device, sparse_size=(3, 3)) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) adj.validate() assert isinstance(adj, EdgeIndex) - assert str(adj) == ('EdgeIndex([[0, 1, 1, 2],\n' - ' [1, 0, 2, 1]])') + assert str(adj).startswith('EdgeIndex([[0, 1, 1, 2],') + assert adj.dtype == dtype + assert adj.device == device assert adj.sparse_size == (3, 3) assert adj.sort_order is None @@ -35,13 +45,22 @@ def test_basic(): assert not adj.is_undirected - assert not isinstance(adj.as_tensor(), EdgeIndex) + out = adj.as_tensor() + assert not isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device - assert not isinstance(adj + 1, EdgeIndex) + out = adj + 1 + assert not isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device -def test_undirected(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], is_undirected=True) +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_undirected(dtype, device): + kwargs = dict(dtype=dtype, device=device, is_undirected=True) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) assert isinstance(adj, EdgeIndex) assert adj.is_undirected @@ -52,63 +71,119 @@ def test_undirected(): adj.validate() with pytest.raises(ValueError, match="'EdgeIndex' is not undirected"): - EdgeIndex([[0, 1, 1, 2], [0, 0, 1, 1]], is_undirected=True).validate() + EdgeIndex([[0, 1, 1, 2], [0, 0, 1, 1]], **kwargs).validate() -def test_fill_cache_(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_fill_cache_(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.validate().fill_cache_() assert adj.sparse_size == (3, 3) - assert torch.equal(adj._rowptr, torch.tensor([0, 1, 3, 4])) - - assert adj.sort_order == 'row' - assert adj.is_sorted - assert adj.is_sorted_by_row - assert not adj.is_sorted_by_col - - adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col') + assert adj._rowptr.dtype == dtype + assert adj._rowptr.equal(tensor([0, 1, 3, 4], device=device)) + assert adj._csr_col is None + assert adj._csr2csc.dtype == torch.int64 + assert (adj._csr2csc.equal(tensor([1, 0, 3, 2], device=device)) + or adj._csr2csc.equal(tensor([1, 3, 0, 2], device=device))) + if is_undirected: + assert adj._colptr is None + else: + assert adj._colptr.dtype == dtype + assert adj._colptr.equal(tensor([0, 1, 3, 4], device=device)) + assert adj._csc_row.dtype == dtype + assert (adj._csc_row.equal(tensor([1, 0, 2, 1], device=device)) + or adj._csc_row.equal(tensor([1, 2, 0, 1], device=device))) + assert adj._csc2csr is None + + adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) adj.validate().fill_cache_() assert adj.sparse_size == (3, 3) - assert torch.equal(adj._colptr, torch.tensor([0, 1, 3, 4])) - - assert adj.sort_order == 'col' - assert adj.is_sorted - assert not adj.is_sorted_by_row - assert adj.is_sorted_by_col + assert adj._colptr.dtype == dtype + assert adj._colptr.equal(tensor([0, 1, 3, 4], device=device)) + assert adj._csc_row is None + assert (adj._csc2csr.equal(tensor([1, 0, 3, 2], device=device)) + or adj._csc2csr.equal(tensor([1, 3, 0, 2], device=device))) + if is_undirected: + assert adj._rowptr is None + else: + assert adj._rowptr.dtype == dtype + assert adj._rowptr.equal(tensor([0, 1, 3, 4], device=device)) + assert adj._csr_col.dtype == dtype + assert (adj._csr_col.equal(tensor([1, 0, 2, 1], device=device)) + or adj._csr_col.equal(tensor([1, 2, 0, 1], device=device))) + assert adj._csr2csc is None -def test_clone(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_clone(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.clone() assert isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device assert out.is_sorted_by_row + assert out.is_undirected == is_undirected out = torch.clone(adj) assert isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device assert out.is_sorted_by_row + assert out.is_undirected == is_undirected @withCUDA -def test_to(device): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]]) +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_to(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + adj.fill_cache_() + + adj = adj.to(device) + assert isinstance(adj, EdgeIndex) + assert adj.device == device + assert adj._rowptr.device == device + assert adj._csr2csc.device == device out = adj.to(torch.int) - assert isinstance(out, EdgeIndex) assert out.dtype == torch.int + if torch_geometric.typing.WITH_PT113: + assert isinstance(out, EdgeIndex) + assert out._rowptr.dtype == torch.int + assert out._csr2csc.dtype == torch.int + else: + assert not isinstance(out, EdgeIndex) out = adj.to(torch.float) assert not isinstance(out, EdgeIndex) assert out.dtype == torch.float - out = adj.to(device) + out = adj.long() assert isinstance(out, EdgeIndex) - assert out.device == device + assert out.dtype == torch.int64 + + out = adj.int() + assert out.dtype == torch.int + if torch_geometric.typing.WITH_PT113: + assert isinstance(out, EdgeIndex) + else: + assert not isinstance(out, EdgeIndex) @onlyCUDA -def test_cpu_cuda(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]]) +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_cpu_cuda(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) out = adj.cuda() assert isinstance(out, EdgeIndex) @@ -119,8 +194,12 @@ def test_cpu_cuda(): assert not out.is_cuda -def test_share_memory(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_share_memory(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.fill_cache_() adj = adj.share_memory_() @@ -129,8 +208,11 @@ def test_share_memory(): assert adj._rowptr.is_shared() -def test_contiguous(): - data = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]]).t() +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_contiguous(dtype, device): + kwargs = dict(dtype=dtype, device=device) + data = tensor([[0, 1], [1, 0], [1, 2], [2, 1]], **kwargs).t() with pytest.raises(ValueError, match="needs to be contiguous"): EdgeIndex(data) @@ -140,135 +222,163 @@ def test_contiguous(): assert adj.is_contiguous() -@pytest.mark.parametrize('is_undirected', [False, True]) -def test_sort_by(is_undirected): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', - is_undirected=is_undirected) +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_sort_by(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.sort_by('row') assert isinstance(out, torch.return_types.sort) assert isinstance(out.values, EdgeIndex) assert not isinstance(out.indices, EdgeIndex) - assert torch.equal(out.values, adj) + assert out.values.equal(adj) assert out.indices == slice(None, None, None) - adj = EdgeIndex([[0, 1, 2, 1], [1, 0, 1, 2]], is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 2, 1], [1, 0, 1, 2]], **kwargs) out = adj.sort_by('row') assert isinstance(out, torch.return_types.sort) assert isinstance(out.values, EdgeIndex) assert not isinstance(out.indices, EdgeIndex) - assert torch.equal(out.values, torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) - assert torch.equal(out.indices, torch.tensor([0, 1, 3, 2])) - - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', - is_undirected=is_undirected) - adj.fill_cache_() - - out = adj.sort_by('col') - assert torch.equal(out.values, torch.tensor([[1, 0, 2, 1], [0, 1, 1, 2]])) - assert torch.equal(out.indices, torch.tensor([1, 0, 3, 2])) - assert torch.equal(out.values._csr2csc, torch.tensor([1, 0, 3, 2])) - - out = out.values.sort_by('row') - assert torch.equal(out.values, torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) - assert torch.equal(out.indices, torch.tensor([1, 0, 3, 2])) - assert torch.equal(out.values._csr2csc, torch.tensor([1, 0, 3, 2])) - assert torch.equal(out.values._csc2csr, torch.tensor([1, 0, 3, 2])) + assert out.values[0].equal(tensor([0, 1, 1, 2], device=device)) + assert (out.values[1].equal(tensor([1, 0, 2, 1], device=device)) + or out.values[1].equal(tensor([1, 2, 0, 1], device=device))) + assert (out.indices.equal(tensor([0, 1, 3, 2], device=device)) + or out.indices.equal(tensor([0, 3, 1, 2], device=device))) + + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + out, perm = adj.sort_by('col') + assert adj._csr2csc is not None # Check caches. + assert adj._csc_row is not None + assert (out[0].equal(tensor([1, 0, 2, 1], device=device)) + or out[0].equal(tensor([1, 2, 0, 1], device=device))) + assert out[1].equal(tensor([0, 1, 1, 2], device=device)) + assert (perm.equal(tensor([1, 0, 3, 2], device=device)) + or perm.equal(tensor([1, 3, 0, 2], device=device))) + assert out._csr2csc is None + assert out._csc2csr is None + + out, perm = out.sort_by('row') + assert out[0].equal(tensor([0, 1, 1, 2], device=device)) + assert (out[1].equal(tensor([1, 0, 2, 1], device=device)) + or out[1].equal(tensor([1, 2, 0, 1], device=device))) + assert (perm.equal(tensor([1, 0, 3, 2], device=device)) + or perm.equal(tensor([2, 3, 0, 1], device=device))) + assert out._csr2csc is None + assert out._csc2csr is None - # Do another round to sort based on `_csr2csc` and `_csc2csr`: - out = out.values.sort_by('col') - assert torch.equal(out.values, torch.tensor([[1, 0, 2, 1], [0, 1, 1, 2]])) - assert torch.equal(out.indices, torch.tensor([1, 0, 3, 2])) - assert torch.equal(out.values._csr2csc, torch.tensor([1, 0, 3, 2])) - assert torch.equal(out.values._csc2csr, torch.tensor([1, 0, 3, 2])) - out = out.values.sort_by('row') - assert torch.equal(out.values, torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) - assert torch.equal(out.indices, torch.tensor([1, 0, 3, 2])) - assert torch.equal(out.values._csr2csc, torch.tensor([1, 0, 3, 2])) - assert torch.equal(out.values._csc2csr, torch.tensor([1, 0, 3, 2])) - - -def test_cat(): - adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3)) - adj2 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], sparse_size=(4, 4)) +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_cat(dtype, device, is_undirected): + args = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **args) + adj2 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], sparse_size=(4, 4), **args) out = torch.cat([adj1, adj2], dim=1) assert out.size() == (2, 8) assert isinstance(out, EdgeIndex) assert out.sparse_size == (4, 4) assert not out.is_sorted + assert out.is_undirected == is_undirected out = torch.cat([adj1, adj2], dim=0) assert out.size() == (4, 4) assert not isinstance(out, EdgeIndex) -def test_flip(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_flip(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) adj.fill_cache_() out = adj.flip(0) assert isinstance(out, EdgeIndex) - assert torch.equal(out, torch.tensor([[1, 0, 2, 1], [0, 1, 1, 2]])) + assert out.equal(tensor([[1, 0, 2, 1], [0, 1, 1, 2]], device=device)) assert out.sparse_size == (3, 3) assert out.is_sorted_by_col - assert torch.equal(out._colptr, torch.tensor([0, 1, 3, 4])) + assert out.is_undirected == is_undirected + assert out._colptr.equal(tensor([0, 1, 3, 4], device=device)) out = adj.flip([0, 1]) assert isinstance(out, EdgeIndex) - assert torch.equal(out, torch.tensor([[1, 2, 0, 1], [2, 1, 1, 0]])) + assert out.equal(tensor([[1, 2, 0, 1], [2, 1, 1, 0]], device=device)) assert out.sparse_size == (3, 3) assert not out.is_sorted + assert out.is_undirected == is_undirected assert out._colptr is None -def test_index_select(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') - - out = adj.index_select(1, torch.tensor([1, 3])) - assert torch.equal(out, torch.tensor([[1, 2], [0, 1]])) +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_index_select(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + out = adj.index_select(1, tensor([1, 3], device=device)) + assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert isinstance(out, EdgeIndex) + assert not out.is_sorted + assert not out.is_undirected - out = adj.index_select(0, torch.tensor([0])) - assert torch.equal(out, torch.tensor([[0, 1, 1, 2]])) + out = adj.index_select(0, tensor([0], device=device)) + assert out.equal(tensor([[0, 1, 1, 2]], device=device)) assert not isinstance(out, EdgeIndex) -def test_narrow(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_narrow(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) out = adj.narrow(dim=1, start=1, length=2) - assert torch.equal(out, torch.tensor([[1, 1], [0, 2]])) assert isinstance(out, EdgeIndex) + assert out.equal(tensor([[1, 1], [0, 2]], device=device)) assert out.is_sorted_by_row + assert not out.is_undirected out = adj.narrow(dim=0, start=0, length=1) - assert torch.equal(out, torch.tensor([[0, 1, 1, 2]])) assert not isinstance(out, EdgeIndex) + assert out.equal(tensor([[0, 1, 1, 2]], device=device)) -def test_getitem(): - adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_getitem(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) - out = adj[:, torch.tensor([False, True, False, True])] + out = adj[:, tensor([False, True, False, True], device=device)] assert isinstance(out, EdgeIndex) - assert torch.equal(out, torch.tensor([[1, 2], [0, 1]])) + assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert out.is_sorted_by_row + assert not out.is_undirected - out = adj[..., torch.tensor([1, 3])] + out = adj[..., tensor([1, 3], device=device)] assert isinstance(out, EdgeIndex) - assert torch.equal(out, torch.tensor([[1, 2], [0, 1]])) + assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert not out.is_sorted + assert not out.is_undirected out = adj[..., 1::2] assert isinstance(out, EdgeIndex) - assert torch.equal(out, torch.tensor([[1, 2], [0, 1]])) + assert out.equal(tensor([[1, 2], [0, 1]], device=device)) assert out.is_sorted_by_row + assert not out.is_undirected out = adj[:, 0] assert not isinstance(out, EdgeIndex) - out = adj[torch.tensor([0])] + out = adj[tensor([0], device=device)] assert not isinstance(out, EdgeIndex) @@ -306,7 +416,7 @@ def test_to_sparse_coo(): assert isinstance(out, Tensor) assert out.layout == torch.sparse_coo assert out.size() == (3, 3) - assert torch.equal(adj, out._indices()) + assert adj.equal(out._indices()) # Test clunky dispatch logic for `to_sparse_coo()`: adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]]) @@ -314,7 +424,7 @@ def test_to_sparse_coo(): assert isinstance(out, Tensor) assert out.layout == torch.sparse_coo assert out.size() == (3, 3) - assert torch.equal(adj, out._indices()) + assert adj.equal(out._indices()) def test_to_sparse_csr(): @@ -329,8 +439,8 @@ def test_to_sparse_csr(): assert isinstance(out, Tensor) assert out.layout == torch.sparse_csr assert out.size() == (3, 3) - assert torch.equal(adj._rowptr, out.crow_indices()) - assert torch.equal(adj[1], out.col_indices()) + assert adj._rowptr.equal(out.crow_indices()) + assert adj[1].equal(out.col_indices()) def test_to_sparse_csc(): @@ -345,8 +455,8 @@ def test_to_sparse_csc(): assert isinstance(out, Tensor) assert out.layout == torch.sparse_csc assert out.size() == (3, 3) - assert torch.equal(adj._colptr, out.ccol_indices()) - assert torch.equal(adj[0], out.row_indices()) + assert adj._colptr.equal(out.ccol_indices()) + assert adj[0].equal(out.row_indices()) def test_matmul_forward(): @@ -427,8 +537,8 @@ def test_to_sparse_tensor(): assert isinstance(adj, SparseTensor) assert adj.sizes() == [3, 3] row, col, _ = adj.coo() - assert torch.equal(row, torch.tensor([0, 1, 1, 2])) - assert torch.equal(col, torch.tensor([1, 0, 2, 1])) + assert row.equal(tensor([0, 1, 1, 2])) + assert col.equal(tensor([1, 0, 2, 1])) def test_save_and_load(tmp_path): @@ -436,16 +546,16 @@ def test_save_and_load(tmp_path): adj.fill_cache_() assert adj.sort_order == 'row' - assert torch.equal(adj._rowptr, torch.tensor([0, 1, 3, 4])) + assert adj._rowptr.equal(tensor([0, 1, 3, 4])) path = osp.join(tmp_path, 'edge_index.pt') torch.save(adj, path) out = torch.load(path) assert isinstance(out, EdgeIndex) - assert torch.equal(out, torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) + assert out.equal(tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) assert out.sort_order == 'row' - assert torch.equal(out._rowptr, torch.tensor([0, 1, 3, 4])) + assert out._rowptr.equal(tensor([0, 1, 3, 4])) @pytest.mark.parametrize('num_workers', [0, 2]) diff --git a/test/utils/test_index_sort.py b/test/utils/test_index_sort.py new file mode 100644 index 000000000000..c52e02f9307f --- /dev/null +++ b/test/utils/test_index_sort.py @@ -0,0 +1,16 @@ +import torch + +from torch_geometric.testing import withCUDA +from torch_geometric.utils import index_sort + + +@withCUDA +def test_index_sort_stable(device): + for _ in range(100): + inputs = torch.randint(0, 4, size=(10, ), device=device) + + out = index_sort(inputs, stable=True) + expected = torch.sort(inputs, stable=True) + + assert torch.equal(out[0], expected[0]) + assert torch.equal(out[1], expected[1]) diff --git a/torch_geometric/data/edge_index.py b/torch_geometric/data/edge_index.py index a26657fb1383..53878f009e2e 100644 --- a/torch_geometric/data/edge_index.py +++ b/torch_geometric/data/edge_index.py @@ -21,13 +21,15 @@ HANDLED_FUNCTIONS: Dict[Callable, Callable] = {} -SUPPORTED_DTYPES: Set[torch.dtype] = { - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.int64, -} +if torch_geometric.typing.WITH_PT113: + SUPPORTED_DTYPES: Set[torch.dtype] = { + torch.int32, + torch.int64, + } +else: + SUPPORTED_DTYPES: Set[torch.dtype] = { + torch.int64, + } ReduceType = Literal['sum'] @@ -318,12 +320,9 @@ def get_csr(self) -> Tuple[Tensor, Tensor, Union[Tensor, slice]]: rowptr = self._colptr else: # Otherwise, fill cache: - self._rowptr = torch._convert_indices_from_coo_to_csr( - self[0], - self.get_num_rows(), - out_int32=self.dtype != torch.int64, - ).to(self.dtype) - rowptr = self._rowptr + self._rowptr = rowptr = torch._convert_indices_from_coo_to_csr( + self[0], self.get_num_rows(), out_int32=self.dtype + != torch.int64) return rowptr, self[1], slice(None, None, None) @@ -349,12 +348,9 @@ def get_csr(self) -> Tuple[Tensor, Tensor, Union[Tensor, slice]]: if row is None: row = self[0][self._csc2csr] - self._rowptr = torch._convert_indices_from_coo_to_csr( - row, - self.get_num_rows(), - out_int32=self.dtype != torch.int64, - ).to(self.dtype) - rowptr = self._rowptr + self._rowptr = rowptr = torch._convert_indices_from_coo_to_csr( + row, self.get_num_rows(), out_int32=self.dtype + != torch.int64) return rowptr, self._csr_col, self._csc2csr @@ -376,12 +372,9 @@ def get_csc(self) -> Tuple[Tensor, Tensor, Union[Tensor, slice]]: colptr = self._rowptr else: # Otherwise, fill cache: - self._colptr = torch._convert_indices_from_coo_to_csr( - self[1], - self.get_num_cols(), - out_int32=self.dtype != torch.int64, - ).to(self.dtype) - colptr = self._colptr + self._colptr = colptr = torch._convert_indices_from_coo_to_csr( + self[1], self.get_num_cols(), out_int32=self.dtype + != torch.int64) return colptr, self[0], slice(None, None, None) @@ -407,12 +400,9 @@ def get_csc(self) -> Tuple[Tensor, Tensor, Union[Tensor, slice]]: if col is None: col = self[1][self._csr2csc] - self._colptr = torch._convert_indices_from_coo_to_csr( - col, - self.get_num_cols(), - out_int32=self.dtype != torch.int64, - ).to(self.dtype) - colptr = self._colptr + self._colptr = colptr = torch._convert_indices_from_coo_to_csr( + col, self.get_num_cols(), out_int32=self.dtype + != torch.int64) return colptr, self._csc_row, self._csr2csc @@ -465,12 +455,16 @@ def as_tensor(self) -> Tensor: def sort_by( self, sort_order: Union[str, SortOrder], + stable: bool = False, ) -> torch.return_types.sort: r"""Sorts the elements by row or column indices. Args: sort_order (str): The sort order, either :obj:`"row"` or :obj:`"col"`. + stable (bool, optional): Makes the sorting routine stable, which + guarantees that the order of equivalent elements is preserved. + (default: :obj:`False`) """ sort_order = SortOrder(sort_order) @@ -485,9 +479,9 @@ def sort_by( edge_index = torch.stack([self._csc_row, self[0]], dim=0) elif perm is None: - col, perm = index_sort(self[1], self.get_num_cols()) - edge_index = torch.stack([self[0][perm], col], dim=0) - self._csc_row = edge_index[0] + col, perm = index_sort(self[1], self.get_num_cols(), stable) + self._csc_row = self[0][perm] + edge_index = torch.stack([self._csc_row, col], dim=0) self._csr2csc = perm else: edge_index = self.as_tensor()[:, perm] @@ -501,9 +495,9 @@ def sort_by( edge_index = torch.stack([self[1], self._csr_col], dim=0) elif perm is None: - row, perm = index_sort(self[0], self.get_num_rows()) - edge_index = torch.stack([row, self[1][perm]], dim=0) - self._csr_col = edge_index[1] + row, perm = index_sort(self[0], self.get_num_rows(), stable) + self._csr_col = self[1][perm] + edge_index = torch.stack([row, self._csr_col], dim=0) self._csc2csr = perm else: edge_index = self.as_tensor()[:, perm] @@ -511,28 +505,25 @@ def sort_by( # Otherwise, perform sorting: elif sort_order == SortOrder.ROW: - row, perm = index_sort(self[0], self.get_num_rows()) + row, perm = index_sort(self[0], self.get_num_rows(), stable) edge_index = torch.stack([row, self[1][perm]], dim=0) else: - col, perm = index_sort(self[1], self.get_num_cols()) + col, perm = index_sort(self[1], self.get_num_cols(), stable) edge_index = torch.stack([self[0][perm], col], dim=0) out = self.__class__(edge_index) - # We can fully inherit metadata and cache: + # We can mostly inherit metadata and cache: out._sparse_size = self.sparse_size out._sort_order = sort_order out._is_undirected = self.is_undirected out._rowptr = self._rowptr - out._csr_col = self._csr_col - out._colptr = self._colptr - out._csc_row = self._csc_row - out._csr2csc = self._csr2csc - out._csc2csr = self._csc2csr + # NOTE We cannot copy CSR<>CSC permutations since we don't require that + # local neighborhoods are sorted, and thus they may run out of sync. out._value = self._value @@ -711,7 +702,7 @@ def clone(tensor: EdgeIndex) -> EdgeIndex: @implements(Tensor.to) -def to(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: +def to(tensor: EdgeIndex, *args, **kwargs) -> Union[EdgeIndex, Tensor]: out = apply_(tensor, Tensor.to, *args, **kwargs) if out.dtype not in SUPPORTED_DTYPES: @@ -720,6 +711,16 @@ def to(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: return out +@implements(Tensor.int) +def _int(tensor: EdgeIndex) -> EdgeIndex: + return to(tensor, torch.int32) + + +@implements(Tensor.long) +def long(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: + return to(tensor, torch.int64) + + @implements(Tensor.cpu) def cpu(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: return apply_(tensor, Tensor.cpu, *args, **kwargs) @@ -884,7 +885,8 @@ def is_last_dim_select(i: Any) -> bool: is_valid = is_last_dim_select(index) # 1. `edge_index[:, mask]` or `edge_index[..., mask]`. - if is_valid and isinstance(index[1], (torch.BoolTensor, torch.ByteTensor)): + if (is_valid and isinstance(index[1], Tensor) + and index[1].dtype in (torch.bool, torch.uint8)): out = out.as_subclass(EdgeIndex) out._sparse_size = input.sparse_size out._sort_order = input._sort_order @@ -1072,7 +1074,6 @@ def matmul( rowptr, col = out.crow_indices(), out.col_indices() edge_index = torch._convert_indices_from_csr_to_coo( rowptr, col, out_int32=rowptr.dtype != torch.int64) - edge_index = edge_index.to(rowptr.device) edge_index = edge_index.as_subclass(EdgeIndex) edge_index._sort_order = SortOrder.ROW diff --git a/torch_geometric/utils/sort.py b/torch_geometric/utils/sort.py index 39ae96366fc8..1b63851c5d18 100644 --- a/torch_geometric/utils/sort.py +++ b/torch_geometric/utils/sort.py @@ -1,16 +1,16 @@ from typing import Optional, Tuple -import torch +from torch import Tensor import torch_geometric.typing from torch_geometric.typing import pyg_lib def index_sort( - inputs: torch.Tensor, + inputs: Tensor, max_value: Optional[int] = None, stable: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[Tensor, Tensor]: r"""Sorts the elements of the :obj:`inputs` tensor in ascending order. It is expected that :obj:`inputs` is one-dimensional and that it only contains positive integer values. If :obj:`max_value` is given, it can