Skip to content

Commit

Permalink
Fix sklearn compat issues
Browse files Browse the repository at this point in the history
  • Loading branch information
perimosocordiae committed Sep 28, 2023
1 parent 4e89e3d commit e5b06fa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
11 changes: 7 additions & 4 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base module.
"""

from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.extmath import stable_cumsum
from sklearn.utils.validation import _is_arraylike, check_is_fitted
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve
Expand Down Expand Up @@ -464,7 +464,7 @@ def get_mahalanobis_matrix(self):
return self.components_.T.dot(self.components_)


class _PairsClassifierMixin(BaseMetricLearner):
class _PairsClassifierMixin(BaseMetricLearner, ClassifierMixin):
"""Base class for pairs learners.
Attributes
Expand All @@ -475,6 +475,7 @@ class _PairsClassifierMixin(BaseMetricLearner):
classified as dissimilar.
"""

classes_ = np.array([0, 1])
_tuple_size = 2 # number of points in a tuple, 2 for pairs

def predict(self, pairs):
Expand Down Expand Up @@ -752,11 +753,12 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None,
'Got {} instead.'.format(type(beta)))


class _TripletsClassifierMixin(BaseMetricLearner):
class _TripletsClassifierMixin(BaseMetricLearner, ClassifierMixin):
"""
Base class for triplets learners.
"""

classes_ = np.array([0, 1])
_tuple_size = 3 # number of points in a tuple, 3 for triplets

def predict(self, triplets):
Expand Down Expand Up @@ -837,11 +839,12 @@ def score(self, triplets):
return self.predict(triplets).mean() / 2 + 0.5


class _QuadrupletsClassifierMixin(BaseMetricLearner):
class _QuadrupletsClassifierMixin(BaseMetricLearner, ClassifierMixin):
"""
Base class for quadruplets learners.
"""

classes_ = np.array([0, 1])
_tuple_size = 4 # number of points in a tuple, 4 for quadruplets

def predict(self, quadruplets):
Expand Down
4 changes: 4 additions & 0 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def _fit(self, pairs, y):
print("SDML will use skggm's graphical lasso solver.")
pairs, y = self._prepare_inputs(pairs, y,
type_of_inputs='tuples')
n_features = pairs.shape[2]
if n_features < 2:
raise ValueError(f"Cannot fit SDML with {n_features} feature(s)")

# set up (the inverse of) the prior M
# if the prior is the default (None), we raise a warning
Expand Down Expand Up @@ -83,6 +86,7 @@ def _fit(self, pairs, y):
w_mahalanobis, _ = np.linalg.eigh(M)
not_spd = any(w_mahalanobis < 0.)
not_finite = not np.isfinite(M).all()
# TODO: Narrow this to the specific exceptions we expect.
except Exception as e:
raised_error = e
not_spd = False # not_spd not applicable here so we set to False
Expand Down

0 comments on commit e5b06fa

Please sign in to comment.