From b703cf603f34bff03971cfce1d156f0160206a23 Mon Sep 17 00:00:00 2001 From: Neil Shephard Date: Wed, 1 May 2024 07:06:59 +0100 Subject: [PATCH] Add function to make nx Graph from Skeleton (#224) Further to initial work by @jni to implement NetworkX based pruning I've finally worked out how to reconstruct `Skeleton` from the NetworkX graph by storing the `path_coordinates()` in the `edge properties`. We save the `.path_coordinates()` to the edge properties of a NetworkX graph, using the `row.Index` _and_ the `node_id_src` / `node_id_dst` to define the nodes. This is required because in some of the test instances there are loops with the same `src` and `dst` and if only these are used we would lose information. Early work I was recreating the `summarize()` Pandas dataframes and the number of rows were differing which made me scratch my head until I realised the information was being over-written because the `src` and `dst` were the same. Including the `row.Index` as a proxy for the graph segment avoids this. As the `.path_coordinates()` are saved as `edge` properties in NetworkX object we can now reconstruct the original Numpy array based on this information and convert to a `Skeleton`. Tests are included for a range of sample graphs and exceptions are raised (and tested) when... + Wrong type of object is passed into `nx_to_skeleton()`. + The NetworkX object is missing `edge` properties completely. + The `edge` properties are co-ordinates outside of the original images dimensions. There may be more edge cases that we could/should test for. I found problems in reconstructing `tinyline` as it was a 1-D "image" and so have made this a 2-D "image" with a single row and updated `test_line()` appropriately. This may be inappropriate as the intention is to have a 1-D skeleton but most skeletons will be 2-D, can be reverted if necessary. Also on reviewing this still includes WIP on iterative pruning. It desirable this could be removed to keep the PR more focused on going from and to NetworkX graph objects. --------- Co-authored-by: Juan Nunez-Iglesias --- pyproject.toml | 4 + src/skan/_testdata.py | 38 ++++++- src/skan/csr.py | 175 ++++++++++++++++++++++++++++-- src/skan/summary_utils.py | 42 +++++--- src/skan/test/test_csr.py | 218 ++++++++++++++++++++++++++++++++++++-- 5 files changed, 444 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6bbce6e..eb0077a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,5 +106,9 @@ skan = [ [tool.setuptools_scm] write_to = 'src/skan/_version.py' +[tool.pytest.ini_options] +minversion = "7.4.2" +addopts = "-W ignore" + [project.entry-points.'napari.manifest'] skan-napari = 'skan:napari.yaml' diff --git a/src/skan/_testdata.py b/src/skan/_testdata.py index cb1670fa..43b3e919 100644 --- a/src/skan/_testdata.py +++ b/src/skan/_testdata.py @@ -1,5 +1,7 @@ +import networkx as nx import numpy as np - +from skimage.draw import random_shapes +from skimage.morphology import skeletonize tinycycle = np.array([[0, 1, 0], [1, 0, 1], @@ -71,3 +73,37 @@ [3, 0, 0, 1, 0, 0, 0], [3, 0, 0, 0, 1, 1, 1], [0, 3, 0, 0, 0, 1, 0]], dtype=int) + + +def _generate_random_skeleton(**extra_kwargs): + """Generate random skeletons using skimage.draw's random_shapes.""" + kwargs = {"image_shape": (128, 128), + "max_shapes": 20, + "channel_axis": None, + "shape": None, + "allow_overlap": True} + random_image, _ = random_shapes(**kwargs, **extra_kwargs) + mask = random_image != 255 + return skeletonize(mask) + +# Generate random skeletons: + +# Skeleton with loop to be retained and side-branches +skeleton_loop1 = _generate_random_skeleton(rng=1, min_size=20) +# Skeleton with loop to be retained and side-branches +skeleton_loop2 = _generate_random_skeleton(rng=165103, min_size=60) +# Linear skeleton with lots of large side-branches, some forked +skeleton_linear1 = _generate_random_skeleton(rng=13588686514, min_size=20) +# Linear Skeleton with simple fork at one end +skeleton_linear2 = _generate_random_skeleton(rng=21, min_size=20) +# Linear Skeletons (i.e. multiple) with branches +skeleton_linear3 = _generate_random_skeleton(rng=894632511, min_size=20) + +## Sample NetworkX Graphs... +# ...with no edge attributes +nx_graph = nx.Graph() +nx_graph.add_nodes_from([1, 2, 3]) +# ...with edge attributes +nx_graph_edges = nx.Graph() +nx_graph_edges.add_nodes_from([1, 2, 3]) +nx_graph_edges.add_edge(1, 2, **{"path": np.asarray([[4, 4]])}) diff --git a/src/skan/csr.py b/src/skan/csr.py index 1d529cc8..18b7aad5 100644 --- a/src/skan/csr.py +++ b/src/skan/csr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import networkx as nx import numpy as np import pandas as pd from scipy import sparse, ndimage as ndi @@ -11,7 +12,7 @@ import numpy.typing as npt import numba import warnings - +from typing import Tuple, Callable from .nputil import _raveled_offsets_and_distances from .summary_utils import find_main_branches @@ -202,8 +203,11 @@ def csr_to_nbgraph(csr, node_props=None): node_props = np.broadcast_to(1., csr.shape[0]) node_props.flags.writeable = True return NBGraph( - csr.indptr, csr.indices, csr.data, - np.array(csr.shape, dtype=np.int32), node_props + csr.indptr, + csr.indices, + csr.data, + np.array(csr.shape, dtype=np.int32), + node_props, ) @@ -345,8 +349,14 @@ def _build_paths(jgraph, indptr, indices, path_data, visited, degrees): for neighbor in jgraph.neighbors(node): if not visited.edge(node, neighbor): n_steps = _walk_path( - jgraph, node, neighbor, visited, degrees, indices, - path_data, indices_j + jgraph, + node, + neighbor, + visited, + degrees, + indices, + path_data, + indices_j, ) indptr[indptr_i + 1] = indptr[indptr_i] + n_steps indptr_i += 1 @@ -357,8 +367,14 @@ def _build_paths(jgraph, indptr, indices, path_data, visited, degrees): neighbor = jgraph.neighbors(node)[0] if not visited.edge(node, neighbor): n_steps = _walk_path( - jgraph, node, neighbor, visited, degrees, indices, - path_data, indices_j + jgraph, + node, + neighbor, + visited, + degrees, + indices, + path_data, + indices_j, ) indptr[indptr_i + 1] = indptr[indptr_i] + n_steps indptr_i += 1 @@ -395,7 +411,7 @@ def _build_skeleton_path_graph(graph): visited_data = np.zeros(graph.data.shape, dtype=bool) visited = NBGraphBool( graph.indptr, graph.indices, visited_data, graph.shape, - np.broadcast_to(1., graph.shape[0]) + np.broadcast_to(1.0, graph.shape[0]) ) endpoints = (degrees != 2) endpoint_degrees = degrees[endpoints] @@ -521,6 +537,7 @@ def __init__( self._distances_initialized = False self.skeleton_image = None self.skeleton_shape = skeleton_image.shape + self.skeleton_dtype = skeleton_image.dtype self.source_image = None self.degrees = np.diff(self.graph.indptr) self.spacing = ( @@ -1043,6 +1060,7 @@ def make_degree_image(skeleton_image): else: import dask.array as da from dask_image.ndfilters import convolve as dask_convolve + if isinstance(bool_skeleton, da.Array): degree_image = bool_skeleton * dask_convolve( bool_skeleton.astype(int), degree_kernel, mode='constant' @@ -1221,3 +1239,144 @@ def sholl_analysis(skeleton, center=None, shells=None): intersection_counts = np.bincount(shells, minlength=len(shell_radii)) // 2 return center, shell_radii, intersection_counts + + +def skeleton_to_nx( + skeleton: Skeleton, + summary: pd.DataFrame | None = None, + ) -> nx.MultiGraph: + """Convert a Skeleton object to a networkx Graph. + + Parameters + ---------- + skeleton : Skeleton + The skeleton to convert. + summary : pd.DataFrame | None + The summary statistics of the skeleton. Each row in the summary table + is an edge in the networkx graph. It is not necessary to pass this in + because it can be computed from the input skeleton, but if it is + already computed, it will speed up this function. + + Returns + ------- + g : nx.MultiGraph + A graph where each node is a junction or endpoint in the skeleton, and + each edge is a path. + """ + if summary is None: + summary = summarize(skeleton, separator='_') + # Ensure underscores in column names + csum = summary.rename(columns=lambda s: s.replace('-', '_')) + g = nx.MultiGraph( + shape=skeleton.skeleton_shape, dtype=skeleton.skeleton_dtype + ) + for row in csum.itertuples(name='Edge'): + index = row.Index + i = row.node_id_src + j = row.node_id_dst + indices, values = skeleton.path_with_data(index) + # Nodes are added if they don't exist so only need to add edges + g.add_edge( + i, j, **{ + 'path': skeleton.path_coordinates(index), + 'indices': indices, + 'values': values, + } + ) + return g + + +def nx_to_skeleton(g: nx.Graph | nx.MultiGraph) -> Skeleton: + """Convert a Networkx Graph to a Skeleton object. + + The NetworkX Graph should have the following graph properties: + - 'shape': the shape of the image from which the graph was created. + - 'dtype': the dtype of the image from which the graph was created. + + and the following edge properties: + - 'path': an (N, Ndim) array of coordinates in the image of the pixels + traced by this edge. + - 'values': an (N,) array of values for the pixels traced by this edge. + + The `skeleton_to_nx` function can produce such a graph from a `Skeleton` + object. + + Parameters + ---------- + g: nx.Graph + Networkx graph to convert to Skeleton. + + Returns + ------- + Skeleton + Skeleton object corresponding to the input graph. + + Notes + ----- + Currently, this method uses the naive, brute-force approach of regenerating + the entire image from the graph, and then generating the Skeleton from the + image. It is probably low-hanging fruit to instead generate the Skeleton + compressed-sparse-row (CSR) data directly. + """ + image = np.zeros(g.graph['shape'], dtype=g.graph['dtype']) + all_coords = np.concatenate([path for _, _, path in g.edges.data('path')], + axis=0) + all_values = np.concatenate([ + values for _, _, values in g.edges.data('values') + ], + axis=0) + + image[tuple(all_coords.T)] = all_values + + return Skeleton(image) + + +def _merge_paths(p1: npt.NDArray, p2: npt.NDArray): + """Join two paths together that have a common endpoint.""" + return np.concatenate([p1[:-1], p2], axis=0) + + +def _merge_edges(g: nx.Graph, e1: tuple[int], e2: tuple[int]): + middle_node = set(e1) & set(e2) + new_edge = sorted( + (set(e1) | set(e2)) - {middle_node}, + key=lambda i: i in e2, + ) + d1 = g.edges[e1] + d2 = g.edges[e2] + p1 = d1['path'] if e1[1] == middle_node else d1['path'][::-1] + p2 = d2['path'] if e2[0] == middle_node else d2['path'][::-1] + n1 = len(d1['path']) + n2 = len(d2['path']) + new_edge_values = { + 'skeleton_id': + g.edges[e1]['skeleton_id'], + 'node_id_src': + new_edge[0], + 'node_id_dst': + new_edge[1], + 'branch_distance': + d1['branch_distance'] + d2['branch_distance'], + 'branch_type': + min(d1['branch_type'], d2['branch_type']), + 'mean_pixel_value': ( + n1 * d1['mean_pixel_value'] + n2 * d2['mean_pixel_value'] + ) / (n1+n2), + 'stdev_pixel_value': + np.sqrt(( + d1['stdev_pixel_value']**2 * + (n1-1) + d2['stdev_pixel_value']**2 * (n2-1) + ) / (n1+n2-1)), + 'path': + _merge_paths(p1, p2), + } + g.add_edge(new_edge[0], new_edge[1], **new_edge_values) + g.remove_node(middle_node) + + +def _remove_simple_path_nodes(g): + """Remove any nodes of degree 2 by merging their incident edges.""" + to_remove = [n for n in g.nodes if g.degree(n) == 2] + for u in to_remove: + v, w = g[u].keys() + _merge_edges(g, (u, v), (u, w)) diff --git a/src/skan/summary_utils.py b/src/skan/summary_utils.py index 45aaf94f..dc7afff5 100644 --- a/src/skan/summary_utils.py +++ b/src/skan/summary_utils.py @@ -4,6 +4,29 @@ import toolz as tz +def find_main_branch_nx(g: nx.Graph, weight='branch-distance', in_place=True): + """Find longest shortest paths in g and annotate edges on the path.""" + if not in_place: + g = g.copy() + for conn in nx.connected_components(g): + curr_val = 0 + curr_pair = None + h = g.subgraph(conn) + p = dict(nx.all_pairs_dijkstra_path_length(h, weight=weight)) + for src in p: + for dst in p[src]: + val = p[src][dst] + if val is not None and np.isfinite(val) and val >= curr_val: + curr_val = val + curr_pair = (src, dst) + for i, j in tz.sliding_window(2, nx.shortest_path(h, + source=curr_pair[0], + target=curr_pair[1], + weight=weight)): + g.edges[i, j]['main'] = True + return g + + def find_main_branches(summary: DataFrame) -> np.ndarray: """Predict the extent of branching. @@ -32,21 +55,8 @@ def find_main_branches(summary: DataFrame) -> np.ndarray: g.add_weighted_edges_from(zip(us, vs, ws)) - for conn in nx.connected_components(g): - curr_val = 0 - curr_pair = None - h = g.subgraph(conn) - p = dict(nx.all_pairs_dijkstra_path_length(h)) - for src in p: - for dst in p[src]: - val = p[src][dst] - if (val is not None and np.isfinite(val) and val >= curr_val): - curr_val = val - curr_pair = (src, dst) - for i, j in tz.sliding_window(2, nx.shortest_path(h, - source=curr_pair[0], - target=curr_pair[1], - weight='weight')): - is_main[edge2idx[(i, j)]] = 1 + h = find_main_branch_nx(g) + for i, j in h.edges(): + is_main[edge2idx[(i, j)]] = 1 return is_main diff --git a/src/skan/test/test_csr.py b/src/skan/test/test_csr.py index fd595972..246db331 100644 --- a/src/skan/test/test_csr.py +++ b/src/skan/test/test_csr.py @@ -1,15 +1,35 @@ +from __future__ import annotations +import sys + from collections import defaultdict from itertools import product +from typing import Any -import pytest import numpy as np +import numpy.typing as npt from numpy.testing import assert_equal, assert_almost_equal +import pandas as pd +import pytest from skimage.draw import line -from skan import csr, summarize +from skan import csr from skan._testdata import ( - tinycycle, tinyline, skeleton0, skeleton1, skeleton2, skeleton3d, - topograph1d, skeleton4, skeletonlabel + tinycycle, + tinyline, + skeleton0, + skeleton1, + skeleton2, + skeleton3d, + topograph1d, + skeleton4, + skeletonlabel, + skeleton_loop1, + skeleton_loop2, + skeleton_linear1, + skeleton_linear2, + skeleton_linear3, + nx_graph, + nx_graph_edges, ) @@ -205,14 +225,14 @@ def test_transpose_image(): [0, 3, 2, 0, 1, 1, 0], [3, 0, 0, 4, 0, 0, 0], [3, 0, 0, 0, 4, 4, 4]]) ), - ] + ], ) def test_prune_paths( skeleton: np.ndarray, prune_branch: int, target: np.ndarray ) -> None: """Test pruning of paths.""" s = csr.Skeleton(skeleton, keep_images=True) - summary = summarize(s, separator='_') + summary = csr.summarize(s, separator='_') indices_to_remove = summary.loc[summary['branch_type'] == prune_branch ].index pruned = s.prune_paths(indices_to_remove) @@ -223,7 +243,7 @@ def test_prune_paths_exception_single_point() -> None: """Test exceptions raised when pruning leaves a single point and Skeleton object can not be created and returned.""" s = csr.Skeleton(skeleton0) - summary = summarize(s, separator='_') + summary = csr.summarize(s, separator='_') indices_to_remove = summary.loc[summary['branch_type'] == 1].index with pytest.raises(ValueError): s.prune_paths(indices_to_remove) @@ -233,7 +253,7 @@ def test_prune_paths_exception_invalid_path_index() -> None: """Test exceptions raised when trying to prune paths that do not exist in the summary. This can arise if skeletons are not updated correctly during iterative pruning.""" s = csr.Skeleton(skeleton0) - summary = summarize(s, separator='_') + summary = csr.summarize(s, separator='_') indices_to_remove = [6] with pytest.raises(ValueError): s.prune_paths(indices_to_remove) @@ -317,6 +337,12 @@ def test_skeleton_path_image_no_keep_image(): assert np.max(pli) == s.n_paths +def test_skeletonlabel(): + stats = csr.summarize(csr.Skeleton(skeletonlabel)) + assert stats['mean-pixel-value'].max() == skeletonlabel.max() + assert stats['mean-pixel-value'].max() > 1 + + @pytest.mark.parametrize( 'dtype', [ ''.join([pre, 'int', suf]) @@ -336,3 +362,179 @@ def test_default_summarize_separator(): match='separator in column name'): stats = csr.summarize(csr.Skeleton(skeletonlabel)) assert 'skeleton-id' in stats + + +def test_skeletonlabel(): + stats = csr.summarize(csr.Skeleton(skeletonlabel)) + assert stats['mean-pixel-value'].max() == skeletonlabel.max() + assert stats['mean-pixel-value'].max() > 1 + + +@pytest.mark.parametrize( + ('np_skeleton', 'summary', 'nodes', 'edges'), + [ + pytest.param( + tinycycle, None, 1, 1, id='tinycircle (no summary)' + ), + pytest.param(tinyline, None, 2, 1, id='tinyline (no summary)'), + pytest.param( + skeleton0, + csr.summarize(csr.Skeleton(skeleton0), separator='_'), + 4, + 3, + id='skeleton0 (with summary)' + ), + pytest.param( + skeleton1, None, 4, 4, id='skeleton1 (no summary)' + ), + pytest.param( + skeleton1, + csr.summarize(csr.Skeleton(skeleton1)), + 4, + 4, + id='skeleton1 (with summary)' + ), + pytest.param( + skeleton2, + csr.summarize(csr.Skeleton(skeleton2)), + 8, + 8, + id='skeleton2 (with summary)' + ), + pytest.param( + skeleton3d, None, 7, 7, id='skeleton3d (no summary)' + ), + pytest.param( + skeleton_loop1, + None, + 10, + 10, + id='skeleton_loop1 (no summary)' + ), + pytest.param( + skeleton_loop2, + None, + 10, + 10, + id='skeleton_loop2 (no summary)' + ), + pytest.param( + skeleton_linear1, + None, + 24, + 24, + id='skeleton_linear1 (no summary)', + marks=pytest.mark.xfail( + sys.version_info[:2] == (3, 8), + reason='Incorrect edege discovery (#225)' + ) + ), + pytest.param( + skeleton_linear2, + None, + 4, + 3, + id='skeleton_linear2 (no summary)' + ), + pytest.param( + skeleton_linear3, + None, + 20, + 17, + id='skeleton_linear3 (no summary)' + ), + ], + ) +def test_skeleton_to_nx( + np_skeleton: npt.NDArray, summary: pd.DataFrame | None, edges: int, + nodes: int + ) -> None: + """Test creation of NetworkX Graph from skeletons arrays and summary.""" + skeleton = csr.Skeleton(np_skeleton) + skan_nx = csr.skeleton_to_nx(skeleton, summary) + assert skan_nx.number_of_nodes() == nodes + assert skan_nx.number_of_edges() == edges + + +@pytest.mark.parametrize( + ('np_skeleton', 'summary'), + [ + pytest.param( + tinycycle, + csr.summarize(csr.Skeleton(tinycycle)), + id='tinycircle' + ), + pytest.param( + tinyline, + csr.summarize(csr.Skeleton(tinyline)), + id='tinyline' + ), + pytest.param( + skeleton0, + csr.summarize(csr.Skeleton(skeleton0)), + id='skeleton0' + ), + pytest.param( + skeleton1, + csr.summarize(csr.Skeleton(skeleton1)), + id='skeleton1' + ), + pytest.param( + skeleton3d, + csr.summarize(csr.Skeleton(skeleton3d)), + id='skeleton3d (no summary)' + ), + pytest.param( + skeleton_loop1, + csr.summarize(csr.Skeleton(skeleton_loop1)), + id='skeleton_loop1' + ), + pytest.param( + skeleton_loop2, + csr.summarize(csr.Skeleton(skeleton_loop2)), + id='skeleton_loop2' + ), + pytest.param( + skeleton_linear1, + csr.summarize(csr.Skeleton(skeleton_linear1)), + id='skeleton_lienar1' + ), + pytest.param( + skeleton_linear2, + csr.summarize(csr.Skeleton(skeleton_linear2)), + id='skeleton_linear2' + ), + pytest.param( + skeleton_linear3, + csr.summarize(csr.Skeleton(skeleton_linear3)), + id='skeleton_linear3' + ), + ], + ) +def test_nx_to_skeleton( + np_skeleton: npt.NDArray, + summary: pd.DataFrame | None, + ) -> None: + """Test creation of Skeleton from NetworkX Graph.""" + skeleton = csr.Skeleton(np_skeleton) + skan_nx = csr.skeleton_to_nx(skeleton, summary) + skeleton_nx = csr.nx_to_skeleton(skan_nx) + np.testing.assert_array_equal(np_skeleton, skeleton_nx.skeleton_image) + + +@pytest.mark.parametrize( + 'wrong_skeleton', + [ + pytest.param(skeleton0, id='Numpy Array.'), + pytest.param(csr.Skeleton(skeleton0), id='Skeleton.'), + pytest.param(nx_graph, id='NetworkX Graph without edges.'), + pytest.param( + nx_graph_edges, + id='NetworkX Graph with points outside image.' + ), + ], + ) +def test_nx_to_skeleton_attribute_error(wrong_skeleton: Any) -> None: + """Test various errors are raised by nx_to_skeleton().""" + with pytest.raises(Exception): + csr.nx_to_skeleton(wrong_skeleton)