Skip to content

Commit

Permalink
Merge pull request #171 from NREL/bnb/train_only_features_hot_fix
Browse files Browse the repository at this point in the history
Bnb/train only features hot fix
  • Loading branch information
bnb32 authored Oct 18, 2023
2 parents 344783d + 3540ede commit 69283cd
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 17 deletions.
25 changes: 21 additions & 4 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _combine_loss_input(self, high_res_true, high_res_gen):
"""
if high_res_true.shape[-1] > high_res_gen.shape[-1]:
for feature in self.exogenous_features:
f_idx = self.training_features.index(feature)
f_idx = self.hr_features.index(feature)
exo_data = high_res_true[..., f_idx: f_idx + 1]
high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1)
return high_res_gen
Expand Down Expand Up @@ -360,6 +360,23 @@ def training_features(self):
trained on."""
return self.meta.get('training_features', None)

@property
def train_only_features(self):
"""Get the list of feature names used only for training (expected as
input but not included in output)."""
return self.meta.get('train_only_features', None)

@property
def hr_features(self):
"""Get the list of features stored in batch.high_res. This is the same
as training_features but without train_only_features. This is used to
select the correct high res exogenous data."""
hr_features = self.training_features
if self.train_only_features is not None:
hr_features = [f for f in self.training_features
if f not in self.train_only_features]
return hr_features

@property
def output_features(self):
"""Get the list of output feature names that the generative model
Expand Down Expand Up @@ -445,7 +462,8 @@ def set_model_params(self, **kwargs):
kwargs = self._check_exo_features(**kwargs)

keys = ('input_resolution', 'training_features', 'output_features',
'smoothed_features', 's_enhance', 't_enhance', 'smoothing')
'train_only_features', 'smoothed_features', 's_enhance',
't_enhance', 'smoothing')
keys = [k for k in keys if k in kwargs]

for var in keys:
Expand Down Expand Up @@ -660,7 +678,6 @@ def norm_input(self, low_res):
warn(msg)
else:
stdevs = self._stdevs

low_res = (low_res.copy() - self._means) / stdevs

return low_res
Expand Down Expand Up @@ -819,7 +836,7 @@ def get_high_res_exo_input(self, high_res):
"""
exo_data = {}
for feature in self.exogenous_features:
f_idx = self.training_features.index(feature)
f_idx = self.hr_features.index(feature)
exo_fdata = high_res[..., f_idx: f_idx + 1]
exo_data[feature] = exo_fdata
return exo_data
Expand Down
1 change: 1 addition & 0 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ def train(self,
t_enhance=batch_handler.t_enhance,
smoothing=batch_handler.smoothing,
training_features=batch_handler.training_features,
train_only_features=batch_handler.train_only_features,
output_features=batch_handler.output_features,
smoothed_features=batch_handler.smoothed_features)

Expand Down
1 change: 1 addition & 0 deletions sup3r/models/conditional_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def train(self, batch_handler,
s_enhance=batch_handler.s_enhance,
t_enhance=batch_handler.t_enhance,
smoothing=batch_handler.smoothing,
train_only_features=batch_handler.train_only_features,
training_features=batch_handler.training_features,
output_features=batch_handler.output_features,
smoothed_features=batch_handler.smoothed_features)
Expand Down
22 changes: 18 additions & 4 deletions sup3r/preprocessing/batch_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@
from scipy.ndimage.filters import gaussian_filter

from sup3r.preprocessing.data_handling.h5_data_handling import (
DataHandlerDCforH5, )
DataHandlerDCforH5,
)
from sup3r.utilities.utilities import (
estimate_max_workers, nn_fill_array, nsrdb_reduce_daily_data, smooth_data,
spatial_coarsening, temporal_coarsening, uniform_box_sampler,
uniform_time_sampler, weighted_box_sampler, weighted_time_sampler,
estimate_max_workers,
nn_fill_array,
nsrdb_reduce_daily_data,
smooth_data,
spatial_coarsening,
temporal_coarsening,
uniform_box_sampler,
uniform_time_sampler,
weighted_box_sampler,
weighted_time_sampler,
)

np.random.seed(42)
Expand Down Expand Up @@ -577,6 +585,12 @@ def training_features(self):
data handlers"""
return self.data_handlers[0].features

@property
def train_only_features(self):
"""Get the ordered list of feature names used only for training which
will not be stored in batch.high_res"""
return self.data_handlers[0].train_only_features

@property
def output_features(self):
"""Get the ordered list of feature names held in this object's
Expand Down
2 changes: 1 addition & 1 deletion sup3r/preprocessing/data_handling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon):
min_lon = np.min(lat_lon[..., 1])
max_lat = np.max(lat_lon[..., 0])
max_lon = np.max(lat_lon[..., 1])
logger.debug('Calculating raster index from WRF file '
logger.debug('Calculating raster index from NETCDF file '
f'for shape {grid_shape} and target {target}')
logger.debug(f'lat/lon (min, max): {min_lat}/{min_lon}, '
f'{max_lat}/{max_lon}')
Expand Down
5 changes: 5 additions & 0 deletions sup3r/preprocessing/data_handling/dual_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ def output_features(self):
GAN"""
return self.hr_dh.output_features

@property
def train_only_features(self):
"""Features to use for training only and not output"""
return self.lr_dh.train_only_features

def _shape_check(self):
"""Check if hr_handler.shape is divisible by s_enhance. If not take
the largest shape that can be."""
Expand Down
35 changes: 28 additions & 7 deletions sup3r/preprocessing/data_handling/nc_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,34 @@

from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC
from sup3r.preprocessing.feature_handling import (
BVFreqMon, BVFreqSquaredNC, ClearSkyRatioCC, Feature, InverseMonNC,
LatLonNC, PotentialTempNC, PressureNC, Rews, Shear, Tas, TasMax, TasMin,
TempNC, TempNCforCC, UWind, UWindPowerLaw, VWind, VWindPowerLaw,
WinddirectionNC, WindspeedNC)
BVFreqMon,
BVFreqSquaredNC,
ClearSkyRatioCC,
Feature,
InverseMonNC,
LatLonNC,
PotentialTempNC,
PressureNC,
Rews,
Shear,
Tas,
TasMax,
TasMin,
TempNC,
TempNCforCC,
UWind,
UWindPowerLaw,
VWind,
VWindPowerLaw,
WinddirectionNC,
WindspeedNC,
)
from sup3r.utilities.interpolation import Interpolator
from sup3r.utilities.utilities import (estimate_max_workers, get_time_dim_name,
np_to_pd_times)
from sup3r.utilities.utilities import (
estimate_max_workers,
get_time_dim_name,
np_to_pd_times,
)

np.random.seed(42)

Expand Down Expand Up @@ -385,7 +406,7 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon):
min_lon = np.min(lat_lon[..., 1])
max_lat = np.max(lat_lon[..., 0])
max_lon = np.max(lat_lon[..., 1])
logger.debug('Calculating raster index from WRF file '
logger.debug('Calculating raster index from NETCDF file '
f'for shape {grid_shape} and target {target}')
logger.debug(f'lat/lon (min, max): {min_lat}/{min_lon}, '
f'{max_lat}/{max_lon}')
Expand Down
96 changes: 95 additions & 1 deletion tests/training/test_train_gan_exo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,107 @@
TARGET_S = (39.01, -105.13)

INPUT_FILE_W = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
FEATURES_W = ['U_100m', 'V_100m', 'temperature_100m', 'topography']
FEATURES_W = ['temperature_100m', 'U_100m', 'V_100m', 'topography']
TARGET_W = (39.01, -105.15)

FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)


@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat'])
def test_wind_hi_res_topo_with_train_only(custom_layer, log=False):
"""Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat
layer that adds/concatenates hi-res topography in the middle of the
network. This also includes a train only feature"""

handler = DataHandlerH5WindCC(INPUT_FILE_W,
FEATURES_W,
target=TARGET_W, shape=SHAPE,
temporal_slice=slice(None, None, 2),
time_roll=-7,
val_split=0.1,
sample_shape=(20, 20),
worker_kwargs=dict(max_workers=1),
train_only_features=['temperature_100m'])
batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2,
s_enhance=2)

if log:
init_logger('sup3r', log_level='DEBUG')

gen_model = [{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 64, "kernel_size": 3,
"strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4},

{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 64,
"kernel_size": 3, "strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4},

{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 64,
"kernel_size": 3, "strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4},
{"class": "SpatialExpansion", "spatial_mult": 2},
{"class": "Activation", "activation": "relu"},

{"class": custom_layer, "name": "topography"},

{"class": "FlexiblePadding",
"paddings": [[0, 0], [3, 3], [3, 3], [0, 0]],
"mode": "REFLECT"},
{"class": "Conv2DTranspose", "filters": 2,
"kernel_size": 3, "strides": 1, "activation": "relu"},
{"class": "Cropping2D", "cropping": 4}]

fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json')

Sup3rGan.seed()
model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4)

with tempfile.TemporaryDirectory() as td:
model.train(batcher,
input_resolution={'spatial': '16km',
'temporal': '3600min'},
n_epoch=1,
weight_gen_advers=0.0,
train_gen=True, train_disc=False,
checkpoint_int=None,
out_dir=os.path.join(td, 'test_{epoch}'))

assert model.train_only_features == ['temperature_100m']
assert model.hr_features == ['U_100m', 'V_100m', 'topography']
assert 'test_0' in os.listdir(td)
assert model.meta['output_features'] == ['U_100m', 'V_100m']
assert model.meta['class'] == 'Sup3rGan'
assert 'topography' in batcher.output_features
assert 'topography' not in model.output_features

x = np.random.uniform(0, 1, (4, 30, 30, 4))
hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1))

with pytest.raises(RuntimeError):
y = model.generate(x, exogenous_data=None)

exo_tmp = {
'topography': {
'steps': [
{'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}}
y = model.generate(x, exogenous_data=exo_tmp)

assert y.shape[0] == x.shape[0]
assert y.shape[1] == x.shape[1] * 2
assert y.shape[2] == x.shape[2] * 2
assert y.shape[3] == x.shape[3] - 2


@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat'])
def test_wind_hi_res_topo(custom_layer, log=False):
"""Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat
Expand Down

0 comments on commit 69283cd

Please sign in to comment.