From 023fa9063c8ba95814706014df2e7f6ea7a169eb Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Wed, 30 Aug 2023 18:01:17 +0200 Subject: [PATCH] Add additional check whether downscaled starts from history --- src/aneris/downscaling/core.py | 44 +++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/aneris/downscaling/core.py b/src/aneris/downscaling/core.py index 5bd5295..d23c27f 100644 --- a/src/aneris/downscaling/core.py +++ b/src/aneris/downscaling/core.py @@ -212,22 +212,44 @@ def downscale( return self.return_type(downscaled) def check_downscaled(self, downscaled, rtol=1e-05, atol=1e-08): - downscaled = ( + def warn_if_differences(actual, should, message): + actual, should = actual.align(should, join="left") + diff = actual - should + diff_exceeded = abs(diff) > atol + rtol * abs(should) + if diff_exceeded.any(): + logger().warning( + "%s:\n%s", + message, + DataFrame(dict(actual=actual, should=should, diff=diff)) + .loc[diff_exceeded] + .to_string(), + ) + + downscaled_region = ( downscaled.groupby(self.model.index.names, dropna=False) .sum() .rename_axis(columns="year") .stack() ) - model = self.model.rename_axis(columns="year").stack() - diff = downscaled - model - diff_exceeded = abs(diff) + rtol * abs(model) > atol - if diff_exceeded.any(): - logger().warning( - "Difference thresholds exceeded for a few trajectories:\n%s", - DataFrame(dict(model=model, downscaled=downscaled, diff=diff)) - .loc[diff_exceeded] - .to_string(), - ) + model = self.model.loc[:, self.year:].rename_axis(columns="year").stack() + + warn_if_differences( + downscaled_region, + model, + "Downscaled trajectories do not sum up to regional totals", + ) + + hist = self.hist + if isinstance(hist, DataFrame): + hist = hist.loc[:, self.year] + hist = hist.pix.semijoin(downscaled.index, how="right") + downscaled_start = downscaled.loc[:, self.year] + + warn_if_differences( + downscaled_start, + hist, + "Downscaled trajectories do not start from history", + ) def methods(self, method_choice=None, overwrites=None): if method_choice is None: