Skip to content

Commit

Permalink
Add flat classifier #minor (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 authored Jul 23, 2024
1 parent 6f37990 commit 5ee9b59
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 3 deletions.
15 changes: 15 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,18 @@ pre-commit install
```

If black is not executed locally and there are formatting errors, the CI/CD pipeline will fail.

## Building the documentation locally

To build the documentation locally, you need to install another set of dependencies that are specific for the documentation. It is easier to create a separate conda environment and run the following command:

```
pip install -r docs/requirements.txt
```

To build the documentation you need to change to run the following commands:

```
cd docs
make html
```
9 changes: 9 additions & 0 deletions docs/source/api/classifiers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ LocalClassifierPerParentNode
:show-inheritance:
:inherited-members:
:special-members: __init__

Flat Classifier
===============

FlatClassifier
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: FlatClassifier.FlatClassifier
:members:
:special-members: __init__
4 changes: 2 additions & 2 deletions hiclass/BinaryPolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC

from scipy.sparse import vstack, csr_matrix
from scipy.sparse import vstack, csr_matrix, csr_array
import networkx as nx
import numpy as np

Expand Down Expand Up @@ -160,7 +160,7 @@ def get_binary_examples(self, node) -> tuple:
)
y = np.zeros(len(X))
y[: len(positive_x)] = 1
elif isinstance(self.X, csr_matrix):
elif isinstance(self.X, csr_matrix) or isinstance(self.X, csr_array):
X = vstack([positive_x, negative_x])
sample_weights = (
vstack([positive_weights, negative_weights])
Expand Down
99 changes: 99 additions & 0 deletions hiclass/FlatClassifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Flat classifier approach, used for comparison purposes.
Implementation by @lpfgarcia
"""

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.linear_model import LogisticRegression
from sklearn.utils.validation import check_is_fitted


class FlatClassifier(BaseEstimator):
"""
A flat classifier utility that accepts as input a hierarchy and flattens it internally.
Examples
--------
>>> from hiclass import FlatClassifier
>>> y = [['1', '1.1'], ['2', '2.1']]
>>> X = [[1, 2], [3, 4]]
>>> flat = FlatClassifier()
>>> flat.fit(X, y)
>>> flat.predict(X)
array([['1', '1.1'],
['2', '2.1']])
"""

def __init__(
self,
local_classifier: BaseEstimator = LogisticRegression(),
):
"""
Initialize a flat classifier.
Parameters
----------
local_classifier : BaseEstimator, default=LogisticRegression
The scikit-learn model used for the flat classification. Needs to have fit, predict and clone methods.
"""
self.local_classifier = local_classifier

def fit(self, X, y, sample_weight=None):
"""
Fit a flat classifier.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The training input samples. Internally, its dtype will be converted
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csc_matrix``.
y : array-like of shape (n_samples, n_levels)
The target values, i.e., hierarchical class labels for classification.
sample_weight : array-like of shape (n_samples,), default=None
Array of weights that are assigned to individual samples.
If not provided, then each sample is given unit weight.
Returns
-------
self : object
Fitted estimator.
"""
# Convert from hierarchical labels to flat labels
self.separator_ = "::HiClass::Separator::"
y = [self.separator_.join(i) for i in y]

# Fit flat classifier
self.local_classifier.fit(X, y, sample_weight=sample_weight)

# Return the classifier
return self

def predict(self, X):
"""
Predict classes for the given data.
Hierarchical labels are returned.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples. Internally, its dtype will be converted
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csr_matrix``.
Returns
-------
y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
The predicted classes.
"""
# Check if fit has been called
check_is_fitted(self)

# Predict and remove separator
predictions = [
i.split(self.separator_) for i in self.local_classifier.predict(X)
]

return np.array(predictions)
2 changes: 2 additions & 0 deletions hiclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .LocalClassifierPerLevel import LocalClassifierPerLevel
from .LocalClassifierPerNode import LocalClassifierPerNode
from .LocalClassifierPerParentNode import LocalClassifierPerParentNode
from .FlatClassifier import FlatClassifier
from .MultiLabelLocalClassifierPerNode import MultiLabelLocalClassifierPerNode
from .MultiLabelLocalClassifierPerParentNode import (
MultiLabelLocalClassifierPerParentNode,
Expand All @@ -19,6 +20,7 @@
"LocalClassifierPerNode",
"LocalClassifierPerParentNode",
"LocalClassifierPerLevel",
"FlatClassifier",
"Explainer",
"MultiLabelLocalClassifierPerNode",
"MultiLabelLocalClassifierPerParentNode",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
KEYWORDS = ["hierarchical classification"]
DACS_SOFTWARE = "https://gitlab.com/dacs-hpi"
# What packages are required for this module to be executed?
REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13"]
REQUIRED = ["networkx", "numpy", "scikit-learn<1.5", "scipy<1.13"]

# What packages are optional?
# 'fancy feature': ['django'],}
Expand Down
13 changes: 13 additions & 0 deletions tests/test_FlatClassifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np
from numpy.testing import assert_array_equal

from hiclass import FlatClassifier


def test_fit_predict():
flat = FlatClassifier()
x = np.array([[1, 2], [3, 4]])
y = np.array([["a", "b"], ["b", "c"]])
flat.fit(x, y)
predictions = flat.predict(x)
assert_array_equal(y, predictions)

0 comments on commit 5ee9b59

Please sign in to comment.