Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to make nx Graph from Skeleton #224

Merged
merged 22 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,9 @@ skan = [
[tool.setuptools_scm]
write_to = 'src/skan/_version.py'

[tool.pytest.ini_options]
minversion = "7.4.2"
addopts = "-W ignore"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add this? (ie what warnings are we ignoring here and is that wise 😅)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently its our own deprecation warnings that were recently added for the removal of _.

src/skan/test/test_csr.py:407
  /home/neil/work/git/hub/ns-rse/skan/src/skan/test/test_csr.py:407: VisibleDeprecationWarning: separator in column name will change to _ in version 0.13; to silence this warning, use `separator='-'` to maintain current behavior and use `separator='_'` to switch to the new default behavior.
    csr.summarize(csr.Skeleton(skeleton1)),

Happy to remove this though I only added it so it was easier to read test output whilst writing them.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it for now and make an issue to either wrap those in pytest.warns or pass the separator arg.


[project.entry-points.'napari.manifest']
skan-napari = 'skan:napari.yaml'
38 changes: 37 additions & 1 deletion src/skan/_testdata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import networkx as nx
import numpy as np

from skimage.draw import random_shapes
from skimage.morphology import skeletonize

tinycycle = np.array([[0, 1, 0],
[1, 0, 1],
Expand Down Expand Up @@ -71,3 +73,37 @@
[3, 0, 0, 1, 0, 0, 0],
[3, 0, 0, 0, 1, 1, 1],
[0, 3, 0, 0, 0, 1, 0]], dtype=int)


def _generate_random_skeleton(**extra_kwargs):
"""Generate random skeletons using skimage.draw's random_shapes."""
kwargs = {"image_shape": (128, 128),
"max_shapes": 20,
"channel_axis": None,
"shape": None,
"allow_overlap": True}
random_image, _ = random_shapes(**kwargs, **extra_kwargs)
mask = random_image != 255
return skeletonize(mask)

# Generate random skeletons:

# Skeleton with loop to be retained and side-branches
skeleton_loop1 = _generate_random_skeleton(rng=1, min_size=20)
# Skeleton with loop to be retained and side-branches
skeleton_loop2 = _generate_random_skeleton(rng=165103, min_size=60)
# Linear skeleton with lots of large side-branches, some forked
skeleton_linear1 = _generate_random_skeleton(rng=13588686514, min_size=20)
# Linear Skeleton with simple fork at one end
skeleton_linear2 = _generate_random_skeleton(rng=21, min_size=20)
# Linear Skeletons (i.e. multiple) with branches
skeleton_linear3 = _generate_random_skeleton(rng=894632511, min_size=20)

## Sample NetworkX Graphs...
# ...with no edge attributes
nx_graph = nx.Graph()
nx_graph.add_nodes_from([1, 2, 3])
# ...with edge attributes
nx_graph_edges = nx.Graph()
nx_graph_edges.add_nodes_from([1, 2, 3])
nx_graph_edges.add_edge(1, 2, **{"path": np.asarray([[4, 4]])})
175 changes: 167 additions & 8 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import networkx as nx
import numpy as np
import pandas as pd
from scipy import sparse, ndimage as ndi
Expand All @@ -11,7 +12,7 @@
import numpy.typing as npt
import numba
import warnings

from typing import Tuple, Callable
from .nputil import _raveled_offsets_and_distances
from .summary_utils import find_main_branches

Expand Down Expand Up @@ -202,8 +203,11 @@ def csr_to_nbgraph(csr, node_props=None):
node_props = np.broadcast_to(1., csr.shape[0])
node_props.flags.writeable = True
return NBGraph(
csr.indptr, csr.indices, csr.data,
np.array(csr.shape, dtype=np.int32), node_props
csr.indptr,
csr.indices,
csr.data,
np.array(csr.shape, dtype=np.int32),
node_props,
)


Expand Down Expand Up @@ -345,8 +349,14 @@ def _build_paths(jgraph, indptr, indices, path_data, visited, degrees):
for neighbor in jgraph.neighbors(node):
if not visited.edge(node, neighbor):
n_steps = _walk_path(
jgraph, node, neighbor, visited, degrees, indices,
path_data, indices_j
jgraph,
node,
neighbor,
visited,
degrees,
indices,
path_data,
indices_j,
)
indptr[indptr_i + 1] = indptr[indptr_i] + n_steps
indptr_i += 1
Expand All @@ -357,8 +367,14 @@ def _build_paths(jgraph, indptr, indices, path_data, visited, degrees):
neighbor = jgraph.neighbors(node)[0]
if not visited.edge(node, neighbor):
n_steps = _walk_path(
jgraph, node, neighbor, visited, degrees, indices,
path_data, indices_j
jgraph,
node,
neighbor,
visited,
degrees,
indices,
path_data,
indices_j,
)
indptr[indptr_i + 1] = indptr[indptr_i] + n_steps
indptr_i += 1
Expand Down Expand Up @@ -395,7 +411,7 @@ def _build_skeleton_path_graph(graph):
visited_data = np.zeros(graph.data.shape, dtype=bool)
visited = NBGraphBool(
graph.indptr, graph.indices, visited_data, graph.shape,
np.broadcast_to(1., graph.shape[0])
np.broadcast_to(1.0, graph.shape[0])
)
endpoints = (degrees != 2)
endpoint_degrees = degrees[endpoints]
Expand Down Expand Up @@ -521,6 +537,7 @@ def __init__(
self._distances_initialized = False
self.skeleton_image = None
self.skeleton_shape = skeleton_image.shape
self.skeleton_dtype = skeleton_image.dtype
self.source_image = None
self.degrees = np.diff(self.graph.indptr)
self.spacing = (
Expand Down Expand Up @@ -1043,6 +1060,7 @@ def make_degree_image(skeleton_image):
else:
import dask.array as da
from dask_image.ndfilters import convolve as dask_convolve

if isinstance(bool_skeleton, da.Array):
degree_image = bool_skeleton * dask_convolve(
bool_skeleton.astype(int), degree_kernel, mode='constant'
Expand Down Expand Up @@ -1221,3 +1239,144 @@ def sholl_analysis(skeleton, center=None, shells=None):
intersection_counts = np.bincount(shells, minlength=len(shell_radii)) // 2

return center, shell_radii, intersection_counts


def skeleton_to_nx(
skeleton: Skeleton,
summary: pd.DataFrame | None = None,
) -> nx.MultiGraph:
"""Convert a Skeleton object to a networkx Graph.

Parameters
----------
skeleton : Skeleton
The skeleton to convert.
summary : pd.DataFrame | None
The summary statistics of the skeleton. Each row in the summary table
is an edge in the networkx graph. It is not necessary to pass this in
because it can be computed from the input skeleton, but if it is
already computed, it will speed up this function.

Returns
-------
g : nx.MultiGraph
A graph where each node is a junction or endpoint in the skeleton, and
each edge is a path.
"""
if summary is None:
summary = summarize(skeleton, separator='_')
# Ensure underscores in column names
csum = summary.rename(columns=lambda s: s.replace('-', '_'))
g = nx.MultiGraph(
shape=skeleton.skeleton_shape, dtype=skeleton.skeleton_dtype
)
for row in csum.itertuples(name='Edge'):
index = row.Index
i = row.node_id_src
j = row.node_id_dst
indices, values = skeleton.path_with_data(index)
# Nodes are added if they don't exist so only need to add edges
g.add_edge(
i, j, **{
'path': skeleton.path_coordinates(index),
'indices': indices,
'values': values,
}
)
return g


def nx_to_skeleton(g: nx.Graph | nx.MultiGraph) -> Skeleton:
"""Convert a Networkx Graph to a Skeleton object.

The NetworkX Graph should have the following graph properties:
- 'shape': the shape of the image from which the graph was created.
- 'dtype': the dtype of the image from which the graph was created.

and the following edge properties:
- 'path': an (N, Ndim) array of coordinates in the image of the pixels
traced by this edge.
- 'values': an (N,) array of values for the pixels traced by this edge.

The `skeleton_to_nx` function can produce such a graph from a `Skeleton`
object.

Parameters
----------
g: nx.Graph
Networkx graph to convert to Skeleton.

Returns
-------
Skeleton
Skeleton object corresponding to the input graph.

Notes
-----
Currently, this method uses the naive, brute-force approach of regenerating
the entire image from the graph, and then generating the Skeleton from the
image. It is probably low-hanging fruit to instead generate the Skeleton
compressed-sparse-row (CSR) data directly.
"""
image = np.zeros(g.graph['shape'], dtype=g.graph['dtype'])
all_coords = np.concatenate([path for _, _, path in g.edges.data('path')],
axis=0)
all_values = np.concatenate([
values for _, _, values in g.edges.data('values')
],
axis=0)

image[tuple(all_coords.T)] = all_values

return Skeleton(image)


def _merge_paths(p1: npt.NDArray, p2: npt.NDArray):
"""Join two paths together that have a common endpoint."""
return np.concatenate([p1[:-1], p2], axis=0)


def _merge_edges(g: nx.Graph, e1: tuple[int], e2: tuple[int]):
middle_node = set(e1) & set(e2)
new_edge = sorted(
(set(e1) | set(e2)) - {middle_node},
key=lambda i: i in e2,
)
d1 = g.edges[e1]
d2 = g.edges[e2]
p1 = d1['path'] if e1[1] == middle_node else d1['path'][::-1]
p2 = d2['path'] if e2[0] == middle_node else d2['path'][::-1]
n1 = len(d1['path'])
n2 = len(d2['path'])
new_edge_values = {
'skeleton_id':
g.edges[e1]['skeleton_id'],
'node_id_src':
new_edge[0],
'node_id_dst':
new_edge[1],
'branch_distance':
d1['branch_distance'] + d2['branch_distance'],
'branch_type':
min(d1['branch_type'], d2['branch_type']),
'mean_pixel_value': (
n1 * d1['mean_pixel_value'] + n2 * d2['mean_pixel_value']
) / (n1+n2),
'stdev_pixel_value':
np.sqrt((
d1['stdev_pixel_value']**2 *
(n1-1) + d2['stdev_pixel_value']**2 * (n2-1)
) / (n1+n2-1)),
'path':
_merge_paths(p1, p2),
}
g.add_edge(new_edge[0], new_edge[1], **new_edge_values)
g.remove_node(middle_node)


def _remove_simple_path_nodes(g):
"""Remove any nodes of degree 2 by merging their incident edges."""
to_remove = [n for n in g.nodes if g.degree(n) == 2]
for u in to_remove:
v, w = g[u].keys()
_merge_edges(g, (u, v), (u, w))
42 changes: 26 additions & 16 deletions src/skan/summary_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,29 @@
import toolz as tz


def find_main_branch_nx(g: nx.Graph, weight='branch-distance', in_place=True):
"""Find longest shortest paths in g and annotate edges on the path."""
if not in_place:
g = g.copy()
for conn in nx.connected_components(g):
curr_val = 0
curr_pair = None
h = g.subgraph(conn)
p = dict(nx.all_pairs_dijkstra_path_length(h, weight=weight))
for src in p:
for dst in p[src]:
val = p[src][dst]
if val is not None and np.isfinite(val) and val >= curr_val:
curr_val = val
curr_pair = (src, dst)
for i, j in tz.sliding_window(2, nx.shortest_path(h,
source=curr_pair[0],
target=curr_pair[1],
weight=weight)):
g.edges[i, j]['main'] = True
return g


def find_main_branches(summary: DataFrame) -> np.ndarray:
"""Predict the extent of branching.

Expand Down Expand Up @@ -32,21 +55,8 @@ def find_main_branches(summary: DataFrame) -> np.ndarray:

g.add_weighted_edges_from(zip(us, vs, ws))

for conn in nx.connected_components(g):
curr_val = 0
curr_pair = None
h = g.subgraph(conn)
p = dict(nx.all_pairs_dijkstra_path_length(h))
for src in p:
for dst in p[src]:
val = p[src][dst]
if (val is not None and np.isfinite(val) and val >= curr_val):
curr_val = val
curr_pair = (src, dst)
for i, j in tz.sliding_window(2, nx.shortest_path(h,
source=curr_pair[0],
target=curr_pair[1],
weight='weight')):
is_main[edge2idx[(i, j)]] = 1
h = find_main_branch_nx(g)
for i, j in h.edges():
is_main[edge2idx[(i, j)]] = 1

return is_main
Loading
Loading