Skip to content
This repository has been archived by the owner on Oct 14, 2018. It is now read-only.

cross validation for xarray_filters.MLDataset - Elm PR 221 #61

Closed
Closed
17 changes: 13 additions & 4 deletions dask_searchcv/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sklearn.utils import safe_indexing
from sklearn.utils.validation import check_consistent_length, _is_arraylike

from .utils import copy_estimator
from .utils import copy_estimator, _split_Xy

# Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop
# support for scikit-learn < 0.18.1 or numpy < 1.12.0.
Expand Down Expand Up @@ -125,8 +125,15 @@ def _extract_pairwise(self, X, y, n, is_train=True):


def cv_split(cv, X, y, groups, is_pairwise, cache):
check_consistent_length(X, y, groups)
return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache)
splits = list(cv.split(X, y, groups))
if not cache or isinstance(cache, bool):
check_consistent_length(X, y, groups)
return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache)
params = dict(pairwise=is_pairwise,
cache={},
splits=splits)
cache.set_params(**params)
return cache


def cv_n_samples(cvs):
Expand Down Expand Up @@ -239,19 +246,21 @@ def fit_transform(est, X, y, error_score='raise', fields=None, params=None,
else:
est.fit(X, y, **fit_params)
Xt = est.transform(X)
Xt, y = _split_Xy(Xt, y, typ=(tuple, list))
except Exception as e:
if error_score == 'raise':
raise
warn_fit_failure(error_score, e)
est = Xt = FIT_FAILURE
fit_time = default_timer() - start_time

return (est, fit_time), Xt
return (est, fit_time), (Xt, y)


def _score(est, X, y, scorer):
if est is FIT_FAILURE:
return FIT_FAILURE
X, y = _split_Xy(X, y)
return scorer(est, X) if y is None else scorer(est, X, y)


Expand Down
101 changes: 60 additions & 41 deletions dask_searchcv/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def build_graph(estimator, cv, scorer, candidate_params, X, y=None,
cv = check_cv(cv, y, is_classifier(estimator))
# "pairwise" estimators require a different graph for CV splitting
is_pairwise = getattr(estimator, '_pairwise', False)

dsk = {}
X_name, y_name, groups_name = to_keys(dsk, X, y, groups)
n_splits = compute_n_splits(cv, X, y, groups)
Expand Down Expand Up @@ -119,7 +118,6 @@ def build_graph(estimator, cv, scorer, candidate_params, X, y=None,
dsk[best_estimator] = (fit_best, clone(estimator), best_params,
X_name, y_name, fit_params)
keys.append(best_estimator)

return dsk, keys, n_splits


Expand Down Expand Up @@ -201,7 +199,6 @@ def do_fit_and_score(dsk, main_token, est, cv, fields, tokens, params,
fit_ests = do_fit(dsk, TokenIterator(main_token), est, cv,
fields, tokens, params, X_trains, y_trains,
fit_params, n_splits, error_score)

score_name = 'score-' + main_token

scores = []
Expand All @@ -215,7 +212,6 @@ def do_fit_and_score(dsk, main_token, est, cv, fields, tokens, params,

xtest = X_test + (n,)
ytest = y_test + (n,)

for (name, m) in fit_ests:
dsk[(score_name, m, n)] = (score, (name, m, n), xtest, ytest,
xtrain, ytrain, scorer)
Expand Down Expand Up @@ -281,7 +277,8 @@ def do_fit_transform(dsk, next_token, est, cv, fields, tokens, params, Xs, ys,
token = next_token(est)
fit_Xt_name = '%s-fit-transform-%s' % (name, token)
fit_name = '%s-fit-%s' % (name, token)
Xt_name = '%s-transform-%s' % (name, token)
Xt_name = '%s-transform-X-%s' % (name, token)
yt_name = '%s-transform-y-%s' % (name, token)
est_name = '%s-%s' % (type(est).__name__.lower(), token)
dsk[est_name] = est

Expand All @@ -300,12 +297,16 @@ def do_fit_transform(dsk, next_token, est, cv, fields, tokens, params, Xs, ys,
error_score, fields, p,
fit_params)
dsk[(fit_name, m, n)] = (getitem, (fit_Xt_name, m, n), 0)
dsk[(Xt_name, m, n)] = (getitem, (fit_Xt_name, m, n), 1)
Xty = (getitem, (fit_Xt_name, m, n), 1)
dsk[(Xt_name, m, n)] = (getitem, Xty, 0)
dsk[(yt_name, m, n)] = (getitem, Xty, 1)
seen[X, y, t] = m
out_append(m)
m += 1

return [(fit_name, i) for i in out], [(Xt_name, i) for i in out]
return ([(fit_name, i) for i in out],
[(Xt_name, i) for i in out],
[(yt_name, i) for i in out],)


def _group_subparams(steps, fields, ignore=()):
Expand Down Expand Up @@ -340,14 +341,14 @@ def new_group():
def _do_fit_step(dsk, next_token, step, cv, fields, tokens, params, Xs, ys,
fit_params, n_splits, error_score, step_fields_lk,
fit_params_lk, field_to_index, step_name, none_passthrough,
is_transform):
is_transform, is_featureunion):
sub_fields, sub_inds = map(list, unzip(step_fields_lk[step_name], 2))
sub_fit_params = fit_params_lk[step_name]

if step_name in field_to_index:
# The estimator may change each call
new_fits = {}
new_Xs = {}
new_ys = {}
est_index = field_to_index[step_name]

for ids in _group_ids_by_index(est_index, tokens):
Expand All @@ -363,8 +364,10 @@ def _do_fit_step(dsk, next_token, step, cv, fields, tokens, params, Xs, ys,
if is_transform:
if none_passthrough:
new_Xs.update(zip(ids, get(ids, Xs)))
new_ys.update(zip(ids, get(ids, ys)))
else:
new_Xs.update(nones)
new_ys.update(nones)
else:
# Extract the proper subset of Xs, ys
sub_Xs = get(ids, Xs)
Expand All @@ -377,24 +380,30 @@ def _do_fit_step(dsk, next_token, step, cv, fields, tokens, params, Xs, ys,
sub_tokens = sub_params = None

if is_transform:
sub_fits, sub_Xs = do_fit_transform(dsk, next_token,
sub_est, cv, sub_fields,
sub_tokens, sub_params,
sub_Xs, sub_ys,
sub_fit_params,
n_splits, error_score)
out = do_fit_transform(dsk, next_token,
sub_est, cv, sub_fields,
sub_tokens, sub_params,
sub_Xs, sub_ys,
sub_fit_params,
n_splits, error_score)
if len(out) == 3:
sub_fits, sub_Xs, sub_ys = out
new_ys.update(zip(ids, sub_ys))
else:
sub_fits, sub_Xs = out
new_Xs.update(zip(ids, sub_Xs))
new_fits.update(zip(ids, sub_fits))
else:
sub_fits = do_fit(dsk, next_token, sub_est, cv,
sub_fields, sub_tokens, sub_params,
sub_Xs, sub_ys, sub_fit_params,
n_splits, error_score)
new_fits.update(zip(ids, sub_fits))
new_fits.update(zip(ids, sub_fits))
# Extract lists of transformed Xs and fit steps
all_ids = list(range(len(Xs)))
if is_transform:
Xs = get(all_ids, new_Xs)
if not is_featureunion:
ys = get(all_ids, new_ys)
fits = get(all_ids, new_fits)
elif step is None:
# Nothing to do
Expand All @@ -410,15 +419,19 @@ def _do_fit_step(dsk, next_token, step, cv, fields, tokens, params, Xs, ys,
sub_tokens = sub_params = None

if is_transform:
fits, Xs = do_fit_transform(dsk, next_token, step, cv,
sub_fields, sub_tokens, sub_params,
Xs, ys, sub_fit_params, n_splits,
error_score)
out = do_fit_transform(dsk, next_token, step, cv,
sub_fields, sub_tokens, sub_params,
Xs, ys, sub_fit_params, n_splits,
error_score)
if len(out) == 3:
fits, Xs, ys = out
else:
fits, Xs = out
else:
fits = do_fit(dsk, next_token, step, cv, sub_fields,
sub_tokens, sub_params, Xs, ys, sub_fit_params,
n_splits, error_score)
return (fits, Xs) if is_transform else (fits, None)
return (fits, Xs, ys) if is_transform else (fits, None, None)


def _do_pipeline(dsk, next_token, est, cv, fields, tokens, params, Xs, ys,
Expand All @@ -432,13 +445,19 @@ def _do_pipeline(dsk, next_token, est, cv, fields, tokens, params, Xs, ys,
# A list of (step, is_transform)
instrs = [(s, True) for s in est.steps[:-1]]
instrs.append((est.steps[-1], is_transform))

fit_steps = []
for (step_name, step), transform in instrs:
fits, Xs = _do_fit_step(dsk, next_token, step, cv, fields, tokens,
params, Xs, ys, fit_params, n_splits,
error_score, step_fields_lk, fit_params_lk,
field_to_index, step_name, True, transform)
fits, temp_Xs, temp_ys = _do_fit_step(dsk, next_token, step,
cv, fields, tokens,
params, Xs, ys, fit_params,
n_splits,
error_score, step_fields_lk,
fit_params_lk,
field_to_index, step_name,
True, transform, False)
if transform:
Xs = temp_Xs
ys = temp_ys
fit_steps.append(fits)

# Rebuild the pipelines
Expand All @@ -461,7 +480,7 @@ def _do_pipeline(dsk, next_token, est, cv, fields, tokens, params, Xs, ys,
m += 1

if is_transform:
return out_ests, Xs
return out_ests, Xs, ys
return out_ests


Expand Down Expand Up @@ -501,10 +520,10 @@ def _do_featureunion(dsk, next_token, est, cv, fields, tokens, params, Xs, ys,
fit_steps = []
tr_Xs = []
for (step_name, step) in est.transformer_list:
fits, out_Xs = _do_fit_step(dsk, next_token, step, cv, fields, tokens,
fits, out_Xs, _ = _do_fit_step(dsk, next_token, step, cv, fields, tokens,
params, Xs, ys, fit_params, n_splits,
error_score, step_fields_lk, fit_params_lk,
field_to_index, step_name, False, True)
field_to_index, step_name, False, True, True)
fit_steps.append(fits)
tr_Xs.append(out_Xs)

Expand Down Expand Up @@ -775,18 +794,17 @@ def fit(self, X, y=None, groups=None, **fit_params):
error_score == 'raise'):
raise ValueError("error_score must be the string 'raise' or a"
" numeric value.")

dsk, keys, n_splits = build_graph(estimator, self.cv, self.scorer_,
list(self._get_param_iterator()),
X, y, groups, fit_params,
iid=self.iid,
refit=self.refit,
error_score=error_score,
return_train_score=self.return_train_score,
cache_cv=self.cache_cv)
list(self._get_param_iterator()),
X=X, y=y, groups=groups,
fit_params=fit_params,
iid=self.iid,
refit=self.refit,
error_score=error_score,
return_train_score=self.return_train_score,
cache_cv=self.cache_cv)
self.dask_graph_ = dsk
self.n_splits_ = n_splits

n_jobs = _normalize_n_jobs(self.n_jobs)
scheduler = _normalize_scheduler(self.scheduler, n_jobs)

Expand Down Expand Up @@ -893,11 +911,12 @@ def visualize(self, filename='mydask', format=None, **kwargs):
distributed schedulers. If ``n_jobs == -1`` [default] all cpus are used.
For ``n_jobs < -1``, ``(n_cpus + 1 + n_jobs)`` are used.

cache_cv : bool, default=True
cache_cv : bool or CVCache-like class, default=True
Whether to extract each train/test subset at most once in each worker
process, or every time that subset is needed. Caching the splits can
speedup computation at the cost of increased memory usage per worker
process.
process. If cache_cv is a class, then it is used in place of CVCache
(and extraction is assumed to be at most once).

If True, worst case memory usage is ``(n_splits + 1) * (X.nbytes +
y.nbytes)`` per worker. If False, worst case memory usage is
Expand Down
14 changes: 14 additions & 0 deletions dask_searchcv/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
from distutils.version import LooseVersion

Expand Down Expand Up @@ -78,3 +79,16 @@ def copy_estimator(est):

def unzip(itbl, n):
return zip(*itbl) if itbl else [()] * n


def _is_xy_tuple(result, typ=tuple):
if typ and not isinstance(typ, tuple):
typ = (typ,)
typ = typ + (tuple,)
return isinstance(result, typ) and len(result) == 2


def _split_Xy(X, y, typ=tuple):
if _is_xy_tuple(X, typ=typ):
X, y = X
return X, y