Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add classes attribute and test for CalibratedClassifierCV #187

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,18 @@ class _PairsClassifierMixin(BaseMetricLearner):
"""
Attributes
----------
classes_ : `list`
The possible labels of the pairs the metric learner can fit on.
`classes_ = [-1, 1]`, where -1 means points in a pair are dissimilar
(negative label), and 1 means they are similar (positive label).

threshold_ : `float`
If the distance metric between two points is lower than this threshold,
points will be classified as similar, otherwise they will be
classified as dissimilar.
"""

classes_ = [-1, 1]
_tuple_size = 2 # number of points in a tuple, 2 for pairs

def predict(self, pairs):
Expand Down
5 changes: 5 additions & 0 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ class ITML(_BaseITML, _PairsClassifierMixin):
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See function `transformer_from_metric`.)

classes_ : `list`
The possible labels of the pairs `ITML` can fit on. `classes_ = [-1, 1]`,
where -1 means points in a pair are dissimilar (negative label), and 1
means they are similar (positive label).

threshold_ : `float`
If the distance metric between two points is lower than this threshold,
points will be classified as similar, otherwise they will be
Expand Down
11 changes: 10 additions & 1 deletion metric_learn/mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@


class _BaseMMC(MahalanobisMixin):
"""Mahalanobis Metric for Clustering (MMC)"""
"""Mahalanobis Metric for Clustering (MMC)

Attributes
----------

classes_ : `list`
The possible labels of the pairs `MMC` can fit on. `classes_ = [-1, 1]`,
where -1 means points in a pair are dissimilar (negative label), and 1
means they are similar (positive label).
"""

_tuple_size = 2 # constraints are pairs

Expand Down
5 changes: 5 additions & 0 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ class SDML(_BaseSDML, _PairsClassifierMixin):
The linear transformation ``L`` deduced from the learned Mahalanobis
metric (See function `transformer_from_metric`.)

classes_ : `list`
The possible labels of the pairs `SDML` can fit on. `classes_ = [-1, 1]`,
where -1 means points in a pair are dissimilar (negative label), and 1
means they are similar (positive label).

threshold_ : `float`
If the distance metric between two points is lower than this threshold,
points will be classified as similar, otherwise they will be
Expand Down
32 changes: 29 additions & 3 deletions test/test_sklearn_compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import unittest
from sklearn.calibration import CalibratedClassifierCV
from sklearn.utils.estimator_checks import check_estimator
from sklearn.base import TransformerMixin
from sklearn.pipeline import make_pipeline
Expand All @@ -18,9 +19,9 @@
from sklearn.metrics.scorer import get_scorer
from sklearn.utils.testing import _get_args
from test.test_utils import (metric_learners, ids_metric_learners,
mock_preprocessor, tuples_learners,
ids_tuples_learners, pairs_learners,
ids_pairs_learners, remove_y_quadruplets,
mock_preprocessor, pairs_learners,
ids_pairs_learners, tuples_learners,
ids_tuples_learners, remove_y_quadruplets,
quadruplets_learners)


Expand Down Expand Up @@ -106,6 +107,31 @@ def stable_init(self, num_dims=None, pca_comps=None,
# ---------------------- Test scikit-learn compatibility ----------------------


@pytest.mark.parametrize('with_preprocessor',
[True,
# TODO: uncomment the below line as soon as
# https://github.com/scikit-learn/scikit-learn/
# issues/13077 is solved:
# False,
])
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,
ids=ids_pairs_learners)
def test_calibrated_classifier_CV(estimator, build_dataset,
with_preprocessor):
"""Tests that metric-learn tuples estimators' work with scikit-learn's
CalibratedClassifierCV.
"""
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
estimator = clone(estimator)
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
calibrated_clf = CalibratedClassifierCV(estimator)

# test fit and predict_proba
calibrated_clf.fit(input_data, labels)
calibrated_clf.predict_proba(input_data)


@pytest.mark.parametrize('with_preprocessor', [True, False])
@pytest.mark.parametrize('estimator, build_dataset', pairs_learners,
ids=ids_pairs_learners)
Expand Down