Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-val splitter #132

Merged
merged 43 commits into from
Feb 9, 2023
Merged
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
039ff6b
Merge branch 'main' into crossval_splitter
jeffjennings Feb 3, 2023
eb7ba74
Merge branch 'crossval_class' into crossval_splitter
jeffjennings Feb 3, 2023
6dcd7ef
move splitter utils from datasets to crossval
jeffjennings Feb 3, 2023
b2b96ec
CrossValidate: add split_method arg
jeffjennings Feb 3, 2023
3f21e3a
split_dataset: remove args
jeffjennings Feb 3, 2023
49efc43
split_dataset: start of 'random cell' split method
jeffjennings Feb 3, 2023
837d5cc
conftest: add crossval par
jeffjennings Feb 3, 2023
b237356
crossval_test: add uniform cell splitter test
jeffjennings Feb 3, 2023
34111dd
Merge branch 'main' into crossval_splitter
jeffjennings Feb 6, 2023
cc66a59
typo
jeffjennings Feb 6, 2023
a35bbab
Merge branch 'main' into crossval_splitter
jeffjennings Feb 6, 2023
b1f7988
apply main changes to moved dataset routines
jeffjennings Feb 6, 2023
4f2dc66
run_crossval: return mean and std of scores
jeffjennings Feb 6, 2023
29334bd
Merge branch 'main' into crossval_splitter
jeffjennings Feb 7, 2023
6e00ea1
propagate new Dartboard, KFold into crossval
jeffjennings Feb 7, 2023
3e7226f
crossval: change default arg
jeffjennings Feb 7, 2023
795a2a8
move Dartboard back to datasets.py
jeffjennings Feb 7, 2023
34e7dc2
rename KFoldCross... to DartboardSplitGridded
jeffjennings Feb 7, 2023
2455da3
move tests from datasets_test to crossval_test
jeffjennings Feb 7, 2023
c513f1d
update crossval test name
jeffjennings Feb 7, 2023
8875cd6
update crossval test args
jeffjennings Feb 7, 2023
318fea6
common_data: update import of DartboardSplit...
jeffjennings Feb 7, 2023
0a7648d
crossval tutorial: update import of DartboardSplit
jeffjennings Feb 7, 2023
66f3825
HD143006_pt2: update import of DartboardSplit...
jeffjennings Feb 7, 2023
c3c20b3
tests.yml: update activity types
jeffjennings Feb 7, 2023
f3b7ee6
docs_build.yml: update activity types
jeffjennings Feb 7, 2023
c70103c
Add crossval to API
jeffjennings Feb 9, 2023
0d3f247
crossval.py: add random cell split class
jeffjennings Feb 9, 2023
a63adb9
crossval.py: update imports
jeffjennings Feb 9, 2023
b431530
CrossValidate: update args
jeffjennings Feb 9, 2023
aa06b99
split_dataset: update supported methods
jeffjennings Feb 9, 2023
27c9f43
run_crossval: update args
jeffjennings Feb 9, 2023
62ea137
run_crossval: return dict
jeffjennings Feb 9, 2023
e64138c
run_crossval: accept either splitter
jeffjennings Feb 9, 2023
aa11463
run_crossval: GPU TODOs
jeffjennings Feb 9, 2023
0effac0
run_crossval: add diagnostics dict
jeffjennings Feb 9, 2023
1784ff1
add tests for new splitter class
jeffjennings Feb 9, 2023
7b848f0
update default params for crossval tests
jeffjennings Feb 9, 2023
ebe31d8
CrossValidate: remove old arg
jeffjennings Feb 9, 2023
80f0e43
DartboardSplitGridded: mysterious indent
jeffjennings Feb 9, 2023
85f01b3
crossval.py: add imports for type hint
jeffjennings Feb 9, 2023
5947598
crossval.py: import TODO
jeffjennings Feb 9, 2023
17933f7
TrainTest: add arg for compat with CrossValidate
jeffjennings Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
crossval.py: update imports
jeffjennings committed Feb 9, 2023

Verified

This commit was signed with the committer’s verified signature.
kayabaNerve Luke Parker
commit a63adb9beb0edb59f918db82ee119ce55fccb3da
6 changes: 5 additions & 1 deletion src/mpol/crossval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import numpy as np
import copy
from collections import defaultdict
import logging
import torch

from mpol.precomposed import SimpleNet
from mpol.training import TrainTest
from mpol.datasets import Dartboard
from mpol.datasets import Dartboard, GriddedDataset

class CrossValidate:
r"""