Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new batch filtering method for efficiency purpose #9911

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
877f742
add new batch filtering method for efficiency purpose
leonardcaquot94 Jan 2, 2025
4da4ed1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
9f6ae80
update CHANGELOG.md
leonardcaquot94 Jan 2, 2025
007baa8
Merge remote-tracking branch 'origin/master'
leonardcaquot94 Jan 2, 2025
7b9ec12
support non tensor attributes
leonardcaquot94 Jan 3, 2025
b1f3e84
fix issue when computing new edge index
leonardcaquot94 Jan 3, 2025
606d34a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
929ddcf
handle case where filtered batch is empty
leonardcaquot94 Jan 3, 2025
6d34e00
Merge remote-tracking branch 'origin/master'
leonardcaquot94 Jan 7, 2025
818e4cd
refacto to make code fit in one line
leonardcaquot94 Jan 7, 2025
d5bfc91
fix issue in new edge index computation and add explanations
leonardcaquot94 Jan 7, 2025
080e23c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
66b48d2
fix variable error
leonardcaquot94 Jan 7, 2025
bd296a3
Merge remote-tracking branch 'origin/master'
leonardcaquot94 Jan 7, 2025
06088d5
skip useless repeat_interleave computation in some specific but commo…
leonardcaquot94 Jan 7, 2025
9a8b22d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
c91671f
refacto code to make it compatible with max line size 79 characters
leonardcaquot94 Jan 7, 2025
8eb8ad3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
b843233
refacto test (max line size 79 characters)
leonardcaquot94 Jan 7, 2025
3839fbb
add homogeneous graph support
leonardcaquot94 Jan 9, 2025
723f54d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
9d1cef8
use new filter method as default filter method
leonardcaquot94 Jan 9, 2025
5561121
Merge remote-tracking branch 'origin/master'
leonardcaquot94 Jan 9, 2025
906f559
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
- Added `filter` method for efficient Batch filtering ([#9911](https://github.com/pyg-team/pytorch_geometric/pull/9911))

### Changed

Expand Down
66 changes: 66 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,69 @@ def __inc__(self, key, value, *args, **kwargs) -> int:
batch.x.to_padded_tensor(0.0),
expected.to_padded_tensor(0.0),
)


def test_batch_filtering():
# Create initial HeteroData object and batch
data_list = []
for i in range(1, 5):
data = Data(
x=torch.randn(2 * i, 2),
edge_index=torch.stack((torch.arange(2 * i).repeat(2),
torch.arange(2 * i).repeat_interleave(2))),
edge_attr=torch.randn(4 * i, 3))
data.info = [i] * i # Add argument of variable size
data_list.append(data)
batch = Batch.from_data_list(data_list)

# Check different filtering methods:
assert isinstance(batch[1:], Batch)
assert isinstance(batch[torch.tensor([0, 1])], Batch)
assert isinstance(batch[torch.tensor([True, False, True, False])], Batch)
assert isinstance(batch[np.array([0, 1])], Batch)
assert isinstance(batch[np.array([True, False, True, False])], Batch)
assert isinstance(batch[[1, 2]], Batch)

batch_filtered = batch[[False, True, True, False]]
assert isinstance(batch, Batch)
assert isinstance(batch_filtered, Batch)
assert len(batch) == 4
assert len(batch_filtered) == 2
assert batch_filtered[0].x.shape == batch[1].x.shape
assert batch_filtered[1].x.shape == batch[2].x.shape
assert len(batch_filtered.batch.unique()) == 2
assert len(batch_filtered.info) == 2
assert len(batch_filtered.info[0]) == 2
assert len(batch_filtered.info[1]) == 3
# Check if result supports round-trip conversion to and from a data list
assert Batch.from_data_list(batch_filtered.to_data_list())


def test_herero_batch_filtering():
# Create initial HeteroData object and batch
data_list = []
for i in range(1, 5):
data = HeteroData()
data.info_1 = torch.ones(3) # graph attribute
data.info_2 = torch.ones(3) # graph attribute
data['paper'].x = torch.randn(2 * i, 16)
data['paper'].exist = torch.randn(
4 * i, 1) # node attributes of different length than x
data['author'].x = torch.randn(3 * i, 2)
data['author', 'writes', 'paper'].edge_index = torch.stack(
(torch.arange(3 * i).repeat(2), torch.arange(2 * i).repeat(3)))
data['author', 'writes', 'paper'].edge_attr = torch.randn(6 * i, 3)
data_list.append(data)
batch = Batch.from_data_list(data_list)
batch_filtered = batch[[1, 3]]
assert isinstance(batch, Batch)
assert isinstance(batch_filtered, Batch)
assert len(batch) == 4
assert len(batch_filtered) == 2
assert batch_filtered[0]['paper'].x.shape == batch[1]['paper'].x.shape
assert batch_filtered[0]['paper'].exist.shape == batch[1][
'paper'].exist.shape
assert batch_filtered[1]['author'].x.shape == batch[3]['author'].x.shape

# Check if result supports round-trip conversion to and from a data list
assert Batch.from_data_list(batch_filtered.to_data_list())
134 changes: 133 additions & 1 deletion torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,138 @@ def index_select(self, idx: IndexType) -> List[BaseData]:

return [self.get_example(i) for i in index]

def filter(self, idx: torch.Tensor) -> Self:
"""Efficiently filters Batch object using a boolean mask or index,
directly modifying batch attributes instead of converting it to data
list, masking, and rebuilding it (which can be ~10x slower).

The provided indices (:obj:`idx`) can be a list, a tuple, a slicing
object (e.g., :obj:`[2:5]`), or a :obj:`torch.Tensor`/:obj:`np.ndarray`
of type long or bool, or any sequence of integers (excluding strings).
"""
mask = torch.zeros(len(self), dtype=torch.bool)
try:
mask[idx] = True
except IndexError as e:
raise IndexError(
"Invalid index provided. Ensure the index is "
"compatible with PyTorch's indexing rules.") from e

# Create new empty batch that will be filled later
batch = Batch(_base_cls=self[0].__class__).stores_as(self)
batch._slice_dict = {}
batch._inc_dict = {}

# Update the number of graphs based on the mask
batch._num_graphs = mask.sum().item()

# Return empty batch when mask filters all elements
if batch._num_graphs == 0:
return batch

# Main loop to apply the mask at different levels (graph, nodes, edges)
for old_store, new_store in zip(self.stores, batch.stores):
# We get slices dictionary from key. If key is None then it means
# we are dealing with graph level attributes.
key = old_store._key

if key is not None: # Heterogeneous:
attrs = self._slice_dict[key].keys()
else: # Homogeneous:
attrs = set(old_store.keys())
attrs = [
attr for attr in self._slice_dict.keys() if attr in attrs
]

if key:
batch._slice_dict[key] = {}
batch._inc_dict[key] = {}

# All slice and store are updated one by one in following loop
for attr in attrs:
if key is not None:
slc = self._slice_dict[key][attr]
incs = self._inc_dict[key][attr]
else:
slc = self._slice_dict[attr]
incs = self._inc_dict[attr]

slice_diff = slc.diff()

# Reshape mask to align it with attribute shape. Since
# slice_diff often contains only ones, skip useless
# computation in such cases
if torch.any(slice_diff != 1):
attr_mask = mask[torch.repeat_interleave(slice_diff)]
else:
attr_mask = mask

# Apply mask to attribute
if attr == 'edge_index':
new_store[attr] = old_store[attr][:, attr_mask]
elif isinstance(old_store[attr], list):
new_store[attr] = [
x for x, m in zip(old_store[attr], attr_mask) if m
]
else:
new_store[attr] = old_store[attr][attr_mask]

# Compute masked version of slice tensor
sizes_masked = slice_diff[mask]
slice_masked = torch.cat(
(torch.zeros(1, dtype=torch.int), sizes_masked.cumsum(0)))

# New _inc tensor is zeros tensor by default and can be
# overwritten later if needed. For now, we only update it when
# attr is edge_index, but we should do it every time original
# _inc tensor is not zeros only.
new_inc = torch.zeros(batch._num_graphs, dtype=torch.int)

# when attr is 'x', we also update 'ptr' and 'batch' tensors
# since this attribute provides node number information.
if attr == 'x':
new_store['ptr'] = slice_masked
new_store['batch'] = torch.repeat_interleave(sizes_masked)

# Reindex edge_index to remove gaps left by removed nodes
if attr == 'edge_index':

# Compute diff tensor to get edge_index spans
old_spans = incs.diff(dim=0, append=incs[-1:])

# Apply the mask to filter spans
new_spans = old_spans[mask]

# Use cumsum to reconstruct masked _inc tensor
new_inc_tmp = new_spans.cumsum(0)

# Adjust the result (start from zero, ignore last values)
new_inc_tmp[-1] = 0
new_inc = new_inc_tmp.roll(1, dims=0)

# Map each edge_index element to its batch position
attr_batch_map = torch.repeat_interleave(sizes_masked)

# Update edge_index by removing old_inc and add new_inc
# We do new_inc - old_inc operation before applying the
# map for efficiency purpose
shift = (new_inc - incs[mask])[attr_batch_map]

if shift.ndim == 1: # Homogeneous
new_store[attr] += shift
else: # Heterogeneous
new_store[attr] += shift.squeeze(-1).T

# Finally, we update _slice_dict and _inc_dict based on what
# has been computed in previous steps
if key: # Node or edge level attribute
batch._slice_dict[key][attr] = slice_masked
batch._inc_dict[key][attr] = new_inc
else: # Graph level attribute
batch._slice_dict[attr] = slice_masked
batch._inc_dict[attr] = new_inc
return batch

def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
if (isinstance(idx, (int, np.integer))
or (isinstance(idx, Tensor) and idx.dim() == 0)
Expand All @@ -180,7 +312,7 @@ def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
# Accessing attributes or node/edge types:
return super().__getitem__(idx) # type: ignore
else:
return self.index_select(idx)
return self.filter(idx)

def to_data_list(self) -> List[BaseData]:
r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or
Expand Down
Loading