Skip to content

Commit

Permalink
Add function to make nx Graph from Skeleton (#224)
Browse files Browse the repository at this point in the history
Further to initial work by @jni to implement NetworkX based pruning I've
finally worked out how to reconstruct `Skeleton` from the NetworkX graph
by storing the `path_coordinates()` in the `edge properties`.

We save the `.path_coordinates()` to the edge properties of a NetworkX
graph, using the `row.Index` _and_ the `node_id_src` / `node_id_dst` to
define the nodes. This is required because in some of the test instances
there are loops with the same `src` and `dst` and if only these are used
we would lose information. Early work I was recreating the `summarize()`
Pandas dataframes and the number of rows were differing which made me
scratch my head until I realised the information was being over-written
because the `src` and `dst` were the same. Including the `row.Index` as
a proxy for the graph segment avoids this.

As the `.path_coordinates()` are saved as `edge` properties in NetworkX
object we can now reconstruct the original Numpy array based on this
information and convert to a `Skeleton`.

Tests are included for a range of sample graphs and exceptions are
raised (and tested) when...

+ Wrong type of object is passed into `nx_to_skeleton()`.
+ The NetworkX object is missing `edge` properties completely.
+ The `edge` properties are co-ordinates outside of the original images
dimensions.

There may be more edge cases that we could/should test for.

I found problems in reconstructing `tinyline` as it was a 1-D "image"
and so have made this a 2-D "image" with a single row and updated
`test_line()` appropriately. This may be inappropriate as the intention
is to have a 1-D skeleton but most skeletons will be 2-D, can be
reverted if necessary.

Also on reviewing this still includes WIP on iterative pruning. It
desirable this could be removed to keep the PR more focused on going
from and to NetworkX graph objects.

---------

Co-authored-by: Juan Nunez-Iglesias <[email protected]>
  • Loading branch information
ns-rse and jni authored May 1, 2024
1 parent 63aa8a7 commit b703cf6
Show file tree
Hide file tree
Showing 5 changed files with 444 additions and 33 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,9 @@ skan = [
[tool.setuptools_scm]
write_to = 'src/skan/_version.py'

[tool.pytest.ini_options]
minversion = "7.4.2"
addopts = "-W ignore"

[project.entry-points.'napari.manifest']
skan-napari = 'skan:napari.yaml'
38 changes: 37 additions & 1 deletion src/skan/_testdata.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import networkx as nx
import numpy as np

from skimage.draw import random_shapes
from skimage.morphology import skeletonize

tinycycle = np.array([[0, 1, 0],
[1, 0, 1],
Expand Down Expand Up @@ -71,3 +73,37 @@
[3, 0, 0, 1, 0, 0, 0],
[3, 0, 0, 0, 1, 1, 1],
[0, 3, 0, 0, 0, 1, 0]], dtype=int)


def _generate_random_skeleton(**extra_kwargs):
"""Generate random skeletons using skimage.draw's random_shapes."""
kwargs = {"image_shape": (128, 128),
"max_shapes": 20,
"channel_axis": None,
"shape": None,
"allow_overlap": True}
random_image, _ = random_shapes(**kwargs, **extra_kwargs)
mask = random_image != 255
return skeletonize(mask)

# Generate random skeletons:

# Skeleton with loop to be retained and side-branches
skeleton_loop1 = _generate_random_skeleton(rng=1, min_size=20)
# Skeleton with loop to be retained and side-branches
skeleton_loop2 = _generate_random_skeleton(rng=165103, min_size=60)
# Linear skeleton with lots of large side-branches, some forked
skeleton_linear1 = _generate_random_skeleton(rng=13588686514, min_size=20)
# Linear Skeleton with simple fork at one end
skeleton_linear2 = _generate_random_skeleton(rng=21, min_size=20)
# Linear Skeletons (i.e. multiple) with branches
skeleton_linear3 = _generate_random_skeleton(rng=894632511, min_size=20)

## Sample NetworkX Graphs...
# ...with no edge attributes
nx_graph = nx.Graph()
nx_graph.add_nodes_from([1, 2, 3])
# ...with edge attributes
nx_graph_edges = nx.Graph()
nx_graph_edges.add_nodes_from([1, 2, 3])
nx_graph_edges.add_edge(1, 2, **{"path": np.asarray([[4, 4]])})
175 changes: 167 additions & 8 deletions src/skan/csr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import networkx as nx
import numpy as np
import pandas as pd
from scipy import sparse, ndimage as ndi
Expand All @@ -11,7 +12,7 @@
import numpy.typing as npt
import numba
import warnings

from typing import Tuple, Callable
from .nputil import _raveled_offsets_and_distances
from .summary_utils import find_main_branches

Expand Down Expand Up @@ -202,8 +203,11 @@ def csr_to_nbgraph(csr, node_props=None):
node_props = np.broadcast_to(1., csr.shape[0])
node_props.flags.writeable = True
return NBGraph(
csr.indptr, csr.indices, csr.data,
np.array(csr.shape, dtype=np.int32), node_props
csr.indptr,
csr.indices,
csr.data,
np.array(csr.shape, dtype=np.int32),
node_props,
)


Expand Down Expand Up @@ -345,8 +349,14 @@ def _build_paths(jgraph, indptr, indices, path_data, visited, degrees):
for neighbor in jgraph.neighbors(node):
if not visited.edge(node, neighbor):
n_steps = _walk_path(
jgraph, node, neighbor, visited, degrees, indices,
path_data, indices_j
jgraph,
node,
neighbor,
visited,
degrees,
indices,
path_data,
indices_j,
)
indptr[indptr_i + 1] = indptr[indptr_i] + n_steps
indptr_i += 1
Expand All @@ -357,8 +367,14 @@ def _build_paths(jgraph, indptr, indices, path_data, visited, degrees):
neighbor = jgraph.neighbors(node)[0]
if not visited.edge(node, neighbor):
n_steps = _walk_path(
jgraph, node, neighbor, visited, degrees, indices,
path_data, indices_j
jgraph,
node,
neighbor,
visited,
degrees,
indices,
path_data,
indices_j,
)
indptr[indptr_i + 1] = indptr[indptr_i] + n_steps
indptr_i += 1
Expand Down Expand Up @@ -395,7 +411,7 @@ def _build_skeleton_path_graph(graph):
visited_data = np.zeros(graph.data.shape, dtype=bool)
visited = NBGraphBool(
graph.indptr, graph.indices, visited_data, graph.shape,
np.broadcast_to(1., graph.shape[0])
np.broadcast_to(1.0, graph.shape[0])
)
endpoints = (degrees != 2)
endpoint_degrees = degrees[endpoints]
Expand Down Expand Up @@ -521,6 +537,7 @@ def __init__(
self._distances_initialized = False
self.skeleton_image = None
self.skeleton_shape = skeleton_image.shape
self.skeleton_dtype = skeleton_image.dtype
self.source_image = None
self.degrees = np.diff(self.graph.indptr)
self.spacing = (
Expand Down Expand Up @@ -1043,6 +1060,7 @@ def make_degree_image(skeleton_image):
else:
import dask.array as da
from dask_image.ndfilters import convolve as dask_convolve

if isinstance(bool_skeleton, da.Array):
degree_image = bool_skeleton * dask_convolve(
bool_skeleton.astype(int), degree_kernel, mode='constant'
Expand Down Expand Up @@ -1221,3 +1239,144 @@ def sholl_analysis(skeleton, center=None, shells=None):
intersection_counts = np.bincount(shells, minlength=len(shell_radii)) // 2

return center, shell_radii, intersection_counts


def skeleton_to_nx(
skeleton: Skeleton,
summary: pd.DataFrame | None = None,
) -> nx.MultiGraph:
"""Convert a Skeleton object to a networkx Graph.
Parameters
----------
skeleton : Skeleton
The skeleton to convert.
summary : pd.DataFrame | None
The summary statistics of the skeleton. Each row in the summary table
is an edge in the networkx graph. It is not necessary to pass this in
because it can be computed from the input skeleton, but if it is
already computed, it will speed up this function.
Returns
-------
g : nx.MultiGraph
A graph where each node is a junction or endpoint in the skeleton, and
each edge is a path.
"""
if summary is None:
summary = summarize(skeleton, separator='_')
# Ensure underscores in column names
csum = summary.rename(columns=lambda s: s.replace('-', '_'))
g = nx.MultiGraph(
shape=skeleton.skeleton_shape, dtype=skeleton.skeleton_dtype
)
for row in csum.itertuples(name='Edge'):
index = row.Index
i = row.node_id_src
j = row.node_id_dst
indices, values = skeleton.path_with_data(index)
# Nodes are added if they don't exist so only need to add edges
g.add_edge(
i, j, **{
'path': skeleton.path_coordinates(index),
'indices': indices,
'values': values,
}
)
return g


def nx_to_skeleton(g: nx.Graph | nx.MultiGraph) -> Skeleton:
"""Convert a Networkx Graph to a Skeleton object.
The NetworkX Graph should have the following graph properties:
- 'shape': the shape of the image from which the graph was created.
- 'dtype': the dtype of the image from which the graph was created.
and the following edge properties:
- 'path': an (N, Ndim) array of coordinates in the image of the pixels
traced by this edge.
- 'values': an (N,) array of values for the pixels traced by this edge.
The `skeleton_to_nx` function can produce such a graph from a `Skeleton`
object.
Parameters
----------
g: nx.Graph
Networkx graph to convert to Skeleton.
Returns
-------
Skeleton
Skeleton object corresponding to the input graph.
Notes
-----
Currently, this method uses the naive, brute-force approach of regenerating
the entire image from the graph, and then generating the Skeleton from the
image. It is probably low-hanging fruit to instead generate the Skeleton
compressed-sparse-row (CSR) data directly.
"""
image = np.zeros(g.graph['shape'], dtype=g.graph['dtype'])
all_coords = np.concatenate([path for _, _, path in g.edges.data('path')],
axis=0)
all_values = np.concatenate([
values for _, _, values in g.edges.data('values')
],
axis=0)

image[tuple(all_coords.T)] = all_values

return Skeleton(image)


def _merge_paths(p1: npt.NDArray, p2: npt.NDArray):
"""Join two paths together that have a common endpoint."""
return np.concatenate([p1[:-1], p2], axis=0)


def _merge_edges(g: nx.Graph, e1: tuple[int], e2: tuple[int]):
middle_node = set(e1) & set(e2)
new_edge = sorted(
(set(e1) | set(e2)) - {middle_node},
key=lambda i: i in e2,
)
d1 = g.edges[e1]
d2 = g.edges[e2]
p1 = d1['path'] if e1[1] == middle_node else d1['path'][::-1]
p2 = d2['path'] if e2[0] == middle_node else d2['path'][::-1]
n1 = len(d1['path'])
n2 = len(d2['path'])
new_edge_values = {
'skeleton_id':
g.edges[e1]['skeleton_id'],
'node_id_src':
new_edge[0],
'node_id_dst':
new_edge[1],
'branch_distance':
d1['branch_distance'] + d2['branch_distance'],
'branch_type':
min(d1['branch_type'], d2['branch_type']),
'mean_pixel_value': (
n1 * d1['mean_pixel_value'] + n2 * d2['mean_pixel_value']
) / (n1+n2),
'stdev_pixel_value':
np.sqrt((
d1['stdev_pixel_value']**2 *
(n1-1) + d2['stdev_pixel_value']**2 * (n2-1)
) / (n1+n2-1)),
'path':
_merge_paths(p1, p2),
}
g.add_edge(new_edge[0], new_edge[1], **new_edge_values)
g.remove_node(middle_node)


def _remove_simple_path_nodes(g):
"""Remove any nodes of degree 2 by merging their incident edges."""
to_remove = [n for n in g.nodes if g.degree(n) == 2]
for u in to_remove:
v, w = g[u].keys()
_merge_edges(g, (u, v), (u, w))
42 changes: 26 additions & 16 deletions src/skan/summary_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,29 @@
import toolz as tz


def find_main_branch_nx(g: nx.Graph, weight='branch-distance', in_place=True):
"""Find longest shortest paths in g and annotate edges on the path."""
if not in_place:
g = g.copy()
for conn in nx.connected_components(g):
curr_val = 0
curr_pair = None
h = g.subgraph(conn)
p = dict(nx.all_pairs_dijkstra_path_length(h, weight=weight))
for src in p:
for dst in p[src]:
val = p[src][dst]
if val is not None and np.isfinite(val) and val >= curr_val:
curr_val = val
curr_pair = (src, dst)
for i, j in tz.sliding_window(2, nx.shortest_path(h,
source=curr_pair[0],
target=curr_pair[1],
weight=weight)):
g.edges[i, j]['main'] = True
return g


def find_main_branches(summary: DataFrame) -> np.ndarray:
"""Predict the extent of branching.
Expand Down Expand Up @@ -32,21 +55,8 @@ def find_main_branches(summary: DataFrame) -> np.ndarray:

g.add_weighted_edges_from(zip(us, vs, ws))

for conn in nx.connected_components(g):
curr_val = 0
curr_pair = None
h = g.subgraph(conn)
p = dict(nx.all_pairs_dijkstra_path_length(h))
for src in p:
for dst in p[src]:
val = p[src][dst]
if (val is not None and np.isfinite(val) and val >= curr_val):
curr_val = val
curr_pair = (src, dst)
for i, j in tz.sliding_window(2, nx.shortest_path(h,
source=curr_pair[0],
target=curr_pair[1],
weight='weight')):
is_main[edge2idx[(i, j)]] = 1
h = find_main_branch_nx(g)
for i, j in h.edges():
is_main[edge2idx[(i, j)]] = 1

return is_main
Loading

0 comments on commit b703cf6

Please sign in to comment.