Skip to content

Commit

Permalink
Add connectivity_from_labels in SFTData
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 26, 2024
1 parent 96d5744 commit 6dd4a80
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 40 deletions.
75 changes: 52 additions & 23 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group):
return streamlines


def _load_connectivity_info(hdf_group: h5py.Group):
connectivity_nb_blocs = None
connectivity_labels = None
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
if 'connectivity_nb_blocs' in hdf_group.attrs:
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
elif 'connectivity_labels' in hdf_group:
connectivity_labels = np.asarray(
hdf_group['connectivity_labels'], dtype=int)
else:
raise ValueError(
"Information stored in the hdf5 is that it contains a "
"connectivity matrix, but we don't know how it was "
"created. Either 'connectivity_nb_blocs' or "
"'connectivity_labels' should be set.")
else:
contains_connectivity = False
return contains_connectivity, connectivity_nb_blocs, connectivity_labels


class _LazyStreamlinesGetter(object):
def __init__(self, hdf_group):
self.hdf_group = hdf_group
Expand Down Expand Up @@ -141,27 +162,38 @@ class SFTDataAbstract(object):
"""
def __init__(self, space_attributes: Tuple, space: Space, origin: Origin,
contains_connectivity: bool,
connectivity_nb_blocs: List):
connectivity_nb_blocs: List = None,
connectivity_labels: np.ndarray = None):
"""
Params
------
group: str
The current streamlines group id, as loaded in the hdf5 file (it
had type "streamlines"). Probabaly 'streamlines'.
The lazy/non-lazy versions will have more parameters, such as the
streamlines, the connectivity_matrix. In the case of the lazy version,
through the LazyStreamlinesGetter.
Parameters
----------
space_attributes: Tuple
The space attributes consist of a tuple:
(affine, dimensions, voxel_sizes, voxel_order)
space: Space
The space from dipy's Space format.
subject_id: str:
The subject's name
origin: Origin
The origin from dipy's Origin format.
contains_connectivity: bool
If true, will search for either the connectivity_nb_blocs or the
connectivity_from_labels information.
connectivity_nb_blocs: List
The information how to recreate the connectivity matrix.
connectivity_labels: np.ndarray
The 3D volume stating how to recreate the labels.
(toDo: Could be managed to be lazy)
"""
self.space_attributes = space_attributes
self.space = space
self.origin = origin
self.is_lazy = None
self.contains_connectivity = contains_connectivity
self.connectivity_nb_blocs = connectivity_nb_blocs
self.connectivity_labels = connectivity_labels

def __len__(self):
raise NotImplementedError
Expand Down Expand Up @@ -195,7 +227,7 @@ def get_connectivity_matrix_and_info(self, ind=None):
(_, ref_volume_shape, _, _) = self.space_attributes

return (self._access_connectivity_matrix(ind), ref_volume_shape,
self.connectivity_nb_blocs)
self.connectivity_nb_blocs, self.connectivity_labels)

def _access_connectivity_matrix(self, ind):
raise NotImplementedError
Expand Down Expand Up @@ -277,15 +309,14 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
streamlines = _load_all_streamlines_from_hdf(hdf_group)
# Adding non-hidden parameters for nicer later access
lengths_mm = hdf_group['euclidean_lengths']
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
connectivity_matrix = np.asarray(hdf_group['connectivity_matrix'],
dtype=int)
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
_load_connectivity_info(hdf_group)
if contains_connectivity:
connectivity_matrix = np.asarray(
hdf_group['connectivity_matrix'], dtype=int) # int or bool?
else:
contains_connectivity = False
connectivity_matrix = None
connectivity_nb_blocs = None

space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)

Expand All @@ -296,7 +327,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes=space_attributes,
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs)
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
if streamline_ids is not None:
Expand Down Expand Up @@ -337,12 +369,9 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None):
@classmethod
def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
else:
contains_connectivity = False
connectivity_nb_blocs = None

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
_load_connectivity_info(hdf_group)

streamlines = _LazyStreamlinesGetter(hdf_group)

Expand Down
12 changes: 5 additions & 7 deletions dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels,
Else, shape (nb_labels, nb_labels)
labels: List
The list of labels
start_labels: List
For each streamline, the label at starting point.
end_labels: List
For each streamline, the label at ending point.
"""
real_labels = list(np.sort(np.unique(data_labels)))
nb_labels = len(real_labels)
Expand Down Expand Up @@ -388,8 +392,7 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels,
return matrix, real_labels, start_labels, end_labels


def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs,
binary: bool = False):
def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs):
"""
Compute a connectivity matrix.
Expand All @@ -405,8 +408,6 @@ def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs,
In 3D, with 20x20x20, this is an 8000 x 8000 matrix (triangular). It
probably contains a lot of zeros with the background being included.
Can be saved as sparse.
binary: bool
If true, return a binary matrix.
"""
nb_blocs = np.asarray(nb_blocs)
start_block, end_block = _compute_origin_finish_blocs(
Expand All @@ -425,9 +426,6 @@ def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs,
matrix = np.triu(matrix)
assert matrix.sum() == len(streamlines)

if binary:
matrix = matrix.astype(bool)

return matrix, start_block, end_block


Expand Down
7 changes: 5 additions & 2 deletions dwi_ml/training/with_generation/batch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def load_batch_connectivity_matrices(
matrices = [None] * nb_subjs
volume_sizes = [None] * nb_subjs
connectivity_nb_blocs = [None] * nb_subjs
connectivity_labels = [None] * nb_subjs
for i, subj in enumerate(subjs):
# No cache for the sft data. Accessing it directly.
# Note: If this is used through the dataloader, multiprocessing
Expand All @@ -34,7 +35,9 @@ def load_batch_connectivity_matrices(

# We could access it only at required index, maybe. Loading the
# whole matrix here.
matrices[i], volume_sizes[i], connectivity_nb_blocs[i] = \
(matrices[i], volume_sizes[i],
connectivity_nb_blocs[i], connectivity_labels[i]) = \
subj_sft_data.get_connectivity_matrix_and_info()

return matrices, volume_sizes, connectivity_nb_blocs
return (matrices, volume_sizes,
connectivity_nb_blocs, connectivity_labels)
26 changes: 18 additions & 8 deletions dwi_ml/training/with_generation/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from torch.nn import PairwiseDistance

from dwi_ml.data.processing.streamlines.post_processing import \
compute_triu_connectivity_from_blocs
compute_triu_connectivity_from_blocs, compute_triu_connectivity_from_labels
from dwi_ml.models.main_models import ModelWithDirectionGetter
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.tracking.io_utils import prepare_tracking_mask
Expand Down Expand Up @@ -317,28 +317,38 @@ def _compare_connectivity(self, lines, ids_per_subj):
compares with expected values for the subject.
"""
if self.compute_connectivity:
connectivity_matrices, volume_sizes, connectivity_nb_blocs = \
(connectivity_matrices, volume_sizes,
connectivity_nb_blocs, connectivity_labels) = \
self.batch_loader.load_batch_connectivity_matrices(ids_per_subj)

score = 0.0
for i, subj in enumerate(ids_per_subj.keys()):
real_matrix = connectivity_matrices[i]
volume_size = volume_sizes[i]
nb_blocs = connectivity_nb_blocs[i]
labels = connectivity_labels[i]
_lines = lines[ids_per_subj[subj]]

batch_matrix, _, _ = compute_triu_connectivity_from_blocs(
_lines, volume_size, nb_blocs, binary=False)
# Reference matrices are saved as binary in create_hdf5,
# but still. Ensuring.
real_matrix = real_matrix > 0

# But our matrix here won't be!
if nb_blocs is not None:
batch_matrix, _, _ = compute_triu_connectivity_from_blocs(
_lines, volume_size, nb_blocs)
else:
# ToDo. Allow use_scilpy?
batch_matrix, _, _ = compute_triu_connectivity_from_labels(
_lines, labels, use_scilpy=False)

# Where our batch has a 0: not important, maybe it was simply
# not in this batch.
# Where our batch has a 1, if there was really a one: score
# should be 0. = 1 - 1.
# Else, score should be high (1). = 1 - 0.
# should be 0. = 1 - 1 = 1 - real
# Else, score should be high (1). = 1 - 0 = 1 - real
# If two streamlines have the same connection, score is
# either 0 or 2 for that voxel. ==> nb * (1 - real).

# Reference matrices are saved as binary in create_hdf5.
where_one = np.where(batch_matrix > 0)
score += np.sum(batch_matrix[where_one] *
(1.0 - real_matrix[where_one]))
Expand Down

0 comments on commit 6dd4a80

Please sign in to comment.