Skip to content

Commit

Permalink
Updated jem (#31)
Browse files Browse the repository at this point in the history
Updated filters, features, and cli to include functions to generate feature set used for FCD paper
  • Loading branch information
snydek1 authored Jul 2, 2020
1 parent f5bf6d4 commit 1c282e4
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 3 deletions.
48 changes: 46 additions & 2 deletions jem/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
fef_rotational_invariants,
NUM_SCALES,
NORMALIZATION_SCALE,
gaussian_pyramid_features
)


Expand Down Expand Up @@ -158,7 +159,7 @@ def compute_laplacian_pyramid(input_image, num_scales, normalization_scale, outp
)
@click.option("--lowpass/--no-lowpass", default=False)
def riff(input_image, num_scales, normalization_scale, output, lowpass):
"""Compute rotationally invariant features."""
"""Compute rotationally invariant bandpass features."""

click.echo(f"Compute rotationally invariant features for {input_image}.")

Expand Down Expand Up @@ -210,4 +211,47 @@ def riff(input_image, num_scales, normalization_scale, output, lowpass):
out_im.set_data_dtype(np.float32)
out_im.to_filename(output)

click.echo(f"Wrote rotationally invariant features to {output}.")
click.echo(f"Wrote rotationally invariant bandpass features to {output}.")

@click.command()
@click.option(
"--output",
type=click.STRING,
default="out.nii",
help="Output filename for the features image.",
)
@click.argument("input_image", type=click.STRING)
@click.option(
"--num_scales", type=click.INT, default=NUM_SCALES, help="number of spatial scales"
)
@click.option(
"--normalization_scale",
type=click.INT,
default=NORMALIZATION_SCALE,
help="scale for input gain control",
)
def compute_features(input_image, num_scales, normalization_scale, output):
"""Compute rotationally invariant features."""

click.echo(f"Compute rotationally invariant features for {input_image}.")

# open the images
im = nibabel.load(input_image)
data = im.get_data().astype(np.float32)

# Global scaling, signal likelihood, noise level
f, w, sigma = global_scale(data)

# Local scale normalization
f = local_scale_normalization(
f, w=w, sigma=sigma, normalization_scale=normalization_scale
)

# compute features
feats = gaussian_pyramid_features(f, num_scales=num_scales, w=w, sigma=sigma)
feats = np.stack(feats, axis=-1)

# write out the result in the same format and preserve the header
out_im = type(im)(feats, affine=None, header=im.header)
out_im.set_data_dtype(np.float32)
out_im.to_filename(output)
16 changes: 15 additions & 1 deletion jem/features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from .filters import _pinv, dog, dog_rotational_invariants
from .filters import _pinv, dog, dog_rotational_invariants, gradient_amplitude, hessian_amplitude, hessian_trace

# Number of spatial scales
NUM_SCALES = 4
Expand Down Expand Up @@ -215,3 +215,17 @@ def fef_rotational_invariants(fef, inplace=True):
rfef["two"][n] = dog_rotational_invariants(fef["two"][n], order=2)

return rfef

def gaussian_pyramid_features(d, w, sigma, num_scales=NUM_SCALES):
"""Rotational invariants of the gaussian pyramid features
"""

gauss = gaussian_pyramid(d, order=0, num_scales=num_scales, w=w, sigma=sigma)
grad = gaussian_pyramid(d, order=1, num_scales=num_scales, w=w, sigma=sigma)
grad = [gradient_amplitude(x) for x in grad]
hess = gaussian_pyramid(d, order=2, num_scales=num_scales, w=w, sigma=sigma)
lap = [hessian_trace(x) for x in hess]
norm = [hessian_amplitude(x) for x in hess]
features = [d]+gauss+grad+lap+norm

return features
27 changes: 27 additions & 0 deletions jem/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,21 @@ def hessian_amplitude(h):
raise RuntimeError("Unsupported number of dimensions {}.".format(len(h)))
return a

def hessian_det(h):
"""
Determinant in the hessian filter band
"""
if len(h) == 3:
det = h[0] * h[2] - h[1] * h[1]
elif len(h) == 6:
det = (
h[0] * (h[3] * h[5] - h[4] * h[4])
- h[1] * (h[1] * h[5] - h[4] * h[2])
+ h[2] * (h[1] * h[4] - h[3] * h[2])
)
else:
raise RuntimeError("Unsupported number of dimensions {}.".format(len(h)))
return det

def hessian_rot(h):
"""
Expand Down Expand Up @@ -494,6 +509,18 @@ def hessian_rot(h):
else:
raise RuntimeError("Unsupported number of dimensions {}.".format(len(h)))

def hessian_trace(h):
"""
Trace in the hessian filter band
Laplacian
"""
if len(h) == 3:
trace = h[0] + h[2]
elif len(h) == 6:
trace = h[0] + h[3] + h[5]
else:
raise RuntimeError("Unsupported number of dimensions {}.".format(len(h)))
return trace

def rotate_gradient_2d(gx, gy, z=0.0):
"""Rotate a 2d gradient or band pass gradient.
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"coil_correction=jem.cli:coil_correction",
"laplacian_pyramid=jem.cli:compute_laplacian_pyramid",
"riff=jem.cli:riff",
"compute_features=jem.cli:compute_features",
]
},
include_package_data=True,
Expand Down

0 comments on commit 1c282e4

Please sign in to comment.