From 877f7426eea4813135a1b28c8676aba312b93db8 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Thu, 2 Jan 2025 12:37:01 +0100 Subject: [PATCH 01/20] add new batch filtering method for efficiency purpose --- test/data/test_batch.py | 30 ++++++++++ torch_geometric/data/batch.py | 110 ++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index b893c1e4dd11..0199b7b105f8 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -608,3 +608,33 @@ def __inc__(self, key, value, *args, **kwargs) -> int: batch.x.to_padded_tensor(0.0), expected.to_padded_tensor(0.0), ) + + +def test_batch_filtering(): + # Create initial HeteroData object and batch + data_list = [] + for i in range(1, 5): + data = HeteroData() + data.info_1 = torch.ones(3) # graph attribute + data.info_2 = torch.ones(3) # graph attribute + data['paper'].x = torch.randn(2*i, 16) + data['paper'].exist = torch.randn(4*i, 1) # node attributes of different length than x + data['author'].x = torch.randn(3*i, 2) + data['author', 'writes', 'paper'].edge_index = torch.stack(( + torch.arange(3*i).repeat(2), + torch.arange(2*i).repeat(3) + )) + data['author', 'writes', 'paper'].edge_attr = torch.randn(6*i, 3) + data_list.append(data) + batch = Batch.from_data_list(data_list) + batch_filtered = batch.filter([1, 3]) + assert isinstance(batch, Batch) + assert isinstance(batch_filtered, Batch) + assert len(batch) == 4 + assert len(batch_filtered) == 2 + assert batch_filtered[0]['paper'].x.shape == batch[1]['paper'].x.shape + assert batch_filtered[0]['paper'].exist.shape == batch[1]['paper'].exist.shape + assert batch_filtered[1]['author'].x.shape == batch[3]['author'].x.shape + + # Verify the filtered batch supports round-trip conversion to and from a data list + assert Batch.from_data_list(batch_filtered.to_data_list()) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 411639a228d2..260a1eb865ba 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -170,6 +170,116 @@ def index_select(self, idx: IndexType) -> List[BaseData]: return [self.get_example(i) for i in index] + def filter(self, idx: torch.Tensor) -> Self: + """ + Efficiently filters the object using a boolean mask or index, directly modifying + batch attributes instead of rebuilding the batch. + + This method is ~10x faster than calling Batch.from_data_list(batch[mask]). + + The provided indices (:obj:`idx`) can be a slicing object (e.g., :obj:`[2:5]`), + a list, tuple, or a :obj:`torch.Tensor`/:obj:`np.ndarray` of type long or bool, + or any sequence of integers (excluding strings). + + This implementation currently focuses on HeteroData, but handling HomogeneousData + needs to be addressed. Additionally, the default filtering from __get_item__ still + uses the index_select method, which could be replaced with this approach for + improved efficiency, avoiding conversion to list objects. + """ + + mask: torch.Tensor + if isinstance(idx, slice): + mask = torch.zeros(len(self), dtype=torch.bool) + mask[idx] = True + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + mask = torch.zeros(len(self), dtype=torch.bool) + mask[idx.flatten()] = True + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + mask = idx.flatten() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + mask = torch.zeros(len(self), dtype=torch.bool) + mask[idx.flatten()] = True + + elif isinstance(idx, np.ndarray) and idx.dtype == bool: + mask = torch.tensor(idx.flatten()) + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + mask = torch.zeros(len(self), dtype=torch.bool) + mask[idx] = True + + else: + raise IndexError( + f"Only slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')") + + # Create new empty batch that will be filled later + batch = Batch(_base_cls=self[0].__class__).stores_as(self) + batch._slice_dict = {} + batch._inc_dict = {} + + # Update the number of graphs based on the mask + batch._num_graphs = mask.sum().item() + + # Mask application works the same way for all attribute levels (graph, nodes, edges) + for old_store, new_store in zip(self.stores, batch.stores): + # We get slices dictionary from key. If key is None then we are dealing with graph level attributes. + key = old_store._key + slices = self._slice_dict[key] if key else {attr: self._slice_dict[attr] for attr in old_store} + + if key: + batch._slice_dict[key] = {} + batch._inc_dict[key] = {} + + # All slice and store are updated one by one in following loop + for attr, slc in slices.items(): + slice_diff = slc.diff() + + # Reshape mask to align it with attribute shape + attr_mask = mask[torch.repeat_interleave(slice_diff)] + + # Apply mask to attribute + new_store[attr] = old_store[attr][:, attr_mask] if attr == 'edge_index' else old_store[attr][attr_mask] + + # Compute masked version of slice tensor + sizes_masked = slice_diff[mask] + slice_masked = torch.cat((torch.zeros(1, dtype=torch.int), sizes_masked.cumsum(0))) + + # By default, new inc tensor is zero tensor, unless it is overwritten later + new_inc = torch.zeros(batch._num_graphs, dtype=torch.int) + + # x attribute provides num_node info to update 'ptr' and 'batch' tensors + if attr == 'x': + batch[key].ptr = slice_masked + batch[key].batch = torch.repeat_interleave(sizes_masked) + + if attr == 'edge_index': + # Then we reindex edge_index to remove gaps left by removed nodes + # We assume x node attributes to be changed before edge attributes + # so that mapping_idx_dict is already available. + old_inc = self._inc_dict[key][attr].squeeze(-1).T + old_inc_diff = old_inc.diff(prepend=torch.zeros((2, 1), dtype=torch.int))[:, mask] + old_inc_diff[:, 0] = 0 + new_inc = old_inc_diff.cumsum(1) + shift_inc = new_inc - old_inc[:, mask] + + edge_index_batch = torch.repeat_interleave(sizes_masked) + batch[key].edge_index += shift_inc[:, edge_index_batch] + + new_inc = new_inc.T.unsqueeze(-1) + + # Update _slice_dict and _inc_dict based on what has been computed before + if key: # Node or edge level attribute + batch._slice_dict[key][attr] = slice_masked + batch._inc_dict[key][attr] = new_inc + else: # Graph level attribute + batch._slice_dict[attr] = slice_masked + batch._inc_dict[attr] = new_inc + return batch + def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) From 4da4ed1408c924ec15594d685a72026b20a2337f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:47:40 +0000 Subject: [PATCH 02/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/data/test_batch.py | 18 +++++++++--------- torch_geometric/data/batch.py | 19 ++++++++++++------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index 0199b7b105f8..ae3a44c45976 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -617,14 +617,13 @@ def test_batch_filtering(): data = HeteroData() data.info_1 = torch.ones(3) # graph attribute data.info_2 = torch.ones(3) # graph attribute - data['paper'].x = torch.randn(2*i, 16) - data['paper'].exist = torch.randn(4*i, 1) # node attributes of different length than x - data['author'].x = torch.randn(3*i, 2) - data['author', 'writes', 'paper'].edge_index = torch.stack(( - torch.arange(3*i).repeat(2), - torch.arange(2*i).repeat(3) - )) - data['author', 'writes', 'paper'].edge_attr = torch.randn(6*i, 3) + data['paper'].x = torch.randn(2 * i, 16) + data['paper'].exist = torch.randn( + 4 * i, 1) # node attributes of different length than x + data['author'].x = torch.randn(3 * i, 2) + data['author', 'writes', 'paper'].edge_index = torch.stack( + (torch.arange(3 * i).repeat(2), torch.arange(2 * i).repeat(3))) + data['author', 'writes', 'paper'].edge_attr = torch.randn(6 * i, 3) data_list.append(data) batch = Batch.from_data_list(data_list) batch_filtered = batch.filter([1, 3]) @@ -633,7 +632,8 @@ def test_batch_filtering(): assert len(batch) == 4 assert len(batch_filtered) == 2 assert batch_filtered[0]['paper'].x.shape == batch[1]['paper'].x.shape - assert batch_filtered[0]['paper'].exist.shape == batch[1]['paper'].exist.shape + assert batch_filtered[0]['paper'].exist.shape == batch[1][ + 'paper'].exist.shape assert batch_filtered[1]['author'].x.shape == batch[3]['author'].x.shape # Verify the filtered batch supports round-trip conversion to and from a data list diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 260a1eb865ba..234305f7cd94 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -171,8 +171,7 @@ def index_select(self, idx: IndexType) -> List[BaseData]: return [self.get_example(i) for i in index] def filter(self, idx: torch.Tensor) -> Self: - """ - Efficiently filters the object using a boolean mask or index, directly modifying + """Efficiently filters the object using a boolean mask or index, directly modifying batch attributes instead of rebuilding the batch. This method is ~10x faster than calling Batch.from_data_list(batch[mask]). @@ -186,7 +185,6 @@ def filter(self, idx: torch.Tensor) -> Self: uses the index_select method, which could be replaced with this approach for improved efficiency, avoiding conversion to list objects. """ - mask: torch.Tensor if isinstance(idx, slice): mask = torch.zeros(len(self), dtype=torch.bool) @@ -228,7 +226,10 @@ def filter(self, idx: torch.Tensor) -> Self: for old_store, new_store in zip(self.stores, batch.stores): # We get slices dictionary from key. If key is None then we are dealing with graph level attributes. key = old_store._key - slices = self._slice_dict[key] if key else {attr: self._slice_dict[attr] for attr in old_store} + slices = self._slice_dict[key] if key else { + attr: self._slice_dict[attr] + for attr in old_store + } if key: batch._slice_dict[key] = {} @@ -242,11 +243,14 @@ def filter(self, idx: torch.Tensor) -> Self: attr_mask = mask[torch.repeat_interleave(slice_diff)] # Apply mask to attribute - new_store[attr] = old_store[attr][:, attr_mask] if attr == 'edge_index' else old_store[attr][attr_mask] + new_store[attr] = old_store[ + attr][:, attr_mask] if attr == 'edge_index' else old_store[ + attr][attr_mask] # Compute masked version of slice tensor sizes_masked = slice_diff[mask] - slice_masked = torch.cat((torch.zeros(1, dtype=torch.int), sizes_masked.cumsum(0))) + slice_masked = torch.cat( + (torch.zeros(1, dtype=torch.int), sizes_masked.cumsum(0))) # By default, new inc tensor is zero tensor, unless it is overwritten later new_inc = torch.zeros(batch._num_graphs, dtype=torch.int) @@ -261,7 +265,8 @@ def filter(self, idx: torch.Tensor) -> Self: # We assume x node attributes to be changed before edge attributes # so that mapping_idx_dict is already available. old_inc = self._inc_dict[key][attr].squeeze(-1).T - old_inc_diff = old_inc.diff(prepend=torch.zeros((2, 1), dtype=torch.int))[:, mask] + old_inc_diff = old_inc.diff( + prepend=torch.zeros((2, 1), dtype=torch.int))[:, mask] old_inc_diff[:, 0] = 0 new_inc = old_inc_diff.cumsum(1) shift_inc = new_inc - old_inc[:, mask] From 9f6ae80e31ee69e91b3a8037c0d5c4652b0f0268 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Thu, 2 Jan 2025 17:57:55 +0100 Subject: [PATCH 03/20] update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..d405114a42db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722)) - Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) +- Added `filter` method for efficient Batch filtering ([#9911](https://github.com/pyg-team/pytorch_geometric/pull/9911)) ### Changed From 7b9ec12c1109e66809521c6f37aea7188e307a51 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Fri, 3 Jan 2025 16:54:57 +0100 Subject: [PATCH 04/20] support non tensor attributes --- torch_geometric/data/batch.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 234305f7cd94..ea9e5a11af28 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -243,9 +243,12 @@ def filter(self, idx: torch.Tensor) -> Self: attr_mask = mask[torch.repeat_interleave(slice_diff)] # Apply mask to attribute - new_store[attr] = old_store[ - attr][:, attr_mask] if attr == 'edge_index' else old_store[ - attr][attr_mask] + if attr == 'edge_index': + new_store[attr] = old_store[attr][:, attr_mask] + elif isinstance(old_store[attr], list): + new_store[attr] = [item for item, m in zip(old_store[attr], attr_mask) if m] + else: + new_store[attr] = old_store[attr][attr_mask] # Compute masked version of slice tensor sizes_masked = slice_diff[mask] From b1f3e84adefb2f4bc83c221467e654a721828b7a Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Fri, 3 Jan 2025 16:55:33 +0100 Subject: [PATCH 05/20] fix issue when computing new edge index --- torch_geometric/data/batch.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index ea9e5a11af28..0459fd0a1aa8 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -268,14 +268,13 @@ def filter(self, idx: torch.Tensor) -> Self: # We assume x node attributes to be changed before edge attributes # so that mapping_idx_dict is already available. old_inc = self._inc_dict[key][attr].squeeze(-1).T - old_inc_diff = old_inc.diff( - prepend=torch.zeros((2, 1), dtype=torch.int))[:, mask] - old_inc_diff[:, 0] = 0 - new_inc = old_inc_diff.cumsum(1) - shift_inc = new_inc - old_inc[:, mask] + new_inc = old_inc.diff()[:, mask[:-1]].cumsum(1) + new_inc = torch.cat((torch.zeros((2, 1), dtype=torch.int), new_inc), dim=1) + + shift = new_inc - old_inc[:, mask] edge_index_batch = torch.repeat_interleave(sizes_masked) - batch[key].edge_index += shift_inc[:, edge_index_batch] + batch[key].edge_index += shift[:, edge_index_batch] new_inc = new_inc.T.unsqueeze(-1) From 606d34a8079521eaf440ea1dfaf14819b29d9ecb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 15:57:13 +0000 Subject: [PATCH 06/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/data/batch.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 0459fd0a1aa8..65989bd1c52d 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -246,7 +246,10 @@ def filter(self, idx: torch.Tensor) -> Self: if attr == 'edge_index': new_store[attr] = old_store[attr][:, attr_mask] elif isinstance(old_store[attr], list): - new_store[attr] = [item for item, m in zip(old_store[attr], attr_mask) if m] + new_store[attr] = [ + item for item, m in zip(old_store[attr], attr_mask) + if m + ] else: new_store[attr] = old_store[attr][attr_mask] @@ -269,7 +272,8 @@ def filter(self, idx: torch.Tensor) -> Self: # so that mapping_idx_dict is already available. old_inc = self._inc_dict[key][attr].squeeze(-1).T new_inc = old_inc.diff()[:, mask[:-1]].cumsum(1) - new_inc = torch.cat((torch.zeros((2, 1), dtype=torch.int), new_inc), dim=1) + new_inc = torch.cat((torch.zeros( + (2, 1), dtype=torch.int), new_inc), dim=1) shift = new_inc - old_inc[:, mask] From 929ddcfcc40bfc4c1556e5ace510e8bd99a3f1f9 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Fri, 3 Jan 2025 18:13:48 +0100 Subject: [PATCH 07/20] handle case where filtered batch is empty --- torch_geometric/data/batch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 0459fd0a1aa8..3d55c5f63125 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -222,6 +222,10 @@ def filter(self, idx: torch.Tensor) -> Self: # Update the number of graphs based on the mask batch._num_graphs = mask.sum().item() + # Return empty batch when mask filters all elements + if batch._num_graphs == 0: + return batch + # Mask application works the same way for all attribute levels (graph, nodes, edges) for old_store, new_store in zip(self.stores, batch.stores): # We get slices dictionary from key. If key is None then we are dealing with graph level attributes. From 818e4cda357f220a4cd5db901c19d4f852518b46 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Tue, 7 Jan 2025 11:43:17 +0100 Subject: [PATCH 08/20] refacto to make code fit in one line --- torch_geometric/data/batch.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 26933aa8d6dc..1473ed36abba 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -250,10 +250,7 @@ def filter(self, idx: torch.Tensor) -> Self: if attr == 'edge_index': new_store[attr] = old_store[attr][:, attr_mask] elif isinstance(old_store[attr], list): - new_store[attr] = [ - item for item, m in zip(old_store[attr], attr_mask) - if m - ] + new_store[attr] = [x for x, m in zip(old_store[attr], attr_mask) if m] else: new_store[attr] = old_store[attr][attr_mask] From d5bfc915d336976b46d5aa46f9a8fd81f79ea6f1 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Tue, 7 Jan 2025 11:44:00 +0100 Subject: [PATCH 09/20] fix issue in new edge index computation and add explanations --- torch_geometric/data/batch.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 1473ed36abba..2ed6d7a82462 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -268,19 +268,25 @@ def filter(self, idx: torch.Tensor) -> Self: batch[key].batch = torch.repeat_interleave(sizes_masked) if attr == 'edge_index': - # Then we reindex edge_index to remove gaps left by removed nodes - # We assume x node attributes to be changed before edge attributes - # so that mapping_idx_dict is already available. + # Reindex edge_index to remove gaps from removed nodes. This involves: + # 1. Computing the difference (diff) to get edge index spans + # 2. Applying the mask to filter the spans + # 3. Using cumsum to reconstruct the _inc tensor + # 4. Adjusting the result to start from zero and ignore last _inc values old_inc = self._inc_dict[key][attr].squeeze(-1).T - new_inc = old_inc.diff()[:, mask[:-1]].cumsum(1) - new_inc = torch.cat((torch.zeros( - (2, 1), dtype=torch.int), new_inc), dim=1) - + old_edge_index_spans = old_inc.diff(append=old_inc[:, -1:]) + new_edge_index_spans = old_edge_index_spans[:, mask] + new_inc_tmp = new_edge_index_spans.cumsum(1) + new_inc_tmp[:, -1] = 0 + new_inc = new_inc.roll(1, dims=1) + + # Map each edge_index element to its batch position + edge_index_batch_map = torch.repeat_interleave(sizes_masked) + # Remove old_inc and add new_inc to each edge_index element using shift tensor shift = new_inc - old_inc[:, mask] + batch[key].edge_index += shift[:, edge_index_batch_map] - edge_index_batch = torch.repeat_interleave(sizes_masked) - batch[key].edge_index += shift[:, edge_index_batch] - + # Reshape new_inc before saving in _inc_dict new_inc = new_inc.T.unsqueeze(-1) # Update _slice_dict and _inc_dict based on what has been computed before From 080e23c18ab1c670790d83b89fc77cbaf08da5e5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:47:14 +0000 Subject: [PATCH 10/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/data/batch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 2ed6d7a82462..c35e15aa7541 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -250,7 +250,9 @@ def filter(self, idx: torch.Tensor) -> Self: if attr == 'edge_index': new_store[attr] = old_store[attr][:, attr_mask] elif isinstance(old_store[attr], list): - new_store[attr] = [x for x, m in zip(old_store[attr], attr_mask) if m] + new_store[attr] = [ + x for x, m in zip(old_store[attr], attr_mask) if m + ] else: new_store[attr] = old_store[attr][attr_mask] @@ -281,7 +283,8 @@ def filter(self, idx: torch.Tensor) -> Self: new_inc = new_inc.roll(1, dims=1) # Map each edge_index element to its batch position - edge_index_batch_map = torch.repeat_interleave(sizes_masked) + edge_index_batch_map = torch.repeat_interleave( + sizes_masked) # Remove old_inc and add new_inc to each edge_index element using shift tensor shift = new_inc - old_inc[:, mask] batch[key].edge_index += shift[:, edge_index_batch_map] From 66b48d267c921c650b071b81073efa49d1964807 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Tue, 7 Jan 2025 12:02:47 +0100 Subject: [PATCH 11/20] fix variable error --- torch_geometric/data/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 2ed6d7a82462..a775f780412d 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -278,7 +278,7 @@ def filter(self, idx: torch.Tensor) -> Self: new_edge_index_spans = old_edge_index_spans[:, mask] new_inc_tmp = new_edge_index_spans.cumsum(1) new_inc_tmp[:, -1] = 0 - new_inc = new_inc.roll(1, dims=1) + new_inc = new_inc_tmp.roll(1, dims=1) # Map each edge_index element to its batch position edge_index_batch_map = torch.repeat_interleave(sizes_masked) From 06088d57e35ab295e5a2f772a1f71cef97a95612 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Tue, 7 Jan 2025 15:15:52 +0100 Subject: [PATCH 12/20] skip useless repeat_interleave computation in some specific but common scenarios --- torch_geometric/data/batch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 5be12c423e7f..3fe47870f680 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -243,8 +243,9 @@ def filter(self, idx: torch.Tensor) -> Self: for attr, slc in slices.items(): slice_diff = slc.diff() - # Reshape mask to align it with attribute shape - attr_mask = mask[torch.repeat_interleave(slice_diff)] + # Reshape mask to align it with attribute shape. + # Since slice_diff often contains only ones, skip useless computation in such cases + attr_mask = mask[torch.repeat_interleave(slice_diff)] if torch.any(slice_diff != 1) else mask # Apply mask to attribute if attr == 'edge_index': From 9a8b22d53c5345ec5f16603d18b943405e1a6ebf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 14:17:19 +0000 Subject: [PATCH 13/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/data/batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 3fe47870f680..f909622df2d0 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -245,7 +245,8 @@ def filter(self, idx: torch.Tensor) -> Self: # Reshape mask to align it with attribute shape. # Since slice_diff often contains only ones, skip useless computation in such cases - attr_mask = mask[torch.repeat_interleave(slice_diff)] if torch.any(slice_diff != 1) else mask + attr_mask = mask[torch.repeat_interleave( + slice_diff)] if torch.any(slice_diff != 1) else mask # Apply mask to attribute if attr == 'edge_index': From c91671f930293ee46c57befc5151ecb6fa72dc94 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Tue, 7 Jan 2025 17:03:43 +0100 Subject: [PATCH 14/20] refacto code to make it compatible with max line size 79 characters --- torch_geometric/data/batch.py | 75 +++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index f909622df2d0..751318d7d20c 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -171,19 +171,20 @@ def index_select(self, idx: IndexType) -> List[BaseData]: return [self.get_example(i) for i in index] def filter(self, idx: torch.Tensor) -> Self: - """Efficiently filters the object using a boolean mask or index, directly modifying - batch attributes instead of rebuilding the batch. + """Efficiently filters the object using a boolean mask or index, + directly modifying batch attributes instead of rebuilding the batch. - This method is ~10x faster than calling Batch.from_data_list(batch[mask]). + This method can be ~10x faster than Batch.from_data_list(batch[mask]). - The provided indices (:obj:`idx`) can be a slicing object (e.g., :obj:`[2:5]`), - a list, tuple, or a :obj:`torch.Tensor`/:obj:`np.ndarray` of type long or bool, - or any sequence of integers (excluding strings). + The provided indices (:obj:`idx`) can be a list, a tuple, a slicing + object (e.g., :obj:`[2:5]`), or a :obj:`torch.Tensor`/:obj:`np.ndarray` + of type long or bool, or any sequence of integers (excluding strings). - This implementation currently focuses on HeteroData, but handling HomogeneousData - needs to be addressed. Additionally, the default filtering from __get_item__ still - uses the index_select method, which could be replaced with this approach for - improved efficiency, avoiding conversion to list objects. + This implementation currently focuses on HeteroData, so handling + HomogeneousData needs to be addressed. Additionally, the default + filtering from __get_item__ still uses the index_select method, which + could be replaced with this approach for improved efficiency, avoiding + conversion to list objects. """ mask: torch.Tensor if isinstance(idx, slice): @@ -226,9 +227,10 @@ def filter(self, idx: torch.Tensor) -> Self: if batch._num_graphs == 0: return batch - # Mask application works the same way for all attribute levels (graph, nodes, edges) + # Main loop to apply the mask at different levels (graph, nodes, edges) for old_store, new_store in zip(self.stores, batch.stores): - # We get slices dictionary from key. If key is None then we are dealing with graph level attributes. + # We get slices dictionary from key. If key is None then it means + # we are dealing with graph level attributes. key = old_store._key slices = self._slice_dict[key] if key else { attr: self._slice_dict[attr] @@ -243,18 +245,20 @@ def filter(self, idx: torch.Tensor) -> Self: for attr, slc in slices.items(): slice_diff = slc.diff() - # Reshape mask to align it with attribute shape. - # Since slice_diff often contains only ones, skip useless computation in such cases - attr_mask = mask[torch.repeat_interleave( - slice_diff)] if torch.any(slice_diff != 1) else mask + # Reshape mask to align it with attribute shape. Since + # slice_diff often contains only ones, skip useless + # computation in such cases + if torch.any(slice_diff != 1): + attr_mask = mask[torch.repeat_interleave(slice_diff)] + else: + attr_mask = mask # Apply mask to attribute if attr == 'edge_index': new_store[attr] = old_store[attr][:, attr_mask] elif isinstance(old_store[attr], list): new_store[attr] = [ - x for x, m in zip(old_store[attr], attr_mask) if m - ] + x for x, m in zip(old_store[attr], attr_mask) if m] else: new_store[attr] = old_store[attr][attr_mask] @@ -263,38 +267,49 @@ def filter(self, idx: torch.Tensor) -> Self: slice_masked = torch.cat( (torch.zeros(1, dtype=torch.int), sizes_masked.cumsum(0))) - # By default, new inc tensor is zero tensor, unless it is overwritten later + # New _inc tensor is zeros tensor by default and can be + # overwritten later if needed. For now, we only update it when + # attr is edge_index, but we should do it every time original + # _inc tensor is not zeros only. new_inc = torch.zeros(batch._num_graphs, dtype=torch.int) - # x attribute provides num_node info to update 'ptr' and 'batch' tensors + # when attr is 'x', we also update 'ptr' and 'batch' tensors + # since this attribute provides node number information. if attr == 'x': batch[key].ptr = slice_masked batch[key].batch = torch.repeat_interleave(sizes_masked) + # Reindex edge_index to remove gaps left by removed nodes if attr == 'edge_index': - # Reindex edge_index to remove gaps from removed nodes. This involves: - # 1. Computing the difference (diff) to get edge index spans - # 2. Applying the mask to filter the spans - # 3. Using cumsum to reconstruct the _inc tensor - # 4. Adjusting the result to start from zero and ignore last _inc values + # Reshape tensor to match edge_index size old_inc = self._inc_dict[key][attr].squeeze(-1).T + + # Compute diff tensor to get edge_index spans old_edge_index_spans = old_inc.diff(append=old_inc[:, -1:]) + + # Apply the mask to filter spans new_edge_index_spans = old_edge_index_spans[:, mask] + + # Use cumsum to reconstruct masked _inc tensor new_inc_tmp = new_edge_index_spans.cumsum(1) + + # Adjust the result (start from zero, ignore last values) new_inc_tmp[:, -1] = 0 new_inc = new_inc_tmp.roll(1, dims=1) # Map each edge_index element to its batch position - edge_index_batch_map = torch.repeat_interleave( - sizes_masked) - # Remove old_inc and add new_inc to each edge_index element using shift tensor + attr_batch_map = torch.repeat_interleave(sizes_masked) + + # Remove old_inc and add new_inc to each edge_index + # element using shift tensor shift = new_inc - old_inc[:, mask] - batch[key].edge_index += shift[:, edge_index_batch_map] + batch[key].edge_index += shift[:, attr_batch_map] # Reshape new_inc before saving in _inc_dict new_inc = new_inc.T.unsqueeze(-1) - # Update _slice_dict and _inc_dict based on what has been computed before + # Finally, we update _slice_dict and _inc_dict based on what + # has been computed in previous steps if key: # Node or edge level attribute batch._slice_dict[key][attr] = slice_masked batch._inc_dict[key][attr] = new_inc From 8eb8ad3a9733538ce8ca1a8b3c497da079d4d315 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 16:05:07 +0000 Subject: [PATCH 15/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/data/batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 751318d7d20c..ee409a1d560f 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -258,7 +258,8 @@ def filter(self, idx: torch.Tensor) -> Self: new_store[attr] = old_store[attr][:, attr_mask] elif isinstance(old_store[attr], list): new_store[attr] = [ - x for x, m in zip(old_store[attr], attr_mask) if m] + x for x, m in zip(old_store[attr], attr_mask) if m + ] else: new_store[attr] = old_store[attr][attr_mask] From b84323308f503b69b8d8d6cace5035348a09089b Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Tue, 7 Jan 2025 17:15:37 +0100 Subject: [PATCH 16/20] refacto test (max line size 79 characters) --- test/data/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index ae3a44c45976..e17666404177 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -636,5 +636,5 @@ def test_batch_filtering(): 'paper'].exist.shape assert batch_filtered[1]['author'].x.shape == batch[3]['author'].x.shape - # Verify the filtered batch supports round-trip conversion to and from a data list + # Check if result supports round-trip conversion to and from a data list assert Batch.from_data_list(batch_filtered.to_data_list()) From 3839fbb5680e5ba30d652f6d47f6ba6cbc4df23d Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Thu, 9 Jan 2025 14:02:48 +0100 Subject: [PATCH 17/20] add homogeneous graph support --- test/data/test_batch.py | 41 +++++++++++++++++---- torch_geometric/data/batch.py | 67 ++++++++++++++++++++--------------- 2 files changed, 73 insertions(+), 35 deletions(-) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index e17666404177..b7a9d56e1749 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -376,8 +376,8 @@ def test_recursive_batch(): [data1.edge_index[0], data2.edge_index[0] + 30], dim=1).tolist()) assert (batch.edge_index[1].tolist() == torch.cat( [data1.edge_index[1], data2.edge_index[1] + 30], dim=1).tolist()) - assert batch.batch.size() == (90, ) - assert batch.ptr.size() == (3, ) + assert batch.batch.size() == (90,) + assert batch.ptr.size() == (3,) out1 = batch[0] assert len(out1) == 3 @@ -451,10 +451,10 @@ def test_hetero_batch(): assert torch.allclose( batch[e2].edge_attr, torch.cat([data1[e2].edge_attr, data2[e2].edge_attr], 0)) - assert batch['p'].batch.size() == (150, ) - assert batch['p'].ptr.size() == (3, ) - assert batch['a'].batch.size() == (300, ) - assert batch['a'].ptr.size() == (3, ) + assert batch['p'].batch.size() == (150,) + assert batch['p'].ptr.size() == (3,) + assert batch['a'].batch.size() == (300,) + assert batch['a'].ptr.size() == (3,) out1 = batch[0] assert len(out1) == 3 @@ -611,6 +611,35 @@ def __inc__(self, key, value, *args, **kwargs) -> int: def test_batch_filtering(): + # Create initial HeteroData object and batch + data_list = [] + for i in range(1, 5): + data = Data( + x=torch.randn(2 * i, 2), + edge_index=torch.stack(( + torch.arange(2 * i).repeat(2), + torch.arange(2 * i).repeat_interleave(2))), + edge_attr=torch.randn(4 * i, 3) + ) + data.info = [i]*i # Add argument of variable size + data_list.append(data) + batch = Batch.from_data_list(data_list) + batch_filtered = batch.filter([1, 2]) + assert isinstance(batch, Batch) + assert isinstance(batch_filtered, Batch) + assert len(batch) == 4 + assert len(batch_filtered) == 2 + assert batch_filtered[0].x.shape == batch[1].x.shape + assert batch_filtered[1].x.shape == batch[2].x.shape + assert len(batch_filtered.batch.unique()) == 2 + assert len(batch_filtered.info) == 2 + assert len(batch_filtered.info[0]) == 2 + assert len(batch_filtered.info[1]) == 3 + # Check if result supports round-trip conversion to and from a data list + assert Batch.from_data_list(batch_filtered.to_data_list()) + + +def test_herero_batch_filtering(): # Create initial HeteroData object and batch data_list = [] for i in range(1, 5): diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index ee409a1d560f..d7e70b2ed8a2 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -79,12 +79,13 @@ class Batch(metaclass=DynamicInheritance): Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together. """ + @classmethod def from_data_list( - cls, - data_list: List[BaseData], - follow_batch: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, + cls, + data_list: List[BaseData], + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, ) -> Self: r"""Constructs a :class:`~torch_geometric.data.Batch` object from a list of :class:`~torch_geometric.data.Data` or @@ -180,11 +181,9 @@ def filter(self, idx: torch.Tensor) -> Self: object (e.g., :obj:`[2:5]`), or a :obj:`torch.Tensor`/:obj:`np.ndarray` of type long or bool, or any sequence of integers (excluding strings). - This implementation currently focuses on HeteroData, so handling - HomogeneousData needs to be addressed. Additionally, the default - filtering from __get_item__ still uses the index_select method, which - could be replaced with this approach for improved efficiency, avoiding - conversion to list objects. + For now, default filtering from __get_item__ still uses index_select + method, but it could be replaced with this approach to improved + efficiency and avoid conversion to list objects that may be unwanted. """ mask: torch.Tensor if isinstance(idx, slice): @@ -232,17 +231,27 @@ def filter(self, idx: torch.Tensor) -> Self: # We get slices dictionary from key. If key is None then it means # we are dealing with graph level attributes. key = old_store._key - slices = self._slice_dict[key] if key else { - attr: self._slice_dict[attr] - for attr in old_store - } + + if key is not None: # Heterogeneous: + attrs = self._slice_dict[key].keys() + else: # Homogeneous: + attrs = set(old_store.keys()) + attrs = [attr for attr in self._slice_dict.keys() if + attr in attrs] if key: batch._slice_dict[key] = {} batch._inc_dict[key] = {} # All slice and store are updated one by one in following loop - for attr, slc in slices.items(): + for attr in attrs: + if key is not None: + slc = self._slice_dict[key][attr] + incs = self._inc_dict[key][attr] + else: + slc = self._slice_dict[attr] + incs = self._inc_dict[attr] + slice_diff = slc.diff() # Reshape mask to align it with attribute shape. Since @@ -277,37 +286,37 @@ def filter(self, idx: torch.Tensor) -> Self: # when attr is 'x', we also update 'ptr' and 'batch' tensors # since this attribute provides node number information. if attr == 'x': - batch[key].ptr = slice_masked - batch[key].batch = torch.repeat_interleave(sizes_masked) + new_store['ptr'] = slice_masked + new_store['batch'] = torch.repeat_interleave(sizes_masked) # Reindex edge_index to remove gaps left by removed nodes if attr == 'edge_index': - # Reshape tensor to match edge_index size - old_inc = self._inc_dict[key][attr].squeeze(-1).T # Compute diff tensor to get edge_index spans - old_edge_index_spans = old_inc.diff(append=old_inc[:, -1:]) + old_spans = incs.diff(dim=0, append=incs[-1:]) # Apply the mask to filter spans - new_edge_index_spans = old_edge_index_spans[:, mask] + new_spans = old_spans[mask] # Use cumsum to reconstruct masked _inc tensor - new_inc_tmp = new_edge_index_spans.cumsum(1) + new_inc_tmp = new_spans.cumsum(0) # Adjust the result (start from zero, ignore last values) - new_inc_tmp[:, -1] = 0 - new_inc = new_inc_tmp.roll(1, dims=1) + new_inc_tmp[-1] = 0 + new_inc = new_inc_tmp.roll(1, dims=0) # Map each edge_index element to its batch position attr_batch_map = torch.repeat_interleave(sizes_masked) - # Remove old_inc and add new_inc to each edge_index - # element using shift tensor - shift = new_inc - old_inc[:, mask] - batch[key].edge_index += shift[:, attr_batch_map] + # Update edge_index by removing old_inc and add new_inc + # We do new_inc - old_inc operation before applying the + # map for efficiency purpose + shift = (new_inc - incs[mask])[attr_batch_map] - # Reshape new_inc before saving in _inc_dict - new_inc = new_inc.T.unsqueeze(-1) + if shift.ndim == 1: # Homogeneous + new_store[attr] += shift + else: # Heterogeneous + new_store[attr] += shift.squeeze(-1).T # Finally, we update _slice_dict and _inc_dict based on what # has been computed in previous steps From 723f54d569f9334480145d6d3c96b110bb05daa0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:05:28 +0000 Subject: [PATCH 18/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/data/test_batch.py | 22 ++++++++++------------ torch_geometric/data/batch.py | 14 +++++++------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index b7a9d56e1749..07f1f333dc5f 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -376,8 +376,8 @@ def test_recursive_batch(): [data1.edge_index[0], data2.edge_index[0] + 30], dim=1).tolist()) assert (batch.edge_index[1].tolist() == torch.cat( [data1.edge_index[1], data2.edge_index[1] + 30], dim=1).tolist()) - assert batch.batch.size() == (90,) - assert batch.ptr.size() == (3,) + assert batch.batch.size() == (90, ) + assert batch.ptr.size() == (3, ) out1 = batch[0] assert len(out1) == 3 @@ -451,10 +451,10 @@ def test_hetero_batch(): assert torch.allclose( batch[e2].edge_attr, torch.cat([data1[e2].edge_attr, data2[e2].edge_attr], 0)) - assert batch['p'].batch.size() == (150,) - assert batch['p'].ptr.size() == (3,) - assert batch['a'].batch.size() == (300,) - assert batch['a'].ptr.size() == (3,) + assert batch['p'].batch.size() == (150, ) + assert batch['p'].ptr.size() == (3, ) + assert batch['a'].batch.size() == (300, ) + assert batch['a'].ptr.size() == (3, ) out1 = batch[0] assert len(out1) == 3 @@ -616,12 +616,10 @@ def test_batch_filtering(): for i in range(1, 5): data = Data( x=torch.randn(2 * i, 2), - edge_index=torch.stack(( - torch.arange(2 * i).repeat(2), - torch.arange(2 * i).repeat_interleave(2))), - edge_attr=torch.randn(4 * i, 3) - ) - data.info = [i]*i # Add argument of variable size + edge_index=torch.stack((torch.arange(2 * i).repeat(2), + torch.arange(2 * i).repeat_interleave(2))), + edge_attr=torch.randn(4 * i, 3)) + data.info = [i] * i # Add argument of variable size data_list.append(data) batch = Batch.from_data_list(data_list) batch_filtered = batch.filter([1, 2]) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index d7e70b2ed8a2..ae9f728aafd4 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -79,13 +79,12 @@ class Batch(metaclass=DynamicInheritance): Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together. """ - @classmethod def from_data_list( - cls, - data_list: List[BaseData], - follow_batch: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, + cls, + data_list: List[BaseData], + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, ) -> Self: r"""Constructs a :class:`~torch_geometric.data.Batch` object from a list of :class:`~torch_geometric.data.Data` or @@ -236,8 +235,9 @@ def filter(self, idx: torch.Tensor) -> Self: attrs = self._slice_dict[key].keys() else: # Homogeneous: attrs = set(old_store.keys()) - attrs = [attr for attr in self._slice_dict.keys() if - attr in attrs] + attrs = [ + attr for attr in self._slice_dict.keys() if attr in attrs + ] if key: batch._slice_dict[key] = {} From 9d1cef8fdc5681e24278db13f60f6c224ecc7be1 Mon Sep 17 00:00:00 2001 From: "Leonard.Caquot" Date: Thu, 9 Jan 2025 17:11:31 +0100 Subject: [PATCH 19/20] use new filter method as default filter method --- test/data/test_batch.py | 13 +++++++++-- torch_geometric/data/batch.py | 43 ++++++++--------------------------- 2 files changed, 20 insertions(+), 36 deletions(-) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index b7a9d56e1749..a39b7401b2ca 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -624,7 +624,16 @@ def test_batch_filtering(): data.info = [i]*i # Add argument of variable size data_list.append(data) batch = Batch.from_data_list(data_list) - batch_filtered = batch.filter([1, 2]) + + # Check different filtering methods: + assert isinstance(batch[1:], Batch) + assert isinstance(batch[torch.tensor([0, 1])], Batch) + assert isinstance(batch[torch.tensor([True, False, True, False])], Batch) + assert isinstance(batch[np.array([0, 1])], Batch) + assert isinstance(batch[np.array([True, False, True, False])], Batch) + assert isinstance(batch[[1, 2]], Batch) + + batch_filtered = batch[[False, True, True, False]] assert isinstance(batch, Batch) assert isinstance(batch_filtered, Batch) assert len(batch) == 4 @@ -655,7 +664,7 @@ def test_herero_batch_filtering(): data['author', 'writes', 'paper'].edge_attr = torch.randn(6 * i, 3) data_list.append(data) batch = Batch.from_data_list(data_list) - batch_filtered = batch.filter([1, 3]) + batch_filtered = batch[[1, 3]] assert isinstance(batch, Batch) assert isinstance(batch_filtered, Batch) assert len(batch) == 4 diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index d7e70b2ed8a2..2fe7eebb9ea4 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -172,47 +172,22 @@ def index_select(self, idx: IndexType) -> List[BaseData]: return [self.get_example(i) for i in index] def filter(self, idx: torch.Tensor) -> Self: - """Efficiently filters the object using a boolean mask or index, - directly modifying batch attributes instead of rebuilding the batch. - - This method can be ~10x faster than Batch.from_data_list(batch[mask]). + """Efficiently filters Batch object using a boolean mask or index, + directly modifying batch attributes instead of converting it to data + list, masking, and rebuilding it (which can be ~10x slower). The provided indices (:obj:`idx`) can be a list, a tuple, a slicing object (e.g., :obj:`[2:5]`), or a :obj:`torch.Tensor`/:obj:`np.ndarray` of type long or bool, or any sequence of integers (excluding strings). - - For now, default filtering from __get_item__ still uses index_select - method, but it could be replaced with this approach to improved - efficiency and avoid conversion to list objects that may be unwanted. """ - mask: torch.Tensor - if isinstance(idx, slice): - mask = torch.zeros(len(self), dtype=torch.bool) - mask[idx] = True - - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - mask = torch.zeros(len(self), dtype=torch.bool) - mask[idx.flatten()] = True - - elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - mask = idx.flatten() - - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - mask = torch.zeros(len(self), dtype=torch.bool) - mask[idx.flatten()] = True - - elif isinstance(idx, np.ndarray) and idx.dtype == bool: - mask = torch.tensor(idx.flatten()) - elif isinstance(idx, Sequence) and not isinstance(idx, str): - mask = torch.zeros(len(self), dtype=torch.bool) + mask = torch.zeros(len(self), dtype=torch.bool) + try: mask[idx] = True - - else: + except IndexError as e: raise IndexError( - f"Only slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')") + "Invalid index provided. Ensure the index is " + "compatible with PyTorch's indexing rules.") from e # Create new empty batch that will be filled later batch = Batch(_base_cls=self[0].__class__).stores_as(self) @@ -338,7 +313,7 @@ def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: # Accessing attributes or node/edge types: return super().__getitem__(idx) # type: ignore else: - return self.index_select(idx) + return self.filter(idx) def to_data_list(self) -> List[BaseData]: r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or From 906f5592c078a3d1dc6228035c72fdc45a2e7cc1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 16:13:02 +0000 Subject: [PATCH 20/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/data/batch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 3bcbc3b4b1cc..99d0574c66e9 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -179,7 +179,6 @@ def filter(self, idx: torch.Tensor) -> Self: object (e.g., :obj:`[2:5]`), or a :obj:`torch.Tensor`/:obj:`np.ndarray` of type long or bool, or any sequence of integers (excluding strings). """ - mask = torch.zeros(len(self), dtype=torch.bool) try: mask[idx] = True