Skip to content

Commit

Permalink
Updated PCA test to include ncomp input as list
Browse files Browse the repository at this point in the history
  • Loading branch information
VChristiaens committed Sep 6, 2023
1 parent e0fc0d9 commit 49745f4
Showing 1 changed file with 25 additions and 29 deletions.
54 changes: 25 additions & 29 deletions tests/post_3_10/test_objects_pppca.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tests.snapshots.snapshot_psfsub import PSFADI_PATH

NO_FRAME_CASE = ["pca_drot", "pca_ann_auto"]
PREV_CASE = ["pca_grid_list"]

# Note : this function comes from the former test for adi psfsub, I did not write it,
# and I didn't found the author (feel free to add the author if you know them)
Expand Down Expand Up @@ -49,17 +50,14 @@ def test_pca_object(injected_cube_position):
Generate a frame with ``vip_hci.objects.pppca`` and ensure they match with
their procedural counterpart. This is done by getting the snapshots of the
``vip_hci.psfsub.pca`` and ``vip_hci.psfsub.pca_annular`` functions, generated
preemptively with ``tests..snapshots.snapshot_psfsub``. There are a lot more
of cases to test than in regular post-processing methods because PCA is the
"go-to" method of the package.
``vip_hci.psfsub.pca`` and ``vip_hci.psfsub.pca_annular`` functions,
generated preemptively with ``tests.snapshots.snapshot_psfsub``. There are a
lot more cases to test than in regular post-processing methods because PCA
is the "go-to" method of the package.
"""
betapic = injected_cube_position

imlib_rot = "vip-fft"
interpolation = None

# Testing the basic version of PCA

position = np.load(f"{PSFADI_PATH}pca_adi_detect.npy").copy()
Expand All @@ -71,17 +69,9 @@ def test_pca_object(injected_cube_position):
pca_obj.make_snrmap()

check_detection(pca_obj.snr_map, position, betapic.fwhm, snr_thresh=2)
assert np.allclose(np.abs(pca_obj.frame_final), np.abs(exp_frame), atol=1e-2)

# Testing the basic version of PCA with ncomp as list
update_params = {"ncomp": [1, 2],
"verbose": False}
update(pca_obj, PCABuilder(**update_params))
pca_obj.run(runmode="classic")
assert np.allclose(np.abs(pca_obj.frames_final[0]), np.abs(exp_frame),
assert np.allclose(np.abs(pca_obj.frame_final), np.abs(exp_frame),
atol=1e-2)


pca_test_set = [
{
"case_name": "pca_left_eigv",
Expand Down Expand Up @@ -156,23 +146,29 @@ def test_pca_object(injected_cube_position):
]
# Testing all alternatives of PCA listed above

for pca_test in pca_test_set:
for pp, pca_test in enumerate(pca_test_set):
case_name = pca_test["case_name"]
update_params = pca_test["update_params"]
runmode = pca_test["runmode"]

position = np.load(f"{PSFADI_PATH}{case_name}_adi_detect.npy").copy()

if case_name not in NO_FRAME_CASE:
exp_frame = np.load(f"{PSFADI_PATH}{case_name}_adi.npy").copy()
if case_name in PREV_CASE:
# compare to frame_final from previous iteration
assert np.allclose(
np.abs(pca_obj.frame_final), np.abs(pca_obj.frames_final[-1]),
atol=1e-2
)
else:
pos = np.load(f"{PSFADI_PATH}{case_name}_adi_detect.npy").copy()
if case_name not in NO_FRAME_CASE:
exp_frame = np.load(f"{PSFADI_PATH}{case_name}_adi.npy").copy()

update(pca_obj, PCABuilder(**update_params))
update(pca_obj, PCABuilder(**update_params))

pca_obj.run(runmode=runmode)
pca_obj.make_snrmap()
pca_obj.run(runmode=runmode)
pca_obj.make_snrmap()

check_detection(pca_obj.snr_map, position, betapic.fwhm, snr_thresh=2)
if case_name not in NO_FRAME_CASE:
assert np.allclose(
np.abs(pca_obj.frame_final), np.abs(exp_frame), atol=1e-2
)
check_detection(pca_obj.snr_map, pos, betapic.fwhm, snr_thresh=2)
if case_name not in NO_FRAME_CASE:
assert np.allclose(
np.abs(pca_obj.frame_final), np.abs(exp_frame), atol=1e-2
)

0 comments on commit 49745f4

Please sign in to comment.