diff --git a/.gitignore b/.gitignore index 9751fa8..c2f0954 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__ build dist *egg-info +.coverage #syncthing .syncthing.* diff --git a/graphtools/base.py b/graphtools/base.py index 04b0c90..5427041 100644 --- a/graphtools/base.py +++ b/graphtools/base.py @@ -15,6 +15,12 @@ # pandas not installed pass +try: + import anndata +except (ImportError, SyntaxError): + # anndata not installed + pass + from .utils import (elementwise_minimum, elementwise_maximum, set_diagonal) @@ -111,6 +117,13 @@ def __init__(self, data, n_pca=None, random_state=None, **kwargs): except NameError: # pandas not installed pass + + try: + if isinstance(data, anndata.AnnData): + data = data.X + except NameError: + # anndata not installed + pass self.data = data self.n_pca = n_pca self.random_state = random_state diff --git a/graphtools/version.py b/graphtools/version.py index 54c9e35..c11f861 100644 --- a/graphtools/version.py +++ b/graphtools/version.py @@ -1 +1 @@ -__version__ = "0.1.8.1" +__version__ = "0.1.9" diff --git a/setup.py b/setup.py index 445a8f9..12c1c18 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,9 @@ 'coveralls' ] +if sys.version_info[0] == 3: + test_requires += ['anndata'] + doc_requires = [ 'sphinx', 'sphinxcontrib-napoleon', diff --git a/test/test_api.py b/test/test_api.py index 882cbae..c099086 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -19,7 +19,3 @@ def test_unknown_parameter(): @raises(ValueError) def test_invalid_graphtype(): build_graph(data, graphtype='hello world') - - -if __name__ == "__main__": - exit(nose2.run()) diff --git a/test/test_data.py b/test/test_data.py index 08a4231..6907821 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -12,6 +12,16 @@ squareform, pdist, ) +import warnings + +try: + import anndata +except (ImportError, SyntaxError): + # python2 support is missing + with warnings.catch_warnings(): + warnings.filterwarnings("always") + warnings.warn("Warning: failed to import anndata", ImportWarning) + pass ##################################################### # Check parameters @@ -57,6 +67,28 @@ def test_pandas_sparse_dataframe(): assert isinstance(G.data, sp.csr_matrix) +def test_anndata(): + try: + anndata + except NameError: + # not installed + return + G = build_graph(anndata.AnnData(data)) + assert isinstance(G, graphtools.base.BaseGraph) + assert isinstance(G.data, np.ndarray) + + +def test_anndata_sparse(): + try: + anndata + except NameError: + # not installed + return + G = build_graph(anndata.AnnData(sp.csr_matrix(data))) + assert isinstance(G, graphtools.base.BaseGraph) + assert isinstance(G.data, sp.csr_matrix) + + ##################################################### # Check transform ##################################################### @@ -136,6 +168,9 @@ def test_inverse_transform_sparse_svd(): def test_inverse_transform_dense_no_pca(): G = build_graph(data, n_pca=None) + np.testing.assert_allclose(data[:, 5:7], + G.inverse_transform(G.data_nu, columns=[5, 6]), + atol=1e-12) assert(np.all(G.data == G.inverse_transform(G.data_nu))) assert_raises(ValueError, G.inverse_transform, G.data[:, 0]) assert_raises(ValueError, G.inverse_transform, G.data[:, None, :15]) @@ -148,7 +183,3 @@ def test_inverse_transform_sparse_no_pca(): assert_raises(ValueError, G.inverse_transform, sp.csr_matrix(G.data)[:, 0]) assert_raises(ValueError, G.inverse_transform, sp.csr_matrix(G.data)[:, :15]) - - -if __name__ == "__main__": - exit(nose2.run()) diff --git a/test/test_exact.py b/test/test_exact.py index 5f99145..f43f859 100644 --- a/test/test_exact.py +++ b/test/test_exact.py @@ -336,7 +336,3 @@ def test_verbose(): print() print("Verbose test: Exact") build_graph(data, decay=10, thresh=0, verbose=True) - - -if __name__ == "__main__": - exit(nose2.run()) diff --git a/test/test_knn.py b/test/test_knn.py index 8ae5b0a..1ea6e12 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -171,7 +171,3 @@ def test_verbose(): print() print("Verbose test: kNN") build_graph(data, decay=None, verbose=True) - - -if __name__ == "__main__": - exit(nose2.run()) diff --git a/test/test_landmark.py b/test/test_landmark.py index 896388d..17f386c 100644 --- a/test/test_landmark.py +++ b/test/test_landmark.py @@ -7,6 +7,7 @@ build_graph, raises, warns, + generate_swiss_roll ) @@ -53,12 +54,12 @@ def test_landmark_knn_graph(): def test_landmark_mnn_graph(): n_landmark = 150 + X, sample_idx = generate_swiss_roll() # mnn graph - select_idx = np.random.choice(len(data), len(data) // 5, replace=False) - G = build_graph(data[select_idx], n_landmark=n_landmark, - thresh=1e-5, n_pca=20, + G = build_graph(X, n_landmark=n_landmark, + thresh=1e-5, n_pca=None, decay=10, knn=5, random_state=42, - sample_idx=digits['target'][select_idx]) + sample_idx=sample_idx) assert(G.landmark_op.shape == (n_landmark, n_landmark)) assert(isinstance(G, graphtools.graphs.MNNGraph)) assert(isinstance(G, graphtools.graphs.LandmarkGraph)) @@ -76,7 +77,3 @@ def test_verbose(): print() print("Verbose test: Landmark") build_graph(data, decay=None, n_landmark=500, verbose=True).landmark_op - - -if __name__ == "__main__": - exit(nose2.run()) diff --git a/test/test_mnn.py b/test/test_mnn.py index 5adb6ec..51d58e5 100644 --- a/test/test_mnn.py +++ b/test/test_mnn.py @@ -222,7 +222,3 @@ def test_verbose(): build_graph(X, sample_idx=sample_idx, kernel_symm='gamma', gamma=0.5, n_pca=None, verbose=True) - - -if __name__ == "__main__": - exit(nose2.run())