diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 7b8f1dba..7cb24c72 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -4,6 +4,8 @@ import networkx as nx import numpy as np +import sklearn + from joblib import Parallel, delayed from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression @@ -348,3 +350,12 @@ def _clean_up(self): del self.y_ if self.sample_weight_ is not None: del self.sample_weight_ + + def _change_local_classifier(self, classifier): + if not isinstance(classifier, sklearn.base.BaseEstimator): + raise TypeError( + "Unsupported Classifier: Classifier should be of type sklearn.base.BaseEstimator" + ) + + self.local_classifier = classifier + self.local_classifier_ = classifier diff --git a/tests/test_LocalClassifiers.py b/tests/test_LocalClassifiers.py index 37f1bf46..8364d9a9 100644 --- a/tests/test_LocalClassifiers.py +++ b/tests/test_LocalClassifiers.py @@ -105,7 +105,7 @@ def test_knn(classifier): @pytest.mark.parametrize("classifier", classifiers) def test_fit_multiple_dim_input(classifier): clf = classifier() - X = np.random.rand(1, 275, 3) + X = np.random.rand(1, 1, 275, 3) y = np.array([["a", "b", "c"]]) clf.fit(X, y) check_is_fitted(clf) @@ -119,3 +119,16 @@ def test_predict_multiple_dim_input(classifier): clf.fit(X, y) predictions = clf.predict(X) assert predictions is not None + + +@pytest.mark.parametrize("classifier", classifiers) +def test_change_local_classifier(classifier): + clf = classifier(local_classifier=LogisticRegression()) + y = np.array([["a", "b", "c"], ["a", "b", "d"]]) + X = np.random.randint(1, 11, size=(2, 10)) + + clf.fit(X, y) + assert isinstance(clf.local_classifier_, LogisticRegression) + + clf._change_local_classifier(KNeighborsClassifier()) + assert isinstance(clf.local_classifier_, KNeighborsClassifier)