Skip to content

Commit

Permalink
[MAINT] Reduce copying of params (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns authored Sep 17, 2024
1 parent 060ccab commit 04a304c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
9 changes: 4 additions & 5 deletions src/pyparrm/_utils/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Author(s):
# Thomas Samuel Binns | github.com/tsbinns

from copy import deepcopy
from multiprocessing import cpu_count

from matplotlib import pyplot as plt
Expand Down Expand Up @@ -112,7 +111,7 @@ def _check_sort_init_inputs(
"PyPARRM Internal Error: `_ParamSelection` should only be called if the "
"period has been estimated. Please contact the PyPARRM developers."
)
self.parrm = deepcopy(parrm)
self.parrm = parrm
self.parrm._verbose = False

# time_range
Expand Down Expand Up @@ -163,14 +162,14 @@ def _check_sort_init_inputs(
)
if freq_range[0] >= freq_range[1]:
raise ValueError("`freq_range[1]` must be > `freq_range[0]`.")
self.freq_range = deepcopy(freq_range)
self.freq_range = freq_range

# freq_res
if not isinstance(freq_res, (int, float)):
raise TypeError("`freq_res` must be an int or a float.")
if freq_res <= 0 or freq_res > self.parrm._sampling_freq / 2:
raise ValueError("`freq_res` must lie in the range (0, Nyquist frequency].")
self.freq_res = deepcopy(freq_res)
self.freq_res = freq_res

# n_jobs
if not isinstance(n_jobs, int):
Expand All @@ -181,7 +180,7 @@ def _check_sort_init_inputs(
raise ValueError("If `n_jobs` is <= 0, it must be -1.")
if n_jobs == -1:
n_jobs = cpu_count()
self.n_jobs = deepcopy(n_jobs)
self.n_jobs = n_jobs

self.parrm._check_sort_create_filter_inputs(None, 0, "both", None)
self.current_period_half_width = self.parrm._period_half_width
Expand Down
47 changes: 23 additions & 24 deletions src/pyparrm/parrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Author(s):
# Thomas Samuel Binns | github.com/tsbinns

from copy import deepcopy
from multiprocessing import cpu_count

import numpy as np
Expand Down Expand Up @@ -123,23 +122,23 @@ def _check_init_inputs(
raise TypeError("`data` must be a NumPy array.")
if data.ndim != 2:
raise ValueError("`data` must be a 2D array.")
self._data = data.copy()
self._data = data

if not isinstance(sampling_freq, (int, float)):
raise TypeError("`sampling_freq` must be an int or a float.")
if sampling_freq <= 0:
raise ValueError("`sampling_freq` must be > 0.")
self._sampling_freq = deepcopy(sampling_freq)
self._sampling_freq = sampling_freq

if not isinstance(artefact_freq, (int, float)):
raise TypeError("`artefact_freq` must be an int or a float.")
if artefact_freq <= 0:
raise ValueError("`artefact_freq` must be > 0.")
self._artefact_freq = deepcopy(artefact_freq)
self._artefact_freq = artefact_freq

if not isinstance(verbose, bool):
raise TypeError("`verbose` must be a bool.")
self._verbose = deepcopy(verbose)
self._verbose = verbose

def __repr__(self) -> str: # noqa D107
return (
Expand Down Expand Up @@ -232,7 +231,7 @@ def _check_sort_find_stim_period_inputs(
raise ValueError(
"Entries of `search_samples` must lie in the range [0, n_samples)."
)
self._search_samples = search_samples.copy()
self._search_samples = search_samples

if assumed_periods is not None and not isinstance(
assumed_periods, (int, float, tuple)
Expand All @@ -248,18 +247,18 @@ def _check_sort_find_stim_period_inputs(
raise TypeError(
"If a tuple, entries of `assumed_periods` must be ints or floats."
)
self._assumed_periods = deepcopy(assumed_periods)
self._assumed_periods = assumed_periods

if not isinstance(outlier_boundary, (int, float)):
raise TypeError("`outlier_boundary` must be an int or a float.")
if outlier_boundary <= 0:
raise ValueError("`outlier_boundary` must be > 0.")
self._outlier_boundary = deepcopy(outlier_boundary)
self._outlier_boundary = outlier_boundary

if random_seed is not None and not isinstance(random_seed, int):
raise TypeError("`random_seed` must be an int or None.")
if random_seed is not None:
self._random_seed = deepcopy(random_seed)
self._random_seed = random_seed

if not isinstance(n_jobs, int):
raise TypeError("`n_jobs` must be an int.")
Expand All @@ -269,7 +268,7 @@ def _check_sort_find_stim_period_inputs(
raise ValueError("If `n_jobs` is <= 0, it must be -1.")
if n_jobs == -1:
n_jobs = cpu_count()
self._n_jobs = deepcopy(n_jobs)
self._n_jobs = n_jobs

def _standardise_data(self) -> None:
"""Take derivatives of data, set S.D. to 1, and clip outliers."""
Expand All @@ -285,7 +284,7 @@ def _optimise_period_estimate(self) -> None:
"""Optimise artefact period estimate."""
random_state = np.random.RandomState(self._random_seed)

estimated_period = deepcopy(self._assumed_periods)
estimated_period = self._assumed_periods

opt_sample_lens = np.unique(
[
Expand Down Expand Up @@ -752,7 +751,7 @@ def _check_sort_create_filter_inputs(
raise ValueError(
"`omit_n_samples` must lie in the range [0, (no. of samples - 1) // 2)."
)
self._omit_n_samples = deepcopy(omit_n_samples)
self._omit_n_samples = omit_n_samples

if period_half_width is None:
period_half_width = self._period / 50
Expand All @@ -762,7 +761,7 @@ def _check_sort_create_filter_inputs(
raise ValueError(
"`period_half_width` must be lie in the range (0, period]."
)
self._period_half_width = deepcopy(period_half_width)
self._period_half_width = period_half_width

# Must come after `omit_n_samples` and `period_half_width` set!
if filter_half_width is None:
Expand All @@ -776,7 +775,7 @@ def _check_sort_create_filter_inputs(
"`filter_half_width` must lie in the range (`omit_n_samples`, "
"(no. of samples - 1) // 2]."
)
self._filter_half_width = deepcopy(filter_half_width)
self._filter_half_width = filter_half_width

if not isinstance(filter_direction, str):
raise TypeError("`filter_direction` must be a str.")
Expand All @@ -785,11 +784,11 @@ def _check_sort_create_filter_inputs(
raise ValueError(
f"`filter_direction` must be one of {valid_filter_directions}."
)
self._filter_direction = deepcopy(filter_direction)
self._filter_direction = filter_direction

def _get_filter_half_width(self) -> int:
"""Get appropriate `filter_half_width`, if None given."""
filter_half_width = deepcopy(self._omit_n_samples)
filter_half_width = self._omit_n_samples
check = 0
while check < 50 and filter_half_width < (self._n_samples - 1) // 2:
filter_half_width += 1
Expand Down Expand Up @@ -890,29 +889,29 @@ def _check_sort_filter_data_inputs(self, data: np.ndarray | None) -> np.ndarray:

@property
def data(self) -> np.ndarray:
"""Return a copy of the data."""
return self._data.copy()
"""Return the data."""
return self._data

@property
def period(self) -> float:
"""Return a copy of the estimated stimulation period."""
"""Return the estimated stimulation period."""
if self._period is None:
raise AttributeError("No period has been computed yet.")
return deepcopy(self._period)
return self._period

@property
def filter(self) -> np.ndarray:
"""Return a copy of the PARRM filter."""
"""Return the PARRM filter."""
if self._filter is None:
raise AttributeError("No filter has been computed yet.")
return self._filter.copy()
return self._filter

@property
def filtered_data(self) -> np.ndarray:
"""Return a copy of the most recently filtered data."""
"""Return the most recently filtered data."""
if self._filtered_data is None:
raise AttributeError("No data has been filtered yet.")
return deepcopy(self._filtered_data)
return self._filtered_data

@property
def settings(self) -> dict:
Expand Down

0 comments on commit 04a304c

Please sign in to comment.