Skip to content

Commit

Permalink
Merge pull request #15 from KrishnaswamyLab/dev
Browse files Browse the repository at this point in the history
Handle sparse matrices
  • Loading branch information
scottgigante authored Jun 21, 2018
2 parents 97de1e4 + d23c72f commit c1e769b
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 12 deletions.
11 changes: 7 additions & 4 deletions graphtools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,13 @@ def _reduce_data(self):
log_complete("PCA")
return data_nu
else:
data = self.data
if sparse.issparse(data):
data = data.toarray()
return data
data_nu = self.data
if sparse.issparse(data_nu) and not isinstance(
data_nu, (sparse.csr_matrix,
sparse.csc_matrix,
sparse.bsr_matrix)):
data_nu = data_nu.tocsr()
return data_nu

def get_params(self):
"""Get parameters from this object
Expand Down
8 changes: 4 additions & 4 deletions graphtools/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def build_kernel_to_data(self, Y, knn=None):
if len(update_idx) > 0:
distances = [d for d in distances]
indices = [i for i in indices]
while len(update_idx) > len(Y) // 10 and \
while len(update_idx) > Y.shape[0] // 10 and \
search_knn < self.data_nu.shape[0] / 2:
# increase the knn search
search_knn = min(search_knn * 20, self.data_nu.shape[0])
Expand Down Expand Up @@ -829,9 +829,9 @@ def __init__(self, data, sample_idx,
if sample_idx is None:
raise ValueError("sample_idx must be given. For a graph without"
" batch correction, use kNNGraph.")
elif len(sample_idx) != len(data):
elif len(sample_idx) != data.shape[0]:
raise ValueError("sample_idx ({}) must be the same length as "
"data ({})".format(len(sample_idx), len(data)))
"data ({})".format(len(sample_idx), data.shape[0]))
elif len(self.samples) == 1:
raise ValueError(
"sample_idx must contain more than one unique value")
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def build_kernel_to_data(self, Y, gamma=None):
kernel_yx = []
# don't really need within Y kernel
Y_graph = kNNGraph(Y, n_pca=None, knn=0, **(self.knn_args))
y_knn = self._weight_knn(sample_size=len(Y))
y_knn = self._weight_knn(sample_size=Y.shape[0])
for i, X in enumerate(self.subgraphs):
kernel_xy.append(X.build_kernel_to_data(
Y, knn=self.weighted_knn[i])) # kernel X -> Y
Expand Down
2 changes: 1 addition & 1 deletion graphtools/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.6"
__version__ = "0.1.7"
2 changes: 1 addition & 1 deletion test/load_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sklearn.decomposition import PCA
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn import datasets
from scipy.spatial.distance import pdist, cdist, squareform
import pygsp
Expand Down
95 changes: 95 additions & 0 deletions test/test_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
squareform,
pdist,
PCA,
TruncatedSVD
)

#####################################################
Expand Down Expand Up @@ -186,6 +187,100 @@ def test_truncated_exact_graph():
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))


def test_truncated_exact_graph_sparse():
k = 3
a = 13
n_pca = 20
thresh = 1e-4
data_small = data[np.random.choice(
len(data), len(data) // 2, replace=False)]
pca = TruncatedSVD(n_pca,
random_state=42).fit(data_small)
data_small_nu = pca.transform(data_small)
pdx = squareform(pdist(data_small_nu, metric='euclidean'))
knn_dist = np.partition(pdx, k, axis=1)[:, :k]
epsilon = np.max(knn_dist, axis=1)
weighted_pdx = (pdx.T / epsilon).T
K = np.exp(-1 * weighted_pdx**a)
K[K < thresh] = 0
W = K + K.T
W = np.divide(W, 2)
np.fill_diagonal(W, 0)
G = pygsp.graphs.Graph(W)
G2 = build_graph(sp.coo_matrix(data_small), thresh=thresh,
graphtype='exact',
n_pca=n_pca,
decay=a, knn=k, random_state=42,
use_pygsp=True)
assert(G.N == G2.N)
np.testing.assert_allclose(G2.W.toarray(), G.W.toarray())
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
G2 = build_graph(sp.bsr_matrix(pdx), n_pca=None, precomputed='distance',
thresh=thresh,
decay=a, knn=k, random_state=42, use_pygsp=True)
assert(G.N == G2.N)
assert(np.all(G.d == G2.d))
assert((G.W != G2.W).nnz == 0)
assert((G2.W != G.W).sum() == 0)
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
G2 = build_graph(sp.lil_matrix(K), n_pca=None,
precomputed='affinity',
thresh=thresh,
random_state=42, use_pygsp=True)
assert(G.N == G2.N)
assert(np.all(G.d == G2.d))
assert((G.W != G2.W).nnz == 0)
assert((G2.W != G.W).sum() == 0)
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
G2 = build_graph(sp.dok_matrix(W), n_pca=None,
precomputed='adjacency',
random_state=42, use_pygsp=True)
assert(G.N == G2.N)
assert(np.all(G.d == G2.d))
assert((G.W != G2.W).nnz == 0)
assert((G2.W != G.W).sum() == 0)
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))


def test_truncated_exact_graph_no_pca():
k = 3
a = 13
n_pca = None
thresh = 1e-4
data_small = data[np.random.choice(
len(data), len(data) // 10, replace=False)]
pdx = squareform(pdist(data_small, metric='euclidean'))
knn_dist = np.partition(pdx, k, axis=1)[:, :k]
epsilon = np.max(knn_dist, axis=1)
weighted_pdx = (pdx.T / epsilon).T
K = np.exp(-1 * weighted_pdx**a)
K[K < thresh] = 0
W = K + K.T
W = np.divide(W, 2)
np.fill_diagonal(W, 0)
G = pygsp.graphs.Graph(W)
G2 = build_graph(data_small, thresh=thresh,
graphtype='exact',
n_pca=n_pca,
decay=a, knn=k, random_state=42,
use_pygsp=True)
assert(G.N == G2.N)
assert(np.all(G.d == G2.d))
assert((G.W != G2.W).nnz == 0)
assert((G2.W != G.W).sum() == 0)
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))
G2 = build_graph(sp.csr_matrix(data_small), thresh=thresh,
graphtype='exact',
n_pca=n_pca,
decay=a, knn=k, random_state=42,
use_pygsp=True)
assert(G.N == G2.N)
assert(np.all(G.d == G2.d))
assert((G.W != G2.W).nnz == 0)
assert((G2.W != G.W).sum() == 0)
assert(isinstance(G2, graphtools.graphs.TraditionalGraph))


#####################################################
# Check interpolation
#####################################################
Expand Down
35 changes: 35 additions & 0 deletions test/test_knn.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from load_tests import (
graphtools,
np,
sp,
pygsp,
nose2,
data,
datasets,
build_graph,
assert_raises,
warns,
raises,
squareform,
pdist,
PCA,
TruncatedSVD,
)


Expand Down Expand Up @@ -66,6 +69,31 @@ def test_knn_graph():
assert(isinstance(G2, graphtools.graphs.kNNGraph))


def test_knn_graph_sparse():
k = 3
n_pca = 20
pca = TruncatedSVD(n_pca, random_state=42).fit(data)
data_nu = pca.transform(data)
pdx = squareform(pdist(data_nu, metric='euclidean'))
knn_dist = np.partition(pdx, k, axis=1)[:, :k]
epsilon = np.max(knn_dist, axis=1)
K = np.empty_like(pdx)
for i in range(len(pdx)):
K[i, pdx[i, :] <= epsilon[i]] = 1
K[i, pdx[i, :] > epsilon[i]] = 0

K = K + K.T
W = np.divide(K, 2)
np.fill_diagonal(W, 0)
G = pygsp.graphs.Graph(W)
G2 = build_graph(sp.coo_matrix(data), n_pca=n_pca,
decay=None, knn=k, random_state=42,
use_pygsp=True)
assert(G.N == G2.N)
np.testing.assert_allclose(G2.W.toarray(), G.W.toarray())
assert(isinstance(G2, graphtools.graphs.kNNGraph))


def test_sparse_alpha_knn_graph():
data = datasets.make_swiss_roll()[0]
k = 5
Expand All @@ -88,6 +116,13 @@ def test_sparse_alpha_knn_graph():
assert(isinstance(G2, graphtools.graphs.kNNGraph))


@warns(UserWarning)
def test_knn_graph_sparse_no_pca():
build_graph(sp.coo_matrix(data), n_pca=None, # n_pca,
decay=10, knn=3, thresh=1e-4,
random_state=42, use_pygsp=True)


#####################################################
# Check interpolation
#####################################################
Expand Down
4 changes: 2 additions & 2 deletions test/test_landmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def test_landmark_knn_graph():


def test_landmark_mnn_graph():
n_landmark = 500
n_landmark = 150
# mnn graph
select_idx = np.random.choice([True, False], len(data), replace=True)
select_idx = np.random.choice(len(data), len(data) // 5, replace=False)
G = build_graph(data[select_idx], n_landmark=n_landmark,
thresh=1e-5, n_pca=20,
decay=10, knn=5, random_state=42,
Expand Down

0 comments on commit c1e769b

Please sign in to comment.