From 1d7a6a4a7f3e1e5c1d1b8da3957aa83013a9964f Mon Sep 17 00:00:00 2001 From: timothyas Date: Mon, 30 Oct 2023 16:17:35 -0600 Subject: [PATCH] use slots in all classes --- xesn/cost.py | 3 +++ xesn/driver.py | 21 +++++++++++---------- xesn/esn.py | 12 ++++++------ xesn/lazyesn.py | 3 +++ xesn/matrix.py | 22 ++++++++-------------- xesn/xdata.py | 5 +++++ 6 files changed, 36 insertions(+), 30 deletions(-) diff --git a/xesn/cost.py b/xesn/cost.py index 1790518..9724dc9 100644 --- a/xesn/cost.py +++ b/xesn/cost.py @@ -19,6 +19,9 @@ from .psd import psd class CostFunction(): + + __slots__ = ("ESN", "train_data", "macro_data", "config") + def __init__(self, ESN, train_data, macro_data, config): self.ESN = ESN diff --git a/xesn/driver.py b/xesn/driver.py index a9ed323..9e56c8a 100644 --- a/xesn/driver.py +++ b/xesn/driver.py @@ -30,17 +30,18 @@ class Driver(): config (str or dict): either a path to a yaml file or dict containing experiment parameters output_directory (str, optional): directory to save results and write logs to """ - name = "driver" - config = None - output_directory = None - walltime = None - localtime = None + + __slots__ = ( + "config", "output_directory", + "walltime", "localtime", + "esn_name", "ESN", + "logfile", "logname", "logger", + ) def __init__(self, config, output_directory=None): - self._make_output_directory(output_directory) self._create_logger() self.set_config(config) @@ -277,7 +278,7 @@ def get_sample_indices(self, data_length, n_samples, n_steps, n_spinup, random_s def _make_output_directory(self, out_dir): """Make provided output directory. If none given, make a unique directory: - output-{self.name}-XX, where XX is 00->99 + output-driver-XX, where XX is 00->99 Args: out_dir (str or None): path to dump output, or None for default @@ -289,11 +290,11 @@ def _make_output_directory(self, out_dir): # make a unique default directory i=0 - out_dir = f"output-{self.name}-{i:02d}" + out_dir = f"output-driver-{i:02d}" while os.path.isdir(out_dir): if i>99: raise ValueError("Hit max number of default output directories...") - out_dir = f"output-{self.name}-{i:02d}" + out_dir = f"output-driver-{i:02d}" i = i+1 os.makedirs(out_dir) @@ -317,7 +318,7 @@ def _create_logger(self): self.logfile = os.path.join(self.output_directory, 'stdout.log') # create a logger - self.logname = f'{self.name}_logger' + self.logname = f'driver_logger' self.logger = logging.getLogger(self.logname) self.logger.setLevel(logging.DEBUG) diff --git a/xesn/esn.py b/xesn/esn.py index a310057..a171ec6 100644 --- a/xesn/esn.py +++ b/xesn/esn.py @@ -18,12 +18,12 @@ class ESN(): """A classic ESN architecture, with no distribution or parallelism. It is assumed that all data used with this architecture can fit into memory. """ - W = None - Win = None - Wout = None - input_kwargs = None - adjacency_kwargs = None - bias_kwargs = None + __slots__ = ( + "W", "Win", "Wout", + "n_input", "n_output", "n_reservoir", + "leak_rate", "tikhonov_parameter", "bias_vector", + "input_kwargs", "adjacency_kwargs", "bias_kwargs", + ) @property def input_factor(self): diff --git a/xesn/lazyesn.py b/xesn/lazyesn.py index 2ae72a1..32247ff 100644 --- a/xesn/lazyesn.py +++ b/xesn/lazyesn.py @@ -23,6 +23,9 @@ class LazyESN(ESN): 2. Non-global axes, i.e., axes which is chunked up or made up of patches, are first 3. Can handle multi-dimensional data, but only 2D chunking """ + __slots__ = ( + "esn_chunks", "overlap", "persist", "boundary" + ) @property def output_chunks(self): diff --git a/xesn/matrix.py b/xesn/matrix.py index c522f63..6f008c2 100644 --- a/xesn/matrix.py +++ b/xesn/matrix.py @@ -71,18 +71,11 @@ class RandomMatrix(): random_seed (int, optional): used to control the RNG for matrix generation """ - # Set by user - n_rows = None - n_cols = None - - distribution = None - normalization = "multiply" - factor = 1.0 - - random_seed = None - - # Set automatically - dist_kw = None + __slots__ = ( + "n_rows", "n_cols", + "distribution", "normalization", "factor", + "random_seed", "random_state", + ) def __init__( self, @@ -225,8 +218,9 @@ class SparseRandomMatrix(RandomMatrix): random_seed (int, optional): used to control the RNG for matrix generation """ - density = None - format = "coo" # scipy's default + __slots__ = ( + "density", "format", + ) def __init__( self, diff --git a/xesn/xdata.py b/xesn/xdata.py index 317386d..1b17b4b 100644 --- a/xesn/xdata.py +++ b/xesn/xdata.py @@ -8,6 +8,11 @@ class XData(): See :meth:`setup` for the main usage. """ + __slots__ = ( + "field_name", "zstore_path", + "dimensions", "subsampling", "normalization", + ) + def __init__(self, field_name, zstore_path,