Skip to content

Commit

Permalink
standardize split class args
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Feb 21, 2023
1 parent 2b60981 commit b5f3436
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/ci-tutorials/crossvalidation.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ dartboard = datasets.Dartboard(coords=coords)
# create cross validator using this "dartboard"
k = 5
cv = crossval.DartboardSplitGridded(dset, k, dartboard=dartboard, npseed=42)
cv = crossval.DartboardSplitGridded(dset, k, dartboard=dartboard, seed=42)
# ``cv`` is a Python iterator, it will return a ``(train, test)`` pair of ``GriddedDataset``s for each iteration.
# Because we'll want to revisit the individual datasets
Expand Down
8 changes: 4 additions & 4 deletions docs/large-tutorials/HD143006_part_2.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/large-tutorials/HD143006_part_2.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ dartboard = datasets.Dartboard(coords=coords)
# create cross validator using this "dartboard"
k = 5
cv = crossval.DartboardSplitGridded(dataset, k, dartboard=dartboard, npseed=42)
cv = crossval.DartboardSplitGridded(dataset, k, dartboard=dartboard, seed=42)
# ``cv`` is a Python iterator, it will return a ``(train, test)`` pair of ``GriddedDataset``s for each iteration.
# Because we'll want to revisit the individual datasets
Expand Down
2 changes: 1 addition & 1 deletion examples/HD143006/common_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

# create cross validator using this "dartboard"
k = 5
cv = crossval.DartboardSplitGridded(dataset, k, dartboard=dartboard, npseed=42)
cv = crossval.DartboardSplitGridded(dataset, k, dartboard=dartboard, seed=42)
k_fold_datasets = [(train, test) for (train, test) in cv]

# create the model
Expand Down
34 changes: 17 additions & 17 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def split_dataset(self, dataset):
"""
if self._split_method == "random_cell":
split_iterator = RandomCellSplitGridded(
dataset=dataset, kfolds=self._kfolds, seed=self._seed
dataset=dataset, k=self._kfolds, seed=self._seed
)

elif self._split_method == "dartboard":
Expand All @@ -151,7 +151,7 @@ def split_dataset(self, dataset):

# use 'dartboard' to split full dataset into train/test subsets
split_iterator = DartboardSplitGridded(
dataset, k=self._kfolds, dartboard=dartboard, npseed=self._seed
dataset, k=self._kfolds, dartboard=dartboard, seed=self._seed
)

else:
Expand Down Expand Up @@ -251,7 +251,7 @@ class RandomCellSplitGridded:
----------
dataset : PyTorch dataset object
Instance of the `mpol.datasets.GriddedDataset` class
kfolds : int, default=5
k : int, default=5
Number of k-folds (partitions) of `dataset`
seed : int, default=None
Seed for PyTorch random number generator used to shuffle data before
Expand All @@ -260,9 +260,9 @@ class RandomCellSplitGridded:
Channel of the dataset to use in determining the splits
Once initialized, iterate through the datasets like:
>>> split_iterator = crossval.RandomCellSplitGridded(dataset, kfolds)
>>> for (train, test) in split_iterator: # iterate through `kfolds` datasets
>>> ... # working with the n-th slice of `kfolds` datasets
>>> split_iterator = crossval.RandomCellSplitGridded(dataset, k)
>>> for (train, test) in split_iterator: # iterate through `k` datasets
>>> ... # working with the n-th slice of `k` datasets
>>> ... # do operations with train dataset
>>> ... # do operations with test dataset
Expand All @@ -273,9 +273,9 @@ class RandomCellSplitGridded:
The splitting doesn't select (preserve) Hermitian pairs of visibilities.
"""

def __init__(self, dataset, kfolds=5, seed=None, channel=0):
def __init__(self, dataset, k=5, seed=None, channel=0):
self.dataset = dataset
self.kfolds = kfolds
self.k = k
self.channel = channel

# get indices for cells in the top 1% of gridded weight
Expand Down Expand Up @@ -306,15 +306,15 @@ def __init__(self, dataset, kfolds=5, seed=None, channel=0):
split_idx = split_idx[:, shuffle]

# split indices into k subsets
self.splits = torch.tensor_split(split_idx, self.kfolds, dim=1)
self.splits = torch.tensor_split(split_idx, self.k, dim=1)

def __iter__(self):
# current k-slice
self._n = 0
return self

def __next__(self):
if self._n < self.kfolds:
if self._n < self.k:
test_idx = self.splits[self._n]
train_idx = torch.cat(
([self.splits[x] for x in range(len(self.splits)) if x != self._n]),
Expand Down Expand Up @@ -358,7 +358,7 @@ class DartboardSplitGridded:
dartboard (:class:`~mpol.datasets.Dartboard`): a pre-initialized Dartboard instance. If ``dartboard`` is provided, do not provide ``q_edges`` or ``phi_edges``.
q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``.
phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians]. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`.
npseed (int): (optional) numpy random seed to use for the permutation, for reproducibility
seed (int): (optional) numpy random seed to use for the permutation, for reproducibility
Once initialized, iterate through the datasets like
Expand All @@ -375,7 +375,7 @@ def __init__(
gridded_dataset: GriddedDataset,
k: int,
dartboard: Dartboard | None = None,
npseed: int | None = None,
seed: int | None = None,
):
if k <= 0:
raise ValueError("k must be a positive integer")
Expand All @@ -401,8 +401,8 @@ def __init__(
# partition the cell_list into k pieces
# first, randomly permute the sequence to make sure
# we don't get structured radial/azimuthal patterns
if npseed is not None:
np.random.seed(npseed)
if seed is not None:
np.random.seed(seed)

self.k_split_cell_list = np.array_split(
np.random.permutation(self.cell_list), k
Expand All @@ -415,7 +415,7 @@ def from_dartboard_properties(
k: int,
q_edges: NDArray[floating[Any]],
phi_edges: NDArray[floating[Any]],
npseed: int | None = None,
seed: int | None = None,
) -> DartboardSplitGridded:
"""
Alternative method to initialize a DartboardSplitGridded object from Dartboard parameters.
Expand All @@ -425,10 +425,10 @@ def from_dartboard_properties(
k (int): the number of subpartitions of the dataset
q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``.
phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians]. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`.
npseed (int): (optional) numpy random seed to use for the permutation, for reproducibility
seed (int): (optional) numpy random seed to use for the permutation, for reproducibility
"""
dartboard = Dartboard(gridded_dataset.coords, q_edges, phi_edges)
return cls(gridded_dataset, k, dartboard, npseed)
return cls(gridded_dataset, k, dartboard, seed)

def __iter__(self) -> DartboardSplitGridded:
self.n = 0 # the current k-slice we're on
Expand Down
2 changes: 1 addition & 1 deletion src/mpol/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def splitter_diagnostics_fig(splitter, channel=0, save_prefix=None):
No assumption or correction is made concerning whether the (u,v) distances
are projected or deprojected.
"""
fig, axes = plt.subplots(nrows=splitter.kfolds, ncols=2, figsize=(4, 10))
fig, axes = plt.subplots(nrows=splitter.k, ncols=2, figsize=(4, 10))

for ii, (train, test) in enumerate(splitter):
train_mask = torch2npy(train.ground_mask[channel])
Expand Down

0 comments on commit b5f3436

Please sign in to comment.