diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 9c430d3b..239f0ebc 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -123,19 +123,17 @@ def __matmul__(self, b): True >>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) - >>> xfm1 @ np.eye(4) == xfm1 + >>> xfm1 @ Affine() == xfm1 True """ - if not isinstance(b, self.__class__): - _b = self.__class__(b) - else: - _b = b + if isinstance(b, self.__class__): + return self.__class__( + b.matrix @ self.matrix, + reference=b.reference, + ) - retval = self.__class__(self.matrix.dot(_b.matrix)) - if _b.reference: - retval.reference = _b.reference - return retval + return b @ self @property def matrix(self): diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 233f5adf..58d15058 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -8,7 +8,6 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" from collections.abc import Iterable -import numpy as np from .base import ( TransformBase, @@ -140,9 +139,9 @@ def map(self, x, inverse=False): return x - def asaffine(self, indices=None): + def collapse(self): """ - Combine a succession of linear transforms into one. + Combine a succession of transforms into one. Example ------ @@ -150,7 +149,7 @@ def asaffine(self, indices=None): ... Affine.from_matvec(vec=(2, -10, 3)), ... Affine.from_matvec(vec=(-2, 10, -3)), ... ]) - >>> chain.asaffine() + >>> chain.collapse() array([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], @@ -160,7 +159,7 @@ def asaffine(self, indices=None): ... Affine.from_matvec(vec=(1, 2, 3)), ... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]), ... ]) - >>> chain.asaffine() + >>> chain.collapse() array([[0., 1., 0., 2.], [0., 0., 1., 3.], [1., 0., 0., 1.], @@ -168,7 +167,7 @@ def asaffine(self, indices=None): >>> np.allclose( ... chain.map((4, -2, 1)), - ... chain.asaffine().map((4, -2, 1)), + ... chain.collapse().map((4, -2, 1)), ... ) True @@ -178,9 +177,8 @@ def asaffine(self, indices=None): The indices of the values to extract. """ - affines = self.transforms if indices is None else np.take(self.transforms, indices) - retval = affines[0] - for xfm in affines[1:]: + retval = self.transforms[-1] + for xfm in reversed(self.transforms[:-1]): retval = xfm @ retval return retval diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index eea77b7f..f3f83b38 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -372,10 +372,10 @@ def test_mulmat_operator(testdata_path): mat2 = from_matvec(np.eye(3), (4, 2, -1)) aff = nitl.Affine(mat1, reference=ref) - composed = aff @ mat2 + composed = aff @ nitl.Affine(mat2) assert composed.reference is None - assert composed == nitl.Affine(mat1.dot(mat2)) + assert composed == nitl.Affine(mat2 @ mat1) composed = nitl.Affine(mat2) @ aff assert composed.reference == aff.reference - assert composed == nitl.Affine(mat2.dot(mat1), reference=ref) + assert composed == nitl.Affine(mat1 @ mat2, reference=ref) diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index 6dee540e..59f7f3b7 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -60,6 +60,12 @@ def test_itk_h5(tmp_path, testdata_path): # A certain tolerance is necessary because of resampling at borders assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL + col_moved = xfm.collapse().apply(img_fname, order=0) + col_moved.to_filename("nt_collapse_resampled.nii.gz") + diff = sw_moved.get_fdata() - col_moved.get_fdata() + # A certain tolerance is necessary because of resampling at borders + assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL + @pytest.mark.parametrize("ext0", ["lta", "tfm"]) @pytest.mark.parametrize("ext1", ["lta", "tfm"]) @@ -81,7 +87,7 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2): ] ) assert np.allclose( - chain.asaffine().matrix, + chain.collapse().matrix, Affine.from_filename( data_path / "regressions" / f"from-fsnative_to-bold_mode-image.{ext2}", fmt=f"{FMT[ext2]}",