diff --git a/tedana/utils.py b/tedana/utils.py index 6b88f1da0..356e27b99 100644 --- a/tedana/utils.py +++ b/tedana/utils.py @@ -104,16 +104,21 @@ def make_adaptive_mask(data, mask, threshold=1): - This is the threshold for "good" data. - The 1/3 value is arbitrary. - If there was more than one exemplar voxel, - retain the the highest value across the exemplars for each echo. - d. For each voxel, count the number of echoes that have a mean value greater than the + retain the echo-wise values from the exemplar with the highest total value. + d. For each voxel, identify the last echo with a mean value greater than the corresponding echo's threshold. + + - Preceding echoes (including ones with mean values less than the threshold) + are considered "good" data. """ RepLGR.info( "An adaptive mask was then generated, " "in which each voxel's value reflects the number of echoes with 'good' data." ) - mask = reshape_niimg(mask).astype(bool) - data = data[mask, :, :] + # mask = reshape_niimg(mask).astype(bool) + # data = data[mask, :, :] + + n_samples, n_echoes, _ = data.shape # take temporal mean of echos and extract non-zero values in first echo echo_means = data.mean(axis=-1) # temporal mean of echos @@ -121,40 +126,46 @@ def make_adaptive_mask(data, mask, threshold=1): # get 33rd %ile of `first_echo` and find corresponding index # NOTE: percentile is arbitrary - # TODO: "interpolation" param changed to "method" in numpy 1.22.0 - # confirm method="higher" is the same as interpolation="higher" - # Current minimum version for numpy in tedana is 1.16 where - # there is no "method" parameter. Either wait until we bump - # our minimum numpy version to 1.22 or add a version check - # or try/catch statement. - perc = np.percentile(first_echo, 33, interpolation="higher") - perc_val = echo_means[:, 0] == perc + perc = np.percentile(first_echo, 33, method="higher") + voxels_at_perc = echo_means[:, 0] == perc # extract values from all echos at relevant index # NOTE: threshold of 1/3 voxel value is arbitrary - lthrs = np.squeeze(echo_means[perc_val].T) / 3 + lthrs = np.squeeze(echo_means[voxels_at_perc, :].T) / 3 - # if multiple samples were extracted per echo, keep the one w/the highest signal + # if multiple voxels exactly match the 33rd percentile value in the first echo, + # retain the values from the voxel with the highest total value across echoes if lthrs.ndim > 1: lthrs = lthrs[:, lthrs.sum(axis=0).argmax()] - # determine samples where absolute value is greater than echo-specific thresholds - # and count # of echos that pass criterion - masksum = (np.abs(echo_means) > lthrs).sum(axis=-1) + # Find the last good echo for each voxel + # Add a 1 to the end of the threshold array to match the size of the echo_means array + lthrs = np.hstack((lthrs, 1)) + # Add a 0 to the end of the echo_means array to make a trailing echo "bad". + # This way, argmax can distinguish between all bad echoes and having the last echo be good. + # The former will have a value of n_echoes + 1, while the latter will have a value of n_echoes. + echo_means = np.hstack((echo_means, np.zeros((n_samples, 1)))) + # argmax finds the first instance of the maximum value, so we need to reverse the order + # of the array to find the last instance of the maximum value. + masksum_inverted = np.argmax(np.abs(echo_means[:, ::-1]) > lthrs[::-1], axis=1) + masksum = n_echoes - masksum_inverted + # Replace values of n_echoes + 1 (all bad echoes) with 0 (no good echoes) + masksum[masksum == n_echoes + 1] = 0 # TODO: Use visual report to make checking the reduced mask easier if np.any(masksum < threshold): n_bad_voxels = np.sum(masksum < threshold) LGR.warning( - f"{n_bad_voxels} voxels in user-defined mask do not have good " - "signal. Removing voxels from mask." + f"{n_bad_voxels} voxels in user-defined mask do not have good signal. " + "Removing voxels from mask." ) masksum[masksum < threshold] = 0 + masksum = masksum * mask.astype(bool) modified_mask = masksum.astype(bool) - masksum = unmask(masksum, mask) - modified_mask = unmask(modified_mask, mask) + # masksum = unmask(masksum, mask) + # modified_mask = unmask(modified_mask, mask) return modified_mask, masksum