diff --git a/graphtools/base.py b/graphtools/base.py index 8ec1b57..edd678a 100644 --- a/graphtools/base.py +++ b/graphtools/base.py @@ -65,6 +65,9 @@ def _get_param_names(cls): return parameters + def set_params(self, **kwargs): + return self + class Data(Base): """Parent class that handles the import and dimensionality reduction of data @@ -202,6 +205,7 @@ def set_params(self, **params): raise ValueError("Cannot update n_pca. Please create a new graph") if 'random_state' in params: self.random_state = params['random_state'] + super().set_params(**params) return self def transform(self, Y): @@ -441,6 +445,7 @@ def set_params(self, **params): params['kernel_symm'] != self.kernel_symm: raise ValueError( "Cannot update kernel_symm. Please create a new graph") + super().set_params(**params) return self @property diff --git a/graphtools/graphs.py b/graphtools/graphs.py index 6915cea..fb1b2a3 100644 --- a/graphtools/graphs.py +++ b/graphtools/graphs.py @@ -681,14 +681,14 @@ def set_params(self, **params): raise ValueError("Cannot update precomputed. " "Please create a new graph") if 'distance' in params and params['distance'] != self.distance and \ - self.precomputed is not None: + self.precomputed is None: raise ValueError("Cannot update distance. " "Please create a new graph") if 'knn' in params and params['knn'] != self.knn and \ - self.precomputed is not None: + self.precomputed is None: raise ValueError("Cannot update knn. Please create a new graph") if 'decay' in params and params['decay'] != self.decay and \ - self.precomputed is not None: + self.precomputed is None: raise ValueError("Cannot update decay. Please create a new graph") # update superclass parameters super().set_params(**params) diff --git a/test/test_landmark.py b/test/test_landmark.py index 04fbd21..42d9025 100644 --- a/test/test_landmark.py +++ b/test/test_landmark.py @@ -132,18 +132,18 @@ def test_verbose(): def test_set_params(): G = build_graph(data, n_landmark=500, decay=None) G.landmark_op - assert G.get_params == {'n_pca': 20, - 'random_state': 42, - 'kernel_symm': '+', - 'gamma': None, - 'n_landmark': 500, - 'knn': 3, - 'decay': None, - 'distance': - 'euclidean', - 'thresh': 0, - 'n_jobs': -1, - 'verbose': 0} + assert G.get_params() == {'n_pca': 20, + 'random_state': 42, + 'kernel_symm': '+', + 'gamma': None, + 'n_landmark': 500, + 'knn': 3, + 'decay': None, + 'distance': + 'euclidean', + 'thresh': 0, + 'n_jobs': -1, + 'verbose': 0} G.set_params(n_landmark=300) assert G.landmark_op.shape == (300, 300) G.set_params(n_landmark=G.n_landmark, n_svd=G.n_svd)