From 79f19742a31a0417fdde1acb0d1e882d2116ca27 Mon Sep 17 00:00:00 2001 From: Olivier Binette Date: Fri, 17 Nov 2023 17:01:17 -0500 Subject: [PATCH] fix compress_memberships bug, add test --- er_evaluation/data_structures/_data_structures.py | 4 ++-- .../test_compress_memberships.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 tests/test_data_structures/test_compress_memberships.py diff --git a/er_evaluation/data_structures/_data_structures.py b/er_evaluation/data_structures/_data_structures.py index d0251e0..157dcd7 100644 --- a/er_evaluation/data_structures/_data_structures.py +++ b/er_evaluation/data_structures/_data_structures.py @@ -28,9 +28,9 @@ def compress_memberships(*memberships): Name: 0, dtype: int8 """ compressed = pd.concat(memberships, axis=1) - compressed.index = np.where(compressed.index.isna(), np.nan, pd.Categorical(compressed.index).codes) + compressed.index = pd.Categorical(compressed.index).codes for col in compressed.columns: - compressed[col] = pd.Categorical(compressed[col]).codes + compressed[col] = np.where(compressed[col].isna(), np.nan, pd.Categorical(compressed[col]).codes) return [compressed[col] for col in compressed.columns] diff --git a/tests/test_data_structures/test_compress_memberships.py b/tests/test_data_structures/test_compress_memberships.py new file mode 100644 index 0000000..be3409e --- /dev/null +++ b/tests/test_data_structures/test_compress_memberships.py @@ -0,0 +1,12 @@ +import pandas as pd + +from er_evaluation.data_structures import compress_memberships + +def test_keep_na_values_in_index(): + series1 = pd.Series(index=[-1, 0, 4, 7], data=[pd.NA, 1, 2, 3]) + series2 = pd.Series(index=[1, 0, 4, 8], data=[1, pd.NA, 2, 3]) + cs1, cs2 = compress_memberships(series1, series2) + + assert cs1.isna().sum() == 3 + assert cs2.isna().sum() == 3 +