Skip to content

Commit

Permalink
Merge pull request #222 from MPoL-dev/dartboard_1kfold
Browse files Browse the repository at this point in the history
Dartboard: handle case for 1 k-fold
  • Loading branch information
jeffjennings authored Dec 6, 2023
2 parents c86d8f9 + 2631c31 commit a8fdb34
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def run_crossval(self, dataset):
self._diagnostics["models"].append(self._model)
self._diagnostics["regularizers"].append(self._regularizers)
self._diagnostics["loss_histories"].append(loss_history)
self._diagnostics["train_figures"].append(self._train_figure)
# self._diagnostics["train_figures"].append(self._train_figure)

# average individual test scores to get the cross-val metric for chosen
# hyperparameters
Expand Down Expand Up @@ -297,7 +297,7 @@ def split_figure(self):

@property
def diagnostics(self):
"""Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values, training figures"""
"""Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values"""
return self._diagnostics


Expand Down Expand Up @@ -443,6 +443,7 @@ def __init__(
k: int,
dartboard: Dartboard | None = None,
seed: int | None = None,
verbose: bool = True
):
if k <= 0:
raise ValueError("k must be a positive integer")
Expand All @@ -453,6 +454,7 @@ def __init__(
self.griddedDataset = gridded_dataset
self.k = k
self.dartboard = dartboard
self.verbose = verbose

# 2D mask for any UV cells that contain visibilities
# in *any* channel
Expand Down Expand Up @@ -490,6 +492,7 @@ def from_dartboard_properties(
q_edges: NDArray[floating[Any]],
phi_edges: NDArray[floating[Any]],
seed: int | None = None,
verbose: bool = True,
) -> DartboardSplitGridded:
r"""
Alternative method to initialize a DartboardSplitGridded object from Dartboard parameters.
Expand All @@ -500,9 +503,10 @@ def from_dartboard_properties(
q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``.
phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians]. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`.
seed (int): (optional) numpy random seed to use for the permutation, for reproducibility
verbose (bool): whether to print notification messages
"""
dartboard = Dartboard(gridded_dataset.coords, q_edges, phi_edges)
return cls(gridded_dataset, k, dartboard, seed)
return cls(gridded_dataset, k, dartboard, seed, verbose)

def __iter__(self) -> DartboardSplitGridded:
self.n = 0 # the current k-slice we're on
Expand All @@ -511,7 +515,16 @@ def __iter__(self) -> DartboardSplitGridded:
def __next__(self) -> tuple[GriddedDataset, GriddedDataset]:
if self.n < self.k:
k_list = self.k_split_cell_list.copy()
cell_list_test = k_list.pop(self.n)
if self.k == 1:
if self.verbose is True:
logging.info(" DartboardSplitGridded: only 1 k-fold: splitting dataset as ~80/20 train/test")
ntest = round(0.2 * len(k_list[0]))
# put ~20% of cells into test set
cell_list_test = k_list[0][:ntest]
# remove cells in test set from train set
k_list[0] = np.delete(k_list[0], range(ntest), axis=0)
else:
cell_list_test = k_list.pop(self.n)

# put the remaining indices back into a full list
cell_list_train = np.concatenate(k_list)
Expand Down
21 changes: 21 additions & 0 deletions test/crossval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@ def test_crossvalclass_split_dartboard(coords, imager, dataset, generic_paramete
cross_validator.split_dataset(dataset)


def test_crossvalclass_split_dartboard_1kfold(coords, imager, dataset, generic_parameters):
# using the CrossValidate class, split a dataset into train/test subsets
# using 'dartboard' splitter with only 1 k-fold; check that the train set
# has ~80% of the model visibilities

crossval_pars = generic_parameters["crossval_pars"]
crossval_pars["split_method"] = "dartboard"
crossval_pars['kfolds'] = 1

cross_validator = CrossValidate(coords, imager, **crossval_pars)
split_iterator = cross_validator.split_dataset(dataset)

for (train_set, test_set) in split_iterator:
ntrain = len(train_set.vis_indexed)
ntest = len(test_set.vis_indexed)

ratio = ntrain / (ntrain + ntest)

np.testing.assert_allclose(ratio, 0.8, atol=0.05)


def test_crossvalclass_split_randomcell(coords, imager, dataset, generic_parameters):
# using the CrossValidate class, split a dataset into train/test subsets
# using 'random_cell' splitter
Expand Down

0 comments on commit a8fdb34

Please sign in to comment.