diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 9f127f58..2ad4cba6 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -298,12 +298,18 @@ class _PairsClassifierMixin(BaseMetricLearner): """ Attributes ---------- + classes_ : `list` + The possible labels of the pairs the metric learner can fit on. + `classes_ = [-1, 1]`, where -1 means points in a pair are dissimilar + (negative label), and 1 means they are similar (positive label). + threshold_ : `float` If the distance metric between two points is lower than this threshold, points will be classified as similar, otherwise they will be classified as dissimilar. """ + classes_ = [-1, 1] _tuple_size = 2 # number of points in a tuple, 2 for pairs def predict(self, pairs): diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 6cb34313..ac043a29 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -150,6 +150,11 @@ class ITML(_BaseITML, _PairsClassifierMixin): The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) + classes_ : `list` + The possible labels of the pairs `ITML` can fit on. `classes_ = [-1, 1]`, + where -1 means points in a pair are dissimilar (negative label), and 1 + means they are similar (positive label). + threshold_ : `float` If the distance metric between two points is lower than this threshold, points will be classified as similar, otherwise they will be diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index eb7dc529..910538d5 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -29,7 +29,16 @@ class _BaseMMC(MahalanobisMixin): - """Mahalanobis Metric for Clustering (MMC)""" + """Mahalanobis Metric for Clustering (MMC) + + Attributes + ---------- + + classes_ : `list` + The possible labels of the pairs `MMC` can fit on. `classes_ = [-1, 1]`, + where -1 means points in a pair are dissimilar (negative label), and 1 + means they are similar (positive label). + """ _tuple_size = 2 # constraints are pairs diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index b300b9ac..a7ca8478 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -143,6 +143,11 @@ class SDML(_BaseSDML, _PairsClassifierMixin): The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) + classes_ : `list` + The possible labels of the pairs `SDML` can fit on. `classes_ = [-1, 1]`, + where -1 means points in a pair are dissimilar (negative label), and 1 + means they are similar (positive label). + threshold_ : `float` If the distance metric between two points is lower than this threshold, points will be classified as similar, otherwise they will be diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 091c56e2..f3d27519 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -1,5 +1,6 @@ import pytest import unittest +from sklearn.calibration import CalibratedClassifierCV from sklearn.utils.estimator_checks import check_estimator from sklearn.base import TransformerMixin from sklearn.pipeline import make_pipeline @@ -18,9 +19,9 @@ from sklearn.metrics.scorer import get_scorer from sklearn.utils.testing import _get_args from test.test_utils import (metric_learners, ids_metric_learners, - mock_preprocessor, tuples_learners, - ids_tuples_learners, pairs_learners, - ids_pairs_learners, remove_y_quadruplets, + mock_preprocessor, pairs_learners, + ids_pairs_learners, tuples_learners, + ids_tuples_learners, remove_y_quadruplets, quadruplets_learners) @@ -106,6 +107,31 @@ def stable_init(self, num_dims=None, pca_comps=None, # ---------------------- Test scikit-learn compatibility ---------------------- +@pytest.mark.parametrize('with_preprocessor', + [True, + # TODO: uncomment the below line as soon as + # https://github.com/scikit-learn/scikit-learn/ + # issues/13077 is solved: + # False, + ]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_calibrated_classifier_CV(estimator, build_dataset, + with_preprocessor): + """Tests that metric-learn tuples estimators' work with scikit-learn's + CalibratedClassifierCV. + """ + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + calibrated_clf = CalibratedClassifierCV(estimator) + + # test fit and predict_proba + calibrated_clf.fit(input_data, labels) + calibrated_clf.predict_proba(input_data) + + @pytest.mark.parametrize('with_preprocessor', [True, False]) @pytest.mark.parametrize('estimator, build_dataset', pairs_learners, ids=ids_pairs_learners)