Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MAINT] Reduce copying of params #10

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/pyparrm/_utils/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Author(s):
# Thomas Samuel Binns | github.com/tsbinns

from copy import deepcopy
from multiprocessing import cpu_count

from matplotlib import pyplot as plt
Expand Down Expand Up @@ -112,7 +111,7 @@ def _check_sort_init_inputs(
"PyPARRM Internal Error: `_ParamSelection` should only be called if the "
"period has been estimated. Please contact the PyPARRM developers."
)
self.parrm = deepcopy(parrm)
self.parrm = parrm
self.parrm._verbose = False

# time_range
Expand Down Expand Up @@ -163,14 +162,14 @@ def _check_sort_init_inputs(
)
if freq_range[0] >= freq_range[1]:
raise ValueError("`freq_range[1]` must be > `freq_range[0]`.")
self.freq_range = deepcopy(freq_range)
self.freq_range = freq_range

# freq_res
if not isinstance(freq_res, (int, float)):
raise TypeError("`freq_res` must be an int or a float.")
if freq_res <= 0 or freq_res > self.parrm._sampling_freq / 2:
raise ValueError("`freq_res` must lie in the range (0, Nyquist frequency].")
self.freq_res = deepcopy(freq_res)
self.freq_res = freq_res

# n_jobs
if not isinstance(n_jobs, int):
Expand All @@ -181,7 +180,7 @@ def _check_sort_init_inputs(
raise ValueError("If `n_jobs` is <= 0, it must be -1.")
if n_jobs == -1:
n_jobs = cpu_count()
self.n_jobs = deepcopy(n_jobs)
self.n_jobs = n_jobs

self.parrm._check_sort_create_filter_inputs(None, 0, "both", None)
self.current_period_half_width = self.parrm._period_half_width
Expand Down
47 changes: 23 additions & 24 deletions src/pyparrm/parrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Author(s):
# Thomas Samuel Binns | github.com/tsbinns

from copy import deepcopy
from multiprocessing import cpu_count

import numpy as np
Expand Down Expand Up @@ -123,23 +122,23 @@ def _check_init_inputs(
raise TypeError("`data` must be a NumPy array.")
if data.ndim != 2:
raise ValueError("`data` must be a 2D array.")
self._data = data.copy()
self._data = data

if not isinstance(sampling_freq, (int, float)):
raise TypeError("`sampling_freq` must be an int or a float.")
if sampling_freq <= 0:
raise ValueError("`sampling_freq` must be > 0.")
self._sampling_freq = deepcopy(sampling_freq)
self._sampling_freq = sampling_freq

if not isinstance(artefact_freq, (int, float)):
raise TypeError("`artefact_freq` must be an int or a float.")
if artefact_freq <= 0:
raise ValueError("`artefact_freq` must be > 0.")
self._artefact_freq = deepcopy(artefact_freq)
self._artefact_freq = artefact_freq

if not isinstance(verbose, bool):
raise TypeError("`verbose` must be a bool.")
self._verbose = deepcopy(verbose)
self._verbose = verbose

def __repr__(self) -> str: # noqa D107
return (
Expand Down Expand Up @@ -232,7 +231,7 @@ def _check_sort_find_stim_period_inputs(
raise ValueError(
"Entries of `search_samples` must lie in the range [0, n_samples)."
)
self._search_samples = search_samples.copy()
self._search_samples = search_samples

if assumed_periods is not None and not isinstance(
assumed_periods, (int, float, tuple)
Expand All @@ -248,18 +247,18 @@ def _check_sort_find_stim_period_inputs(
raise TypeError(
"If a tuple, entries of `assumed_periods` must be ints or floats."
)
self._assumed_periods = deepcopy(assumed_periods)
self._assumed_periods = assumed_periods

if not isinstance(outlier_boundary, (int, float)):
raise TypeError("`outlier_boundary` must be an int or a float.")
if outlier_boundary <= 0:
raise ValueError("`outlier_boundary` must be > 0.")
self._outlier_boundary = deepcopy(outlier_boundary)
self._outlier_boundary = outlier_boundary

if random_seed is not None and not isinstance(random_seed, int):
raise TypeError("`random_seed` must be an int or None.")
if random_seed is not None:
self._random_seed = deepcopy(random_seed)
self._random_seed = random_seed

if not isinstance(n_jobs, int):
raise TypeError("`n_jobs` must be an int.")
Expand All @@ -269,7 +268,7 @@ def _check_sort_find_stim_period_inputs(
raise ValueError("If `n_jobs` is <= 0, it must be -1.")
if n_jobs == -1:
n_jobs = cpu_count()
self._n_jobs = deepcopy(n_jobs)
self._n_jobs = n_jobs

def _standardise_data(self) -> None:
"""Take derivatives of data, set S.D. to 1, and clip outliers."""
Expand All @@ -285,7 +284,7 @@ def _optimise_period_estimate(self) -> None:
"""Optimise artefact period estimate."""
random_state = np.random.RandomState(self._random_seed)

estimated_period = deepcopy(self._assumed_periods)
estimated_period = self._assumed_periods

opt_sample_lens = np.unique(
[
Expand Down Expand Up @@ -752,7 +751,7 @@ def _check_sort_create_filter_inputs(
raise ValueError(
"`omit_n_samples` must lie in the range [0, (no. of samples - 1) // 2)."
)
self._omit_n_samples = deepcopy(omit_n_samples)
self._omit_n_samples = omit_n_samples

if period_half_width is None:
period_half_width = self._period / 50
Expand All @@ -762,7 +761,7 @@ def _check_sort_create_filter_inputs(
raise ValueError(
"`period_half_width` must be lie in the range (0, period]."
)
self._period_half_width = deepcopy(period_half_width)
self._period_half_width = period_half_width

# Must come after `omit_n_samples` and `period_half_width` set!
if filter_half_width is None:
Expand All @@ -776,7 +775,7 @@ def _check_sort_create_filter_inputs(
"`filter_half_width` must lie in the range (`omit_n_samples`, "
"(no. of samples - 1) // 2]."
)
self._filter_half_width = deepcopy(filter_half_width)
self._filter_half_width = filter_half_width

if not isinstance(filter_direction, str):
raise TypeError("`filter_direction` must be a str.")
Expand All @@ -785,11 +784,11 @@ def _check_sort_create_filter_inputs(
raise ValueError(
f"`filter_direction` must be one of {valid_filter_directions}."
)
self._filter_direction = deepcopy(filter_direction)
self._filter_direction = filter_direction

def _get_filter_half_width(self) -> int:
"""Get appropriate `filter_half_width`, if None given."""
filter_half_width = deepcopy(self._omit_n_samples)
filter_half_width = self._omit_n_samples
check = 0
while check < 50 and filter_half_width < (self._n_samples - 1) // 2:
filter_half_width += 1
Expand Down Expand Up @@ -890,29 +889,29 @@ def _check_sort_filter_data_inputs(self, data: np.ndarray | None) -> np.ndarray:

@property
def data(self) -> np.ndarray:
"""Return a copy of the data."""
return self._data.copy()
"""Return the data."""
return self._data

@property
def period(self) -> float:
"""Return a copy of the estimated stimulation period."""
"""Return the estimated stimulation period."""
if self._period is None:
raise AttributeError("No period has been computed yet.")
return deepcopy(self._period)
return self._period

@property
def filter(self) -> np.ndarray:
"""Return a copy of the PARRM filter."""
"""Return the PARRM filter."""
if self._filter is None:
raise AttributeError("No filter has been computed yet.")
return self._filter.copy()
return self._filter

@property
def filtered_data(self) -> np.ndarray:
"""Return a copy of the most recently filtered data."""
"""Return the most recently filtered data."""
if self._filtered_data is None:
raise AttributeError("No data has been filtered yet.")
return deepcopy(self._filtered_data)
return self._filtered_data

@property
def settings(self) -> dict:
Expand Down