Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-val splitter #132

Merged
merged 43 commits into from
Feb 9, 2023
Merged

Cross-val splitter #132

merged 43 commits into from
Feb 9, 2023

Conversation

jeffjennings
Copy link
Contributor

@jeffjennings jeffjennings commented Feb 3, 2023

  • Moves datasets.KFoldCrossValidatorGridded to crossval.DartboardSplitGridded, updates imports of this throughout tests and tutorials. I don't like renaming a user-facing function, but I think with this PR's new splitting class (crossval.RandomCellSplitGridded) and new class that performs cross-validation (crossval.CrossValidate), the name KFoldCrossValidatorGridded is confusing (e.g. it's not doing the cross-val, just splitting).

    • moves tests using crossval.DartboardSplitGridded into crossval_test
  • Adds another dataset splitting class crossval.RandomCellSplitGridded, which draws random cells from the gridded vis to build the (train, test) sets, returning an iterator of GriddedDataset tuples, consistent with crossval.DartboardSplitGridded.

    • Cells in the top 1% of gridded weight are included in every train set
    • The split sets are even in number:
 for (train, test) in DartboardSplitGridded:
     print(train.shape, test.shape)

              torch.Size([19580]) torch.Size([5798])
              torch.Size([19492]) torch.Size([5886])
              torch.Size([20734]) torch.Size([4644])
              torch.Size([21266]) torch.Size([4112])
              torch.Size([20440]) torch.Size([4938])

 for (train, test) in RandomCellSplitGridded:
     print(train.shape, test.shape)

              torch.Size([20353]) torch.Size([5025])
              torch.Size([20353]) torch.Size([5025])
              torch.Size([20353]) torch.Size([5025])
              torch.Size([20353]) torch.Size([5025])
              torch.Size([20354]) torch.Size([5024])
  • Updates crossval.CrossValidate workflow to use these classes to split the data (crossval.CrossValidate.split_dataset), within a k-fold cross-val loop (crossval.CrossValidate.run_crossval)

    • generates one train and one test GriddedDataset at a time in the loop to keep memory footprint reasonable
  • Haven't confirmed GPU compatibility yet (will address in Fit runner #128 once I include calls to crossval.CrossValidate in the fit pipeline)

  • Haven't added a stratified k-fold splitter - let's see how RandomCellSplitGridded does

@jeffjennings jeffjennings added this to the v0.1.4 milestone Feb 3, 2023
@jeffjennings
Copy link
Contributor Author

Testing whether review request triggers docs_build workflow once or twice...

@jeffjennings jeffjennings requested a review from a team February 8, 2023 18:55
@jeffjennings
Copy link
Contributor Author

...just once!

@jeffjennings jeffjennings changed the title [WIP] Cross-val splitter Cross-val splitter Feb 9, 2023
@jeffjennings jeffjennings requested a review from iancze February 9, 2023 08:34
@jeffjennings jeffjennings self-assigned this Feb 9, 2023
@jeffjennings
Copy link
Contributor Author

Added note in #100 about changes here.

@iancze iancze merged commit b217642 into main Feb 9, 2023
@iancze iancze deleted the crossval_splitter branch February 9, 2023 20:34
@iancze iancze mentioned this pull request Feb 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants