Skip to content

Commit

Permalink
Merge pull request #1103 from 36000/FIX_FIND_FILE
Browse files Browse the repository at this point in the history
[ENH/FIX] Smarter file finding
  • Loading branch information
36000 authored Feb 7, 2024
2 parents 0b25b02 + e30cb73 commit 78f4c3e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 62 deletions.
42 changes: 16 additions & 26 deletions AFQ/definitions/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from dipy.segment.mask import median_otsu
from dipy.align import resample

import AFQ.utils.volume as auv
from AFQ.definitions.utils import Definition, find_file, name_from_path

from skimage.morphology import convex_hull_image, binary_opening
from scipy.linalg import blas

__all__ = [
"ImageFile", "FullImage", "RoiImage", "B0Image", "LabelledImageFile",
Expand Down Expand Up @@ -128,16 +126,20 @@ def __init__(self, path=None, suffix=None, filters={}):
self.filters = filters
self.fnames = {}

def find_path(self, bids_layout, from_path, subject, session):
def find_path(self, bids_layout, from_path,
subject, session, required=True):
if self._from_path:
return
if session not in self.fnames:
self.fnames[session] = {}

nearest_image = find_file(
bids_layout, from_path, self.filters, self.suffix, session,
subject)
subject, required=required)

if nearest_image is None:
return False

if session not in self.fnames:
self.fnames[session] = {}
self.fnames[session][subject] = nearest_image

def get_path_data_affine(self, bids_info):
Expand Down Expand Up @@ -196,9 +198,6 @@ class FullImage(ImageDefinition):
def __init__(self):
pass

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return "entire_volume"

Expand Down Expand Up @@ -269,9 +268,6 @@ def __init__(self,
"One of use_waypoints, use_presegment, "
"use_endpoints, must be True"))

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return "roi"

Expand Down Expand Up @@ -364,9 +360,6 @@ class GQImage(ImageDefinition):
def __init__(self):
pass

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return "GQ"

Expand Down Expand Up @@ -410,9 +403,6 @@ class B0Image(ImageDefinition):
def __init__(self, median_otsu_kwargs={}):
self.median_otsu_kwargs = median_otsu_kwargs

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return "b0"

Expand Down Expand Up @@ -616,9 +606,6 @@ class ScalarImage(ImageDefinition):
def __init__(self, scalar):
self.scalar = scalar

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return self.scalar

Expand Down Expand Up @@ -705,9 +692,15 @@ class PFTImage(ImageDefinition):
def __init__(self, WM_probseg, GM_probseg, CSF_probseg):
self.probsegs = (WM_probseg, GM_probseg, CSF_probseg)

def find_path(self, bids_layout, from_path, subject, session):
def find_path(self, bids_layout, from_path,
subject, session, required=True):
if required == False:
raise ValueError(
"PFTImage cannot be used in this context")
for probseg in self.probsegs:
probseg.find_path(bids_layout, from_path, subject, session)
probseg.find_path(
bids_layout, from_path, subject, session,
required=required)

def get_name(self):
return "pft"
Expand Down Expand Up @@ -739,9 +732,6 @@ class TemplateImage(ImageDefinition):
def __init__(self, path):
self.path = path

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return name_from_path(self.path)

Expand Down
19 changes: 4 additions & 15 deletions AFQ/definitions/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,20 @@ def __init__(self, warp_path=None, space_path=None,
self.space_filters = space_filters
self.fnames = {}

def find_path(self, bids_layout, from_path, subject, session):
def find_path(self, bids_layout, from_path,
subject, session, required=True):
if self._from_path:
return
if session not in self.fnames:
self.fnames[session] = {}

nearest_warp = find_file(
bids_layout, from_path, self.warp_filters, self.warp_suffix,
session, subject)
session, subject, required=required)

nearest_space = find_file(
bids_layout, from_path, self.space_filters, self.space_suffix,
session, subject)
session, subject, required=required)

self.fnames[session][subject] = (nearest_warp, nearest_space)

Expand Down Expand Up @@ -181,9 +182,6 @@ class IdentityMap(Definition):
def __init__(self):
pass

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_for_subses(self, base_fname, dwi, bids_info, reg_subject,
reg_template):
return ConformedAffineMapping(
Expand Down Expand Up @@ -314,9 +312,6 @@ def __init__(self, use_prealign=True, affine_kwargs={}, syn_kwargs={}):
self.syn_kwargs = syn_kwargs
self.extension = ".nii.gz"

def find_path(self, bids_layout, from_path, subject, session):
pass

def gen_mapping(self, base_fname, reg_subject, reg_template,
subject_sls, template_sls,
reg_prealign):
Expand Down Expand Up @@ -362,9 +357,6 @@ def __init__(self, slr_kwargs={}):
self.use_prealign = False
self.extension = ".npy"

def find_path(self, bids_layout, from_path, subject, session):
pass

def gen_mapping(self, base_fname, reg_template, reg_subject,
subject_sls, template_sls, reg_prealign):
return reg.slr_registration(
Expand Down Expand Up @@ -402,9 +394,6 @@ def __init__(self, affine_kwargs={}):
self.affine_kwargs = affine_kwargs
self.extension = ".npy"

def find_path(self, bids_layout, from_path, subject, session):
pass

def gen_mapping(self, base_fname, reg_subject, reg_template,
subject_sls, template_sls,
reg_prealign):
Expand Down
48 changes: 39 additions & 9 deletions AFQ/definitions/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os.path as op
import logging

from AFQ.utils.path import drop_extension


logger = logging.getLogger('AFQ')


__all__ = ["Definition", "find_file", "name_from_path"]


Expand All @@ -20,8 +25,9 @@ class Definition(object):
def __init__(self):
raise NotImplementedError("Please implement an __init__ method")

def find_path(self, bids_layout, from_path, subject, session):
raise NotImplementedError("Please implement a find_path method")
def find_path(self, bids_layout, from_path,
subject, session, required=True):
pass

def str_for_toml(self):
"""
Expand Down Expand Up @@ -71,8 +77,16 @@ def name_from_path(path):
return file_name


def _ff_helper(required, err_msg):
if required:
raise ValueError(err_msg)
else:
logger.warning(err_msg)
return None


def find_file(bids_layout, path, filters, suffix, session, subject,
extension=".nii.gz"):
extension=".nii.gz", required=True):
"""
Helper function
Generic calls to get_nearest to find a file
Expand All @@ -92,8 +106,9 @@ def find_file(bids_layout, path, filters, suffix, session, subject,
strict=False,
)

# If that fails, loosen session restriction
# in order to find scans that are not session specific
if nearest is None:
# If that fails, loosen session restriction
nearest = bids_layout.get_nearest(
path,
**filters,
Expand All @@ -102,9 +117,9 @@ def find_file(bids_layout, path, filters, suffix, session, subject,
strict=False,
)

# Nothing is found
if nearest is None:
# If nothing is found still, raise an error
raise ValueError((
return _ff_helper(required, (
"No file found with these parameters:\n"
f"suffix: {suffix},\n"
f"session (searched with and without): {session},\n"
Expand All @@ -118,12 +133,27 @@ def find_file(bids_layout, path, filters, suffix, session, subject,
file_subject = bids_layout.parse_file_entities(nearest).get(
"subject", None
)
path_session = bids_layout.parse_file_entities(path).get(
"session", None
)
file_session = bids_layout.parse_file_entities(nearest).get(
"session", None
)

# found file is wrong subject
if path_subject != file_subject:
raise ValueError(
return _ff_helper(required, (
f"Expected subject IDs to match for the retrieved image file "
f"and the supplied `from_path` file. Got sub-{file_subject} "
f"from image file {nearest} and sub-{path_subject} "
f"from `from_path` file {path}."
)
f"from `from_path` file {path}."))

# found file is wrong session
if (file_session is not None) and (path_session != file_session):
return _ff_helper(required, (
f"Expected session IDs to match for the retrieved image file "
f"and the supplied `from_path` file. Got ses-{file_session} "
f"from image file {nearest} and ses-{path_session} "
f"from `from_path` file {path}."))

return nearest
23 changes: 12 additions & 11 deletions AFQ/tasks/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,22 @@ def get_mapping_plan(kwargs, use_sls=False):
None,
kwargs["dwi_path"],
None,
None
)
None)
scalar_found = True
else:
scalar.find_path(
scalar_found = scalar.find_path(
bids_info["bids_layout"],
kwargs["dwi_path"],
bids_info["subject"],
bids_info["session"]
)
mapping_tasks[f"{scalar.get_name()}_res"] =\
pimms.calc(f"{scalar.get_name()}")(
as_file((
f'_desc-{str_to_desc(scalar.get_name())}'
'_dwi.nii.gz'))(
scalar.get_image_getter("mapping")))
bids_info["session"],
required=False)
if scalar_found != False:
mapping_tasks[f"{scalar.get_name()}_res"] =\
pimms.calc(f"{scalar.get_name()}")(
as_file((
f'_desc-{str_to_desc(scalar.get_name())}'
'_dwi.nii.gz'))(
scalar.get_image_getter("mapping")))

if use_sls:
mapping_tasks["mapping_res"] = sls_mapping
Expand Down
2 changes: 1 addition & 1 deletion AFQ/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def get_scalar_dict(data_imap, mapping_imap, scalars=["dti_fa", "dti_md"]):
if isinstance(scalar, str):
sc = scalar.lower()
scalar_dict[sc] = data_imap[f"{sc}"]
else:
elif f"{scalar.get_name()}" in mapping_imap:
scalar_dict[scalar.get_name()] = mapping_imap[
f"{scalar.get_name()}"]
return {"scalar_dict": scalar_dict}
Expand Down
3 changes: 3 additions & 0 deletions AFQ/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def test_AFQ_fury():
myafq.export("all_bundles_figure")


@pytest.mark.nightly_pft
def test_AFQ_trx():
tmpdir = tempfile.TemporaryDirectory()
bids_path = op.join(tmpdir.name, "stanford_hardi")
Expand All @@ -320,6 +321,8 @@ def test_AFQ_trx():
myafq = GroupAFQ(
bids_path=bids_path,
preproc_pipeline='vistasoft',
# should throw warning but not error
scalars=["dti_fa", "dti_md", ImageFile(suffix="DNE")],
tracking_params={"trx": True})
myafq.export("all_bundles_figure")

Expand Down

0 comments on commit 78f4c3e

Please sign in to comment.