Skip to content

Commit

Permalink
Merge pull request #36 from KrishnaswamyLab/dev
Browse files Browse the repository at this point in the history
graphtools v1.1
  • Loading branch information
scottgigante authored Mar 2, 2019
2 parents 29b6f90 + 40fefb1 commit 3edc962
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 12 deletions.
2 changes: 1 addition & 1 deletion graphtools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .api import Graph, from_igraph
from .api import Graph, from_igraph, read_pickle
from .version import __version__
21 changes: 21 additions & 0 deletions graphtools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import warnings
import tasklogger
from scipy import sparse
import pickle
import pygsp

from . import base
from . import graphs
Expand Down Expand Up @@ -283,3 +285,22 @@ def from_igraph(G, attribute="weight", **kwargs):
K = G.get_adjacency(attribute=None).data
return Graph(sparse.coo_matrix(K),
precomputed='adjacency', **kwargs)


def read_pickle(path):
"""Load pickled Graphtools object (or any object) from file.
Parameters
----------
path : str
File path where the pickled object will be loaded.
"""
with open(path, 'rb') as f:
G = pickle.load(f)

if not isinstance(G, base.BaseGraph):
warnings.warn(
'Returning object that is not a graphtools.base.BaseGraph')
elif isinstance(G, base.PyGSPGraph) and isinstance(G.logger, str):
G.logger = pygsp.utils.build_logger(G.logger)
return G
32 changes: 25 additions & 7 deletions graphtools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import numpy as np
import abc
import pygsp
from sklearn.utils.fixes import signature
from inspect import signature
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.preprocessing import normalize
from sklearn.utils.graph import graph_shortest_path
from scipy import sparse
import warnings
import numbers
import tasklogger
import pickle
import sys

try:
import pandas as pd
Expand Down Expand Up @@ -106,10 +107,10 @@ class Data(Base):
def __init__(self, data, n_pca=None, random_state=None, **kwargs):

self._check_data(data)
if n_pca is not None and data.shape[1] <= n_pca:
if n_pca is not None and np.min(data.shape) <= n_pca:
warnings.warn("Cannot perform PCA to {} dimensions on "
"data with {} dimensions".format(n_pca,
data.shape[1]),
"data with min(n_samples, n_features) = {}".format(
n_pca, np.min(data.shape)),
RuntimeWarning)
n_pca = None
try:
Expand Down Expand Up @@ -316,7 +317,7 @@ class BaseGraph(with_metaclass(abc.ABCMeta, Base)):
'theta' : min-max
'none' : no symmetrization
theta: float (default: 0.5)
theta: float (default: 1)
Min-max symmetrization constant.
K = `theta * min(K, K.T) + (1 - theta) * max(K, K.T)`
Expand Down Expand Up @@ -385,7 +386,7 @@ def _check_symmetrization(self, kernel_symm, theta):
if theta is None:
warnings.warn("kernel_symm='theta' but theta not given. "
"Defaulting to theta=0.5.")
self.theta = theta = 0.5
self.theta = theta = 1
elif not isinstance(theta, numbers.Number) or \
theta < 0 or theta > 1:
raise ValueError("theta {} not recognized. Expected "
Expand Down Expand Up @@ -636,6 +637,23 @@ def to_igraph(self, attribute="weight", **kwargs):
return ig.Graph.Weighted_Adjacency(utils.to_dense(W).tolist(),
attr=attribute, **kwargs)

def to_pickle(self, path):
"""Save the current Graph to a pickle.
Parameters
----------
path : str
File path where the pickled object will be stored.
"""
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
# python 3.5, 3.6
logger = self.logger
self.logger = logger.name
with open(path, 'wb') as f:
pickle.dump(self, f)
if int(sys.version.split(".")[1]) < 7 and isinstance(self, pygsp.graphs.Graph):
self.logger = logger


class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)):
"""Interface between BaseGraph and PyGSP.
Expand Down
4 changes: 1 addition & 3 deletions graphtools/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import tasklogger

from .utils import (set_diagonal,
elementwise_minimum,
elementwise_maximum,
set_submatrix)
from .base import DataGraph, PyGSPGraph

Expand Down Expand Up @@ -245,7 +243,7 @@ def _check_duplicates(self, distances, indices):
"Detected zero distance between {} pairs of samples. "
"Consider removing duplicates to avoid errors in "
"downstream processing.".format(
np.sum(np.sum(distances[:, 1:]))),
np.sum(np.sum(distances[:, 1:] == 0))),
RuntimeWarning)

def build_kernel_to_data(self, Y, knn=None, bandwidth=None,
Expand Down
2 changes: 1 addition & 1 deletion graphtools/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0"
__version__ = "1.1.0"
52 changes: 52 additions & 0 deletions test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import igraph
import numpy as np
import graphtools
import tempfile
import os


def test_from_igraph():
Expand Down Expand Up @@ -81,6 +83,56 @@ def test_to_igraph():
attribute="weight").data) == G.W)


def test_pickle_io_knngraph():
G = build_graph(data, knn=5, decay=None)
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))


def test_pickle_io_traditionalgraph():
G = build_graph(data, knn=5, decay=10, thresh=0)
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))


def test_pickle_io_landmarkgraph():
G = build_graph(data, knn=5, decay=None,
n_landmark=data.shape[0] // 2)
L = G.landmark_op
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))
np.testing.assert_array_equal(L, G_prime._landmark_op)


def test_pickle_io_pygspgraph():
G = build_graph(data, knn=5, decay=None, use_pygsp=True)
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
G.to_pickle(path)
G_prime = graphtools.read_pickle(path)
assert isinstance(G_prime, type(G))
assert G_prime.logger.name == G.logger.name


@warns(UserWarning)
def test_pickle_bad_pickle():
import pickle
with tempfile.TemporaryDirectory() as tempdir:
path = os.path.join(tempdir, 'tmp.pkl')
with open(path, 'wb') as f:
pickle.dump('hello world', f)
G = graphtools.read_pickle(path)


@warns(UserWarning)
def test_to_pygsp_invalid_precomputed():
G = build_graph(data)
Expand Down
6 changes: 6 additions & 0 deletions test/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def test_too_many_n_pca():
build_graph(data, n_pca=data.shape[1])


@warns(RuntimeWarning)
def test_too_many_n_pca():
build_graph(data[:data.shape[1] - 1],
n_pca=data.shape[1] - 1)


@warns(RuntimeWarning)
def test_precomputed_with_pca():
build_graph(squareform(pdist(data)),
Expand Down

0 comments on commit 3edc962

Please sign in to comment.