diff --git a/src/sorcha/modules/PPStats.py b/src/sorcha/modules/PPStats.py index 1307842e..324108dd 100644 --- a/src/sorcha/modules/PPStats.py +++ b/src/sorcha/modules/PPStats.py @@ -44,6 +44,9 @@ def stats(observations, statsfilename, outpath, configs): linked = group_by["object_linked"].agg("all").to_frame("object_linked") date_linked = group_by["date_linked_MJD"].agg("first").to_frame("date_linked_MJD") joined_stats = num_obs.join([mag, phase_deg, linked, date_linked]) + elif configs["SSP_linking_on"]: + date_linked = group_by["date_linked_MJD"].agg("first").to_frame("date_linked_MJD") + joined_stats = num_obs.join([mag, phase_deg, date_linked]) else: joined_stats = num_obs.join([mag, phase_deg]) diff --git a/tests/sorcha/test_PPStats.py b/tests/sorcha/test_PPStats.py index eebbec32..f60ab6e4 100644 --- a/tests/sorcha/test_PPStats.py +++ b/tests/sorcha/test_PPStats.py @@ -123,3 +123,54 @@ def test_PPStats_nolinking(tmp_path): # the previous test checks all rows so it's fine to just check one here, this test is mostly to make # sure that the stats file works correctly if linking is off assert_equal(expected_row_one, stats_df.iloc[0].values) + + +def test_PPStats_justlinking(tmp_path): + # tests behaviour when linking is on but drop_unlinked=True + + ObjID = (["object_one"] * 10) + (["object_two"] * 5) + optFilter = (["r"] * 6) + (["g"] * 4) + (["r"] * 5) + trailedSourceMag = np.concatenate( + (np.linspace(18, 21, 6), np.linspace(19, 22, 4), np.linspace(20, 23, 5)) + ) + phase_deg = np.concatenate((np.linspace(3, 10, 6), np.linspace(4, 11, 4), np.linspace(5, 10, 5))) + obj_date = np.array(([666.0] * 10) + ([np.nan] * 5), dtype=object) + + test_dict = { + "ObjID": ObjID, + "optFilter": optFilter, + "trailedSourceMag": trailedSourceMag, + "phase_deg": phase_deg, + "date_linked_MJD": obj_date, + } + test_df = pd.DataFrame(test_dict) + + configs = {"SSP_linking_on": True, "drop_unlinked": True} + + filename_stats = "test_stats" + stats(test_df, filename_stats, tmp_path, configs) + + stats_df = pd.read_csv(os.path.join(tmp_path, filename_stats + ".csv")) + + assert len(stats_df) == 3 + + expected_columns = np.array( + [ + "ObjID", + "optFilter", + "number_obs", + "min_apparent_mag", + "max_apparent_mag", + "median_apparent_mag", + "min_phase", + "max_phase", + "date_linked_MJD", + ], + dtype=object, + ) + + assert_equal(expected_columns, stats_df.columns.values) + + expected_row_one = np.array(["object_one", "g", 4, 19.0, 22.0, 20.5, 4.0, 11.0, 666.0], dtype=object) + + assert_equal(expected_row_one, stats_df.iloc[0].values)