You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This feature was originally drafted by @hgrzy in PR #93. The codebase has evolved significantly since that PR was opened and our thinking about how to run cross validation loops has evolved slightly. I'm closing PR #93 and excerpting the relevant draft code below. The hope is that this new issue could serve as a starting point for future stratified K-fold implementations.
In the original PR, Hannah said,
"Perform k fold cross validation where each k fold has the same ratio of low : high spatial frequency data points as the overall dataset.
Currently working on the for loop. I'm skeptical with the number of rows in the np array "pairs." While iterating through the while loop in the init of a stratkfold object it stops producing any flags that I've incorporated though jupyter notebook says it is still running the cell."
This most likely ties into the GriddedDataset and crossval discussions in #163#166, too.
class StratKCV:
def __init__(self, gridder, griddedDataset, k, npseed=None):
self.griddedDataset = griddedDataset
self.gridder = gridder
self.cartesian_us = self.griddedDataset.coords.packed_u_centers_2D
self.cartesian_vs = self.griddedDataset.coords.packed_v_centers_2D
assert k > 0, "k must be a positive integer"
self.k = k
# 2d mask for any UV cells that contain visibilities
# in *any* channel
stacked_mask = np.any(self.griddedDataset.mask.detach().numpy(), axis=0)
# add
self.stacked_mask = stacked_mask
# get u's and v's from dataset amd turn into 1D lists
uu = self.gridder.uu
vv = self.gridder.vv
if npseed is not None:
np.random.seed(npseed)
# pairing u's and v's
pairs = np.vstack((uu, vv)).T
self.pairs = pairs
# splitting
l5000 = np.empty(2)
g5000 = np.empty(2)
pair_ind = 0
while pair_ind < len(pairs):
q = np.sqrt((pairs[pair_ind, 0]) ** 2 + (pairs[pair_ind, 1]) ** 2)
if q < 5000:
l5000 = np.vstack([l5000, pairs[pair_ind]])
if q > 5000:
g5000 = np.vstack([g5000, pairs[pair_ind]])
pair_ind += 1
print(pair_ind)
# Doing this to get split function to work - should change in the long run
# With the current dataset and criterion the l5000 array will be 3 too long to split evenly
print("G1")
l5000 = l5000[:-3]
g5000 = g5000[:-1]
self.l5000 = l5000
self.g5000 = g5000
print("G2")
numSamplesFold = len(pairs) / k
percentSmallSpatFreqFold = len(l5000) / len(pairs)
numl5000perFold = numSamplesFold * percentSmallSpatFreqFold
print("G3")
percentLargeSpatFreqFold = len(g5000) / len(pairs)
numg5000perFold = numSamplesFold * percentLargeSpatFreqFold
# randomize each list while maintaining [uu, vv] pairs
np.random.shuffle(l5000)
np.random.shuffle(g5000)
# get how many sections to partition l5000 into based on
# numl5000perFold are necessary
print("g4")
numPairsl5000 = len(l5000)
numSectionsl5000 = round(numPairsl5000 / numl5000perFold)
self.numPairsl5000 = numPairsl5000
self.numSectionsl5000 = numSectionsl5000
# same for g5000
print("G5")
numPairsg5000 = len(g5000)
numSectionsg5000 = round(numPairsg5000 / numg5000perFold)
self.numPairsg5000 = numPairsg5000
self.numSectionsg5000 = numSectionsg5000
print("g6")
# Partition low and high spat freq lists into groups
lowSpatGroups = np.vsplit(l5000, numSectionsl5000)
highSpatGroups = np.vsplit(g5000, numSectionsg5000)
self.lowSpatGroups = lowSpatGroups
self.highSpatGroups = highSpatGroups
print("g7")
# Create sets for kfolds
# np array 7 X 983 X 2
# 7 kfolds, 983 visibilities per fold, 2 coordinates [u, v]
self.k_split_cell_list = np.hstack([lowSpatGroups, highSpatGroups])
def build_grid_mask_from_cells(self, cell_index_list):
# need to add to init self.cartesian_us = self.coords.pack_u_centers_2D
mask = np.zeros_like(self.cartesian_us, dtype="bool")
# problem, the following has actual u,v readings but does grid have problem with this?
# for cell_index in cell_index_list:
# u, v = cell_index
# mask[u,v] = True
for cell_index in cell_index_list:
u, v = cell_index
# u_min, u_max = self.u_edges[u : u + 2] #change bc u_edges is from coordinates
# v_min, v_max = self.v_edges[v : v + 2]
u_min, u_max = self.coords.u_bin_min, self.coords.u_bin_max
v_min, v_max = self.coords.v_bin_min, self.coords.v_bin_max
# whether or not the u and v values of the coordinate array
# fit in the u cell and v cell
ind = (
(self.cartesian_us >= u_min)
& (self.cartesian_us < u_max)
& (self.cartesian_vs >= v_min)
& (self.cartesian_vs < v_max)
)
mask[ind] = True
return mask
def __iter__(self):
self.n = 0 # the current k-slice we're on
return self
def __next__(self):
print("Entered next")
if self.n < self.k:
print("Entered if")
k_list = self.k_split_cell_list.copy()
cell_list_test = k_list[self.n]
self.cell_list_test = cell_list_test
# put remaining indices back into a full list
print("F1")
cell_list_train = np.vstack(
(k_list[: self.n, :, :], k_list[self.n + 1 :, :, :])
)
self.cell_list_train = cell_list_train
# create the masks for each cell list
print("F2")
train_mask = self.build_grid_mask_from_cells(cell_list_train)
test_mask = self.build_grid_mask_from_cells(cell_list_test)
# copy origial dataset
print("F3")
train = copy.deepcopy(self.griddedDataset)
test = copy.deepcopy(self.griddedDataset)
# use these masks to limit new datasets to only unmasked cells
print("F4")
train.add_mask(train_mask)
test.add_mask(test_mask)
self.n += 1
print(self.n)
return train, test
else:
raise StopIteration
The text was updated successfully, but these errors were encountered:
This feature was originally drafted by @hgrzy in PR #93. The codebase has evolved significantly since that PR was opened and our thinking about how to run cross validation loops has evolved slightly. I'm closing PR #93 and excerpting the relevant draft code below. The hope is that this new issue could serve as a starting point for future stratified K-fold implementations.
In the original PR, Hannah said,
"Perform k fold cross validation where each k fold has the same ratio of low : high spatial frequency data points as the overall dataset.
Currently working on the for loop. I'm skeptical with the number of rows in the np array "pairs." While iterating through the while loop in the init of a stratkfold object it stops producing any flags that I've incorporated though jupyter notebook says it is still running the cell."
This most likely ties into the GriddedDataset and crossval discussions in #163 #166, too.
The text was updated successfully, but these errors were encountered: