Skip to content

Commit

Permalink
Support pin_memory() in Multi{Embedding,Nested}Tensor and `Tensor…
Browse files Browse the repository at this point in the history
…Frame` (#437)

Co-authored-by: Zecheng Zhang <[email protected]>
  • Loading branch information
akihironitta and zechengz authored Dec 30, 2024
1 parent 655730c commit febb5e4
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444))
- Added `is_floating_point` method to `MultiNestedTensor` and `MultiEmbeddingTensor` ([#445](https://github.com/pyg-team/pytorch-frame/pull/445))
- Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421))
- Added `pin_memory()` to `TensorFrame`, `MultiEmbeddingTensor`, and `MultiNestedTensor` ([#437](https://github.com/pyg-team/pytorch-frame/pull/437))

### Changed

Expand Down
17 changes: 16 additions & 1 deletion test/data/test_multi_embedding_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
from torch_frame.testing import withCUDA
from torch_frame.testing import onlyCUDA, withCUDA


def assert_equal(
Expand Down Expand Up @@ -487,3 +487,18 @@ def test_cat(device):
# case: list of non-MultiEmbeddingTensor should raise error
with pytest.raises(AssertionError):
MultiEmbeddingTensor.cat([object()], dim=0)


@onlyCUDA
def test_pin_memory():
met, _ = get_fake_multi_embedding_tensor(
num_rows=2,
num_cols=3,
)
assert not met.is_pinned()
assert not met.values.is_pinned()
assert not met.offset.is_pinned()
met = met.pin_memory()
assert met.is_pinned()
assert met.values.is_pinned()
assert met.offset.is_pinned()
25 changes: 21 additions & 4 deletions test/data/test_multi_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor

from torch_frame.data import MultiNestedTensor
from torch_frame.testing import withCUDA
from torch_frame.testing import onlyCUDA


def assert_equal(tensor_mat: list[list[Tensor]],
Expand Down Expand Up @@ -95,8 +95,8 @@ def test_fillna_col():
torch.tensor([100], dtype=torch.float32)))


@withCUDA
def test_multi_nested_tensor_basics(device):
@onlyCUDA
def test_basics(device):
num_rows = 8
num_cols = 10
max_value = 100
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_multi_nested_tensor_basics(device):
cloned_multi_nested_tensor)


def test_multi_nested_tensor_different_num_rows():
def test_different_num_rows():
tensor_mat = [
[torch.tensor([1, 2, 3]),
torch.tensor([4, 5])],
Expand All @@ -340,3 +340,20 @@ def test_multi_nested_tensor_different_num_rows():
match="The length of each row must be the same",
):
MultiNestedTensor.from_tensor_mat(tensor_mat)


@onlyCUDA
def test_pin_memory():
num_rows = 10
num_cols = 3
tensor = MultiNestedTensor.from_tensor_mat(
[[torch.randn(random.randint(0, 10)) for _ in range(num_cols)]
for _ in range(num_rows)])

assert not tensor.is_pinned()
assert not tensor.values.is_pinned()
assert not tensor.offset.is_pinned()
tensor = tensor.pin_memory()
assert tensor.is_pinned()
assert tensor.values.is_pinned()
assert tensor.offset.is_pinned()
17 changes: 17 additions & 0 deletions test/data/test_tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_frame import TensorFrame
from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
from torch_frame.data.multi_nested_tensor import MultiNestedTensor
from torch_frame.testing import onlyCUDA


def test_tensor_frame_basics(get_fake_tensor_frame):
Expand Down Expand Up @@ -253,3 +254,19 @@ def test_non_list_col_names_dict():
col_names_dict = {torch_frame.categorical: 'cat_1'}
with pytest.raises(ValueError, match='must be a list of column names'):
TensorFrame(feat_dict, col_names_dict)


@onlyCUDA
def test_pin_memory(get_fake_tensor_frame):
def assert_is_pinned(tf: TensorFrame, expected: bool) -> bool:
for value in tf.feat_dict.values():
if isinstance(value, dict):
for v in value.values():
assert v.is_pinned() is expected
else:
assert value.is_pinned() is expected

tf = get_fake_tensor_frame(10)
assert_is_pinned(tf, expected=False)
tf = tf.pin_memory()
assert_is_pinned(tf, expected=True)
6 changes: 6 additions & 0 deletions torch_frame/data/multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def cpu(self, *args, **kwargs):
def cuda(self, *args, **kwargs):
return self._apply(lambda x: x.cuda(*args, **kwargs))

def pin_memory(self, *args, **kwargs):
return self._apply(lambda x: x.pin_memory(*args, **kwargs))

def is_pinned(self) -> bool:
return self.values.is_pinned() and self.offset.is_pinned()

# Helper Functions ########################################################

def _apply(self, fn: Callable[[Tensor], Tensor]) -> _MultiTensor:
Expand Down
11 changes: 11 additions & 0 deletions torch_frame/data/tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,17 @@ def fn(x):

return self._apply(fn)

def pin_memory(self, *args, **kwargs):
def fn(x):
if isinstance(x, dict):
for key in x:
x[key] = x[key].pin_memory(*args, **kwargs)
else:
x = x.pin_memory(*args, **kwargs)
return x

return self._apply(fn)

# Helper Functions ########################################################

def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame:
Expand Down
2 changes: 2 additions & 0 deletions torch_frame/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
has_package,
withPackage,
withCUDA,
onlyCUDA,
)

__all__ = [
'has_package',
'withPackage',
'withCUDA',
'onlyCUDA',
]
9 changes: 9 additions & 0 deletions torch_frame/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ def withCUDA(func: Callable):
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))

return pytest.mark.parametrize('device', devices)(func)


def onlyCUDA(func: Callable) -> Callable:
r"""A decorator to skip tests if CUDA is not found."""
import pytest
return pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA not available",
)(func)

0 comments on commit febb5e4

Please sign in to comment.