Skip to content

Commit

Permalink
add on_nested test
Browse files Browse the repository at this point in the history
  • Loading branch information
smcguire-cmu committed Sep 5, 2024
1 parent c81981c commit eb257f4
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/lsdb/catalog/test_nested.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import nested_dask as nd
import pandas as pd

Expand All @@ -13,3 +14,24 @@ def test_dropna(small_sky_with_nested_sources):
filtered_compute = filtered_cat.compute()
assert len(drop_na_compute) < len(filtered_compute)
pd.testing.assert_frame_equal(drop_na_compute, filtered_compute.dropna())


def test_dropna_on_nested(small_sky_with_nested_sources):
def add_na_values_nested(df):
"""replaces the first source_ra value in each nested df with NaN"""
for i in range(len(df)):
first_ra_value = df.iloc[i]["sources"].iloc[0]["source_ra"]
df["sources"].array[i] = df["sources"].array[i].replace(first_ra_value, np.NaN)
return df

filtered_cat = small_sky_with_nested_sources.map_partitions(add_na_values_nested)
drop_na_cat = filtered_cat.dropna(on_nested="sources")
assert isinstance(drop_na_cat, Catalog)
assert isinstance(drop_na_cat._ddf, nd.NestedFrame)
drop_na_sources_compute = drop_na_cat["sources"].compute()
filtered_sources_compute = filtered_cat["sources"].compute()
assert len(drop_na_sources_compute) == len(filtered_sources_compute)
assert sum(map(len, drop_na_sources_compute)) < sum(map(len, filtered_sources_compute))
pd.testing.assert_frame_equal(
drop_na_cat.compute(), filtered_cat._ddf.dropna(on_nested="sources").compute()
)

0 comments on commit eb257f4

Please sign in to comment.