Skip to content

Commit

Permalink
feat: reduce memory footprint (#55)
Browse files Browse the repository at this point in the history
* feat: reduce memory footprint of windowing and tensor conversion

Windowing can now unload results to a memmap on disk. The tensor conversions now use as_tensor that avoids copying numpy arrays if possible.

* refactor: avoid concatenating runs by using custom dataset

Concatenating the runs is memory intensive and can be avoided by putting them into a custom dataset instead of a TensorDataset.

* fix: make RulDataset more robust

* refactor: cleanup windowing code

* feat: convert data to tensor only when leaving dataset

This enables using occult Numpy arrays to hold data, e.g., `lib.stride_tricks.sliding_window_view` which can lower memory consumption.

* fix: linting issues

* refactor: avoid unnecessary copies

* feat: add flag to force copy tensors

* fix: linting issues
  • Loading branch information
tilman151 authored Jan 30, 2024
1 parent 8a32183 commit ac70405
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 185 deletions.
55 changes: 25 additions & 30 deletions rul_datasets/adaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@

import warnings
from copy import deepcopy
from typing import List, Optional, Any, Tuple, Callable, Sequence, Union, cast
from typing import List, Optional, Any, Tuple, Callable, Sequence, cast

import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import ConcatDataset, TensorDataset
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from rul_datasets import utils
from rul_datasets.core import PairedRulDataset, RulDataModule
from rul_datasets.core import PairedRulDataset, RulDataModule, RulDataset


class DomainAdaptionDataModule(pl.LightningDataModule):
Expand Down Expand Up @@ -235,7 +233,7 @@ class LatentAlignDataModule(DomainAdaptionDataModule):
>>> fd2 = rul_datasets.CmapssReader(fd=2, percent_broken=0.8)
>>> src = rul_datasets.RulDataModule(fd1, 32)
>>> trg = rul_datasets.RulDataModule(fd2, 32)
>>> dm = rul_datasets.LatentAlignDataModule(src, trg, split_by_max_rul=125)
>>> dm = rul_datasets.LatentAlignDataModule(src, trg, split_by_max_rul=True)
>>> dm.prepare_data()
>>> dm.setup()
>>> train_1_2 = dm.train_dataloader()
Expand Down Expand Up @@ -286,12 +284,15 @@ def __init__(

def _get_training_dataset(self) -> "AdaptionDataset":
source_healthy, source_degraded = split_healthy(
*self.source.load_split("dev"), by_max_rul=True
*self.source.data["dev"], by_max_rul=True
)
target_features, target_labels = ( # reload only if needed to save memory
self.target.data["dev"]
if not self.inductive
else self.target.load_split("test", alias="dev")
)
target_healthy, target_degraded = split_healthy(
*self.target.load_split("test" if self.inductive else "dev", alias="dev"),
self.split_by_max_rul,
self.split_by_steps,
target_features, target_labels, self.split_by_max_rul, self.split_by_steps
)
healthy: Dataset = ConcatDataset([source_healthy, target_healthy])
dataset = AdaptionDataset(source_degraded, target_degraded, healthy)
Expand All @@ -300,11 +301,11 @@ def _get_training_dataset(self) -> "AdaptionDataset":


def split_healthy(
features: Union[List[np.ndarray], List[torch.Tensor]],
targets: Union[List[np.ndarray], List[torch.Tensor]],
features: List[np.ndarray],
targets: List[np.ndarray],
by_max_rul: bool = False,
by_steps: Optional[int] = None,
) -> Tuple[TensorDataset, TensorDataset]:
) -> Tuple[RulDataset, RulDataset]:
"""
Split the feature and target time series into healthy and degrading parts and
return a dataset of each.
Expand All @@ -329,19 +330,13 @@ def split_healthy(
if not by_max_rul and (by_steps is None):
raise ValueError("Either 'by_max_rul' or 'by_steps' need to be set.")

if isinstance(features[0], np.ndarray):
features, targets = cast(Tuple[List[np.ndarray], ...], (features, targets))
_features, _targets = utils.to_tensor(features, targets)
else:
_features, _targets = cast(Tuple[List[torch.Tensor], ...], (features, targets))

healthy = []
degraded = []
for feature, target in zip(_features, _targets):
for feature, target in zip(features, targets):
sections = _get_sections(by_max_rul, by_steps, target)
healthy_feat, degraded_feat = torch.split(feature, sections)
healthy_target, degraded_target = torch.split(target, sections)
degradation_steps = torch.arange(1, len(degraded_target) + 1)
healthy_feat, degraded_feat = np.split(feature, sections)
healthy_target, degraded_target = np.split(target, sections)
degradation_steps = np.arange(1, len(degraded_target) + 1)
healthy.append((healthy_feat, healthy_target))
degraded.append((degraded_feat, degradation_steps, degraded_target))

Expand All @@ -352,22 +347,22 @@ def split_healthy(


def _get_sections(
by_max_rul: bool, by_steps: Optional[int], target: torch.Tensor
by_max_rul: bool, by_steps: Optional[int], target: np.ndarray
) -> List[int]:
# cast is needed for mypy and has no runtime effect
if by_max_rul:
split_idx = cast(int, target.flip(0).argmax().item())
sections = [len(target) - split_idx, split_idx]
split_idx = cast(int, np.flip(target, axis=0).argmax())
sections = [len(target) - split_idx]
else:
by_steps = min(cast(int, by_steps), len(target))
sections = [by_steps, len(target) - by_steps]
sections = [by_steps]

return sections


def _to_dataset(data: Sequence[Tuple[torch.Tensor, ...]]) -> TensorDataset:
tensor_data = [torch.cat(h) for h in zip(*data)]
dataset = TensorDataset(*tensor_data)
def _to_dataset(data: Sequence[Tuple[np.ndarray, ...]]) -> RulDataset:
features, *targets = list(zip(*data))
dataset = RulDataset(features, *targets)

return dataset

Expand Down
8 changes: 4 additions & 4 deletions rul_datasets/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def __init__(
self.min_distance = min_distance
self.distance_mode = distance_mode
self.window_size = self.unfailed.reader.window_size
self.source = unfailed_data_module

self._check_loaders()

Expand Down Expand Up @@ -209,7 +208,8 @@ def prepare_data(self, *args, **kwargs):
self.unfailed.reader.prepare_data()

def setup(self, stage: Optional[str] = None):
self.source.setup(stage)
self.unfailed.setup(stage)
self.failed.setup(stage)

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
Expand All @@ -220,9 +220,9 @@ def val_dataloader(self, *args, **kwargs) -> List[DataLoader]:
combined_loader = DataLoader(
self._get_paired_dataset("val"), batch_size=self.batch_size, pin_memory=True
)
source_loader = self.source.val_dataloader()
unfailed_loader = self.unfailed.val_dataloader()

return [combined_loader, source_loader]
return [combined_loader, unfailed_loader]

def _get_paired_dataset(self, split: str) -> PairedRulDataset:
deterministic = split == "val"
Expand Down
Loading

0 comments on commit ac70405

Please sign in to comment.