Skip to content

Commit

Permalink
Merge pull request dipy#3211 from skoudoro/pmf-output
Browse files Browse the repository at this point in the history
[RF] PMF Gen: from memoryview to pointer
  • Loading branch information
Garyfallidis authored May 7, 2024
2 parents fe5dd0c + be8c4e7 commit eec0ad8
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 68 deletions.
10 changes: 2 additions & 8 deletions dipy/direction/closest_peak_direction_getter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,8 @@ cdef class BasePmfDirectionGetter(DirectionGetter):
self,
double[::1] point)

cdef double[:] _get_pmf(
self,
double[::1] point) nogil

cdef int get_direction_c(
self,
double[::1] point,
double[::1] direction)
cdef double* _get_pmf(self, double[::1] point) nogil
cdef int get_direction_c(self, double[::1] point, double[::1] direction)


cdef class BaseDirectionGetter(BasePmfDirectionGetter):
Expand Down
15 changes: 8 additions & 7 deletions dipy/direction/closest_peak_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ cdef class BasePmfDirectionGetter(DirectionGetter):
directions should be unique.
"""
cdef double[:] pmf = self._get_pmf(point)
return self._get_peak_directions(pmf)
cdef double* pmf = self._get_pmf(point)
return self._get_peak_directions(<double[:self.len_pmf]>pmf)

cdef double[:] _get_pmf(self, double[::1] point) nogil:
cdef double* _get_pmf(self, double[::1] point) nogil:
cdef:
cnp.npy_intp i
cnp.npy_intp _len = self.len_pmf
double[:] pmf
double* pmf = &self.pmf_gen.pmf[0]
double pmf_threshold=self.pmf_threshold
double absolute_pmf_threshold
double max_pmf=0

pmf = self.pmf_gen.get_pmf_c(point)
pmf = self.pmf_gen.get_pmf_c(&point[0], pmf)
for i in range(_len):
if pmf[i] > max_pmf:
max_pmf = pmf[i]
Expand Down Expand Up @@ -236,12 +236,13 @@ cdef class ClosestPeakDirectionGetter(PmfGenDirectionGetter):
"""
cdef:
cnp.npy_intp _len = self.len_pmf
double[:] pmf
double* pmf
cnp.ndarray[cnp.float_t, ndim=2] peaks

pmf = self._get_pmf(point)

peaks = self._get_peak_directions(pmf)
peaks = self._get_peak_directions(<double[:_len]>pmf)
if len(peaks) == 0:
return 1

return closest_peak(peaks, direction, self.cos_similarity)
6 changes: 3 additions & 3 deletions dipy/direction/pmf.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ cdef class PmfGen:
double[:, :] vertices
object sphere

cdef double[:] get_pmf_c(self, double[::1] point) noexcept nogil
cdef double* get_pmf_c(self, double* point, double* out) noexcept nogil
cdef int find_closest(self, double* xyz) noexcept nogil
cdef double get_pmf_value_c(self, double[::1] point, double[::1] xyz) noexcept nogil
cdef void __clear_pmf(self) noexcept nogil
cdef double get_pmf_value_c(self, double* point, double* xyz,
double* out) noexcept nogil
pass


Expand Down
76 changes: 39 additions & 37 deletions dipy/direction/pmf.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ cimport numpy as cnp
from dipy.reconst import shm

from dipy.core.interpolation cimport trilinear_interpolate4d_c
from libc.stdlib cimport malloc, free

cdef extern from "stdlib.h" nogil:
void *memset(void *ptr, int value, size_t num)


cdef class PmfGen:
Expand All @@ -18,11 +22,14 @@ cdef class PmfGen:
self.data = np.asarray(data, dtype=float, order='C')
self.sphere = sphere
self.vertices = np.asarray(sphere.vertices, dtype=float)
self.pmf = np.zeros(self.vertices.shape[0])

def get_pmf(self, double[::1] point):
return self.get_pmf_c(point)
def get_pmf(self, double[::1] point, double[:] out=None):
if out is None:
out = self.pmf
return <double[:len(self.vertices)]>self.get_pmf_c(&point[0], &out[0])

cdef double[:] get_pmf_c(self, double[::1] point) noexcept nogil:
cdef double* get_pmf_c(self, double* point, double* out) noexcept nogil:
pass

cdef int find_closest(self, double* xyz) noexcept nogil:
Expand All @@ -44,24 +51,20 @@ cdef class PmfGen:
idx = i
return idx

def get_pmf_value(self, double[::1] point, double[::1] xyz):
return self.get_pmf_value_c(point, xyz)
def get_pmf_value(self, double[::1] point, double[::1] xyz,
double[:] pmf_buffer=None):
if pmf_buffer is None:
pmf_buffer = self.pmf
return self.get_pmf_value_c(&point[0], &xyz[0], &pmf_buffer[0])

cdef double get_pmf_value_c(self, double[::1] point, double[::1] xyz) noexcept nogil:
cdef double get_pmf_value_c(self, double* point, double* xyz,
double* pmf_buffer) noexcept nogil:
"""
Return the pmf value corresponding to the closest vertex to the
direction xyz.
"""
cdef int idx = self.find_closest(&xyz[0])
return self.get_pmf_c(point)[idx]

cdef void __clear_pmf(self) noexcept nogil:
cdef:
cnp.npy_intp len_pmf = self.pmf.shape[0]
cnp.npy_intp i

for i in range(len_pmf):
self.pmf[i] = 0.0
cdef int idx = self.find_closest(xyz)
return self.get_pmf_c(point, pmf_buffer)[idx]


cdef class SimplePmfGen(PmfGen):
Expand All @@ -70,33 +73,33 @@ cdef class SimplePmfGen(PmfGen):
double[:, :, :, :] pmf_array,
object sphere):
PmfGen.__init__(self, pmf_array, sphere)
self.pmf = np.empty(pmf_array.shape[3])
if np.min(pmf_array) < 0:
raise ValueError("pmf should not have negative values.")
if not pmf_array.shape[3] == sphere.vertices.shape[0]:
raise ValueError("pmf should have the same number of values as the"
+ " number of vertices of sphere.")

cdef double[:] get_pmf_c(self, double[::1] point) noexcept nogil:
if trilinear_interpolate4d_c(self.data, &point[0], &self.pmf[0]) != 0:
PmfGen.__clear_pmf(self)
return self.pmf
cdef double* get_pmf_c(self, double* point, double* out) noexcept nogil:
if trilinear_interpolate4d_c(self.data, point, out) != 0:
memset(out, 0, self.pmf.shape[0] * sizeof(double))
return out

cdef double get_pmf_value_c(self, double[::1] point, double[::1] xyz) noexcept nogil:
cdef double get_pmf_value_c(self, double* point, double* xyz,
double* pmf_buffer) noexcept nogil:
"""
Return the pmf value corresponding to the closest vertex to the
direction xyz.
"""
cdef:
int idx

idx = self.find_closest(&xyz[0])
idx = self.find_closest(xyz)

if trilinear_interpolate4d_c(self.data[:,:,:,idx:idx+1],
&point[0],
&self.pmf[0]) != 0:
PmfGen.__clear_pmf(self)
return self.pmf[0]
point,
pmf_buffer) != 0:
memset(pmf_buffer, 0, self.pmf.shape[0] * sizeof(double))
return pmf_buffer[0]


cdef class SHCoeffPmfGen(PmfGen):
Expand All @@ -117,24 +120,23 @@ cdef class SHCoeffPmfGen(PmfGen):
except KeyError:
raise ValueError("%s is not a known basis type." % basis_type)
self.B, _, _ = basis(sh_order, sphere.theta, sphere.phi, legacy=legacy)
self.coeff = np.empty(shcoeff_array.shape[3])
self.pmf = np.empty(self.B.shape[0])

cdef double[:] get_pmf_c(self, double[::1] point) noexcept nogil:
cdef double* get_pmf_c(self, double* point, double* out) noexcept nogil:
cdef:
cnp.npy_intp i, j
cnp.npy_intp len_pmf = self.pmf.shape[0]
cnp.npy_intp len_B = self.B.shape[1]
double _sum
# TODO: Maybe a better to do this
double *coeff = <double*>malloc(self.data.shape[3] * sizeof(double))

if trilinear_interpolate4d_c(self.data,
&point[0],
&self.coeff[0]) != 0:
PmfGen.__clear_pmf(self)
if trilinear_interpolate4d_c(self.data, point, coeff) != 0:
memset(out, 0, len_pmf * sizeof(double))
else:
for i in range(len_pmf):
_sum = 0
for j in range(len_B):
_sum = _sum + (self.B[i, j] * self.coeff[j])
self.pmf[i] = _sum
return self.pmf
_sum = _sum + (self.B[i, j] * coeff[j])
out[i] = _sum
free(coeff)
return out
8 changes: 4 additions & 4 deletions dipy/direction/probabilistic_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ cdef class ProbabilisticDirectionGetter(PmfGenDirectionGetter):
cdef:
cnp.npy_intp i, idx, _len
double[:] newdir
double[:] pmf
double* pmf
double last_cdf, cos_sim

_len = self.len_pmf
Expand All @@ -103,12 +103,12 @@ cdef class ProbabilisticDirectionGetter(PmfGenDirectionGetter):
if cos_sim < self.cos_similarity:
pmf[i] = 0

cumsum(&pmf[0], &pmf[0], _len)
cumsum(pmf, pmf, _len)
last_cdf = pmf[_len - 1]
if last_cdf == 0:
return 1

idx = where_to_insert(&pmf[0], random() * last_cdf, _len)
idx = where_to_insert(pmf, random() * last_cdf, _len)

newdir = self.vertices[idx]
# Update direction and return 0 for error
Expand Down Expand Up @@ -151,7 +151,7 @@ cdef class DeterministicMaximumDirectionGetter(ProbabilisticDirectionGetter):
cdef:
cnp.npy_intp _len, max_idx
double[:] newdir
double[:] pmf
double* pmf
double max_value, cos_sim

pmf = self._get_pmf(point)
Expand Down
14 changes: 10 additions & 4 deletions dipy/direction/ptt_direction_getter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Aydogan DB, Shi Y. Parallel Transport Tractography. IEEE Trans

cimport numpy as cnp
from libc.math cimport M_PI, pow, sin, cos, fabs
from libc.stdlib cimport malloc, free

from dipy.direction.probabilistic_direction_getter cimport \
ProbabilisticDirectionGetter
Expand Down Expand Up @@ -161,7 +162,8 @@ cdef class PTTDirectionGetter(ProbabilisticDirectionGetter):

if self.probe_count == 1:
self.last_val = self.pmf_gen.get_pmf_value_c(self.position,
self.frame[0])
self.frame[0],
&self.pmf_gen.pmf[0])
else:
for count in range(self.probe_count):
for i in range(3):
Expand All @@ -175,7 +177,8 @@ cdef class PTTDirectionGetter(ProbabilisticDirectionGetter):
* self.inv_voxel_size[i])

self.last_val += self.pmf_gen.get_pmf_value_c(position,
self.frame[0])
self.frame[0],
&self.pmf_gen.pmf[0])


cdef void prepare_propagator(self, double arclength) nogil:
Expand Down Expand Up @@ -269,7 +272,8 @@ cdef class PTTDirectionGetter(ProbabilisticDirectionGetter):
copy_point(&binormal[0], &frame[2][0])

if self.probe_count == 1:
fod_amp = self.pmf_gen.get_pmf_value_c(position, tangent)
fod_amp = self.pmf_gen.get_pmf_value_c(position, tangent,
&self.pmf_gen.pmf[0])
fod_amp = fod_amp if fod_amp > self.pmf_threshold else 0
self.last_val_cand = fod_amp
likelihood += self.last_val_cand
Expand All @@ -291,7 +295,8 @@ cdef class PTTDirectionGetter(ProbabilisticDirectionGetter):
+ binormal[i] * self.probe_radius
* sin(c * self.angular_separation)
* self.inv_voxel_size[i])
fod_amp = self.pmf_gen.get_pmf_value_c(new_position, tangent)
fod_amp = self.pmf_gen.get_pmf_value_c(new_position, tangent,
&self.pmf_gen.pmf[0])
fod_amp = fod_amp if fod_amp > self.pmf_threshold else 0
self.last_val_cand += fod_amp

Expand Down Expand Up @@ -344,6 +349,7 @@ cdef class PTTDirectionGetter(ProbabilisticDirectionGetter):
if (random() * max_posterior <= self.calculate_data_support()):
self.last_val = self.last_val_cand
return 0

return 1

cdef int propagate(self):
Expand Down
18 changes: 13 additions & 5 deletions dipy/direction/tests/test_pmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
import numpy as np
import numpy.testing as npt

from dipy.core.gradients import gradient_table
from dipy.core.sphere import HemiSphere, unit_octahedron
from dipy.data import default_sphere, get_sphere
from dipy.direction.pmf import SHCoeffPmfGen, SimplePmfGen
from dipy.reconst import shm
from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel
from dipy.reconst.dti import TensorModel
from dipy.testing.decorators import set_random_number_generator

response = (np.array([1.5e3, 0.3e3, 0.3e3]), 1)
Expand All @@ -25,8 +22,13 @@ def test_pmf_val(rng):
pmfgen = SHCoeffPmfGen(rng.random([2, 2, 2, 28]), sphere, None)
point = np.array([1, 1, 1], dtype='float')

out = np.ones(len(sphere.vertices))
for idx in [0, 5, 15, -1]:
pmf = pmfgen.get_pmf(point)
pmf_2 = pmfgen.get_pmf(point, out)

npt.assert_array_almost_equal(pmf, out)
npt.assert_array_almost_equal(pmf, pmf_2)
# Create a direction vector close to the vertex idx
xyz = sphere.vertices[idx] + rng.random([3]) / 100
pmf_idx = pmfgen.get_pmf_value(point, xyz)
Expand All @@ -42,9 +44,12 @@ def test_pmf_from_sh():
category=PendingDeprecationWarning)
pmfgen = SHCoeffPmfGen(np.ones([2, 2, 2, 28]), sphere, None)

out = np.zeros(len(sphere.vertices))
# Test that the pmf is greater than 0 for a valid point
pmf = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'))
out = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'), out)
npt.assert_equal(np.sum(pmf) > 0, True)
npt.assert_array_almost_equal(pmf, out)

# Test that the pmf is 0 for invalid Points
npt.assert_array_equal(pmfgen.get_pmf(np.array([-1, 0, 0], dtype='float')),
Expand All @@ -57,14 +62,17 @@ def test_pmf_from_array():
sphere = HemiSphere.from_sphere(unit_octahedron)
pmfgen = SimplePmfGen(np.ones([2, 2, 2, len(sphere.vertices)]), sphere)

out = np.zeros(len(sphere.vertices))
# Test that the pmf is greater than 0 for a valid point
pmf = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'))
out = pmfgen.get_pmf(np.array([0, 0, 0], dtype='float'), out)
npt.assert_equal(np.sum(pmf) > 0, True)
npt.assert_array_almost_equal(pmf, out)

# Test that the pmf is 0 for invalid Points
npt.assert_array_equal(pmfgen.get_pmf(np.array([-1, 0, 0], dtype='float')),
npt.assert_array_equal(pmfgen.get_pmf(np.array([-1, 0, 0], dtype=float)),
np.zeros(len(sphere.vertices)))
npt.assert_array_equal(pmfgen.get_pmf(np.array([0, 0, 10], dtype='float')),
npt.assert_array_equal(pmfgen.get_pmf(np.array([0, 0, 10], dtype=float)),
np.zeros(len(sphere.vertices)))

# Test ValueError for negative pmf
Expand Down
9 changes: 9 additions & 0 deletions dipy/utils/fast_numpy.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# cython: boundscheck=False
# cython: initializedcheck=False
# cython: wraparound=False
from libc.stdio cimport printf


cdef int where_to_insert(cnp.float_t* arr, cnp.float_t number, int size) noexcept nogil:
cdef:
Expand Down Expand Up @@ -194,3 +196,10 @@ cpdef void seed(cnp.npy_uint32 s) noexcept nogil:
random seed.
"""
srand(s)


cdef void print_c_array_pointer(double* arr, int size) noexcept nogil:
cdef int i
for i in range(size):
printf("%f, ", arr[i])
printf("\n\n\n")

0 comments on commit eec0ad8

Please sign in to comment.