diff --git a/.travis.yml b/.travis.yml index e7159a0..3bc3118 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,6 +17,7 @@ script: - pip install -U .[test] + - if [ "$TRAVIS_PYTHON_VERSION" != "3.5" ]; then black . --check --diff; fi - python setup.py test - pip install -U .[doc] - cd doc; make html diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..7593865 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,39 @@ + +Contributing to graphtools +============================ + +There are many ways to contribute to `graphtools`, with the most common ones +being contribution of code or documentation to the project. Improving the +documentation is no less important than improving the library itself. If you +find a typo in the documentation, or have made improvements, do not hesitate to +submit a GitHub pull request. + +But there are many other ways to help. In particular answering queries on the +[issue tracker](https://github.com/KrishnaswamyLab/graphtools/issues), +investigating bugs, and [reviewing other developers' pull +requests](https://github.com/KrishnaswamyLab/graphtools/pulls) +are very valuable contributions that decrease the burden on the project +maintainers. + +Another way to contribute is to report issues you're facing, and give a "thumbs +up" on issues that others reported and that are relevant to you. It also helps +us if you spread the word: reference the project from your blog and articles, +link to it from your website, or simply star it in GitHub to say "I use it". + +Code Style and Testing +---------------------- + +`graphtools` is maintained at close to 100% code coverage. Contributors are encouraged to write tests for their code, but if you do not know how to do so, please do not feel discouraged from contributing code! Others can always help you test your contribution. + +Code style is dictated by [`black`](https://pypi.org/project/black/#installation-and-usage). To automatically reformat your code when you run `git commit`, you can run `./autoblack.sh` in the root directory of this project to add a hook to your `git` repository. + +Code of Conduct +--------------- + +We abide by the principles of openness, respect, and consideration of others +of the Python Software Foundation: https://www.python.org/psf/codeofconduct/. + +Attribution +--------------- + +This `CONTRIBUTING.md` was adapted from [scikit-learn](https://github.com/scikit-learn/scikit-learn/blob/master/CONTRIBUTING.md). diff --git a/README.rst b/README.rst index 2859790..133409a 100644 --- a/README.rst +++ b/README.rst @@ -23,6 +23,9 @@ graphtools .. image:: https://img.shields.io/github/stars/KrishnaswamyLab/graphtools.svg?style=social&label=Stars :target: https://github.com/KrishnaswamyLab/graphtools/ :alt: GitHub stars +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black + :alt: Code style: black Tools for building and manipulating graphs in Python. diff --git a/autoblack.sh b/autoblack.sh new file mode 100644 index 0000000..cfbaf2b --- /dev/null +++ b/autoblack.sh @@ -0,0 +1,14 @@ +cat <> .git/hooks/pre-commit +#!/bin/sh + +set -e + +files=\$(git diff --staged --name-only --diff-filter=d -- "*.py") + +for file in \$files; do + black -q \$file + git add \$file +done +EOF +chmod +x .git/hooks/pre-commit + diff --git a/doc/source/conf.py b/doc/source/conf.py index c5c3d9b..e7303a3 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -19,7 +19,9 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../..')) + +root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, root_dir) # print(sys.path) # -- General configuration ------------------------------------------------ @@ -31,40 +33,43 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'sphinx.ext.doctest', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinxcontrib.bibtex'] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.doctest", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinxcontrib.bibtex", +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['ytemplates'] +templates_path = ["ytemplates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'graphtools' -copyright = '2018 Krishnaswamy Lab, Yale University' -author = 'Jay Stanley and Scott Gigante, Krishnaswamy Lab, Yale University' +project = "graphtools" +copyright = "2018 Krishnaswamy Lab, Yale University" +author = "Scott Gigante and Jay Stanley, Yale University" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # -# The short X.Y version. -version = '0.1.3' +version_py = os.path.join(root_dir, "graphtools", "version.py") # The full version, including alpha/beta/rc tags. -release = '0.1.3' +release = open(version_py).read().strip().split("=")[-1].replace('"', "").strip() +# The short X.Y version. +version = release.split("-")[0] # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -79,7 +84,7 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -90,7 +95,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -101,13 +106,13 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['ystatic'] +html_static_path = ["ystatic"] # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'graphtoolsdoc' +htmlhelp_basename = "graphtoolsdoc" # -- Options for LaTeX output --------------------------------------------- @@ -116,15 +121,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -134,8 +136,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'graphtools.tex', 'graphtools Documentation', - 'Jay Stanley and Scott Gigante, Krishnaswamy Lab, Yale University', 'manual'), + ( + master_doc, + "graphtools.tex", + "graphtools Documentation", + "Scott Gigante and Jay Stanley, Yale University", + "manual", + ), ] @@ -143,10 +150,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'graphtools', 'graphtools Documentation', - [author], 1) -] +man_pages = [(master_doc, "graphtools", "graphtools Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -155,7 +159,13 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'graphtools', 'graphtools Documentation', - author, 'graphtools', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "graphtools", + "graphtools Documentation", + author, + "graphtools", + "One line description of project.", + "Miscellaneous", + ), ] diff --git a/doc/source/index.rst b/doc/source/index.rst index 0cddedf..e60d936 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -30,6 +30,10 @@ graphtools GitHub stars +.. raw:: html + + Code style: black + Tools for building and manipulating graphs in Python. .. toctree:: @@ -65,4 +69,4 @@ To use `graphtools` with `pygsp`, create a `graphtools.Graph` class with `use_py Help ==== -If you have any questions or require assistance using graphtools, please contact us at https://krishnaswamylab.org/get-help \ No newline at end of file +If you have any questions or require assistance using graphtools, please contact us at https://krishnaswamylab.org/get-help diff --git a/graphtools/api.py b/graphtools/api.py index 80b7746..f29e756 100644 --- a/graphtools/api.py +++ b/graphtools/api.py @@ -7,34 +7,37 @@ from . import base, graphs -_logger = tasklogger.get_tasklogger('graphtools') - - -def Graph(data, - n_pca=None, - rank_threshold=None, - sample_idx=None, - adaptive_k=None, - precomputed=None, - knn=5, - decay=40, - bandwidth=None, - bandwidth_scale=1.0, - anisotropy=0, - distance='euclidean', - thresh=1e-4, - kernel_symm='+', - theta=None, - n_landmark=None, - n_svd=100, - beta=1, - n_jobs=-1, - verbose=False, - random_state=None, - graphtype='auto', - use_pygsp=False, - initialize=True, - **kwargs): +_logger = tasklogger.get_tasklogger("graphtools") + + +def Graph( + data, + n_pca=None, + rank_threshold=None, + knn=5, + decay=40, + bandwidth=None, + bandwidth_scale=1.0, + knn_max=None, + anisotropy=0, + distance="euclidean", + thresh=1e-4, + kernel_symm="+", + theta=None, + precomputed=None, + beta=1, + sample_idx=None, + adaptive_k=None, + n_landmark=None, + n_svd=100, + n_jobs=-1, + verbose=False, + random_state=None, + graphtype="auto", + use_pygsp=False, + initialize=True, + **kwargs +): """Create a graph built on data. Automatically selects the appropriate DataGraph subclass based on @@ -88,6 +91,9 @@ def Graph(data, bandwidth_scale : `float`, optional (default : 1.0) Rescaling factor for bandwidth. + knn_max : `int` or `None`, optional (default : `None`) + Maximum number of neighbors with nonzero affinity + anisotropy : float, optional (default: 0) Level of anisotropy between 0 and 1 (alpha in Coifman & Lafon, 2006) @@ -176,12 +182,11 @@ def Graph(data, """ _logger.set_level(verbose) if sample_idx is not None and len(np.unique(sample_idx)) == 1: - warnings.warn("Only one unique sample. " - "Not using MNNGraph") + warnings.warn("Only one unique sample. Not using MNNGraph") sample_idx = None - if graphtype == 'mnn': - graphtype = 'auto' - if graphtype == 'auto': + if graphtype == "mnn": + graphtype = "auto" + if graphtype == "auto": # automatic graph selection if sample_idx is not None: # only mnn does batch correction @@ -192,7 +197,7 @@ def Graph(data, elif decay is None: # knn kernel graphtype = "knn" - elif thresh == 0 or callable(bandwidth): + elif (thresh == 0 and knn_max is None) or callable(bandwidth): # compute full distance matrix graphtype = "exact" else: @@ -203,29 +208,39 @@ def Graph(data, if graphtype == "knn": basegraph = graphs.kNNGraph if precomputed is not None: - raise ValueError("kNNGraph does not support precomputed " - "values. Use `graphtype='exact'` or " - "`precomputed=None`") + raise ValueError( + "kNNGraph does not support precomputed " + "values. Use `graphtype='exact'` or " + "`precomputed=None`" + ) if sample_idx is not None: - raise ValueError("kNNGraph does not support batch " - "correction. Use `graphtype='mnn'` or " - "`sample_idx=None`") + raise ValueError( + "kNNGraph does not support batch " + "correction. Use `graphtype='mnn'` or " + "`sample_idx=None`" + ) elif graphtype == "mnn": basegraph = graphs.MNNGraph if precomputed is not None: - raise ValueError("MNNGraph does not support precomputed " - "values. Use `graphtype='exact'` and " - "`sample_idx=None` or `precomputed=None`") + raise ValueError( + "MNNGraph does not support precomputed " + "values. Use `graphtype='exact'` and " + "`sample_idx=None` or `precomputed=None`" + ) elif graphtype == "exact": basegraph = graphs.TraditionalGraph if sample_idx is not None: - raise ValueError("TraditionalGraph does not support batch " - "correction. Use `graphtype='mnn'` or " - "`sample_idx=None`") + raise ValueError( + "TraditionalGraph does not support batch " + "correction. Use `graphtype='mnn'` or " + "`sample_idx=None`" + ) else: - raise ValueError("graphtype '{}' not recognized. Choose from " - "['knn', 'mnn', 'exact', 'auto']") + raise ValueError( + "graphtype '{}' not recognized. Choose from " + "['knn', 'mnn', 'exact', 'auto']" + ) # set add landmarks if necessary parent_classes = [basegraph] @@ -258,11 +273,18 @@ def Graph(data, pass # build graph and return - _logger.debug("Initializing {} with arguments {}".format( - parent_classes, - ", ".join(["{}='{}'".format(key, value) - for key, value in params.items() - if key != "data"]))) + _logger.debug( + "Initializing {} with arguments {}".format( + parent_classes, + ", ".join( + [ + "{}='{}'".format(key, value) + for key, value in params.items() + if key != "data" + ] + ), + ) + ) return Graph(**params) @@ -286,23 +308,25 @@ def from_igraph(G, attribute="weight", **kwargs): ------- G : graphtools.graphs.TraditionalGraph """ - if 'precomputed' in kwargs: - if kwargs['precomputed'] != 'adjacency': + if "precomputed" in kwargs: + if kwargs["precomputed"] != "adjacency": warnings.warn( "Cannot build graph from igraph with precomputed={}. " - "Use 'adjacency' instead.".format(kwargs['precomputed']), - UserWarning) - del kwargs['precomputed'] + "Use 'adjacency' instead.".format(kwargs["precomputed"]), + UserWarning, + ) + del kwargs["precomputed"] try: K = G.get_adjacency(attribute=attribute).data except ValueError as e: if str(e) == "Attribute does not exist": - warnings.warn("Edge attribute {} not found. " - "Returning unweighted graph".format(attribute), - UserWarning) + warnings.warn( + "Edge attribute {} not found. " + "Returning unweighted graph".format(attribute), + UserWarning, + ) K = G.get_adjacency(attribute=None).data - return Graph(sparse.coo_matrix(K), - precomputed='adjacency', **kwargs) + return Graph(sparse.coo_matrix(K), precomputed="adjacency", **kwargs) def read_pickle(path): @@ -313,12 +337,11 @@ def read_pickle(path): path : str File path where the pickled object will be loaded. """ - with open(path, 'rb') as f: + with open(path, "rb") as f: G = pickle.load(f) if not isinstance(G, base.BaseGraph): - warnings.warn( - 'Returning object that is not a graphtools.base.BaseGraph') + warnings.warn("Returning object that is not a graphtools.base.BaseGraph") elif isinstance(G, base.PyGSPGraph) and isinstance(G.logger, str): G.logger = pygsp.utils.build_logger(G.logger) return G diff --git a/graphtools/base.py b/graphtools/base.py index a960d02..734ecb3 100644 --- a/graphtools/base.py +++ b/graphtools/base.py @@ -29,7 +29,7 @@ from . import utils -_logger = tasklogger.get_tasklogger('graphtools') +_logger = tasklogger.get_tasklogger("graphtools") class Base(object): @@ -45,7 +45,7 @@ def _get_param_names(cls): """Get parameter names for the estimator""" # fetch the constructor or the original constructor before # deprecation wrapping if any - init = getattr(cls.__init__, 'deprecated_original', cls.__init__) + init = getattr(cls.__init__, "deprecated_original", cls.__init__) if init is object.__init__: # No explicit constructor to introspect return [] @@ -54,8 +54,11 @@ def _get_param_names(cls): # to represent init_signature = signature(init) # Consider the constructor parameters excluding 'self' - parameters = [p for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD] + parameters = [ + p + for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD + ] # Extract and sort argument names excluding 'self' parameters = set([p.name for p in parameters]) @@ -117,12 +120,12 @@ class Data(Base): sklearn PCA operator """ - def __init__(self, data, n_pca=None, rank_threshold=None, - random_state=None, **kwargs): + def __init__( + self, data, n_pca=None, rank_threshold=None, random_state=None, **kwargs + ): self._check_data(data) - n_pca, rank_threshold = self._parse_n_pca_threshold( - data, n_pca, rank_threshold) + n_pca, rank_threshold = self._parse_n_pca_threshold(data, n_pca, rank_threshold) try: if isinstance(data, pd.SparseDataFrame): data = data.to_coo() @@ -152,70 +155,83 @@ def _parse_n_pca_threshold(self, data, n_pca, rank_threshold): if isinstance(n_pca, str): n_pca = n_pca.lower() if n_pca != "auto": - raise ValueError("n_pca must be an integer " - "0 <= n_pca < min(n_samples,n_features), " - "or in [None,False,True,'auto'].") + raise ValueError( + "n_pca must be an integer " + "0 <= n_pca < min(n_samples,n_features), " + "or in [None,False,True,'auto']." + ) if isinstance(n_pca, numbers.Number): if not float(n_pca).is_integer(): # cast it to integer n_pcaR = np.round(n_pca).astype(int) warnings.warn( "Cannot perform PCA to fractional {} dimensions. " - "Rounding to {}".format( - n_pca, n_pcaR), RuntimeWarning) + "Rounding to {}".format(n_pca, n_pcaR), + RuntimeWarning, + ) n_pca = n_pcaR if n_pca < 0: raise ValueError( "n_pca cannot be negative. " "Please supply an integer " - "0 <= n_pca < min(n_samples,n_features) or None") + "0 <= n_pca < min(n_samples,n_features) or None" + ) elif np.min(data.shape) <= n_pca: warnings.warn( "Cannot perform PCA to {} dimensions on " "data with min(n_samples, n_features) = {}".format( - n_pca, np.min( - data.shape)), RuntimeWarning) + n_pca, np.min(data.shape) + ), + RuntimeWarning, + ) n_pca = 0 if n_pca in [0, False, None]: # cast 0, False to None. n_pca = None elif n_pca is True: # notify that we're going to estimate rank. - n_pca = 'auto' - _logger.info("Estimating n_pca from matrix rank. " - "Supply an integer n_pca " - "for fixed amount.") - if not any([isinstance(n_pca, numbers.Number), - n_pca is None, - n_pca == 'auto']): + n_pca = "auto" + _logger.info( + "Estimating n_pca from matrix rank. " + "Supply an integer n_pca " + "for fixed amount." + ) + if not any([isinstance(n_pca, numbers.Number), n_pca is None, n_pca == "auto"]): raise ValueError( "n_pca was not an instance of numbers.Number, " "could not be cast to False, and not None. " "Please supply an integer " - "0 <= n_pca < min(n_samples,n_features) or None") - if rank_threshold is not None and n_pca != 'auto': - warnings.warn("n_pca = {}, therefore rank_threshold of {} " - "will not be used. To use rank thresholding, " - "set n_pca = True".format(n_pca, rank_threshold), - RuntimeWarning) - if n_pca == 'auto': + "0 <= n_pca < min(n_samples,n_features) or None" + ) + if rank_threshold is not None and n_pca != "auto": + warnings.warn( + "n_pca = {}, therefore rank_threshold of {} " + "will not be used. To use rank thresholding, " + "set n_pca = True".format(n_pca, rank_threshold), + RuntimeWarning, + ) + if n_pca == "auto": if isinstance(rank_threshold, str): rank_threshold = rank_threshold.lower() if rank_threshold is None: - rank_threshold = 'auto' + rank_threshold = "auto" if isinstance(rank_threshold, numbers.Number): if rank_threshold <= 0: raise ValueError( - "rank_threshold must be positive float or 'auto'. ") + "rank_threshold must be positive float or 'auto'. " + ) else: - if rank_threshold != 'auto': + if rank_threshold != "auto": raise ValueError( - "rank_threshold must be positive float or 'auto'. ") + "rank_threshold must be positive float or 'auto'. " + ) return n_pca, rank_threshold def _check_data(self, data): if len(data.shape) != 2: - msg = "ValueError: Expected 2D array, got {}D array " \ + msg = ( + "ValueError: Expected 2D array, got {}D array " "instead (shape: {}.) ".format(len(data.shape), data.shape) + ) if len(data.shape) < 2: msg += "\nReshape your data either using array.reshape(-1, 1) " "if your data has a single feature or array.reshape(1, -1) if " @@ -234,60 +250,67 @@ def _reduce_data(self): ------- Reduced data matrix """ - if self.n_pca is not None and (self.n_pca == 'auto' or self.n_pca < self.data.shape[1]): + if self.n_pca is not None and ( + self.n_pca == "auto" or self.n_pca < self.data.shape[1] + ): with _logger.task("PCA"): - n_pca = self.data.shape[1] - 1 if self.n_pca == 'auto' else self.n_pca + n_pca = self.data.shape[1] - 1 if self.n_pca == "auto" else self.n_pca if sparse.issparse(self.data): - if isinstance(self.data, sparse.coo_matrix) or \ - isinstance(self.data, sparse.lil_matrix) or \ - isinstance(self.data, sparse.dok_matrix): + if ( + isinstance(self.data, sparse.coo_matrix) + or isinstance(self.data, sparse.lil_matrix) + or isinstance(self.data, sparse.dok_matrix) + ): self.data = self.data.tocsr() self.data_pca = TruncatedSVD(n_pca, random_state=self.random_state) else: - self.data_pca = PCA(n_pca, - svd_solver='randomized', - random_state=self.random_state) + self.data_pca = PCA( + n_pca, svd_solver="randomized", random_state=self.random_state + ) self.data_pca.fit(self.data) - if self.n_pca == 'auto': + if self.n_pca == "auto": s = self.data_pca.singular_values_ smax = s.max() - if self.rank_threshold == 'auto': - threshold = smax * \ - np.finfo(self.data.dtype).eps * max(self.data.shape) + if self.rank_threshold == "auto": + threshold = ( + smax * np.finfo(self.data.dtype).eps * max(self.data.shape) + ) self.rank_threshold = threshold threshold = self.rank_threshold gate = np.where(s >= threshold)[0] self.n_pca = gate.shape[0] if self.n_pca == 0: - raise ValueError("Supplied threshold {} was greater than " - "maximum singular value {} " - "for the data matrix".format(threshold, smax)) + raise ValueError( + "Supplied threshold {} was greater than " + "maximum singular value {} " + "for the data matrix".format(threshold, smax) + ) _logger.info( - "Using rank estimate of {} as n_pca".format(self.n_pca)) + "Using rank estimate of {} as n_pca".format(self.n_pca) + ) # reset the sklearn operator op = self.data_pca # for line-width brevity.. op.components_ = op.components_[gate, :] op.explained_variance_ = op.explained_variance_[gate] - op.explained_variance_ratio_ = op.explained_variance_ratio_[ - gate] + op.explained_variance_ratio_ = op.explained_variance_ratio_[gate] op.singular_values_ = op.singular_values_[gate] - self.data_pca = op # im not clear if this is needed due to assignment rules + self.data_pca = ( + op # im not clear if this is needed due to assignment rules + ) data_nu = self.data_pca.transform(self.data) return data_nu else: data_nu = self.data if sparse.issparse(data_nu) and not isinstance( - data_nu, (sparse.csr_matrix, - sparse.csc_matrix, - sparse.bsr_matrix)): + data_nu, (sparse.csr_matrix, sparse.csc_matrix, sparse.bsr_matrix) + ): data_nu = data_nu.tocsr() return data_nu def get_params(self): """Get parameters from this object """ - return {'n_pca': self.n_pca, - 'random_state': self.random_state} + return {"n_pca": self.n_pca, "random_state": self.random_state} def set_params(self, **params): """Set parameters on this object @@ -306,10 +329,10 @@ def set_params(self, **params): ------- self """ - if 'n_pca' in params and params['n_pca'] != self.n_pca: + if "n_pca" in params and params["n_pca"] != self.n_pca: raise ValueError("Cannot update n_pca. Please create a new graph") - if 'random_state' in params: - self.random_state = params['random_state'] + if "random_state" in params: + self.random_state = params["random_state"] super().set_params(**params) return self @@ -347,9 +370,10 @@ def transform(self, Y): raise ValueError except ValueError: # more informative error - raise ValueError("data of shape {} cannot be transformed" - " to graph built on data of shape {}".format( - Y.shape, self.data.shape)) + raise ValueError( + "data of shape {} cannot be transformed" + " to graph built on data of shape {}".format(Y.shape, self.data.shape) + ) def inverse_transform(self, Y, columns=None): """Transform input data `Y` to ambient data space defined by `self.data` @@ -402,9 +426,12 @@ def inverse_transform(self, Y, columns=None): return Y_inv except ValueError: # more informative error - raise ValueError("data of shape {} cannot be inverse transformed" - " from graph built on data of shape {}".format( - Y.shape, self.data_nu.shape)) + raise ValueError( + "data of shape {} cannot be inverse transformed" + " from graph built on data of shape {}".format( + Y.shape, self.data_nu.shape + ) + ) class BaseGraph(with_metaclass(abc.ABCMeta, Base)): @@ -446,30 +473,39 @@ class BaseGraph(with_metaclass(abc.ABCMeta, Base)): diff_op : synonym for `P` """ - def __init__(self, - kernel_symm='+', - theta=None, - anisotropy=0, - gamma=None, - initialize=True, **kwargs): + def __init__( + self, + kernel_symm="+", + theta=None, + anisotropy=0, + gamma=None, + initialize=True, + **kwargs + ): if gamma is not None: - warnings.warn("gamma is deprecated. " - "Setting theta={}".format(gamma), FutureWarning) + warnings.warn( + "gamma is deprecated. " "Setting theta={}".format(gamma), FutureWarning + ) theta = gamma - if kernel_symm == 'gamma': - warnings.warn("kernel_symm='gamma' is deprecated. " - "Setting kernel_symm='mnn'", FutureWarning) - kernel_symm = 'mnn' - if kernel_symm == 'theta': - warnings.warn("kernel_symm='theta' is deprecated. " - "Setting kernel_symm='mnn'", FutureWarning) - kernel_symm = 'mnn' + if kernel_symm == "gamma": + warnings.warn( + "kernel_symm='gamma' is deprecated. " "Setting kernel_symm='mnn'", + FutureWarning, + ) + kernel_symm = "mnn" + if kernel_symm == "theta": + warnings.warn( + "kernel_symm='theta' is deprecated. " "Setting kernel_symm='mnn'", + FutureWarning, + ) + kernel_symm = "mnn" self.kernel_symm = kernel_symm self.theta = theta self._check_symmetrization(kernel_symm, theta) if not (isinstance(anisotropy, numbers.Real) and 0 <= anisotropy <= 1): - raise ValueError("Expected 0 <= anisotropy <= 1. " - "Got {}".format(anisotropy)) + raise ValueError( + "Expected 0 <= anisotropy <= 1. " "Got {}".format(anisotropy) + ) self.anisotropy = anisotropy if initialize: @@ -480,24 +516,30 @@ def __init__(self, super().__init__(**kwargs) def _check_symmetrization(self, kernel_symm, theta): - if kernel_symm not in ['+', '*', 'mnn', None]: + if kernel_symm not in ["+", "*", "mnn", None]: raise ValueError( "kernel_symm '{}' not recognized. Choose from " - "'+', '*', 'mnn', or 'none'.".format(kernel_symm)) - elif kernel_symm != 'mnn' and theta is not None: - warnings.warn("kernel_symm='{}' but theta is not None. " - "Setting kernel_symm='mnn'.".format(kernel_symm)) - self.kernel_symm = kernel_symm = 'mnn' - - if kernel_symm == 'mnn': + "'+', '*', 'mnn', or 'none'.".format(kernel_symm) + ) + elif kernel_symm != "mnn" and theta is not None: + warnings.warn( + "kernel_symm='{}' but theta is not None. " + "Setting kernel_symm='mnn'.".format(kernel_symm) + ) + self.kernel_symm = kernel_symm = "mnn" + + if kernel_symm == "mnn": if theta is None: self.theta = theta = 1 - warnings.warn("kernel_symm='mnn' but theta not given. " - "Defaulting to theta={}.".format(self.theta)) - elif not isinstance(theta, numbers.Number) or \ - theta < 0 or theta > 1: - raise ValueError("theta {} not recognized. Expected " - "a float between 0 and 1".format(theta)) + warnings.warn( + "kernel_symm='mnn' but theta not given. " + "Defaulting to theta={}.".format(self.theta) + ) + elif not isinstance(theta, numbers.Number) or theta < 0 or theta > 1: + raise ValueError( + "theta {} not recognized. Expected " + "a float between 0 and 1".format(theta) + ) def _build_kernel(self): """Private method to build kernel matrix @@ -530,11 +572,11 @@ def symmetrize_kernel(self, K): elif self.kernel_symm == "*": _logger.debug("Using multiplication symmetrization.") K = K.multiply(K.T) - elif self.kernel_symm == 'mnn': - _logger.debug( - "Using mnn symmetrization (theta = {}).".format(self.theta)) - K = self.theta * utils.elementwise_minimum(K, K.T) + \ - (1 - self.theta) * utils.elementwise_maximum(K, K.T) + elif self.kernel_symm == "mnn": + _logger.debug("Using mnn symmetrization (theta = {}).".format(self.theta)) + K = self.theta * utils.elementwise_minimum(K, K.T) + ( + 1 - self.theta + ) * utils.elementwise_maximum(K, K.T) elif self.kernel_symm is None: _logger.debug("Using no symmetrization.") pass @@ -542,7 +584,8 @@ def symmetrize_kernel(self, K): # this should never happen raise ValueError( "Expected kernel_symm in ['+', '*', 'mnn' or None]. " - "Got {}".format(self.theta)) + "Got {}".format(self.theta) + ) return K def apply_anisotropy(self, K): @@ -563,9 +606,11 @@ def apply_anisotropy(self, K): def get_params(self): """Get parameters from this object """ - return {'kernel_symm': self.kernel_symm, - 'theta': self.theta, - 'anisotropy': self.anisotropy} + return { + "kernel_symm": self.kernel_symm, + "theta": self.theta, + "anisotropy": self.anisotropy, + } def set_params(self, **params): """Set parameters on this object @@ -585,15 +630,12 @@ def set_params(self, **params): ------- self """ - if 'theta' in params and params['theta'] != self.theta: + if "theta" in params and params["theta"] != self.theta: raise ValueError("Cannot update theta. Please create a new graph") - if 'anisotropy' in params and params['anisotropy'] != self.anisotropy: - raise ValueError( - "Cannot update anisotropy. Please create a new graph") - if 'kernel_symm' in params and \ - params['kernel_symm'] != self.kernel_symm: - raise ValueError( - "Cannot update kernel_symm. Please create a new graph") + if "anisotropy" in params and params["anisotropy"] != self.anisotropy: + raise ValueError("Cannot update anisotropy. Please create a new graph") + if "kernel_symm" in params and params["kernel_symm"] != self.kernel_symm: + raise ValueError("Cannot update kernel_symm. Please create a new graph") super().set_params(**params) return self @@ -613,7 +655,7 @@ def P(self): try: return self._diff_op except AttributeError: - self._diff_op = normalize(self.kernel, 'l1', axis=1) + self._diff_op = normalize(self.kernel, "l1", axis=1) return self._diff_op @property @@ -636,9 +678,13 @@ def diff_aff(self): row_degrees = utils.to_array(self.kernel.sum(axis=1)) if sparse.issparse(self.kernel): # diagonal matrix - degrees = sparse.csr_matrix((1 / np.sqrt(row_degrees.flatten()), - np.arange(len(row_degrees)), - np.arange(len(row_degrees) + 1))) + degrees = sparse.csr_matrix( + ( + 1 / np.sqrt(row_degrees.flatten()), + np.arange(len(row_degrees)), + np.arange(len(row_degrees) + 1), + ) + ) return degrees @ self.kernel @ degrees else: col_degrees = row_degrees.T @@ -709,23 +755,24 @@ def to_pygsp(self, **kwargs): G : graphtools.base.PyGSPGraph, graphtools.graphs.TraditionalGraph """ from . import api - if 'precomputed' in kwargs: - if kwargs['precomputed'] != 'affinity': + + if "precomputed" in kwargs: + if kwargs["precomputed"] != "affinity": warnings.warn( "Cannot build PyGSPGraph with precomputed={}. " - "Using 'affinity' instead.".format(kwargs['precomputed']), - UserWarning) - del kwargs['precomputed'] - if 'use_pygsp' in kwargs: - if kwargs['use_pygsp'] is not True: + "Using 'affinity' instead.".format(kwargs["precomputed"]), + UserWarning, + ) + del kwargs["precomputed"] + if "use_pygsp" in kwargs: + if kwargs["use_pygsp"] is not True: warnings.warn( "Cannot build PyGSPGraph with use_pygsp={}. " - "Use True instead.".format(kwargs['use_pygsp']), - UserWarning) - del kwargs['use_pygsp'] - return api.Graph(self.K, - precomputed="affinity", use_pygsp=True, - **kwargs) + "Use True instead.".format(kwargs["use_pygsp"]), + UserWarning, + ) + del kwargs["use_pygsp"] + return api.Graph(self.K, precomputed="affinity", use_pygsp=True, **kwargs) def to_igraph(self, attribute="weight", **kwargs): """Convert to an igraph Graph @@ -740,8 +787,9 @@ def to_igraph(self, attribute="weight", **kwargs): try: import igraph as ig except ImportError: - raise ImportError("Please install igraph with " - "`pip install --user python-igraph`.") + raise ImportError( + "Please install igraph with " "`pip install --user python-igraph`." + ) try: W = self.W except AttributeError: @@ -765,41 +813,45 @@ def to_pickle(self, path): File path where the pickled object will be stored. """ pickle_obj = shallow_copy(self) - is_oldpygsp = all([isinstance(self, pygsp.graphs.Graph), - int(sys.version.split(".")[1]) < 7]) + is_oldpygsp = all( + [isinstance(self, pygsp.graphs.Graph), int(sys.version.split(".")[1]) < 7] + ) if is_oldpygsp: pickle_obj.logger = pickle_obj.logger.name - with open(path, 'wb') as f: + with open(path, "wb") as f: pickle.dump(pickle_obj, f, protocol=pickle.HIGHEST_PROTOCOL) def _check_shortest_path_distance(self, distance): - if distance == 'data' and self.weighted: + if distance == "data" and self.weighted: raise NotImplementedError( "Graph shortest path with constant or data distance only " "implemented for unweighted graphs. " - "For weighted graphs, use `distance='affinity'`.") - elif distance == 'constant' and self.weighted: + "For weighted graphs, use `distance='affinity'`." + ) + elif distance == "constant" and self.weighted: raise NotImplementedError( "Graph shortest path with constant distance only " "implemented for unweighted graphs. " - "For weighted graphs, use `distance='affinity'`.") - elif distance == 'affinity' and not self.weighted: + "For weighted graphs, use `distance='affinity'`." + ) + elif distance == "affinity" and not self.weighted: raise ValueError( "Graph shortest path with affinity distance only " "valid for weighted graphs. " "For unweighted graphs, use `distance='constant'` " - "or `distance='data'`.") + "or `distance='data'`." + ) def _default_shortest_path_distance(self): if not self.weighted: - distance = 'data' + distance = "data" _logger.info("Using ambient data distances.") else: - distance = 'affinity' + distance = "affinity" _logger.info("Using negative log affinity distances.") return distance - def shortest_path(self, method='auto', distance=None): + def shortest_path(self, method="auto", distance=None): """ Find the length of the shortest path between every pair of vertices on the graph @@ -830,19 +882,21 @@ def shortest_path(self, method='auto', distance=None): self._check_shortest_path_distance(distance) - if distance == 'constant': + if distance == "constant": D = self.K - elif distance == 'data': + elif distance == "data": D = sparse.coo_matrix(self.K) - D.data = np.sqrt(np.sum(( - self.data_nu[D.row] - self.data_nu[D.col])**2, axis=1)) - elif distance == 'affinity': + D.data = np.sqrt( + np.sum((self.data_nu[D.row] - self.data_nu[D.col]) ** 2, axis=1) + ) + elif distance == "affinity": D = sparse.csr_matrix(self.K) D.data = -1 * np.log(D.data) else: raise ValueError( "Expected `distance` in ['constant', 'data', 'affinity']. " - "Got {}".format(distance)) + "Got {}".format(distance) + ) P = graph_shortest_path(D, method=method) # symmetrize for numerical error @@ -864,16 +918,14 @@ class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)): kernel matrix """ - def __init__(self, lap_type='combinatorial', coords=None, - plotting=None, **kwargs): + def __init__(self, lap_type="combinatorial", coords=None, plotting=None, **kwargs): if plotting is None: plotting = {} W = self._build_weight_from_kernel(self.K) - super().__init__(W, - lap_type=lap_type, - coords=coords, - plotting=plotting, **kwargs) + super().__init__( + W, lap_type=lap_type, coords=coords, plotting=plotting, **kwargs + ) @property @abc.abstractmethod @@ -954,9 +1006,7 @@ class DataGraph(with_metaclass(abc.ABCMeta, Data, BaseGraph)): n_jobs = -2, all CPUs but one are used """ - def __init__(self, data, - verbose=True, - n_jobs=1, **kwargs): + def __init__(self, data, verbose=True, n_jobs=1, **kwargs): # kwargs are ignored self.n_jobs = n_jobs self.verbose = verbose @@ -1016,8 +1066,7 @@ def _check_extension_shape(self, Y): `self.n_pca`. """ if len(Y.shape) != 2: - raise ValueError("Expected a 2D matrix. Y has shape {}".format( - Y.shape)) + raise ValueError("Expected a 2D matrix. Y has shape {}".format(Y.shape)) if not Y.shape[1] == self.data_nu.shape[1]: # try PCA transform if Y.shape[1] == self.data.shape[1]: @@ -1026,13 +1075,12 @@ def _check_extension_shape(self, Y): # wrong shape if self.data.shape[1] != self.data_nu.shape[1]: # PCA is possible - msg = ("Y must be of shape either " - "(n, {}) or (n, {})").format( - self.data.shape[1], self.data_nu.shape[1]) + msg = ("Y must be of shape either " "(n, {}) or (n, {})").format( + self.data.shape[1], self.data_nu.shape[1] + ) else: # no PCA, only one choice of shape - msg = "Y must be of shape (n, {})".format( - self.data.shape[1]) + msg = "Y must be of shape (n, {})".format(self.data.shape[1]) raise ValueError(msg) return Y @@ -1062,7 +1110,7 @@ def extend_to_data(self, Y): """ Y = self._check_extension_shape(Y) kernel = self.build_kernel_to_data(Y) - transitions = normalize(kernel, norm='l1', axis=1) + transitions = normalize(kernel, norm="l1", axis=1) return transitions def interpolate(self, transform, transitions=None, Y=None): @@ -1095,8 +1143,7 @@ def interpolate(self, transform, transitions=None, Y=None): """ if transitions is None: if Y is None: - raise ValueError( - "Either `transitions` or `Y` must be provided.") + raise ValueError("Either `transitions` or `Y` must be provided.") else: transitions = self.extend_to_data(Y) Y_transform = transitions.dot(transform) @@ -1119,10 +1166,10 @@ def set_params(self, **params): ------- self """ - if 'n_jobs' in params: - self.n_jobs = params['n_jobs'] - if 'verbose' in params: - self.verbose = params['verbose'] + if "n_jobs" in params: + self.n_jobs = params["n_jobs"] + if "verbose" in params: + self.verbose = params["verbose"] _logger.set_level(self.verbose) super().set_params(**params) return self diff --git a/graphtools/graphs.py b/graphtools/graphs.py index 6a1a021..b5fe0f5 100644 --- a/graphtools/graphs.py +++ b/graphtools/graphs.py @@ -15,7 +15,7 @@ from . import utils from .base import DataGraph, PyGSPGraph -_logger = tasklogger.get_tasklogger('graphtools') +_logger = tasklogger.get_tasklogger("graphtools") class kNNGraph(DataGraph): @@ -64,36 +64,64 @@ class kNNGraph(DataGraph): between KD tree, ball tree and brute force? """ - def __init__(self, data, knn=5, decay=None, - bandwidth=None, bandwidth_scale=1.0, distance='euclidean', - thresh=1e-4, n_pca=None, **kwargs): + def __init__( + self, + data, + knn=5, + decay=None, + knn_max=None, + search_multiplier=20, + bandwidth=None, + bandwidth_scale=1.0, + distance="euclidean", + thresh=1e-4, + n_pca=None, + **kwargs + ): + + if decay is not None: + if thresh <= 0 and knn_max is None: + raise ValueError( + "Cannot instantiate a kNNGraph with `decay=None`, " + "`thresh=0` and `knn_max=None`. Use a TraditionalGraph instead." + ) + elif thresh < np.finfo(float).eps: + thresh = np.finfo(float).eps - if decay is not None and thresh <= 0: - raise ValueError("Cannot instantiate a kNNGraph with `decay=None` " - "and `thresh=0`. Use a TraditionalGraph instead.") if callable(bandwidth): - raise NotImplementedError("Callable bandwidth is only supported by" - " graphtools.graphs.TraditionalGraph.") + raise NotImplementedError( + "Callable bandwidth is only supported by" + " graphtools.graphs.TraditionalGraph." + ) if knn is None and bandwidth is None: - raise ValueError( - "Either `knn` or `bandwidth` must be provided.") + raise ValueError("Either `knn` or `bandwidth` must be provided.") elif knn is None and bandwidth is not None: # implementation requires a knn value knn = 5 if decay is None and bandwidth is not None: - warnings.warn("`bandwidth` is not used when `decay=None`.", - UserWarning) + warnings.warn("`bandwidth` is not used when `decay=None`.", UserWarning) if knn > data.shape[0] - 2: - warnings.warn("Cannot set knn ({k}) to be greater than " - "n_samples ({n}). Setting knn={n}".format( - k=knn, n=data.shape[0] - 2)) + warnings.warn( + "Cannot set knn ({k}) to be greater than " + "n_samples ({n}). Setting knn={n}".format(k=knn, n=data.shape[0] - 2) + ) knn = data.shape[0] - 2 - if n_pca in [None,0,False] and data.shape[1] > 500: - warnings.warn("Building a kNNGraph on data of shape {} is " - "expensive. Consider setting n_pca.".format( - data.shape), UserWarning) + if knn_max is not None and knn_max < knn: + warnings.warn( + "Cannot set knn_max ({knn_max}) to be less than " + "knn ({knn}). Setting knn_max={knn}".format(knn=knn, knn_max=knn_max) + ) + knn_max = knn + if n_pca in [None, 0, False] and data.shape[1] > 500: + warnings.warn( + "Building a kNNGraph on data of shape {} is " + "expensive. Consider setting n_pca.".format(data.shape), + UserWarning, + ) self.knn = knn + self.knn_max = knn_max + self.search_multiplier = search_multiplier self.decay = decay self.bandwidth = bandwidth self.bandwidth_scale = bandwidth_scale @@ -105,15 +133,20 @@ def get_params(self): """Get parameters from this object """ params = super().get_params() - params.update({'knn': self.knn, - 'decay': self.decay, - 'bandwidth': self.bandwidth, - 'bandwidth_scale': self.bandwidth_scale, - 'distance': self.distance, - 'thresh': self.thresh, - 'n_jobs': self.n_jobs, - 'random_state': self.random_state, - 'verbose': self.verbose}) + params.update( + { + "knn": self.knn, + "decay": self.decay, + "bandwidth": self.bandwidth, + "bandwidth_scale": self.bandwidth_scale, + "knn_max": self.knn_max, + "distance": self.distance, + "thresh": self.thresh, + "n_jobs": self.n_jobs, + "random_state": self.random_state, + "verbose": self.verbose, + } + ) return params def set_params(self, **params): @@ -127,6 +160,7 @@ def set_params(self, **params): - verbose Invalid parameters: (these would require modifying the kernel matrix) - knn + - knn_max - decay - bandwidth - bandwidth_scale @@ -141,31 +175,31 @@ def set_params(self, **params): ------- self """ - if 'knn' in params and params['knn'] != self.knn: + if "knn" in params and params["knn"] != self.knn: raise ValueError("Cannot update knn. Please create a new graph") - if 'decay' in params and params['decay'] != self.decay: + if "knn_max" in params and params["knn_max"] != self.knn: + raise ValueError("Cannot update knn_max. Please create a new graph") + if "decay" in params and params["decay"] != self.decay: raise ValueError("Cannot update decay. Please create a new graph") - if 'bandwidth' in params and params['bandwidth'] != self.bandwidth: - raise ValueError( - "Cannot update bandwidth. Please create a new graph") - if 'bandwidth_scale' in params and \ - params['bandwidth_scale'] != self.bandwidth_scale: - raise ValueError( - "Cannot update bandwidth_scale. Please create a new graph") - if 'distance' in params and params['distance'] != self.distance: - raise ValueError("Cannot update distance. " - "Please create a new graph") - if 'thresh' in params and params['thresh'] != self.thresh \ - and self.decay != 0: + if "bandwidth" in params and params["bandwidth"] != self.bandwidth: + raise ValueError("Cannot update bandwidth. Please create a new graph") + if ( + "bandwidth_scale" in params + and params["bandwidth_scale"] != self.bandwidth_scale + ): + raise ValueError("Cannot update bandwidth_scale. Please create a new graph") + if "distance" in params and params["distance"] != self.distance: + raise ValueError("Cannot update distance. " "Please create a new graph") + if "thresh" in params and params["thresh"] != self.thresh and self.decay != 0: raise ValueError("Cannot update thresh. Please create a new graph") - if 'n_jobs' in params: - self.n_jobs = params['n_jobs'] + if "n_jobs" in params: + self.n_jobs = params["n_jobs"] if hasattr(self, "_knn_tree"): self.knn_tree.set_params(n_jobs=self.n_jobs) - if 'random_state' in params: - self.random_state = params['random_state'] - if 'verbose' in params: - self.verbose = params['verbose'] + if "random_state" in params: + self.random_state = params["random_state"] + if "verbose" in params: + self.verbose = params["verbose"] # update superclass parameters super().set_params(**params) return self @@ -188,21 +222,25 @@ def knn_tree(self): try: self._knn_tree = NearestNeighbors( n_neighbors=self.knn + 1, - algorithm='ball_tree', + algorithm="ball_tree", metric=self.distance, - n_jobs=self.n_jobs).fit(self.data_nu) + n_jobs=self.n_jobs, + ).fit(self.data_nu) except ValueError: # invalid metric warnings.warn( "Metric {} not valid for `sklearn.neighbors.BallTree`. " "Graph instantiation may be slower than normal.".format( - self.distance), - UserWarning) + self.distance + ), + UserWarning, + ) self._knn_tree = NearestNeighbors( n_neighbors=self.knn + 1, - algorithm='auto', + algorithm="auto", metric=self.distance, - n_jobs=self.n_jobs).fit(self.data_nu) + n_jobs=self.n_jobs, + ).fit(self.data_nu) return self._knn_tree def build_kernel(self): @@ -217,37 +255,45 @@ def build_kernel(self): symmetric matrix with ones down the diagonal with no non-negative entries. """ - K = self.build_kernel_to_data(self.data_nu, knn=self.knn + 1) + knn_max = self.knn_max + 1 if self.knn_max else None + K = self.build_kernel_to_data(self.data_nu, knn=self.knn + 1, knn_max=knn_max) return K def _check_duplicates(self, distances, indices): if np.any(distances[:, 1] == 0): has_duplicates = distances[:, 1] == 0 if np.sum(distances[:, 1:] == 0) < 20: - idx = np.argwhere((distances == 0) & - has_duplicates[:, None]) + idx = np.argwhere((distances == 0) & has_duplicates[:, None]) duplicate_ids = np.array( - [[indices[i[0], i[1]], i[0]] - for i in idx if indices[i[0], i[1]] < i[0]]) - duplicate_ids = duplicate_ids[ - np.argsort(duplicate_ids[:, 0])] - duplicate_names = ", ".join(["{} and {}".format(i[0], i[1]) - for i in duplicate_ids]) + [ + [indices[i[0], i[1]], i[0]] + for i in idx + if indices[i[0], i[1]] < i[0] + ] + ) + duplicate_ids = duplicate_ids[np.argsort(duplicate_ids[:, 0])] + duplicate_names = ", ".join( + ["{} and {}".format(i[0], i[1]) for i in duplicate_ids] + ) warnings.warn( "Detected zero distance between samples {}. " "Consider removing duplicates to avoid errors in " "downstream processing.".format(duplicate_names), - RuntimeWarning) + RuntimeWarning, + ) else: warnings.warn( "Detected zero distance between {} pairs of samples. " "Consider removing duplicates to avoid errors in " "downstream processing.".format( - np.sum(np.sum(distances[:, 1:] == 0))), - RuntimeWarning) - - def build_kernel_to_data(self, Y, knn=None, bandwidth=None, - bandwidth_scale=None): + np.sum(np.sum(distances[:, 1:] == 0)) + ), + RuntimeWarning, + ) + + def build_kernel_to_data( + self, Y, knn=None, knn_max=None, bandwidth=None, bandwidth_scale=None + ): """Build a kernel from new input data `Y` to the `self.data` Parameters @@ -287,24 +333,29 @@ def build_kernel_to_data(self, Y, knn=None, bandwidth=None, if bandwidth_scale is None: bandwidth_scale = self.bandwidth_scale if knn > self.data.shape[0]: - warnings.warn("Cannot set knn ({k}) to be greater than " - "n_samples ({n}). Setting knn={n}".format( - k=knn, n=self.data.shape[0])) + warnings.warn( + "Cannot set knn ({k}) to be greater than " + "n_samples ({n}). Setting knn={n}".format( + k=knn, n=self.data_nu.shape[0] + ) + ) + knn = self.data_nu.shape[0] + if knn_max is None: + knn_max = self.data_nu.shape[0] Y = self._check_extension_shape(Y) if self.decay is None or self.thresh == 1: with _logger.task("KNN search"): # binary connectivity matrix K = self.knn_tree.kneighbors_graph( - Y, n_neighbors=knn, - mode='connectivity') + Y, n_neighbors=knn, mode="connectivity" + ) else: with _logger.task("KNN search"): # sparse fast alpha decay knn_tree = self.knn_tree - search_knn = min(knn * 20, self.data_nu.shape[0]) - distances, indices = knn_tree.kneighbors( - Y, n_neighbors=search_knn) + search_knn = min(knn * self.search_multiplier, knn_max) + distances, indices = knn_tree.kneighbors(Y, n_neighbors=search_knn) self._check_duplicates(distances, indices) with _logger.task("affinities"): if bandwidth is None: @@ -315,57 +366,86 @@ def build_kernel_to_data(self, Y, knn=None, bandwidth=None, # check for zero bandwidth bandwidth = np.maximum(bandwidth, np.finfo(float).eps) - radius = bandwidth * np.power(-1 * np.log(self.thresh), - 1 / self.decay) - update_idx = np.argwhere( - np.max(distances, axis=1) < radius).reshape(-1) - _logger.debug("search_knn = {}; {} remaining".format( - search_knn, len(update_idx))) + radius = bandwidth * np.power(-1 * np.log(self.thresh), 1 / self.decay) + update_idx = np.argwhere(np.max(distances, axis=1) < radius).reshape(-1) + _logger.debug( + "search_knn = {}; {} remaining".format(search_knn, len(update_idx)) + ) if len(update_idx) > 0: distances = [d for d in distances] indices = [i for i in indices] - while len(update_idx) > Y.shape[0] // 10 and \ - search_knn < self.data_nu.shape[0] / 2: - # increase the knn search - search_knn = min(search_knn * 20, self.data_nu.shape[0]) + # increase the knn search + search_knn = min(search_knn * self.search_multiplier, knn_max) + while ( + len(update_idx) > Y.shape[0] // 10 + and search_knn < self.data_nu.shape[0] / 2 + and search_knn < knn_max + ): dist_new, ind_new = knn_tree.kneighbors( - Y[update_idx], n_neighbors=search_knn) + Y[update_idx], n_neighbors=search_knn + ) for i, idx in enumerate(update_idx): distances[idx] = dist_new[i] indices[idx] = ind_new[i] - update_idx = [i for i, d in enumerate(distances) if np.max(d) < - (radius if isinstance(bandwidth, numbers.Number) - else radius[i])] - _logger.debug("search_knn = {}; {} remaining".format( - search_knn, - len(update_idx))) + update_idx = [ + i + for i, d in enumerate(distances) + if np.max(d) + < ( + radius + if isinstance(bandwidth, numbers.Number) + else radius[i] + ) + ] + _logger.debug( + "search_knn = {}; {} remaining".format( + search_knn, len(update_idx) + ) + ) + # increase the knn search + search_knn = min(search_knn * self.search_multiplier, knn_max) if search_knn > self.data_nu.shape[0] / 2: knn_tree = NearestNeighbors( - search_knn, algorithm='brute', - n_jobs=self.n_jobs).fit(self.data_nu) + search_knn, algorithm="brute", n_jobs=self.n_jobs + ).fit(self.data_nu) if len(update_idx) > 0: - _logger.debug( - "radius search on {}".format(len(update_idx))) - # give up - radius search - dist_new, ind_new = knn_tree.radius_neighbors( - Y[update_idx, :], - radius=radius - if isinstance(bandwidth, numbers.Number) - else np.max(radius[update_idx])) - for i, idx in enumerate(update_idx): - distances[idx] = dist_new[i] - indices[idx] = ind_new[i] + if search_knn == knn_max: + _logger.debug( + "knn search to knn_max ({}) on {}".format( + knn_max, len(update_idx) + ) + ) + # give up - search out to knn_max + dist_new, ind_new = knn_tree.kneighbors( + Y[update_idx], n_neighbors=search_knn + ) + for i, idx in enumerate(update_idx): + distances[idx] = dist_new[i] + indices[idx] = ind_new[i] + else: + _logger.debug("radius search on {}".format(len(update_idx))) + # give up - radius search + dist_new, ind_new = knn_tree.radius_neighbors( + Y[update_idx, :], + radius=radius + if isinstance(bandwidth, numbers.Number) + else np.max(radius[update_idx]), + ) + for i, idx in enumerate(update_idx): + distances[idx] = dist_new[i] + indices[idx] = ind_new[i] if isinstance(bandwidth, numbers.Number): data = np.concatenate(distances) / bandwidth else: - data = np.concatenate([distances[i] / bandwidth[i] - for i in range(len(distances))]) + data = np.concatenate( + [distances[i] / bandwidth[i] for i in range(len(distances))] + ) indices = np.concatenate(indices) - indptr = np.concatenate( - [[0], np.cumsum([len(d) for d in distances])]) - K = sparse.csr_matrix((data, indices, indptr), - shape=(Y.shape[0], self.data_nu.shape[0])) + indptr = np.concatenate([[0], np.cumsum([len(d) for d in distances])]) + K = sparse.csr_matrix( + (data, indices, indptr), shape=(Y.shape[0], self.data_nu.shape[0]) + ) K.data = np.exp(-1 * np.power(K.data, self.decay)) # handle nan K.data = np.where(np.isnan(K.data), 1, K.data) @@ -429,12 +509,14 @@ def __init__(self, data, n_landmark=2000, n_svd=100, **kwargs): if n_landmark >= data.shape[0]: raise ValueError( "n_landmark ({}) >= n_samples ({}). Use " - "kNNGraph instead".format(n_landmark, data.shape[0])) + "kNNGraph instead".format(n_landmark, data.shape[0]) + ) if n_svd >= data.shape[0]: - warnings.warn("n_svd ({}) >= n_samples ({}) Consider " - "using kNNGraph or lower n_svd".format( - n_svd, data.shape[0]), - RuntimeWarning) + warnings.warn( + "n_svd ({}) >= n_samples ({}) Consider " + "using kNNGraph or lower n_svd".format(n_svd, data.shape[0]), + RuntimeWarning, + ) self.n_landmark = n_landmark self.n_svd = n_svd super().__init__(data, **kwargs) @@ -443,8 +525,7 @@ def get_params(self): """Get parameters from this object """ params = super().get_params() - params.update({'n_landmark': self.n_landmark, - 'n_pca': self.n_pca}) + params.update({"n_landmark": self.n_landmark, "n_pca": self.n_pca}) return params def set_params(self, **params): @@ -466,11 +547,11 @@ def set_params(self, **params): """ # update parameters reset_landmarks = False - if 'n_landmark' in params and params['n_landmark'] != self.n_landmark: - self.n_landmark = params['n_landmark'] + if "n_landmark" in params and params["n_landmark"] != self.n_landmark: + self.n_landmark = params["n_landmark"] reset_landmarks = True - if 'n_svd' in params and params['n_svd'] != self.n_svd: - self.n_svd = params['n_svd'] + if "n_svd" in params and params["n_svd"] != self.n_svd: + self.n_svd = params["n_svd"] reset_landmarks = True # update superclass parameters super().set_params(**params) @@ -549,15 +630,19 @@ def _landmarks_to_data(self): landmarks = np.unique(self.clusters) if sparse.issparse(self.kernel): pmn = sparse.vstack( - [sparse.csr_matrix(self.kernel[self.clusters == i, :].sum( - axis=0)) for i in landmarks]) + [ + sparse.csr_matrix(self.kernel[self.clusters == i, :].sum(axis=0)) + for i in landmarks + ] + ) else: - pmn = np.array([np.sum(self.kernel[self.clusters == i, :], axis=0) - for i in landmarks]) + pmn = np.array( + [np.sum(self.kernel[self.clusters == i, :], axis=0) for i in landmarks] + ) return pmn def _data_transitions(self): - return normalize(self._landmarks_to_data(), 'l1', axis=1) + return normalize(self._landmarks_to_data(), "l1", axis=1) def build_landmark_op(self): """Build the landmark operator @@ -570,25 +655,27 @@ def build_landmark_op(self): is_sparse = sparse.issparse(self.kernel) # spectral clustering with _logger.task("SVD"): - _, _, VT = randomized_svd(self.diff_aff, - n_components=self.n_svd, - random_state=self.random_state) + _, _, VT = randomized_svd( + self.diff_aff, + n_components=self.n_svd, + random_state=self.random_state, + ) with _logger.task("KMeans"): kmeans = MiniBatchKMeans( self.n_landmark, init_size=3 * self.n_landmark, batch_size=10000, - random_state=self.random_state) - self._clusters = kmeans.fit_predict( - self.diff_op.dot(VT.T)) + random_state=self.random_state, + ) + self._clusters = kmeans.fit_predict(self.diff_op.dot(VT.T)) # transition matrices pmn = self._landmarks_to_data() # row normalize pnm = pmn.transpose() - pmn = normalize(pmn, norm='l1', axis=1) - pnm = normalize(pnm, norm='l1', axis=1) + pmn = normalize(pmn, norm="l1", axis=1) + pnm = normalize(pnm, norm="l1", axis=1) landmark_op = pmn.dot(pnm) # sparsity agnostic matrix multiplication if is_sparse: # no need to have a sparse landmark operator @@ -624,13 +711,19 @@ def extend_to_data(self, data, **kwargs): kernel = self.build_kernel_to_data(data, **kwargs) if sparse.issparse(kernel): pnm = sparse.hstack( - [sparse.csr_matrix(kernel[:, self.clusters == i].sum( - axis=1)) for i in np.unique(self.clusters)]) + [ + sparse.csr_matrix(kernel[:, self.clusters == i].sum(axis=1)) + for i in np.unique(self.clusters) + ] + ) else: - pnm = np.array([np.sum( - kernel[:, self.clusters == i], - axis=1).T for i in np.unique(self.clusters)]).transpose() - pnm = normalize(pnm, norm='l1', axis=1) + pnm = np.array( + [ + np.sum(kernel[:, self.clusters == i], axis=1).T + for i in np.unique(self.clusters) + ] + ).transpose() + pnm = normalize(pnm, norm="l1", axis=1) return pnm def interpolate(self, transform, transitions=None, Y=None): @@ -728,44 +821,58 @@ class TraditionalGraph(DataGraph): Only one of `precomputed` and `n_pca` can be set. """ - def __init__(self, data, - knn=5, decay=40, - bandwidth=None, - bandwidth_scale=1.0, - distance='euclidean', - n_pca=None, - thresh=1e-4, - precomputed=None, **kwargs): - if decay is None and precomputed not in ['affinity', 'adjacency']: + def __init__( + self, + data, + knn=5, + decay=40, + bandwidth=None, + bandwidth_scale=1.0, + distance="euclidean", + n_pca=None, + thresh=1e-4, + precomputed=None, + **kwargs + ): + if decay is None and precomputed not in ["affinity", "adjacency"]: # decay high enough is basically a binary kernel raise ValueError( "`decay` must be provided for a " - "TraditionalGraph. For kNN kernel, use kNNGraph.") - if precomputed is not None and n_pca not in [None,0,False]: + "TraditionalGraph. For kNN kernel, use kNNGraph." + ) + if precomputed is not None and n_pca not in [None, 0, False]: # the data itself is a matrix of distances / affinities n_pca = None - warnings.warn("n_pca cannot be given on a precomputed graph." - " Setting n_pca=None", RuntimeWarning) + warnings.warn( + "n_pca cannot be given on a precomputed graph." " Setting n_pca=None", + RuntimeWarning, + ) if knn is None and bandwidth is None: - raise ValueError( - "Either `knn` or `bandwidth` must be provided.") + raise ValueError("Either `knn` or `bandwidth` must be provided.") if knn is not None and knn > data.shape[0] - 2: - warnings.warn("Cannot set knn ({k}) to be greater than " - " n_samples - 2 ({n}). Setting knn={n}".format( - k=knn, n=data.shape[0] - 2)) + warnings.warn( + "Cannot set knn ({k}) to be greater than " + " n_samples - 2 ({n}). Setting knn={n}".format( + k=knn, n=data.shape[0] - 2 + ) + ) knn = data.shape[0] - 2 if precomputed is not None: if precomputed not in ["distance", "affinity", "adjacency"]: - raise ValueError("Precomputed value {} not recognized. " - "Choose from ['distance', 'affinity', " - "'adjacency']") + raise ValueError( + "Precomputed value {} not recognized. " + "Choose from ['distance', 'affinity', " + "'adjacency']" + ) elif data.shape[0] != data.shape[1]: - raise ValueError("Precomputed {} must be a square matrix. " - "{} was given".format(precomputed, - data.shape)) + raise ValueError( + "Precomputed {} must be a square matrix. " + "{} was given".format(precomputed, data.shape) + ) elif (data < 0).sum() > 0: - raise ValueError("Precomputed {} should be " - "non-negative".format(precomputed)) + raise ValueError( + "Precomputed {} should be " "non-negative".format(precomputed) + ) self.knn = knn self.decay = decay self.bandwidth = bandwidth @@ -774,19 +881,22 @@ def __init__(self, data, self.thresh = thresh self.precomputed = precomputed - super().__init__(data, n_pca=n_pca, - **kwargs) + super().__init__(data, n_pca=n_pca, **kwargs) def get_params(self): """Get parameters from this object """ params = super().get_params() - params.update({'knn': self.knn, - 'decay': self.decay, - 'bandwidth': self.bandwidth, - 'bandwidth_scale': self.bandwidth_scale, - 'distance': self.distance, - 'precomputed': self.precomputed}) + params.update( + { + "knn": self.knn, + "decay": self.decay, + "bandwidth": self.bandwidth, + "bandwidth_scale": self.bandwidth_scale, + "distance": self.distance, + "precomputed": self.precomputed, + } + ) return params def set_params(self, **params): @@ -810,29 +920,33 @@ def set_params(self, **params): ------- self """ - if 'precomputed' in params and \ - params['precomputed'] != self.precomputed: - raise ValueError("Cannot update precomputed. " - "Please create a new graph") - if 'distance' in params and params['distance'] != self.distance and \ - self.precomputed is None: - raise ValueError("Cannot update distance. " - "Please create a new graph") - if 'knn' in params and params['knn'] != self.knn and \ - self.precomputed is None: + if "precomputed" in params and params["precomputed"] != self.precomputed: + raise ValueError("Cannot update precomputed. " "Please create a new graph") + if ( + "distance" in params + and params["distance"] != self.distance + and self.precomputed is None + ): + raise ValueError("Cannot update distance. " "Please create a new graph") + if "knn" in params and params["knn"] != self.knn and self.precomputed is None: raise ValueError("Cannot update knn. Please create a new graph") - if 'decay' in params and params['decay'] != self.decay and \ - self.precomputed is None: + if ( + "decay" in params + and params["decay"] != self.decay + and self.precomputed is None + ): raise ValueError("Cannot update decay. Please create a new graph") - if 'bandwidth' in params and \ - params['bandwidth'] != self.bandwidth and \ - self.precomputed is None: - raise ValueError( - "Cannot update bandwidth. Please create a new graph") - if 'bandwidth_scale' in params and \ - params['bandwidth_scale'] != self.bandwidth_scale: - raise ValueError( - "Cannot update bandwidth_scale. Please create a new graph") + if ( + "bandwidth" in params + and params["bandwidth"] != self.bandwidth + and self.precomputed is None + ): + raise ValueError("Cannot update bandwidth. Please create a new graph") + if ( + "bandwidth_scale" in params + and params["bandwidth_scale"] != self.bandwidth_scale + ): + raise ValueError("Cannot update bandwidth_scale. Please create a new graph") # update superclass parameters super().set_params(**params) return self @@ -863,9 +977,9 @@ def build_kernel(self): elif self.precomputed == "adjacency": # need to set diagonal to one to make it an affinity matrix K = self.data_nu - if sparse.issparse(K) and \ - not (isinstance(K, sparse.dok_matrix) or - isinstance(K, sparse.lil_matrix)): + if sparse.issparse(K) and not ( + isinstance(K, sparse.dok_matrix) or isinstance(K, sparse.lil_matrix) + ): K = K.tolil() K = utils.set_diagonal(K, 1) else: @@ -879,25 +993,29 @@ def build_kernel(self): if np.any(pdx == 0): pdx = squareform(pdx) duplicate_ids = np.array( - [i for i in np.argwhere(pdx == 0) - if i[1] > i[0]]) - duplicate_names = ", ".join(["{} and {}".format(i[0], i[1]) - for i in duplicate_ids]) + [i for i in np.argwhere(pdx == 0) if i[1] > i[0]] + ) + duplicate_names = ", ".join( + ["{} and {}".format(i[0], i[1]) for i in duplicate_ids] + ) warnings.warn( "Detected zero distance between samples {}. " "Consider removing duplicates to avoid errors in " "downstream processing.".format(duplicate_names), - RuntimeWarning) + RuntimeWarning, + ) else: pdx = squareform(pdx) else: raise ValueError( "precomputed='{}' not recognized. " "Choose from ['affinity', 'adjacency', 'distance', " - "None]".format(self.precomputed)) + "None]".format(self.precomputed) + ) if self.bandwidth is None: - knn_dist = np.partition( - pdx, self.knn + 1, axis=1)[:, :self.knn + 1] + knn_dist = np.partition(pdx, self.knn + 1, axis=1)[ + :, : self.knn + 1 + ] bandwidth = np.max(knn_dist, axis=1) elif callable(self.bandwidth): bandwidth = self.bandwidth(pdx) @@ -910,9 +1028,11 @@ def build_kernel(self): K = np.where(np.isnan(K), 1, K) # truncate if sparse.issparse(K): - if not (isinstance(K, sparse.csr_matrix) or - isinstance(K, sparse.csc_matrix) or - isinstance(K, sparse.bsr_matrix)): + if not ( + isinstance(K, sparse.csr_matrix) + or isinstance(K, sparse.csc_matrix) + or isinstance(K, sparse.bsr_matrix) + ): K = K.tocsr() K.data[K.data < self.thresh] = 0 K = K.tocoo() @@ -971,7 +1091,7 @@ def build_kernel_to_data(self, Y, knn=None, bandwidth=None, bandwidth_scale=None bandwidth = bandwidth(pdx) bandwidth = bandwidth_scale * bandwidth pdx = (pdx.T / bandwidth).T - K = np.exp(-1 * pdx**self.decay) + K = np.exp(-1 * pdx ** self.decay) # handle nan K = np.where(np.isnan(K), 1, K) K[K < self.thresh] = 0 @@ -986,17 +1106,18 @@ def weighted(self): def _check_shortest_path_distance(self, distance): if self.precomputed is not None: - if distance == 'data': + if distance == "data": raise ValueError( "Graph shortest path with data distance not " "valid for precomputed graphs. For precomputed graphs, " "use `distance='constant'` for unweighted graphs and " - "`distance='affinity'` for weighted graphs.") + "`distance='affinity'` for weighted graphs." + ) super()._check_shortest_path_distance(distance) def _default_shortest_path_distance(self): if self.precomputed is not None and not self.weighted: - distance = 'constant' + distance = "constant" _logger.info("Using constant distances.") else: distance = super()._default_shortest_path_distance() @@ -1034,19 +1155,24 @@ class MNNGraph(DataGraph): Graphs representing each batch separately """ - def __init__(self, data, sample_idx, - knn=5, beta=1, n_pca=None, - decay=None, - adaptive_k=None, - bandwidth=None, - distance='euclidean', - thresh=1e-4, - n_jobs=1, - **kwargs): + def __init__( + self, + data, + sample_idx, + knn=5, + beta=1, + n_pca=None, + decay=None, + adaptive_k=None, + bandwidth=None, + distance="euclidean", + thresh=1e-4, + n_jobs=1, + **kwargs + ): self.beta = beta self.sample_idx = sample_idx - self.samples, self.n_cells = np.unique( - self.sample_idx, return_counts=True) + self.samples, self.n_cells = np.unique(self.sample_idx, return_counts=True) self.knn = knn self.decay = decay self.distance = distance @@ -1055,27 +1181,33 @@ def __init__(self, data, sample_idx, self.n_jobs = n_jobs if sample_idx is None: - raise ValueError("sample_idx must be given. For a graph without" - " batch correction, use kNNGraph.") + raise ValueError( + "sample_idx must be given. For a graph without" + " batch correction, use kNNGraph." + ) elif len(sample_idx) != data.shape[0]: - raise ValueError("sample_idx ({}) must be the same length as " - "data ({})".format(len(sample_idx), - data.shape[0])) - elif len(self.samples) == 1: raise ValueError( - "sample_idx must contain more than one unique value") + "sample_idx ({}) must be the same length as " + "data ({})".format(len(sample_idx), data.shape[0]) + ) + elif len(self.samples) == 1: + raise ValueError("sample_idx must contain more than one unique value") if adaptive_k is not None: - warnings.warn("`adaptive_k` has been deprecated. Using fixed knn.", - DeprecationWarning) + warnings.warn( + "`adaptive_k` has been deprecated. Using fixed knn.", DeprecationWarning + ) super().__init__(data, n_pca=n_pca, **kwargs) def _check_symmetrization(self, kernel_symm, theta): - if (kernel_symm == 'theta' or kernel_symm == 'mnn') \ - and theta is not None and \ - not isinstance(theta, numbers.Number): - raise TypeError("Expected `theta` as a float. " - "Got {}.".format(type(theta))) + if ( + (kernel_symm == "theta" or kernel_symm == "mnn") + and theta is not None + and not isinstance(theta, numbers.Number) + ): + raise TypeError( + "Expected `theta` as a float. " "Got {}.".format(type(theta)) + ) else: super()._check_symmetrization(kernel_symm, theta) @@ -1083,13 +1215,17 @@ def get_params(self): """Get parameters from this object """ params = super().get_params() - params.update({'beta': self.beta, - 'knn': self.knn, - 'decay': self.decay, - 'bandwidth': self.bandwidth, - 'distance': self.distance, - 'thresh': self.thresh, - 'n_jobs': self.n_jobs}) + params.update( + { + "beta": self.beta, + "knn": self.knn, + "decay": self.decay, + "bandwidth": self.bandwidth, + "distance": self.distance, + "thresh": self.thresh, + "n_jobs": self.n_jobs, + } + ) return params def set_params(self, **params): @@ -1118,16 +1254,17 @@ def set_params(self, **params): self """ # mnn specific arguments - if 'beta' in params and params['beta'] != self.beta: + if "beta" in params and params["beta"] != self.beta: raise ValueError("Cannot update beta. Please create a new graph") # knn arguments - knn_kernel_args = ['knn', 'decay', 'distance', 'thresh', 'bandwidth'] - knn_other_args = ['n_jobs', 'random_state', 'verbose'] + knn_kernel_args = ["knn", "decay", "distance", "thresh", "bandwidth"] + knn_other_args = ["n_jobs", "random_state", "verbose"] for arg in knn_kernel_args: if arg in params and params[arg] != getattr(self, arg): - raise ValueError("Cannot update {}. " - "Please create a new graph".format(arg)) + raise ValueError( + "Cannot update {}. " "Please create a new graph".format(arg) + ) for arg in knn_other_args: if arg in params: self.__setattr__(arg, params[arg]) @@ -1152,57 +1289,71 @@ def build_kernel(self): with _logger.task("subgraphs"): self.subgraphs = [] from .api import Graph + # iterate through sample ids for i, idx in enumerate(self.samples): - _logger.debug("subgraph {}: sample {}, " - "n = {}, knn = {}".format( - i, idx, np.sum(self.sample_idx == idx), - self.knn)) + _logger.debug( + "subgraph {}: sample {}, " + "n = {}, knn = {}".format( + i, idx, np.sum(self.sample_idx == idx), self.knn + ) + ) # select data for sample data = self.data_nu[self.sample_idx == idx] # build a kNN graph for cells within sample - graph = Graph(data, n_pca=None, - knn=self.knn, - decay=self.decay, - bandwidth=self.bandwidth, - distance=self.distance, - thresh=self.thresh, - verbose=self.verbose, - random_state=self.random_state, - n_jobs=self.n_jobs, - kernel_symm='+', - initialize=True) + graph = Graph( + data, + n_pca=None, + knn=self.knn, + decay=self.decay, + bandwidth=self.bandwidth, + distance=self.distance, + thresh=self.thresh, + verbose=self.verbose, + random_state=self.random_state, + n_jobs=self.n_jobs, + kernel_symm="+", + initialize=True, + ) self.subgraphs.append(graph) # append to list of subgraphs with _logger.task("MNN kernel"): if self.thresh > 0 or self.decay is None: - K = sparse.lil_matrix( - (self.data_nu.shape[0], self.data_nu.shape[0])) + K = sparse.lil_matrix((self.data_nu.shape[0], self.data_nu.shape[0])) else: K = np.zeros([self.data_nu.shape[0], self.data_nu.shape[0]]) for i, X in enumerate(self.subgraphs): K = utils.set_submatrix( - K, self.sample_idx == self.samples[i], - self.sample_idx == self.samples[i], X.K) + K, + self.sample_idx == self.samples[i], + self.sample_idx == self.samples[i], + X.K, + ) within_batch_norm = np.array(np.sum(X.K, 1)).flatten() for j, Y in enumerate(self.subgraphs): if i == j: continue with _logger.task( - "kernel from sample {} to {}".format(self.samples[i], self.samples[j])): - Kij = Y.build_kernel_to_data( - X.data_nu, - knn=self.knn) + "kernel from sample {} to {}".format( + self.samples[i], self.samples[j] + ) + ): + Kij = Y.build_kernel_to_data(X.data_nu, knn=self.knn) between_batch_norm = np.array(np.sum(Kij, 1)).flatten() - scale = np.minimum(1, within_batch_norm / - between_batch_norm) * self.beta + scale = ( + np.minimum(1, within_batch_norm / between_batch_norm) + * self.beta + ) if sparse.issparse(Kij): Kij = Kij.multiply(scale[:, None]) else: Kij = Kij * scale[:, None] K = utils.set_submatrix( - K, self.sample_idx == self.samples[i], - self.sample_idx == self.samples[j], Kij) + K, + self.sample_idx == self.samples[i], + self.sample_idx == self.samples[j], + Kij, + ) return K def build_kernel_to_data(self, Y, theta=None): diff --git a/graphtools/utils.py b/graphtools/utils.py index db511ad..64d1b08 100644 --- a/graphtools/utils.py +++ b/graphtools/utils.py @@ -6,7 +6,7 @@ def if_sparse(sparse_func, dense_func, *args, **kwargs): if sparse.issparse(args[0]): for arg in args[1:]: - assert(sparse.issparse(arg)) + assert sparse.issparse(arg) return sparse_func(*args, **kwargs) else: return dense_func(*args, **kwargs) @@ -51,8 +51,9 @@ def set_submatrix(X, i, j, values): def sparse_nonzero_discrete(X, values): - if isinstance(X, (sparse.bsr_matrix, sparse.dia_matrix, - sparse.dok_matrix, sparse.lil_matrix)): + if isinstance( + X, (sparse.bsr_matrix, sparse.dia_matrix, sparse.dok_matrix, sparse.lil_matrix) + ): X = X.tocsr() return dense_nonzero_discrete(X.data, values) diff --git a/graphtools/version.py b/graphtools/version.py index 9c73af2..3e8d9f9 100644 --- a/graphtools/version.py +++ b/graphtools/version.py @@ -1 +1 @@ -__version__ = "1.3.1" +__version__ = "1.4.0" diff --git a/setup.py b/setup.py index cf358d5..4d7be50 100644 --- a/setup.py +++ b/setup.py @@ -3,78 +3,71 @@ from setuptools import setup install_requires = [ - 'numpy>=1.14.0', - 'scipy>=1.1.0', - 'pygsp>=0.5.1', - 'scikit-learn>=0.20.0', - 'future', - 'tasklogger>=1.0', + "numpy>=1.14.0", + "scipy>=1.1.0", + "pygsp>=0.5.1", + "scikit-learn>=0.20.0", + "future", + "tasklogger>=1.0", ] test_requires = [ - 'nose', - 'nose2', - 'pandas', - 'coverage', - 'coveralls', - 'python-igraph', - 'parameterized' + "nose", + "nose2", + "pandas", + "coverage", + "coveralls", + "python-igraph", + "parameterized", ] if sys.version_info[0] == 3: - test_requires += ['anndata'] + test_requires += ["anndata"] -doc_requires = [ - 'sphinx', - 'sphinxcontrib-napoleon', - 'sphinxcontrib-bibtex' -] +doc_requires = ["sphinx", "sphinxcontrib-napoleon", "sphinxcontrib-bibtex"] if sys.version_info[:2] < (3, 5): raise RuntimeError("Python version >=3.5 required.") +elif sys.version_info[:2] >= (3, 6): + test_requires += ["black"] -version_py = os.path.join(os.path.dirname( - __file__), 'graphtools', 'version.py') -version = open(version_py).read().strip().split( - '=')[-1].replace('"', '').strip() +version_py = os.path.join(os.path.dirname(__file__), "graphtools", "version.py") +version = open(version_py).read().strip().split("=")[-1].replace('"', "").strip() -readme = open('README.rst').read() +readme = open("README.rst").read() -setup(name='graphtools', - version=version, - description='graphtools', - author='Scott Gigante, Daniel Burkhardt, and Jay Stanley, Yale University', - author_email='scott.gigante@yale.edu', - packages=['graphtools', ], - license='GNU General Public License Version 2', - install_requires=install_requires, - extras_require={'test': test_requires, - 'doc': doc_requires}, - test_suite='nose2.collector.collector', - long_description=readme, - url='https://github.com/KrishnaswamyLab/graphtools', - download_url="https://github.com/KrishnaswamyLab/graphtools/archive/v{}.tar.gz".format( - version), - keywords=['graphs', - 'big-data', - 'signal processing', - 'manifold-learning', - ], - classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Console', - 'Framework :: Jupyter', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'Natural Language :: English', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Topic :: Scientific/Engineering :: Mathematics', - ] - ) +setup( + name="graphtools", + version=version, + description="graphtools", + author="Scott Gigante, Daniel Burkhardt, and Jay Stanley, Yale University", + author_email="scott.gigante@yale.edu", + packages=["graphtools",], + license="GNU General Public License Version 2", + install_requires=install_requires, + extras_require={"test": test_requires, "doc": doc_requires}, + test_suite="nose2.collector.collector", + long_description=readme, + url="https://github.com/KrishnaswamyLab/graphtools", + download_url="https://github.com/KrishnaswamyLab/graphtools/archive/v{}.tar.gz".format( + version + ), + keywords=["graphs", "big-data", "signal processing", "manifold-learning",], + classifiers=[ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Framework :: Jupyter", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Natural Language :: English", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Topic :: Scientific/Engineering :: Mathematics", + ], +) diff --git a/test/load_tests/__init__.py b/test/load_tests/__init__.py index 8105a4a..d6a3c1a 100644 --- a/test/load_tests/__init__.py +++ b/test/load_tests/__init__.py @@ -22,32 +22,42 @@ def reset_warnings(): def ignore_numpy_warning(): warnings.filterwarnings( - "ignore", category=PendingDeprecationWarning, + "ignore", + category=PendingDeprecationWarning, message="the matrix subclass is not the recommended way to represent " - "matrices or deal with linear algebra ") + "matrices or deal with linear algebra ", + ) def ignore_igraph_warning(): warnings.filterwarnings( - "ignore", category=DeprecationWarning, + "ignore", + category=DeprecationWarning, message="The SafeConfigParser class has been renamed to ConfigParser " "in Python 3.2. This alias will be removed in future versions. Use " - "ConfigParser directly instead") + "ConfigParser directly instead", + ) warnings.filterwarnings( - "ignore", category=DeprecationWarning, + "ignore", + category=DeprecationWarning, message="Using or importing the ABCs from 'collections' instead of from " - "'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working") + "'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working", + ) warnings.filterwarnings( - "ignore", category=DeprecationWarning, + "ignore", + category=DeprecationWarning, message="Using or importing the ABCs from 'collections' instead of from " - "'collections.abc' is deprecated, and in 3.8 it will stop working") + "'collections.abc' is deprecated, and in 3.8 it will stop working", + ) def ignore_joblib_warning(): warnings.filterwarnings( - "ignore", category=DeprecationWarning, + "ignore", + category=DeprecationWarning, message="check_pickle is deprecated in joblib 0.12 and will be removed" - " in 0.13") + " in 0.13", + ) reset_warnings() @@ -55,7 +65,7 @@ def ignore_joblib_warning(): global digits global data digits = datasets.load_digits() -data = digits['data'] +data = digits["data"] def generate_swiss_roll(n_samples=1000, noise=0.5, seed=42): @@ -73,19 +83,30 @@ def generate_swiss_roll(n_samples=1000, noise=0.5, seed=42): return X, sample_idx -def build_graph(data, n_pca=20, thresh=0, - decay=10, knn=3, - random_state=42, - sparse=False, - graph_class=graphtools.Graph, - verbose=0, - **kwargs): +def build_graph( + data, + n_pca=20, + thresh=0, + decay=10, + knn=3, + random_state=42, + sparse=False, + graph_class=graphtools.Graph, + verbose=0, + **kwargs +): if sparse: data = sp.coo_matrix(data) - return graph_class(data, thresh=thresh, n_pca=n_pca, - decay=decay, knn=knn, - random_state=42, verbose=verbose, - **kwargs) + return graph_class( + data, + thresh=thresh, + n_pca=n_pca, + decay=decay, + knn=knn, + random_state=42, + verbose=verbose, + **kwargs + ) def warns(*warns): @@ -98,7 +119,7 @@ def test_raises_type_error(): def test_that_fails_by_passing(): pass """ - valid = ' or '.join([w.__name__ for w in warns]) + valid = " or ".join([w.__name__ for w in warns]) def decorate(func): name = func.__name__ @@ -118,6 +139,8 @@ def newfunc(*arg, **kw): else: message = "%s() did not raise %s" % (name, valid) raise AssertionError(message) + newfunc = make_decorator(func)(newfunc) return newfunc + return decorate diff --git a/test/test_api.py b/test/test_api.py index e5ae0d8..7c17664 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -22,7 +22,7 @@ def test_from_igraph(): K[e[0], e[1]] = K[e[1], e[0]] = 1 g = igraph.Graph.Adjacency(K.tolist()) G = graphtools.from_igraph(g, attribute=None) - G2 = graphtools.Graph(K, precomputed='adjacency') + G2 = graphtools.Graph(K, precomputed="adjacency") assert np.all(G.K == G2.K) @@ -35,7 +35,7 @@ def test_from_igraph_weighted(): K[e[0], e[1]] = K[e[1], e[0]] = np.random.uniform(0, 1) g = igraph.Graph.Weighted_Adjacency(K.tolist()) G = graphtools.from_igraph(g) - G2 = graphtools.Graph(K, precomputed='adjacency') + G2 = graphtools.Graph(K, precomputed="adjacency") assert np.all(G.K == G2.K) @@ -48,7 +48,7 @@ def test_from_igraph_invalid_precomputed(): e = np.random.choice(n, 2, replace=False) K[e[0], e[1]] = K[e[1], e[0]] = 1 g = igraph.Graph.Adjacency(K.tolist()) - G = graphtools.from_igraph(g, attribute=None, precomputed='affinity') + G = graphtools.from_igraph(g, attribute=None, precomputed="affinity") @warns(UserWarning) @@ -74,19 +74,17 @@ def test_to_igraph(): G = build_graph(data, use_pygsp=True) G2 = G.to_igraph() assert isinstance(G2, igraph.Graph) - assert np.all(np.array(G2.get_adjacency( - attribute="weight").data) == G.W) + assert np.all(np.array(G2.get_adjacency(attribute="weight").data) == G.W) G3 = build_graph(data, use_pygsp=False) G2 = G3.to_igraph() assert isinstance(G2, igraph.Graph) - assert np.all(np.array(G2.get_adjacency( - attribute="weight").data) == G.W) + assert np.all(np.array(G2.get_adjacency(attribute="weight").data) == G.W) def test_pickle_io_knngraph(): G = build_graph(data, knn=5, decay=None) with tempfile.TemporaryDirectory() as tempdir: - path = os.path.join(tempdir, 'tmp.pkl') + path = os.path.join(tempdir, "tmp.pkl") G.to_pickle(path) G_prime = graphtools.read_pickle(path) assert isinstance(G_prime, type(G)) @@ -95,18 +93,17 @@ def test_pickle_io_knngraph(): def test_pickle_io_traditionalgraph(): G = build_graph(data, knn=5, decay=10, thresh=0) with tempfile.TemporaryDirectory() as tempdir: - path = os.path.join(tempdir, 'tmp.pkl') + path = os.path.join(tempdir, "tmp.pkl") G.to_pickle(path) G_prime = graphtools.read_pickle(path) assert isinstance(G_prime, type(G)) def test_pickle_io_landmarkgraph(): - G = build_graph(data, knn=5, decay=None, - n_landmark=data.shape[0] // 2) + G = build_graph(data, knn=5, decay=None, n_landmark=data.shape[0] // 2) L = G.landmark_op with tempfile.TemporaryDirectory() as tempdir: - path = os.path.join(tempdir, 'tmp.pkl') + path = os.path.join(tempdir, "tmp.pkl") G.to_pickle(path) G_prime = graphtools.read_pickle(path) assert isinstance(G_prime, type(G)) @@ -116,7 +113,7 @@ def test_pickle_io_landmarkgraph(): def test_pickle_io_pygspgraph(): G = build_graph(data, knn=5, decay=None, use_pygsp=True) with tempfile.TemporaryDirectory() as tempdir: - path = os.path.join(tempdir, 'tmp.pkl') + path = os.path.join(tempdir, "tmp.pkl") G.to_pickle(path) G_prime = graphtools.read_pickle(path) assert isinstance(G_prime, type(G)) @@ -126,17 +123,18 @@ def test_pickle_io_pygspgraph(): @warns(UserWarning) def test_pickle_bad_pickle(): import pickle + with tempfile.TemporaryDirectory() as tempdir: - path = os.path.join(tempdir, 'tmp.pkl') - with open(path, 'wb') as f: - pickle.dump('hello world', f) + path = os.path.join(tempdir, "tmp.pkl") + with open(path, "wb") as f: + pickle.dump("hello world", f) G = graphtools.read_pickle(path) @warns(UserWarning) def test_to_pygsp_invalid_precomputed(): G = build_graph(data) - G2 = G.to_pygsp(precomputed='adjacency') + G2 = G.to_pygsp(precomputed="adjacency") @warns(UserWarning) @@ -144,6 +142,7 @@ def test_to_pygsp_invalid_use_pygsp(): G = build_graph(data) G2 = G.to_pygsp(use_pygsp=False) + ##################################################### # Check parameters ##################################################### @@ -151,9 +150,9 @@ def test_to_pygsp_invalid_use_pygsp(): @raises(TypeError) def test_unknown_parameter(): - build_graph(data, hello='world') + build_graph(data, hello="world") @raises(ValueError) def test_invalid_graphtype(): - build_graph(data, graphtype='hello world') + build_graph(data, graphtype="hello world") diff --git a/test/test_data.py b/test/test_data.py index f809478..cac8990 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -47,7 +47,7 @@ def test_0_n_pca(): @raises(ValueError) def test_badstring_n_pca(): - build_graph(data, n_pca='foobar') + build_graph(data, n_pca="foobar") @raises(ValueError) @@ -62,7 +62,7 @@ def test_negative_n_pca(): @raises(ValueError) def test_badstring_rank_threshold(): - build_graph(data, n_pca=True, rank_threshold='foobar') + build_graph(data, n_pca=True, rank_threshold="foobar") @raises(ValueError) @@ -73,8 +73,7 @@ def test_negative_rank_threshold(): @raises(ValueError) @warns(RuntimeWarning) def test_True_n_pca_large_threshold(): - build_graph(data, n_pca=True, - rank_threshold=np.linalg.norm(data)**2) + build_graph(data, n_pca=True, rank_threshold=np.linalg.norm(data) ** 2) @warns(RuntimeWarning) @@ -97,15 +96,13 @@ def test_True_n_pca(): def test_True_n_pca_manual_rank_threshold(): - g = build_graph(data, n_pca=True, - rank_threshold=0.1) + g = build_graph(data, n_pca=True, rank_threshold=0.1) assert isinstance(g.n_pca, numbers.Number) assert isinstance(g.rank_threshold, numbers.Number) def test_True_n_pca_auto_rank_threshold(): - g = build_graph(data, n_pca=True, - rank_threshold='auto') + g = build_graph(data, n_pca=True, rank_threshold="auto") assert isinstance(g.n_pca, numbers.Number) assert isinstance(g.rank_threshold, numbers.Number) next_threshold = np.sort(g.data_pca.singular_values_)[2] @@ -114,13 +111,13 @@ def test_True_n_pca_auto_rank_threshold(): def test_goodstring_rank_threshold(): - build_graph(data, n_pca=True, rank_threshold='auto') - build_graph(data, n_pca=True, rank_threshold='AUTO') + build_graph(data, n_pca=True, rank_threshold="auto") + build_graph(data, n_pca=True, rank_threshold="AUTO") def test_string_n_pca(): - build_graph(data, n_pca='auto') - build_graph(data, n_pca='AUTO') + build_graph(data, n_pca="auto") + build_graph(data, n_pca="AUTO") @warns(RuntimeWarning) @@ -135,15 +132,12 @@ def test_too_many_n_pca(): @warns(RuntimeWarning) def test_too_many_n_pca2(): - build_graph(data[:data.shape[1] - 1], - n_pca=data.shape[1] - 1) + build_graph(data[: data.shape[1] - 1], n_pca=data.shape[1] - 1) @warns(RuntimeWarning) def test_precomputed_with_pca(): - build_graph(squareform(pdist(data)), - precomputed='distance', - n_pca=20) + build_graph(squareform(pdist(data)), precomputed="distance", n_pca=20) ##################################################### @@ -196,7 +190,7 @@ def test_anndata_sparse(): def test_transform_dense_pca(): G = build_graph(data, n_pca=20) - assert(np.all(G.data_nu == G.transform(G.data))) + assert np.all(G.data_nu == G.transform(G.data)) assert_raises(ValueError, G.transform, G.data[:, 0]) assert_raises(ValueError, G.transform, G.data[:, None, :15]) assert_raises(ValueError, G.transform, G.data[:, :15]) @@ -204,7 +198,7 @@ def test_transform_dense_pca(): def test_transform_dense_no_pca(): G = build_graph(data, n_pca=None) - assert(np.all(G.data_nu == G.transform(G.data))) + assert np.all(G.data_nu == G.transform(G.data)) assert_raises(ValueError, G.transform, G.data[:, 0]) assert_raises(ValueError, G.transform, G.data[:, None, :15]) assert_raises(ValueError, G.transform, G.data[:, :15]) @@ -212,14 +206,14 @@ def test_transform_dense_no_pca(): def test_transform_sparse_pca(): G = build_graph(data, sparse=True, n_pca=20) - assert(np.all(G.data_nu == G.transform(G.data))) + assert np.all(G.data_nu == G.transform(G.data)) assert_raises(ValueError, G.transform, sp.csr_matrix(G.data)[:, 0]) assert_raises(ValueError, G.transform, sp.csr_matrix(G.data)[:, :15]) def test_transform_sparse_no_pca(): G = build_graph(data, sparse=True, n_pca=None) - assert(np.sum(G.data_nu != G.transform(G.data)) == 0) + assert np.sum(G.data_nu != G.transform(G.data)) == 0 assert_raises(ValueError, G.transform, sp.csr_matrix(G.data)[:, 0]) assert_raises(ValueError, G.transform, sp.csr_matrix(G.data)[:, :15]) @@ -231,16 +225,14 @@ def test_transform_sparse_no_pca(): def test_inverse_transform_dense_pca(): G = build_graph(data, n_pca=data.shape[1] - 1) + np.testing.assert_allclose(G.data, G.inverse_transform(G.data_nu), atol=1e-12) np.testing.assert_allclose( - G.data, G.inverse_transform(G.data_nu), atol=1e-12) - np.testing.assert_allclose(G.data[:, -1, None], - G.inverse_transform(G.data_nu, columns=-1), - atol=1e-12) - np.testing.assert_allclose(G.data[:, 5:7], - G.inverse_transform(G.data_nu, columns=[5, 6]), - atol=1e-12) - assert_raises(IndexError, G.inverse_transform, - G.data_nu, columns=data.shape[1]) + G.data[:, -1, None], G.inverse_transform(G.data_nu, columns=-1), atol=1e-12 + ) + np.testing.assert_allclose( + G.data[:, 5:7], G.inverse_transform(G.data_nu, columns=[5, 6]), atol=1e-12 + ) + assert_raises(IndexError, G.inverse_transform, G.data_nu, columns=data.shape[1]) assert_raises(ValueError, G.inverse_transform, G.data[:, 0]) assert_raises(ValueError, G.inverse_transform, G.data[:, None, :15]) assert_raises(ValueError, G.inverse_transform, G.data[:, :15]) @@ -248,30 +240,26 @@ def test_inverse_transform_dense_pca(): def test_inverse_transform_sparse_svd(): G = build_graph(data, sparse=True, n_pca=data.shape[1] - 1) + np.testing.assert_allclose(data, G.inverse_transform(G.data_nu), atol=1e-12) + np.testing.assert_allclose( + data[:, -1, None], G.inverse_transform(G.data_nu, columns=-1), atol=1e-12 + ) np.testing.assert_allclose( - data, G.inverse_transform(G.data_nu), atol=1e-12) - np.testing.assert_allclose(data[:, -1, None], - G.inverse_transform(G.data_nu, columns=-1), - atol=1e-12) - np.testing.assert_allclose(data[:, 5:7], - G.inverse_transform(G.data_nu, columns=[5, 6]), - atol=1e-12) - assert_raises(IndexError, G.inverse_transform, - G.data_nu, columns=data.shape[1]) + data[:, 5:7], G.inverse_transform(G.data_nu, columns=[5, 6]), atol=1e-12 + ) + assert_raises(IndexError, G.inverse_transform, G.data_nu, columns=data.shape[1]) assert_raises(TypeError, G.inverse_transform, sp.csr_matrix(G.data)[:, 0]) - assert_raises(TypeError, G.inverse_transform, - sp.csr_matrix(G.data)[:, :15]) + assert_raises(TypeError, G.inverse_transform, sp.csr_matrix(G.data)[:, :15]) assert_raises(ValueError, G.inverse_transform, data[:, 0]) - assert_raises(ValueError, G.inverse_transform, - data[:, :15]) + assert_raises(ValueError, G.inverse_transform, data[:, :15]) def test_inverse_transform_dense_no_pca(): G = build_graph(data, n_pca=None) - np.testing.assert_allclose(data[:, 5:7], - G.inverse_transform(G.data_nu, columns=[5, 6]), - atol=1e-12) - assert(np.all(G.data == G.inverse_transform(G.data_nu))) + np.testing.assert_allclose( + data[:, 5:7], G.inverse_transform(G.data_nu, columns=[5, 6]), atol=1e-12 + ) + assert np.all(G.data == G.inverse_transform(G.data_nu)) assert_raises(ValueError, G.inverse_transform, G.data[:, 0]) assert_raises(ValueError, G.inverse_transform, G.data[:, None, :15]) assert_raises(ValueError, G.inverse_transform, G.data[:, :15]) @@ -279,10 +267,9 @@ def test_inverse_transform_dense_no_pca(): def test_inverse_transform_sparse_no_pca(): G = build_graph(data, sparse=True, n_pca=None) - assert(np.sum(G.data != G.inverse_transform(G.data_nu)) == 0) + assert np.sum(G.data != G.inverse_transform(G.data_nu)) == 0 assert_raises(ValueError, G.inverse_transform, sp.csr_matrix(G.data)[:, 0]) - assert_raises(ValueError, G.inverse_transform, - sp.csr_matrix(G.data)[:, :15]) + assert_raises(ValueError, G.inverse_transform, sp.csr_matrix(G.data)[:, :15]) ##################################################### @@ -292,36 +279,36 @@ def test_inverse_transform_sparse_no_pca(): def test_transform_adaptive_pca(): G = build_graph(data, n_pca=True, random_state=42) - assert(np.all(G.data_nu == G.transform(G.data))) + assert np.all(G.data_nu == G.transform(G.data)) assert_raises(ValueError, G.transform, G.data[:, 0]) assert_raises(ValueError, G.transform, G.data[:, None, :15]) assert_raises(ValueError, G.transform, G.data[:, :15]) - G2 = build_graph(data, n_pca=True, - rank_threshold=G.rank_threshold, random_state=42) - assert(np.allclose(G2.data_nu, G2.transform(G2.data))) - assert(np.allclose(G2.data_nu, G.transform(G.data))) + G2 = build_graph(data, n_pca=True, rank_threshold=G.rank_threshold, random_state=42) + assert np.allclose(G2.data_nu, G2.transform(G2.data)) + assert np.allclose(G2.data_nu, G.transform(G.data)) G3 = build_graph(data, n_pca=G2.n_pca, random_state=42) - assert(np.allclose(G3.data_nu, G3.transform(G3.data))) - assert(np.allclose(G3.data_nu, G2.transform(G2.data))) + assert np.allclose(G3.data_nu, G3.transform(G3.data)) + assert np.allclose(G3.data_nu, G2.transform(G2.data)) def test_transform_sparse_adaptive_pca(): G = build_graph(data, sparse=True, n_pca=True, random_state=42) - assert(np.all(G.data_nu == G.transform(G.data))) + assert np.all(G.data_nu == G.transform(G.data)) assert_raises(ValueError, G.transform, sp.csr_matrix(G.data)[:, 0]) assert_raises(ValueError, G.transform, sp.csr_matrix(G.data)[:, :15]) - G2 = build_graph(data, sparse=True, n_pca=True, - rank_threshold=G.rank_threshold, random_state=42) - assert(np.allclose(G2.data_nu, G2.transform(G2.data))) - assert(np.allclose(G2.data_nu, G.transform(G.data))) + G2 = build_graph( + data, sparse=True, n_pca=True, rank_threshold=G.rank_threshold, random_state=42 + ) + assert np.allclose(G2.data_nu, G2.transform(G2.data)) + assert np.allclose(G2.data_nu, G.transform(G.data)) G3 = build_graph(data, sparse=True, n_pca=G2.n_pca, random_state=42) - assert(np.allclose(G3.data_nu, G3.transform(G3.data))) - assert(np.allclose(G3.data_nu, G2.transform(G2.data))) + assert np.allclose(G3.data_nu, G3.transform(G3.data)) + assert np.allclose(G3.data_nu, G2.transform(G2.data)) ############# @@ -331,7 +318,7 @@ def test_transform_sparse_adaptive_pca(): def test_set_params(): G = graphtools.base.Data(data, n_pca=20) - assert G.get_params() == {'n_pca': 20, 'random_state': None} + assert G.get_params() == {"n_pca": 20, "random_state": None} G.set_params(random_state=13) assert G.random_state == 13 assert_raises(ValueError, G.set_params, n_pca=10) diff --git a/test/test_exact.py b/test/test_exact.py index 8dfbc66..a36f2b3 100644 --- a/test/test_exact.py +++ b/test/test_exact.py @@ -14,7 +14,7 @@ squareform, pdist, PCA, - TruncatedSVD + TruncatedSVD, ) ##################################################### @@ -24,88 +24,70 @@ @raises(ValueError) def test_sample_idx_and_precomputed(): - build_graph(squareform(pdist(data)), n_pca=None, - sample_idx=np.arange(10), - precomputed='distance', - decay=10) + build_graph( + squareform(pdist(data)), + n_pca=None, + sample_idx=np.arange(10), + precomputed="distance", + decay=10, + ) @raises(ValueError) def test_invalid_precomputed(): - build_graph(squareform(pdist(data)), n_pca=None, - precomputed='hello world', - decay=10) + build_graph( + squareform(pdist(data)), n_pca=None, precomputed="hello world", decay=10 + ) @raises(ValueError) def test_precomputed_not_square(): - build_graph(data, n_pca=None, precomputed='distance', - decay=10) + build_graph(data, n_pca=None, precomputed="distance", decay=10) @raises(ValueError) def test_build_exact_with_sample_idx(): - build_graph(data, graphtype='exact', sample_idx=np.arange(len(data)), - decay=10) + build_graph(data, graphtype="exact", sample_idx=np.arange(len(data)), decay=10) @warns(RuntimeWarning) def test_precomputed_with_pca(): - build_graph(squareform(pdist(data)), - precomputed='distance', - n_pca=20, - decay=10) + build_graph(squareform(pdist(data)), precomputed="distance", n_pca=20, decay=10) @raises(ValueError) def test_exact_no_decay(): - build_graph(data, graphtype='exact', - decay=None) + build_graph(data, graphtype="exact", decay=None) @raises(ValueError) def test_exact_no_knn_no_bandwidth(): - build_graph(data, graphtype='exact', - knn=None, bandwidth=None) + build_graph(data, graphtype="exact", knn=None, bandwidth=None) @raises(ValueError) def test_precomputed_negative(): - build_graph(np.random.normal(0, 1, [200, 200]), - precomputed='distance', - n_pca=None) + build_graph(np.random.normal(0, 1, [200, 200]), precomputed="distance", n_pca=None) @raises(ValueError) def test_precomputed_invalid(): - build_graph(np.random.uniform(0, 1, [200, 200]), - precomputed='invalid', - n_pca=None) + build_graph(np.random.uniform(0, 1, [200, 200]), precomputed="invalid", n_pca=None) @warns(RuntimeWarning) def test_duplicate_data(): - build_graph(np.vstack([data, data[:10]]), - n_pca=20, - decay=10, - thresh=0) + build_graph(np.vstack([data, data[:10]]), n_pca=20, decay=10, thresh=0) @warns(RuntimeWarning) def test_many_duplicate_data(): - build_graph(np.vstack([data, data]), - n_pca=20, - decay=10, - thresh=0) + build_graph(np.vstack([data, data]), n_pca=20, decay=10, thresh=0) @warns(UserWarning) def test_k_too_large(): - build_graph(data, - n_pca=20, - decay=10, - knn=len(data) - 1, - thresh=0) + build_graph(data, n_pca=20, decay=10, knn=len(data) - 1, thresh=0) ##################################################### @@ -118,60 +100,76 @@ def test_exact_graph(): a = 13 n_pca = 20 bandwidth_scale = 1.3 - data_small = data[np.random.choice( - len(data), len(data) // 2, replace=False)] - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data_small) + data_small = data[np.random.choice(len(data), len(data) // 2, replace=False)] + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data_small) data_small_nu = pca.transform(data_small) - pdx = squareform(pdist(data_small_nu, metric='euclidean')) + pdx = squareform(pdist(data_small_nu, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) * bandwidth_scale weighted_pdx = (pdx.T / epsilon).T - K = np.exp(-1 * weighted_pdx**a) + K = np.exp(-1 * weighted_pdx ** a) W = K + K.T W = np.divide(W, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data_small, thresh=0, n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - bandwidth_scale=bandwidth_scale, - use_pygsp=True) - assert(G.N == G2.N) + G2 = build_graph( + data_small, + thresh=0, + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + bandwidth_scale=bandwidth_scale, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(pdx, n_pca=None, precomputed='distance', - bandwidth_scale=bandwidth_scale, - decay=a, knn=k - 1, random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + pdx, + n_pca=None, + precomputed="distance", + bandwidth_scale=bandwidth_scale, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(sp.coo_matrix(K), n_pca=None, - precomputed='affinity', - random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + sp.coo_matrix(K), + n_pca=None, + precomputed="affinity", + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(K, n_pca=None, - precomputed='affinity', - random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + K, n_pca=None, precomputed="affinity", random_state=42, use_pygsp=True + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(W, n_pca=None, - precomputed='adjacency', - random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + W, n_pca=None, precomputed="adjacency", random_state=42, use_pygsp=True + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) def test_truncated_exact_graph(): @@ -179,55 +177,70 @@ def test_truncated_exact_graph(): a = 13 n_pca = 20 thresh = 1e-4 - data_small = data[np.random.choice( - len(data), len(data) // 2, replace=False)] - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data_small) + data_small = data[np.random.choice(len(data), len(data) // 2, replace=False)] + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data_small) data_small_nu = pca.transform(data_small) - pdx = squareform(pdist(data_small_nu, metric='euclidean')) + pdx = squareform(pdist(data_small_nu, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) weighted_pdx = (pdx.T / epsilon).T - K = np.exp(-1 * weighted_pdx**a) + K = np.exp(-1 * weighted_pdx ** a) K[K < thresh] = 0 W = K + K.T W = np.divide(W, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data_small, thresh=thresh, - graphtype='exact', - n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True) - assert(G.N == G2.N) + G2 = build_graph( + data_small, + thresh=thresh, + graphtype="exact", + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(pdx, n_pca=None, precomputed='distance', - thresh=thresh, - decay=a, knn=k - 1, random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + pdx, + n_pca=None, + precomputed="distance", + thresh=thresh, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(K, n_pca=None, - precomputed='affinity', - thresh=thresh, - random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + K, + n_pca=None, + precomputed="affinity", + thresh=thresh, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(W, n_pca=None, - precomputed='adjacency', - random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + W, n_pca=None, precomputed="adjacency", random_state=42, use_pygsp=True + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) def test_truncated_exact_graph_sparse(): @@ -235,54 +248,72 @@ def test_truncated_exact_graph_sparse(): a = 13 n_pca = 20 thresh = 1e-4 - data_small = data[np.random.choice( - len(data), len(data) // 2, replace=False)] - pca = TruncatedSVD(n_pca, - random_state=42).fit(data_small) + data_small = data[np.random.choice(len(data), len(data) // 2, replace=False)] + pca = TruncatedSVD(n_pca, random_state=42).fit(data_small) data_small_nu = pca.transform(data_small) - pdx = squareform(pdist(data_small_nu, metric='euclidean')) + pdx = squareform(pdist(data_small_nu, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) weighted_pdx = (pdx.T / epsilon).T - K = np.exp(-1 * weighted_pdx**a) + K = np.exp(-1 * weighted_pdx ** a) K[K < thresh] = 0 W = K + K.T W = np.divide(W, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(sp.coo_matrix(data_small), thresh=thresh, - graphtype='exact', - n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True) - assert(G.N == G2.N) + G2 = build_graph( + sp.coo_matrix(data_small), + thresh=thresh, + graphtype="exact", + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_allclose(G2.W.toarray(), G.W.toarray()) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(sp.bsr_matrix(pdx), n_pca=None, precomputed='distance', - thresh=thresh, - decay=a, knn=k - 1, random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + sp.bsr_matrix(pdx), + n_pca=None, + precomputed="distance", + thresh=thresh, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(sp.lil_matrix(K), n_pca=None, - precomputed='affinity', - thresh=thresh, - random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + sp.lil_matrix(K), + n_pca=None, + precomputed="affinity", + thresh=thresh, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(sp.dok_matrix(W), n_pca=None, - precomputed='adjacency', - random_state=42, use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + sp.dok_matrix(W), + n_pca=None, + precomputed="adjacency", + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) def test_truncated_exact_graph_no_pca(): @@ -290,38 +321,47 @@ def test_truncated_exact_graph_no_pca(): a = 13 n_pca = None thresh = 1e-4 - data_small = data[np.random.choice( - len(data), len(data) // 10, replace=False)] - pdx = squareform(pdist(data_small, metric='euclidean')) + data_small = data[np.random.choice(len(data), len(data) // 10, replace=False)] + pdx = squareform(pdist(data_small, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) weighted_pdx = (pdx.T / epsilon).T - K = np.exp(-1 * weighted_pdx**a) + K = np.exp(-1 * weighted_pdx ** a) K[K < thresh] = 0 W = K + K.T W = np.divide(W, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data_small, thresh=thresh, - graphtype='exact', - n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True) - assert(G.N == G2.N) + G2 = build_graph( + data_small, + thresh=thresh, + graphtype="exact", + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - G2 = build_graph(sp.csr_matrix(data_small), thresh=thresh, - graphtype='exact', - n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True) - assert(G.N == G2.N) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + G2 = build_graph( + sp.csr_matrix(data_small), + thresh=thresh, + graphtype="exact", + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G.W != G2.W).nnz == 0) - assert((G2.W != G.W).sum() == 0) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) + assert (G.W != G2.W).nnz == 0 + assert (G2.W != G.W).sum() == 0 + assert isinstance(G2, graphtools.graphs.TraditionalGraph) def test_exact_graph_fixed_bandwidth(): @@ -329,38 +369,48 @@ def test_exact_graph_fixed_bandwidth(): knn = None bandwidth = 2 n_pca = 20 - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data) + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data) data_nu = pca.transform(data) - pdx = squareform(pdist(data_nu, metric='euclidean')) - K = np.exp(-1 * (pdx / bandwidth)**decay) + pdx = squareform(pdist(data_nu, metric="euclidean")) + K = np.exp(-1 * (pdx / bandwidth) ** decay) K = K + K.T W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=n_pca, - graphtype='exact', knn=knn, - decay=decay, bandwidth=bandwidth, - random_state=42, - thresh=0, - use_pygsp=True) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - assert(G.N == G2.N) + G2 = build_graph( + data, + n_pca=n_pca, + graphtype="exact", + knn=knn, + decay=decay, + bandwidth=bandwidth, + random_state=42, + thresh=0, + use_pygsp=True, + ) + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + assert G.N == G2.N np.testing.assert_allclose(G.dw, G2.dw) np.testing.assert_allclose((G2.W - G.W).data, 0, atol=1e-14) bandwidth = np.random.gamma(5, 0.5, len(data)) - K = np.exp(-1 * (pdx.T / bandwidth).T**decay) + K = np.exp(-1 * (pdx.T / bandwidth).T ** decay) K = K + K.T W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=n_pca, - graphtype='exact', knn=knn, - decay=decay, bandwidth=bandwidth, - random_state=42, - thresh=0, - use_pygsp=True) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - assert(G.N == G2.N) + G2 = build_graph( + data, + n_pca=n_pca, + graphtype="exact", + knn=knn, + decay=decay, + bandwidth=bandwidth, + random_state=42, + thresh=0, + use_pygsp=True, + ) + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + assert G.N == G2.N np.testing.assert_allclose(G.dw, G2.dw) np.testing.assert_allclose((G2.W - G.W).data, 0, atol=1e-14) @@ -371,39 +421,49 @@ def test_exact_graph_callable_bandwidth(): bandwidth = lambda x: 2 n_pca = 20 thresh = 1e-4 - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data) + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data) data_nu = pca.transform(data) - pdx = squareform(pdist(data_nu, metric='euclidean')) - K = np.exp(-1 * (pdx / bandwidth(pdx))**decay) + pdx = squareform(pdist(data_nu, metric="euclidean")) + K = np.exp(-1 * (pdx / bandwidth(pdx)) ** decay) K[K < thresh] = 0 K = K + K.T W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=n_pca, knn=knn - 1, - decay=decay, bandwidth=bandwidth, - random_state=42, - thresh=thresh, - use_pygsp=True) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - assert(G.N == G2.N) + G2 = build_graph( + data, + n_pca=n_pca, + knn=knn - 1, + decay=decay, + bandwidth=bandwidth, + random_state=42, + thresh=thresh, + use_pygsp=True, + ) + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G2.W != G.W).sum() == 0) - assert((G.W != G2.W).nnz == 0) + assert (G2.W != G.W).sum() == 0 + assert (G.W != G2.W).nnz == 0 bandwidth = lambda x: np.percentile(x, 10, axis=1) - K = np.exp(-1 * (pdx / bandwidth(pdx))**decay) + K = np.exp(-1 * (pdx / bandwidth(pdx)) ** decay) K[K < thresh] = 0 K = K + K.T W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=n_pca, knn=knn - 1, - decay=decay, bandwidth=bandwidth, - random_state=42, - thresh=thresh, - use_pygsp=True) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - assert(G.N == G2.N) + G2 = build_graph( + data, + n_pca=n_pca, + knn=knn - 1, + decay=decay, + bandwidth=bandwidth, + random_state=42, + thresh=thresh, + use_pygsp=True, + ) + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + assert G.N == G2.N np.testing.assert_allclose(G.dw, G2.dw) np.testing.assert_allclose((G2.W - G.W).data, 0, atol=1e-14) @@ -412,46 +472,78 @@ def test_exact_graph_callable_bandwidth(): # Check anisotropy ##################################################### + def test_exact_graph_anisotropy(): k = 3 a = 13 n_pca = 20 anisotropy = 0.9 - data_small = data[np.random.choice( - len(data), len(data) // 2, replace=False)] - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data_small) + data_small = data[np.random.choice(len(data), len(data) // 2, replace=False)] + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data_small) data_small_nu = pca.transform(data_small) - pdx = squareform(pdist(data_small_nu, metric='euclidean')) + pdx = squareform(pdist(data_small_nu, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) weighted_pdx = (pdx.T / epsilon).T - K = np.exp(-1 * weighted_pdx**a) + K = np.exp(-1 * weighted_pdx ** a) K = K + K.T K = np.divide(K, 2) d = K.sum(1) W = K / (np.outer(d, d) ** anisotropy) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data_small, thresh=0, n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True, anisotropy=anisotropy) - assert(isinstance(G2, graphtools.graphs.TraditionalGraph)) - assert(G.N == G2.N) + G2 = build_graph( + data_small, + thresh=0, + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + anisotropy=anisotropy, + ) + assert isinstance(G2, graphtools.graphs.TraditionalGraph) + assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert((G2.W != G.W).sum() == 0) - assert((G.W != G2.W).nnz == 0) - assert_raises(ValueError, build_graph, - data_small, thresh=0, n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True, anisotropy=-1) - assert_raises(ValueError, build_graph, - data_small, thresh=0, n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True, anisotropy=2) - assert_raises(ValueError, build_graph, - data_small, thresh=0, n_pca=n_pca, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True, anisotropy='invalid') + assert (G2.W != G.W).sum() == 0 + assert (G.W != G2.W).nnz == 0 + assert_raises( + ValueError, + build_graph, + data_small, + thresh=0, + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + anisotropy=-1, + ) + assert_raises( + ValueError, + build_graph, + data_small, + thresh=0, + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + anisotropy=2, + ) + assert_raises( + ValueError, + build_graph, + data_small, + thresh=0, + n_pca=n_pca, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + anisotropy="invalid", + ) + ##################################################### # Check extra functionality @@ -459,8 +551,7 @@ def test_exact_graph_anisotropy(): def test_shortest_path_affinity(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=15) D = -1 * np.where(G.K != 0, np.log(np.where(G.K != 0, G.K, np.nan)), 0) P = graph_shortest_path(D) @@ -468,57 +559,52 @@ def test_shortest_path_affinity(): P[np.where(P == 0)] = np.inf # diagonal should actually be zero np.fill_diagonal(P, 0) - np.testing.assert_allclose(P, G.shortest_path(distance='affinity')) + np.testing.assert_allclose(P, G.shortest_path(distance="affinity")) np.testing.assert_allclose(P, G.shortest_path()) def test_shortest_path_affinity_precomputed(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=15) - G = graphtools.Graph(G.K, precomputed='affinity') + G = graphtools.Graph(G.K, precomputed="affinity") D = -1 * np.where(G.K != 0, np.log(np.where(G.K != 0, G.K, np.nan)), 0) P = graph_shortest_path(D) # sklearn returns 0 if no path exists P[np.where(P == 0)] = np.inf # diagonal should actually be zero np.fill_diagonal(P, 0) - np.testing.assert_allclose(P, G.shortest_path(distance='affinity')) + np.testing.assert_allclose(P, G.shortest_path(distance="affinity")) np.testing.assert_allclose(P, G.shortest_path()) @raises(NotImplementedError) def test_shortest_path_decay_constant(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=15) - G.shortest_path(distance='constant') + G.shortest_path(distance="constant") @raises(NotImplementedError) def test_shortest_path_precomputed_decay_constant(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=15) - G = graphtools.Graph(G.K, precomputed='affinity') - G.shortest_path(distance='constant') + G = graphtools.Graph(G.K, precomputed="affinity") + G.shortest_path(distance="constant") @raises(NotImplementedError) def test_shortest_path_decay_data(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=15) - G.shortest_path(distance='data') + G.shortest_path(distance="data") @raises(ValueError) def test_shortest_path_precomputed_data(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=15) - G = graphtools.Graph(G.K, precomputed='affinity') - G.shortest_path(distance='data') + G = graphtools.Graph(G.K, precomputed="affinity") + G.shortest_path(distance="data") ##################################################### @@ -529,8 +615,8 @@ def test_shortest_path_precomputed_data(): def test_build_dense_exact_kernel_to_data(**kwargs): G = build_graph(data, decay=10, thresh=0) n = G.data.shape[0] - K = G.build_kernel_to_data(data[:n // 2, :]) - assert(K.shape == (n // 2, n)) + K = G.build_kernel_to_data(data[: n // 2, :]) + assert K.shape == (n // 2, n) K = G.build_kernel_to_data(G.data, knn=G.knn + 1) np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0) K = G.build_kernel_to_data(G.data_nu, knn=G.knn + 1) @@ -540,8 +626,8 @@ def test_build_dense_exact_kernel_to_data(**kwargs): def test_build_dense_exact_callable_bw_kernel_to_data(**kwargs): G = build_graph(data, decay=10, thresh=0, bandwidth=lambda x: x.mean(1)) n = G.data.shape[0] - K = G.build_kernel_to_data(data[:n // 2, :]) - assert(K.shape == (n // 2, n)) + K = G.build_kernel_to_data(data[: n // 2, :]) + assert K.shape == (n // 2, n) K = G.build_kernel_to_data(G.data, knn=G.knn + 1) np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0) K = G.build_kernel_to_data(G.data_nu, knn=G.knn + 1) @@ -551,8 +637,8 @@ def test_build_dense_exact_callable_bw_kernel_to_data(**kwargs): def test_build_sparse_exact_kernel_to_data(**kwargs): G = build_graph(data, decay=10, thresh=0, sparse=True) n = G.data.shape[0] - K = G.build_kernel_to_data(data[:n // 2, :]) - assert(K.shape == (n // 2, n)) + K = G.build_kernel_to_data(data[: n // 2, :]) + assert K.shape == (n // 2, n) K = G.build_kernel_to_data(G.data, knn=G.knn + 1) np.testing.assert_equal(G.kernel - (K + K.T) / 2, 0) K = G.build_kernel_to_data(G.data_nu, knn=G.knn + 1) @@ -564,14 +650,15 @@ def test_exact_interpolate(): assert_raises(ValueError, G.interpolate, data) pca_data = PCA(2).fit_transform(data) transitions = G.extend_to_data(data) - assert(np.all(G.interpolate(pca_data, Y=data) == - G.interpolate(pca_data, transitions=transitions))) + assert np.all( + G.interpolate(pca_data, Y=data) + == G.interpolate(pca_data, transitions=transitions) + ) @raises(ValueError) def test_precomputed_interpolate(): - G = build_graph(squareform(pdist(data)), n_pca=None, - precomputed='distance') + G = build_graph(squareform(pdist(data)), n_pca=None, precomputed="distance") G.build_kernel_to_data(data) @@ -588,24 +675,25 @@ def test_verbose(): def test_set_params(): G = build_graph(data, decay=10, thresh=0) - assert G.get_params() == {'n_pca': 20, - 'random_state': 42, - 'kernel_symm': '+', - 'theta': None, - 'knn': 3, - 'anisotropy': 0, - 'decay': 10, - 'bandwidth': None, - 'bandwidth_scale': 1, - 'distance': 'euclidean', - 'precomputed': None} + assert G.get_params() == { + "n_pca": 20, + "random_state": 42, + "kernel_symm": "+", + "theta": None, + "knn": 3, + "anisotropy": 0, + "decay": 10, + "bandwidth": None, + "bandwidth_scale": 1, + "distance": "euclidean", + "precomputed": None, + } assert_raises(ValueError, G.set_params, knn=15) assert_raises(ValueError, G.set_params, decay=15) - assert_raises(ValueError, G.set_params, distance='manhattan') - assert_raises(ValueError, G.set_params, precomputed='distance') + assert_raises(ValueError, G.set_params, distance="manhattan") + assert_raises(ValueError, G.set_params, precomputed="distance") assert_raises(ValueError, G.set_params, bandwidth=5) assert_raises(ValueError, G.set_params, bandwidth_scale=5) - G.set_params(knn=G.knn, - decay=G.decay, - distance=G.distance, - precomputed=G.precomputed) + G.set_params( + knn=G.knn, decay=G.decay, distance=G.distance, precomputed=G.precomputed + ) diff --git a/test/test_knn.py b/test/test_knn.py index 52e0223..3640e5b 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -1,6 +1,8 @@ from __future__ import print_function, division from sklearn.utils.graph import graph_shortest_path from scipy.spatial.distance import pdist, squareform +from sklearn.utils.testing import assert_raise_message, assert_warns_message +import warnings from load_tests import ( graphtools, np, @@ -24,67 +26,52 @@ @raises(ValueError) def test_build_knn_with_exact_alpha(): - build_graph(data, graphtype='knn', decay=10, thresh=0) + build_graph(data, graphtype="knn", decay=10, thresh=0) @raises(ValueError) def test_build_knn_with_precomputed(): - build_graph(data, n_pca=None, graphtype='knn', precomputed='distance') + build_graph(data, n_pca=None, graphtype="knn", precomputed="distance") @raises(ValueError) def test_build_knn_with_sample_idx(): - build_graph(data, graphtype='knn', sample_idx=np.arange(len(data))) + build_graph(data, graphtype="knn", sample_idx=np.arange(len(data))) @warns(RuntimeWarning) def test_duplicate_data(): - build_graph(np.vstack([data, data[:10]]), - n_pca=20, - decay=10, - thresh=1e-4) + build_graph(np.vstack([data, data[:10]]), n_pca=20, decay=10, thresh=1e-4) @warns(RuntimeWarning) def test_duplicate_data_many(): - build_graph(np.vstack([data, data[:21]]), - n_pca=20, - decay=10, - thresh=1e-4) + build_graph(np.vstack([data, data[:21]]), n_pca=20, decay=10, thresh=1e-4) @warns(UserWarning) def test_balltree_cosine(): - build_graph(data, - n_pca=20, - decay=10, - distance='cosine', - thresh=1e-4) + build_graph(data, n_pca=20, decay=10, distance="cosine", thresh=1e-4) @warns(UserWarning) def test_k_too_large(): - build_graph(data, - n_pca=20, - decay=10, - knn=len(data) - 1, - thresh=1e-4) + build_graph(data, n_pca=20, decay=10, knn=len(data) - 1, thresh=1e-4) + + +@warns(UserWarning) +def test_knnmax_too_large(): + build_graph(data, n_pca=20, decay=10, knn=10, knn_max=9, thresh=1e-4) @warns(UserWarning) def test_bandwidth_no_decay(): - build_graph(data, - n_pca=20, - decay=None, - bandwidth=3, - thresh=1e-4) + build_graph(data, n_pca=20, decay=None, bandwidth=3, thresh=1e-4) @raises(ValueError) def test_knn_no_knn_no_bandwidth(): - build_graph(data, graphtype='knn', - knn=None, bandwidth=None, - thresh=1e-4) + build_graph(data, graphtype="knn", knn=None, bandwidth=None, thresh=1e-4) ##################################################### @@ -95,9 +82,9 @@ def test_knn_no_knn_no_bandwidth(): def test_knn_graph(): k = 3 n_pca = 20 - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data) + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data) data_nu = pca.transform(data) - pdx = squareform(pdist(data_nu, metric='euclidean')) + pdx = squareform(pdist(data_nu, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) K = np.empty_like(pdx) @@ -109,22 +96,40 @@ def test_knn_graph(): W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=n_pca, - decay=None, knn=k - 1, random_state=42, - use_pygsp=True) + G2 = build_graph( + data, n_pca=n_pca, decay=None, knn=k - 1, random_state=42, use_pygsp=True + ) assert G.N == G2.N np.testing.assert_equal(G.dw, G2.dw) - assert (G.W != G2.W).nnz == 0 - assert (G2.W != G.W).sum() == 0 + assert (G.W - G2.W).nnz == 0 + assert (G2.W - G.W).sum() == 0 assert isinstance(G2, graphtools.graphs.kNNGraph) + K2 = G2.build_kernel_to_data(G2.data_nu, knn=k) + K2 = (K2 + K2.T) / 2 + assert (G2.K - K2).nnz == 0 + assert ( + G2.build_kernel_to_data(G2.data_nu, knn=data.shape[0]).nnz + == data.shape[0] * data.shape[0] + ) + assert_warns_message( + UserWarning, + "Cannot set knn ({}) to be greater than " + "n_samples ({}). Setting knn={}".format( + data.shape[0] + 1, data.shape[0], data.shape[0] + ), + G2.build_kernel_to_data, + Y=G2.data_nu, + knn=data.shape[0] + 1, + ) + def test_knn_graph_sparse(): k = 3 n_pca = 20 pca = TruncatedSVD(n_pca, random_state=42).fit(data) data_nu = pca.transform(data) - pdx = squareform(pdist(data_nu, metric='euclidean')) + pdx = squareform(pdist(data_nu, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) K = np.empty_like(pdx) @@ -136,9 +141,14 @@ def test_knn_graph_sparse(): W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(sp.coo_matrix(data), n_pca=n_pca, - decay=None, knn=k - 1, random_state=42, - use_pygsp=True) + G2 = build_graph( + sp.coo_matrix(data), + n_pca=n_pca, + decay=None, + knn=k - 1, + random_state=42, + use_pygsp=True, + ) assert G.N == G2.N np.testing.assert_allclose(G2.W.toarray(), G.W.toarray()) assert isinstance(G2, graphtools.graphs.kNNGraph) @@ -150,24 +160,83 @@ def test_sparse_alpha_knn_graph(): a = 0.45 thresh = 0.01 bandwidth_scale = 1.3 - pdx = squareform(pdist(data, metric='euclidean')) + pdx = squareform(pdist(data, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) * bandwidth_scale pdx = (pdx.T / epsilon).T - K = np.exp(-1 * pdx**a) + K = np.exp(-1 * pdx ** a) K = K + K.T W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=None, # n_pca, - decay=a, knn=k - 1, thresh=thresh, - bandwidth_scale=bandwidth_scale, - random_state=42, use_pygsp=True) + G2 = build_graph( + data, + n_pca=None, # n_pca, + decay=a, + knn=k - 1, + thresh=thresh, + bandwidth_scale=bandwidth_scale, + random_state=42, + use_pygsp=True, + ) assert np.abs(G.W - G2.W).max() < thresh assert G.N == G2.N assert isinstance(G2, graphtools.graphs.kNNGraph) +def test_knnmax(): + data = datasets.make_swiss_roll()[0] + k = 5 + k_max = 10 + a = 0.45 + thresh = 0 + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "K should be symmetric", RuntimeWarning) + G = build_graph( + data, + n_pca=None, # n_pca, + decay=a, + knn=k - 1, + knn_max=k_max - 1, + thresh=0, + random_state=42, + kernel_symm=None, + ) + assert np.all((G.K > 0).sum(axis=1) == k_max) + + pdx = squareform(pdist(data, metric="euclidean")) + knn_dist = np.partition(pdx, k, axis=1)[:, :k] + knn_max_dist = np.max(np.partition(pdx, k_max, axis=1)[:, :k_max], axis=1) + epsilon = np.max(knn_dist, axis=1) + pdx_scale = (pdx.T / epsilon).T + K = np.where(pdx <= knn_max_dist[:, None], np.exp(-1 * pdx_scale ** a), 0) + K = K + K.T + W = np.divide(K, 2) + np.fill_diagonal(W, 0) + G = pygsp.graphs.Graph(W) + G2 = build_graph( + data, + n_pca=None, # n_pca, + decay=a, + knn=k - 1, + knn_max=k_max - 1, + thresh=0, + random_state=42, + use_pygsp=True, + ) + assert isinstance(G2, graphtools.graphs.kNNGraph) + assert G.N == G2.N + assert np.all(G.dw == G2.dw) + assert (G.W - G2.W).nnz == 0 + + +def test_thresh_small(): + data = datasets.make_swiss_roll()[0] + G = graphtools.Graph(data, thresh=1e-30) + assert G.thresh == np.finfo("float").eps + + def test_knn_graph_fixed_bandwidth(): k = None decay = 5 @@ -175,46 +244,57 @@ def test_knn_graph_fixed_bandwidth(): bandwidth_scale = 1.3 n_pca = 20 thresh = 1e-4 - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data) + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data) data_nu = pca.transform(data) - pdx = squareform(pdist(data_nu, metric='euclidean')) + pdx = squareform(pdist(data_nu, metric="euclidean")) K = np.exp(-1 * np.power(pdx / (bandwidth * bandwidth_scale), decay)) K[K < thresh] = 0 K = K + K.T W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=n_pca, - decay=decay, bandwidth=bandwidth, - bandwidth_scale=bandwidth_scale, - knn=k, random_state=42, - thresh=thresh, - use_pygsp=True) + G2 = build_graph( + data, + n_pca=n_pca, + decay=decay, + bandwidth=bandwidth, + bandwidth_scale=bandwidth_scale, + knn=k, + random_state=42, + thresh=thresh, + search_multiplier=2, + use_pygsp=True, + ) assert isinstance(G2, graphtools.graphs.kNNGraph) np.testing.assert_array_equal(G.N, G2.N) np.testing.assert_array_equal(G.d, G2.d) np.testing.assert_allclose( - (G.W - G2.W).data, - np.zeros_like((G.W - G2.W).data), atol=1e-14) + (G.W - G2.W).data, np.zeros_like((G.W - G2.W).data), atol=1e-14 + ) bandwidth = np.random.gamma(20, 0.5, len(data)) - K = np.exp(-1 * (pdx.T / (bandwidth * bandwidth_scale)).T**decay) + K = np.exp(-1 * (pdx.T / (bandwidth * bandwidth_scale)).T ** decay) K[K < thresh] = 0 K = K + K.T W = np.divide(K, 2) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data, n_pca=n_pca, - decay=decay, bandwidth=bandwidth, - bandwidth_scale=bandwidth_scale, - knn=k, random_state=42, - thresh=thresh, - use_pygsp=True) + G2 = build_graph( + data, + n_pca=n_pca, + decay=decay, + bandwidth=bandwidth, + bandwidth_scale=bandwidth_scale, + knn=k, + random_state=42, + thresh=thresh, + use_pygsp=True, + ) assert isinstance(G2, graphtools.graphs.kNNGraph) np.testing.assert_array_equal(G.N, G2.N) np.testing.assert_allclose(G.dw, G2.dw, atol=1e-14) np.testing.assert_allclose( - (G.W - G2.W).data, - np.zeros_like((G.W - G2.W).data), atol=1e-14) + (G.W - G2.W).data, np.zeros_like((G.W - G2.W).data), atol=1e-14 + ) @raises(NotImplementedError) @@ -224,38 +304,50 @@ def test_knn_graph_callable_bandwidth(): bandwidth = lambda x: 2 n_pca = 20 thresh = 1e-4 - build_graph(data, n_pca=n_pca, knn=k - 1, - decay=decay, bandwidth=bandwidth, - random_state=42, - thresh=thresh, graphtype='knn') + build_graph( + data, + n_pca=n_pca, + knn=k - 1, + decay=decay, + bandwidth=bandwidth, + random_state=42, + thresh=thresh, + graphtype="knn", + ) @warns(UserWarning) def test_knn_graph_sparse_no_pca(): - build_graph(sp.coo_matrix(data), n_pca=None, # n_pca, - decay=10, knn=3, thresh=1e-4, - random_state=42, use_pygsp=True) + build_graph( + sp.coo_matrix(data), + n_pca=None, # n_pca, + decay=10, + knn=3, + thresh=1e-4, + random_state=42, + use_pygsp=True, + ) ##################################################### # Check anisotropy ##################################################### + def test_knn_graph_anisotropy(): k = 3 a = 13 n_pca = 20 anisotropy = 0.9 thresh = 1e-4 - data_small = data[np.random.choice( - len(data), len(data) // 2, replace=False)] - pca = PCA(n_pca, svd_solver='randomized', random_state=42).fit(data_small) + data_small = data[np.random.choice(len(data), len(data) // 2, replace=False)] + pca = PCA(n_pca, svd_solver="randomized", random_state=42).fit(data_small) data_small_nu = pca.transform(data_small) - pdx = squareform(pdist(data_small_nu, metric='euclidean')) + pdx = squareform(pdist(data_small_nu, metric="euclidean")) knn_dist = np.partition(pdx, k, axis=1)[:, :k] epsilon = np.max(knn_dist, axis=1) weighted_pdx = (pdx.T / epsilon).T - K = np.exp(-1 * weighted_pdx**a) + K = np.exp(-1 * weighted_pdx ** a) K[K < thresh] = 0 K = K + K.T K = np.divide(K, 2) @@ -263,10 +355,16 @@ def test_knn_graph_anisotropy(): W = K / (np.outer(d, d) ** anisotropy) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = build_graph(data_small, n_pca=n_pca, - thresh=thresh, - decay=a, knn=k - 1, random_state=42, - use_pygsp=True, anisotropy=anisotropy) + G2 = build_graph( + data_small, + n_pca=n_pca, + thresh=thresh, + decay=a, + knn=k - 1, + random_state=42, + use_pygsp=True, + anisotropy=anisotropy, + ) assert isinstance(G2, graphtools.graphs.kNNGraph) assert G.N == G2.N np.testing.assert_allclose(G.dw, G2.dw, atol=1e-14, rtol=1e-14) @@ -281,7 +379,7 @@ def test_knn_graph_anisotropy(): def test_build_dense_knn_kernel_to_data(): G = build_graph(data, decay=None) n = G.data.shape[0] - K = G.build_kernel_to_data(data[:n // 2, :], knn=G.knn + 1) + K = G.build_kernel_to_data(data[: n // 2, :], knn=G.knn + 1) assert K.shape == (n // 2, n) K = G.build_kernel_to_data(G.data, knn=G.knn + 1) assert (G.kernel - (K + K.T) / 2).nnz == 0 @@ -292,7 +390,7 @@ def test_build_dense_knn_kernel_to_data(): def test_build_sparse_knn_kernel_to_data(): G = build_graph(data, decay=None, sparse=True) n = G.data.shape[0] - K = G.build_kernel_to_data(data[:n // 2, :], knn=G.knn + 1) + K = G.build_kernel_to_data(data[: n // 2, :], knn=G.knn + 1) assert K.shape == (n // 2, n) K = G.build_kernel_to_data(G.data, knn=G.knn + 1) assert (G.kernel - (K + K.T) / 2).nnz == 0 @@ -305,8 +403,41 @@ def test_knn_interpolate(): assert_raises(ValueError, G.interpolate, data) pca_data = PCA(2).fit_transform(data) transitions = G.extend_to_data(data) - np.testing.assert_equal(G.interpolate(pca_data, Y=data), G.interpolate( - pca_data, transitions=transitions)) + np.testing.assert_equal( + G.interpolate(pca_data, Y=data), + G.interpolate(pca_data, transitions=transitions), + ) + + +def test_knn_interpolate_wrong_shape(): + G = build_graph(data, n_pca=10, decay=None) + transitions = assert_raise_message( + ValueError, + "Expected a 2D matrix. Y has shape ({},)".format(data.shape[0]), + G.extend_to_data, + data[:, 0], + ) + transitions = assert_raise_message( + ValueError, + "Expected a 2D matrix. Y has shape ({}, {}, 1)".format( + data.shape[0], data.shape[1] + ), + G.extend_to_data, + data[:, :, None], + ) + transitions = assert_raise_message( + ValueError, + "Y must be of shape either (n, 64) or (n, 10)", + G.extend_to_data, + data[:, : data.shape[1] // 2], + ) + G = build_graph(data, n_pca=None, decay=None) + transitions = assert_raise_message( + ValueError, + "Y must be of shape (n, 64)", + G.extend_to_data, + data[:, : data.shape[1] // 2], + ) ################################################# @@ -315,34 +446,31 @@ def test_knn_interpolate(): def test_shortest_path_constant(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=None) P = graph_shortest_path(G.K) # sklearn returns 0 if no path exists P[np.where(P == 0)] = np.inf # diagonal should actually be zero np.fill_diagonal(P, 0) - np.testing.assert_equal(P, G.shortest_path(distance='constant')) + np.testing.assert_equal(P, G.shortest_path(distance="constant")) def test_shortest_path_precomputed_constant(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=None) - G = graphtools.Graph(G.K, precomputed='affinity') + G = graphtools.Graph(G.K, precomputed="affinity") P = graph_shortest_path(G.K) # sklearn returns 0 if no path exists P[np.where(P == 0)] = np.inf # diagonal should actually be zero np.fill_diagonal(P, 0) - np.testing.assert_equal(P, G.shortest_path(distance='constant')) + np.testing.assert_equal(P, G.shortest_path(distance="constant")) np.testing.assert_equal(P, G.shortest_path()) def test_shortest_path_data(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=None) D = squareform(pdist(G.data_nu)) * np.where(G.K.toarray() > 0, 1, 0) P = graph_shortest_path(D) @@ -350,42 +478,38 @@ def test_shortest_path_data(): P[np.where(P == 0)] = np.inf # diagonal should actually be zero np.fill_diagonal(P, 0) - np.testing.assert_allclose(P, G.shortest_path(distance='data')) + np.testing.assert_allclose(P, G.shortest_path(distance="data")) np.testing.assert_allclose(P, G.shortest_path()) @raises(ValueError) def test_shortest_path_no_decay_affinity(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=None) - G.shortest_path(distance='affinity') + G.shortest_path(distance="affinity") @raises(ValueError) def test_shortest_path_precomputed_no_decay_affinity(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=None) - G = graphtools.Graph(G.K, precomputed='affinity') - G.shortest_path(distance='affinity') + G = graphtools.Graph(G.K, precomputed="affinity") + G.shortest_path(distance="affinity") @raises(ValueError) def test_shortest_path_precomputed_no_decay_data(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=None) - G = graphtools.Graph(G.K, precomputed='affinity') - G.shortest_path(distance='data') + G = graphtools.Graph(G.K, precomputed="affinity") + G.shortest_path(distance="data") @raises(ValueError) def test_shortest_path_invalid(): - data_small = data[np.random.choice( - len(data), len(data) // 4, replace=False)] + data_small = data[np.random.choice(len(data), len(data) // 4, replace=False)] G = build_graph(data_small, knn=5, decay=None) - G.shortest_path(distance='invalid') + G.shortest_path(distance="invalid") #################### @@ -402,20 +526,21 @@ def test_verbose(): def test_set_params(): G = build_graph(data, decay=None) assert G.get_params() == { - 'n_pca': 20, - 'random_state': 42, - 'kernel_symm': '+', - 'theta': None, - 'anisotropy': 0, - 'knn': 3, - 'decay': None, - 'bandwidth': None, - 'bandwidth_scale': 1, - 'distance': 'euclidean', - 'thresh': 0, - 'n_jobs': -1, - 'verbose': 0 - } + "n_pca": 20, + "random_state": 42, + "kernel_symm": "+", + "theta": None, + "anisotropy": 0, + "knn": 3, + "knn_max": None, + "decay": None, + "bandwidth": None, + "bandwidth_scale": 1, + "distance": "euclidean", + "thresh": 0, + "n_jobs": -1, + "verbose": 0, + }, G.get_params() G.set_params(n_jobs=4) assert G.n_jobs == 4 assert G.knn_tree.n_jobs == 4 @@ -425,18 +550,21 @@ def test_set_params(): assert G.verbose == 2 G.set_params(verbose=0) assert_raises(ValueError, G.set_params, knn=15) + assert_raises(ValueError, G.set_params, knn_max=15) assert_raises(ValueError, G.set_params, decay=10) - assert_raises(ValueError, G.set_params, distance='manhattan') + assert_raises(ValueError, G.set_params, distance="manhattan") assert_raises(ValueError, G.set_params, thresh=1e-3) assert_raises(ValueError, G.set_params, theta=0.99) - assert_raises(ValueError, G.set_params, kernel_symm='*') + assert_raises(ValueError, G.set_params, kernel_symm="*") assert_raises(ValueError, G.set_params, anisotropy=0.7) assert_raises(ValueError, G.set_params, bandwidth=5) assert_raises(ValueError, G.set_params, bandwidth_scale=5) - G.set_params(knn=G.knn, - decay=G.decay, - thresh=G.thresh, - distance=G.distance, - theta=G.theta, - anisotropy=G.anisotropy, - kernel_symm=G.kernel_symm) + G.set_params( + knn=G.knn, + decay=G.decay, + thresh=G.thresh, + distance=G.distance, + theta=G.theta, + anisotropy=G.anisotropy, + kernel_symm=G.kernel_symm, + ) diff --git a/test/test_landmark.py b/test/test_landmark.py index 950af06..2e8d8d7 100644 --- a/test/test_landmark.py +++ b/test/test_landmark.py @@ -9,7 +9,7 @@ assert_raises, raises, warns, - generate_swiss_roll + generate_swiss_roll, ) import pygsp @@ -37,15 +37,21 @@ def test_build_landmark_with_too_few_points(): def test_landmark_exact_graph(): n_landmark = 100 # exact graph - G = build_graph(data, n_landmark=n_landmark, - thresh=0, n_pca=20, - decay=10, knn=5 - 1, random_state=42) - assert(G.landmark_op.shape == (n_landmark, n_landmark)) - assert(isinstance(G, graphtools.graphs.TraditionalGraph)) - assert(isinstance(G, graphtools.graphs.LandmarkGraph)) - assert(G.transitions.shape == (data.shape[0], n_landmark)) - assert(G.clusters.shape == (data.shape[0],)) - assert(len(np.unique(G.clusters)) <= n_landmark) + G = build_graph( + data, + n_landmark=n_landmark, + thresh=0, + n_pca=20, + decay=10, + knn=5 - 1, + random_state=42, + ) + assert G.landmark_op.shape == (n_landmark, n_landmark) + assert isinstance(G, graphtools.graphs.TraditionalGraph) + assert isinstance(G, graphtools.graphs.LandmarkGraph) + assert G.transitions.shape == (data.shape[0], n_landmark) + assert G.clusters.shape == (data.shape[0],) + assert len(np.unique(G.clusters)) <= n_landmark signal = np.random.normal(0, 1, [n_landmark, 10]) interpolated_signal = G.interpolate(signal) assert interpolated_signal.shape == (data.shape[0], signal.shape[1]) @@ -57,26 +63,33 @@ def test_landmark_exact_graph(): def test_landmark_knn_graph(): n_landmark = 500 # knn graph - G = build_graph(data, n_landmark=n_landmark, n_pca=20, - decay=None, knn=5 - 1, random_state=42) - assert(G.transitions.shape == (data.shape[0], n_landmark)) - assert(G.landmark_op.shape == (n_landmark, n_landmark)) - assert(isinstance(G, graphtools.graphs.kNNGraph)) - assert(isinstance(G, graphtools.graphs.LandmarkGraph)) + G = build_graph( + data, n_landmark=n_landmark, n_pca=20, decay=None, knn=5 - 1, random_state=42 + ) + assert G.transitions.shape == (data.shape[0], n_landmark) + assert G.landmark_op.shape == (n_landmark, n_landmark) + assert isinstance(G, graphtools.graphs.kNNGraph) + assert isinstance(G, graphtools.graphs.LandmarkGraph) def test_landmark_mnn_graph(): n_landmark = 150 X, sample_idx = generate_swiss_roll() # mnn graph - G = build_graph(X, n_landmark=n_landmark, - thresh=1e-5, n_pca=None, - decay=10, knn=5 - 1, random_state=42, - sample_idx=sample_idx) - assert(G.clusters.shape == (X.shape[0],)) - assert(G.landmark_op.shape == (n_landmark, n_landmark)) - assert(isinstance(G, graphtools.graphs.MNNGraph)) - assert(isinstance(G, graphtools.graphs.LandmarkGraph)) + G = build_graph( + X, + n_landmark=n_landmark, + thresh=1e-5, + n_pca=None, + decay=10, + knn=5 - 1, + random_state=42, + sample_idx=sample_idx, + ) + assert G.clusters.shape == (X.shape[0],) + assert G.landmark_op.shape == (n_landmark, n_landmark) + assert isinstance(G, graphtools.graphs.MNNGraph) + assert isinstance(G, graphtools.graphs.LandmarkGraph) ##################################################### @@ -87,40 +100,59 @@ def test_landmark_mnn_graph(): def test_landmark_exact_pygsp_graph(): n_landmark = 100 # exact graph - G = build_graph(data, n_landmark=n_landmark, - thresh=0, n_pca=10, - decay=10, knn=3 - 1, random_state=42, - use_pygsp=True) - assert(G.landmark_op.shape == (n_landmark, n_landmark)) - assert(isinstance(G, graphtools.graphs.TraditionalGraph)) - assert(isinstance(G, graphtools.graphs.LandmarkGraph)) - assert(isinstance(G, pygsp.graphs.Graph)) + G = build_graph( + data, + n_landmark=n_landmark, + thresh=0, + n_pca=10, + decay=10, + knn=3 - 1, + random_state=42, + use_pygsp=True, + ) + assert G.landmark_op.shape == (n_landmark, n_landmark) + assert isinstance(G, graphtools.graphs.TraditionalGraph) + assert isinstance(G, graphtools.graphs.LandmarkGraph) + assert isinstance(G, pygsp.graphs.Graph) def test_landmark_knn_pygsp_graph(): n_landmark = 500 # knn graph - G = build_graph(data, n_landmark=n_landmark, n_pca=10, - decay=None, knn=3 - 1, random_state=42, - use_pygsp=True) - assert(G.landmark_op.shape == (n_landmark, n_landmark)) - assert(isinstance(G, graphtools.graphs.kNNGraph)) - assert(isinstance(G, graphtools.graphs.LandmarkGraph)) - assert(isinstance(G, pygsp.graphs.Graph)) + G = build_graph( + data, + n_landmark=n_landmark, + n_pca=10, + decay=None, + knn=3 - 1, + random_state=42, + use_pygsp=True, + ) + assert G.landmark_op.shape == (n_landmark, n_landmark) + assert isinstance(G, graphtools.graphs.kNNGraph) + assert isinstance(G, graphtools.graphs.LandmarkGraph) + assert isinstance(G, pygsp.graphs.Graph) def test_landmark_mnn_pygsp_graph(): n_landmark = 150 X, sample_idx = generate_swiss_roll() # mnn graph - G = build_graph(X, n_landmark=n_landmark, - thresh=1e-3, n_pca=None, - decay=10, knn=3 - 1, random_state=42, - sample_idx=sample_idx, use_pygsp=True) - assert(G.landmark_op.shape == (n_landmark, n_landmark)) - assert(isinstance(G, graphtools.graphs.MNNGraph)) - assert(isinstance(G, graphtools.graphs.LandmarkGraph)) - assert(isinstance(G, pygsp.graphs.Graph)) + G = build_graph( + X, + n_landmark=n_landmark, + thresh=1e-3, + n_pca=None, + decay=10, + knn=3 - 1, + random_state=42, + sample_idx=sample_idx, + use_pygsp=True, + ) + assert G.landmark_op.shape == (n_landmark, n_landmark) + assert isinstance(G, graphtools.graphs.MNNGraph) + assert isinstance(G, graphtools.graphs.LandmarkGraph) + assert isinstance(G, pygsp.graphs.Graph) ##################################################### @@ -135,6 +167,7 @@ def test_landmark_mnn_pygsp_graph(): # Test API ############# + def test_verbose(): print() print("Verbose test: Landmark") @@ -145,20 +178,22 @@ def test_set_params(): G = build_graph(data, n_landmark=500, decay=None) G.landmark_op assert G.get_params() == { - 'n_pca': 20, - 'random_state': 42, - 'kernel_symm': '+', - 'theta': None, - 'n_landmark': 500, - 'anisotropy': 0, - 'knn': 3, - 'decay': None, - 'bandwidth': None, - 'bandwidth_scale': 1, - 'distance': 'euclidean', - 'thresh': 0, - 'n_jobs': -1, - 'verbose': 0} + "n_pca": 20, + "random_state": 42, + "kernel_symm": "+", + "theta": None, + "n_landmark": 500, + "anisotropy": 0, + "knn": 3, + "knn_max": None, + "decay": None, + "bandwidth": None, + "bandwidth_scale": 1, + "distance": "euclidean", + "thresh": 0, + "n_jobs": -1, + "verbose": 0, + } G.set_params(n_landmark=300) assert G.landmark_op.shape == (300, 300) G.set_params(n_landmark=G.n_landmark, n_svd=G.n_svd) diff --git a/test/test_mnn.py b/test/test_mnn.py index 0e91aa4..f1aa8c3 100644 --- a/test/test_mnn.py +++ b/test/test_mnn.py @@ -1,4 +1,5 @@ from __future__ import print_function +import warnings from load_tests import ( graphtools, np, @@ -24,156 +25,236 @@ @raises(ValueError) def test_sample_idx_and_precomputed(): - build_graph(data, n_pca=None, - sample_idx=np.arange(10), - precomputed='distance') + build_graph(data, n_pca=None, sample_idx=np.arange(10), precomputed="distance") @raises(ValueError) def test_sample_idx_wrong_length(): - build_graph(data, graphtype='mnn', - sample_idx=np.arange(10)) + build_graph(data, graphtype="mnn", sample_idx=np.arange(10)) @raises(ValueError) def test_sample_idx_unique(): - build_graph(data, graph_class=graphtools.graphs.MNNGraph, - sample_idx=np.ones(len(data))) + build_graph( + data, graph_class=graphtools.graphs.MNNGraph, sample_idx=np.ones(len(data)) + ) @raises(ValueError) def test_sample_idx_none(): - build_graph(data, graphtype='mnn', sample_idx=None) + build_graph(data, graphtype="mnn", sample_idx=None) @raises(ValueError) def test_build_mnn_with_precomputed(): - build_graph(data, n_pca=None, graphtype='mnn', precomputed='distance') + build_graph(data, n_pca=None, graphtype="mnn", precomputed="distance") @raises(TypeError) def test_mnn_with_matrix_theta(): - n_sample = len(np.unique(digits['target'])) + n_sample = len(np.unique(digits["target"])) # square matrix theta of the wrong size build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='mnn', - theta=np.tile(np.linspace(0, 1, n_sample), - n_sample).reshape(n_sample, n_sample)) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="mnn", + theta=np.tile(np.linspace(0, 1, n_sample), n_sample).reshape( + n_sample, n_sample + ), + ) @raises(TypeError) def test_mnn_with_vector_theta(): - n_sample = len(np.unique(digits['target'])) + n_sample = len(np.unique(digits["target"])) # vector theta build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='mnn', - theta=np.linspace(0, 1, n_sample - 1)) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="mnn", + theta=np.linspace(0, 1, n_sample - 1), + ) @raises(ValueError) def test_mnn_with_unbounded_theta(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='mnn', - theta=2) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="mnn", + theta=2, + ) @raises(TypeError) def test_mnn_with_string_theta(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='mnn', - theta='invalid') + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="mnn", + theta="invalid", + ) @warns(FutureWarning) def test_mnn_with_gamma(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='mnn', - gamma=0.9) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="mnn", + gamma=0.9, + ) @warns(FutureWarning) def test_mnn_with_kernel_symm_gamma(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='gamma', - theta=0.9) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="gamma", + theta=0.9, + ) @raises(ValueError) def test_mnn_with_kernel_symm_invalid(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='invalid', - theta=0.9) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="invalid", + theta=0.9, + ) @warns(FutureWarning) def test_mnn_with_kernel_symm_theta(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='theta', - theta=0.9) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="theta", + theta=0.9, + ) @warns(UserWarning) def test_mnn_with_theta_and_kernel_symm_not_theta(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='+', - theta=0.9) + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="+", + theta=0.9, + ) @warns(UserWarning) def test_mnn_with_kernel_symmm_theta_and_no_theta(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='mnn') + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="mnn", + ) @warns(DeprecationWarning) def test_mnn_adaptive_k(): build_graph( - data, thresh=0, n_pca=20, - decay=10, knn=5, random_state=42, - sample_idx=digits['target'], - kernel_symm='mnn', - theta=0.9, adaptive_k='sqrt') + data, + thresh=0, + n_pca=20, + decay=10, + knn=5, + random_state=42, + sample_idx=digits["target"], + kernel_symm="mnn", + theta=0.9, + adaptive_k="sqrt", + ) + + +@warns(UserWarning) +def test_single_sample_idx_warning(): + build_graph(data, sample_idx=np.repeat(1, len(data))) + + +def test_single_sample_idx(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Only one unique sample. Not using MNNGraph", UserWarning + ) + G = build_graph(data, sample_idx=np.repeat(1, len(data))) + G2 = build_graph(data) + np.testing.assert_array_equal(G.K, G2.K) def test_mnn_with_non_zero_indexed_sample_idx(): X, sample_idx = generate_swiss_roll() - G = build_graph(X, sample_idx=sample_idx, - kernel_symm='mnn', theta=0.5, - n_pca=None, use_pygsp=True) + G = build_graph( + X, + sample_idx=sample_idx, + kernel_symm="mnn", + theta=0.5, + n_pca=None, + use_pygsp=True, + ) sample_idx += 1 - G2 = build_graph(X, sample_idx=sample_idx, - kernel_symm='mnn', theta=0.5, - n_pca=None, use_pygsp=True) + G2 = build_graph( + X, + sample_idx=sample_idx, + kernel_symm="mnn", + theta=0.5, + n_pca=None, + use_pygsp=True, + ) assert G.N == G2.N assert np.all(G.d == G2.d) assert (G.W != G2.W).nnz == 0 @@ -183,13 +264,23 @@ def test_mnn_with_non_zero_indexed_sample_idx(): def test_mnn_with_string_sample_idx(): X, sample_idx = generate_swiss_roll() - G = build_graph(X, sample_idx=sample_idx, - kernel_symm='mnn', theta=0.5, - n_pca=None, use_pygsp=True) - sample_idx = np.where(sample_idx == 0, 'a', 'b') - G2 = build_graph(X, sample_idx=sample_idx, - kernel_symm='mnn', theta=0.5, - n_pca=None, use_pygsp=True) + G = build_graph( + X, + sample_idx=sample_idx, + kernel_symm="mnn", + theta=0.5, + n_pca=None, + use_pygsp=True, + ) + sample_idx = np.where(sample_idx == 0, "a", "b") + G2 = build_graph( + X, + sample_idx=sample_idx, + kernel_symm="mnn", + theta=0.5, + n_pca=None, + use_pygsp=True, + ) assert G.N == G2.N assert np.all(G.d == G2.d) assert (G.W != G2.W).nnz == 0 @@ -201,12 +292,13 @@ def test_mnn_with_string_sample_idx(): # Check kernel ##################################################### + def test_mnn_graph_no_decay(): X, sample_idx = generate_swiss_roll() theta = 0.9 k = 10 a = None - metric = 'euclidean' + metric = "euclidean" beta = 0.2 samples = np.unique(sample_idx) @@ -215,17 +307,16 @@ def test_mnn_graph_no_decay(): K = pd.DataFrame(K) for si in samples: - X_i = X[sample_idx == si] # get observations in sample i + X_i = X[sample_idx == si] # get observations in sample i for sj in samples: batch_k = k + 1 if si == sj else k - X_j = X[sample_idx == sj] # get observation in sample j + X_j = X[sample_idx == sj] # get observation in sample j pdx_ij = cdist(X_i, X_j, metric=metric) # pairwise distances kdx_ij = np.sort(pdx_ij, axis=1) # get kNN - e_ij = kdx_ij[:, batch_k - 1] # dist to kNN + e_ij = kdx_ij[:, batch_k - 1] # dist to kNN k_ij = np.where(pdx_ij <= e_ij[:, None], 1, 0) # apply knn kernel if si == sj: - K.iloc[sample_idx == si, sample_idx == sj] = ( - k_ij + k_ij.T) / 2 + K.iloc[sample_idx == si, sample_idx == sj] = (k_ij + k_ij.T) / 2 else: # fill out values in K for NN on diagonal K.iloc[sample_idx == si, sample_idx == sj] = k_ij @@ -241,18 +332,26 @@ def test_mnn_graph_no_decay(): curr_K = K.iloc[sample_idx == i, sample_idx == j] curr_norm = norm(curr_K, 1, axis=1) scale = np.minimum(1, i_norm / curr_norm) * beta - Kn.iloc[sample_idx == i, - sample_idx == j] = curr_K.values * scale[:, None] + Kn.iloc[sample_idx == i, sample_idx == j] = ( + curr_K.values * scale[:, None] + ) K = Kn - W = np.array((theta * np.minimum(K, K.T)) + - ((1 - theta) * np.maximum(K, K.T))) + W = np.array((theta * np.minimum(K, K.T)) + ((1 - theta) * np.maximum(K, K.T))) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = graphtools.Graph(X, knn=k, decay=a, beta=beta, - kernel_symm='mnn', theta=theta, - distance=metric, sample_idx=sample_idx, thresh=0, - use_pygsp=True) + G2 = graphtools.Graph( + X, + knn=k, + decay=a, + beta=beta, + kernel_symm="mnn", + theta=theta, + distance=metric, + sample_idx=sample_idx, + thresh=0, + use_pygsp=True, + ) assert G.N == G2.N np.testing.assert_array_equal(G.dw, G2.dw) np.testing.assert_array_equal((G.W - G2.W).data, 0) @@ -264,7 +363,7 @@ def test_mnn_graph_decay(): theta = 0.9 k = 10 a = 20 - metric = 'euclidean' + metric = "euclidean" beta = 0.2 samples = np.unique(sample_idx) @@ -273,18 +372,17 @@ def test_mnn_graph_decay(): K = pd.DataFrame(K) for si in samples: - X_i = X[sample_idx == si] # get observations in sample i + X_i = X[sample_idx == si] # get observations in sample i for sj in samples: batch_k = k if si == sj else k - 1 - X_j = X[sample_idx == sj] # get observation in sample j + X_j = X[sample_idx == sj] # get observation in sample j pdx_ij = cdist(X_i, X_j, metric=metric) # pairwise distances kdx_ij = np.sort(pdx_ij, axis=1) # get kNN - e_ij = kdx_ij[:, batch_k] # dist to kNN + e_ij = kdx_ij[:, batch_k] # dist to kNN pdxe_ij = pdx_ij / e_ij[:, np.newaxis] # normalize k_ij = np.exp(-1 * (pdxe_ij ** a)) # apply alpha-decaying kernel if si == sj: - K.iloc[sample_idx == si, sample_idx == sj] = ( - k_ij + k_ij.T) / 2 + K.iloc[sample_idx == si, sample_idx == sj] = (k_ij + k_ij.T) / 2 else: # fill out values in K for NN on diagonal K.iloc[sample_idx == si, sample_idx == sj] = k_ij @@ -300,18 +398,26 @@ def test_mnn_graph_decay(): curr_K = K.iloc[sample_idx == i, sample_idx == j] curr_norm = norm(curr_K, 1, axis=1) scale = np.minimum(1, i_norm / curr_norm) * beta - Kn.iloc[sample_idx == i, - sample_idx == j] = curr_K.values * scale[:, None] + Kn.iloc[sample_idx == i, sample_idx == j] = ( + curr_K.values * scale[:, None] + ) K = Kn - W = np.array((theta * np.minimum(K, K.T)) + - ((1 - theta) * np.maximum(K, K.T))) + W = np.array((theta * np.minimum(K, K.T)) + ((1 - theta) * np.maximum(K, K.T))) np.fill_diagonal(W, 0) G = pygsp.graphs.Graph(W) - G2 = graphtools.Graph(X, knn=k, decay=a, beta=beta, - kernel_symm='mnn', theta=theta, - distance=metric, sample_idx=sample_idx, thresh=0, - use_pygsp=True) + G2 = graphtools.Graph( + X, + knn=k, + decay=a, + beta=beta, + kernel_symm="mnn", + theta=theta, + distance=metric, + sample_idx=sample_idx, + thresh=0, + use_pygsp=True, + ) assert G.N == G2.N np.testing.assert_array_equal(G.dw, G2.dw) np.testing.assert_array_equal((G.W - G2.W).data, 0) @@ -325,34 +431,34 @@ def test_mnn_graph_decay(): # TODO: add interpolation tests + def test_verbose(): X, sample_idx = generate_swiss_roll() print() print("Verbose test: MNN") - build_graph(X, sample_idx=sample_idx, - kernel_symm='mnn', theta=0.5, - n_pca=None, verbose=True) + build_graph( + X, sample_idx=sample_idx, kernel_symm="mnn", theta=0.5, n_pca=None, verbose=True + ) def test_set_params(): X, sample_idx = generate_swiss_roll() - G = build_graph(X, sample_idx=sample_idx, - kernel_symm='mnn', theta=0.5, - n_pca=None, - thresh=1e-4) + G = build_graph( + X, sample_idx=sample_idx, kernel_symm="mnn", theta=0.5, n_pca=None, thresh=1e-4 + ) assert G.get_params() == { - 'n_pca': None, - 'random_state': 42, - 'kernel_symm': 'mnn', - 'theta': 0.5, - 'anisotropy': 0, - 'beta': 1, - 'knn': 3, - 'decay': 10, - 'bandwidth': None, - 'distance': 'euclidean', - 'thresh': 1e-4, - 'n_jobs': 1 + "n_pca": None, + "random_state": 42, + "kernel_symm": "mnn", + "theta": 0.5, + "anisotropy": 0, + "beta": 1, + "knn": 3, + "decay": 10, + "bandwidth": None, + "distance": "euclidean", + "thresh": 1e-4, + "n_jobs": 1, } G.set_params(n_jobs=4) assert G.n_jobs == 4 @@ -370,11 +476,9 @@ def test_set_params(): G.set_params(verbose=0) assert_raises(ValueError, G.set_params, knn=15) assert_raises(ValueError, G.set_params, decay=15) - assert_raises(ValueError, G.set_params, distance='manhattan') + assert_raises(ValueError, G.set_params, distance="manhattan") assert_raises(ValueError, G.set_params, thresh=1e-3) assert_raises(ValueError, G.set_params, beta=0.2) - G.set_params(knn=G.knn, - decay=G.decay, - thresh=G.thresh, - distance=G.distance, - beta=G.beta) + G.set_params( + knn=G.knn, decay=G.decay, thresh=G.thresh, distance=G.distance, beta=G.beta + ) diff --git a/test/test_utils.py b/test/test_utils.py index d26c612..a104c03 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,8 +7,15 @@ @parameterized( - [(np.array,), (sparse.csr_matrix,), (sparse.csc_matrix,), - (sparse.bsr_matrix,), (sparse.lil_matrix,), (sparse.coo_matrix,)]) + [ + (np.array,), + (sparse.csr_matrix,), + (sparse.csc_matrix,), + (sparse.bsr_matrix,), + (sparse.lil_matrix,), + (sparse.coo_matrix,), + ] +) def test_nonzero_discrete(matrix_class): X = np.random.choice([0, 1, 2], p=[0.95, 0.025, 0.025], size=(100, 100)) X = matrix_class(X) @@ -16,15 +23,13 @@ def test_nonzero_discrete(matrix_class): assert not graphtools.utils.nonzero_discrete(X, [1, 3]) -@parameterized( - [(0,), (1e-4,)]) +@parameterized([(0,), (1e-4,)]) def test_nonzero_discrete_knngraph(thresh): G = graphtools.Graph(data, n_pca=10, knn=5, decay=None, thresh=thresh) assert graphtools.utils.nonzero_discrete(G.K, [0.5, 1]) -@parameterized( - [(0,), (1e-4,)]) +@parameterized([(0,), (1e-4,)]) def test_nonzero_discrete_decay_graph(thresh): G = graphtools.Graph(data, n_pca=10, knn=5, decay=15, thresh=thresh) assert not graphtools.utils.nonzero_discrete(G.K, [0.5, 1])