From fa3ac1640039c4288c94bb68819068c8e2737cd4 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Jul 2024 16:38:11 +0200 Subject: [PATCH] Limit sklearn version --- hiclass/BinaryPolicy.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hiclass/BinaryPolicy.py b/hiclass/BinaryPolicy.py index 46326b29..b6cf3001 100644 --- a/hiclass/BinaryPolicy.py +++ b/hiclass/BinaryPolicy.py @@ -2,7 +2,7 @@ from abc import ABC -from scipy.sparse import vstack, csr_matrix +from scipy.sparse import vstack, csr_matrix, csr_array import networkx as nx import numpy as np @@ -160,7 +160,7 @@ def get_binary_examples(self, node) -> tuple: ) y = np.zeros(len(X)) y[: len(positive_x)] = 1 - elif isinstance(self.X, csr_matrix): + elif isinstance(self.X, csr_matrix) or isinstance(self.X, csr_array): X = vstack([positive_x, negative_x]) sample_weights = ( vstack([positive_weights, negative_weights]) diff --git a/setup.py b/setup.py index dbe5a1fa..39fb9a0d 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ KEYWORDS = ["hierarchical classification"] DACS_SOFTWARE = "https://gitlab.com/dacs-hpi" # What packages are required for this module to be executed? -REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13"] +REQUIRED = ["networkx", "numpy", "scikit-learn<1.5", "scipy<1.13"] # What packages are optional? # 'fancy feature': ['django'],}