Skip to content
This repository has been archived by the owner on Oct 14, 2018. It is now read-only.

cross validation for xarray_filters.MLDataset - Elm PR 221 #61

Closed
Closed
69 changes: 54 additions & 15 deletions dask_searchcv/methods.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import absolute_import, division, print_function

import warnings
from collections import defaultdict
from collections import defaultdict, Sequence
from threading import Lock
from timeit import default_timer
from distutils.version import LooseVersion
Expand All @@ -15,9 +15,11 @@
from sklearn.exceptions import FitFailedWarning
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.utils import safe_indexing
from sklearn.utils.validation import check_consistent_length, _is_arraylike
from sklearn.utils.validation import _is_arraylike, check_consistent_length

from .utils import copy_estimator, _split_Xy, _is_xy_tuple


from .utils import copy_estimator

# Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop
# support for scikit-learn < 0.18.1 or numpy < 1.12.0.
Expand Down Expand Up @@ -63,7 +65,6 @@ def warn_fit_failure(error_score, e):
# Functions in the graphs #
# ----------------------- #


class CVCache(object):
def __init__(self, splits, pairwise=False, cache=True):
self.splits = splits
Expand Down Expand Up @@ -101,10 +102,14 @@ def _extract(self, X, y, n, is_x=True, is_train=True):
return self.cache[n, is_x, is_train]

inds = self.splits[n][0] if is_train else self.splits[n][1]
result = safe_indexing(X if is_x else y, inds)

if self.cache is not None:
self.cache[n, is_x, is_train] = result
post_splits = getattr(self, '_post_splits', None)
if post_splits:
result = post_splits(np.array(X)[inds])
self.cache[n, True, is_train] = result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does post_split imply that you always have a cache? If so this can be rewriting as

if post_splits:
    result = ...
else:
   result = safe_indexing(...)

if self.cache is not None:
    self.cache[n, is_x, is_train] = result

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what happens if a users passes GridSearchCV(..., cache_cv=False)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - not sure how usage of a post_splits callable relates to cache_cv=False. My Dataset/MLDatatset work so far considered only cache_cv=True.

The idea of post_splits is shown more in the CVCacheSampler. It is a way of calling a sampler function on the argument X given to fit. See the usage of sampler argument to EaSearchCV in this elm PR 221. That example uses a sampler that expects X to be a list of dates, and the sampler function determines how to build an Dataset/MLDataset for the list of dates. In other cases, X could be a list of file names and the cv argument to EaSearchCV controls how to split up those file names into test / train groups. I like the idea of a sampler function from cross validation / hyperparameterization of Dataset/MLDataset workflows because there is no assumption about the shapes of the DataArrays (unlike cross validation in typical sklearn workflows where the cv object is used to subset train/test row groupings of a large matrix).

TODO (for me):

  • Add a test for cache_cv=False with a sampler function - fix or document requirements there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sampler function determines how to build an Dataset/MLDataset for the list of dates.

Is the returned X (any y?) constrained to be the same shape as the input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More thought needs to be done on the sampler idea in general - currently it avoids the call to check_consistent_length. The sampler may return X or an X, y tuple. X may be a Dataset/MLDataset or array-like and y is currently expected to be array-like.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should just make sure I call check_consistent_length on samples of X or (X, y) returned by sampler and return FIT_FAILURE if there are inconsistent lengths. I'll experiment with that today/tomorrow.

else:
result = safe_indexing(X if is_x else y, inds)
if self.cache is not None:
self.cache[n, is_x, is_train] = result
return result

def _extract_pairwise(self, X, y, n, is_train=True):
Expand All @@ -117,16 +122,47 @@ def _extract_pairwise(self, X, y, n, is_train=True):
if X.shape[0] != X.shape[1]:
raise ValueError("X should be a square kernel matrix")
train, test = self.splits[n]
post_splits = getattr(self, '_post_splits', None)
result = X[np.ix_(train if is_train else test, train)]

if self.cache is not None:
self.cache[n, True, is_train] = result
if post_splits:
result = post_splits(result)
if _is_xy_tuple(result):
if self.cache is not None:
self.cache[n, True, is_train], self.cache[n, False, is_train] = result
elif self.cache is not None:
self.cache[n, True, is_train] = result
elif self.cache is not None:
self.cache[n, True, is_train] = result
return result


def cv_split(cv, X, y, groups, is_pairwise, cache):
check_consistent_length(X, y, groups)
return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache)
class CVCacheSampler(CVCache):
def __init__(self, sampler, splits, pairwise=False, cache=True):
self.sampler = sampler
super(CVCacheSampler, self).__init__(splits, pairwise=pairwise,
cache=cache)

def _post_splits(self, X, y=None, n=None, is_x=True, is_train=False):
if y is not None:
raise ValueError('Expected y to be None (returned by Sampler() instance or similar.')
func = getattr(self.sampler, 'fit_transform', getattr(self.sampler, 'transform', self.sampler))
return func(X, y=y, is_x=is_x, is_train=is_train)


def cv_split(cv, X, y, groups, is_pairwise, cache, sampler):
kw = dict(pairwise=is_pairwise, cache=cache)
if sampler:
cls = CVCacheSampler
kw['cache'] = True
else:
cls = CVCache
check_consistent_length(X, y, groups)
splits = list(cv.split(X, y, groups))
if sampler:
args = (sampler, splits,)
else:
args = (splits,)
return cls(*args, **kw)


def cv_n_samples(cvs):
Expand Down Expand Up @@ -226,6 +262,7 @@ def fit(est, X, y, error_score='raise', fields=None, params=None,

def fit_transform(est, X, y, error_score='raise', fields=None, params=None,
fit_params=None):
new_y = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover from old effort (new_y) - just ran PEP8 check and found that and a few other things to fix.

if X is FIT_FAILURE:
est, fit_time, Xt = FIT_FAILURE, 0.0, FIT_FAILURE
else:
Expand All @@ -239,19 +276,21 @@ def fit_transform(est, X, y, error_score='raise', fields=None, params=None,
else:
est.fit(X, y, **fit_params)
Xt = est.transform(X)
Xt, y = _split_Xy(Xt, y, typ=(tuple, list))
except Exception as e:
if error_score == 'raise':
raise
warn_fit_failure(error_score, e)
est = Xt = FIT_FAILURE
fit_time = default_timer() - start_time

return (est, fit_time), Xt
return (est, fit_time), (Xt, y)


def _score(est, X, y, scorer):
if est is FIT_FAILURE:
return FIT_FAILURE
X, y = _split_Xy(X, y)
return scorer(est, X) if y is None else scorer(est, X, y)


Expand Down
Loading