From 19ec677cde2b92529c278c1969c2b492c3a84429 Mon Sep 17 00:00:00 2001 From: Jeff Jennings Date: Mon, 4 Dec 2023 04:18:57 -0500 Subject: [PATCH 1/5] DartboardSplitGridded: 1 k-fold case --- src/mpol/crossval.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 7aa30cee..2a1e64f0 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -252,7 +252,7 @@ def run_crossval(self, dataset): self._diagnostics["models"].append(self._model) self._diagnostics["regularizers"].append(self._regularizers) self._diagnostics["loss_histories"].append(loss_history) - self._diagnostics["train_figures"].append(self._train_figure) + # self._diagnostics["train_figures"].append(self._train_figure) # average individual test scores to get the cross-val metric for chosen # hyperparameters @@ -296,7 +296,7 @@ def split_figure(self): @property def diagnostics(self): - """Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values, training figures""" + """Dict containing diagnostics of the cross-validation loop across all kfolds: models, regularizers, loss values""" return self._diagnostics @@ -510,7 +510,14 @@ def __iter__(self) -> DartboardSplitGridded: def __next__(self) -> tuple[GriddedDataset, GriddedDataset]: if self.n < self.k: k_list = self.k_split_cell_list.copy() - cell_list_test = k_list.pop(self.n) + if self.k == 1: + # number of cells ~equal to 20% of full cell list + ntest = round(0.2 * len(k_list[0])) + cell_list_test = k_list[0][:ntest] + # remove cells place in test set from train set + k_list[0] = np.delete(k_list[0], range(ntest), axis=0) + else: + cell_list_test = k_list.pop(self.n) # put the remaining indices back into a full list cell_list_train = np.concatenate(k_list) From 000e435399ca43b69ecf35d5553bec79f4454261 Mon Sep 17 00:00:00 2001 From: Jeff Jennings Date: Mon, 4 Dec 2023 04:20:01 -0500 Subject: [PATCH 2/5] DartboardSplitGridded: verbose arg, message --- src/mpol/crossval.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 2a1e64f0..9847d98d 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -442,6 +442,7 @@ def __init__( k: int, dartboard: Dartboard | None = None, seed: int | None = None, + verbose: bool = True ): if k <= 0: raise ValueError("k must be a positive integer") @@ -452,6 +453,7 @@ def __init__( self.griddedDataset = gridded_dataset self.k = k self.dartboard = dartboard + self.verbose = verbose # 2D mask for any UV cells that contain visibilities # in *any* channel @@ -489,6 +491,7 @@ def from_dartboard_properties( q_edges: NDArray[floating[Any]], phi_edges: NDArray[floating[Any]], seed: int | None = None, + verbose: bool = True, ) -> DartboardSplitGridded: r""" Alternative method to initialize a DartboardSplitGridded object from Dartboard parameters. @@ -499,9 +502,10 @@ def from_dartboard_properties( q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``. phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians]. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`. seed (int): (optional) numpy random seed to use for the permutation, for reproducibility + verbose (bool): whether to print notification messages """ dartboard = Dartboard(gridded_dataset.coords, q_edges, phi_edges) - return cls(gridded_dataset, k, dartboard, seed) + return cls(gridded_dataset, k, dartboard, seed, verbose) def __iter__(self) -> DartboardSplitGridded: self.n = 0 # the current k-slice we're on @@ -511,6 +515,8 @@ def __next__(self) -> tuple[GriddedDataset, GriddedDataset]: if self.n < self.k: k_list = self.k_split_cell_list.copy() if self.k == 1: + if self.verbose is True: + logging.info(" DartboardSplitGridded: only 1 k-fold: splitting dataset as ~80/20 train/test") # number of cells ~equal to 20% of full cell list ntest = round(0.2 * len(k_list[0])) cell_list_test = k_list[0][:ntest] From 05473f09cc68f0f6429ae709ce0845e337dad48a Mon Sep 17 00:00:00 2001 From: Jeff Jennings Date: Mon, 4 Dec 2023 04:25:13 -0500 Subject: [PATCH 3/5] DartboardSplitGridded: comments --- src/mpol/crossval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 9847d98d..6a320e4c 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -517,10 +517,10 @@ def __next__(self) -> tuple[GriddedDataset, GriddedDataset]: if self.k == 1: if self.verbose is True: logging.info(" DartboardSplitGridded: only 1 k-fold: splitting dataset as ~80/20 train/test") - # number of cells ~equal to 20% of full cell list ntest = round(0.2 * len(k_list[0])) + # put ~20% of cells into test set cell_list_test = k_list[0][:ntest] - # remove cells place in test set from train set + # remove cells in test set from train set k_list[0] = np.delete(k_list[0], range(ntest), axis=0) else: cell_list_test = k_list.pop(self.n) From ef57810587239ce677df66231c480d84806a78f8 Mon Sep 17 00:00:00 2001 From: Jeff Jennings Date: Wed, 6 Dec 2023 12:53:33 -0500 Subject: [PATCH 4/5] crossval_test: add test for dartboard 1 k-fold --- test/crossval_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/crossval_test.py b/test/crossval_test.py index f83830ef..bcc90f68 100644 --- a/test/crossval_test.py +++ b/test/crossval_test.py @@ -21,6 +21,23 @@ def test_crossvalclass_split_dartboard(coords, imager, dataset, generic_paramete cross_validator = CrossValidate(coords, imager, **crossval_pars) cross_validator.split_dataset(dataset) +def test_crossvalclass_split_dartboard_1kfold(coords, imager, dataset, generic_parameters): + + crossval_pars = generic_parameters["crossval_pars"] + crossval_pars["split_method"] = "dartboard" + crossval_pars['kfolds'] = 1 + + cross_validator = CrossValidate(coords, imager, **crossval_pars) + split_iterator = cross_validator.split_dataset(dataset) + + for (train_set, test_set) in split_iterator: + ntrain = len(train_set.vis_indexed) + ntest = len(test_set.vis_indexed) + + ratio = ntrain / (ntrain + ntest) + + np.testing.assert_allclose(ratio, 0.8, atol=0.05) + def test_crossvalclass_split_randomcell(coords, imager, dataset, generic_parameters): # using the CrossValidate class, split a dataset into train/test subsets From 2631c31fd522a9ef301aab7ce159aa561ab3d7c7 Mon Sep 17 00:00:00 2001 From: Jeff Jennings Date: Wed, 6 Dec 2023 12:53:52 -0500 Subject: [PATCH 5/5] crossval_test: docstring --- test/crossval_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/crossval_test.py b/test/crossval_test.py index bcc90f68..6a1451bf 100644 --- a/test/crossval_test.py +++ b/test/crossval_test.py @@ -21,7 +21,11 @@ def test_crossvalclass_split_dartboard(coords, imager, dataset, generic_paramete cross_validator = CrossValidate(coords, imager, **crossval_pars) cross_validator.split_dataset(dataset) + def test_crossvalclass_split_dartboard_1kfold(coords, imager, dataset, generic_parameters): + # using the CrossValidate class, split a dataset into train/test subsets + # using 'dartboard' splitter with only 1 k-fold; check that the train set + # has ~80% of the model visibilities crossval_pars = generic_parameters["crossval_pars"] crossval_pars["split_method"] = "dartboard"