Skip to content

Commit

Permalink
Updates to iterative pruning example
Browse files Browse the repository at this point in the history
  • Loading branch information
jni committed Jun 29, 2024
1 parent 8969838 commit 20d6470
Showing 1 changed file with 101 additions and 11 deletions.
112 changes: 101 additions & 11 deletions doc/examples/pruning_skeletons.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.13.6
jupytext_version: 1.16.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand All @@ -19,7 +19,6 @@ removing to leave the skeleton that represents the backbone of the molecule. Ima
x/y plane but are never as clean as one might like. The number of extraneous branches can be reduced by first applying a
Gaussian blur prior to skeletonisation.


```{code-cell} ipython3
%matplotlib inline
%config InlineBackend.figure_format='retina'
Expand All @@ -36,19 +35,26 @@ raw = np.load("../example-data/sample_grain.npy")
plt.imshow(raw)
```

```{code-cell} ipython3
np.max(raw)
```

The mask for this image is rather messy

```{code-cell} ipython3
raw_mask = np.load("../example-data/sample_grain_mask.npy")
from skimage import filters
raw_mask = filters.threshold_li(raw) < raw
plt.imshow(raw_mask)
```

If it is skeletonised as is then we have a skeleton with a large number of side branches and internal loops.
Keeping only the largest component, skeletonizing as-is will give a skeleton with a large number of side branches and internal loops. (We multiply the mask by the raw image to get the height of each point, since this is an AFM image.)

```{code-cell} ipython3
from skimage import morphology
raw_skeleton = morphology.skeletonize(raw_mask, method="zhang")
plt.imshow(raw_skeleton)
large_comp = morphology.remove_small_objects(raw_mask, min_size=100)
raw_skeleton = morphology.skeletonize(large_comp, method="zhang")
plt.imshow(raw_skeleton * raw)
```

We want to remove all the side-branches which are paths that go from a junction to end-point and we can use the
Expand All @@ -67,9 +73,10 @@ each path.
import skan
from skan import Skeleton
height_skeleton = raw_skeleton * raw
# Summarise the skeleton
skeleton = Skeleton(raw_skeleton, keep_images=True)
skeleton_summary = skan.summarize(skeleton)
skeleton = Skeleton(height_skeleton, keep_images=True, value_is_height=True) # TODO: include pixel spacing
skeleton_summary = skan.summarize(skeleton, separator='-')
# Extract the indices of paths of type
junction_to_endpoint = skeleton_summary[skeleton_summary["branch-type"] == 1].index
skeleton_pruned = skeleton.prune_paths(junction_to_endpoint).skeleton_image
Expand All @@ -82,12 +89,96 @@ end-point. Further we observe some small loops that also need removing.
## Iteratively Prune Paths

To address this issue the `iteratively_prune_paths()` function can be used to repeatedly prune skeletons until only a
single path remains, whether that is circular or linear.
single path remains, whether that is circular or linear. However, this function needs the skeleton in another format,
a networkx MultiGraph, because the native data structures in skan are not easy to update repeatedly. We use the `skeleton_to_nx`
function for this.

```{code-cell} ipython3
from skan.csr import skeleton_to_nx # TODO: update to just skan
```

```{code-cell} ipython3
nxskel = skeleton_to_nx(skeleton, skeleton_summary)
```

Now, we can use the iteratively_prune_paths function. This function takes in the graph, and a *discard* predicate, a function that takes as input the graph and an edge ID, and returns True if that edge should be removed. To create a useful predicate, it's good to know what attributes an edge has, which we can do by inspecting an arbitrary edge in our graph. Note the keywords `data=True`, which returns the full data on each edge, and `keys=True`, which returns, in addition to the edge source and target, a *key* for each edge, which distinguishes edges when there are multiple edges between two nodes.

```{code-cell} ipython3
next(iter(nxskel.edges(keys=True, data=True)))
```

In short, it has every attribute in the summary table, as well as the path coordinates (in pixel space), the nonzero pixel indices, and the image values under the array.

+++

Now, looking at the above skeleton, we notice:
- single branches that haven't been removed (because the original pruning function wasn't iterative/recursive
- small self-loops that aren't removed because they aren't technically endpoints
- dual loops where one of two paths joining points is "dimmer" than the other

+++

Let's look at the signature of `iteratively_prune_paths`:

```{code-cell} ipython3
from skan import iteratively_prune_paths
help(iteratively_prune_paths)
```

skeleton_pruned_iteratively = iteratively_prune_paths(skeleton)
So we need to write a function that will identify the edges that we don't want in the graph. To repeat our previous conditions:

- edge branches — one of the endpoints' degrees should be 1.
- self-loops *other* than the final self-loop, which of course we want to keep! 😅
- the "dimmer" edge of a multi-edge pair.

Let's try:

```{code-cell} ipython3
def unwanted(mg, e):
u, v, k = e
# first the easy one: the branch is an endpoint
if mg.degree(u) == 1 or mg.degree(v) == 1:
return True
# next, self-loops, other than the final self-loop
if u == v and len(mg.edges()) > 1:
return True
# finally, the dimmer of two of the same edge.
# We'll use a helper function, 'get_multiedge', that returns
# a sibling multiedge if it exists and None otherwise
if (e2 := get_multiedge(mg, e)) is not None:
# if there is a multiedge, we discard current edge if it's
# dimmer (lower mean pixel value) than its sibling edge
return mg.edges[e]['mean_pixel_value'] < mg.edges[e2]['mean_pixel_value']
return False
def get_multiedge(mg, e):
u, v, k = e
edge_keys = set(mg[u][v]) # g[u][v] returns a view of the keys
if len(edge_keys) > 1: # multiedge
other_key = (edge_keys - {k}).pop()
return (u, v, other_key)
```

```{code-cell} ipython3
skeleton_pruned_iteratively = iteratively_prune_paths(nxskel, discard=unwanted)
```

```{code-cell} ipython3
len(skeleton_pruned_iteratively.nodes())
```

```{code-cell} ipython3
len(skeleton_pruned_iteratively.edges())
```

```{code-cell} ipython3
from skan.csr import nx_to_skeleton
skel_pruned2 = nx_to_skeleton(skeleton_pruned_iteratively)
```

```{code-cell} ipython3
plt.imshow(skeleton_pruned_iteratively.skeleton_image)
```

Expand All @@ -100,7 +191,6 @@ that.
If we apply a [Gaussian blur](https://en.wikipedia.org/wiki/Gaussian_blur) to the raw image before skeletonising we
obtain a much clearer mask.


```{code-cell} ipython3
from skimage import filters
Expand Down

0 comments on commit 20d6470

Please sign in to comment.