Skip to content

Commit

Permalink
DartboardSplitGridded: verbose arg, message
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Dec 4, 2023
1 parent 19ec677 commit 000e435
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ def __init__(
k: int,
dartboard: Dartboard | None = None,
seed: int | None = None,
verbose: bool = True
):
if k <= 0:
raise ValueError("k must be a positive integer")
Expand All @@ -452,6 +453,7 @@ def __init__(
self.griddedDataset = gridded_dataset
self.k = k
self.dartboard = dartboard
self.verbose = verbose

# 2D mask for any UV cells that contain visibilities
# in *any* channel
Expand Down Expand Up @@ -489,6 +491,7 @@ def from_dartboard_properties(
q_edges: NDArray[floating[Any]],
phi_edges: NDArray[floating[Any]],
seed: int | None = None,
verbose: bool = True,
) -> DartboardSplitGridded:
r"""
Alternative method to initialize a DartboardSplitGridded object from Dartboard parameters.
Expand All @@ -499,9 +502,10 @@ def from_dartboard_properties(
q_edges (1D numpy array): an array of radial bin edges to set the dartboard cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` represented by ``coords``.
phi_edges (1D numpy array): an array of azimuthal bin edges to set the dartboard cells in [radians]. If ``None``, defaults to 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`.
seed (int): (optional) numpy random seed to use for the permutation, for reproducibility
verbose (bool): whether to print notification messages
"""
dartboard = Dartboard(gridded_dataset.coords, q_edges, phi_edges)
return cls(gridded_dataset, k, dartboard, seed)
return cls(gridded_dataset, k, dartboard, seed, verbose)

def __iter__(self) -> DartboardSplitGridded:
self.n = 0 # the current k-slice we're on
Expand All @@ -511,6 +515,8 @@ def __next__(self) -> tuple[GriddedDataset, GriddedDataset]:
if self.n < self.k:
k_list = self.k_split_cell_list.copy()
if self.k == 1:
if self.verbose is True:
logging.info(" DartboardSplitGridded: only 1 k-fold: splitting dataset as ~80/20 train/test")
# number of cells ~equal to 20% of full cell list
ntest = round(0.2 * len(k_list[0]))
cell_list_test = k_list[0][:ntest]
Expand Down

0 comments on commit 000e435

Please sign in to comment.