From 04a304c3422c57e31e3168f67b0186d48210d8fd Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Tue, 17 Sep 2024 17:48:15 +0200 Subject: [PATCH] [MAINT] Reduce copying of params (#10) --- src/pyparrm/_utils/_plotting.py | 9 +++---- src/pyparrm/parrm.py | 47 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/src/pyparrm/_utils/_plotting.py b/src/pyparrm/_utils/_plotting.py index 634317f..8b2aa51 100644 --- a/src/pyparrm/_utils/_plotting.py +++ b/src/pyparrm/_utils/_plotting.py @@ -3,7 +3,6 @@ # Author(s): # Thomas Samuel Binns | github.com/tsbinns -from copy import deepcopy from multiprocessing import cpu_count from matplotlib import pyplot as plt @@ -112,7 +111,7 @@ def _check_sort_init_inputs( "PyPARRM Internal Error: `_ParamSelection` should only be called if the " "period has been estimated. Please contact the PyPARRM developers." ) - self.parrm = deepcopy(parrm) + self.parrm = parrm self.parrm._verbose = False # time_range @@ -163,14 +162,14 @@ def _check_sort_init_inputs( ) if freq_range[0] >= freq_range[1]: raise ValueError("`freq_range[1]` must be > `freq_range[0]`.") - self.freq_range = deepcopy(freq_range) + self.freq_range = freq_range # freq_res if not isinstance(freq_res, (int, float)): raise TypeError("`freq_res` must be an int or a float.") if freq_res <= 0 or freq_res > self.parrm._sampling_freq / 2: raise ValueError("`freq_res` must lie in the range (0, Nyquist frequency].") - self.freq_res = deepcopy(freq_res) + self.freq_res = freq_res # n_jobs if not isinstance(n_jobs, int): @@ -181,7 +180,7 @@ def _check_sort_init_inputs( raise ValueError("If `n_jobs` is <= 0, it must be -1.") if n_jobs == -1: n_jobs = cpu_count() - self.n_jobs = deepcopy(n_jobs) + self.n_jobs = n_jobs self.parrm._check_sort_create_filter_inputs(None, 0, "both", None) self.current_period_half_width = self.parrm._period_half_width diff --git a/src/pyparrm/parrm.py b/src/pyparrm/parrm.py index f0d871b..ff80e6f 100644 --- a/src/pyparrm/parrm.py +++ b/src/pyparrm/parrm.py @@ -3,7 +3,6 @@ # Author(s): # Thomas Samuel Binns | github.com/tsbinns -from copy import deepcopy from multiprocessing import cpu_count import numpy as np @@ -123,23 +122,23 @@ def _check_init_inputs( raise TypeError("`data` must be a NumPy array.") if data.ndim != 2: raise ValueError("`data` must be a 2D array.") - self._data = data.copy() + self._data = data if not isinstance(sampling_freq, (int, float)): raise TypeError("`sampling_freq` must be an int or a float.") if sampling_freq <= 0: raise ValueError("`sampling_freq` must be > 0.") - self._sampling_freq = deepcopy(sampling_freq) + self._sampling_freq = sampling_freq if not isinstance(artefact_freq, (int, float)): raise TypeError("`artefact_freq` must be an int or a float.") if artefact_freq <= 0: raise ValueError("`artefact_freq` must be > 0.") - self._artefact_freq = deepcopy(artefact_freq) + self._artefact_freq = artefact_freq if not isinstance(verbose, bool): raise TypeError("`verbose` must be a bool.") - self._verbose = deepcopy(verbose) + self._verbose = verbose def __repr__(self) -> str: # noqa D107 return ( @@ -232,7 +231,7 @@ def _check_sort_find_stim_period_inputs( raise ValueError( "Entries of `search_samples` must lie in the range [0, n_samples)." ) - self._search_samples = search_samples.copy() + self._search_samples = search_samples if assumed_periods is not None and not isinstance( assumed_periods, (int, float, tuple) @@ -248,18 +247,18 @@ def _check_sort_find_stim_period_inputs( raise TypeError( "If a tuple, entries of `assumed_periods` must be ints or floats." ) - self._assumed_periods = deepcopy(assumed_periods) + self._assumed_periods = assumed_periods if not isinstance(outlier_boundary, (int, float)): raise TypeError("`outlier_boundary` must be an int or a float.") if outlier_boundary <= 0: raise ValueError("`outlier_boundary` must be > 0.") - self._outlier_boundary = deepcopy(outlier_boundary) + self._outlier_boundary = outlier_boundary if random_seed is not None and not isinstance(random_seed, int): raise TypeError("`random_seed` must be an int or None.") if random_seed is not None: - self._random_seed = deepcopy(random_seed) + self._random_seed = random_seed if not isinstance(n_jobs, int): raise TypeError("`n_jobs` must be an int.") @@ -269,7 +268,7 @@ def _check_sort_find_stim_period_inputs( raise ValueError("If `n_jobs` is <= 0, it must be -1.") if n_jobs == -1: n_jobs = cpu_count() - self._n_jobs = deepcopy(n_jobs) + self._n_jobs = n_jobs def _standardise_data(self) -> None: """Take derivatives of data, set S.D. to 1, and clip outliers.""" @@ -285,7 +284,7 @@ def _optimise_period_estimate(self) -> None: """Optimise artefact period estimate.""" random_state = np.random.RandomState(self._random_seed) - estimated_period = deepcopy(self._assumed_periods) + estimated_period = self._assumed_periods opt_sample_lens = np.unique( [ @@ -752,7 +751,7 @@ def _check_sort_create_filter_inputs( raise ValueError( "`omit_n_samples` must lie in the range [0, (no. of samples - 1) // 2)." ) - self._omit_n_samples = deepcopy(omit_n_samples) + self._omit_n_samples = omit_n_samples if period_half_width is None: period_half_width = self._period / 50 @@ -762,7 +761,7 @@ def _check_sort_create_filter_inputs( raise ValueError( "`period_half_width` must be lie in the range (0, period]." ) - self._period_half_width = deepcopy(period_half_width) + self._period_half_width = period_half_width # Must come after `omit_n_samples` and `period_half_width` set! if filter_half_width is None: @@ -776,7 +775,7 @@ def _check_sort_create_filter_inputs( "`filter_half_width` must lie in the range (`omit_n_samples`, " "(no. of samples - 1) // 2]." ) - self._filter_half_width = deepcopy(filter_half_width) + self._filter_half_width = filter_half_width if not isinstance(filter_direction, str): raise TypeError("`filter_direction` must be a str.") @@ -785,11 +784,11 @@ def _check_sort_create_filter_inputs( raise ValueError( f"`filter_direction` must be one of {valid_filter_directions}." ) - self._filter_direction = deepcopy(filter_direction) + self._filter_direction = filter_direction def _get_filter_half_width(self) -> int: """Get appropriate `filter_half_width`, if None given.""" - filter_half_width = deepcopy(self._omit_n_samples) + filter_half_width = self._omit_n_samples check = 0 while check < 50 and filter_half_width < (self._n_samples - 1) // 2: filter_half_width += 1 @@ -890,29 +889,29 @@ def _check_sort_filter_data_inputs(self, data: np.ndarray | None) -> np.ndarray: @property def data(self) -> np.ndarray: - """Return a copy of the data.""" - return self._data.copy() + """Return the data.""" + return self._data @property def period(self) -> float: - """Return a copy of the estimated stimulation period.""" + """Return the estimated stimulation period.""" if self._period is None: raise AttributeError("No period has been computed yet.") - return deepcopy(self._period) + return self._period @property def filter(self) -> np.ndarray: - """Return a copy of the PARRM filter.""" + """Return the PARRM filter.""" if self._filter is None: raise AttributeError("No filter has been computed yet.") - return self._filter.copy() + return self._filter @property def filtered_data(self) -> np.ndarray: - """Return a copy of the most recently filtered data.""" + """Return the most recently filtered data.""" if self._filtered_data is None: raise AttributeError("No data has been filtered yet.") - return deepcopy(self._filtered_data) + return self._filtered_data @property def settings(self) -> dict: