Skip to content

Commit

Permalink
Update Python code for recent changes on dev.
Browse files Browse the repository at this point in the history
  • Loading branch information
dchristiaens committed Sep 12, 2024
1 parent f1f9770 commit 3d8f36f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 64 deletions.
8 changes: 2 additions & 6 deletions python/mrtrix3/commands/dwimotioncorrect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#!/usr/bin/env python

# Copyright (c) 2017-2019 Daan Christiaens
#
# MRtrix and this add-on module are distributed in the hope
Expand All @@ -17,7 +15,6 @@
# [email protected]
#

import mrtrix3
from mrtrix3 import app, image, path, run, MRtrixError
import json

Expand Down Expand Up @@ -150,7 +147,7 @@ def get_max_sh_degree(N, oversampling_factor=1.3):
for l in range(0,10,2):
if (l+3)*(l+4)/2 * oversampling_factor > N:
return l

lmax = [(b>10) * get_max_sh_degree(n) for b, n in zip(shells, shell_sizes)]
if app.ARGS.lmax:
lmax = [int(l) for l in app.ARGS.lmax.split(',')]
Expand Down Expand Up @@ -204,7 +201,7 @@ def get_max_sh_degree(N, oversampling_factor=1.3):
nthr = ''
if app.ARGS.nthreads:
nthr = ' -nthreads ' + str(app.ARGS.nthreads)

# Set multiband factor
mb = 1
if app.ARGS.mb:
Expand Down Expand Up @@ -332,5 +329,4 @@ def fieldalignstep(k):
run.command('cp sliceweights.txt ' + path.from_user(app.ARGS.export_weights, True))


mrtrix3.execute()

74 changes: 46 additions & 28 deletions python/mrtrix3/commands/motionfilter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
#!/usr/bin/env python
# Copyright (c) 2017-2019 Daan Christiaens
#
# MRtrix and this add-on module are distributed in the hope
# that it will be useful, but WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.
#
# MOTION CORRECTION FOR DWI VOLUME SERIES
#
# This script performs volume-to-series and slice-to-series registration
# of diffusion-weighted images for motion correction in the brain.
#
# Author: Daan Christiaens
# King's College London
# [email protected]
#

import argparse
import math
import numpy as np
from scipy.linalg import logm, expm
from scipy.signal import medfilt
Expand All @@ -11,31 +24,37 @@ def getsliceorder(n, p=2, s=1):
return np.array([j for k in range(0,p) for j in range((k*s)%p,n,p)], dtype=int)


if __name__ == '__main__':
# arguments
parser = argparse.ArgumentParser(description='Filtering a series of rigid motion parameters.')
parser.add_argument('input', type=str, help='input motion file')
parser.add_argument('weights', type=str, help='input weight matrix')
parser.add_argument('output', type=str, help='output motion file')
#parser.add_argument('-mb', type=int, default=1, help='multiband factor')
parser.add_argument('-packs', type=int, default=2, help='no. slice packs')
parser.add_argument('-shift', type=int, default=1, help='slice shift')
parser.add_argument('-medfilt', type=int, default=1, help='median filtering kernel size (default = 1; disabled)')
parser.add_argument('-driftfilt', action='store_true', help='drift filter slice packs')
# mandatory MRtrix options (unused)
parser.add_argument('-nthreads', type=int, default=1, help='no. threads (unused)')
parser.add_argument('-info', help='(unused)')
parser.add_argument('-debug', help='(unused)')
parser.add_argument('-quiet', help='(unused)')
parser.add_argument('-force', help='(unused)')
args = parser.parse_args()
def usage(cmdline): #pylint: disable=unused-variable
from mrtrix3 import app #pylint: disable=no-name-in-module, import-outside-toplevel
cmdline.set_author('Daan Christiaens ([email protected])')
cmdline.set_synopsis('Filtering a series of rigid motion parameters')
cmdline.add_description('This command applies a filter on a timeseries of rigid motion parameters.'
' This is used in dwimotioncorrect to correct severe registration errors.')
cmdline.add_argument('input',
type=app.Parser.FileIn(),
help='The input motion file')
cmdline.add_argument('weights',
type=app.Parser.FileIn(),
help='The input weight matrix')
cmdline.add_argument('output',
type=app.Parser.FileOut(),
help='The output motion file')
cmdline.add_argument('-packs', type=int, default=2, help='no. slice packs')
cmdline.add_argument('-shift', type=int, default=1, help='slice shift')
cmdline.add_argument('-medfilt', type=int, default=1, help='median filtering kernel size (default = 1; disabled)')
cmdline.add_argument('-driftfilt', action='store_true', help='drift filter slice packs')


def execute(): #pylint: disable=unused-variable
from mrtrix3 import MRtrixError #pylint: disable=no-name-in-module, import-outside-toplevel
from mrtrix3 import app, image, run #pylint: disable=no-name-in-module, import-outside-toplevel
# read inputs
M = np.loadtxt(args.input)
W = np.clip(np.loadtxt(args.weights), 1e-6, None)
M = np.loadtxt(app.ARGS.input)
W = np.clip(np.loadtxt(app.ARGS.weights), 1e-6, None)
# set up slice order
nv = W.shape[1]
ne = M.shape[0]//nv
sliceorder = getsliceorder(ne, args.packs, args.shift)
sliceorder = getsliceorder(ne, app.ARGS.packs, app.ARGS.shift)
isliceorder = np.argsort(sliceorder)
# reorder
M1 = np.reshape(M.reshape((nv,ne,6))[:,sliceorder,:], (-1,6))
Expand All @@ -47,11 +66,10 @@ def getsliceorder(n, p=2, s=1):
M2 = np.linalg.solve(A, W1 * M1)
# median filtering
if ne > 1:
M2 = medfilt(M2, (args.medfilt, 1))
if args.driftfilt:
M2 = medfilt(M2, (app.ARGS.medfilt, 1))
if app.ARGS.driftfilt:
M2 = M2.reshape((nv,ne,6)) - np.median(M2.reshape((nv,ne,6)), 0)
# reorder output
M3 = np.reshape(M2.reshape((nv,ne,6))[:,isliceorder,:], (-1,6))
np.savetxt(args.output, M3, fmt='%.6f')

np.savetxt(app.ARGS.output, M3, fmt='%.6f')

71 changes: 41 additions & 30 deletions python/mrtrix3/commands/motionstats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
#!/usr/bin/env python
# Copyright (c) 2017-2019 Daan Christiaens
#
# MRtrix and this add-on module are distributed in the hope
# that it will be useful, but WITHOUT ANY WARRANTY; without
# even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.
#
# Author: Daan Christiaens
# King's College London
# [email protected]
#

import argparse
import math
import numpy as np
from scipy.linalg import logm, expm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -46,32 +54,35 @@ def tr2euler(T):
return np.array([T[0,3], T[1,3], T[2,3], yaw, pitch, roll])


def usage(cmdline): #pylint: disable=unused-variable
from mrtrix3 import app #pylint: disable=no-name-in-module, import-outside-toplevel
cmdline.set_author('Daan Christiaens ([email protected])')
cmdline.set_synopsis('Calculate motion and outlier statistics')
cmdline.add_description('This command calculates statistics of the subject translation and rotation'
' and of the detected slice outliers.')
cmdline.add_argument('input',
type=app.Parser.FileIn(),
help='The input motion file')
cmdline.add_argument('weights',
type=app.Parser.FileIn(),
help='The input weight matrix')
cmdline.add_argument('-packs', type=int, default=2, help='no. slice packs')
cmdline.add_argument('-shift', type=int, default=1, help='slice shift')
cmdline.add_argument('-plot', action='store_true', help='plot motion trajectory')
cmdline.add_argument('-grad', type=app.Parser.FileIn(), help='dMRI gradient table')
cmdline.add_argument('-dispersion', type=app.Parser.FileIn(), help='output gradient dispersion to file')

if __name__ == '__main__':
# arguments
parser = argparse.ArgumentParser(description='Calculate motion and outlier statistics.')
parser.add_argument('input', type=str, help='input motion file')
parser.add_argument('weights', type=str, help='input weight matrix')
#parser.add_argument('-mb', type=int, default=1, help='multiband factor')
parser.add_argument('-packs', type=int, default=2, help='no. slice packs')
parser.add_argument('-shift', type=int, default=1, help='slice shift')
parser.add_argument('-plot', action='store_true', help='plot motion trajectory')
parser.add_argument('-grad', type=str, help='dMRI gradient table')
parser.add_argument('-dispersion', type=str, help='output gradient dispersion to file')
# mandatory MRtrix options (unused)
parser.add_argument('-nthreads', type=int, default=1, help='no. threads (unused)')
parser.add_argument('-info', action='store_true', help='(unused)')
parser.add_argument('-debug', action='store_true', help='(unused)')
parser.add_argument('-quiet', action='store_true', help='(unused)')
parser.add_argument('-force', action='store_true', help='(unused)')
args = parser.parse_args()

def execute(): #pylint: disable=unused-variable
from mrtrix3 import MRtrixError #pylint: disable=no-name-in-module, import-outside-toplevel
from mrtrix3 import app, image #pylint: disable=no-name-in-module, import-outside-toplevel
# read inputs
M0 = np.loadtxt(args.input)
W = np.loadtxt(args.weights)
M0 = np.loadtxt(app.ARGS.input)
W = np.loadtxt(app.ARGS.weights)
# set up slice order
nv = W.shape[1]
ne = M0.shape[0]//nv
sliceorder = getsliceorder(ne, args.packs, args.shift)
sliceorder = getsliceorder(ne, app.ARGS.packs, app.ARGS.shift)
isliceorder = np.argsort(sliceorder)
# reorder
M = np.reshape(M0.reshape((nv,ne,6))[:,sliceorder,:], (-1,6))
Expand All @@ -84,22 +95,22 @@ def tr2euler(T):
# print stats
print('{:f} {:f} {:f}'.format(mtra, mrot, orratio))
# plot trajectory
if args.plot:
if app.ARGS.plot:
T = np.array([tr2euler(lie2tr(m)) for m in M])
ax1 = plt.subplot(2,1,1); plt.plot(T[:,:3]); plt.ylabel('translation'); plt.legend(['x', 'y', 'z']);
ax2 = plt.subplot(2,1,2, sharex=ax1); plt.plot(T[:,3:]); plt.ylabel('rotation'); plt.legend(['yaw', 'pitch', 'roll']);
plt.xlim(0, nv*ne); plt.xlabel('time'); plt.tight_layout();
plt.show();
# intra-volume gradient scatter
if args.dispersion:
if args.grad:
grad = np.loadtxt(args.grad)
if app.ARGS.dispersion:
if app.ARGS.grad:
grad = np.loadtxt(app.ARGS.grad)
bvec = grad[:,:3] / np.linalg.norm(grad[:,:3], axis=1)[:,np.newaxis]
r = np.array([[np.dot(lie2tr(m)[:3,:3], v) for m in mvol] for mvol, v in zip(M.reshape((nv,ne,6)), bvec)])
rm = np.sum(r, axis=1); rm /= np.linalg.norm(rm, axis=1)[:,np.newaxis]
rd = np.einsum('vzi,vi->vz', r, rm)
dispersion = np.degrees(2*np.arccos(np.sqrt(np.mean(rd**2, axis=1))))
np.savetxt(args.dispersion, dispersion[np.newaxis,:], fmt='%.4f')
np.savetxt(app.ARGS.dispersion, dispersion[np.newaxis,:], fmt='%.4f')
else:
print('error: gradient table not provided.')
raise MRtrixError('No diffusion gradient table provided')

0 comments on commit 3d8f36f

Please sign in to comment.