This repository has been archived by the owner on Oct 14, 2018. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 42
cross validation for xarray_filters.MLDataset - Elm PR 221 #61
Closed
PeterDSteinberg
wants to merge
17
commits into
dask:master
from
PeterDSteinberg:cv-xarray-elm-issue-204
Closed
Changes from 8 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
47834f4
cross validation for xarray_filters.MLDataset - Elm PR 221
6c61b04
diagnostic printing
ce34117
remove print statements
013b3ad
fix usage of isinstance dask.base.Base -> is_dask_collection
6cb7c8d
sampler related cross validation changes
af16ad4
resolve merge conflicts with master
557e40d
deduplicate X,y splitting logic
ea512ae
fix test failures related to transformers of None and feature union
632ba83
fix pep8 issues found in CI checks
9ad1d74
fix pep8 issues
ec1e287
refactor to simplify changes in dask-searchcv
6906e83
refactor to simplify changes in dask-searchcv
5940ff1
refactor to simplify changes in dask-searchcv
2e1edc9
reduce diff in dask-searchcv -> move cv stuff to elm
e73381d
pep8 fixes
c58293f
remove _get_est_type
6367c21
pep8 fixes
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import warnings | ||
from collections import defaultdict | ||
from collections import defaultdict, Sequence | ||
from threading import Lock | ||
from timeit import default_timer | ||
from distutils.version import LooseVersion | ||
|
@@ -15,9 +15,11 @@ | |
from sklearn.exceptions import FitFailedWarning | ||
from sklearn.pipeline import Pipeline, FeatureUnion | ||
from sklearn.utils import safe_indexing | ||
from sklearn.utils.validation import check_consistent_length, _is_arraylike | ||
from sklearn.utils.validation import _is_arraylike, check_consistent_length | ||
|
||
from .utils import copy_estimator, _split_Xy, _is_xy_tuple | ||
|
||
|
||
from .utils import copy_estimator | ||
|
||
# Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop | ||
# support for scikit-learn < 0.18.1 or numpy < 1.12.0. | ||
|
@@ -63,7 +65,6 @@ def warn_fit_failure(error_score, e): | |
# Functions in the graphs # | ||
# ----------------------- # | ||
|
||
|
||
class CVCache(object): | ||
def __init__(self, splits, pairwise=False, cache=True): | ||
self.splits = splits | ||
|
@@ -101,10 +102,14 @@ def _extract(self, X, y, n, is_x=True, is_train=True): | |
return self.cache[n, is_x, is_train] | ||
|
||
inds = self.splits[n][0] if is_train else self.splits[n][1] | ||
result = safe_indexing(X if is_x else y, inds) | ||
|
||
if self.cache is not None: | ||
self.cache[n, is_x, is_train] = result | ||
post_splits = getattr(self, '_post_splits', None) | ||
if post_splits: | ||
result = post_splits(np.array(X)[inds]) | ||
self.cache[n, True, is_train] = result | ||
else: | ||
result = safe_indexing(X if is_x else y, inds) | ||
if self.cache is not None: | ||
self.cache[n, is_x, is_train] = result | ||
return result | ||
|
||
def _extract_pairwise(self, X, y, n, is_train=True): | ||
|
@@ -117,16 +122,47 @@ def _extract_pairwise(self, X, y, n, is_train=True): | |
if X.shape[0] != X.shape[1]: | ||
raise ValueError("X should be a square kernel matrix") | ||
train, test = self.splits[n] | ||
post_splits = getattr(self, '_post_splits', None) | ||
result = X[np.ix_(train if is_train else test, train)] | ||
|
||
if self.cache is not None: | ||
self.cache[n, True, is_train] = result | ||
if post_splits: | ||
result = post_splits(result) | ||
if _is_xy_tuple(result): | ||
if self.cache is not None: | ||
self.cache[n, True, is_train], self.cache[n, False, is_train] = result | ||
elif self.cache is not None: | ||
self.cache[n, True, is_train] = result | ||
elif self.cache is not None: | ||
self.cache[n, True, is_train] = result | ||
return result | ||
|
||
|
||
def cv_split(cv, X, y, groups, is_pairwise, cache): | ||
check_consistent_length(X, y, groups) | ||
return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache) | ||
class CVCacheSampler(CVCache): | ||
def __init__(self, sampler, splits, pairwise=False, cache=True): | ||
self.sampler = sampler | ||
super(CVCacheSampler, self).__init__(splits, pairwise=pairwise, | ||
cache=cache) | ||
|
||
def _post_splits(self, X, y=None, n=None, is_x=True, is_train=False): | ||
if y is not None: | ||
raise ValueError('Expected y to be None (returned by Sampler() instance or similar.') | ||
func = getattr(self.sampler, 'fit_transform', getattr(self.sampler, 'transform', self.sampler)) | ||
return func(X, y=y, is_x=is_x, is_train=is_train) | ||
|
||
|
||
def cv_split(cv, X, y, groups, is_pairwise, cache, sampler): | ||
kw = dict(pairwise=is_pairwise, cache=cache) | ||
if sampler: | ||
cls = CVCacheSampler | ||
kw['cache'] = True | ||
else: | ||
cls = CVCache | ||
check_consistent_length(X, y, groups) | ||
splits = list(cv.split(X, y, groups)) | ||
if sampler: | ||
args = (sampler, splits,) | ||
else: | ||
args = (splits,) | ||
return cls(*args, **kw) | ||
|
||
|
||
def cv_n_samples(cvs): | ||
|
@@ -226,6 +262,7 @@ def fit(est, X, y, error_score='raise', fields=None, params=None, | |
|
||
def fit_transform(est, X, y, error_score='raise', fields=None, params=None, | ||
fit_params=None): | ||
new_y = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is this used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leftover from old effort ( |
||
if X is FIT_FAILURE: | ||
est, fit_time, Xt = FIT_FAILURE, 0.0, FIT_FAILURE | ||
else: | ||
|
@@ -239,19 +276,21 @@ def fit_transform(est, X, y, error_score='raise', fields=None, params=None, | |
else: | ||
est.fit(X, y, **fit_params) | ||
Xt = est.transform(X) | ||
Xt, y = _split_Xy(Xt, y, typ=(tuple, list)) | ||
except Exception as e: | ||
if error_score == 'raise': | ||
raise | ||
warn_fit_failure(error_score, e) | ||
est = Xt = FIT_FAILURE | ||
fit_time = default_timer() - start_time | ||
|
||
return (est, fit_time), Xt | ||
return (est, fit_time), (Xt, y) | ||
|
||
|
||
def _score(est, X, y, scorer): | ||
if est is FIT_FAILURE: | ||
return FIT_FAILURE | ||
X, y = _split_Xy(X, y) | ||
return scorer(est, X) if y is None else scorer(est, X, y) | ||
|
||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does
post_split
imply that you always have a cache? If so this can be rewriting asThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And what happens if a users passes
GridSearchCV(..., cache_cv=False)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question - not sure how usage of a
post_splits
callable relates tocache_cv=False
. My Dataset/MLDatatset work so far considered onlycache_cv=True
.The idea of
post_splits
is shown more in the CVCacheSampler. It is a way of calling asampler
function on the argumentX
given to fit. See the usage ofsampler
argument toEaSearchCV
in this elm PR 221. That example uses a sampler that expectsX
to be a list of dates, and thesampler
function determines how to build an Dataset/MLDataset for the list of dates. In other cases,X
could be a list of file names and thecv
argument toEaSearchCV
controls how to split up those file names into test / train groups. I like the idea of asampler
function from cross validation / hyperparameterization of Dataset/MLDataset workflows because there is no assumption about the shapes of the DataArrays (unlike cross validation in typicalsklearn
workflows where thecv
object is used to subset train/test row groupings of a large matrix).TODO (for me):
cache_cv=False
with asampler
function - fix or document requirements there.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the returned
X
(anyy
?) constrained to be the same shape as the input?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More thought needs to be done on the
sampler
idea in general - currently it avoids the call tocheck_consistent_length
. Thesampler
may returnX
or anX, y
tuple. X may be a Dataset/MLDataset or array-like and y is currently expected to be array-like.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I should just make sure I call
check_consistent_length
on samples ofX
or(X, y
) returned bysampler
and returnFIT_FAILURE
if there are inconsistent lengths. I'll experiment with that today/tomorrow.