Skip to content

Commit

Permalink
Merge pull request #211 from NREL/fix/QDM_explicit_args
Browse files Browse the repository at this point in the history
Pipeline requires explicit arguments for QuantileDeltaMappingCorrection
  • Loading branch information
castelao authored May 14, 2024
2 parents 138a006 + aea7b34 commit 96a6f6a
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 12 deletions.
78 changes: 71 additions & 7 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,12 +1204,21 @@ def __init__(self,
base_fps,
bias_fps,
bias_fut_fps,
*args,
base_dset,
bias_feature,
distance_upper_bound=None,
target=None,
shape=None,
base_handler='Resource',
bias_handler='DataHandlerNCforCC',
base_handler_kwargs=None,
bias_handler_kwargs=None,
decimals=None,
match_zero_rate=False,
n_quantiles=101,
dist="empirical",
sampling="linear",
log_base=10,
**kwargs):
log_base=10):
"""
Parameters
----------
Expand All @@ -1229,9 +1238,52 @@ def __init__(self,
future in Cannon et. al. (2015)). This is the dataset that would be
corrected, while `bias_fsp` is used to provide a transformation map
with the baseline data.
*(kw)args :
For aditional arguments, check :class:`DataRetrievalBase`.
*Note: The following arguments must be keyworded arguments.*
base_dset : str
A single dataset from the base_fps to retrieve. In the case of wind
components, this can be U_100m or V_100m which will retrieve
windspeed and winddirection and derive the U/V component.
bias_feature : str
This is the biased feature from bias_fps to retrieve. This should
be a single feature name corresponding to base_dset
distance_upper_bound : float
Upper bound on the nearest neighbor distance in decimal degrees.
This should be the approximate resolution of the low-resolution
bias data. None (default) will calculate this based on the median
distance between points in bias_fps
target : tuple
(lat, lon) lower left corner of raster to retrieve from bias_fps.
If None then the lower left corner of the full domain will be used.
shape : tuple
(rows, cols) grid size to retrieve from bias_fps. If None then the
full domain shape will be used.
base_handler : str
Name of rex resource handler or sup3r.preprocessing.data_handling
class to be retrieved from the rex/sup3r library. If a
sup3r.preprocessing.data_handling class is used, all data will be
loaded in this class' initialization and the subsequent bias
calculation will be done in serial
bias_handler : str
Name of the bias data handler class to be retrieved from the
sup3r.preprocessing.data_handling library.
base_handler_kwargs : dict | None
Optional kwargs to send to the initialization of the base_handler
class
bias_handler_kwargs : dict | None
Optional kwargs to send to the initialization of the bias_handler
class
decimals : int | None
Option to round bias and base data to this number of
decimals, this gets passed to np.around(). If decimals
is negative, it specifies the number of positions to
the left of the decimal point.
match_zero_rate : bool
Option to fix the frequency of zero values in the biased data. The
lowest percentile of values in the biased data will be set to zero
to match the percentile of zeros in the base data. If
SkillAssessment is being run and this is True, the distributions
will not be mean-centered. This helps resolve the issue where
global climate models produce too many days with small
precipitation totals e.g., the "drizzle problem" [Polade2014]_.
dist : str, default="empirical",
Define the type of distribution, which can be "empirical" or any
parametric distribution defined in "scipy".
Expand Down Expand Up @@ -1280,7 +1332,19 @@ def __init__(self,
self.sampling = sampling
self.log_base = log_base

super().__init__(base_fps, bias_fps, *args, **kwargs)
super().__init__(base_fps=base_fps,
bias_fps=bias_fps,
base_dset=base_dset,
bias_feature=bias_feature,
distance_upper_bound=distance_upper_bound,
target=target,
shape=shape,
base_handler=base_handler,
bias_handler=bias_handler,
base_handler_kwargs=base_handler_kwargs,
bias_handler_kwargs=bias_handler_kwargs,
decimals=decimals,
match_zero_rate=match_zero_rate)

self.bias_fut_fps = bias_fut_fps

Expand Down
2 changes: 1 addition & 1 deletion sup3r/bias/bias_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def local_qdm_bc(data: np.array,
spatial_slice = (lr_padded_slice[0], lr_padded_slice[1])
base = base[spatial_slice]
bias = bias[spatial_slice]
bias_fut = bias[spatial_slice]
bias_fut = bias_fut[spatial_slice]

if no_trend:
mf = None
Expand Down
15 changes: 13 additions & 2 deletions sup3r/preprocessing/data_handling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,8 @@ def qdm_bc(self,
bc_files,
reference_feature,
relative=True,
threshold=0.1):
threshold=0.1,
no_trend=False):
"""Bias Correction using Quantile Delta Mapping
Bias correct this DataHandler's data with Quantile Delta Mapping. The
Expand Down Expand Up @@ -1361,6 +1362,15 @@ def qdm_bc(self,
Nearest neighbor euclidean distance threshold. If the DataHandler
coordinates are more than this value away from the bias correction
lat/lon, an error is raised.
no_trend: bool, default=False
An option to ignore the trend component of the correction, thus
resulting in an ordinary Quantile Mapping, i.e. corrects the bias
by comparing the distributions of the biased dataset with a
reference datasets. See ``params_mf`` of
:class:`rex.utilities.bc_utils.QuantileDeltaMapping`.
Note that this assumes that "bias_{feature}_params"
(``params_mh``) is the data distribution representative for the
target data.
"""

if isinstance(bc_files, str):
Expand All @@ -1378,7 +1388,8 @@ def qdm_bc(self,
feature,
bias_fp=fp,
threshold=threshold,
relative=relative)
relative=relative,
no_trend=no_trend)
completed.append(feature)


Expand Down
113 changes: 111 additions & 2 deletions tests/bias/test_qdm_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
import xarray as xr

from sup3r import TEST_DATA_DIR
from sup3r import CONFIG_DIR, TEST_DATA_DIR
from sup3r.models import Sup3rGan
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
from sup3r.bias.bias_calc import QuantileDeltaMappingCorrection
from sup3r.bias.bias_transforms import local_qdm_bc
from sup3r.preprocessing.data_handling import DataHandlerNC
from sup3r.preprocessing.data_handling import DataHandlerNC, DataHandlerNCforCC

FP_NSRDB = os.path.join(TEST_DATA_DIR, "test_nsrdb_co_2018.h5")
FP_CC = os.path.join(TEST_DATA_DIR, "rsds_test.nc")
Expand Down Expand Up @@ -359,3 +361,110 @@ def test_bc_trend_same_hist(tmp_path, fp_fut_cc, dist_params):

idx = ~(np.isnan(original) | np.isnan(corrected))
assert np.allclose(corrected[idx], original[idx])


def test_fwp_integration(tmp_path):
"""Integration of the bias correction method into the forward pass
Validate two aspects:
- We should be able to run a forward pass with unbiased data.
- The bias trend should be observed in the predicted output.
"""
fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json')
fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json')
features = ['U_100m', 'V_100m']
target = (13.67, 125.0)
shape = (8, 8)
temporal_slice = slice(None, None, 1)
fwp_chunk_shape = (4, 4, 150)
input_files = [os.path.join(TEST_DATA_DIR, 'ua_test.nc'),
os.path.join(TEST_DATA_DIR, 'va_test.nc'),
os.path.join(TEST_DATA_DIR, 'orog_test.nc'),
os.path.join(TEST_DATA_DIR, 'zg_test.nc')]

n_samples = 101
quantiles = np.linspace(0, 1, n_samples)
params = {}
with xr.open_dataset(os.path.join(TEST_DATA_DIR, 'ua_test.nc')) as ds:
params['bias_U_100m_params'] = ds['ua'].quantile(quantiles).to_numpy()
params['base_Uref_100m_params'] = params['bias_U_100m_params'] - 2.72
params['bias_fut_U_100m_params'] = params['bias_U_100m_params']
with xr.open_dataset(os.path.join(TEST_DATA_DIR, 'va_test.nc')) as ds:
params['bias_V_100m_params'] = ds['va'].quantile(quantiles).to_numpy()
params['base_Vref_100m_params'] = params['bias_V_100m_params'] + 2.72
params['bias_fut_V_100m_params'] = params['bias_V_100m_params']

lat_lon = DataHandlerNCforCC(input_files, features=[], target=target,
shape=shape,
worker_kwargs={'max_workers': 1}).lat_lon

Sup3rGan.seed()
model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4)
_ = model.generate(np.ones((4, 10, 10, 6, len(features))))
model.meta['lr_features'] = features
model.meta['hr_out_features'] = features
model.meta['s_enhance'] = 3
model.meta['t_enhance'] = 4

bias_fp = os.path.join(tmp_path, 'bc.h5')
out_dir = os.path.join(tmp_path, 'st_gan')
model.save(out_dir)

with h5py.File(bias_fp, "w") as f:
f.create_dataset("latitude", data=lat_lon[..., 0])
f.create_dataset("longitude", data=lat_lon[..., 1])

s = lat_lon.shape[:2]
for k, v in params.items():
f.create_dataset(k, data=np.broadcast_to(v, (*s, v.size)))
f.attrs["dist"] = "empirical"
f.attrs["sampling"] = "linear"
f.attrs["log_base"] = 10

bias_correct_kwargs = {'U_100m': {'feature_name': 'U_100m',
'base_dset': 'Uref_100m',
'bias_fp': bias_fp},
'V_100m': {'feature_name': 'V_100m',
'base_dset': 'Vref_100m',
'bias_fp': bias_fp}}

strat = ForwardPassStrategy(
input_files,
model_kwargs={'model_dir': out_dir},
fwp_chunk_shape=fwp_chunk_shape,
spatial_pad=0, temporal_pad=0,
input_handler_kwargs=dict(target=target, shape=shape,
temporal_slice=temporal_slice,
worker_kwargs=dict(max_workers=1)),
out_pattern=os.path.join(tmp_path, 'out_{file_id}.nc'),
worker_kwargs=dict(max_workers=1),
input_handler='DataHandlerNCforCC')
bc_strat = ForwardPassStrategy(
input_files,
model_kwargs={'model_dir': out_dir},
fwp_chunk_shape=fwp_chunk_shape,
spatial_pad=0, temporal_pad=0,
input_handler_kwargs=dict(target=target, shape=shape,
temporal_slice=temporal_slice,
worker_kwargs=dict(max_workers=1)),
out_pattern=os.path.join(tmp_path, 'out_{file_id}.nc'),
worker_kwargs=dict(max_workers=1),
input_handler='DataHandlerNCforCC',
bias_correct_method='local_qdm_bc',
bias_correct_kwargs=bias_correct_kwargs)

for ichunk in range(strat.chunks):
fwp = ForwardPass(strat, chunk_index=ichunk)
bc_fwp = ForwardPass(bc_strat, chunk_index=ichunk)

delta = bc_fwp.input_data - fwp.input_data
assert np.allclose(
delta[..., 0], -2.72, atol=1e-03
), "U reference offset is -1"
assert np.allclose(
delta[..., 1], 2.72, atol=1e-03
), "V reference offset is 1"

delta = bc_fwp.run_chunk() - fwp.run_chunk()
assert delta[..., 0].mean() < 0, "Predicted U should trend <0"
assert delta[..., 1].mean() > 0, "Predicted V should trend >0"

0 comments on commit 96a6f6a

Please sign in to comment.