Skip to content

Commit

Permalink
Slots (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
timothyas authored Oct 31, 2023
1 parent 4aec638 commit 4db35a9
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 30 deletions.
3 changes: 3 additions & 0 deletions xesn/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from .psd import psd

class CostFunction():

__slots__ = ("ESN", "train_data", "macro_data", "config")

def __init__(self, ESN, train_data, macro_data, config):

self.ESN = ESN
Expand Down
21 changes: 11 additions & 10 deletions xesn/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ class Driver():
config (str or dict): either a path to a yaml file or dict containing experiment parameters
output_directory (str, optional): directory to save results and write logs to
"""
name = "driver"
config = None
output_directory = None
walltime = None
localtime = None

__slots__ = (
"config", "output_directory",
"walltime", "localtime",
"esn_name", "ESN",
"logfile", "logname", "logger",
)

def __init__(self,
config,
output_directory=None):


self._make_output_directory(output_directory)
self._create_logger()
self.set_config(config)
Expand Down Expand Up @@ -277,7 +278,7 @@ def get_sample_indices(self, data_length, n_samples, n_steps, n_spinup, random_s

def _make_output_directory(self, out_dir):
"""Make provided output directory. If none given, make a unique directory:
output-{self.name}-XX, where XX is 00->99
output-driver-XX, where XX is 00->99
Args:
out_dir (str or None): path to dump output, or None for default
Expand All @@ -289,11 +290,11 @@ def _make_output_directory(self, out_dir):

# make a unique default directory
i=0
out_dir = f"output-{self.name}-{i:02d}"
out_dir = f"output-driver-{i:02d}"
while os.path.isdir(out_dir):
if i>99:
raise ValueError("Hit max number of default output directories...")
out_dir = f"output-{self.name}-{i:02d}"
out_dir = f"output-driver-{i:02d}"
i = i+1
os.makedirs(out_dir)

Expand All @@ -317,7 +318,7 @@ def _create_logger(self):
self.logfile = os.path.join(self.output_directory, 'stdout.log')

# create a logger
self.logname = f'{self.name}_logger'
self.logname = f'driver_logger'
self.logger = logging.getLogger(self.logname)
self.logger.setLevel(logging.DEBUG)

Expand Down
12 changes: 6 additions & 6 deletions xesn/esn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ class ESN():
"""A classic ESN architecture, with no distribution or parallelism.
It is assumed that all data used with this architecture can fit into memory.
"""
W = None
Win = None
Wout = None
input_kwargs = None
adjacency_kwargs = None
bias_kwargs = None
__slots__ = (
"W", "Win", "Wout",
"n_input", "n_output", "n_reservoir",
"leak_rate", "tikhonov_parameter", "bias_vector",
"input_kwargs", "adjacency_kwargs", "bias_kwargs",
)

@property
def input_factor(self):
Expand Down
3 changes: 3 additions & 0 deletions xesn/lazyesn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class LazyESN(ESN):
2. Non-global axes, i.e., axes which is chunked up or made up of patches, are first
3. Can handle multi-dimensional data, but only 2D chunking
"""
__slots__ = (
"esn_chunks", "overlap", "persist", "boundary"
)

@property
def output_chunks(self):
Expand Down
22 changes: 8 additions & 14 deletions xesn/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,11 @@ class RandomMatrix():
random_seed (int, optional): used to control the RNG for matrix generation
"""

# Set by user
n_rows = None
n_cols = None

distribution = None
normalization = "multiply"
factor = 1.0

random_seed = None

# Set automatically
dist_kw = None
__slots__ = (
"n_rows", "n_cols",
"distribution", "normalization", "factor",
"random_seed", "random_state",
)

def __init__(
self,
Expand Down Expand Up @@ -225,8 +218,9 @@ class SparseRandomMatrix(RandomMatrix):
random_seed (int, optional): used to control the RNG for matrix generation
"""

density = None
format = "coo" # scipy's default
__slots__ = (
"density", "format",
)

def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions xesn/xdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class XData():
See :meth:`setup` for the main usage.
"""

__slots__ = (
"field_name", "zstore_path",
"dimensions", "subsampling", "normalization",
)

def __init__(self,
field_name,
zstore_path,
Expand Down

0 comments on commit 4db35a9

Please sign in to comment.