Skip to content

Commit

Permalink
Merge pull request #31 from KrishnaswamyLab/feature/callable_bw
Browse files Browse the repository at this point in the history
Feature/callable bw
  • Loading branch information
scottgigante authored Jan 26, 2019
2 parents de4a123 + 30b0cd8 commit 352e8ff
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 63 deletions.
23 changes: 17 additions & 6 deletions graphtools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def Graph(data,
knn=5,
decay=10,
bandwidth=None,
bandwidth_scale=1.0,
anisotropy=0,
distance='euclidean',
thresh=1e-4,
Expand Down Expand Up @@ -64,10 +65,14 @@ def Graph(data,
decay : `int` or `None`, optional (default: 10)
Rate of alpha decay to use. If `None`, alpha decay is not used.
bandwidth : `float`, list-like or `None`, optional (default: `None`)
bandwidth : `float`, list-like,`callable`, or `None`, optional (default: `None`)
Fixed bandwidth to use. If given, overrides `knn`. Can be a single
bandwidth or a list-like (shape=[n_samples]) of bandwidths for each
sample.
bandwidth, list-like (shape=[n_samples]) of bandwidths for each
sample, or a `callable` that takes in a `n x m` matrix and returns a
a single value or list-like of length n (shape=[n_samples])
bandwidth_scale : `float`, optional (default : 1.0)
Rescaling factor for bandwidth.
anisotropy : float, optional (default: 0)
Level of anisotropy between 0 and 1
Expand Down Expand Up @@ -161,12 +166,18 @@ def Graph(data,
if sample_idx is not None:
# only mnn does batch correction
graphtype = "mnn"
elif precomputed is None and (decay is None or thresh > 0):
elif precomputed is not None:
# precomputed requires exact graph
# no decay or threshold decay require knngraph
graphtype = "exact"
elif decay is None:
# knn kernel
graphtype = "knn"
else:
elif thresh == 0 or callable(bandwidth):
# compute full distance matrix
graphtype = "exact"
else:
# decay kernel with nonzero threshold - knn is more efficient
graphtype = "knn"

# set base graph type
if graphtype == "knn":
Expand Down
5 changes: 3 additions & 2 deletions graphtools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sklearn.utils.fixes import signature
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.preprocessing import normalize
from sklearn.utils.graph import graph_shortest_path
from scipy import sparse
import warnings
import numbers
Expand Down Expand Up @@ -643,13 +644,13 @@ class PyGSPGraph(with_metaclass(abc.ABCMeta, pygsp.graphs.Graph, Base)):
kernel matrix
"""

def __init__(self, gtype='unknown', lap_type='combinatorial', coords=None,
def __init__(self, lap_type='combinatorial', coords=None,
plotting=None, **kwargs):
if plotting is None:
plotting = {}
W = self._build_weight_from_kernel(self.K)

super().__init__(W=W, gtype=gtype,
super().__init__(W=W,
lap_type=lap_type,
coords=coords,
plotting=plotting, **kwargs)
Expand Down
Loading

0 comments on commit 352e8ff

Please sign in to comment.