Skip to content

Commit

Permalink
MNT add sparse input support and complete documentation (scikit-learn…
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopfonseca committed Dec 18, 2021
1 parent 0e80574 commit 2db667f
Showing 1 changed file with 49 additions and 10 deletions.
59 changes: 49 additions & 10 deletions imblearn/over_sampling/_smote/geometric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Class to perform over-sampling using Geometric SMOTE."""

# Author: Georgios Douzas <[email protected]>
# Joao Fonseca <[email protected]>
# License: BSD 3 clause

import numpy as np
from numpy.linalg import norm
from scipy import sparse
from sklearn.utils import check_random_state
from imblearn.over_sampling.base import BaseOverSampler
from ..base import BaseOverSampler
from imblearn.utils import check_neighbors_object, Substitution
from imblearn.utils._docstring import _random_state_docstring

Expand Down Expand Up @@ -119,6 +121,33 @@ class GeometricSMOTE(BaseOverSampler):
n_jobs : int, optional (default=1)
The number of threads to open if possible.
Attributes
----------
sampling_strategy_ : dict
Dictionary containing the information to sample the dataset. The keys
corresponds to the class labels from which to sample and the values
are the number of samples to sample.
n_features_in_ : int
Number of features in the input dataset.
nns_pos_ : estimator object
Validated k-nearest neighbours created from the `k_neighbors` parameter. It is
used to find the nearest neighbors of the same class of a selected
observation.
nn_neg_ : estimator object
Validated k-nearest neighbours created from the `k_neighbors` parameter. It is
used to find the nearest neighbor of the remaining classes (k=1) of a selected
observation.
random_state_ : instance of RandomState
If the `random_state` parameter is None, it is a RandomState singleton used by
np.random. If `random_state` is an int, it is a RandomState instance seeded with
seed. If `random_state` is already a RandomState instance, it is the same
object.
Notes
-----
See the original paper: [1]_ for more details.
Expand All @@ -142,7 +171,8 @@ class GeometricSMOTE(BaseOverSampler):
>>> from collections import Counter
>>> from sklearn.datasets import make_classification
>>> from gsmote import GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE
>>> from imblearn.over_sampling import \
GeometricSMOTE # doctest: +NORMALIZE_WHITESPACE
>>> X, y = make_classification(n_classes=2, class_sep=2,
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
Expand Down Expand Up @@ -237,7 +267,7 @@ def _make_geometric_samples(self, X, y, pos_class_label, n_samples):

# Force minority strategy if no negative class samples are present
self.selection_strategy_ = (
'minority' if len(X) == len(X_pos) else self.selection_strategy
'minority' if X.shape[0] == X_pos.shape[0] else self.selection_strategy
)

# Minority or combined strategy
Expand Down Expand Up @@ -306,19 +336,28 @@ def _fit_resample(self, X, y):
# Validate estimator's parameters
self._validate_estimator()

# Ensure the input data is dense
X_dense = X.toarray() if sparse.issparse(X) else X

# Copy data
X_resampled, y_resampled = X.copy(), y.copy()
X_resampled, y_resampled = [X_dense.copy()], [y.copy()]

# Resample data
for class_label, n_samples in self.sampling_strategy_.items():

# Apply gsmote mechanism
X_new, y_new = self._make_geometric_samples(X, y, class_label, n_samples)

# Append new data
X_resampled, y_resampled = (
np.vstack((X_resampled, X_new)),
np.hstack((y_resampled, y_new)),
X_new, y_new = self._make_geometric_samples(
X_dense, y, class_label, n_samples
)

X_resampled.append(X_new)
y_resampled.append(y_new)

# Append new data
if sparse.issparse(X):
X_resampled = sparse.vstack(X_resampled, format=X.format)
else:
X_resampled = np.vstack(X_resampled).astype(X.dtype)
y_resampled = np.hstack(y_resampled).astype(y.dtype)

return X_resampled, y_resampled

0 comments on commit 2db667f

Please sign in to comment.