Skip to content

Commit

Permalink
refactor(transforms)!: use torch instead of numpy
Browse files Browse the repository at this point in the history
* `transforms` are now implemented with PyTorch and expect a tensor as
input. Their randomness is also controlled by PyTorch, removing the need
for users to specify `worker_init_fn` in dataloaders.

* For `RandomRotation`, the lightweight `roma` package is used instead of
`scipy`.

* `upsample_pcd` has moved from `data` to `transforms`.

Additional changes:
* Remove `Identity` from `transforms` (equivalent to `torch.nn.Identity`)
* Remove copying data in `split_pcd`
* Remove default values from `transforms`
* Remove `scipy` and `roma` in dependencies
* Update docs

Fixes #32
  • Loading branch information
adosar committed Jan 2, 2025
1 parent 308cdba commit 0f7918a
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 180 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"plotly>=5.19.0",
"tqdm>=4.66.2",
"pandas>=2.2.0",
"scipy>=1.12.0",
"roma>=1.5.1",
"lightning>=2.5.0",
"jsonargparse[signatures]>=4.35.0",
"numpy<=1.26.4", # Temporarily, for avoiding dependency conflicts.
Expand Down
13 changes: 6 additions & 7 deletions src/aidsorb/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,22 @@
import pandas as pd


def check_shape(array):
def check_shape(obj):
r"""
Check if ``array`` has valid shape to be considered a point cloud.
Check if ``obj`` has valid shape to be considered a point cloud.
Parameters
----------
array
obj : array/tensor
Raises
------
ValueError
If ``array.shape != (N, 3+C)``.
If ``obj.shape != (N, 3+C)``.
"""
if not ((array.ndim == 2) and (array.shape[1] >= 3)):
if not ((obj.ndim == 2) and (obj.shape[1] >= 3)):
raise ValueError(
'Expecting array of shape (N, 3+C) '
f'but got array of shape {array.shape}!'
f'Expecting shape (N, 3+C) but got shape {tuple(obj.shape)}!'
)


Expand Down
107 changes: 29 additions & 78 deletions src/aidsorb/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.utils.data import random_split, Dataset
from torch.nn.utils.rnn import pad_sequence
from . _internal import SEED, pd
from . transforms import upsample_pcd


def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int = SEED):
Expand All @@ -50,11 +51,8 @@ def prepare_data(source: str, split_ratio: Sequence=(0.8, 0.1, 0.1), seed: int =
source : str
Absolute or relative path to the directory holding the point clouds.
split_ratio : sequence, default=(0.8, 0.1, 0.1)
Absolute sizes or fractions of splits.
* ``split_ratio[0] == train``
* ``split_ratio[1] == validation``
* ``split_ratio[2] == test``
Absolute sizes or fractions of splits of the form ``(train, val,
test)``.
seed : int, default=1
Controls randomness of the ``rng`` used for splitting.
Expand Down Expand Up @@ -120,49 +118,6 @@ def get_names(filename):
return names


def upsample_pcd(pcd, size):
r"""
Upsample ``pcd`` to a new ``size`` by sampling with replacement from ``pcd``.
Parameters
----------
pcd : tensor of shape (N, C)
Original point cloud of size ``N``.
size : int
Size of the new point cloud.
Returns
-------
new_pcd : tensor of shape (size, C)
Examples
--------
>>> pcd = torch.tensor([[2, 4, 5, 6]])
>>> upsample_pcd(pcd, 3)
tensor([[2, 4, 5, 6],
[2, 4, 5, 6],
[2, 4, 5, 6]])
>>> # New points point must be from pcd.
>>> pcd = torch.randn(10, 4)
>>> new_pcd = upsample_pcd(pcd, 20)
>>> (new_pcd[-1] == pcd).all(1).any() # Check for last point.
tensor(True)
>>> # No upsampling.
>>> pcd = torch.randn(100, 4)
>>> new_pcd = upsample_pcd(pcd, len(pcd))
>>> torch.equal(pcd, new_pcd)
True
"""
n_samples = size - len(pcd)
indices = torch.randint(len(pcd), (n_samples,)) # With replacement.
new_points = pcd[indices]

return torch.cat((pcd, new_points))


def pad_pcds(pcds, channels_first=True, mode='upsample'):
r"""
Pad a sequence of variable size point clouds.
Expand Down Expand Up @@ -233,7 +188,7 @@ def pad_pcds(pcds, channels_first=True, mode='upsample'):
[8, 9]]])
"""
if mode == 'zeropad':
batch = pad_sequence(pcds, batch_first=True, padding_value=0)
batch = pad_sequence(pcds, batch_first=True, padding_value=0.0)

elif mode == 'upsample':
max_len = max(len(i) for i in pcds)
Expand Down Expand Up @@ -328,20 +283,20 @@ class Collator():
tensor([0, 1])
>>> # Label is None, i.e. unlabeled data.
>>> sample1 = (torch.tensor([[1, 0, 1, 0]]), None)
>>> sample2 = (torch.tensor([[5, 2, 2, 0], [9, 0, 0, 1]]), None)
>>> collate_fn = Collator()
>>> sample1 = (torch.tensor([[1., 0., 1., 0.]]), None)
>>> sample2 = (torch.tensor([[5., 2., 2., 0.], [9., 0., 0., 1.]]), None)
>>> collate_fn = Collator(mode='zeropad')
>>> x, y = collate_fn((sample1, sample2))
>>> x
tensor([[[1, 1],
[0, 0],
[1, 1],
[0, 0]],
tensor([[[1., 0.],
[0., 0.],
[1., 0.],
[0., 0.]],
<BLANKLINE>
[[5, 9],
[2, 0],
[2, 0],
[0, 1]]])
[[5., 9.],
[2., 0.],
[2., 0.],
[0., 1.]]])
>>> y
"""
def __init__(self, channels_first=True, mode='upsample'):
Expand All @@ -353,16 +308,13 @@ def __call__(self, samples):
Parameters
----------
samples : sequence of tuples
Each sample is a tuple of tensors ``(pcd, label)`` where
``pcd.shape == (n_points, C)`` and ``label`` has shape
``(n_outputs,)`` or ``()``.
Each sample is a tuple of tensors ``(pcd, label)`` or ``(pcd,
None)``.
Returns
-------
batch : tuple of length 2
* ``batch[0] == x`` with shape ``(B, C, T)`` or ``(B, T, C)``, where
``T`` is the size of the largest point cloud.
* ``batch[1] == y`` with shape ``(B, n_outputs)`` or ``(B,)``.
Batch of the form ``(x, y)`` or ``(x, None)``.
"""
pcds, labels = list(zip(*samples))

Expand All @@ -382,7 +334,7 @@ class PCDDataset(Dataset):
.. note::
* ``x`` and ``y`` are tensors of ``dtype=torch.float``.
* ``y`` has shape ``(len(labels),)``.
* ``transform_x`` and ``transform_y`` expect :class:`~numpy.ndarray` as
* ``transform_x`` and ``transform_y`` expect :class:`~torch.Tensor` as
input.
.. warning::
Expand Down Expand Up @@ -450,21 +402,20 @@ def __len__(self):
def __getitem__(self, idx):
pcd_name = self.pcd_names[idx]
pcd_path = os.path.join(self.path_to_X, f'{pcd_name}.npy')
sample_x = np.load(pcd_path)
sample_y = None

if self.transform_x is not None:
sample_x = self.transform_x(sample_x)
pcd = torch.tensor(np.load(pcd_path), dtype=torch.float)
label = None

sample_x = torch.tensor(sample_x, dtype=torch.float)
if self.transform_x is not None:
pcd = self.transform_x(pcd)

# Only for labeled data.
if self.Y is not None:
sample_y = self.Y.loc[pcd_name].to_numpy()
label = torch.tensor(
self.Y.loc[pcd_name].to_numpy(),
dtype=torch.float,
)

if self.transform_y is not None:
sample_y = self.transform_y(sample_y)

sample_y = torch.tensor(sample_y, dtype=torch.float)
label = self.transform_y(label)

return sample_x, sample_y
return pcd, label
9 changes: 4 additions & 5 deletions src/aidsorb/litmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
"""

from collections.abc import Callable
import torch
import torchmetrics
from torchmetrics import MetricCollection
import lightning as L
from . _litmodels_utils import get_optimizers

Expand Down Expand Up @@ -72,11 +71,11 @@ class PCDLit(L.LightningModule):
--------
>>> from aidsorb.modules import PointNetClsHead
>>> from aidsorb.models import PointNet
>>> from torch.nn import MSELoss
>>> import torch
>>> from torchmetrics import MetricCollection, R2Score, MeanAbsoluteError as MAE
>>> model = PointNet(head=PointNetClsHead(n_outputs=10))
>>> criterion, metric = MSELoss(), MetricCollection(R2Score(), MAE())
>>> criterion, metric = torch.nn.MSELoss(), MetricCollection(R2Score(), MAE())
>>> # Adam optimizer with default hyperparameters, no scheduler.
>>> litmodel = PCDLit(model, criterion, metric)
Expand All @@ -102,7 +101,7 @@ def __init__(
self,
model: Callable,
criterion: Callable,
metric: torchmetrics.MetricCollection,
metric: MetricCollection,
config_optimizer: dict=None,
config_scheduler: dict=None,
):
Expand Down
16 changes: 4 additions & 12 deletions src/aidsorb/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,11 @@ def conv1d_block(in_channels, out_channels, **kwargs):
If ``None``, the ``conv_layer`` is lazy initialized.
out_channels : int
**kwargs
Valid keyword arguments for :class:`torch.nn.Conv1d`.
Valid keyword arguments for :class:`~torch.nn.Conv1d`.
Returns
-------
block : :class:`torch.nn.Sequential`
See Also
--------
:class:`torch.nn.Conv1d` : For a description of the parameters.
block : :class:`~torch.nn.Sequential`
Examples
--------
Expand Down Expand Up @@ -106,15 +102,11 @@ def dense_block(in_features, out_features, **kwargs):
If ``None``, the ``linear_layer`` is lazy initialized.
out_features : int
**kwargs
Valid keyword arguments for :class:`torch.nn.Linear`.
Valid keyword arguments for :class:`~torch.nn.Linear`.
Returns
-------
block : :class:`torch.nn.Sequential`
See Also
--------
:class:`torch.nn.Linear` : For a description of the parameters.
block : :class:`~torch.nn.Sequential`
Examples
--------
Expand Down
Loading

0 comments on commit 0f7918a

Please sign in to comment.