Skip to content

Commit

Permalink
First pass of PR fixes since comeback. Not much left. :)
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentBeaud committed Sep 4, 2024
1 parent 7800eda commit 1005e1d
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 59 deletions.
13 changes: 10 additions & 3 deletions scilpy/image/volume_space_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ class FibertubeDataVolume(DataVolume):
interface for fibertube tracking. Instead of a spherical function,
provides direction and intersection volume of close-by fiber segments.
"""

def __init__(self, centerlines, diameters, mask, voxres, blur_radius,
origin, random_generator):
"""
Expand Down Expand Up @@ -509,7 +510,9 @@ def get_absolute_direction(self, x, y, z):
@staticmethod
@njit
def extract_directions(pos, neighbors, blur_radius, segments_indices,
centerlines, diameters, random_generator):
centerlines, diameters, random_generator,
volume_nb_samples = 1000,
volume_nb_samples_backup = 10000):
directions = []
volumes = []

Expand All @@ -523,13 +526,17 @@ def extract_directions(pos, neighbors, blur_radius, segments_indices,

volume, is_estimated = sphere_cylinder_intersection(
pos, blur_radius, fib_pt1,
fib_pt2, radius, 1000, random_generator)
fib_pt2, radius,
volume_nb_samples,
random_generator)

# Catch estimation error when using very small blur_radius.
if volume == 0 and is_estimated:
volume, _ = sphere_cylinder_intersection(
pos, blur_radius, fib_pt1,
fib_pt2, radius, 10000, random_generator)
fib_pt2, radius,
volume_nb_samples_backup,
random_generator)

if volume > 0:
directions.append(dir / np.linalg.norm(dir))
Expand Down
104 changes: 87 additions & 17 deletions scilpy/tracking/fibertube.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,18 @@ def segment_tractogram(streamlines, verbose=False):
@njit
def rotation_between_vectors_matrix(vec1, vec2):
"""
Rotation matrix that aligns vec1 to vec2. Numba compatible.
Produces a rotation matrix that aligns a 3D vector 'vec1' with another 3D
vector 'vec2'. Numba compatible.
https://math.stackexchange.com/questions/180418/calculate-
rotation-matrix-to-align-vector-a-to-vector-b-in-3d
Parameters
----------
origin : any
destination : any
sft : StatefulTractogram
StatefulTractogram containing the streamlines to segment.
vec1: ndarray
Vector to be rotated
vec2: ndarray
Targeted orientation
Returns
-------
Expand Down Expand Up @@ -95,9 +96,12 @@ def sample_sphere(center, radius: float, amount: int,
Center coordinates of the sphere. Can be [0, 0, 0] if only the
relative displacement interests you.
radius: float
Radius of the sphere.
amount: int
Amount of samples to be produced.
rand_gen: Generator
Numpy random generator used for producing samples within the sphere.
Returns
-------
samples: list
Expand All @@ -123,14 +127,18 @@ def sample_cylinder(center, axis, radius: float, length: float,
Parameters
----------
center: ndarray
Center coordinates of the cylinder
Center coordinates of the cylinder.
axis: ndarray
Center axis of the cylinder, in the form of a vector. Does not have to
be normalized.
radius: float
Radius of the cylinder.
length: float
Length of the cylinder.
sample_count: int
Amount of samples to be produced.
rand_gen: Generator
Numpy random generator used for producing samples within the sphere.
Returns
-------
Expand All @@ -147,14 +155,14 @@ def sample_cylinder(center, axis, radius: float, length: float,
x = random_generator.uniform(-radius, radius)
y = random_generator.uniform(-radius, radius)
z = random_generator.uniform(-half_length, half_length)
sample = np.array([x, y, z], dtype=np.float32)
sample = np.array([x, y, z], dtype=np.float64)

# Rotation
rotation_matrix = np.eye(4, dtype=np.float32)
rotation_matrix = np.eye(4, dtype=np.float64)
rotation_matrix[:3, :3] = rotation_between_vectors_matrix(
reference,
axis).astype(np.float32)
sample = np.dot(rotation_matrix, np.append(sample, np.float32(1.)))[:3]
sample = np.dot(rotation_matrix, np.append(sample, 1.))[:3]

# Translation
sample += center
Expand Down Expand Up @@ -219,7 +227,7 @@ def sphere_cylinder_intersection(sph_p, sph_r: float, cyl_p1, cyl_p2,
return cyl_volume, False

# If cylinder is completely outside the sphere.
_, vector, _, _ = dist_point_segment(cyl_p1, cyl_p2, sph_p)
_, vector, _ = dist_point_segment(cyl_p1, cyl_p2, sph_p)
if np.linalg.norm(vector) >= sph_r + cyl_r:
return 0, False

Expand All @@ -240,8 +248,23 @@ def sphere_cylinder_intersection(sph_p, sph_r: float, cyl_p1, cyl_p2,


@njit
def create_perpendicular(v):
def create_perpendicular(v: np.ndarray):
"""
Generates a vector perpendicular to v.
Parameters
----------
v: ndarray
Vector from which a perpendicular vector will be generated.
Returns
-------
vp: ndarray
Vector perpendicular to v.
"""
vp = np.array([0., 0., 0.])
if v.all() == vp.all():
return vp
for m in range(3):
if v[m] == 0.:
continue
Expand All @@ -253,13 +276,60 @@ def create_perpendicular(v):


@njit
def dist_point_segment(P0, P1, Q):
return dist_segment_segment(P0, P1, Q, Q)
def dist_point_segment(p0, p1, q):
"""
Calculates the shortest distance between a 3D point q and a segment p0-p1.
Parameters
----------
p0: ndarray
Point forming the first end of the segment.
p1: ndarray
Point forming the second end of the segment.
q: ndarray
Point coordinates.
Returns
-------
distance: float
Shortest distance between the two segments
v: ndarray
Vector representing the distance between the two segments.
v = Ps - q and |v| = distance
Ps: ndarray
Point coordinates on segment P that is closest to point q
"""
return dist_segment_segment(p0, p1, q, q)[:3]


@njit
def dist_segment_segment(P0, P1, Q0, Q1):
"""
Calculates the shortest distance between two 3D segments P0-P1 and Q0-Q1.
Parameters
----------
P0: ndarray
Point forming the first end of the P segment.
P1: ndarray
Point forming the second end of the P segment.
Q0: ndarray
Point forming the first end of the Q segment.
Q1: ndarray
Point forming the second end of the Q segment.
Returns
-------
distance: float
Shortest distance between the two segments
v: ndarray
Vector representing the distance between the two segments.
v = Ps - Qt and |v| = distance
Ps: ndarray
Point coordinates on segment P that is closest to segment Q
Qt: ndarray
Point coordinates on segment Q that is closest to segment P
This function is a python version of the following code:
https://www.geometrictools.com/GTE/Mathematics/DistSegmentSegment.h
Expand Down Expand Up @@ -415,7 +485,7 @@ def dist_segment_segment(P0, P1, Q0, Q1):

Ps = P0 + s * P1mP0
Qt = Q0 + t * Q1mQ0
diff = Ps - Qt
sqr_distance = np.dot(diff, diff)
v = Ps - Qt
sqr_distance = np.dot(v, v)
distance = sqrt(sqr_distance)
return (distance, diff, Ps, Qt)
return (distance, v, Ps, Qt)
29 changes: 9 additions & 20 deletions scilpy/tracking/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,23 +578,23 @@ def _get_possible_next_dirs_det(self, pos, previous_direction):

class FibertubePropagator(AbstractPropagator):
"""
Implementation of the scilpy.tracking.propagator.AbstractPropagator
interface for fibertube tracking.
Simplified propagator for using fibertube data. It is probabilistic and
uses the volume of intersection between fibertube segments and the
blurring sphere as a random distribution for picking a segment. This
segment is then used as the propagation direction.
"""
def __init__(self, datavolume: FibertubeDataVolume, step_size, rk_order,
algo, theta, space, origin):
theta, space, origin):
""""
Parameters
----------
datavolume: scilpy.image.volume_space_management.FibertubeDataVolume
datavolume: FibertubeDataVolume
Trackable fibertube dataset object.
step_size: float
The step size for tracking. Important: step size should be in the
same units as the space of the tracking!
rk_order: int
Order for the Runge Kutta integration.
algo: string
Type of algorithm. Choices are 'det' or 'prob'
theta: float
Maximum angle (radians) between two steps.
space: dipy Space
Expand All @@ -616,7 +616,6 @@ def __init__(self, datavolume: FibertubeDataVolume, step_size, rk_order,
self.datavolume = datavolume
self.step_size = step_size
self.rk_order = rk_order
self.algo = algo
self.theta = theta
self.space = space
self.origin = origin
Expand Down Expand Up @@ -648,20 +647,14 @@ def propagate(self, line, v_in):
return super().propagate(line, v_in)

def _sample_next_direction(self, pos, v_in):
if self.algo == 'prob':
directions, volumes = self._get_possible_next_dirs(pos, v_in)

# Sampling one.
if np.sum(volumes) > 0:
v_out = directions[
sample_distribution(volumes, self.line_rng_generator)]
else:
return None
else:
raise ValueError("Tracking algorithm must be 'prob' for " +
" fibertube tracking.")

return v_out
return v_out
return None

def _get_possible_next_dirs(self, pos, v_in):
directions, volumes = (
Expand All @@ -682,11 +675,7 @@ def _get_possible_next_dirs(self, pos, v_in):
cosine = abs(cosine)
dir = -dir

# clip float error to bounds
if cosine > 1:
cosine = 1
if cosine < -1:
cosine = -1
cosine = np.clip(cosine, -1, 1)

if (np.arccos(cosine) > self.theta):
continue
Expand Down
Loading

0 comments on commit 1005e1d

Please sign in to comment.