diff --git a/graphtools/api.py b/graphtools/api.py index 68c236e..4282972 100644 --- a/graphtools/api.py +++ b/graphtools/api.py @@ -16,10 +16,11 @@ def Graph(data, decay=10, distance='euclidean', thresh=1e-4, + kernel_symm='+', + gamma=None, n_landmark=None, n_svd=100, beta=1, - gamma=0.5, n_jobs=-1, verbose=False, random_state=None, @@ -71,6 +72,17 @@ def Graph(data, All affinities below `thresh` will be set to zero in order to save on time and memory constraints. + kernel_symm : string, optional (default: '+') + Defines method of MNN symmetrization. + '+' : additive + '*' : multiplicative + 'gamma' : min-max + 'none' : no symmetrization + + gamma: float (default: None) + Min-max symmetrization constant or matrix. Only used if kernel_symm='gamma'. + K = `gamma * min(K, K.T) + (1 - gamma) * max(K, K.T)` + precomputed : {'distance', 'affinity', 'adjacency', `None`}, optional (default: `None`) If the graph is precomputed, this variable denotes which graph matrix is provided as `data`. @@ -79,11 +91,6 @@ def Graph(data, beta: float, optional(default: 1) Multiply within - batch connections by(1 - beta) - gamma: float or {'+', '*'} (default: 0.99) - Symmetrization method. If '+', use `(K + K.T) / 2`, - if '*', use `K * K.T`, if a float, use - `gamma * min(K, K.T) + (1 - gamma) * max(K, K.T)` - sample_idx: array-like Batch index for MNN kernel @@ -205,44 +212,19 @@ class Graph(parent_classes[0], parent_classes[1]): class Graph(parent_classes[0], parent_classes[1], parent_classes[2]): pass else: - raise RuntimeError("unknown graph classes") + raise RuntimeError("unknown graph classes {}".format(parent_classes)) + + params = kwargs + for parent_class in parent_classes: + for param in parent_class._get_param_names(): + try: + params[param] = eval(param) + except NameError: + # keyword argument not specified above - no problem + pass # build graph and return log_debug("Initializing {} with arguments {}".format( parent_classes, - { - 'n_pca': n_pca, - 'sample_idx': sample_idx, - 'adaptive_k': adaptive_k, - 'precomputed': precomputed, - 'knn': knn, - 'decay': decay, - 'distance': distance, - 'thresh': thresh, - 'n_landmark': n_landmark, - 'n_svd': n_svd, - 'beta': beta, - 'gamma': gamma, - 'n_jobs': n_jobs, - 'verbose': verbose, - 'random_state': random_state, - 'initialize': initialize - })) - return Graph(data, - n_pca=n_pca, - sample_idx=sample_idx, - adaptive_k=adaptive_k, - precomputed=precomputed, - knn=knn, - decay=decay, - distance=distance, - thresh=thresh, - n_landmark=n_landmark, - n_svd=n_svd, - beta=beta, - gamma=gamma, - n_jobs=n_jobs, - verbose=verbose, - random_state=random_state, - initialize=initialize, - **kwargs) + params)) + return Graph(**params) diff --git a/graphtools/base.py b/graphtools/base.py index a76feee..dd7233c 100644 --- a/graphtools/base.py +++ b/graphtools/base.py @@ -3,17 +3,21 @@ import numpy as np import abc import pygsp +from sklearn.utils.fixes import signature from sklearn.decomposition import PCA, TruncatedSVD from sklearn.preprocessing import normalize from scipy import sparse import warnings +import numbers try: import pandas as pd except ImportError: # pandas not installed pass -from .utils import set_diagonal +from .utils import (elementwise_minimum, + elementwise_maximum, + set_diagonal) from .logging import (set_logging, log_start, log_complete, @@ -25,9 +29,38 @@ class Base(object): just an object. """ - def __init__(self, **kwargs): + def __init__(self): super().__init__() + @classmethod + def _get_param_names(cls): + """Get parameter names for the estimator""" + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [p for p in init_signature.parameters.values() + if p.name != 'self' and p.kind != p.VAR_KEYWORD] + # Extract and sort argument names excluding 'self' + parameters = set([p.name for p in parameters]) + + # recurse + for superclass in cls.__bases__: + try: + parameters.update(superclass._get_param_names()) + except AttributeError: + # object and pygsp.graphs.Graph don't have this method + pass + + return parameters + class Data(Base): """Parent class that handles the import and dimensionality reduction of data @@ -36,7 +69,7 @@ class Data(Base): ---------- data : array-like, shape=[n_samples,n_features] accepted types: `numpy.ndarray`, `scipy.sparse.spmatrix`. - TODO: accept pandas dataframes + `pandas.DataFrame`, `pandas.SparseDataFrame`. n_pca : `int` or `None`, optional (default: `None`) number of PC dimensions to retain for graph building. @@ -100,16 +133,18 @@ def _reduce_data(self): if self.n_pca is not None and self.n_pca < self.data.shape[1]: log_start("PCA") if sparse.issparse(self.data): + if isinstance(self.data, sparse.coo_matrix) or \ + isinstance(self.data, sparse.lil_matrix) or \ + isinstance(self.data, sparse.dok_matrix): + self.data = self.data.tocsr() self.data_pca = TruncatedSVD(self.n_pca, - random_state=self.random_state) - self.data_pca.fit(self.data) - data_nu = self.data_pca.transform(self.data) + random_state=self.random_state) else: self.data_pca = PCA(self.n_pca, - svd_solver='randomized', - random_state=self.random_state) - self.data_pca.fit(self.data) - data_nu = self.data_pca.transform(self.data) + svd_solver='randomized', + random_state=self.random_state) + self.data_pca.fit(self.data) + data_nu = self.data_pca.transform(self.data) log_complete("PCA") return data_nu else: @@ -144,7 +179,6 @@ def set_params(self, **params): self.random_state = params['random_state'] return self - def transform(self, Y): """Transform input data `Y` to reduced data space defined by `self.data` @@ -168,15 +202,15 @@ def transform(self, Y): # try PCA first return self.data_pca.transform(Y) - except AttributeError: #no pca, try to return data - try: - if Y.shape[1] != self.data.shape[1]: - # shape is wrong - raise ValueError - return Y - except IndexError: - # len(Y.shape) < 2 + except AttributeError: # no pca, try to return data + try: + if Y.shape[1] != self.data.shape[1]: + # shape is wrong raise ValueError + return Y + except IndexError: + # len(Y.shape) < 2 + raise ValueError except ValueError: # more informative error raise ValueError("data of shape {} cannot be transformed" @@ -206,14 +240,14 @@ def inverse_transform(self, Y): # try PCA first return self.data_pca.inverse_transform(Y) except AttributeError: - try: - if Y.shape[1] != self.data_nu.shape[1]: - # shape is wrong - raise ValueError - return Y - except IndexError: - # len(Y.shape) < 2 + try: + if Y.shape[1] != self.data_nu.shape[1]: + # shape is wrong raise ValueError + return Y + except IndexError: + # len(Y.shape) < 2 + raise ValueError except ValueError: # more informative error raise ValueError("data of shape {} cannot be inverse transformed" @@ -227,6 +261,17 @@ class BaseGraph(with_metaclass(abc.ABCMeta, Base)): Parameters ---------- + kernel_symm : string, optional (default: '+') + Defines method of MNN symmetrization. + '+' : additive + '*' : multiplicative + 'gamma' : min-max + 'none' : no symmetrization + + gamma: float (default: 0.5) + Min-max symmetrization constant. + K = `gamma * min(K, K.T) + (1 - gamma) * max(K, K.T)` + initialize : `bool`, optional (default : `True`) if false, don't create the kernel matrix. @@ -245,7 +290,13 @@ class BaseGraph(with_metaclass(abc.ABCMeta, Base)): diff_op : synonym for `P` """ - def __init__(self, initialize=True, **kwargs): + def __init__(self, kernel_symm='+', + gamma=None, + initialize=True, **kwargs): + self.kernel_symm = kernel_symm + self.gamma = gamma + self._check_symmetrization(kernel_symm, gamma) + if initialize: log_debug("Initializing kernel...") self.K @@ -253,6 +304,25 @@ def __init__(self, initialize=True, **kwargs): log_debug("Not initializing kernel.") super().__init__(**kwargs) + def _check_symmetrization(self, kernel_symm, gamma): + if kernel_symm not in ['+', '*', 'gamma', 'none']: + raise ValueError( + "kernel_symm '{}' not recognized. Choose from " + "'+', '*', 'gamma', or 'none'.".format(kernel_symm)) + elif kernel_symm != 'gamma' and gamma is not None: + warnings.warn("kernel_symm='{}' but gamma is not None. " + "Setting kernel_symm='gamma'.".format(kernel_symm)) + self.kernel_symm = kernel_symm = 'gamma' + + if kernel_symm == 'gamma': + if gamma is None: + warnings.warn("kernel_symm='gamma' but gamma not given. " + "Defaulting to gamma=0.5.") + self.gamma = gamma = 0.5 + elif not isinstance(gamma, numbers.Number) or gamma < 0 or gamma > 1: + raise ValueError("gamma {} not recognized. Expected " + "a float between 0 and 1".format(gamma)) + def _build_kernel(self): """Private method to build kernel matrix @@ -268,16 +338,41 @@ def _build_kernel(self): RuntimeWarning : if K is not symmetric """ kernel = self.build_kernel() + kernel = self.symmetrize_kernel(kernel) if (kernel - kernel.T).max() > 1e-5: warnings.warn("K should be symmetric", RuntimeWarning) if np.any(kernel.diagonal == 0): warnings.warn("K should have a non-zero diagonal", RuntimeWarning) return kernel + def symmetrize_kernel(self, K): + # symmetrize + if self.kernel_symm == "+": + log_debug("Using addition symmetrization.") + K = (K + K.T) / 2 + elif self.kernel_symm == "*": + log_debug("Using multiplication symmetrization.") + K = K.multiply(K.T) + elif self.kernel_symm == 'gamma': + log_debug( + "Using gamma symmetrization (gamma = {}).".format(self.gamma)) + K = self.gamma * elementwise_minimum(K, K.T) + \ + (1 - self.gamma) * elementwise_maximum(K, K.T) + elif self.kernel_symm == 'none': + log_debug("Using no symmetrization.") + pass + else: + # this should never happen + raise ValueError( + "Expected kernel_symm in ['+', '*', 'gamma' or 'none']. " + "Got {}".format(self.gamma)) + return K + def get_params(self): """Get parameters from this object """ - return {} + return {'kernel_symm': self.kernel_symm, + 'gamma': self.gamma} def set_params(self, **params): """Set parameters on this object @@ -285,6 +380,9 @@ def set_params(self, **params): Safe setter method - attributes should not be modified directly as some changes are not valid. Valid parameters: + Invalid parameters: (these would require modifying the kernel matrix) + - kernel_symm + - gamma Parameters ---------- @@ -294,6 +392,11 @@ def set_params(self, **params): ------- self """ + if 'gamma' in params and params['gamma'] != self.gamma: + raise ValueError("Cannot update gamma. Please create a new graph") + if 'kernel_symm' in params and params['kernel_symm'] != self.kernel_symm: + raise ValueError( + "Cannot update kernel_symm. Please create a new graph") return self @property @@ -359,7 +462,7 @@ def build_kernel(self): raise NotImplementedError -class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph)): +class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)): """Interface between BaseGraph and PyGSP. All graphs should possess these matrices. We inherit a lot @@ -369,17 +472,16 @@ class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph)): kernel matrix """ - def __init__(self, **kwargs): + def __init__(self, gtype='unknown', lap_type='combinatorial', coords=None, + plotting=None, **kwargs): + if plotting is None: + plotting = {} W = self._build_weight_from_kernel(self.K) - # delete non-pygsp keywords - # TODO: is there a better way? - keywords = [k for k in kwargs.keys()] - for kw in keywords: - if kw not in ['gtype', 'lap_type', 'coords', 'plotting']: - del kwargs[kw] - - super().__init__(W=W, **kwargs) + super().__init__(W=W, gtype=gtype, + lap_type=lap_type, + coords=coords, + plotting=plotting, **kwargs) @property @abc.abstractmethod diff --git a/graphtools/graphs.py b/graphtools/graphs.py index 81c0080..267d5cf 100644 --- a/graphtools/graphs.py +++ b/graphtools/graphs.py @@ -1,6 +1,6 @@ from builtins import super import numpy as np -from sklearn.neighbors import NearestNeighbors, kneighbors_graph +from sklearn.neighbors import NearestNeighbors from scipy.spatial.distance import pdist, cdist from scipy.spatial.distance import squareform from sklearn.utils.extmath import randomized_svd @@ -10,9 +10,9 @@ import numbers import warnings -from .utils import (elementwise_minimum, +from .utils import (set_diagonal, + elementwise_minimum, elementwise_maximum, - set_diagonal, set_submatrix) from .logging import (log_start, log_complete, @@ -173,8 +173,6 @@ def build_kernel(self): with no non-negative entries. """ K = self.build_kernel_to_data(self.data_nu) - # symmetrize - K = (K + K.T) / 2 return K def build_kernel_to_data(self, Y, knn=None): @@ -696,8 +694,7 @@ def build_kernel(self): pdx = (pdx.T / epsilon).T K = np.exp(-1 * np.power(pdx, self.decay)) log_complete("affinities") - # symmetrize - K = (K + K.T) / 2 + # truncate if sparse.issparse(K): K.data[K.data < self.thresh] = 0 K = K.tocoo() @@ -749,6 +746,7 @@ def build_kernel_to_data(self, Y, knn=None): epsilon = np.max(knn_dist, axis=1) pdx = (pdx.T / epsilon).T K = np.exp(-1 * pdx**self.decay) + K[K < self.thresh] = 0 log_complete("affinities") return K @@ -773,13 +771,6 @@ class MNNGraph(DataGraph): beta: `float`, optional (default: 1) Downweight within-batch affinities by beta - gamma: `float` or {'+', '*'} (default: 0.99) - Symmetrization method. - If '+', use `(K + K.T) / 2`; - if '*', use `K * K.T`; - if a float, use - `gamma * min(K, K.T) + (1 - gamma) * max(K, K.T)` - adaptive_k : `{'min', 'mean', 'sqrt', 'none'}` (default: 'sqrt') Weights MNN kernel adaptively using the number of cells in each sample according to the selected method. @@ -791,20 +782,23 @@ class MNNGraph(DataGraph): """ def __init__(self, data, sample_idx, - knn=5, beta=1, gamma=0.99, n_pca=None, + knn=5, beta=1, n_pca=None, adaptive_k='sqrt', + decay=None, + distance='euclidean', + thresh=1e-4, **kwargs): self.beta = beta - self.gamma = gamma self.sample_idx = sample_idx self.samples, self.n_cells = np.unique( self.sample_idx, return_counts=True) self.adaptive_k = adaptive_k self.knn = knn + self.decay = decay + self.distance = distance + self.thresh = thresh self.weighted_knn = self._weight_knn() - self.knn_args = kwargs - if sample_idx is None: raise ValueError("sample_idx must be given. For a graph without" " batch correction, use kNNGraph.") @@ -815,24 +809,21 @@ def __init__(self, data, sample_idx, raise ValueError( "sample_idx must contain more than one unique value") - if isinstance(gamma, str): - if gamma not in ['+', '*']: - raise ValueError( - "gamma '{}' not recognized. Choose from " - "'+', '*', a float between 0 and 1, " - "or a matrix of floats between 0 " - "and 1.".format(gamma)) - elif isinstance(gamma, numbers.Number): - if (gamma < 0 or gamma > 1): - raise ValueError( - "gamma '{}' invalid. Choose from " - "'+', '*', a float between 0 and 1, " - "or a matrix of floats between 0 " - "and 1.".format(gamma)) - else: - # matrix - if not np.shape(self.gamma) == (len(self.samples), - len(self.samples)): + super().__init__(data, n_pca=n_pca, **kwargs) + + def _check_symmetrization(self, kernel_symm, gamma): + if kernel_symm == 'gamma' and gamma is not None and \ + not isinstance(gamma, numbers.Number): + # matrix gamma + try: + gamma.shape + except AttributeError: + raise ValueError("gamma {} not recognized. " + "Expected a float between 0 and 1 " + "or a [n_batch,n_batch] matrix of " + "floats between 0 and 1".format(gamma)) + if not np.shape(gamma) == (len(self.samples), + len(self.samples)): raise ValueError( "Matrix gamma must be of shape " "({}), got ({})".format( @@ -845,8 +836,8 @@ def __init__(self, data, sample_idx, np.max(gamma), np.min(gamma))) elif np.any(gamma != gamma.T): raise ValueError("gamma must be a symmetric matrix") - - super().__init__(data, n_pca=n_pca, **kwargs) + else: + super()._check_symmetrization(kernel_symm, gamma) def _weight_knn(self, sample_size=None): """Select adaptive values of knn @@ -888,7 +879,7 @@ def get_params(self): """ params = super().get_params() params.update({'beta': self.beta, - 'gamma': self.gamma}) + 'adaptive_k': self.adaptive_k}) params.update(self.knn_args) return params @@ -908,7 +899,6 @@ def set_params(self, **params): - distance - thresh - beta - - gamma Parameters ---------- @@ -921,8 +911,6 @@ def set_params(self, **params): # mnn specific arguments if 'beta' in params and params['beta'] != self.beta: raise ValueError("Cannot update beta. Please create a new graph") - if 'gamma' in params and params['gamma'] != self.gamma: - raise ValueError("Cannot update gamma. Please create a new graph") if 'adaptive_k' in params and params['adaptive_k'] != self.adaptive_k: raise ValueError( "Cannot update adaptive_k. Please create a new graph") @@ -936,8 +924,7 @@ def set_params(self, **params): raise ValueError("Cannot update {}. " "Please create a new graph".format(arg)) for arg in knn_other_args: - if arg in params: - self.knn_args[arg] = params[arg] + self.__setattr__(arg, params[arg]) # update subgraph parameters [g.set_params(**knn_other_args) for g in self.subgraphs] @@ -960,10 +947,6 @@ def build_kernel(self): log_start("subgraphs") self.subgraphs = [] from .api import Graph - remove_args = ['n_landmark', 'initialize'] - for arg in remove_args: - if arg in self.knn_args: - del self.knn_args[arg] # iterate through sample ids for i, idx in enumerate(self.samples): log_debug("subgraph {}: sample {}".format(i, idx)) @@ -972,12 +955,16 @@ def build_kernel(self): # build a kNN graph for cells within sample graph = Graph(data, n_pca=None, knn=self.weighted_knn[i], - initialize=False, - **(self.knn_args)) + decay=self.decay, + distance=self.distance, + thresh=self.thresh, + verbose=self.verbose, + random_state=self.random_state, + initialize=False) self.subgraphs.append(graph) # append to list of subgraphs log_complete("subgraphs") - if isinstance(self.subgraphs[0], kNNGraph): + if self.thresh > 0 or self.decay is None: K = sparse.lil_matrix( (self.data_nu.shape[0], self.data_nu.shape[0])) else: @@ -998,13 +985,17 @@ def build_kernel(self): log_complete( "kernel from sample {} to {}".format(self.samples[i], self.samples[j])) + return K - if not (isinstance(self.gamma, str) or - isinstance(self.gamma, numbers.Number)): + def symmetrize_kernel(self, K): + if self.kernel_symm == 'gamma' and not isinstance(self.gamma, + numbers.Number): # matrix gamma # Gamma can be a matrix with specific values transitions for # each batch. This allows for technical replicates and # experimental samples to be corrected simultaneously + log_debug("Using gamma symmetrization. " + "Gamma:\n{}".format(self.gamma)) for i in range(len(self.samples)): for j in range(i, len(self.samples)): Kij = K[self.sample_idx == i, :][:, self.sample_idx == j] @@ -1019,18 +1010,7 @@ def build_kernel(self): K = set_submatrix(K, self.sample_idx == j, self.sample_idx == i, Kij_symm.T) else: - # symmetrize - if isinstance(self.gamma, str): - if self.gamma == "+": - K = (K + K.T) / 2 - elif self.gamma == "*": - K = K.multiply(K.T) - elif isinstance(self.gamma, numbers.Number): - K = self.gamma * elementwise_minimum(K, K.T) + \ - (1 - self.gamma) * elementwise_maximum(K, K.T) - else: - # this should never happen - raise ValueError("invalid gamma") + K = super().symmetrize_kernel(K) return K def build_kernel_to_data(self, Y, gamma=None): diff --git a/graphtools/version.py b/graphtools/version.py index ae73625..bbab024 100644 --- a/graphtools/version.py +++ b/graphtools/version.py @@ -1 +1 @@ -__version__ = "0.1.3" +__version__ = "0.1.4" diff --git a/test/__init__.py b/test/load_tests/__init__.py similarity index 95% rename from test/__init__.py rename to test/load_tests/__init__.py index e876143..4ae5e10 100644 --- a/test/__init__.py +++ b/test/load_tests/__init__.py @@ -38,12 +38,14 @@ def build_graph(data, n_pca=20, thresh=0, random_state=42, sparse=False, graph_class=graphtools.Graph, + verbose=0, **kwargs): if sparse: data = sp.coo_matrix(data) return graph_class(data, thresh=thresh, n_pca=n_pca, decay=decay, knn=knn, - random_state=42, **kwargs) + random_state=42, verbose=verbose, + **kwargs) def warns(*warns): diff --git a/test/test_api.py b/test/test_api.py index 00b7a50..882cbae 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -1,4 +1,4 @@ -from . import ( +from load_tests import ( nose2, data, build_graph, @@ -11,6 +11,11 @@ ##################################################### +@raises(TypeError) +def test_unknown_parameter(): + build_graph(data, hello='world') + + @raises(ValueError) def test_invalid_graphtype(): build_graph(data, graphtype='hello world') diff --git a/test/test_data.py b/test/test_data.py index d09a502..27fe250 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,4 +1,4 @@ -from . import ( +from load_tests import ( np, sp, pd, @@ -54,7 +54,7 @@ def test_pandas_dataframe(): def test_pandas_sparse_dataframe(): G = build_graph(pd.SparseDataFrame(data)) assert isinstance(G, graphtools.base.BaseGraph) - assert isinstance(G.data, sp.coo_matrix) + assert isinstance(G.data, sp.csr_matrix) ##################################################### diff --git a/test/test_exact.py b/test/test_exact.py index 3fb9fe6..6f9ed3b 100644 --- a/test/test_exact.py +++ b/test/test_exact.py @@ -1,4 +1,4 @@ -from . import ( +from load_tests import ( graphtools, np, sp, @@ -144,9 +144,9 @@ def test_truncated_exact_graph(): epsilon = np.max(knn_dist, axis=1) weighted_pdx = (pdx.T / epsilon).T K = np.exp(-1 * weighted_pdx**a) + K[K < thresh] = 0 W = K + K.T W = np.divide(W, 2) - W[W < thresh] = 0 np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) G2 = build_graph(data_small, thresh=thresh, @@ -230,6 +230,8 @@ def test_precomputed_interpolate(): def test_verbose(): + print() + print("Verbose test: Exact") build_graph(data, decay=10, thresh=0, verbose=True) diff --git a/test/test_knn.py b/test/test_knn.py index 3c7aef1..d221934 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -1,4 +1,4 @@ -from . import ( +from load_tests import ( graphtools, np, pygsp, @@ -125,6 +125,8 @@ def test_knn_interpolate(): def test_verbose(): + print() + print("Verbose test: kNN") build_graph(data, decay=None, verbose=True) diff --git a/test/test_landmark.py b/test/test_landmark.py index 82e6b27..18c5b0e 100644 --- a/test/test_landmark.py +++ b/test/test_landmark.py @@ -1,4 +1,4 @@ -from . import ( +from load_tests import ( graphtools, np, nose2, @@ -73,6 +73,8 @@ def test_landmark_mnn_graph(): def test_verbose(): + print() + print("Verbose test: Landmark") build_graph(data, decay=None, n_landmark=500, verbose=True).landmark_op diff --git a/test/test_mnn.py b/test/test_mnn.py index 76b4f89..43292f2 100644 --- a/test/test_mnn.py +++ b/test/test_mnn.py @@ -1,4 +1,4 @@ -from . import ( +from load_tests import ( graphtools, np, pd, @@ -21,13 +21,15 @@ @raises(ValueError) def test_sample_idx_and_precomputed(): - build_graph(data, n_pca=None, sample_idx=np.arange(10), + build_graph(data, n_pca=None, + sample_idx=np.arange(10), precomputed='distance') @raises(ValueError) def test_sample_idx_wrong_length(): - build_graph(data, graphtype='mnn', sample_idx=np.arange(10)) + build_graph(data, graphtype='mnn', + sample_idx=np.arange(10)) @raises(ValueError) @@ -46,6 +48,31 @@ def test_build_mnn_with_precomputed(): build_graph(data, n_pca=None, graphtype='mnn', precomputed='distance') +@raises(ValueError) +def test_mnn_with_square_gamma_wrong_length(): + n_sample = len(np.unique(digits['target'])) + # square matrix gamma of the wrong size + build_graph( + data, thresh=0, n_pca=20, + decay=10, knn=5, random_state=42, + sample_idx=digits['target'], + kernel_symm='gamma', + gamma=np.tile(np.linspace(0, 1, n_sample - 1), + n_sample).reshape(n_sample - 1, n_sample)) + + +@raises(ValueError) +def test_mnn_with_vector_gamma(): + n_sample = len(np.unique(digits['target'])) + # vector gamma + build_graph( + data, thresh=0, n_pca=20, + decay=10, knn=5, random_state=42, + sample_idx=digits['target'], + kernel_symm='gamma', + gamma=np.linspace(0, 1, n_sample - 1)) + + ##################################################### # Check kernel ##################################################### @@ -84,14 +111,12 @@ def test_mnn_graph_float_gamma(): ((1 - gamma) * np.maximum(K, K.T))) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = graphtools.Graph(X, knn=k + 1, decay=a, beta=1 - beta, gamma=gamma, + G2 = graphtools.Graph(X, knn=k + 1, decay=a, beta=1 - beta, + kernel_symm='gamma', gamma=gamma, distance=metric, sample_idx=sample_idx, thresh=0, use_pygsp=True) assert G.N == G2.N - assert np.all(G.d == G2.d), "{} ({}, {})".format( - np.where(G.d != G2.d), - G.d[np.argwhere(G.d != G2.d).reshape(-1)], - G2.d[np.argwhere(G.d != G2.d).reshape(-1)]) + assert np.all(G.d == G2.d) assert (G.W != G2.W).nnz == 0 assert (G2.W != G.W).sum() == 0 assert isinstance(G2, graphtools.graphs.MNNGraph) @@ -140,34 +165,17 @@ def test_mnn_graph_matrix_gamma(): ((1 - matrix_gamma) * np.maximum(K, K.T))) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = graphtools.Graph(X, knn=k + 1, decay=a, beta=1 - beta, gamma=gamma, + G2 = graphtools.Graph(X, knn=k + 1, decay=a, beta=1 - beta, + kernel_symm='gamma', gamma=gamma, distance=metric, sample_idx=sample_idx, thresh=0, use_pygsp=True) assert G.N == G2.N - assert np.all(G.d == G2.d), "{} ({}, {})".format( - np.where(G.d != G2.d), - G.d[np.argwhere(G.d != G2.d).reshape(-1)], - G2.d[np.argwhere(G.d != G2.d).reshape(-1)]) + assert np.all(G.d == G2.d) assert (G.W != G2.W).nnz == 0 assert (G2.W != G.W).sum() == 0 assert isinstance(G2, graphtools.graphs.MNNGraph) -def test_mnn_graph_error(): - n_sample = len(np.unique(digits['target'])) - assert_raises(ValueError, build_graph, - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - gamma=np.tile(np.linspace(0, 1, n_sample - 1), - n_sample).reshape(n_sample - 1, n_sample)) - assert_raises(ValueError, build_graph, - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - gamma=np.linspace(0, 1, n_sample - 1)) - - ##################################################### # Check interpolation ##################################################### @@ -177,7 +185,11 @@ def test_mnn_graph_error(): def test_verbose(): X, sample_idx = generate_swiss_roll() - build_graph(X, sample_idx=sample_idx, n_pca=None, verbose=True) + print() + print("Verbose test: MNN") + build_graph(X, sample_idx=sample_idx, + kernel_symm='gamma', gamma=0.5, + n_pca=None, verbose=True) if __name__ == "__main__":