Skip to content

Commit

Permalink
Adapted cube_collapse function to also accept an input 4D cube
Browse files Browse the repository at this point in the history
  • Loading branch information
VChristiaens committed Mar 20, 2024
1 parent 5780354 commit a65d4ba
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions vip_hci/preproc/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@


def cube_collapse(cube, mode='median', n=50, w=None):
""" Collapses a cube into a frame (3D array -> 2D array) depending on the
parameter ``mode``. It's possible to perform a trimmed mean combination of
the frames based on description in [BRA13]_.
"""Collapse a 3D or 4D cube into a 2D frame or 3D cube, respectively.
The ``mode`` parameter determines how the collapse should be done. It is
possible to perform a trimmed mean combination of the frames, as in
[BRA13]_. In case of a 4D input cube, it is assumed to be an IFS dataset
with the zero-th axis being the spectral dimension, and the first axis the
temporal dimension.
Parameters
----------
Expand All @@ -47,7 +52,12 @@ def cube_collapse(cube, mode='median', n=50, w=None):
Output array, cube combined.
"""
arr = cube
if arr.ndim != 3:
if arr.ndim == 3:
ax = 0
elif arr.ndim == 4:
nch = arr.shape[0]
ax = 1
else:
raise TypeError('The input array is not a cube or 3d array.')

if mode == 'wmean':
Expand All @@ -60,27 +70,39 @@ def cube_collapse(cube, mode='median', n=50, w=None):
w = np.array(w)

if mode == 'mean':
frame = np.nanmean(arr, axis=0)
frame = np.nanmean(arr, axis=ax)
elif mode == 'median':
frame = np.nanmedian(arr, axis=0)
frame = np.nanmedian(arr, axis=ax)
elif mode == 'sum':
frame = np.nansum(arr, axis=0)
frame = np.nansum(arr, axis=ax)
elif mode == 'max':
frame = np.nanmax(arr, axis=0)
frame = np.nanmax(arr, axis=ax)
elif mode == 'trimmean':
N = arr.shape[0]
N = arr.shape[ax]
k = (N - n)//2
if N % 2 != n % 2:
n += 1
frame = np.empty_like(arr[0])
for index, _ in np.ndenumerate(arr[0]):
sort = np.sort(arr[:, index[0], index[1]])
frame[index] = np.nanmean(sort[k:k+n])
if ax == 0:
frame = np.empty_like(arr[0])
for index, _ in np.ndenumerate(arr[0]):
sort = np.sort(arr[:, index[0], index[1]])
frame[index] = np.nanmean(sort[k:k+n])
else:
frame = np.empty_like(arr[:, 0])
for j in range(nch):
for index, _ in np.ndenumerate(arr[:, 0]):
sort = np.sort(arr[j, :, index[0], index[1]])
frame[j][index] = np.nanmean(sort[k:k+n])
elif mode == 'wmean':
arr[np.where(np.isnan(arr))] = 0 # to avoid product with nan
frame = np.inner(w, np.moveaxis(arr, 0, -1))
if ax == 0:
frame = np.inner(w, np.moveaxis(arr, 0, -1))
else:
frame = np.empty_like(arr[:, 0])
for j in range(nch):
frame[j] = np.inner(w, np.moveaxis(arr[j], 0, -1))
elif mode == 'absmean':
frame = np.mean(np.abs(arr), axis=0)
frame = np.nanmean(np.abs(arr), axis=ax)
else:
raise TypeError("mode not recognized")

Expand Down Expand Up @@ -170,8 +192,9 @@ def cube_subsample(array, n, mode="mean", w=None, parallactic=None,


def cube_subsample_trimmean(arr, n, m):
"""Performs a trimmed mean combination every m frames in a cube. Based on
description in Brandt+ 2012.
"""Perform a trimmed mean combination every m frames in a cube.
Details in [BRA13]_.
Parameters
----------
Expand Down

0 comments on commit a65d4ba

Please sign in to comment.