From 3ce0db06d740c2a7cb1ca19f7208e293d612db08 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 15 Feb 2024 08:30:26 -0500 Subject: [PATCH] Add unit test --- scilpy/gradients/bvec_bval_tools.py | 8 +++++++- scilpy/gradients/tests/test_bvec_bval_tools.py | 18 +++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/scilpy/gradients/bvec_bval_tools.py b/scilpy/gradients/bvec_bval_tools.py index 377b6fac4..2bcf53848 100644 --- a/scilpy/gradients/bvec_bval_tools.py +++ b/scilpy/gradients/bvec_bval_tools.py @@ -73,6 +73,13 @@ def check_b0_threshold(min_bval, b0_thr, skip_b0_check): If True, and no b0 is found, only print a warning, do not raise an error. + Returns + ------- + b0_thr: float + Either the unmodified b0_thr, or, in the case where the minimal b-value + is larger than b0_thr, and skip_b0_check is set to True, then returns + min_bval. + Raises ------ ValueError @@ -93,7 +100,6 @@ def check_b0_threshold(min_bval, b0_thr, skip_b0_check): if min_bval > b0_thr: if skip_b0_check: - logging.warning("GOT {} > {}".format(min_bval, b0_thr)) logging.warning( 'Your minimal bvalue ({}), is above the threshold ({})\n' 'Since --skip_b0_check was specified, the script will ' diff --git a/scilpy/gradients/tests/test_bvec_bval_tools.py b/scilpy/gradients/tests/test_bvec_bval_tools.py index 3d82cc188..4085dd5d3 100644 --- a/scilpy/gradients/tests/test_bvec_bval_tools.py +++ b/scilpy/gradients/tests/test_bvec_bval_tools.py @@ -2,9 +2,9 @@ import numpy as np from scilpy.gradients.bvec_bval_tools import ( - identify_shells, is_normalized_bvecs, flip_gradient_sampling, - normalize_bvecs, round_bvals_to_shell, str_to_axis_index, - swap_gradient_axis) + check_b0_threshold, identify_shells, is_normalized_bvecs, + flip_gradient_sampling, normalize_bvecs, round_bvals_to_shell, + str_to_axis_index, swap_gradient_axis) bvecs = np.asarray([[1.0, 1.0, 1.0], [1.0, 0.0, 1.0], @@ -23,8 +23,16 @@ def test_normalize_bvecs(): def test_check_b0_threshold(): - # toDo To be modified (see PR#867). - pass + assert check_b0_threshold(min_bval=0, b0_thr=0, skip_b0_check=False) == 0 + assert check_b0_threshold(min_bval=0, b0_thr=20, skip_b0_check=False) == 20 + assert check_b0_threshold(min_bval=20, b0_thr=0, skip_b0_check=True) == 20 + + error_raised = False + try: + _ = check_b0_threshold(min_bval=20, b0_thr=0, skip_b0_check=False) + except ValueError: + error_raised = True + assert error_raised def test_identify_shells():