Skip to content

Commit

Permalink
fix pandas dep warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
yoachim committed Mar 29, 2024
1 parent 88eb4e9 commit 343069a
Showing 1 changed file with 40 additions and 17 deletions.
57 changes: 40 additions & 17 deletions rubin_sim/maf/metrics/sn_n_sn_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,10 @@ def season_length(self, seasons, data_slice, zseason):
if season_info.empty:
return [], pd.DataFrame()

season_info["season_copy"] = season_info["season"].values.copy()
dur_z = (
season_info.groupby("season", group_keys=False)
.apply(lambda x: self.nsn_expected_z(x))
season_info.groupby("season_copy", group_keys=False)
.apply(lambda x: self.nsn_expected_z(x), include_groups=False)
.reset_index(drop=True)
)
return season_info["season"].to_list(), dur_z
Expand All @@ -384,10 +385,10 @@ def get_season_info(self, dfa, zseason, min_duration=60.0):
season_info : `pd.DataFrame` with season length infos
"""

dfa["season_copy"] = dfa["season"].values.copy()
season_info = (
dfa.groupby("season")
.apply(lambda x: self.season_info(x, min_duration=min_duration))
.apply(lambda x: self.season_info(x, min_duration=min_duration), include_groups=False)
.reset_index()
)

Expand Down Expand Up @@ -424,7 +425,10 @@ def step_lc(self, obs, gen_par, x1=-2.0, color=0.2):
SN light curves (astropy table)
"""
lc = obs.groupby(["season"]).apply(lambda x: self.gen_lc(x, gen_par, x1, color))
obs["season_copy"] = obs["season"].values.copy()
lc = obs.groupby(["season_copy"]).apply(
lambda x: self.gen_lc(x, gen_par, x1, color), include_groups=False
)
return lc

def step_efficiencies(self, lc):
Expand All @@ -441,11 +445,11 @@ def step_efficiencies(self, lc):
`pd.DataFrame` with efficiencies
"""
sn_effis = (
lc.groupby(["season", "z", "x1", "color", "sntype"])
.apply(lambda x: self.sn_effi(x))
.reset_index()
)
cols = ["season", "z", "x1", "color", "sntype"]
for col in cols:
lc[col + "_copy"] = lc[col].values.copy()

sn_effis = lc.groupby(cols).apply(lambda x: self.sn_effi(x), include_groups=False).reset_index()

# estimate efficiencies
sn_effis["season"] = sn_effis["season"].astype(int)
Expand Down Expand Up @@ -599,7 +603,7 @@ def calc_daymax(self, grp, daymax_step):
t0_min = grp["T0_min"].values
num = (t0_max - t0_min) / daymax_step
if t0_max - t0_min > 10:
df = pd.DataFrame(np.linspace(t0_min, t0_max, int(num)), columns=["daymax"])
df = pd.DataFrame(np.linspace(t0_min, t0_max, int(np.max(num))), columns=["daymax"])
else:
df = pd.DataFrame([-1], columns=["daymax"])
else:
Expand Down Expand Up @@ -709,7 +713,12 @@ def sn_effi(self, lc):
resdf["nepochs_bef"] = self.get_epochs(nights, flag, flagph)

# replace NaN by 0
resdf = resdf.fillna(0)
# solution from: https://stackoverflow.com/
# questions/77900971/
# pandas-futurewarning-downcasting-
# object-dtype-arrays-on-fillna-ffill-bfill
with pd.option_context("future.no_silent_downcasting", True):
resdf = resdf.fillna(0).infer_objects(copy=False)

# get selection efficiencies
effis = self.efficiencies(resdf)
Expand Down Expand Up @@ -904,14 +913,19 @@ def metric(self, data_slice, zseason, x1=-2.0, color=0.2, zlim=-1, metric="zlim"

# get simulation parameters
if np.size(np.unique(seasons)) > 1:
dur_z["z_copy"] = dur_z["z"].values.copy()
dur_z["season_copy"] = dur_z["season"].values.copy()
gen_par = (
dur_z.groupby(["z", "season"])
.apply(lambda x: self.calc_daymax(x, self.daymax_step))
.apply(lambda x: self.calc_daymax(x, self.daymax_step), include_groups=False)
.reset_index()
)
else:
dur_z["z_copy"] = dur_z["z"].values.copy()
gen_par = (
dur_z.groupby(["z"]).apply(lambda x: self.calc_daymax(x, self.daymax_step)).reset_index()
dur_z.groupby(["z"])
.apply(lambda x: self.calc_daymax(x, self.daymax_step), include_groups=False)
.reset_index()
)
gen_par["season"] = gen_par["level_1"] * 0 + np.unique(seasons)

Expand Down Expand Up @@ -958,11 +972,17 @@ def metric(self, data_slice, zseason, x1=-2.0, color=0.2, zlim=-1, metric="zlim"

# estimate redshift completeness
if metric == "zlim":
metric_values = sn.groupby(["season"]).apply(lambda x: self.zlim(x)).reset_index()
sn["season_copy"] = sn["season"].values.copy()
metric_values = (
sn.groupby(["season"]).apply(lambda x: self.zlim(x), include_groups=False).reset_index()
)

if metric == "nsn":
sn = sn.merge(zlim, left_on=["season"], right_on=["season"])
metric_values = sn.groupby(["season"]).apply(lambda x: self.nsn(x)).reset_index()
sn["season_copy"] = sn["season"].values.copy()
metric_values = (
sn.groupby(["season"]).apply(lambda x: self.nsn(x), include_groups=False).reset_index()
)

return metric_values

Expand Down Expand Up @@ -991,12 +1011,15 @@ def z_season(self, seasons, data_slice):
return zseason

def z_season_allz(self, zseason):

zseason["season_copy"] = zseason["season"].values.copy()
zseason_allz = (
zseason.groupby(["season"])
.apply(
lambda x: pd.DataFrame(
{"z": list(np.arange(x["zmin"].mean(), x["zmax"].mean(), x["zstep"].mean()))}
)
),
include_groups=False,
)
.reset_index()
)
Expand Down

0 comments on commit 343069a

Please sign in to comment.