Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] [ENH] Add sample_indices_ for SMOTE/ADASYN classes #933

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3ccd2dc
[ENH] Add sample_indices_ for SMOTE/ADASYN classes
JurajSlivka Oct 28, 2022
1771f6d
Fix of Black Linting errors.
JurajSlivka Oct 28, 2022
4c69088
Fix of black linting error.
JurajSlivka Oct 28, 2022
006a832
Fix of black linting error.
JurajSlivka Oct 28, 2022
50057ee
Fix of some build errors.
JurajSlivka Oct 28, 2022
03132a2
Fix of error that "Summary must start with infinitive verb"
JurajSlivka Oct 28, 2022
7e24dbb
Fix of third person error in summary.
JurajSlivka Oct 28, 2022
d63da62
Fix of third person error in summary.
JurajSlivka Oct 28, 2022
c5e79c9
Added keepdims=True for stats.mode()
JurajSlivka Oct 28, 2022
d6207e6
Black linter error.
JurajSlivka Oct 28, 2022
98c41b3
Back to previous commit
JurajSlivka Oct 28, 2022
e7627c7
Trying to solve FutureWarning error which of keepdims param in scipy.…
JurajSlivka Oct 29, 2022
5d14358
Solve of FutureWarning error
JurajSlivka Oct 30, 2022
7406d85
Fix for black error
JurajSlivka Oct 30, 2022
cf78c53
[FIX] Sample_indices wrong return when balanced dataset is sent
JurajSlivka Nov 7, 2022
2549e62
[FIX] Run black error
JurajSlivka Nov 7, 2022
ce497ad
Deleted keepdims attribute as it is already being solved in PR938 and…
JurajSlivka Nov 15, 2022
fa7cf53
Fix of black errors
JurajSlivka Nov 15, 2022
6471999
Fix for black and one more test added
JurajSlivka Nov 15, 2022
bbe36ec
Fix for black
JurajSlivka Nov 15, 2022
29b053c
Fix for unused import
JurajSlivka Nov 15, 2022
4c72219
Test of code after PR#946 was merged
JurajSlivka Nov 15, 2022
02e31aa
Resolve of conflict
JurajSlivka Dec 3, 2022
8e13811
FIX of "issort" error during checks
JurajSlivka Dec 3, 2022
e660386
Update test_smote.py
JurajSlivka Dec 3, 2022
cc55dd8
Update test_smote.py
JurajSlivka Dec 3, 2022
fc2cde9
Fix of linting
JurajSlivka Dec 3, 2022
80062d9
Merge branch 'master' into issue772
JurajSlivka Feb 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions imblearn/over_sampling/_adasyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def _fit_resample(self, X, y):
X_resampled = [X.copy()]
y_resampled = [y.copy()]

self._sample_indices = np.stack((np.arange(len(y)), np.zeros(len(y)))).T

for class_sample, n_samples in self.sampling_strategy_.items():
if n_samples == 0:
continue
Expand Down Expand Up @@ -207,6 +209,14 @@ def _fit_resample(self, X, y):
diffs = X_class[nns[rows, cols]] - X_class[rows]
steps = random_state.uniform(size=(n_samples, 1))

self._sample_indices = np.concatenate(
(
np.stack((np.arange(len(y)), np.zeros(len(y)))).T,
np.stack((rows, cols)).T,
),
axis=0,
)

if sparse.issparse(X):
sparse_func = type(X).__name__
steps = getattr(sparse, sparse_func)(steps)
Expand All @@ -231,3 +241,21 @@ def _more_tags(self):
return {
"X_types": ["2darray"],
}

def get_sample_indices(self):
"""Return a tuple of indexes of the samples used to generate the new point.

Usable with ADASYN.

Returns
-------
_sample_indices : ndarray of shape (mother_sample_index, random_neighbour_index)
If the sample belongs to original dataset:
mother_sample : index of the original sample
random_neighbour_index : index of the neighbour sample
"""
try:
self._sample_indices
except AttributeError:
return None
return self._sample_indices
31 changes: 31 additions & 0 deletions imblearn/over_sampling/_smote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
else:
X_new = X[rows] + steps * diffs

self._sample_indices = np.stack((rows, nn_num[rows, cols])).T

return X_new.astype(X.dtype)

def _in_danger_noise(self, nn_estimator, samples, target_class, y, kind="danger"):
Expand Down Expand Up @@ -365,8 +367,37 @@ def _fit_resample(self, X, y):
X_resampled = np.vstack(X_resampled)
y_resampled = np.hstack(y_resampled)

try:
self._sample_indices
except AttributeError:
self._sample_indices = np.stack((np.arange(len(y)), np.zeros(len(y)))).T
return X_resampled, y_resampled

self._sample_indices = np.concatenate(
(np.stack((np.arange(len(y)), np.zeros(len(y)))).T, self._sample_indices),
axis=0,
)

return X_resampled, y_resampled

def get_sample_indices(self):
"""Return a tuple of indexes of the samples used to generate the new point.

Usable with SMOTE.

Returns
-------
_sample_indices : ndarray of shape (mother_sample_index, random_neighbour_index)
If the sample belongs to original dataset:
mother_sample : index of the original sample
random_neighbour_index : index of the neighbour sample
"""
try:
self._sample_indices
except AttributeError:
return None
return self._sample_indices


@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
Expand Down
21 changes: 21 additions & 0 deletions imblearn/over_sampling/_smote/tests/test_borderline_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,24 @@ def test_borderline_smote(kind, data):

assert_allclose(X_res_1, X_res_2)
assert_array_equal(y_res_1, y_res_2)


@pytest.mark.parametrize("kind", ["borderline-1", "borderline-2"])
def test_borderline_smote_FutureWarning(kind, data):
bsmote = BorderlineSMOTE(kind=kind, random_state=42, n_jobs=1)
bsmote_nn = BorderlineSMOTE(
kind=kind,
random_state=42,
k_neighbors=NearestNeighbors(n_neighbors=6),
m_neighbors=NearestNeighbors(n_neighbors=11),
)
with pytest.warns(FutureWarning) as record:
bsmote.fit_resample(*data)
bsmote_nn.fit_resample(*data)
assert len(record) == 1
assert (
record[0].message.args[0]
== "The parameter `n_jobs` has been deprecated in 0.10"
" and will be removed in 0.12. You can pass an nearest"
" neighbors estimator where `n_jobs` is already set instead."
)
82 changes: 82 additions & 0 deletions imblearn/over_sampling/_smote/tests/test_smote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
R_TOL = 1e-4

XX = np.array(
[
[0.11622591, -0.0317206],
[0.77481731, 0.60935141],
[1.25192108, -0.22367336],
[0.53366841, -0.30312976],
[1.52091956, -0.49283504],
[-0.28162401, -2.10400981],
[0.83680821, 1.72827342],
[0.3084254, 0.33299982],
[0.70472253, -0.73309052],
[0.28893132, -0.38761769],
[1.15514042, 0.0129463],
[0.88407872, 0.35454207],
]
)
YY = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1])


def test_sample_regular():
smote = SMOTE(random_state=RND_SEED)
Expand Down Expand Up @@ -147,3 +165,67 @@ def test_sample_regular_with_nn():
)
assert_allclose(X_resampled, X_gt, rtol=R_TOL)
assert_array_equal(y_resampled, y_gt)


def test_sample_indices():
smote = SMOTE(random_state=RND_SEED)
smote.fit_resample(X, Y)
indices = smote.get_sample_indices()
indices_gt = np.array(
[
[0, 0],
[1, 0],
[2, 0],
[3, 0],
[4, 0],
[5, 0],
[6, 0],
[7, 0],
[8, 0],
[9, 0],
[10, 0],
[11, 0],
[12, 0],
[13, 0],
[14, 0],
[15, 0],
[16, 0],
[17, 0],
[18, 0],
[19, 0],
[0, 2],
[0, 1],
[0, 1],
[7, 2],
]
)
assert_array_equal(indices, indices_gt)


def test_sample_indices_balanced_dataset():
smote = SMOTE(random_state=RND_SEED)
smote.fit_resample(XX, YY)
indices = smote.get_sample_indices()
indices_gt = np.array(
[
[0, 0],
[1, 0],
[2, 0],
[3, 0],
[4, 0],
[5, 0],
[6, 0],
[7, 0],
[8, 0],
[9, 0],
[10, 0],
[11, 0],
]
)
assert_array_equal(indices, indices_gt)


def test_sample_indices_is_none():
smote = SMOTE(random_state=RND_SEED)
indices = smote.get_sample_indices()
assert_array_equal(indices, None)
37 changes: 37 additions & 0 deletions imblearn/over_sampling/_smote/tests/test_smoten.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,40 @@ def test_smoten_resampling():
X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :]
np.testing.assert_array_equal(X_generated, "blue")
np.testing.assert_array_equal(y_generated, "not apple")


def test_smoten_FutureWarning(data):
# check that SMOTEN throws FutureWarning for "n_jobs" and "keepdims"
X, y = data
sampler = SMOTEN(random_state=0, n_jobs=0)
with pytest.warns(FutureWarning) as record:
sampler.fit_resample(X, y)
assert (
record[0].message.args[0]
== "The parameter `n_jobs` has been deprecated in 0.10"
" and will be removed in 0.12. You can pass an nearest"
" neighbors estimator where `n_jobs` is already set instead."
)


@pytest.fixture
def data_balanced():
rng = np.random.RandomState(0)

feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
feature_2 = ["A"] * 40 + ["B"] * 20
feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
X = np.array([feature_1, feature_2, feature_3], dtype=object).T
rng.shuffle(X)
y = np.array([0] * 30 + [1] * 30, dtype=np.int32)
y_labels = np.array(["not apple", "apple"], dtype=object)
y = y_labels[y]
return X, y


def test_smoten_balanced_data(data_balanced):
X, y = data_balanced
sampler = SMOTEN(random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
assert X_res.shape == (60, 3)
assert y_res.shape == (60,)
31 changes: 31 additions & 0 deletions imblearn/over_sampling/tests/test_adasyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@
Y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
R_TOL = 1e-4

XX = np.array(
[
[0.11622591, -0.0317206],
[0.77481731, 0.60935141],
[1.25192108, -0.22367336],
[0.53366841, -0.30312976],
[1.52091956, -0.49283504],
[-0.28162401, -2.10400981],
[0.83680821, 1.72827342],
[0.3084254, 0.33299982],
[0.70472253, -0.73309052],
[0.28893132, -0.38761769],
[1.15514042, 0.0129463],
[0.88407872, 0.35454207],
]
)
YY = np.array([0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0])

XXX = np.array(
[
[0.915, 0.892],
[0.926, 0.959],
[0.917, 0.983],
[0.945, 0.967],
[-0.844, -0.925],
[-0.987, -0.946],
[-0.962, -0.948],
]
)
YYY = np.array([1, 1, 1, 1, 0, 0, 0])


def test_ada_init():
sampling_strategy = "auto"
Expand Down