-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #22 from MatthewSZhang/extend
FEAT add extend by mini-batch
- Loading branch information
Showing
8 changed files
with
221 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ API Reference | |
|
||
FastCan | ||
refine | ||
extend | ||
ssc | ||
ols | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
""" | ||
Extend feature selection | ||
""" | ||
|
||
import math | ||
from copy import deepcopy | ||
from numbers import Integral | ||
|
||
import numpy as np | ||
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads | ||
from sklearn.utils._param_validation import Interval, validate_params | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
from ._cancorr_fast import _forward_search # type: ignore | ||
from ._fastcan import FastCan, _prepare_search | ||
|
||
|
||
@validate_params( | ||
{ | ||
"selector": [FastCan], | ||
"n_features_to_select": [ | ||
Interval(Integral, 1, None, closed="left"), | ||
], | ||
"batch_size": [ | ||
Interval(Integral, 1, None, closed="left"), | ||
], | ||
}, | ||
prefer_skip_nested_validation=False, | ||
) | ||
def extend(selector, n_features_to_select=1, batch_size=1): | ||
"""Extend FastCan with mini batches. | ||
It is suitable for selecting a very large number of features | ||
even larger than the number of samples. | ||
Similar to the correlation filter which selects each feature without considering | ||
the redundancy, the function selects features in mini-batch and the | ||
redundancy between the two mini-batches will be ignored. | ||
Parameters | ||
---------- | ||
selector : FastCan | ||
FastCan selector. | ||
n_features_to_select : int, default=1 | ||
The parameter is the absolute number of features to select. | ||
batch_size : int, default=1 | ||
The number of features in a mini-batch. | ||
Returns | ||
------- | ||
indices : ndarray of shape (n_features_to_select,), dtype=int | ||
The indices of the selected features. | ||
Examples | ||
-------- | ||
>>> from fastcan import FastCan, extend | ||
>>> X = [[1, 1, 0], [0.01, 0, 0], [-1, 0, 1], [0, 0, 0]] | ||
>>> y = [1, 0, -1, 0] | ||
>>> selector = FastCan(1, verbose=0).fit(X, y) | ||
>>> print(f"Indices: {selector.indices_}") | ||
Indices: [0] | ||
>>> indices = extend(selector, 3, batch_size=2) | ||
>>> print(f"Indices: {indices}") | ||
Indices: [0 2 1] | ||
""" | ||
check_is_fitted(selector) | ||
n_inclusions = selector.indices_include_.size | ||
n_features = selector.n_features_in_ | ||
n_to_select = n_features_to_select - selector.n_features_to_select | ||
batch_size_to_select = batch_size - n_inclusions | ||
|
||
if n_features_to_select > n_features: | ||
raise ValueError( | ||
f"n_features_to_select {n_features_to_select} " | ||
f"must be <= n_features {n_features}." | ||
) | ||
if n_to_select <= 0: | ||
raise ValueError( | ||
f"The number of features to select ({n_to_select}) ", "is less than 0." | ||
) | ||
if batch_size_to_select <= 0: | ||
raise ValueError( | ||
"The size of mini batch without included indices ", | ||
f"({batch_size_to_select}) is less than 0.", | ||
) | ||
|
||
X_transformed_ = deepcopy(selector.X_transformed_) | ||
|
||
indices_include = selector.indices_include_ | ||
indices_exclude = selector.indices_exclude_ | ||
indices_select = selector.indices_[n_inclusions:] | ||
|
||
n_threads = _openmp_effective_n_threads() | ||
|
||
for i in range(math.ceil(n_to_select / batch_size_to_select)): | ||
if i == 0: | ||
batch_size_i = (n_to_select - 1) % batch_size_to_select + 1 + n_inclusions | ||
else: | ||
batch_size_i = batch_size | ||
indices, scores, mask = _prepare_search( | ||
n_features, | ||
batch_size_i, | ||
indices_include, | ||
np.r_[indices_exclude, indices_select], | ||
) | ||
_forward_search( | ||
X=X_transformed_, | ||
V=selector.y_transformed_, | ||
t=batch_size_i, | ||
tol=selector.tol, | ||
num_threads=n_threads, | ||
verbose=0, | ||
mask=mask, | ||
indices=indices, | ||
scores=scores, | ||
) | ||
indices_select = np.r_[indices_select, indices[n_inclusions:]] | ||
return np.r_[indices_include, indices_select] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""Test feature selection extend""" | ||
import numpy as np | ||
import pytest | ||
from numpy.testing import ( | ||
assert_array_equal, | ||
) | ||
from sklearn.datasets import make_classification | ||
|
||
from fastcan import FastCan, extend | ||
|
||
|
||
def test_select_extend_cls(): | ||
# Test whether refine work correctly with random samples. | ||
n_samples = 200 | ||
n_features = 30 | ||
n_informative = 20 | ||
n_classes = 8 | ||
n_repeated = 5 | ||
n_to_select = 18 | ||
|
||
X, y = make_classification( | ||
n_samples=n_samples, | ||
n_features=n_features, | ||
n_informative=n_informative, | ||
n_repeated=n_repeated, | ||
n_classes=n_classes, | ||
n_clusters_per_class=1, | ||
flip_y=0.0, | ||
class_sep=10, | ||
shuffle=False, | ||
random_state=0, | ||
) | ||
|
||
n_features_to_select = 2 | ||
selector = FastCan(n_features_to_select).fit(X, y) | ||
indices = extend(selector, n_to_select, batch_size=3) | ||
selector_inc = FastCan(n_features_to_select, indices_include=[10]).fit(X, y) | ||
indices_inc = extend(selector_inc, n_to_select, batch_size=3) | ||
selector_exc = FastCan( | ||
n_features_to_select, indices_include=[10], indices_exclude=[0] | ||
).fit(X, y) | ||
indices_exc = extend(selector_exc, n_to_select, batch_size=3) | ||
|
||
|
||
assert np.unique(indices).size == n_to_select | ||
assert_array_equal(indices[:n_features_to_select], selector.indices_) | ||
assert np.unique(indices_inc).size == n_to_select | ||
assert_array_equal(indices_inc[:n_features_to_select], selector_inc.indices_) | ||
assert np.unique(indices_exc).size == n_to_select | ||
assert_array_equal(indices_exc[:n_features_to_select], selector_exc.indices_) | ||
assert ~np.isin(0, indices_exc) | ||
|
||
|
||
def test_extend_error(): | ||
# Test refine raise error. | ||
n_samples = 200 | ||
n_features = 20 | ||
n_informative = 10 | ||
n_classes = 8 | ||
n_repeated = 5 | ||
|
||
X, y = make_classification( | ||
n_samples=n_samples, | ||
n_features=n_features, | ||
n_informative=n_informative, | ||
n_repeated=n_repeated, | ||
n_classes=n_classes, | ||
n_clusters_per_class=1, | ||
flip_y=0.0, | ||
class_sep=10, | ||
shuffle=False, | ||
random_state=0, | ||
) | ||
|
||
n_features_to_select = 2 | ||
|
||
selector = FastCan(n_features_to_select, indices_include=[0]).fit(X, y) | ||
|
||
with pytest.raises(ValueError, match=r"n_features_to_select .*"): | ||
_ = extend(selector, n_features+1, batch_size=3) | ||
|
||
with pytest.raises(ValueError, match=r"The number of features to select .*"): | ||
_ = extend(selector, n_features_to_select, batch_size=3) | ||
|
||
with pytest.raises(ValueError, match=r"The size of mini batch without .*"): | ||
_ = extend(selector, n_features, batch_size=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters