Skip to content

Commit

Permalink
Merge pull request #18 from KrishnaswamyLab/dev
Browse files Browse the repository at this point in the history
v0.1.9: support AnnData
  • Loading branch information
scottgigante authored Jul 11, 2018
2 parents 7899467 + b27bda9 commit beca21d
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__
build
dist
*egg-info
.coverage

#syncthing
.syncthing.*
Expand Down
13 changes: 13 additions & 0 deletions graphtools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
# pandas not installed
pass

try:
import anndata
except (ImportError, SyntaxError):
# anndata not installed
pass

from .utils import (elementwise_minimum,
elementwise_maximum,
set_diagonal)
Expand Down Expand Up @@ -111,6 +117,13 @@ def __init__(self, data, n_pca=None, random_state=None, **kwargs):
except NameError:
# pandas not installed
pass

try:
if isinstance(data, anndata.AnnData):
data = data.X
except NameError:
# anndata not installed
pass
self.data = data
self.n_pca = n_pca
self.random_state = random_state
Expand Down
2 changes: 1 addition & 1 deletion graphtools/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.8.1"
__version__ = "0.1.9"
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
'coveralls'
]

if sys.version_info[0] == 3:
test_requires += ['anndata']

doc_requires = [
'sphinx',
'sphinxcontrib-napoleon',
Expand Down
4 changes: 0 additions & 4 deletions test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,3 @@ def test_unknown_parameter():
@raises(ValueError)
def test_invalid_graphtype():
build_graph(data, graphtype='hello world')


if __name__ == "__main__":
exit(nose2.run())
39 changes: 35 additions & 4 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
squareform,
pdist,
)
import warnings

try:
import anndata
except (ImportError, SyntaxError):
# python2 support is missing
with warnings.catch_warnings():
warnings.filterwarnings("always")
warnings.warn("Warning: failed to import anndata", ImportWarning)
pass

#####################################################
# Check parameters
Expand Down Expand Up @@ -57,6 +67,28 @@ def test_pandas_sparse_dataframe():
assert isinstance(G.data, sp.csr_matrix)


def test_anndata():
try:
anndata
except NameError:
# not installed
return
G = build_graph(anndata.AnnData(data))
assert isinstance(G, graphtools.base.BaseGraph)
assert isinstance(G.data, np.ndarray)


def test_anndata_sparse():
try:
anndata
except NameError:
# not installed
return
G = build_graph(anndata.AnnData(sp.csr_matrix(data)))
assert isinstance(G, graphtools.base.BaseGraph)
assert isinstance(G.data, sp.csr_matrix)


#####################################################
# Check transform
#####################################################
Expand Down Expand Up @@ -136,6 +168,9 @@ def test_inverse_transform_sparse_svd():

def test_inverse_transform_dense_no_pca():
G = build_graph(data, n_pca=None)
np.testing.assert_allclose(data[:, 5:7],
G.inverse_transform(G.data_nu, columns=[5, 6]),
atol=1e-12)
assert(np.all(G.data == G.inverse_transform(G.data_nu)))
assert_raises(ValueError, G.inverse_transform, G.data[:, 0])
assert_raises(ValueError, G.inverse_transform, G.data[:, None, :15])
Expand All @@ -148,7 +183,3 @@ def test_inverse_transform_sparse_no_pca():
assert_raises(ValueError, G.inverse_transform, sp.csr_matrix(G.data)[:, 0])
assert_raises(ValueError, G.inverse_transform,
sp.csr_matrix(G.data)[:, :15])


if __name__ == "__main__":
exit(nose2.run())
4 changes: 0 additions & 4 deletions test/test_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,3 @@ def test_verbose():
print()
print("Verbose test: Exact")
build_graph(data, decay=10, thresh=0, verbose=True)


if __name__ == "__main__":
exit(nose2.run())
4 changes: 0 additions & 4 deletions test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,3 @@ def test_verbose():
print()
print("Verbose test: kNN")
build_graph(data, decay=None, verbose=True)


if __name__ == "__main__":
exit(nose2.run())
13 changes: 5 additions & 8 deletions test/test_landmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
build_graph,
raises,
warns,
generate_swiss_roll
)


Expand Down Expand Up @@ -53,12 +54,12 @@ def test_landmark_knn_graph():

def test_landmark_mnn_graph():
n_landmark = 150
X, sample_idx = generate_swiss_roll()
# mnn graph
select_idx = np.random.choice(len(data), len(data) // 5, replace=False)
G = build_graph(data[select_idx], n_landmark=n_landmark,
thresh=1e-5, n_pca=20,
G = build_graph(X, n_landmark=n_landmark,
thresh=1e-5, n_pca=None,
decay=10, knn=5, random_state=42,
sample_idx=digits['target'][select_idx])
sample_idx=sample_idx)
assert(G.landmark_op.shape == (n_landmark, n_landmark))
assert(isinstance(G, graphtools.graphs.MNNGraph))
assert(isinstance(G, graphtools.graphs.LandmarkGraph))
Expand All @@ -76,7 +77,3 @@ def test_verbose():
print()
print("Verbose test: Landmark")
build_graph(data, decay=None, n_landmark=500, verbose=True).landmark_op


if __name__ == "__main__":
exit(nose2.run())
4 changes: 0 additions & 4 deletions test/test_mnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,3 @@ def test_verbose():
build_graph(X, sample_idx=sample_idx,
kernel_symm='gamma', gamma=0.5,
n_pca=None, verbose=True)


if __name__ == "__main__":
exit(nose2.run())

0 comments on commit beca21d

Please sign in to comment.