diff --git a/src/mpol/crossval.py b/src/mpol/crossval.py index 9847d98d..6a320e4c 100644 --- a/src/mpol/crossval.py +++ b/src/mpol/crossval.py @@ -517,10 +517,10 @@ def __next__(self) -> tuple[GriddedDataset, GriddedDataset]: if self.k == 1: if self.verbose is True: logging.info(" DartboardSplitGridded: only 1 k-fold: splitting dataset as ~80/20 train/test") - # number of cells ~equal to 20% of full cell list ntest = round(0.2 * len(k_list[0])) + # put ~20% of cells into test set cell_list_test = k_list[0][:ntest] - # remove cells place in test set from train set + # remove cells in test set from train set k_list[0] = np.delete(k_list[0], range(ntest), axis=0) else: cell_list_test = k_list.pop(self.n)