Skip to content

Commit

Permalink
Bug fix for 4D PCA-RDI case
Browse files Browse the repository at this point in the history
  • Loading branch information
VChristiaens committed Jul 18, 2024
1 parent 8778553 commit 54e70c6
Showing 1 changed file with 62 additions and 94 deletions.
156 changes: 62 additions & 94 deletions vip_hci/psfsub/pca_fullfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,92 +508,75 @@ def pca(*all_args: List, **all_kwargs: dict):
final_residuals_cube = []
recon_cube = []
medians = []

# (ADI+)RDI
if algo_params.cube_ref is not None:
for ch in range(nch):
table = []
pclist = []
grid_case = False

# ADI or RDI
for ch in range(nch):
add_params = {
"start_time": start_time,
"cube": algo_params.cube[ch],
"ncomp": ncomp[ch], # algo_params.ncomp[ch],
"fwhm": algo_params.fwhm[ch],
"full_output": True,
}

# RDI
if algo_params.cube_ref is not None:
if algo_params.cube_ref[ch].ndim != 3:
msg = "Ref cube has wrong format for 4d input cube"
raise TypeError(msg)
add_params["cube_ref"] = algo_params.cube_ref[ch]

add_params = {
"start_time": start_time,
"cube": algo_params.cube[ch],
"cube_ref": algo_params.cube_ref[ch],
"ncomp": ncomp[ch],
}
func_params = setup_parameters(
params_obj=algo_params, fkt=_adi_rdi_pca, **add_params
)
res_pca = _adi_rdi_pca(
**func_params,
**rot_options,
)
pcs.append(res_pca[0])
recon.append(res_pca[1])
residuals_cube.append(res_pca[2])
residuals_cube_.append(res_pca[3])
ifs_adi_frames[ch] = res_pca[-1]
# ADI
else:
table = []
pclist = []
for ch in range(nch):
add_params = {
"start_time": start_time,
"cube": algo_params.cube[ch],
"ncomp": ncomp[ch], # algo_params.ncomp[ch],
"fwhm": algo_params.fwhm[ch],
"full_output": True,
}
func_params = setup_parameters(
params_obj=algo_params, fkt=_adi_rdi_pca, **add_params
)
res_pca = _adi_rdi_pca(
**func_params,
**rot_options,
)
grid_case = False
if algo_params.batch is None:
if algo_params.source_xy is not None:
# PCA grid, computing S/Ns
if isinstance(ncomp[ch], (tuple, list)):
final_residuals_cube.append(res_pca[0])
ifs_adi_frames[ch] = res_pca[1]
table.append(res_pca[2])
# full-frame PCA with rotation threshold
else:
recon_cube.append(res_pca[0])
residuals_cube.append(res_pca[1])
residuals_cube_.append(res_pca[2])
ifs_adi_frames[ch] = res_pca[-1]
func_params = setup_parameters(
params_obj=algo_params, fkt=_adi_rdi_pca, **add_params
)
res_pca = _adi_rdi_pca(
**func_params,
**rot_options,
)

if algo_params.batch is None:
if algo_params.source_xy is not None:
# PCA grid, computing S/Ns
if isinstance(ncomp[ch], (tuple, list)):
final_residuals_cube.append(res_pca[0])
ifs_adi_frames[ch] = res_pca[1]
table.append(res_pca[2])
# full-frame PCA with rotation threshold
else:
# PCA grid
if isinstance(ncomp[ch], (tuple, list)):
ifs_adi_frames[ch] = res_pca[0]
pclist.append(res_pca[1])
grid_case = True
# full-frame standard PCA
else:
pcs.append(res_pca[0])
recon.append(res_pca[1])
residuals_cube.append(res_pca[2])
residuals_cube_.append(res_pca[3])
ifs_adi_frames[ch] = res_pca[-1]
# full-frame incremental PCA
recon_cube.append(res_pca[0])
residuals_cube.append(res_pca[1])
residuals_cube_.append(res_pca[2])
ifs_adi_frames[ch] = res_pca[-1]
else:
ifs_adi_frames[ch] = res_pca[0]
pcs.append(res_pca[2])
medians.append(res_pca[3])

if grid_case:
for i in range(len(ncomp[0])):
frame = cube_collapse(ifs_adi_frames[:, i],
mode=algo_params.collapse_ifs)
final_residuals_cube.append(frame)
# PCA grid
if isinstance(ncomp[ch], (tuple, list)):
ifs_adi_frames[ch] = res_pca[0]
pclist.append(res_pca[1])
grid_case = True
# full-frame standard PCA
else:
pcs.append(res_pca[0])
recon.append(res_pca[1])
residuals_cube.append(res_pca[2])
residuals_cube_.append(res_pca[3])
ifs_adi_frames[ch] = res_pca[-1]
# full-frame incremental PCA
else:
frame = cube_collapse(ifs_adi_frames,
ifs_adi_frames[ch] = res_pca[0]
pcs.append(res_pca[2])
medians.append(res_pca[3])

if grid_case:
for i in range(len(ncomp[0])):
frame = cube_collapse(ifs_adi_frames[:, i],
mode=algo_params.collapse_ifs)
final_residuals_cube.append(frame)
else:
frame = cube_collapse(ifs_adi_frames,
mode=algo_params.collapse_ifs)

# convert to numpy arrays when relevant
if len(pcs) > 0:
Expand All @@ -611,22 +594,7 @@ def pca(*all_args: List, **all_kwargs: dict):
if len(medians) > 0:
medians = np.array(medians)

# ADI + RDI
# elif algo_params.cube_ref is not None:
# add_params = {
# "start_time": start_time,
# "full_output": True,
# }
# func_params = setup_parameters(
# params_obj=algo_params, fkt=_adi_rdi_pca, **add_params
# )
# res_pca = _adi_rdi_pca(
# **func_params,
# **rot_options,
# )
# pcs, recon, residuals_cube, residuals_cube_, frame = res_pca

# ADI+RDI OR ADI. Shape of cube: (n_adi_frames, y, x)
# 3D RDI or ADI. Shape of cube: (n_adi_frames, y, x)
else:
add_params = {
"start_time": start_time,
Expand Down

0 comments on commit 54e70c6

Please sign in to comment.