diff --git a/hiclass/BinaryPolicy.py b/hiclass/BinaryPolicy.py index 46326b29..b6cf3001 100644 --- a/hiclass/BinaryPolicy.py +++ b/hiclass/BinaryPolicy.py @@ -2,7 +2,7 @@ from abc import ABC -from scipy.sparse import vstack, csr_matrix +from scipy.sparse import vstack, csr_matrix, csr_array import networkx as nx import numpy as np @@ -160,7 +160,7 @@ def get_binary_examples(self, node) -> tuple: ) y = np.zeros(len(X)) y[: len(positive_x)] = 1 - elif isinstance(self.X, csr_matrix): + elif isinstance(self.X, csr_matrix) or isinstance(self.X, csr_array): X = vstack([positive_x, negative_x]) sample_weights = ( vstack([positive_weights, negative_weights]) diff --git a/setup.py b/setup.py index dbe5a1fa..39fb9a0d 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ KEYWORDS = ["hierarchical classification"] DACS_SOFTWARE = "https://gitlab.com/dacs-hpi" # What packages are required for this module to be executed? -REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13"] +REQUIRED = ["networkx", "numpy", "scikit-learn<1.5", "scipy<1.13"] # What packages are optional? # 'fancy feature': ['django'],}