diff --git a/examples/sup3rcc/run_configs/solar/config_fwp.json b/examples/sup3rcc/run_configs/solar/config_fwp.json index 74746f1e8c..1b30ba65cc 100755 --- a/examples/sup3rcc/run_configs/solar/config_fwp.json +++ b/examples/sup3rcc/run_configs/solar/config_fwp.json @@ -49,15 +49,15 @@ }, "max_nodes": 10, "exo_kwargs": { - "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], - "features": ["topography"], - "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", - "target": [23.2, -129], - "shape": [26, 59], - "s_enhancements": [1, 5, 5], - "agg_factors": [625, 25, 1], - "input_handler": "DataHandlerNCforCC", - "exo_steps": [0, 1, 2] + "topography": { + "steps": [{"model": 0, "combine_type": "input"}, + {"model": 0, "combine_type": "layer"}, + {"model": 1, "combine_type": "input"}], + "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], + "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", + "target": [23.2, -129], + "shape": [26, 59], + "input_handler": "DataHandlerNCforCC"} }, "execution_control": { "option": "eagle", diff --git a/examples/sup3rcc/run_configs/trh/config_fwp.json b/examples/sup3rcc/run_configs/trh/config_fwp.json index 3fdcd2504a..c423be1d59 100755 --- a/examples/sup3rcc/run_configs/trh/config_fwp.json +++ b/examples/sup3rcc/run_configs/trh/config_fwp.json @@ -1,14 +1,9 @@ { "file_paths": "PLACEHOLDER", "model_kwargs": { - "surface_model_kwargs": { - "model_dir": "./sup3rcc_models_202303/sup3rcc_trh_step1_25x_1x_2f/" - }, - "temporal_model_kwargs": { - "model_dirs": [ - "./sup3rcc_models_202303/sup3rcc_trh_step2_1x_24x_2f/" - ] - } + "model_dirs": [ + "./sup3rcc_models_202303/sup3rcc_trh_step1_25x_1x_2f/", + "./sup3rcc_models_202303/sup3rcc_trh_step2_1x_24x_2f/"] }, "model_class": "MultiStepSurfaceMetGan", "out_pattern": "./chunks/sup3r_chunk_{file_id}.h5", @@ -42,15 +37,14 @@ }, "max_nodes": 10, "exo_kwargs": { - "file_paths": ["/datasets/sup3rcc/source/temp_humi_day_MRI-ESM2-0_ssp245_r1i1p1f1_gn_20500101-20501231.nc"], - "features": ["topography"], - "source_file": "/datasets/sup3rcc/source/nsrdb_clearsky.h5", - "target": [23.2, -129], - "shape": [26, 59], - "s_enhancements": [1, 25], - "agg_factors": [625, 1], - "input_handler": "DataHandlerNCforCC", - "exo_steps": [0, 1] + "topography": { + "steps": [{"model": 0, "combine_type": "input"}, + {"model": 0, "combine_type": "output"}], + "file_paths": ["/datasets/sup3rcc/source/temp_humi_day_MRI-ESM2-0_ssp245_r1i1p1f1_gn_20500101-20501231.nc"], + "source_file": "/datasets/sup3rcc/source/nsrdb_clearsky.h5", + "target": [23.2, -129], + "shape": [26, 59], + "input_handler": "DataHandlerNCforCC"} }, "execution_control": { "option": "eagle", diff --git a/examples/sup3rcc/run_configs/wind/config_fwp.json b/examples/sup3rcc/run_configs/wind/config_fwp.json index acf0e5564d..ae908a7a63 100755 --- a/examples/sup3rcc/run_configs/wind/config_fwp.json +++ b/examples/sup3rcc/run_configs/wind/config_fwp.json @@ -1,13 +1,12 @@ { "file_paths": "PLACEHOLDER", "model_kwargs": { - "spatial_model_dirs": [ + "model_dirs": [ "./sup3rcc_models_202303/sup3rcc_wind_step1_5x_1x_6f/", - "./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/" - ], - "temporal_model_dirs": "./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/" + "./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/", + "./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/"] }, - "model_class": "SpatialThenTemporalGan", + "model_class": "MultiStepGan", "out_pattern": "./chunks/sup3r_chunk_{file_id}.h5", "log_pattern": "./logs/sup3r_fwp_log_{node_index}.log", "bias_correct_method": "monthly_local_linear_bc", @@ -39,15 +38,15 @@ }, "max_nodes": 50, "exo_kwargs": { - "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], - "features": ["topography"], - "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", - "target": [23.2, -129], - "shape": [26, 59], - "s_enhancements": [1, 5, 5], - "agg_factors": [625, 25, 1], - "input_handler": "DataHandlerNCforCC", - "exo_steps": [0, 1, 2] + "topography": { + "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], + "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", + "target": [23.2, -129], + "shape": [26, 59], + "steps": [{"model": 0, "combine_type": "input"}, + {"model": 0, "combine_type": "layer"}, + {"model": 1, "combine_type": "input"}], + "input_handler": "DataHandlerNCforCC"} }, "execution_control": { "option": "eagle", diff --git a/requirements.txt b/requirements.txt index 06de3fb18f..ec27570ce9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ matplotlib>=3.1 -NREL-rex>=0.2.82 +NREL-rex>=0.2.84 NREL-phygnn>=0.0.23 NREL-rev<0.8.0 NREL-gaps>=0.4.0 diff --git a/sup3r/bias/bias_calc_cli.py b/sup3r/bias/bias_calc_cli.py index 9e34ee571a..d20b16f094 100644 --- a/sup3r/bias/bias_calc_cli.py +++ b/sup3r/bias/bias_calc_cli.py @@ -3,15 +3,15 @@ sup3r bias correction calculation CLI entry points. """ import copy -import click import logging import os +import click + import sup3r.bias.bias_calc from sup3r.utilities import ModuleName -from sup3r.version import __version__ from sup3r.utilities.cli import BaseCLI - +from sup3r.version import __version__ logger = logging.getLogger(__name__) diff --git a/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json b/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json index 870c363f2f..9785abf0e2 100644 --- a/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json +++ b/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json @@ -34,7 +34,7 @@ {"class": "SpatioTemporalExpansion", "spatial_mult": 3}, {"alpha": 0.2, "class": "LeakyReLU"}, - {"class": "Sup3rConcat"}, + {"class": "Sup3rConcat", "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"}, {"class": "Conv3D", "filters": 2, "kernel_size": 3, "strides": 1}, diff --git a/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json b/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json index 53ce739a03..b90b9f39c7 100644 --- a/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json +++ b/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json @@ -31,7 +31,7 @@ {"class": "SpatialExpansion", "spatial_mult": 5}, {"alpha": 0.2, "class": "LeakyReLU"}, - {"class": "Sup3rConcat"}, + {"class": "Sup3rConcat", "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [0,0]], "mode": "REFLECT"}, {"class": "Conv2D", "filters": 64, "kernel_size": 3, "strides": 1}, {"class": "Cropping2D", "cropping": 2}, diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 9304404393..1179231d9e 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -1,17 +1,12 @@ # -*- coding: utf-8 -*- """Sup3r Model Software""" from .base import Sup3rGan -from .wind import WindGan -from .solar_cc import SolarCC +from .conditional_moments import Sup3rCondMom from .data_centric import Sup3rGanDC -from .multi_step import (MultiStepGan, - SpatialThenTemporalGan, TemporalThenSpatialGan, - MultiStepSurfaceMetGan, SolarMultiStepGan) -from .surface import SurfaceSpatialMetModel from .linear import LinearInterp -from .conditional_moments import Sup3rCondMom -from .wind_conditional_moments import WindCondMom +from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan +from .solar_cc import SolarCC +from .surface import SurfaceSpatialMetModel -SPATIAL_FIRST_MODELS = (SpatialThenTemporalGan, - MultiStepSurfaceMetGan, +SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 26d7f4be76..a2dc0264e0 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -6,6 +6,7 @@ import logging import os import pprint +import re import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -14,10 +15,10 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import optimizers from phygnn import CustomNetwork from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat from rex.utilities.utilities import safe_json_load +from tensorflow.keras import optimizers import sup3r.utilities.loss_metrics from sup3r.utilities import VERSION_RECORD @@ -54,6 +55,15 @@ def load(cls, model_dir, verbose=True): model_dir """ + @abstractmethod + def generate(self, + low_res, + norm_in=True, + un_norm_out=True, + exogenous_data=None): + """Use the generator model to generate high res data from low res + input. This is the public generate function.""" + @staticmethod def seed(s=0): """ @@ -66,27 +76,6 @@ def seed(s=0): """ CustomNetwork.seed(s=s) - @abstractmethod - def generate(self, low_res): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - @property def input_dims(self): """Get dimension of model generator input. This is usually 4D for @@ -106,17 +95,264 @@ def input_dims(self): else: return 5 + # pylint: disable=E1101 + def get_s_enhance_from_layers(self): + """Compute factor by which model will enhance spatial resolution from + layer attributes. Used in model training during high res coarsening""" + s_enhance = None + if hasattr(self, '_gen'): + s_enhancements = [getattr(layer, '_spatial_mult', 1) + for layer in self._gen.layers] + s_enhance = int(np.product(s_enhancements)) + return s_enhance + + # pylint: disable=E1101 + def get_t_enhance_from_layers(self): + """Compute factor by which model will enhance temporal resolution from + layer attributes. Used in model training during high res coarsening""" + t_enhance = None + if hasattr(self, '_gen'): + t_enhancements = [getattr(layer, '_temporal_mult', 1) + for layer in self._gen.layers] + t_enhance = int(np.product(t_enhancements)) + return t_enhance + @property def s_enhance(self): """Factor by which model will enhance spatial resolution. Used in model training during high res coarsening""" - return self.meta.get('s_enhance', None) + s_enhance = self.meta.get('s_enhance', None) + if s_enhance is None: + s_enhance = self.get_s_enhance_from_layers() + self.meta['s_enhance'] = s_enhance + return s_enhance @property def t_enhance(self): """Factor by which model will enhance temporal resolution. Used in model training during high res coarsening""" - return self.meta.get('t_enhance', None) + t_enhance = self.meta.get('t_enhance', None) + if t_enhance is None: + t_enhance = self.get_t_enhance_from_layers() + self.meta['t_enhance'] = t_enhance + return t_enhance + + @property + def input_resolution(self): + """Resolution of input data. Given as a dictionary {'spatial': '...km', + 'temporal': '...min'}. The numbers are required to be integers in the + units specified. The units are not strict as long as the resolution + of the exogenous data, when extracting exogenous data, is specified + in the same units.""" + input_resolution = self.meta.get('input_resolution', None) + msg = 'model.input_resolution is None. This needs to be set.' + assert input_resolution is not None, msg + return input_resolution + + def _get_numerical_resolutions(self): + """Get the input and output resolutions without units. e.g. for + {"spatial": "30km", "temporal": "60min"} this returns + {"spatial": 30, "temporal": 60}""" + ires_num = {k: int(re.search(r'\d+', v).group(0)) + for k, v in self.input_resolution.items()} + enhancements = {'spatial': self.s_enhance, + 'temporal': self.t_enhance} + ores_num = {k: v // enhancements[k] for k, v in ires_num.items()} + return ires_num, ores_num + + def _ensure_valid_input_resolution(self): + """Ensure ehancement factors evenly divide input_resolution""" + + if self.input_resolution is None: + return + + ires_num, ores_num = self._get_numerical_resolutions() + s_enhance = self.meta['s_enhance'] + t_enhance = self.meta['t_enhance'] + check = ( + ires_num['temporal'] / ores_num['temporal'] == t_enhance + and ires_num['spatial'] / ores_num['spatial'] == s_enhance) + msg = (f'Enhancement factors (s_enhance={s_enhance}, ' + f't_enhance={t_enhance}) do not evenly divide ' + f'input resolution ({self.input_resolution})') + if not check: + logger.error(msg) + raise RuntimeError(msg) + + def _ensure_valid_enhancement_factors(self): + """Ensure user provided enhancement factors are the same as those + computed from layer attributes""" + t_enhance = self.meta.get('t_enhance', None) + s_enhance = self.meta.get('s_enhance', None) + if s_enhance is None or t_enhance is None: + return + + layer_se = self.get_s_enhance_from_layers() + layer_te = self.get_t_enhance_from_layers() + layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] + layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] + msg = (f'Enhancement factors computed from layer attributes ' + f'(s_enhance={layer_se}, t_enhance={layer_te}) ' + f'conflict with user provided values (s_enhance={s_enhance}, ' + f't_enhance={t_enhance})') + check = layer_se == s_enhance or layer_te == t_enhance + if not check: + logger.error(msg) + raise RuntimeError(msg) + + @property + def output_resolution(self): + """Resolution of output data. Given as a dictionary + {'spatial': '...km', 'temporal': '...min'}. This is computed from the + input resolution and the enhancement factors.""" + output_res = self.meta.get('output_resolution', None) + if self.input_resolution is not None and output_res is None: + ires_num, ores_num = self._get_numerical_resolutions() + output_res = {k: v.replace(str(ires_num[k]), str(ores_num[k])) + for k, v in self.input_resolution.items()} + self.meta['output_resolution'] = output_res + return output_res + + def _combine_fwp_input(self, low_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data prior + to forward pass through generator + + Parameters + ---------- + low_res : np.ndarray + Low-resolution input data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} + + Returns + ------- + low_res : np.ndarray + Low-resolution input data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + if exogenous_data is None: + return low_res + + training_features = ([] if self.training_features is None + else self.training_features) + fnum_diff = len(training_features) - low_res.shape[-1] + exo_feats = ([] if fnum_diff <= 0 + else self.training_features[-fnum_diff:]) + msg = ('Provided exogenous_data is missing some required features ' + f'({exo_feats})') + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + entry = exogenous_data[feature] + combine_types = [step['combine_type'] + for step in entry['steps']] + if 'input' in combine_types: + idx = combine_types.index('input') + low_res = np.concatenate((low_res, + entry['steps'][idx]['data']), + axis=-1) + return low_res + + def _combine_fwp_output(self, hi_res, exogenous_data=None): + """Combine exogenous_data at output resolution with generated hi_res + data following forward pass output. + + Parameters + ---------- + hi_res : np.ndarray + High-resolution output data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} + + Returns + ------- + hi_res : np.ndarray + High-resolution output data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + if exogenous_data is None: + return hi_res + + output_features = ([] if self.output_features is None + else self.output_features) + fnum_diff = len(output_features) - hi_res.shape[-1] + exo_feats = ([] if fnum_diff <= 0 + else self.output_features[-fnum_diff:]) + msg = ('Provided exogenous_data is missing some required features ' + f'({exo_feats})') + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + entry = exogenous_data[feature] + combine_types = [step['combine_type'] + for step in entry['steps']] + if 'output' in combine_types: + idx = combine_types.index('output') + hi_res = np.concatenate((hi_res, + entry['steps'][idx]['data']), + axis=-1) + return hi_res + + def _combine_loss_input(self, high_res_true, high_res_gen): + """Combine exogenous feature data from high_res_true with high_res_gen + for loss calculation + + Parameters + ---------- + high_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + high_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + high_res_gen : tf.Tensor + Same as input with exogenous data combined with high_res input + """ + if high_res_true.shape[-1] > high_res_gen.shape[-1]: + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_data = high_res_true[..., f_idx: f_idx + 1] + high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) + return high_res_gen + + @property + def exogenous_features(self): + """Get list of exogenous filter names the model uses. If the model has + N concat or add layers this list will be the last N features in the + training features list. The ordering is assumed to be the same as the + order of concat or add layers. If training features is [..., topo, + sza], and the model has 2 concat or add layers, exo features will be + [topo, sza]. Topo will then be used in the first concat layer and sza + will be used in the second""" + # pylint: disable=E1101 + features = [] + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + features.append(layer.name) + return features @property @abstractmethod @@ -169,31 +405,69 @@ def version_record(self): """ return VERSION_RECORD - def set_model_params(self, **kwargs): - """Set parameters used for training the model + def _check_exo_features(self, **kwargs): + """Make sure exogenous features have the correct ordering and are + included in training_features Parameters ---------- kwargs : dict Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + + Returns + ------- + kwargs : dict + Same as input but with exogenous_features removed from output + features """ + if 'output_features' not in kwargs: + return kwargs - keys = ('training_features', 'output_features', 'smoothed_features', - 's_enhance', 't_enhance', 'smoothing') + output_features = kwargs['output_features'] + msg = (f'Last {len(self.exogenous_features)} output features from the ' + f'data handler must be {self.exogenous_features} ' + 'to train the Exo model, but received output features: {}'. + format(output_features)) + exo_features = ([] if len(self.exogenous_features) == 0 + else output_features[-len(self.exogenous_features):]) + assert exo_features == self.exogenous_features, msg + for f in self.exogenous_features: + output_features.remove(f) + kwargs['output_features'] = output_features + return kwargs + + def set_model_params(self, **kwargs): + """Set parameters used for training the model + + Parameters + ---------- + kwargs : dict + Keyword arguments including 'input_resolution', + 'training_features', 'output_features', 'smoothed_features', + 's_enhance', 't_enhance', 'smoothing' + """ + kwargs = self._check_exo_features(**kwargs) + + keys = ('input_resolution', 'training_features', 'output_features', + 'smoothed_features', 's_enhance', 't_enhance', 'smoothing') keys = [k for k in keys if k in kwargs] for var in keys: - val = getattr(self, var, None) + val = self.meta.get(var, None) if val is None: self.meta[var] = kwargs[var] elif val != kwargs[var]: msg = ('Model was previously trained with {var}={} but ' - 'received new {var}={}' - .format(val, kwargs[var], var=var)) + 'received new {var}={}'.format(val, + kwargs[var], + var=var)) logger.warning(msg) warn(msg) + self._ensure_valid_enhancement_factors() + self._ensure_valid_input_resolution() + def save_params(self, out_dir): """ Parameters @@ -255,16 +529,14 @@ def load_network(self, model, name): self._meta[f'config_{name}'] = model if 'hidden_layers' in model: model = model['hidden_layers'] - elif ('meta' in model - and f'config_{name}' in model['meta'] + elif ('meta' in model and f'config_{name}' in model['meta'] and 'hidden_layers' in model['meta'][f'config_{name}']): model = model['meta'][f'config_{name}']['hidden_layers'] else: msg = ('Could not load model from json config, need ' '"hidden_layers" key or ' f'"meta/config_{name}/hidden_layers" ' - ' at top level but only found: {}' - .format(model.keys())) + ' at top level but only found: {}'.format(model.keys())) logger.error(msg) raise KeyError(msg) @@ -277,8 +549,8 @@ def load_network(self, model, name): if not isinstance(model, CustomNetwork): msg = ('Something went wrong. Tried to load a custom network ' - 'but ended up with a model of type "{}"' - .format(type(model))) + 'but ended up with a model of type "{}"'.format( + type(model))) logger.error(msg) raise TypeError(msg) @@ -308,23 +580,27 @@ def stdevs(self): def output_stdevs(self): """Get the data normalization standard deviation values for only the output features + Returns ------- np.ndarray """ - indices = [self.training_features.index(f) - for f in self.output_features] + indices = [ + self.training_features.index(f) for f in self.output_features + ] return self._stdevs[indices] @property def output_means(self): """Get the data normalization mean values for only the output features + Returns ------- np.ndarray """ - indices = [self.training_features.index(f) - for f in self.output_features] + indices = [ + self.training_features.index(f) for f in self.output_features + ] return self._means[indices] def set_norm_stats(self, new_means, new_stdevs): @@ -341,10 +617,10 @@ def set_norm_stats(self, new_means, new_stdevs): if self._means is not None: logger.info('Setting new normalization statistics...') - logger.info("Model's previous data mean values: {}" - .format(self._means)) - logger.info("Model's previous data stdev values: {}" - .format(self._stdevs)) + logger.info("Model's previous data mean values: {}".format( + self._means)) + logger.info("Model's previous data stdev values: {}".format( + self._stdevs)) self._means = new_means self._stdevs = new_stdevs @@ -354,10 +630,10 @@ def set_norm_stats(self, new_means, new_stdevs): if not isinstance(self._stdevs, np.ndarray): self._stdevs = np.array(self._stdevs) - logger.info('Set data normalization mean values: {}' - .format(self._means)) - logger.info('Set data normalization stdev values: {}' - .format(self._stdevs)) + logger.info('Set data normalization mean values: {}'.format( + self._means)) + logger.info('Set data normalization stdev values: {}'.format( + self._stdevs)) def norm_input(self, low_res): """Normalize low resolution data being input to the generator. @@ -384,7 +660,7 @@ def norm_input(self, low_res): if any(self._stdevs == 0): stdevs = np.where(self._stdevs == 0, 1, self._stdevs) - msg = ('Some standard deviations are zero.') + msg = 'Some standard deviations are zero.' logger.warning(msg) warn(msg) else: @@ -459,25 +735,6 @@ def generator_weights(self): """ return self.generator.weights - def _needs_lr_exo(self, low_res): - """Determine whether or not the sup3r model needs low-res exogenous - data - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - needs_lr_exo : bool - True if the model requires low-resolution exogenous data. - """ - - return low_res.shape[-1] < len(self.training_features) - @staticmethod def init_optimizer(optimizer, learning_rate): """Initialize keras optimizer object. @@ -500,8 +757,10 @@ def init_optimizer(optimizer, learning_rate): class_name = optimizer['name'] OptimizerClass = getattr(optimizers, class_name) sig = signature(OptimizerClass) - optimizer_kwargs = {k: v for k, v in optimizer.items() - if k in sig.parameters} + optimizer_kwargs = { + k: v + for k, v in optimizer.items() if k in sig.parameters + } optimizer = OptimizerClass.from_config(optimizer_kwargs) elif optimizer is None: optimizer = optimizers.Adam(learning_rate=learning_rate) @@ -544,11 +803,32 @@ def load_saved_params(out_dir, verbose=True): if verbose: logger.info('Loading model from disk ' 'that was created with the ' - 'following package versions: \n{}' - .format(pprint.pformat(version_record, indent=2))) + 'following package versions: \n{}'.format( + pprint.pformat(version_record, indent=2))) return params + def get_high_res_exo_input(self, high_res): + """Get exogenous feature data from high_res + + Parameters + ---------- + high_res : tf.Tensor + Ground truth high resolution spatiotemporal data. + + Returns + ------- + exo_data : dict + Dictionary of exogenous feature data used as input to tf_generate. + e.g. {'topography': tf.Tensor(...)} + """ + exo_data = {} + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_fdata = high_res[..., f_idx: f_idx + 1] + exo_data[feature] = exo_fdata + return exo_data + @staticmethod def get_loss_fun(loss): """Get the initialized loss function class from the sup3r loss library @@ -575,8 +855,8 @@ def get_loss_fun(loss): if out is None: msg = ('Could not find requested loss function "{}" in ' - 'sup3r.utilities.loss_metrics or tf.keras.losses.' - .format(loss)) + 'sup3r.utilities.loss_metrics or tf.keras.losses.'.format( + loss)) logger.error(msg) raise KeyError(msg) @@ -619,7 +899,7 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None): Namespace of the breakdown of loss components for a single new batch. batch_len : int - Length of the incomming batch. + Length of the incoming batch. prefix : None | str Option to prefix the names of the loss data when saving to the loss_details dictionary. @@ -635,8 +915,8 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None): for key, new_value in new_data.items(): key = key if prefix is None else prefix + key - new_value = (new_value if not isinstance(new_value, tf.Tensor) - else new_value.numpy()) + new_value = (new_value if not isinstance(new_value, tf.Tensor) else + new_value.numpy()) if key in loss_details: saved_value = loss_details[key] @@ -710,8 +990,8 @@ def early_stop(history, column, threshold=0.005, n_epoch=5): stop = True logger.info('Found early stop condition, loss values "{}" ' 'have absolute relative differences less than ' - 'threshold {}: {}' - .format(column, threshold, diffs[-n_epoch:])) + 'threshold {}: {}'.format(column, threshold, + diffs[-n_epoch:])) return stop @@ -726,10 +1006,17 @@ def save(self, out_dir): if it does not already exist. """ - def finish_epoch(self, epoch, epochs, t0, loss_details, - checkpoint_int, out_dir, - early_stop_on, early_stop_threshold, - early_stop_n_epoch, extras=None): + def finish_epoch(self, + epoch, + epochs, + t0, + loss_details, + checkpoint_int, + out_dir, + early_stop_on, + early_stop_threshold, + early_stop_n_epoch, + extras=None): """Perform finishing checks after an epoch is done training Parameters @@ -790,7 +1077,8 @@ def finish_epoch(self, epoch, epochs, t0, loss_details, stop = False if early_stop_on is not None and early_stop_on in self._history: - stop = self.early_stop(self._history, early_stop_on, + stop = self.early_stop(self._history, + early_stop_on, threshold=early_stop_threshold, n_epoch=early_stop_n_epoch) if stop: @@ -802,57 +1090,12 @@ def finish_epoch(self, epoch, epochs, t0, loss_details, return stop - @tf.function() - def get_single_grad(self, low_res, hi_res_true, training_weights, - device_name=None, **calc_loss_kwargs): - """Run gradient descent for one mini-batch of (low_res, hi_res_true), - do not update weights, just return gradient details. - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - hi_res_true : np.ndarray - Real high-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - training_weights : list - A list of layer weights that are to-be-trained based on the - current loss weight values. - device_name : None | str - Optional tensorflow device name for GPU placement. Note that if a - GPU is available, variables will be placed on that GPU even if - device_name=None. - calc_loss_kwargs : dict - Kwargs to pass to the self.calc_loss() method - - Returns - ------- - grad : list - a list or nested structure of Tensors (or IndexedSlices, or None, - or CompositeTensor) representing the gradients for the - training_weights - loss_details : dict - Namespace of the breakdown of loss components - """ - - with tf.device(device_name): - with tf.GradientTape(watch_accessed_variables=False) as tape: - tape.watch(training_weights) - - hi_res_gen = self._tf_generate(low_res) - loss_out = self.calc_loss(hi_res_true, hi_res_gen, - **calc_loss_kwargs) - loss, loss_details = loss_out - - grad = tape.gradient(loss, training_weights) - - return grad, loss_details - - def run_gradient_descent(self, low_res, hi_res_true, training_weights, - optimizer=None, multi_gpu=False, + def run_gradient_descent(self, + low_res, + hi_res_true, + training_weights, + optimizer=None, + multi_gpu=False, **calc_loss_kwargs): # pylint: disable=E0602 """Run gradient descent for one mini-batch of (low_res, hi_res_true) @@ -909,6 +1152,7 @@ def run_gradient_descent(self, low_res, hi_res_true, training_weights, lr_chunks = np.array_split(low_res, len(self.gpu_list)) hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) split_mask = False + mask_chunks = None if 'mask' in calc_loss_kwargs: split_mask = True mask_chunks = np.array_split(calc_loss_kwargs['mask'], @@ -918,13 +1162,14 @@ def run_gradient_descent(self, low_res, hi_res_true, training_weights, for i in range(len(self.gpu_list)): if split_mask: calc_loss_kwargs['mask'] = mask_chunks[i] - futures.append(exe.submit(self.get_single_grad, - lr_chunks[i], - hr_true_chunks[i], - training_weights, - device_name=f'/gpu:{i}', - **calc_loss_kwargs)) - for i, future in enumerate(futures): + futures.append( + exe.submit(self.get_single_grad, + lr_chunks[i], + hr_true_chunks[i], + training_weights, + device_name=f'/gpu:{i}', + **calc_loss_kwargs)) + for _, future in enumerate(futures): grad, loss_details = future.result() optimizer.apply_gradients(zip(grad, training_weights)) @@ -934,45 +1179,7 @@ def run_gradient_descent(self, low_res, hi_res_true, training_weights, return loss_details - -# pylint: disable=E1101,W0201,E0203 -class AbstractWindInterface(ABC): - """ - Abstract class to define the required training interface - for Sup3r wind model subclasses - """ - - # pylint: disable=E0211 - @staticmethod - def set_model_params(**kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing'. For the - Wind classes, the last entry in "output_features" must be - "topography" - - Returns - ------- - kwargs : dict - Same as input but with topography removed from "output_features", - this is because topography is concatenated mid-network in the - WindGan generators and is not an output feature but is required in - the hi-res training set. - """ - output_features = kwargs['output_features'] - msg = ('Last output feature from the data handler must be topography ' - 'to train the WindCC model, but received output features: {}' - .format(output_features)) - assert output_features[-1] == 'topography', msg - output_features.remove('topography') - kwargs['output_features'] = output_features - return kwargs - - def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): + def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): """Reshape the hi_res_topo to match the hi_res tensor (if necessary) and normalize (if requested). @@ -983,16 +1190,18 @@ def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): array with shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) - hi_res_topo : np.ndarray + hi_res_exo : np.ndarray This should be a 4D array for spatial enhancement model or 5D array for a spatiotemporal enhancement model (obs, spatial_1, spatial_2, (temporal), features) corresponding to the high-resolution - spatial_1 and spatial_2. This data will be input to the custom - phygnn Sup3rAdder or Sup3rConcat layer if found in the generative - network. This differs from the exogenous_data input in that - exogenous_data always matches the low-res input. For this function, - hi_res_topo can also be a 2D array (spatial_1, spatial_2). Note - that this input gets normalized if norm_in=True. + spatial_1, spatial_2, temporal. This data will be input to the + custom phygnn Sup3rAdder or Sup3rConcat layer if found in the + generative network. This differs from the exogenous_data input in + that exogenous_data always matches the low-res input. For this + function, hi_res_exo can also be a 3D array (spatial_1, spatial_2, + 1). Note that this input gets normalized if norm_in=True. + exo_name : str + Name of feature corresponding to hi_res_exo data. norm_in : bool Flag to normalize low_res input data if the self._means, self._stdevs attributes are available. The generator should always @@ -1005,37 +1214,66 @@ def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): Same as input but reshaped to match hi_res (if necessary) and normalized (if requested) """ - if hi_res_topo is None: - return hi_res_topo + if hi_res_exo is None: + return hi_res_exo if norm_in and self._means is not None: - idf = self.training_features.index('topography') - hi_res_topo = ((hi_res_topo.copy() - self._means[idf]) - / self._stdevs[idf]) - - if len(hi_res_topo.shape) > 2: - slicer = [0] * len(hi_res_topo.shape) - slicer[1] = slice(None) - slicer[2] = slice(None) - hi_res_topo = hi_res_topo[tuple(slicer)] - - if len(hi_res.shape) == 4: - hi_res_topo = np.expand_dims(hi_res_topo, axis=(0, 3)) - hi_res_topo = np.repeat(hi_res_topo, hi_res.shape[0], axis=0) - elif len(hi_res.shape) == 5: - hi_res_topo = np.expand_dims(hi_res_topo, axis=(0, 3, 4)) - hi_res_topo = np.repeat(hi_res_topo, hi_res.shape[0], axis=0) - hi_res_topo = np.repeat(hi_res_topo, hi_res.shape[3], axis=3) - - if len(hi_res_topo.shape) != len(hi_res.shape): - msg = ('hi_res and hi_res_topo arrays are not of the same rank: ' - '{} and {}'.format(hi_res.shape, hi_res_topo.shape)) + idf = self.training_features.index(exo_name) + hi_res_exo = ((hi_res_exo.copy() - self._means[idf]) + / self._stdevs[idf]) + + if len(hi_res_exo.shape) == 3: + hi_res_exo = np.expand_dims(hi_res_exo, axis=0) + hi_res_exo = np.repeat(hi_res_exo, hi_res.shape[0], axis=0) + if len(hi_res_exo.shape) == 4 and len(hi_res.shape) == 5: + hi_res_exo = np.expand_dims(hi_res_exo, axis=3) + hi_res_exo = np.repeat(hi_res_exo, hi_res.shape[3], axis=3) + + if len(hi_res_exo.shape) != len(hi_res.shape): + msg = ('hi_res and hi_res_exo arrays are not of the same rank: ' + '{} and {}'.format(hi_res.shape, hi_res_exo.shape)) logger.error(msg) raise RuntimeError(msg) - return hi_res_topo + return hi_res_exo + + def _get_layer_exo_input(self, layer_name, exogenous_data): + """Get the high-resolution exo data for the given layer name from the + full exogenous_data dictionary. - def generate(self, low_res, norm_in=True, un_norm_out=True, + Parameters + ---------- + layer_name : str + Name of Sup3rAdder or Sup3rConcat layer. This should match a + feature key in exogenous_data + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} + + """ + msg = (f'layer.name = {layer_name} does not match any ' + 'features in exogenous_data ' + f'({list(exogenous_data)})') + assert layer_name in exogenous_data, msg + steps = exogenous_data[layer_name]['steps'] + combine_types = [step['combine_type'] for step in steps] + msg = ('Received exogenous_data without any combine_type ' + '= "layer" steps, for a model with an Adder/Concat ' + 'layer.') + assert 'layer' in combine_types, msg + idx = combine_types.index('layer') + hi_res_exo = steps[idx]['data'] + return hi_res_exo + + def generate(self, + low_res, + norm_in=True, + un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res input. This is the public generate function. @@ -1054,12 +1292,14 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units - exogenous_data : ndarray | list | None - Exogenous data for topography inputs. The first entry in this list - (or only entry) is a low-resolution topography array that can be - concatenated to the low_res input array. The second entry is - high-resolution topography (either 2D or 4D/5D depending on if - spatial or spatiotemporal super res). + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} Returns ------- @@ -1069,35 +1309,26 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - low_res_topo = None - hi_res_topo = None - if isinstance(exogenous_data, np.ndarray): - low_res_topo = exogenous_data - elif isinstance(exogenous_data, (list, tuple)): - low_res_topo = exogenous_data[0] - if len(exogenous_data) > 1: - hi_res_topo = exogenous_data[1] - - exo_check = (low_res is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check - else np.concatenate((low_res, low_res_topo), axis=-1)) - + low_res = self._combine_fwp_input(low_res, exogenous_data) if norm_in and self._means is not None: low_res = self.norm_input(low_res) hi_res = self.generator.layers[0](low_res) for i, layer in enumerate(self.generator.layers[1:]): try: - if (isinstance(layer, (Sup3rAdder, Sup3rConcat)) - and hi_res_topo is not None): - hi_res_topo = self._reshape_norm_topo(hi_res, hi_res_topo, - norm_in=norm_in) - hi_res = layer(hi_res, hi_res_topo) + if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + hi_res_exo = self._get_layer_exo_input(layer.name, + exogenous_data) + hi_res_exo = self._reshape_norm_exo(hi_res, + hi_res_exo, + layer.name, + norm_in=norm_in) + hi_res = layer(hi_res, hi_res_exo) else: hi_res = layer(hi_res) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, hi_res.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) from e @@ -1106,25 +1337,30 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, if un_norm_out and self._means is not None: hi_res = self.un_norm_output(hi_res) + hi_res = self._combine_fwp_output(hi_res, exogenous_data) + return hi_res @tf.function - def _tf_generate(self, low_res, hi_res_topo): - """Use the generator model to generate high res data from los res input + def _tf_generate(self, low_res, hi_res_exo=None): + """Use the generator model to generate high res data from low res input Parameters ---------- low_res : np.ndarray Real low-resolution data. The generator should always received normalized data with mean=0 stdev=1. - hi_res_topo : np.ndarray - This should be a 4D array for spatial enhancement model or 5D array - for a spatiotemporal enhancement model (obs, spatial_1, spatial_2, - (temporal), features) corresponding to the high-resolution - spatial_1 and spatial_2. This data will be input to the custom - phygnn Sup3rAdder or Sup3rConcat layer if found in the generative - network. This differs from the exogenous_data input in that - exogenous_data always matches the low-res input. + hi_res_exo : dict + Dictionary of exogenous_data with same resolution as high_res data + e.g. {'topography': np.array} + The arrays in this dictionary should be a 4D array for spatial + enhancement model or 5D array for a spatiotemporal enhancement + model (obs, spatial_1, spatial_2, (temporal), features) + corresponding to the high-resolution spatial_1 and spatial_2. This + data will be input to the custom phygnn Sup3rAdder or Sup3rConcat + layer if found in the generative network. This differs from the + exogenous_data input in that exogenous_data always matches the + low-res input. Returns ------- @@ -1134,23 +1370,29 @@ def _tf_generate(self, low_res, hi_res_topo): hi_res = self.generator.layers[0](low_res) for i, layer in enumerate(self.generator.layers[1:]): try: - if (isinstance(layer, (Sup3rAdder, Sup3rConcat)) - and hi_res_topo is not None): - hi_res = layer(hi_res, hi_res_topo) - + if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + msg = (f'layer.name = {layer.name} does not match any ' + f'features in exogenous_data ({list(hi_res_exo)})') + assert layer.name in hi_res_exo, msg + hr_exo = hi_res_exo[layer.name] + hi_res = layer(hi_res, hr_exo) else: hi_res = layer(hi_res) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, hi_res.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) from e return hi_res @tf.function() - def get_single_grad(self, low_res, hi_res_true, training_weights, - device_name=None, **calc_loss_kwargs): + def get_single_grad(self, + low_res, + hi_res_true, + training_weights, + device_name=None, + **calc_loss_kwargs): """Run gradient descent for one mini-batch of (low_res, hi_res_true), do not update weights, just return gradient details. @@ -1183,14 +1425,12 @@ def get_single_grad(self, low_res, hi_res_true, training_weights, loss_details : dict Namespace of the breakdown of loss components """ - - hi_res_topo = hi_res_true[..., -1:] - with tf.device(device_name): with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(training_weights) - hi_res_gen = self._tf_generate(low_res, hi_res_topo) + hi_res_exo = self.get_high_res_exo_input(hi_res_true) + hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss_out = self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) loss, loss_details = loss_out diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 9e97b18ec3..0878ecac25 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class Sup3rGan(AbstractInterface, AbstractSingleModel): +class Sup3rGan(AbstractSingleModel, AbstractInterface): """Basic sup3r GAN model.""" def __init__(self, @@ -179,91 +179,6 @@ def load(cls, model_dir, verbose=True): return cls(fp_gen, fp_disc, **params) - def generate(self, - low_res, - norm_in=True, - un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self._means, - self._stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : ndarray | None - Exogenous data array, usually a 4D or 5D array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - exo_check = (exogenous_data is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check else np.concatenate( - (low_res, exogenous_data), axis=-1)) - - if norm_in and self._means is not None: - low_res = self.norm_input(low_res) - - hi_res = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - hi_res = layer(hi_res) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}'. - format(i + 1, layer, hi_res.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - hi_res = hi_res.numpy() - - if un_norm_out and self._means is not None: - hi_res = self.un_norm_output(hi_res) - - return hi_res - - @tf.function - def _tf_generate(self, low_res): - """Use the generator model to generate high res data from los res input - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data. The generator should always - received normalized data with mean=0 stdev=1. - - Returns - ------- - hi_res : tf.Tensor - Synthetically generated high-resolution data - """ - hi_res = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - hi_res = layer(hi_res) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}'. - format(i + 1, layer, hi_res.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - return hi_res - @property def discriminator(self): """Get the discriminator model. @@ -462,8 +377,15 @@ def init_weights(self, lr_shape, hr_shape, device=None): logger.info('Initializing model weights on device "{}"'.format(device)) low_res = np.ones(lr_shape).astype(np.float32) hi_res = np.ones(hr_shape).astype(np.float32) + + hr_exo_shape = hr_shape[:-1] + (1,) + hr_exo = np.ones(hr_exo_shape).astype(np.float32) + with tf.device(device): - _ = self._tf_generate(low_res) + hr_exo_data = {} + for feature in self.exogenous_features: + hr_exo_data[feature] = hr_exo + _ = self._tf_generate(low_res, hr_exo_data) _ = self._tf_discriminate(hi_res) @staticmethod @@ -525,7 +447,7 @@ def calc_loss_gen_content(self, hi_res_true, hi_res_gen): 0D tensor generator model loss for the content loss comparing the hi res ground truth to the hi res synthetically generated output. """ - + hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) loss_gen_content = self.loss_fun(hi_res_true, hi_res_gen) return loss_gen_content @@ -624,6 +546,7 @@ def calc_loss(self, loss_details : dict Namespace of the breakdown of loss components """ + hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) if hi_res_gen.shape != hi_res_true.shape: msg = ('The tensor shapes of the synthetic output {} and ' @@ -639,7 +562,7 @@ def calc_loss(self, loss_gen_content = self.calc_loss_gen_content(hi_res_true, hi_res_gen) loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) - loss_gen = (loss_gen_content + weight_gen_advers * loss_gen_advers) + loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) @@ -679,19 +602,17 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - output_gen = self._tf_generate(val_batch.low_res) + val_exo_data = self.get_high_res_exo_input(val_batch.high_res) + high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( - val_batch.high_res, - output_gen, + val_batch.high_res, high_res_gen, weight_gen_advers=weight_gen_advers, - train_gen=False, - train_disc=False) + train_gen=False, train_disc=False) loss_details = self.update_loss_details(loss_details, v_loss_details, len(val_batch), prefix='val_') - return loss_details def train_epoch(self, @@ -851,6 +772,7 @@ def update_adversarial_weights(self, history, adaptive_update_fraction, def train(self, batch_handler, + input_resolution, n_epoch, weight_gen_advers=0.001, train_gen=True, @@ -870,6 +792,9 @@ def train(self, ---------- batch_handler : sup3r.data_handling.preprocessing.BatchHandler BatchHandler object to iterate through + input_resolution : dict + Dictionary specifying spatiotemporal input resolution. e.g. + {'temporal': '60min', 'spatial': '30km'} n_epoch : int Number of epochs to train on weight_gen_advers : float @@ -925,6 +850,7 @@ def train(self, self.set_norm_stats(batch_handler.means, batch_handler.stds) self.set_model_params( + input_resolution=input_resolution, s_enhance=batch_handler.s_enhance, t_enhance=batch_handler.t_enhance, smoothing=batch_handler.smoothing, diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index 4a3a2fc460..f97872d694 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- """Sup3r conditional moment model software""" +import logging import os +import pprint import time -import logging + import numpy as np -import pprint import pandas as pd import tensorflow as tf from tensorflow.keras import optimizers @@ -12,11 +13,10 @@ from sup3r.models.abstract import AbstractInterface, AbstractSingleModel from sup3r.utilities import VERSION_RECORD - logger = logging.getLogger(__name__) -class Sup3rCondMom(AbstractInterface, AbstractSingleModel): +class Sup3rCondMom(AbstractSingleModel, AbstractInterface): """Basic Sup3r conditional moments model.""" def __init__(self, gen_layers, @@ -142,89 +142,6 @@ def load(cls, model_dir, verbose=True): params = cls.load_saved_params(model_dir, verbose=verbose) return cls(fp_gen, **params) - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self._means, - self._stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : ndarray | None - Exogenous data array, usually a 4D or 5D array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - output : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - exo_check = (exogenous_data is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check - else np.concatenate((low_res, exogenous_data), axis=-1)) - - if norm_in and self._means is not None: - low_res = self.norm_input(low_res) - - output = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - output = layer(output) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, output.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - output = output.numpy() - - if un_norm_out and self._means is not None: - output = self.un_norm_output(output) - - return output - - @tf.function - def _tf_generate(self, low_res): - """Use the generator model to generate high res data from los res input - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data. The generator should always - received normalized data with mean=0 stdev=1. - - Returns - ------- - output : tf.Tensor - Synthetically generated high-resolution data - """ - - output = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - output = layer(output) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, output.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - return output - def update_optimizer(self, **kwargs): """Update optimizer by changing current configuration @@ -331,6 +248,7 @@ def calc_loss(self, output_true, output_gen, mask): loss_details : dict Namespace of the breakdown of loss components """ + output_gen = self._combine_loss_input(output_true, output_gen) if output_gen.shape != output_true.shape: msg = ('The tensor shapes of the synthetic output {} and ' @@ -365,7 +283,8 @@ def calc_val_loss(self, batch_handler, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - output_gen = self._tf_generate(val_batch.low_res) + val_exo_data = self.get_high_res_exo_input(val_batch.high_res) + output_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.output, output_gen, val_batch.mask) @@ -420,7 +339,9 @@ def train_epoch(self, batch_handler, multi_gpu=False): return loss_details - def train(self, batch_handler, n_epoch, + def train(self, batch_handler, + input_resolution, + n_epoch, checkpoint_int=None, out_dir='./condMom_{epoch}', early_stop_on=None, @@ -433,6 +354,9 @@ def train(self, batch_handler, n_epoch, ---------- batch_handler : sup3r.data_handling.preprocessing.BatchHandler BatchHandler object to iterate through + input_resolution : dict + Dictionary specifying spatiotemporal input resolution. e.g. + {'temporal': '60min', 'spatial': '30km'} n_epoch : int Number of epochs to train on checkpoint_int : int | None @@ -467,6 +391,7 @@ def train(self, batch_handler, n_epoch, """ self.set_norm_stats(batch_handler.means, batch_handler.stds) self.set_model_params( + input_resolution=input_resolution, s_enhance=batch_handler.s_enhance, t_enhance=batch_handler.t_enhance, smoothing=batch_handler.smoothing, diff --git a/sup3r/models/data_centric.py b/sup3r/models/data_centric.py index a91b957859..a53f3954dd 100644 --- a/sup3r/models/data_centric.py +++ b/sup3r/models/data_centric.py @@ -6,7 +6,6 @@ import numpy as np from sup3r.models.base import Sup3rGan -from sup3r.models.wind import WindGan from sup3r.utilities.utilities import round_array logger = logging.getLogger(__name__) @@ -39,7 +38,8 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers): """ losses = [] for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res) + exo_data = self.get_high_res_exo_input(obs.high_res) + gen = self._tf_generate(obs.low_res, exo_data) loss, _ = self.calc_loss(obs.high_res, gen, weight_gen_advers=weight_gen_advers, train_gen=True, train_disc=True) @@ -66,7 +66,8 @@ def calc_val_loss_gen_content(self, batch_handler): """ losses = [] for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res) + exo_data = self.get_high_res_exo_input(obs.high_res) + gen = self._tf_generate(obs.low_res, exo_data) loss = self.calc_loss_gen_content(obs.high_res, gen) losses.append(float(loss)) return losses @@ -93,7 +94,6 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): Updated loss_details with mean validation loss calculated using the validation samples across the time bins """ - total_losses = self.calc_val_loss_gen(batch_handler, weight_gen_advers) content_losses = self.calc_val_loss_gen_content(batch_handler) @@ -138,71 +138,6 @@ def calc_temporal_losses(total_losses, content_losses, batch_handler): f'{round_array(new_temporal_weights)}') -class WindGanDC(WindGan, Sup3rGanDC): - """Data-centric model using loss across time bins to select training - observations with handling of low and high res topography - inputs.""" - - def calc_val_loss_gen(self, batch_handler, weight_gen_advers): - """Calculate the validation total loss across the validation - samples. e.g. If the sample domain has 100 steps and the - validation set has 10 bins then this will get a list of losses across - step 0 to 10, 10 to 20, etc. Use this to determine performance - within bins and to update how observations are selected from these - bins. Use the _tf_generate function from WindGan to include the high - resolution topography. - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandlerDC - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - - Returns - ------- - list - List of total losses for all sample bins - """ - losses = [] - for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res, - obs.high_res[..., -1:]) - loss, _ = self.calc_loss(obs.high_res, gen, - weight_gen_advers=weight_gen_advers, - train_gen=True, train_disc=True) - losses.append(float(loss)) - return losses - - def calc_val_loss_gen_content(self, batch_handler): - """Calculate the validation content loss across the validation - samples. e.g. If the sample domain has 100 steps and the - validation set has 10 bins then this will get a list of losses across - step 0 to 10, 10 to 20, etc. Use this to determine performance - within bins and to update how observations are selected from these - bins. Use the _tf_generate function from WindGan to include high - resolution topography. - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandlerDC - BatchHandler object to iterate through - - Returns - ------- - list - List of content losses for all sample bins - """ - losses = [] - for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res, - obs.high_res[..., -1:]) - loss = self.calc_loss_gen_content(obs.high_res, gen) - losses.append(float(loss)) - return losses - - class Sup3rGanSpatialDC(Sup3rGanDC): """Data-centric model using loss across time bins to select training observations""" diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index e049a3efd3..378b068c3c 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- """Simple models for super resolution such as linear interp models.""" -import numpy as np +import json import logging -from inspect import signature import os -import json -from sup3r.utilities.utilities import st_interp +from inspect import signature + +import numpy as np + from sup3r.models.abstract import AbstractInterface +from sup3r.utilities.utilities import st_interp logger = logging.getLogger(__name__) @@ -59,7 +61,7 @@ class init args. """ fp_params = os.path.join(model_dir, 'model_params.json') assert os.path.exists(fp_params), f'Could not find: {fp_params}' - with open(fp_params, 'r') as f: + with open(fp_params) as f: params = json.load(f) meta = params['meta'] diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 43d94d848f..c76fd47dfa 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -1,17 +1,17 @@ # -*- coding: utf-8 -*- """Sup3r multi step model frameworks""" -import os +import copy import json import logging +import os + import numpy as np -from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat # pylint: disable=cyclic-import import sup3r.models from sup3r.models.abstract import AbstractInterface from sup3r.models.base import Sup3rGan - logger = logging.getLogger(__name__) @@ -33,25 +33,6 @@ def __len__(self): """Get number of model steps""" return len(self._models) - @staticmethod - def _needs_hr_exo(model): - """Determine whether or not the sup3r model needs hi-res exogenous data - - Parameters - ---------- - model : Sup3rGan | WindGan - Sup3r GAN model based on Sup3rGan with a .generator attribute - - Returns - ------- - needs_hr_exo : bool - True if the model requires high-resolution exogenous data, - typically because of the use of Sup3rAdder or Sup3rConcat layers. - """ - return (hasattr(model, 'generator') - and any(isinstance(layer, (Sup3rAdder, Sup3rConcat)) - for layer in model.generator.layers)) - @classmethod def load(cls, model_dirs, verbose=True): """Load the GANs with its sub-networks from a previously saved-to @@ -80,7 +61,7 @@ def load(cls, model_dirs, verbose=True): for model_dir in model_dirs: fp_params = os.path.join(model_dir, 'model_params.json') assert os.path.exists(fp_params), f'Could not find: {fp_params}' - with open(fp_params, 'r') as f: + with open(fp_params) as f: params = json.load(f) meta = params.get('meta', {'class': 'Sup3rGan'}) @@ -131,6 +112,82 @@ def seed(s=0): """ Sup3rGan.seed(s=s) + def _get_model_step_exo(self, model_step, exogenous_data=None): + """Get the exogenous data for the given model_step from the full + exogenous data dictionary + + Parameters + ---------- + model_step : int + Index of the model to get exogenous data for. + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) + + Returns + ------- + exogenous_data : dict + Same as input dictionary but with only entries with 'model': + model_step + """ + model_step_exo = None + if exogenous_data is not None: + model_step_exo = {} + for feature in exogenous_data: + steps = [step for step in exogenous_data[feature]['steps'] + if step['model'] == model_step] + if steps: + model_step_exo[feature] = {'steps': steps} + return model_step_exo + + def _transpose_model_input(self, model, hi_res): + """Transpose input data according to mdel input dimensions. + + NOTE: If hi_res.shape == 4, it is assumed that the dimensions have the + ordering (n_obs, spatial_1, spatial_2, features) + + If hi_res.shape == 5, it is assumed that the dimensions have the + ordering (1, spatial_1, spatial_2, temporal, features) + + Parameters + ---------- + model : Sup3rGan + A single step model with the attribute model.input_dims + hi_res : ndarray + Synthetically generated high-resolution data, usually a 4D or 5D + array with shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + + Returns + ------- + hi_res : ndarray + Synthetically generated high-resolution data transposed according + to the number of model input dimensions + """ + if model.input_dims == 5 and len(hi_res.shape) == 4: + hi_res = np.transpose( + hi_res, axes=(1, 2, 0, 3))[np.newaxis] + elif model.input_dims == 4 and len(hi_res.shape) == 5: + msg = ('Recieved 5D input data with shape ' + f'({hi_res.shape}) to a 4D model.') + assert hi_res.shape[0] == 1, msg + hi_res = np.transpose(hi_res[0], axes=(2, 0, 1, 3)) + else: + msg = ('Recieved input data with shape ' + f'{hi_res.shape} to a {model.input_dims}D model.') + assert model.input_dims == len(hi_res.shape), msg + return hi_res + def generate(self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res @@ -149,16 +206,18 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a spatial + temporal multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) Returns ------- @@ -168,10 +227,6 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - - exo_data = ([None] * len(self.models) if not exogenous_data - else exogenous_data) - hi_res = low_res.copy() for i, model in enumerate(self.models): @@ -181,11 +236,10 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, if (i + 1 == len(self.models) and not un_norm_out) else True) - i_exo_data = exo_data[i] - if self._needs_hr_exo(model): - i_exo_data = [exo_data[i], exo_data[i + 1]] + i_exo_data = self._get_model_step_exo(i, exogenous_data) try: + hi_res = self._transpose_model_input(model, hi_res) logger.debug('Data input to model #{} of {} has shape {}' .format(i + 1, len(self.models), hi_res.shape)) hi_res = model.generate(hi_res, norm_in=i_norm_in, @@ -320,273 +374,59 @@ def load(cls, spatial_model_dirs, temporal_model_dirs, verbose=True): return cls(s_models, t_models) - -class SpatialThenTemporalGan(SpatialThenTemporalBase): - """A two-step GAN where the first step is a spatial-only enhancement on a - 4D tensor and the second step is a (spatio)temporal enhancement on a 5D - tensor. - - NOTE: The low res input to the spatial enhancement should be a 4D tensor of - the shape (temporal, spatial_1, spatial_2, features) where temporal - (usually the observation index) is a series of sequential timesteps that - will be transposed to a 5D tensor of shape - (1, spatial_1, spatial_2, temporal, features) tensor and then fed to the - 2nd-step (spatio)temporal GAN. - """ - - @property - def models(self): - """Get an ordered tuple of the Sup3rGan models that are part of this - MultiStepGan - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.models - else: - spatial_models = [self.spatial_models] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.models - else: - temporal_models = [self.temporal_models] - - return (*spatial_models, *temporal_models) - - @property - def meta(self): - """Get a tuple of meta data dictionaries for all models - - Returns - ------- - tuple - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.meta - else: - spatial_models = [self.spatial_models.meta] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.meta - else: - temporal_models = [self.temporal_models.meta] - return (*spatial_models, *temporal_models) - - @property - def training_features(self): - """Get the list of input feature names that the first spatial - generative model in this SpatialThenTemporalGan model requires as - input.""" - return self.spatial_models.training_features - - @property - def output_features(self): - """Get the list of output feature names that the last spatiotemporal - interpolation model in this SpatialThenTemporalGan model outputs.""" - return self.temporal_models.output_features - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data to the 1st step spatial GAN, which is a - 4D array of shape: (temporal, spatial_1, spatial_2, n_features) - norm_in : bool - Flag to normalize low_res input data if the self.means, - self.stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a spatial + temporal multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (temporal, spatial_1, spatial_2, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data output from the 2nd - step (spatio)temporal GAN with a 5D array shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - """ - logger.debug('Data input to the 1st step spatial-only ' - 'enhancement has shape {}'.format(low_res.shape)) - t_exogenous = None - if exogenous_data is not None: - t_exogenous = exogenous_data[len(self.spatial_models):] - - try: - hi_res = self.spatial_models.generate( - low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=exogenous_data) - except Exception as e: - msg = ('Could not run the 1st step spatial-only GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step spatial-only ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - logger.debug('Data from the 1st step spatial-only enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) - - try: - hi_res = self.temporal_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=t_exogenous) - except Exception as e: - msg = ('Could not run the 2nd step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res - - -class TemporalThenSpatialGan(SpatialThenTemporalBase): - """A two-step GAN where the first step is a spatiotemporal enhancement on a - 5D tensor and the second step is a spatial enhancement on a 4D tensor. - """ - - @property - def models(self): - """Get an ordered tuple of the Sup3rGan models that are part of this - MultiStepGan - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.models - else: - spatial_models = [self.spatial_models] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.models - else: - temporal_models = [self.temporal_models] - - return (*temporal_models, *spatial_models) - - @property - def meta(self): - """Get a tuple of meta data dictionaries for all models - - Returns - ------- - tuple - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.meta - else: - spatial_models = [self.spatial_models.meta] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.meta - else: - temporal_models = [self.temporal_models.meta] - - return (*temporal_models, *spatial_models) - - @property - def training_features(self): - """Get the list of input feature names that the first temporal - generative model in this TemporalThenSpatialGan model requires as - input.""" - return self.temporal_models.training_features - - @property - def output_features(self): - """Get the list of output feature names that the last spatial - interpolation model in this TemporalThenSpatialGan model outputs.""" - return self.spatial_models.output_features - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. + def _split_exo_dict(self, split_step, exogenous_data=None): + """Split exogenous_data into two dicts based on split_step. The first + dict has only model steps less than split_step. The second dict has + only model steps greater than or equal to split_step. Parameters ---------- - low_res : np.ndarray - Low-resolution input data, a 5D array of shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self.means, - self.stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a temporal + spatial multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (temporal, spatial_1, spatial_2, n_features) + split_step : int + Step index to use for splitting. If this is for a + SpatialThenTemporal model split_step should be len(spatial_models). + If this is for a TemporalThenSpatial model split_step should be + len(temporal_models). + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) Returns ------- - hi_res : ndarray - Synthetically generated high-resolution data output from the 2nd - step (spatio)temporal GAN with a 5D array shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - """ - logger.debug('Data input to the 1st step (spatio)temporal ' - 'enhancement has shape {}'.format(low_res.shape)) - s_exogenous = None + split_exo_1 : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step is less than split_step + split_exo_2 : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step is greater than or equal to split_step + """ + split_exo_1 = {} + split_exo_2 = {} if exogenous_data is not None: - s_exogenous = exogenous_data[len(self.temporal_models):] - - assert low_res.shape[0] == 1, 'Low res input can only have 1 obs!' - - try: - hi_res = self.temporal_models.generate( - low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=exogenous_data) - except Exception as e: - msg = ('Could not run the 1st step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step (spatio)temporal ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res[0], axes=(2, 0, 1, 3)) - logger.debug('Data from the 1st step (spatio)temporal enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) - - try: - hi_res = self.spatial_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=s_exogenous) - except Exception as e: - msg = ('Could not run the 2nd step spatial GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res - - -class MultiStepSurfaceMetGan(SpatialThenTemporalGan): + exo_data = copy.deepcopy(exogenous_data) + for feature in exo_data: + steps = [step for step in exo_data[feature]['steps'] + if step['model'] < split_step] + if steps: + split_exo_1[feature] = {'steps': steps} + steps = [step for step in exo_data[feature]['steps'] + if step['model'] >= split_step] + for step in steps: + step.update({'model': step['model'] - split_step}) + if steps: + split_exo_2[feature] = {'steps': steps} + return split_exo_1, split_exo_2 + + +class MultiStepSurfaceMetGan(MultiStepGan): """A two-step GAN where the first step is a spatial-only enhancement on a 4D tensor of near-surface temperature and relative humidity data, and the second step is a (spatio)temporal enhancement on a 5D tensor. @@ -623,12 +463,18 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units - exogenous_data : list - For the MultiStepSurfaceMetGan model, this must be a 2-entry list - where the first entry is a 2D (lat, lon) array of low-resolution - surface elevation data in meters (must match spatial_1, spatial_2 - from low_res), and the second entry is a 2D (lat, lon) array of - high-resolution surface elevation data in meters. + exogenous_data : dict + For the MultiStepSurfaceMetGan, this must be a nested dictionary + with a main 'topography' key and two entries for + exogenous_data['topography']['steps']. The first entry includes a + 2D (lat, lon) array of low-resolution surface elevation data in + meters (must match spatial_1, spatial_2 from low_res), and the + second entry includes a 2D (lat, lon) array of high-resolution + surface elevation data in meters. e.g. + {'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'input', 'data': lr_topo}, + {'model': 0, 'combine_type': 'output', 'data': hr_topo'}]}} Returns ------- @@ -642,51 +488,13 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, logger.debug('Data input to the 1st step spatial-only ' 'enhancement has shape {}'.format(low_res.shape)) - msg = ('MultiStepSurfaceMetGan needs exogenous_data with two ' - 'entries for low and high res topography inputs.') - assert exogenous_data is not None, msg - assert isinstance(exogenous_data, (list, tuple)), msg - exogenous_data = [d for d in exogenous_data if d is not None] - assert len(exogenous_data) == 2, msg - - # SurfaceSpatialMetModel needs a 2D array for exo topography input - for i, i_exo in enumerate(exogenous_data): - if len(i_exo.shape) == 3: - exogenous_data[i] = i_exo[:, :, 0] - elif len(i_exo.shape) == 4: - exogenous_data[i] = i_exo[0, :, :, 0] - elif len(i_exo.shape) == 5: - exogenous_data[i] = i_exo[0, :, :, 0, 0] - - try: - hi_res = self.spatial_models.generate( - low_res, exogenous_data=exogenous_data) - except Exception as e: - msg = ('Could not run the 1st step spatial-only GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step spatial-only ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - logger.debug('Data from the 1st step spatial-only enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) + msg = ('MultiStepSurfaceMetGan needs exogenous_data with two ' + 'topography steps, for low and high res topography inputs.') + exo_check = (exogenous_data is not None + and len(exogenous_data['topography']['steps']) == 2) + assert exo_check, msg - try: - hi_res = self.temporal_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out) - except Exception as e: - msg = ('Could not run the 2nd step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res + return super().generate(low_res, norm_in, un_norm_out, exogenous_data) @classmethod def load(cls, surface_model_class='SurfaceSpatialMetModel', @@ -732,10 +540,12 @@ def load(cls, surface_model_class='SurfaceSpatialMetModel', t_models = TemporalModelClass.load(verbose=verbose, **temporal_model_kwargs) - return cls(s_models, t_models) + s_models = getattr(s_models, 'models', [s_models]) + t_models = getattr(t_models, 'models', [t_models]) + return cls([*s_models, *t_models]) -class SolarMultiStepGan(SpatialThenTemporalGan): +class SolarMultiStepGan(SpatialThenTemporalBase): """Special multi step model for solar clearsky ratio super resolution. This model takes in two parallel models for wind-only and solar-only @@ -751,11 +561,11 @@ def __init__(self, spatial_solar_models, spatial_wind_models, ---------- spatial_solar_models : MultiStepGan A loaded MultiStepGan object representing the one or more spatial - super resolution steps in this composite SpatialThenTemporalGan + super resolution steps in this composite MultiStepGan model that inputs and outputs clearsky_ratio spatial_wind_models : MultiStepGan A loaded MultiStepGan object representing the one or more spatial - super resolution steps in this composite SpatialThenTemporalGan + super resolution steps in this composite MultiStepGan model that inputs and outputs wind u/v features and must include U_200m + V_200m as output features. temporal_solar_models : MultiStepGan @@ -824,13 +634,13 @@ def preflight(self): @property def spatial_models(self): - """Alias for spatial_solar_models to preserve SpatialThenTemporalGan + """Alias for spatial_solar_models to preserve MultiStepGan interface.""" return self.spatial_solar_models @property def temporal_models(self): - """Alias for temporal_solar_models to preserve SpatialThenTemporalGan + """Alias for temporal_solar_models to preserve MultiStepGan interface.""" return self.temporal_solar_models @@ -962,15 +772,14 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, logger.debug('Data input to the SolarMultiStepGan has shape {} which ' 'will be split up for solar- and wind-only features.' .format(low_res.shape)) - t_exogenous = None - if exogenous_data is not None: - t_exogenous = exogenous_data[len(self.spatial_wind_models):] - + s_exo, t_exo = self._split_exo_dict( + split_step=len(self.spatial_models), + exogenous_data=exogenous_data) try: hi_res_wind = self.spatial_wind_models.generate( low_res[..., self.idf_wind], norm_in=norm_in, un_norm_out=True, - exogenous_data=exogenous_data) + exogenous_data=s_exo) except Exception as e: msg = ('Could not run the 1st step spatial-wind-only GAN on ' 'input shape {}'.format(low_res.shape)) @@ -1006,7 +815,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, try: hi_res = self.temporal_solar_models.generate( hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=t_exogenous) + exogenous_data=t_exo) except Exception as e: msg = ('Could not run the 2nd step (spatio)temporal solar GAN on ' 'input shape {}'.format(low_res.shape)) diff --git a/sup3r/models/surface.py b/sup3r/models/surface.py index 0d9881b0cf..f0b93b8a5e 100644 --- a/sup3r/models/surface.py +++ b/sup3r/models/surface.py @@ -2,10 +2,11 @@ """Special models for surface meteorological data.""" import logging from fnmatch import fnmatch +from warnings import warn + import numpy as np from PIL import Image from sklearn import linear_model -from warnings import warn from sup3r.models.linear import LinearInterp from sup3r.utilities.utilities import spatial_coarsening @@ -106,6 +107,7 @@ def __init__(self, features, s_enhance, noise_adders=None, self._pres_div = pres_div or self.PRES_DIV self._pres_exp = pres_exp or self.PRES_EXP self._fix_bias = fix_bias + self._input_resolution = None self._interp_method = getattr(Image.Resampling, interp_method) if isinstance(self._noise_adders, (int, float)): @@ -470,12 +472,15 @@ def generate(self, low_res, norm_in=False, un_norm_out=False, un_norm_out : bool This doesnt do anything for this SurfaceSpatialMetModel, but is kept to keep the same interface as Sup3rGan - exogenous_data : list - For the SurfaceSpatialMetModel, this must be a 2-entry list where - the first entry is a 2D (lat, lon) array of low-resolution surface - elevation data in meters (must match spatial_1, spatial_2 from - low_res), and the second entry is a 2D (lat, lon) array of - high-resolution surface elevation data in meters. + exogenous_data : dict + For the SurfaceSpatialMetModel, this must be a nested dictionary + with a main 'topography' key and two entries for + exogenous_data['topography']['steps']. The first entry includes a + 2D (lat, lon) array of low-resolution surface elevation data in + meters (must match spatial_1, spatial_2 from low_res), and the + second entry includes a 2D (lat, lon) array of high-resolution + surface elevation data in meters. e.g. + {'topography': {'steps': [{'data': lr_topo}, {'data': hr_topo'}]}} Returns ------- @@ -485,16 +490,17 @@ def generate(self, low_res, norm_in=False, un_norm_out=False, channel can include temperature_*m, relativehumidity_*m, and/or pressure_*m """ - + exo_data = [step['data'] + for step in exogenous_data['topography']['steps']] msg = ('exogenous_data is of a bad type {}!' - .format(type(exogenous_data))) - assert isinstance(exogenous_data, (list, tuple)), msg + .format(type(exo_data))) + assert isinstance(exo_data, (list, tuple)), msg msg = ('exogenous_data is of a bad length {}!' - .format(len(exogenous_data))) - assert len(exogenous_data) == 2, msg + .format(len(exo_data))) + assert len(exo_data) == 2, msg - topo_lr = exogenous_data[0] - topo_hr = exogenous_data[1] + topo_lr = exo_data[0] + topo_hr = exo_data[1] logger.debug('SurfaceSpatialMetModel received low/high res topo ' 'shapes of {} and {}' .format(topo_lr.shape, topo_hr.shape)) @@ -556,6 +562,7 @@ def meta(self): 's_enhance': self._s_enhance, 't_enhance': 1, 'noise_adders': self._noise_adders, + 'input_resolution': self._input_resolution, 'weight_for_delta_temp': self._w_delta_temp, 'weight_for_delta_topo': self._w_delta_topo, 'pressure_divisor': self._pres_div, @@ -567,10 +574,10 @@ def meta(self): 'class': self.__class__.__name__, } - def train(self, true_hr_temp, true_hr_rh, true_hr_topo): - """This method trains the relative humidity linear model. The - temperature and surface lapse rate models are parameterizations taken - from the NSRDB and are not trained. + def train(self, true_hr_temp, true_hr_rh, true_hr_topo, input_resolution): + """Trains the relative humidity linear model. The temperature and + surface lapse rate models are parameterizations taken from the NSRDB + and are not trained. Parameters ---------- @@ -583,6 +590,9 @@ def train(self, true_hr_temp, true_hr_rh, true_hr_topo): true_hr_topo : np.ndarray High-resolution surface elevation data in meters with shape (lat, lon) + input_resolution : dict + Dictionary of spatial and temporal input resolution. e.g. + {'spatial': '20km': 'temporal': '60min'} Returns ------- @@ -593,7 +603,7 @@ def train(self, true_hr_temp, true_hr_rh, true_hr_topo): Weight for the delta-topography feature for the relative humidity linear regression model. """ - + self._input_resolution = input_resolution assert len(true_hr_temp.shape) == 3, 'Bad true_hr_temp shape' assert len(true_hr_rh.shape) == 3, 'Bad true_hr_rh shape' assert len(true_hr_topo.shape) == 2, 'Bad true_hr_topo shape' diff --git a/sup3r/models/wind.py b/sup3r/models/wind.py deleted file mode 100644 index 9aa0c2f9a7..0000000000 --- a/sup3r/models/wind.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- -"""Wind super resolution GAN with handling of low and high res topography -inputs.""" -import numpy as np -import logging -import tensorflow as tf - -from sup3r.models.base import Sup3rGan -from sup3r.models.abstract import AbstractWindInterface - - -logger = logging.getLogger(__name__) - - -class WindGan(AbstractWindInterface, Sup3rGan): - """Wind super resolution GAN with handling of low and high res topography - inputs. - - Modifications to standard Sup3rGan: - - Hi res topography is expected as the last feature channel in the true - data in the true batch observation. This topo channel is appended to - the generated output so the discriminator can look at the wind fields - compared to the associated hi res topo. - - If a custom Sup3rAdder or Sup3rConcat layer (from phygnn) is present - in the network, the hi-res topography will be added or concatenated - to the data at that point in the network during either training or - the forward pass. - """ - - def init_weights(self, lr_shape, hr_shape, device=None): - """Initialize the generator and discriminator weights with device - placement. - - Parameters - ---------- - lr_shape : tuple - Shape of one batch of low res input data for sup3r resolution. Note - that the batch size (axis=0) must be included, but the actual batch - size doesnt really matter. - hr_shape : tuple - Shape of one batch of high res input data for sup3r resolution. - Note that the batch size (axis=0) must be included, but the actual - batch size doesnt really matter. - device : str | None - Option to place model weights on a device. If None, - self.default_device will be used. - """ - - if device is None: - device = self.default_device - - logger.info('Initializing model weights on device "{}"'.format(device)) - low_res = np.ones(lr_shape).astype(np.float32) - hi_res = np.ones(hr_shape).astype(np.float32) - - hr_topo_shape = hr_shape[:-1] + (1,) - hr_topo = np.ones(hr_topo_shape).astype(np.float32) - - with tf.device(device): - _ = self._tf_generate(low_res, hr_topo) - _ = self._tf_discriminate(hi_res) - - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' - """ - AbstractWindInterface.set_model_params(**kwargs) - Sup3rGan.set_model_params(self, **kwargs) - - @tf.function - def calc_loss(self, hi_res_true, hi_res_gen, **kwargs): - """Calculate the GAN loss function using generated and true high - resolution data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - kwargs : dict - Key word arguments for: - Sup3rGan.calc_loss(hi_res_true, hi_res_gen, **kwargs) - - Returns - ------- - loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components - """ - - # append the true topography to the generated synthetic wind data - hi_res_gen = tf.concat((hi_res_gen, hi_res_true[..., -1:]), axis=-1) - - return super().calc_loss(hi_res_true, hi_res_gen, **kwargs) - - def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): - """Calculate the validation loss at the current state of model training - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandler - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - loss_details : dict - Namespace of the breakdown of loss components - - Returns - ------- - loss_details : dict - Same as input but now includes val_* loss info - """ - logger.debug('Starting end-of-epoch validation loss calculation...') - loss_details['n_obs'] = 0 - for val_batch in batch_handler.val_data: - high_res_gen = self._tf_generate(val_batch.low_res, - val_batch.high_res[..., -1:]) - _, v_loss_details = self.calc_loss( - val_batch.high_res, high_res_gen, - weight_gen_advers=weight_gen_advers, - train_gen=False, train_disc=False) - - loss_details = self.update_loss_details(loss_details, - v_loss_details, - len(val_batch), - prefix='val_') - return loss_details diff --git a/sup3r/models/wind_conditional_moments.py b/sup3r/models/wind_conditional_moments.py deleted file mode 100644 index cf3a3dc2cb..0000000000 --- a/sup3r/models/wind_conditional_moments.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- -"""Wind conditional moment estimator with handling of low and -high res topography inputs.""" -import logging -import tensorflow as tf - -from sup3r.models.abstract import AbstractWindInterface -from sup3r.models.conditional_moments import Sup3rCondMom - - -logger = logging.getLogger(__name__) - - -class WindCondMom(AbstractWindInterface, Sup3rCondMom): - """Wind conditional moment estimator with handling of low and - high res topography inputs. - - Modifications to standard Sup3rCondMom: - - Hi res topography is expected as the last feature channel in the true - data in the true batch observation. - - If a custom Sup3rAdder or Sup3rConcat layer (from phygnn) is present - in the network, the hi-res topography will be added or concatenated - to the data at that point in the network during either training or - the forward pass. - """ - - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' - """ - AbstractWindInterface.set_model_params(**kwargs) - Sup3rCondMom.set_model_params(self, **kwargs) - - @tf.function - def calc_loss(self, hi_res_true, hi_res_gen, mask, **kwargs): - """Calculate the loss function using generated and true high - resolution data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - mask : tf.Tensor - Mask to apply - kwargs : dict - Key word arguments for: - Sup3rGan.calc_loss(hi_res_true, hi_res_gen, **kwargs) - - Returns - ------- - loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components - """ - - # append the true topography to the generated synthetic wind data - hi_res_gen = tf.concat((hi_res_gen, hi_res_true[..., -1:]), axis=-1) - - return super().calc_loss(hi_res_true, hi_res_gen, mask, **kwargs) - - def calc_val_loss(self, batch_handler, loss_details): - """Calculate the validation loss at the current state of model training - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandler - BatchHandler object to iterate through - loss_details : dict - Namespace of the breakdown of loss components - - Returns - ------- - loss_details : dict - Same as input but now includes val_* loss info - """ - logger.debug('Starting end-of-epoch validation loss calculation...') - loss_details['n_obs'] = 0 - for val_batch in batch_handler.val_data: - high_res_gen = self._tf_generate(val_batch.low_res, - val_batch.high_res[..., -1:]) - _, v_loss_details = self.calc_loss( - val_batch.output, high_res_gen, val_batch.mask) - - loss_details = self.update_loss_details(loss_details, - v_loss_details, - len(val_batch), - prefix='val_') - return loss_details diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index bcf86fc7e7..1cd24aa965 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -11,6 +11,7 @@ from concurrent.futures import as_completed from datetime import datetime as dt from inspect import signature +from typing import ClassVar import numpy as np import psutil @@ -19,16 +20,21 @@ import sup3r.bias.bias_transforms import sup3r.models -from sup3r.postprocessing.file_handling import (OutputHandler, OutputHandlerH5, - OutputHandlerNC) +from sup3r.postprocessing.file_handling import ( + OutputHandler, + OutputHandlerH5, + OutputHandlerNC, +) from sup3r.preprocessing.data_handling import ExogenousDataHandler from sup3r.preprocessing.data_handling.base import InputMixIn from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI from sup3r.utilities.execution import DistributedProcess -from sup3r.utilities.utilities import (get_chunk_slices, - get_input_handler_class, - get_source_type) +from sup3r.utilities.utilities import ( + get_chunk_slices, + get_input_handler_class, + get_source_type, +) np.random.seed(42) @@ -38,17 +44,15 @@ class ForwardPassSlicer: """Get slices for sending data chunks through model.""" - def __init__( - self, - coarse_shape, - time_steps, - temporal_slice, - chunk_shape, - s_enhancements, - t_enhancements, - spatial_pad, - temporal_pad, - ): + def __init__(self, + coarse_shape, + time_steps, + temporal_slice, + chunk_shape, + s_enhancements, + t_enhancements, + spatial_pad, + temporal_pad): """ Parameters ---------- @@ -88,12 +92,6 @@ def __init__( passes for subsequent temporal stitching. This overlap will pad both sides of the fwp_chunk_shape. Note that the first and last chunks in the temporal dimension will not be padded. - exo_s_enhancements : list - List of spatial enhancement steps specific to the exogenous_data - inputs. This differs from s_enhancements in that s_enhancements[0] - will be the spatial enhancement of the first model, but - exo_s_enhancements[0] may be 1 to signify exo data is required for - the first non-enhanced spatial input resolution. """ self.grid_shape = coarse_shape self.time_steps = time_steps @@ -316,17 +314,18 @@ def s_lr_crop_slices(self): if self._s_lr_crop_slices is None: self._s_lr_crop_slices = [] s1_crop_slices = self.get_cropped_slices(self.s1_lr_slices, - self.s1_lr_pad_slices, 1) + self.s1_lr_pad_slices, + 1) s2_crop_slices = self.get_cropped_slices(self.s2_lr_slices, - self.s2_lr_pad_slices, 1) + self.s2_lr_pad_slices, + 1) for i, _ in enumerate(self.s1_lr_slices): for j, _ in enumerate(self.s2_lr_slices): - lr_crop_slice = ( - s1_crop_slices[i], - s2_crop_slices[j], - slice(None), - slice(None), - ) + lr_crop_slice = (s1_crop_slices[i], + s2_crop_slices[j], + slice(None), + slice(None), + ) self._s_lr_crop_slices.append(lr_crop_slice) return self._s_lr_crop_slices @@ -594,24 +593,22 @@ class ForwardPassStrategy(InputMixIn, DistributedProcess): crop generator output to stich the chunks back togerther. """ - def __init__( - self, - file_paths, - model_kwargs, - fwp_chunk_shape, - spatial_pad, - temporal_pad, - model_class='Sup3rGan', - out_pattern=None, - input_handler=None, - input_handler_kwargs=None, - incremental=True, - worker_kwargs=None, - exo_kwargs=None, - bias_correct_method=None, - bias_correct_kwargs=None, - max_nodes=None, - ): + def __init__(self, + file_paths, + model_kwargs, + fwp_chunk_shape, + spatial_pad, + temporal_pad, + model_class='Sup3rGan', + out_pattern=None, + input_handler=None, + input_handler_kwargs=None, + incremental=True, + worker_kwargs=None, + exo_kwargs=None, + bias_correct_method=None, + bias_correct_kwargs=None, + max_nodes=None): """Use these inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator. @@ -691,9 +688,15 @@ def __init__( used to get the full time index. Doing this is parallel can be helpful when there are a large number of input files. exo_kwargs : dict | None - Dictionary of args to pass to ExogenousDataHandler for extracting - exogenous features such as topography for future multistep foward - pass + Dictionary of args to pass to :class:`ExogenousDataHandler` for + extracting exogenous features for multistep foward pass. This + should be a nested dictionary with keys for each exogeneous + feature. The dictionaries corresponding to the feature names + should include the path to exogenous data source, the resolution + of the exogenous data, and how the exogenous data should be used + in the model. e.g. {'topography': {'file_paths': 'path to input + files', 'source_file': 'path to exo data', 'exo_resolution': + {'spatial': '1km', 'temporal': None}, 'steps': [..]}. bias_correct_method : str | None Optional bias correction function name that can be imported from the :mod:`sup3r.bias.bias_transforms` module. This will transform @@ -718,14 +721,12 @@ def __init__( raster_index = self._input_handler_kwargs.get('raster_index', None) temporal_slice = self._input_handler_kwargs.get( 'temporal_slice', slice(None, None, 1)) - InputMixIn.__init__( - self, - target=target, - shape=grid_shape, - raster_file=raster_file, - raster_index=raster_index, - temporal_slice=temporal_slice, - ) + InputMixIn.__init__(self, + target=target, + shape=grid_shape, + raster_file=raster_file, + raster_index=raster_index, + temporal_slice=temporal_slice) self.file_paths = file_paths self.model_kwargs = model_kwargs @@ -777,23 +778,19 @@ def __init__( self.t_enhance = np.product(self.t_enhancements) self.output_features = model.output_features - self.fwp_slicer = ForwardPassSlicer( - self.grid_shape, - self.raw_tsteps, - self.temporal_slice, - self.fwp_chunk_shape, - self.s_enhancements, - self.t_enhancements, - self.spatial_pad, - self.temporal_pad, - ) + self.fwp_slicer = ForwardPassSlicer(self.grid_shape, + self.raw_tsteps, + self.temporal_slice, + self.fwp_chunk_shape, + self.s_enhancements, + self.t_enhancements, + self.spatial_pad, + self.temporal_pad) - DistributedProcess.__init__( - self, - max_nodes=max_nodes, - max_chunks=self.fwp_slicer.n_chunks, - incremental=self.incremental, - ) + DistributedProcess.__init__(self, + max_nodes=max_nodes, + max_chunks=self.fwp_slicer.n_chunks, + incremental=self.incremental) self.preflight() @@ -821,10 +818,9 @@ def preflight(self): logger.warning(msg) warnings.warn(msg) - hr_data_shape = ( - self.grid_shape[0] * self.s_enhance, - self.grid_shape[1] * self.s_enhance, - ) + hr_data_shape = (self.grid_shape[0] * self.s_enhance, + self.grid_shape[1] * self.s_enhance, + ) self.gids = np.arange(np.product(hr_data_shape)) self.gids = self.gids.reshape(hr_data_shape) @@ -845,13 +841,10 @@ def init_handler(self): """Get initial input handler used for extracting handler features and low res grid""" if self._init_handler is None: - out = self.input_handler_class( - self.file_paths[0], - [], - target=self.target, - shape=self.grid_shape, - worker_kwargs=dict(ti_workers=1), - ) + out = self.input_handler_class(self.file_paths[0], [], + target=self.target, + shape=self.grid_shape, + worker_kwargs=dict(ti_workers=1)) self._init_handler = out return self._init_handler @@ -1043,6 +1036,9 @@ class ForwardPass: through the GAN generator to produce high resolution output. """ + OUTPUT_HANDLER_CLASS: ClassVar = {'nc': OutputHandlerNC, + 'h5': OutputHandlerH5} + def __init__(self, strategy, chunk_index=0, node_index=0): """Initialize ForwardPass with ForwardPassStrategy. The stragegy provides the data chunks to run forward passes on @@ -1093,30 +1089,14 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.pass_workers = strategy.pass_workers self.output_workers = strategy.output_workers self.exo_kwargs = strategy.exo_kwargs - - self.exogenous_handler = None - self.exogenous_data = None - if self.exo_kwargs: - exo_features = self.exo_kwargs.get('features', []) - exo_kwargs = copy.deepcopy(self.exo_kwargs) - exo_kwargs['target'] = self.target - exo_kwargs['shape'] = self.shape - self.features = [f for f in self.features if f not in exo_features] - self.exogenous_handler = ExogenousDataHandler(**exo_kwargs) - self.exogenous_data = self.exogenous_handler.data - shapes = [ - None if d is None else d.shape for d in self.exogenous_data - ] - logger.info( - 'Got exogenous_data of length {} with shapes: {}'.format( - len(self.exogenous_data), shapes)) - + self.exo_features = ([] + if not self.exo_kwargs else list(self.exo_kwargs)) + self.exogenous_data = self.load_exo_data() self.input_handler_class = strategy.input_handler_class - - if strategy.output_type == 'nc': - self.output_handler_class = OutputHandlerNC - elif strategy.output_type == 'h5': - self.output_handler_class = OutputHandlerH5 + msg = f'Received bad output type {strategy.output_type}' + if strategy.output_type in list(self.OUTPUT_HANDLER_CLASS): + self.output_handler_class = self.OUTPUT_HANDLER_CLASS[ + strategy.output_type] input_handler_kwargs = self.update_input_handler_kwargs(strategy) @@ -1128,13 +1108,48 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.input_data = self.bias_correct_source_data( self.input_data, self.strategy.lr_lat_lon) - exo_s_en = self.exo_kwargs.get('s_enhancements', None) - out = self.pad_source_data(self.input_data, self.pad_width, - self.exogenous_data, exo_s_en) + out = self.pad_source_data(self.input_data, + self.pad_width, + self.exogenous_data) self.input_data, self.exogenous_data = out self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] + def load_exo_data(self): + """Extract exogenous data for each exo feature and store data in + dictionary with key for each exo feature + + Returns + ------- + exo_data : dict + Same as exo_kwargs dictionary with data arrays added to a 'data' + key for each feature + """ + exo_data = None + if self.exo_kwargs: + exo_data = self.exo_kwargs.copy() + self.features = [f for f in self.features + if f not in self.exo_features] + for feature in self.exo_features: + exo_kwargs = copy.deepcopy(self.exo_kwargs[feature]) + exo_kwargs['feature'] = feature + exo_kwargs['target'] = self.target + exo_kwargs['shape'] = self.shape + exo_kwargs['temporal_slice'] = self.ti_pad_slice + exo_kwargs['models'] = getattr(self.model, 'models', + [self.model]) + sig = signature(ExogenousDataHandler) + exo_kwargs = {k: v for k, v in exo_kwargs.items() + if k in sig.parameters} + data = ExogenousDataHandler(**exo_kwargs).data + for i, _ in enumerate(exo_kwargs['steps']): + exo_data[feature]['steps'][i]['data'] = data[i] + shapes = [None if d is None else d.shape for d in data] + logger.info( + 'Got exogenous_data of length {} with shapes: {}'.format( + len(data), shapes)) + return exo_data + def update_input_handler_kwargs(self, strategy): """Update the kwargs for the input handler for the current forward pass chunk @@ -1420,12 +1435,41 @@ def pad_width(self): return ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end), (pad_t_start, pad_t_end)) - @staticmethod - def pad_source_data(input_data, - pad_width, - exo_data, - exo_s_enhancements, - mode='reflect'): + def _get_step_enhance(self, step): + """Get enhancement factors for a given step and combine type. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + s_enhance : int + Spatial enhancement factor for given step and combine type + t_enhance : int + Temporal enhancement factor for given step and combine type + """ + combine_type = step['combine_type'] + model_step = step['model'] + if combine_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = np.product( + self.strategy.s_enhancements[:model_step]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step]) + + elif combine_type.lower() in ('output', 'layer'): + s_enhance = np.product( + self.strategy.s_enhancements[:model_step + 1]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step + 1]) + return s_enhance, t_enhance + + def pad_source_data(self, input_data, pad_width, exo_data, mode='reflect'): """Pad the edges of the source data from the data handler. Parameters @@ -1433,56 +1477,49 @@ def pad_source_data(input_data, input_data : np.ndarray Source input data from data handler class, shape is: (spatial_1, spatial_2, temporal, features) - spatial_pad : int - Size of spatial overlap between coarse chunks passed to forward - passes for subsequent spatial stitching. This overlap will pad both - sides of the fwp_chunk_shape. Note that the first and last chunks - in any of the spatial dimension will not be padded. pad_width : tuple Tuple of tuples with padding width for spatial and temporal dimensions. Each tuple includes the start and end of padding for that dimension. Ordering is spatial_1, spatial_2, temporal. - exo_data : None | list - List of exogenous data arrays for each step of the sup3r resolution - model. List entries can be None if not exo data is requested for a - given model step. - exo_s_enhancements : list - List of spatial enhancement factors for each step of the sup3r - resolution model corresponding to the exo_data order. + exo_data: dict + Full exo_kwargs dictionary with all feature entries. + e.g. {'topography': {'exo_resolution': {'spatial': '1km', + 'temporal': None}, 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}]}} mode : str - Padding mode for np.pad(). Reflect is a good default for the - convolutional sup3r work. + Mode to use for padding. e.g. 'reflect'. Returns ------- out : np.ndarray Padded copy of source input data from data handler class, shape is: (spatial_1, spatial_2, temporal, features) - exo_data : list | None - Padded copy of exo_data input. + exo_data : dict + Same as input dictionary with s_agg_factor, t_agg_factor, + s_enhance, t_enhance added to each step entry for all features + """ out = np.pad(input_data, (*pad_width, (0, 0)), mode=mode) logger.info('Padded input data shape from {} to {} using mode "{}" ' 'with padding argument: {}'.format(input_data.shape, - out.shape, mode, + out.shape, + mode, pad_width)) if exo_data is not None: - for i, i_exo_data in enumerate(exo_data): - if i_exo_data is not None: - total_s_enhance = exo_s_enhancements[:i + 1] - total_s_enhance = [ - s for s in total_s_enhance if s is not None - ] - total_s_enhance = np.product(total_s_enhance) - exo_pad_width = ((total_s_enhance * pad_width[0][0], - total_s_enhance * pad_width[0][1]), - (total_s_enhance * pad_width[1][0], - total_s_enhance * pad_width[1][1]), (0, - 0)) - exo_data[i] = np.pad(i_exo_data, exo_pad_width, mode=mode) - + for feature in exo_data: + for i, step in enumerate(exo_data[feature]['steps']): + s_enhance, t_enhance = self._get_step_enhance(step) + exo_pad_width = ((s_enhance * pad_width[0][0], + s_enhance * pad_width[0][1]), + (s_enhance * pad_width[1][0], + s_enhance * pad_width[1][1]), + (t_enhance * pad_width[2][0], + t_enhance * pad_width[2][1]), + (0, 0)) + new_exo = np.pad(step['data'], exo_pad_width, mode=mode) + exo_data[feature]['steps'][i]['data'] = new_exo return out, exo_data def bias_correct_source_data(self, data, lat_lon): @@ -1522,62 +1559,22 @@ def bias_correct_source_data(self, data, lat_lon): 'using function: {} with kwargs: {}'.format( feature, idf, method, feature_kwargs)) - data[..., idf] = method(data[..., idf], lat_lon, + data[..., idf] = method(data[..., idf], + lat_lon, **feature_kwargs) return data - def _prep_exogenous_input(self, chunk_shape): - """Shape exogenous data according to model type and model steps - - Parameters - ---------- - chunk_shape : tuple - Shape of data chunk going through forward pass - - Returns - ------- - exo_data : list - List of arrays of exogenous data. If there are 2 spatial - enhancement steps this is a list of 3 arrays each with the - appropriate shape based on the enhancement factor - """ - exo_data = [] - if self.exogenous_data is not None: - for arr in self.exogenous_data: - if arr is not None: - og_shape = arr.shape - arr = np.expand_dims(arr, axis=2) - arr = np.repeat(arr, chunk_shape[2], axis=2) - - target_shape = ( - arr.shape[0], - arr.shape[1], - chunk_shape[2], - arr.shape[-1], - ) - msg = ('Target shape for exogenous data in forward pass ' - 'chunk was {}, but something went wrong and i ' - 'resized original data shape from {} to {}'.format( - target_shape, og_shape, arr.shape)) - assert arr.shape == target_shape, msg - - exo_data.append(arr) - - return exo_data - @classmethod - def _run_generator( - cls, - data_chunk, - hr_crop_slices, - model=None, - model_kwargs=None, - model_class=None, - s_enhance=None, - t_enhance=None, - exo_data=None, - ): + def _run_generator(cls, + data_chunk, + hr_crop_slices, + model=None, + model_kwargs=None, + model_class=None, + s_enhance=None, + t_enhance=None, + exo_data=None): """Run forward pass of the generator on smallest data chunk. Each chunk has a maximum shape given by self.strategy.fwp_chunk_shape. @@ -1611,10 +1608,15 @@ def _run_generator( Factor by which to enhance temporal resolution s_enhance : int Factor by which to enhance spatial resolution - exo_data : list | None - List of arrays of exogenous data for each model step. - If there are two spatial enhancement steps this is a list of length - 3 with arrays for each intermediate spatial resolution. + exo_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} Returns ------- @@ -1672,9 +1674,15 @@ def _reshape_data_chunk(model, data_chunk, exo_data): data_chunk : np.ndarray Low resolution data for a single spatiotemporal chunk that is going to be passed to the model generate function. - exo_data : list | None - Optional exogenous data which can be a list of arrays of exogenous - inputs to complement data_chunk + exo_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} Returns ------- @@ -1683,27 +1691,27 @@ def _reshape_data_chunk(model, data_chunk, exo_data): features) if the model is a spatial-first model or (n_obs, spatial_1, spatial_2, temporal, features) if the model is spatiotemporal - exo_data : list | None - Same reshaping procedure as for data_chunk + exo_data : dict | None + Same reshaping procedure as for data_chunk applied to + exo_data[feature]['steps'][...]['data'] i_lr_t : int Axis index for the low-resolution temporal dimension i_lr_s : int Axis index for the low-resolution spatial_1 dimension """ - current_model = None if exo_data is not None: - for i, arr in enumerate(exo_data): - if arr is not None: - if not hasattr(model, 'models'): - current_model = model - elif i < len(model.models): - current_model = model.models[i] - - if current_model is not None: - if current_model.input_dims == 4: - exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3)) - else: - exo_data[i] = np.expand_dims(arr, axis=0) + for feature in exo_data: + for i, entry in enumerate(exo_data[feature]['steps']): + models = getattr(model, 'models', [model]) + msg = (f'model index ({entry["model"]}) for exo step {i} ' + 'exceeds the number of model steps') + assert entry['model'] < len(models), msg + current_model = models[entry['model']] + if current_model.input_dims == 4: + out = np.transpose(entry['data'], axes=(2, 0, 1, 3)) + else: + out = np.expand_dims(entry['data'], axis=0) + exo_data[feature]['steps'][i]['data'] = out if model.input_dims == 4: i_lr_t = 0 @@ -1857,11 +1865,10 @@ def _run_serial(cls, strategy, node_index): 'serial.') for i, chunk_index in enumerate(strategy.node_chunks[node_index]): now = dt.now() - cls._single_proc_run( - strategy=strategy, - node_index=node_index, - chunk_index=chunk_index, - ) + cls._single_proc_run(strategy=strategy, + node_index=node_index, + chunk_index=chunk_index, + ) mem = psutil.virtual_memory() logger.info('Finished forward pass on chunk_index=' f'{chunk_index} in {dt.now() - now}. {i + 1} of ' @@ -1898,15 +1905,13 @@ def _run_parallel(cls, strategy, node_index): with SpawnProcessPool(**pool_kws) as exe: now = dt.now() for _i, chunk_index in enumerate(strategy.node_chunks[node_index]): - fut = exe.submit( - cls._single_proc_run, - strategy=strategy, - node_index=node_index, - chunk_index=chunk_index, - ) + fut = exe.submit(cls._single_proc_run, + strategy=strategy, + node_index=node_index, + chunk_index=chunk_index, + ) futures[fut] = { - 'chunk_index': chunk_index, - 'start_time': dt.now(), + 'chunk_index': chunk_index, 'start_time': dt.now(), } logger.info(f'Started {len(futures)} forward pass runs in ' @@ -1943,20 +1948,15 @@ def run_chunk(self): f'{self.strategy.temporal_pad}.') logger.info(msg) - data_chunk = self.input_data - exo_data = None - if self.exogenous_data is not None: - exo_data = self._prep_exogenous_input(data_chunk.shape) - self.output_data = self._run_generator( - data_chunk, + self.input_data, hr_crop_slices=self.hr_crop_slice, model=self.model, model_kwargs=self.model_kwargs, model_class=self.model_class, s_enhance=self.s_enhance, t_enhance=self.t_enhance, - exo_data=exo_data, + exo_data=self.exogenous_data, ) self._constant_output_check(self.output_data) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py new file mode 100644 index 0000000000..6d8c6bdfe3 --- /dev/null +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -0,0 +1,480 @@ +"""Sup3r topography utilities""" + +import logging +from abc import ABC, abstractmethod + +import numpy as np +from rex import Resource +from rex.utilities.solar_position import SolarPosition +from scipy.spatial import KDTree + +import sup3r.preprocessing.data_handling +from sup3r.postprocessing.file_handling import OutputHandler +from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 +from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC +from sup3r.utilities.utilities import get_source_type + +logger = logging.getLogger(__name__) + + +class ExoExtract(ABC): + """Class to extract high-res (4km+) data rasters for new + spatially-enhanced datasets (e.g. GCM files after spatial enhancement) + using nearest neighbor mapping and aggregation from NREL datasets + (e.g. WTK or NSRDB) + """ + + def __init__(self, + file_paths, + exo_source, + s_enhance, + t_enhance, + s_agg_factor, + t_agg_factor, + target=None, + shape=None, + temporal_slice=None, + raster_file=None, + max_delta=20, + input_handler=None, + ti_workers=1): + """ + Parameters + ---------- + file_paths : str | list + A single source h5 file to extract raster data from or a list + of netcdf files with identical grid. The string can be a unix-style + file path which will be passed through glob.glob. This is + typically low-res WRF output or GCM netcdf data files that is + source low-resolution data intended to be sup3r resolved. + exo_source : str + Filepath to source wtk or nsrdb file to get hi-res (2km or 4km) + elevation data from which will be mapped to the enhanced grid of + the file_paths input + s_enhance : int + Factor by which the Sup3rGan model will enhance the spatial + dimensions of low resolution data from file_paths input. For + example, if getting topography data, file_paths has 100km data, and + s_enhance is 4, this class will output a topography raster + corresponding to the file_paths grid enhanced 4x to ~25km + t_enhance : int + Factor by which the Sup3rGan model will enhance the temporal + dimension of low resolution data from file_paths input. For + example, if getting sza data, file_paths has hourly data, and + t_enhance is 4, this class will output a sza raster + corresponding to the file_paths temporally enhanced 4x to 15 min + s_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by s_enhance. For example, if + getting topography data, file_paths have 100km data, and s_enhance + is 4 resulting in a desired resolution of ~25km and topo_source_h5 + has a resolution of 4km, the s_agg_factor should be 36 so that 6x6 + 4km cells are averaged to the ~25km enhanced grid. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by t_enhance. For example, if + getting sza data, file_paths have hourly data, and t_enhance + is 4 resulting in a desired resolution of 5 min and exo_source + has a resolution of 5 min, the t_agg_factor should be 4 so that + every fourth timestep in the exo_source data is skipped. + target : tuple + (lat, lon) lower left corner of raster. Either need target+shape or + raster_file. + shape : tuple + (rows, cols) grid size. Either need target+shape or raster_file. + temporal_slice : slice | None + slice used to extract interval from temporal dimension for input + data and source data + raster_file : str | None + File for raster_index array for the corresponding target and shape. + If specified the raster_index will be loaded from the file if it + exists or written to the file if it does not yet exist. If None + raster_index will be calculated directly. Either need target+shape + or raster_file. + max_delta : int, optional + Optional maximum limit on the raster shape that is retrieved at + once. If shape is (20, 20) and max_delta=10, the full raster will + be retrieved in four chunks of (10, 10). This helps adapt to + non-regular grids that curve over large distances, by default 20 + input_handler : str + data handler class to use for input data. Provide a string name to + match a class in data_handling.py. If None the correct handler will + be guessed based on file type and time series properties. + ti_workers : int | None + max number of workers to use to get full time index. Useful when + there are many input files each with a single time step. If this is + greater than one, time indices for input files will be extracted in + parallel and then concatenated to get the full time index. If input + files do not all have time indices or if there are few input files + this should be set to one. + """ + + logger.info(f'Initializing {self.__class__.__name__} utility.') + + self.ti_workers = ti_workers + self._exo_source = exo_source + self._s_enhance = s_enhance + self._t_enhance = t_enhance + self._s_agg_factor = s_agg_factor + self._t_agg_factor = t_agg_factor + self._tree = None + self._hr_lat_lon = None + self._source_lat_lon = None + self._hr_time_index = None + self._src_time_index = None + + if input_handler is None: + in_type = get_source_type(file_paths) + if in_type == 'nc': + input_handler = DataHandlerNC + elif in_type == 'h5': + input_handler = DataHandlerH5 + else: + msg = (f'Did not recognize input type "{in_type}" for file ' + f'paths: {file_paths}') + logger.error(msg) + raise RuntimeError(msg) + elif isinstance(input_handler, str): + input_handler = getattr(sup3r.preprocessing.data_handling, + input_handler, None) + if input_handler is None: + msg = ('Could not find requested data handler class ' + f'"{input_handler}" in ' + 'sup3r.preprocessing.data_handling.') + logger.error(msg) + raise KeyError(msg) + + self.input_handler = input_handler( + file_paths, [], + target=target, + shape=shape, + temporal_slice=temporal_slice, + raster_file=raster_file, + max_delta=max_delta, + worker_kwargs=dict(ti_workers=ti_workers), + ) + + @property + @abstractmethod + def source_data(self): + """Get the 1D array of source data from the exo_source_h5""" + + @property + def source_temporal_slice(self): + """Get the temporal slice for the exo_source data corresponding to the + input file temporal slice""" + start_index = self.source_time_index.get_indexer( + [self.input_handler.hr_time_index[0]], method='nearest')[0] + end_index = self.source_time_index.get_indexer( + [self.input_handler.hr_time_index[-1]], method='nearest')[0] + return slice(start_index, end_index + 1, self._t_agg_factor) + + @property + def source_lat_lon(self): + """Get the 2D array (n, 2) of lat, lon data for the exo source""" + if self._source_lat_lon is None: + src_enhance = int(np.sqrt(self._s_agg_factor)) + src_shape = (self.hr_shape[0] * src_enhance, + self.hr_shape[1] * src_enhance) + self._source_lat_lon = OutputHandler.get_lat_lon( + self.lr_lat_lon, src_shape).reshape((-1, 2)) + return self._source_lat_lon + + @property + def lr_shape(self): + """Get the low-resolution spatial shape tuple""" + return (self.lr_lat_lon.shape[0], self.lr_lat_lon.shape[1], + len(self.input_handler.time_index)) + + @property + def hr_shape(self): + """Get the high-resolution spatial shape tuple""" + return (self._s_enhance * self.lr_lat_lon.shape[0], + self._s_enhance * self.lr_lat_lon.shape[1], + self._t_enhance * len(self.input_handler.time_index)) + + @property + def lr_lat_lon(self): + """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon + array with same ordering in last dimension. This corresponds to the raw + low-resolution meta data from the file_paths input. + + Returns + ------- + ndarray + """ + return self.input_handler.lat_lon + + @property + def hr_lat_lon(self): + """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon + array with same ordering in last dimension. This corresponds to the + enhanced high-res meta data from the file_paths input * s_enhance. + + Returns + ------- + ndarray + """ + if self._hr_lat_lon is None: + if self._s_enhance > 1: + self._hr_lat_lon = OutputHandler.get_lat_lon( + self.lr_lat_lon, self.hr_shape[:-1]) + else: + self._hr_lat_lon = self.lr_lat_lon + return self._hr_lat_lon + + @property + def source_time_index(self): + """Get the full time index of the exo_source data""" + if self._src_time_index is None: + if self._t_agg_factor > 1: + self._src_time_index = OutputHandler.get_times( + self.input_handler.time_index, + self.hr_shape[-1] * self._t_agg_factor) + else: + self._src_time_index = self.hr_time_index + return self._src_time_index + + @property + def hr_time_index(self): + """Get the full time index for aggregated source data""" + if self._hr_time_index is None: + if self._t_enhance > 1: + self._hr_time_index = OutputHandler.get_times( + self.input_handler.time_index, self.hr_shape[-1]) + else: + self._hr_time_index = self.input_handler.time_index + return self._hr_time_index + + @property + def tree(self): + """Get the KDTree built on the source lat lon data""" + if self._tree is None: + self._tree = KDTree(self.source_lat_lon) + return self._tree + + @property + def nn(self): + """Get the nearest neighbor indices""" + ll2 = np.vstack( + (self.hr_lat_lon[:, :, 0].flatten(), + self.hr_lat_lon[:, :, 1].flatten())).T + _, nn = self.tree.query(ll2, k=self._s_agg_factor) + if len(nn.shape) == 1: + nn = np.expand_dims(nn, 1) + return nn + + @property + def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) + """ + nn = self.nn + hr_data = [] + for j in range(self._s_agg_factor): + out = self.source_data[nn[:, j], self.source_temporal_slice] + out = out.reshape(self.hr_shape) + hr_data.append(out[..., np.newaxis]) + hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) + logger.info('Finished mapping raster from {}'.format(self._exo_source)) + return hr_data[..., np.newaxis] + + @classmethod + def get_exo_raster(cls, + file_paths, + s_enhance, + t_enhance, + s_agg_factor, + t_agg_factor, + exo_source=None, + target=None, + shape=None, + temporal_slice=None, + raster_file=None, + max_delta=20, + input_handler=None): + """Get the exo feature raster corresponding to the spatially enhanced + grid from the file_paths input + + Parameters + ---------- + file_paths : str | list + A single source h5 file to extract raster data from or a list + of netcdf files with identical grid. The string can be a unix-style + file path which will be passed through glob.glob + s_enhance : int + Factor by which the Sup3rGan model will enhance the spatial + dimensions of low resolution data from file_paths input. For + example, if file_paths has 100km data and s_enhance is 4, this + class will output a topography raster corresponding to the + file_paths grid enhanced 4x to ~25km + t_enhance : int + Factor by which the Sup3rGan model will enhance the temporal + dimension of low resolution data from file_paths input. For + example, if getting sza data, file_paths has hourly data, and + t_enhance is 4, this class will output a sza raster + corresponding to the file_paths temporally enhanced 4x to 15 min + s_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by s_enhance. For example, if + getting topography data, file_paths have 100km data, and s_enhance + is 4 resulting in a desired resolution of ~25km and topo_source_h5 + has a resolution of 4km, the s_agg_factor should be 36 so that 6x6 + 4km cells are averaged to the ~25km enhanced grid. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by t_enhance. For example, if + getting sza data, file_paths have hourly data, and t_enhance + is 4 resulting in a desired resolution of 5 min and exo_source + has a resolution of 5 min, the t_agg_factor should be 4 so that + every fourth timestep in the exo_source data is skipped. + exo_source : str + Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or + 4km) data from which will be mapped to the enhanced grid of the + file_paths input + target : tuple + (lat, lon) lower left corner of raster. Either need target+shape or + raster_file. + shape : tuple + (rows, cols) grid size. Either need target+shape or raster_file. + temporal_slice : slice | None + slice used to extract interval from temporal dimension for input + data and source data + raster_file : str | None + File for raster_index array for the corresponding target and shape. + If specified the raster_index will be loaded from the file if it + exists or written to the file if it does not yet exist. If None + raster_index will be calculated directly. Either need target+shape + or raster_file. + max_delta : int, optional + Optional maximum limit on the raster shape that is retrieved at + once. If shape is (20, 20) and max_delta=10, the full raster will + be retrieved in four chunks of (10, 10). This helps adapt to + non-regular grids that curve over large distances, by default 20 + input_handler : str + data handler class to use for input data. Provide a string name to + match a class in data_handling.py. If None the correct handler will + be guessed based on file type and time series properties. + + Returns + ------- + exo_raster : np.ndarray + Exo feature raster with shape (hr_rows, hr_cols, h_temporal) + corresponding to the shape of the spatiotemporally enhanced data + from file_paths * s_enhance * t_enhance. The data units correspond + to the source units in exo_source_h5. This is usually meters when + feature='topography' + """ + exo = cls(file_paths, + s_enhance, + t_enhance, + s_agg_factor, + t_agg_factor, + exo_source=exo_source, + target=target, + shape=shape, + temporal_slice=temporal_slice, + raster_file=raster_file, + max_delta=max_delta, + input_handler=input_handler) + return exo.data + + +class TopoExtractH5(ExoExtract): + """TopoExtract for H5 files""" + + @property + def source_data(self): + """Get the 1D array of elevation data from the exo_source_h5""" + with Resource(self._exo_source) as res: + elev = res.get_meta_arr('elevation') + elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) + return elev + + @property + def source_lat_lon(self): + """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" + with Resource(self._exo_source) as res: + source_lat_lon = res.lat_lon + return source_lat_lon + + @property + def source_time_index(self): + """Time index of the source exo data""" + if self._src_time_index is None: + with Resource(self._exo_source) as res: + self._src_time_index = res.time_index + return self._src_time_index + + @property + def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) + """ + nn = self.nn + hr_data = [] + for j in range(self._s_agg_factor): + out = self.source_data[nn[:, j]] + out = out.reshape(self.hr_shape) + hr_data.append(out[..., np.newaxis]) + hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) + logger.info('Finished mapping raster from {}'.format(self._exo_source)) + return hr_data[..., np.newaxis] + + +class TopoExtractNC(TopoExtractH5): + """TopoExtract for netCDF files""" + + def __init__(self, *args, **kwargs): + """ + Parameters + ---------- + args : list + Same positional arguments as TopoExtract + kwargs : dict + Same keyword arguments as TopoExtract + """ + + super().__init__(*args, **kwargs) + logger.info('Getting topography for full domain from ' + f'{self._exo_source}') + self.source_handler = DataHandlerNC( + self._exo_source, + features=['topography'], + worker_kwargs=dict(ti_workers=self.ti_workers), + val_split=0.0, + ) + + @property + def source_data(self): + """Get the 1D array of elevation data from the exo_source_nc""" + elev = self.source_handler.data.reshape((-1, self.lr_shape[-1])) + return elev + + @property + def source_lat_lon(self): + """Get the 2D array (n, 2) of lat, lon data from the exo_source_nc""" + source_lat_lon = self.source_handler.lat_lon.reshape((-1, 2)) + return source_lat_lon + + +class SzaExtract(ExoExtract): + """SzaExtract for H5 files""" + + @property + def source_data(self): + """Get the 1D array of sza data from the exo_source_h5""" + return SolarPosition(self.hr_time_index, + self.hr_lat_lon.reshape((-1, 2))).zenith.T + + @property + def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) + """ + hr_data = self.source_data.reshape(self.hr_shape) + logger.info('Finished computing SZA data') + return hr_data[..., np.newaxis] diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index c54b44acb7..41791006aa 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -1,14 +1,19 @@ """Sup3r exogenous data handling""" +import logging import os +import pickle +import re import shutil -import logging +from typing import ClassVar + import numpy as np -import pickle -from warnings import warn -from sup3r.utilities.topo import TopoExtractH5, TopoExtractNC -import sup3r.preprocessing.data_handling -import sup3r.utilities.topo +from sup3r.preprocessing.data_handling import exo_extraction +from sup3r.preprocessing.data_handling.exo_extraction import ( + SzaExtract, + TopoExtractH5, + TopoExtractNC, +) from sup3r.utilities.utilities import get_source_type logger = logging.getLogger(__name__) @@ -19,10 +24,33 @@ class ExogenousDataHandler: Multiple topography arrays at different resolutions for multiple spatial enhancement steps.""" - def __init__(self, file_paths, features, source_file, s_enhancements, - agg_factors, target=None, shape=None, raster_file=None, - max_delta=20, input_handler=None, topo_handler=None, - exo_steps=None, cache_data=True): + AVAILABLE_HANDLERS: ClassVar[dict] = { + 'topography': { + 'h5': TopoExtractH5, + 'nc': TopoExtractNC + }, + 'sza': { + 'h5': SzaExtract, + 'nc': SzaExtract + } + } + + def __init__(self, + file_paths, + feature, + steps, + models=None, + exo_resolution=None, + source_file=None, + target=None, + shape=None, + temporal_slice=None, + raster_file=None, + max_delta=20, + input_handler=None, + exo_handler=None, + cache_data=True, + cache_dir='./exo_cache'): """ Parameters ---------- @@ -32,195 +60,455 @@ def __init__(self, file_paths, features, source_file, s_enhancements, through glob.glob. This is typically low-res WRF output or GCM netcdf data that is source low-resolution data intended to be sup3r resolved. - features : list - List of exogenous features to extract from source_h5 + feature : str + Exogenous feature to extract from source_h5 + models : list + List of models used with the given steps list. This list of models + is used to determine the input and output resolution and + enhancement factors for each model step which is then used to + determine aggregation factors. If agg factors and enhancement + factors are provided in the steps list the model list is not + needed. + steps : list + List of dictionaries containing info on which models to use for a + given step index and what type of exo data the step requires. e.g. + [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}] + Each step entry can also contain s_enhance, t_enhance, + s_agg_factor, t_agg_factor. e.g. + [{'model': 0, 'combine_type': 'input', 's_agg_factor': 900, + 's_enhance': 1, 't_agg_factor': 5, 't_enhance': 1}, + {'model': 0, 'combine_type': 'layer', 's_agg_factor', 100, + 's_enhance': 3, 't_agg_factor': 5, 't_enhance': 1}] + If they are not included they will be computed using exo_resolution + and model attributes. + exo_resolution : dict + Dictionary of spatiotemporal resolution for the given exo data + source. e.g. {'spatial': '4km', 'temporal': '60min'}. This is used + only if agg factors are not provided in the steps list. source_file : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the file_paths input - s_enhancements : list - List of factors by which the Sup3rGan model will enhance the - spatial dimensions of low resolution data from file_paths input - where the total spatial enhancement is the product of these - factors. For example, if file_paths has 100km data and there are 2 - spatial enhancement steps of 4x and 5x to a nominal resolution of - 5km, s_enhancements should be [1, 4, 5] and exo_steps should be - [0, 1, 2] so that the input to the 4x model gets exogenous data - at 100km (s_enhance=1, exo_step=0), the input to the 5x model gets - exogenous data at 25km (s_enhance=4, exo_step=1), and there is a - 20x (1*4*5) exogeneous data layer available if the second model can - receive a high-res input feature (e.g. WindGan). The length of this - list should be equal to the number of agg_factors and the number of - exo_steps - agg_factors : list - List of factors by which to aggregate the topo_source_h5 elevation - data to the resolution of the file_paths input enhanced by - s_enhance. The length of this list should be equal to the number of - s_enhancements and the number of exo_steps target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. shape : tuple (rows, cols) grid size. Either need target+shape or raster_file. + temporal_slice : slice | None + slice used to extract interval from temporal dimension for input + data and source data raster_file : str | None File for raster_index array for the corresponding target and shape. If specified the raster_index will be loaded from the file if it exists or written to the file if it does not yet exist. If None raster_index will be calculated directly. Either need target+shape or raster_file. + max_delta : int, optional + Optional maximum limit on the raster shape that is retrieved at + once. If shape is (20, 20) and max_delta=10, the full raster will + be retrieved in four chunks of (10, 10). This helps adapt to + non-regular grids that curve over large distances, by default 20 input_handler : str data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. - topo_handler : str - topo extract class to use for source data. Provide a string name to - match a class in topo.py. If None the correct handler will - be guessed based on file type and time series properties. - exo_steps : list - List of model step indices for which exogenous data is required. - e.g. If we have two model steps which take exo data and one which - does not exo_steps = [0, 1]. The length of this list should be - equal to the number of s_enhancements and the number of agg_factors + exo_handler : str + Feature extract class to use for source data. For example, if + feature='topography' this should be either TopoExtractH5 or + TopoExtractNC. If None the correct handler will be guessed based on + file type and time series properties. cache_data : bool - Flag to cache exogeneous data in ./exo_cache/ this can speed up - forward passes with large temporal extents + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents + cache_dir : str + Directory for storing cache data. Default is './exo_cache' """ - self.features = features - self.s_enhancements = s_enhancements - self.agg_factors = agg_factors + self.feature = feature + self.steps = steps + self.models = models + self.exo_res = exo_resolution self.source_file = source_file self.file_paths = file_paths - self.topo_handler = topo_handler + self.exo_handler = exo_handler + self.temporal_slice = temporal_slice self.target = target self.shape = shape self.raster_file = raster_file self.max_delta = max_delta self.input_handler = input_handler self.cache_data = cache_data + self.cache_dir = cache_dir self.data = [] - exo_steps = exo_steps or np.arange(len(self.s_enhancements)) - if self.s_enhancements[0] != 1: - msg = ('s_enhancements typically starts with 1 so the first ' - 'exogenous data input matches the spatial resolution of ' - 'the source low-res input data, but received ' - 's_enhancements: {}'.format(self.s_enhancements)) - logger.warning(msg) - warn(msg) + self.input_check() + agg_enhance = self._get_all_agg_and_enhancement() + self.s_enhancements = agg_enhance['s_enhancements'] + self.t_enhancements = agg_enhance['t_enhancements'] + self.s_agg_factors = agg_enhance['s_agg_factors'] + self.t_agg_factors = agg_enhance['t_agg_factors'] msg = ('Need to provide the same number of enhancement factors and ' - f'agg factors. Received s_enhancements={s_enhancements} and ' - f'agg_factors={agg_factors}.') - assert len(self.s_enhancements) == len(self.agg_factors), msg + f'agg factors. Received s_enhancements={self.s_enhancements}, ' + f'and s_agg_factors={self.s_agg_factors}.') + assert len(self.s_enhancements) == len(self.s_agg_factors), msg + msg = ('Need to provide the same number of enhancement factors and ' + f'agg factors. Received t_enhancements={self.t_enhancements}, ' + f'and t_agg_factors={self.t_agg_factors}.') + assert len(self.t_enhancements) == len(self.t_agg_factors), msg msg = ('Need to provide an integer enhancement factor for each model' 'step. If the step is temporal enhancement then s_enhance=1') assert not any(s is None for s in self.s_enhancements), msg - for i in range(len(self.s_enhancements)): - s_enhance = np.product(self.s_enhancements[:i + 1]) - agg_factor = self.agg_factors[i] - fdata = [] - if i in exo_steps: - for f in features: - if f == 'topography': - data = self.get_topo_data(s_enhance, agg_factor) - fdata.append(data) - else: - msg = (f"Can only extract topography. Recived {f}.") - raise NotImplementedError(msg) - self.data.append(np.stack(fdata, axis=-1)) + for i, _ in enumerate(self.s_enhancements): + s_enhance = self.s_enhancements[i] + t_enhance = self.t_enhancements[i] + s_agg_factor = self.s_agg_factors[i] + t_agg_factor = self.t_agg_factors[i] + if feature in list(self.AVAILABLE_HANDLERS): + data = self.get_exo_data(feature=feature, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor) + self.data.append(data) else: - self.data.append(None) + msg = (f"Can only extract {list(self.AVAILABLE_HANDLERS)}." + f" Received {feature}.") + raise NotImplementedError(msg) - def get_topo_data(self, s_enhance, agg_factor): - """Get the exogenous topography data + def input_check(self): + """Make sure agg factors are provided or exo_resolution and models are + provided. Make sure enhancement factors are provided or models are + provided""" + agg_check = all('s_agg_factor' in v for v in self.steps) + agg_check = agg_check and all('t_agg_factor' in v for v in self.steps) + agg_check = (agg_check + or self.models is not None and self.exo_res is not None) + msg = ("ExogenousDataHandler needs s_agg_factor and t_agg_factor " + "provided in each step in steps list or models and " + "exo_resolution") + assert agg_check, msg + en_check = all('s_enhance' in v for v in self.steps) + en_check = en_check and all('t_enhance' in v for v in self.steps) + en_check = en_check or self.models is not None + msg = ("ExogenousDataHandler needs s_enhance and t_enhance " + "provided in each step in steps list or models") + assert en_check, msg + + def _get_res_ratio(self, input_res, exo_res): + """Compute resolution ratio given input and output resolution + + Parameters + ---------- + input_res : str | None + Input resolution. e.g. '30km' or '60min' + exo_res : str | None + Exo resolution. e.g. '1km' or '5min' + + Returns + ------- + res_ratio : int | None + Ratio of input / exo resolution + """ + ires_num = (None if input_res is None + else int(re.search(r'\d+', input_res).group(0))) + eres_num = (None if exo_res is None + else int(re.search(r'\d+', exo_res).group(0))) + i_units = (None if input_res is None + else input_res.replace(str(ires_num), '')) + e_units = (None if exo_res is None + else exo_res.replace(str(eres_num), '')) + msg = 'Received conflicting units for input and exo resolution' + if e_units is not None: + assert i_units == e_units, msg + if ires_num is not None and eres_num is not None: + res_ratio = int(ires_num / eres_num) + else: + res_ratio = None + return res_ratio + + def get_agg_factors(self, input_res, exo_res): + """Compute aggregation ratio for exo data given input and output + resolution + + Parameters + ---------- + input_res : dict | None + Input resolution. e.g. {'spatial': '30km', 'temporal': '60min'} + exo_res : dict | None + Exogenous data resolution. e.g. + {'spatial': '1km', 'temporal': '5min'} + + Returns + ------- + s_agg_factor : int + Spatial aggregation factor for exogenous data extraction. + t_agg_factor : int + Temporal aggregation factor for exogenous data extraction. + """ + input_s_res = None if input_res is None else input_res['spatial'] + exo_s_res = None if exo_res is None else exo_res['spatial'] + s_res_ratio = self._get_res_ratio(input_s_res, exo_s_res) + s_agg_factor = None if s_res_ratio is None else int(s_res_ratio)**2 + input_t_res = None if input_res is None else input_res['temporal'] + exo_t_res = None if exo_res is None else exo_res['temporal'] + t_agg_factor = self._get_res_ratio(input_t_res, exo_t_res) + return s_agg_factor, t_agg_factor + + def _get_single_step_agg(self, step): + """Compute agg factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + exo_resolution and the input/output resolution of each model step. If + agg factors are already provided in step they are not overwritten. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + updated_step : dict + Same as input dictionary with s_agg_factor, t_agg_factor added + """ + if all(key in step for key in ['s_agg_factor', 't_agg_factor']): + return step + + model_step = step['model'] + combine_type = step.get('combine_type', None) + msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.models)})') + assert len(self.models) > model_step, msg + model = self.models[model_step] + input_res = model.input_resolution + output_res = model.output_resolution + if combine_type.lower() == 'input': + s_agg_factor, t_agg_factor = self.get_agg_factors( + input_res, self.exo_res) + + elif combine_type.lower() in ('output', 'layer'): + s_agg_factor, t_agg_factor = self.get_agg_factors( + output_res, self.exo_res) + + else: + msg = ('Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)') + raise OSError(msg) + + step.update({'s_agg_factor': s_agg_factor, + 't_agg_factor': t_agg_factor}) + return step + + def _get_single_step_enhance(self, step): + """Get enhancement factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + stored enhance attributes of each model and the model step provided. + If enhancement factors are already provided in step they are not + overwritten. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + updated_step : dict + Same as input dictionary with s_enhance, t_enhance added + """ + if all(key in step for key in ['s_enhance', 't_enhance']): + return step + + model_step = step['model'] + combine_type = step.get('combine_type', None) + msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.models)})') + assert len(self.models) > model_step, msg + + s_enhancements = [model.s_enhance for model in self.models] + t_enhancements = [model.t_enhance for model in self.models] + if combine_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = np.product(s_enhancements[:model_step]) + t_enhance = np.product(t_enhancements[:model_step]) + + elif combine_type.lower() in ('output', 'layer'): + s_enhance = np.product(s_enhancements[:model_step + 1]) + t_enhance = np.product(t_enhancements[:model_step + 1]) + + else: + msg = ('Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)') + raise OSError(msg) + + step.update({'s_enhance': s_enhance, 't_enhance': t_enhance}) + return step + + def _get_all_agg_and_enhancement(self): + """Compute agg and enhancement factors for all model steps for all + features. + + Returns + ------- + agg_enhance_dict : dict + Dictionary with list of agg and enhancement factors for each model + step + """ + agg_enhance_dict = {} + for i, step in enumerate(self.steps): + out = self._get_single_step_agg(step) + out = self._get_single_step_enhance(out) + self.steps[i] = out + agg_enhance_dict['s_agg_factors'] = [step['s_agg_factor'] + for step in self.steps] + agg_enhance_dict['t_agg_factors'] = [step['t_agg_factor'] + for step in self.steps] + agg_enhance_dict['s_enhancements'] = [step['s_enhance'] + for step in self.steps] + agg_enhance_dict['t_enhancements'] = [step['t_enhance'] + for step in self.steps] + return agg_enhance_dict + + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get cache file name Parameters ---------- + feature : str + Name of feature to get cache file for s_enhance : int Spatial enhancement for this exogeneous data step (cumulative for all model steps up to the current step). - agg_factor : int - Factor by which to aggregate the topo_source_h5 elevation - data to the resolution of the file_paths input enhanced by - s_enhance. + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. Returns ------- - data : np.ndarray - 2D array of elevation data with shape (lat, lon) + cache_fp : str + Name of cache file """ - - cache_dir = './exo_cache/' - fn = f'exo_{self.target}_{self.shape}_agg{agg_factor}_{s_enhance}x.pkl' + fn = f'exo_{feature}_{self.target}_{self.shape}_sagg{s_agg_factor}_' + fn += f'tagg{t_agg_factor}_{s_enhance}x_{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') fn = fn.replace(',', 'x').replace(' ', '') - cache_fp = os.path.join(cache_dir, fn) - temp_fp = cache_fp + '.tmp' + cache_fp = os.path.join(self.cache_dir, fn) + if self.cache_data: + os.makedirs(self.cache_dir, exist_ok=True) + return cache_fp + + def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get the exogenous topography data + + Parameters + ---------- + feature : str + Name of feature to get exo data for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. + Returns + ------- + data : np.ndarray + 2D or 3D array of exo data with shape (lat, lon) or (lat, + lon, temporal) + """ + + cache_fp = self.get_cache_file(feature=feature, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor) + tmp_fp = cache_fp + '.tmp' if os.path.exists(cache_fp): with open(cache_fp, 'rb') as f: data = pickle.load(f) else: - topo_handler = self.get_topo_handler(self.source_file, - self.topo_handler) - data = topo_handler(self.file_paths, self.source_file, s_enhance, - agg_factor, target=self.target, - shape=self.shape, - raster_file=self.raster_file, - max_delta=self.max_delta, - input_handler=self.input_handler) - data = data.hr_elev + exo_handler = self.get_exo_handler(feature, self.source_file, + self.exo_handler) + data = exo_handler(self.file_paths, + self.source_file, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor, + target=self.target, + shape=self.shape, + temporal_slice=self.temporal_slice, + raster_file=self.raster_file, + max_delta=self.max_delta, + input_handler=self.input_handler).data if self.cache_data: - os.makedirs(cache_dir, exist_ok=True) - with open(temp_fp, 'wb') as f: + with open(tmp_fp, 'wb') as f: pickle.dump(data, f) - shutil.move(temp_fp, cache_fp) - + shutil.move(tmp_fp, cache_fp) return data - @staticmethod - def get_topo_handler(source_file, topo_handler): - """Get topo extraction class for source file + @classmethod + def get_exo_handler(cls, feature, source_file, exo_handler): + """Get exogenous feature extraction class for source file Parameters ---------- + feature : str + Name of feature to get exo handler for source_file : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the file_paths input - topo_handler : str - topo extract class to use for source data. Provide a string name to - match a class in topo.py. If None the correct handler will - be guessed based on file type and time series properties. + exo_handler : str + Feature extract class to use for source data. For example, if + feature='topography' this should be either TopoExtractH5 or + TopoExtractNC. If None the correct handler will be guessed based on + file type and time series properties. Returns ------- - topo_handler : str - topo extract class to use for source data. + exo_handler : str + Exogenous feature extraction class to use for source data. """ - if topo_handler is None: + if exo_handler is None: in_type = get_source_type(source_file) - if in_type == 'nc': - topo_handler = TopoExtractNC - elif in_type == 'h5': - topo_handler = TopoExtractH5 - else: - msg = ('Did not recognize input type "{}" for file paths: {}' - .format(in_type, source_file)) + if in_type not in ('h5', 'nc'): + msg = ('Did not recognize input type "{}" for file paths: {}'. + format(in_type, source_file)) logger.error(msg) raise RuntimeError(msg) - elif isinstance(topo_handler, str): - topo_handler = getattr(sup3r.utilities.topo, topo_handler, None) - if topo_handler is None: - msg = ('Could not find requested topo handler class ' - f'"{topo_handler}" in ' - 'sup3r.utilities.topo.') + check = (feature in cls.AVAILABLE_HANDLERS + and in_type in cls.AVAILABLE_HANDLERS[feature]) + if check: + exo_handler = cls.AVAILABLE_HANDLERS[feature][in_type] + else: + msg = ('Could not find exo handler class for ' + f'feature={feature} and input_type={in_type}.') logger.error(msg) raise KeyError(msg) - - return topo_handler + elif isinstance(exo_handler, str): + exo_handler = getattr(exo_extraction, exo_handler, None) + return exo_handler diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 6f29d178dd..c52c3640dd 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -15,10 +15,12 @@ import pandas as pd from scipy.stats import mode -from sup3r.utilities.utilities import (get_source_type, ignore_case_path_fetch, - uniform_box_sampler, - uniform_time_sampler, - ) +from sup3r.utilities.utilities import ( + get_source_type, + ignore_case_path_fetch, + uniform_box_sampler, + uniform_time_sampler, +) np.random.seed(42) @@ -601,6 +603,8 @@ def temporal_slice(self, temporal_slice): elements and no more than three, corresponding to the inputs of slice() """ + if temporal_slice is None: + temporal_slice = slice(None) msg = 'temporal_slice must be tuple, list, or slice' assert isinstance(temporal_slice, (tuple, list, slice)), msg if isinstance(temporal_slice, slice): diff --git a/sup3r/preprocessing/data_handling/nc_data_handling.py b/sup3r/preprocessing/data_handling/nc_data_handling.py index afa2f79ace..92ae0dc0e0 100644 --- a/sup3r/preprocessing/data_handling/nc_data_handling.py +++ b/sup3r/preprocessing/data_handling/nc_data_handling.py @@ -19,34 +19,13 @@ from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC from sup3r.preprocessing.feature_handling import ( - BVFreqMon, - BVFreqSquaredNC, - ClearSkyRatioCC, - Feature, - InverseMonNC, - LatLonNC, - PotentialTempNC, - PressureNC, - Rews, - Shear, - Tas, - TasMax, - TasMin, - TempNC, - TempNCforCC, - UWind, - UWindPowerLaw, - VWind, - VWindPowerLaw, - WinddirectionNC, - WindspeedNC, -) + BVFreqMon, BVFreqSquaredNC, ClearSkyRatioCC, Feature, InverseMonNC, + LatLonNC, PotentialTempNC, PressureNC, Rews, Shear, Tas, TasMax, TasMin, + TempNC, TempNCforCC, UWind, UWindPowerLaw, VWind, VWindPowerLaw, + WinddirectionNC, WindspeedNC) from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.utilities import ( - estimate_max_workers, - get_time_dim_name, - np_to_pd_times, -) +from sup3r.utilities.utilities import (estimate_max_workers, get_time_dim_name, + np_to_pd_times) np.random.seed(42) @@ -86,26 +65,6 @@ class DataHandlerNC(DataHandler): Chunk sizes that approximately match the data volume being extracted typically results in the most efficient IO.""" - def __init__(self, *args, xr_chunks=None, **kwargs): - """Initialize NETCDF data handler. - - Parameters - ---------- - *args : list - Same ordered required arguments as DataHandler parent class. - xr_chunks : int | "auto" | tuple | dict | None - kwarg that goes to xr.DataArray.chunk(chunks=xr_chunks). Chunk - sizes that approximately match the data volume being extracted - typically results in the most efficient IO. If not provided, this - defaults to the class CHUNKS attribute. - **kwargs : list - Same optional keyword arguments as DataHandler parent class. - """ - if xr_chunks is not None: - self.CHUNKS = xr_chunks - - super().__init__(*args, **kwargs) - @property def extract_workers(self): """Get upper bound for extract workers based on memory limits. Used to @@ -288,9 +247,9 @@ def extract_feature(cls, fdata = cls.direct_extract(handle, feat_key, raster_index, time_slice) - elif interp_height is not None and (cls.has_multilevel_feature( - feature, handle) or cls.has_surrounding_features( - feature, handle)): + elif interp_height is not None and ( + cls.has_multilevel_feature(feature, handle) + or cls.has_surrounding_features(feature, handle)): fdata = Interpolator.interp_var_to_height( handle, feature, raster_index, np.float32(interp_height), time_slice) diff --git a/sup3r/preprocessing/wind_conditional_moment_batch_handling.py b/sup3r/preprocessing/wind_conditional_moment_batch_handling.py index 537935d82c..0b554cca89 100644 --- a/sup3r/preprocessing/wind_conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/wind_conditional_moment_batch_handling.py @@ -3,17 +3,21 @@ Sup3r wind conditional moment batch_handling module. """ import logging -import tensorflow as tf + import numpy as np +import tensorflow as tf -from sup3r.utilities.utilities import (spatial_simple_enhancing, - temporal_simple_enhancing) from sup3r.preprocessing.batch_handling import Batch from sup3r.preprocessing.conditional_moment_batch_handling import ( - SpatialBatchHandlerMom1, + BatchHandlerMom1, BatchMom1, + SpatialBatchHandlerMom1, ValidationDataMom1, - BatchHandlerMom1) +) +from sup3r.utilities.utilities import ( + spatial_simple_enhancing, + temporal_simple_enhancing, +) np.random.seed(42) @@ -122,7 +126,8 @@ def make_output(low_res, high_res, HR is high-res and LR is low-res """ # Remove first moment from HR and square it - out = model_mom1._tf_generate(low_res, high_res[..., -1:]).numpy() + out = model_mom1._tf_generate( + low_res, {'topography': high_res[..., -1:]}).numpy() out = tf.concat((out, high_res[..., -1:]), axis=-1) return (high_res - out)**2 @@ -176,7 +181,8 @@ def make_output(low_res, high_res, SF = HR - LR """ # Remove LR and first moment from HR and square it - out = model_mom1._tf_generate(low_res, high_res[..., -1:]).numpy() + out = model_mom1._tf_generate( + low_res, {'topography': high_res[..., -1:]}).numpy() out = tf.concat((out, high_res[..., -1:]), axis=-1) enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) @@ -251,6 +257,7 @@ class WindBatchHandlerMom1(BatchHandlerMom1): class WindSpatialBatchHandlerMom1(SpatialBatchHandlerMom1): """Sup3r spatial batch handling class""" + # Classes to use for handling an individual batch obj. VAL_CLASS = ValidationDataMom1 BATCH_CLASS = WindBatchMom1 @@ -260,36 +267,42 @@ class WindSpatialBatchHandlerMom1(SpatialBatchHandlerMom1): class ValidationDataWindMom1SF(ValidationDataMom1): """Iterator for validation wind data for first conditional moment of subfilter velocity""" + BATCH_CLASS = WindBatchMom1SF class ValidationDataWindMom2(ValidationDataMom1): """Iterator for subfilter validation wind data for second conditional moment""" + BATCH_CLASS = WindBatchMom2 class ValidationDataWindMom2Sep(ValidationDataMom1): """Iterator for subfilter validation wind data for second conditional moment separate from first moment""" + BATCH_CLASS = WindBatchMom2Sep class ValidationDataWindMom2SF(ValidationDataMom1): """Iterator for validation wind data for second conditional moment of subfilter velocity""" + BATCH_CLASS = WindBatchMom2SF class ValidationDataWindMom2SepSF(ValidationDataMom1): """Iterator for validation wind data for second conditional moment of subfilter velocity separate from first moment""" + BATCH_CLASS = WindBatchMom2SepSF class WindBatchHandlerMom1SF(WindBatchHandlerMom1): """Sup3r batch handling class for first conditional moment of subfilter velocity using topography as input""" + VAL_CLASS = ValidationDataWindMom1SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -297,6 +310,7 @@ class WindBatchHandlerMom1SF(WindBatchHandlerMom1): class WindSpatialBatchHandlerMom1SF(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for first conditional moment of subfilter velocity using topography as input""" + VAL_CLASS = ValidationDataWindMom1SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -304,6 +318,7 @@ class WindSpatialBatchHandlerMom1SF(WindSpatialBatchHandlerMom1): class WindBatchHandlerMom2(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2 BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -311,6 +326,7 @@ class WindBatchHandlerMom2(WindBatchHandlerMom1): class WindBatchHandlerMom2Sep(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2Sep BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -318,6 +334,7 @@ class WindBatchHandlerMom2Sep(WindBatchHandlerMom1): class WindSpatialBatchHandlerMom2(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2 BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -325,6 +342,7 @@ class WindSpatialBatchHandlerMom2(WindSpatialBatchHandlerMom1): class WindSpatialBatchHandlerMom2Sep(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2Sep BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -332,6 +350,7 @@ class WindSpatialBatchHandlerMom2Sep(WindSpatialBatchHandlerMom1): class WindBatchHandlerMom2SF(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment of subfilter velocity""" + VAL_CLASS = ValidationDataWindMom2SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -339,6 +358,7 @@ class WindBatchHandlerMom2SF(WindBatchHandlerMom1): class WindBatchHandlerMom2SepSF(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment of subfilter velocity separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2SepSF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -346,6 +366,7 @@ class WindBatchHandlerMom2SepSF(WindBatchHandlerMom1): class WindSpatialBatchHandlerMom2SF(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment of subfilter velocity using topography as input""" + VAL_CLASS = ValidationDataWindMom2SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -353,5 +374,6 @@ class WindSpatialBatchHandlerMom2SF(WindSpatialBatchHandlerMom1): class WindSpatialBatchHandlerMom2SepSF(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment of subfilter velocity separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2SepSF BATCH_CLASS = VAL_CLASS.BATCH_CLASS diff --git a/sup3r/utilities/pytest.py b/sup3r/utilities/pytest.py index 7c8581c74d..899bd3b993 100644 --- a/sup3r/utilities/pytest.py +++ b/sup3r/utilities/pytest.py @@ -34,13 +34,14 @@ def make_fake_nc_files(td, input_file, n_files): ] fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] for i in range(n_files): - input_dset = xr.open_dataset(input_file) - with xr.Dataset(input_dset) as dset: - dset['Times'][:] = np.array( - [fake_times[i].encode('ASCII')], dtype='|S19' - ) - dset['XTIME'][:] = i - dset.to_netcdf(fake_files[i]) + if os.path.exists(fake_files[i]): + os.remove(fake_files[i]) + with xr.open_dataset(input_file) as input_dset: + with xr.Dataset(input_dset) as dset: + dset['Times'][:] = np.array( + [fake_times[i].encode('ASCII')], dtype='|S19') + dset['XTIME'][:] = i + dset.to_netcdf(fake_files[i]) return fake_files @@ -68,12 +69,12 @@ def make_fake_multi_time_nc_files(td, input_file, n_steps, n_files): dummy_files = [] for i, files in enumerate(fake_files): dummy_file = os.path.join( - td, f'multi_timestep_file_{str(i).zfill(3)}.nc' - ) + td, f'multi_timestep_file_{str(i).zfill(3)}.nc') + if os.path.exists(dummy_file): + os.remove(dummy_file) dummy_files.append(dummy_file) with xr.open_mfdataset( - files, combine='nested', concat_dim='Time' - ) as dset: + files, combine='nested', concat_dim='Time') as dset: dset.to_netcdf(dummy_file) return dummy_files @@ -104,14 +105,15 @@ def make_fake_era_files(td, input_file, n_files): ] fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] for i in range(n_files): - input_dset = xr.open_dataset(input_file) - with xr.Dataset(input_dset) as dset: - dset['Times'][:] = np.array( - [fake_times[i].encode('ASCII')], dtype='|S19' - ) - dset['XTIME'][:] = i - dset = dset.rename({'U': 'u', 'V': 'v'}) - dset.to_netcdf(fake_files[i]) + if os.path.exists(fake_files[i]): + os.remove(fake_files[i]) + with xr.open_dataset(input_file) as input_dset: + with xr.Dataset(input_dset) as dset: + dset['Times'][:] = np.array( + [fake_times[i].encode('ASCII')], dtype='|S19') + dset['XTIME'][:] = i + dset = dset.rename({'U': 'u', 'V': 'v'}) + dset.to_netcdf(fake_files[i]) return fake_files diff --git a/sup3r/utilities/topo.py b/sup3r/utilities/topo.py deleted file mode 100644 index ac01cecc1f..0000000000 --- a/sup3r/utilities/topo.py +++ /dev/null @@ -1,358 +0,0 @@ -"""Sup3r topography utilities""" - -import logging -from abc import ABC, abstractmethod - -import numpy as np -from rex import Resource -from scipy.spatial import KDTree - -import sup3r.preprocessing.data_handling -from sup3r.postprocessing.file_handling import OutputHandler -from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 -from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC -from sup3r.utilities.utilities import get_source_type - -logger = logging.getLogger(__name__) - - -class TopoExtract(ABC): - """Class to extract high-res (4km+) topography rasters for new - spatially-enhanced datasets (e.g. GCM files after spatial enhancement) - using nearest neighbor mapping and aggregation from NREL datasets - (e.g. WTK or NSRDB) - """ - - def __init__( - self, - file_paths, - topo_source, - s_enhance, - agg_factor, - target=None, - shape=None, - raster_file=None, - max_delta=20, - input_handler=None, - ti_workers=1, - ): - """ - Parameters - ---------- - file_paths : str | list - A single source h5 file to extract raster data from or a list - of netcdf files with identical grid. The string can be a unix-style - file path which will be passed through glob.glob. This is - typically low-res WRF output or GCM netcdf data files that is - source low-resolution data intended to be sup3r resolved. - topo_source : str - Filepath to source wtk or nsrdb file to get hi-res (2km or 4km) - elevation data from which will be mapped to the enhanced grid of - the file_paths input - s_enhance : int - Factor by which the Sup3rGan model will enhance the spatial - dimensions of low resolution data from file_paths input. For - example, if file_paths has 100km data and s_enhance is 4, this - class will output a topography raster corresponding to the - file_paths grid enhanced 4x to ~25km - agg_factor : int - Factor by which to aggregate the topo_source_h5 elevation - data to the resolution of the file_paths input enhanced by - s_enhance. For example, if file_paths has 100km data and s_enhance - is 4 resulting in a desired resolution of ~25km and topo_source_h5 - has a resolution of 4km, the agg_factor should be 36 so that 6x6 - 4km cells are averaged to the ~25km enhanced grid. - target : tuple - (lat, lon) lower left corner of raster. Either need target+shape or - raster_file. - shape : tuple - (rows, cols) grid size. Either need target+shape or raster_file. - raster_file : str | None - File for raster_index array for the corresponding target and shape. - If specified the raster_index will be loaded from the file if it - exists or written to the file if it does not yet exist. If None - raster_index will be calculated directly. Either need target+shape - or raster_file. - input_handler : str - data handler class to use for input data. Provide a string name to - match a class in data_handling.py. If None the correct handler will - be guessed based on file type and time series properties. - ti_workers : int | None - max number of workers to use to get full time index. Useful when - there are many input files each with a single time step. If this is - greater than one, time indices for input files will be extracted in - parallel and then concatenated to get the full time index. If input - files do not all have time indices or if there are few input files - this should be set to one. - - """ - - logger.info('Initializing TopoExtract utility.') - - self._topo_source = topo_source - self._s_enhance = s_enhance - self._agg_factor = agg_factor - self._tree = None - self.ti_workers = ti_workers - self._hr_lat_lon = None - - if input_handler is None: - in_type = get_source_type(file_paths) - if in_type == 'nc': - input_handler = DataHandlerNC - elif in_type == 'h5': - input_handler = DataHandlerH5 - else: - msg = 'Did not recognize input type "{}" for file paths: {}'.format( - in_type, file_paths - ) - logger.error(msg) - raise RuntimeError(msg) - elif isinstance(input_handler, str): - input_handler = getattr( - sup3r.preprocessing.data_handling, input_handler, None - ) - if input_handler is None: - msg = ( - 'Could not find requested data handler class ' - f'"{input_handler}" in ' - 'sup3r.preprocessing.data_handling.' - ) - logger.error(msg) - raise KeyError(msg) - - self.input_handler = input_handler( - file_paths, - [], - target=target, - shape=shape, - raster_file=raster_file, - max_delta=max_delta, - worker_kwargs=dict(ti_workers=ti_workers), - ) - - @property - @abstractmethod - def source_elevation(self): - """Get the 1D array of elevation data from the topo_source_h5""" - - @property - @abstractmethod - def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the topo_source_h5""" - - @property - def lr_shape(self): - """Get the low-resolution spatial shape tuple""" - return (self.lr_lat_lon.shape[0], self.lr_lat_lon.shape[1]) - - @property - def hr_shape(self): - """Get the high-resolution spatial shape tuple""" - return ( - self._s_enhance * self.lr_lat_lon.shape[0], - self._s_enhance * self.lr_lat_lon.shape[1], - ) - - @property - def lr_lat_lon(self): - """lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon - array with same ordering in last dimension. This corresponds to the raw - low-resolution meta data from the file_paths input. - - Returns - ------- - ndarray - """ - return self.input_handler.lat_lon - - @property - def hr_lat_lon(self): - """lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon - array with same ordering in last dimension. This corresponds to the - enhanced high-res meta data from the file_paths input * s_enhance. - - Returns - ------- - ndarray - """ - if self._hr_lat_lon is None: - if self._s_enhance > 1: - self._hr_lat_lon = OutputHandler.get_lat_lon( - self.lr_lat_lon, self.hr_shape - ) - else: - self._hr_lat_lon = self.lr_lat_lon - return self._hr_lat_lon - - @property - def tree(self): - """Get the KDTree built on the source lat lon data""" - if self._tree is None: - self._tree = KDTree(self.source_lat_lon) - return self._tree - - @property - def nn(self): - """Get the nearest neighbor indices""" - ll2 = np.vstack( - ( - self.hr_lat_lon[:, :, 0].flatten(), - self.hr_lat_lon[:, :, 1].flatten(), - ) - ).T - _, nn = self.tree.query(ll2, k=self._agg_factor) - if len(nn.shape) == 1: - nn = np.expand_dims(nn, 1) - return nn - - @property - def hr_elev(self): - """Get a raster of elevation values corresponding to the - high-resolution grid (the file_paths input grid * s_enhance). The shape - is (rows, cols) - """ - nn = self.nn - hr_elev = [] - for j in range(self._agg_factor): - elev = self.source_elevation[nn[:, j]] - elev = elev.reshape(self.hr_shape) - hr_elev.append(elev) - hr_elev = np.dstack(hr_elev).mean(axis=-1) - logger.info( - 'Finished mapping topo raster from {}'.format(self._topo_source) - ) - return hr_elev - - @classmethod - def get_topo_raster( - cls, - file_paths, - topo_source, - s_enhance, - agg_factor, - target=None, - shape=None, - raster_file=None, - max_delta=20, - input_handler=None, - ): - """Get the topography raster corresponding to the spatially enhanced - grid from the file_paths input - - Parameters - ---------- - file_paths : str | list - A single source h5 wind file to extract raster data from or a list - of netcdf files with identical grid. The string can be a unix-style - file path which will be passed through glob.glob - topo_source : str - Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or - 4km) data from which will be mapped to the enhanced grid of the - file_paths input - s_enhance : int - Factor by which the Sup3rGan model will enhance the spatial - dimensions of low resolution data from file_paths input. For - example, if file_paths has 100km data and s_enhance is 4, this - class will output a topography raster corresponding to the - file_paths grid enhanced 4x to ~25km - agg_factor : int - Factor by which to aggregate the topo_source_h5 elevation data to - the resolution of the file_paths input enhanced by s_enhance. For - example, if file_paths has 100km data and s_enhance is 4 resulting - in a desired resolution of ~25km and topo_source_h5 has a - resolution of 4km, the agg_factor should be 36 so that 6x6 4km - cells are averaged to the ~25km enhanced grid. - target : tuple - (lat, lon) lower left corner of raster. Either need target+shape or - raster_file. - shape : tuple - (rows, cols) grid size. Either need target+shape or raster_file. - raster_file : str | None - File for raster_index array for the corresponding target and shape. - If specified the raster_index will be loaded from the file if it - exists or written to the file if it does not yet exist. If None - raster_index will be calculated directly. Either need target+shape - or raster_file. - input_handler : str - data handler class to use for input data. Provide a string name to - match a class in data_handling.py. If None the correct handler will - be guessed based on file type and time series properties. - - Returns - ------- - topo_raster : np.ndarray - Elevation raster with shape (hr_rows, hr_cols) corresponding to the - shape of the spatially enhanced grid from file_paths * s_enhance. - The elevation units correspond to the source units in - topo_source_h5, usually meters. - """ - - te = cls( - file_paths, - topo_source, - s_enhance, - agg_factor, - target=target, - shape=shape, - raster_file=raster_file, - max_delta=max_delta, - input_handler=input_handler, - ) - - return te.hr_elev - - -class TopoExtractH5(TopoExtract): - """TopoExtract for H5 files""" - - @property - def source_elevation(self): - """Get the 1D array of elevation data from the topo_source_h5""" - with Resource(self._topo_source) as res: - elev = res.get_meta_arr('elevation') - return elev - - @property - def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the topo_source_h5""" - with Resource(self._topo_source) as res: - source_lat_lon = res.lat_lon - return source_lat_lon - - -class TopoExtractNC(TopoExtract): - """TopoExtract for netCDF files""" - - def __init__(self, *args, **kwargs): - """ - Parameters - ---------- - args : list - Same positional arguments as TopoExtract - kwargs : dict - Same keyword arguments as TopoExtract - """ - - super().__init__(*args, **kwargs) - logger.info( - 'Getting topography for full domain from ' f'{self._topo_source}' - ) - self.source_handler = DataHandlerNC( - self._topo_source, - features=['topography'], - worker_kwargs=dict(ti_workers=self.ti_workers), - val_split=0.0, - ) - - @property - def source_elevation(self): - """Get the 1D array of elevation data from the topo_source_h5""" - elev = self.source_handler.data.reshape((-1)) - return elev - - @property - def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the topo_source_h5""" - source_lat_lon = self.source_handler.lat_lon.reshape((-1, 2)) - return source_lat_lon diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 13697d50a9..fe42324a7a 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -9,12 +9,14 @@ from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling.dual_data_handling import ( - DualDataHandler, ) + DualDataHandler, +) from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC -from sup3r.preprocessing.dual_batch_handling import (DualBatchHandler, - SpatialDualBatchHandler, - ) +from sup3r.preprocessing.dual_batch_handling import ( + DualBatchHandler, + SpatialDualBatchHandler, +) from sup3r.utilities.utilities import spatial_coarsening FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') @@ -23,7 +25,7 @@ FEATURES = ['U_100m', 'V_100m'] -def test_dual_data_handler(log=True, +def test_dual_data_handler(log=False, full_shape=(20, 20), sample_shape=(10, 10, 1), plot=True): @@ -71,7 +73,7 @@ def test_dual_data_handler(log=True, bbox_inches='tight') -def test_regrid_caching(log=True, +def test_regrid_caching(log=False, full_shape=(20, 20), sample_shape=(10, 10, 1)): """Test caching and loading of regridded data""" diff --git a/tests/data_handling/test_exo_data_handling.py b/tests/data_handling/test_exo_data_handling.py index 0a147fdd6e..93091733c1 100644 --- a/tests/data_handling/test_exo_data_handling.py +++ b/tests/data_handling/test_exo_data_handling.py @@ -4,6 +4,7 @@ import shutil import numpy as np +import pytest from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling import ExogenousDataHandler @@ -14,19 +15,30 @@ os.path.join(TEST_DATA_DIR, 'va_test.nc'), os.path.join(TEST_DATA_DIR, 'orog_test.nc'), os.path.join(TEST_DATA_DIR, 'zg_test.nc')] -FEATURES = ['topography'] TARGET = (13.67, 125.0) SHAPE = (8, 8) S_ENHANCE = [1, 4] -AGG_FACTORS = [4, 1] +T_ENHANCE = [1, 1] +S_AGG_FACTORS = [4, 1] +T_AGG_FACTORS = [1, 1] -def test_exo_cache(): +@pytest.mark.parametrize('feature', ['topography', 'sza']) +def test_exo_cache(feature): """Test exogenous data caching and re-load""" # no cached data + steps = [] + for s_en, t_en, s_agg, t_agg in zip(S_ENHANCE, T_ENHANCE, S_AGG_FACTORS, + T_AGG_FACTORS): + steps.append({'s_enhance': s_en, + 't_enhance': t_en, + 's_agg_factor': s_agg, + 't_agg_factor': t_agg}) try: - base = ExogenousDataHandler(FILE_PATHS, FEATURES, FP_WTK, S_ENHANCE, - AGG_FACTORS, target=TARGET, shape=SHAPE, + base = ExogenousDataHandler(FILE_PATHS, feature, + source_file=FP_WTK, + steps=steps, + target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') for i, arr in enumerate(base.data): assert arr.shape[0] == SHAPE[0] * S_ENHANCE[i] @@ -41,8 +53,10 @@ def test_exo_cache(): # load cached data try: - cache = ExogenousDataHandler(FILE_PATHS, FEATURES, FP_WTK, S_ENHANCE, - AGG_FACTORS, target=TARGET, shape=SHAPE, + cache = ExogenousDataHandler(FILE_PATHS, feature, + source_file=FP_WTK, + steps=steps, + target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') except Exception as e: if os.path.exists('./exo_cache/'): diff --git a/tests/data_handling/test_utils_topo.py b/tests/data_handling/test_utils_topo.py index 3b3fa83987..294bede190 100644 --- a/tests/data_handling/test_utils_topo.py +++ b/tests/data_handling/test_utils_topo.py @@ -1,14 +1,17 @@ # -*- coding: utf-8 -*- """pytests for topography utilities""" import os + +import matplotlib.pyplot as plt import numpy as np import pytest from scipy.spatial import KDTree -import matplotlib.pyplot as plt from sup3r import TEST_DATA_DIR -from sup3r.utilities.topo import TopoExtractNC, TopoExtractH5 - +from sup3r.preprocessing.data_handling.exo_extraction import ( + TopoExtractH5, + TopoExtractNC, +) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET = (39.001, -105.15) @@ -22,32 +25,33 @@ def test_topo_extraction_h5(agg_factor, plot=False): """Test the spatial enhancement of a test grid and then the lookup of the elevation data to a reference WTK file (also the same file for the test)""" - te = TopoExtractH5(FP_WTK, FP_WTK, s_enhance=2, agg_factor=agg_factor, + te = TopoExtractH5(FP_WTK, FP_WTK, s_enhance=2, t_enhance=1, + t_agg_factor=1, s_agg_factor=agg_factor, target=TARGET, shape=SHAPE) - hr_elev = te.hr_elev + hr_elev = te.data tree = KDTree(te.source_lat_lon) # bottom left _, i = tree.query(TARGET, k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[-1, 0]) # top right _, i = tree.query((39.35, -105.2), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[0, 0]) for idy in range(10, 20): for idx in range(10, 20): lat, lon = te.hr_lat_lon[idy, idx, :] _, i = tree.query((lat, lon), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[idy, idx]) if plot: a = plt.scatter(te.source_lat_lon[:, 1], te.source_lat_lon[:, 0], - c=te.source_elevation, marker='s', s=150) + c=te.source_data, marker='s', s=150) plt.colorbar(a) plt.savefig('./source_elevation.png') plt.close() @@ -62,32 +66,33 @@ def test_topo_extraction_h5(agg_factor, plot=False): def test_topo_extraction_nc(agg_factor, plot=False): """Test the spatial enhancement of a test grid and then the lookup of the elevation data to a reference WRF file (also the same file for the test)""" - te = TopoExtractNC(FP_WRF, FP_WRF, s_enhance=2, agg_factor=agg_factor, + te = TopoExtractNC(FP_WRF, FP_WRF, s_enhance=2, t_enhance=1, + s_agg_factor=agg_factor, t_agg_factor=1, target=WRF_TARGET, shape=WRF_SHAPE) - hr_elev = te.hr_elev + hr_elev = te.data tree = KDTree(te.source_lat_lon) # bottom left _, i = tree.query(WRF_TARGET, k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[-1, 0]) # top right _, i = tree.query((19.4, -123.6), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[0, 0]) for idy in range(4, 8): for idx in range(4, 8): lat, lon = te.hr_lat_lon[idy, idx, :] _, i = tree.query((lat, lon), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[idy, idx]) if plot: a = plt.scatter(te.source_lat_lon[:, 1], te.source_lat_lon[:, 0], - c=te.source_elevation, marker='s', s=150) + c=te.source_data, marker='s', s=150) plt.colorbar(a) plt.savefig('./source_elevation.png') plt.close() diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index d6c54612fd..04f99081c8 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -6,17 +6,18 @@ import matplotlib.pyplot as plt import numpy as np -import pytest import tensorflow as tf import xarray as xr from rex import ResourceX, init_logger from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__ -from sup3r.models import LinearInterp, Sup3rGan, WindGan +from sup3r.models import Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC -from sup3r.utilities.pytest import (make_fake_nc_files, - make_fake_multi_time_nc_files) +from sup3r.utilities.pytest import ( + make_fake_multi_time_nc_files, + make_fake_nc_files, +) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -43,10 +44,12 @@ def test_fwp_nc_cc(log=False): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - input_files = [os.path.join(TEST_DATA_DIR, 'ua_test.nc'), - os.path.join(TEST_DATA_DIR, 'va_test.nc'), - os.path.join(TEST_DATA_DIR, 'orog_test.nc'), - os.path.join(TEST_DATA_DIR, 'zg_test.nc')] + input_files = [ + os.path.join(TEST_DATA_DIR, 'ua_test.nc'), + os.path.join(TEST_DATA_DIR, 'va_test.nc'), + os.path.join(TEST_DATA_DIR, 'orog_test.nc'), + os.path.join(TEST_DATA_DIR, 'zg_test.nc') + ] features = ['U_100m', 'V_100m'] target = (13.67, 125.0) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) @@ -63,16 +66,20 @@ def test_fwp_nc_cc(log=False): # 1st forward pass max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, cache_pattern=cache_pattern, overwrite_cache=True, worker_kwargs=dict(max_workers=max_workers)) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers), input_handler='DataHandlerNCforCC') forward_pass = ForwardPass(handler) @@ -84,14 +91,14 @@ def test_fwp_nc_cc(log=False): forward_pass.run(handler, node_index=0) with xr.open_dataset(handler.out_files[0]) as fh: - assert fh[FEATURES[0]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) - assert fh[FEATURES[1]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[0]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[1]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) def test_fwp_single_ts_vs_multi_ts_input_files(): @@ -117,16 +124,20 @@ def test_fwp_single_ts_vs_multi_ts_input_files(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) single_ts_handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) single_ts_forward_pass = ForwardPass(single_ts_handler) single_ts_forward_pass.run(single_ts_handler, node_index=0) @@ -138,16 +149,20 @@ def test_fwp_single_ts_vs_multi_ts_input_files(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) multi_ts_handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) multi_ts_forward_pass = ForwardPass(multi_ts_handler) multi_ts_forward_pass.run(multi_ts_handler, node_index=0) @@ -185,16 +200,20 @@ def test_fwp_spatial_only(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -205,14 +224,12 @@ def test_fwp_spatial_only(): forward_pass.run(handler, node_index=0) with xr.open_dataset(handler.out_files[0]) as fh: - assert fh[FEATURES[0]].shape == ( - len(handler.time_index), - 2 * fwp_chunk_shape[0], - 2 * fwp_chunk_shape[1]) - assert fh[FEATURES[1]].shape == ( - len(handler.time_index), - 2 * fwp_chunk_shape[0], - 2 * fwp_chunk_shape[1]) + assert fh[FEATURES[0]].shape == (len(handler.time_index), + 2 * fwp_chunk_shape[0], + 2 * fwp_chunk_shape[1]) + assert fh[FEATURES[1]].shape == (len(handler.time_index), + 2 * fwp_chunk_shape[0], + 2 * fwp_chunk_shape[1]) def test_fwp_nc(): @@ -238,16 +255,20 @@ def test_fwp_nc(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -258,14 +279,14 @@ def test_fwp_nc(): forward_pass.run(handler, node_index=0) with xr.open_dataset(handler.out_files[0]) as fh: - assert fh[FEATURES[0]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) - assert fh[FEATURES[1]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[0]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[1]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) def test_fwp_temporal_slice(): @@ -295,16 +316,20 @@ def test_fwp_temporal_slice(): raw_time_index = np.arange(20) n_tsteps = len(raw_time_index[temporal_slice]) input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -315,11 +340,10 @@ def test_fwp_temporal_slice(): forward_pass.run(handler, node_index=0) with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * n_tsteps, - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) + assert fh.shape == (t_enhance * n_tsteps, s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) assert fh.global_attrs['package'] == 'sup3r' assert fh.global_attrs['version'] == __version__ @@ -355,15 +379,18 @@ def test_fwp_handler(): max_workers = 1 cache_pattern = os.path.join(td, 'cache') input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) @@ -405,32 +432,39 @@ def test_fwp_chunking(log=False, plot=False): temporal_pad = 20 cache_pattern = os.path.join(td, 'cache') fwp_shape = (4, 4, len(input_files) // 2) - handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, - fwp_chunk_shape=fwp_shape, - worker_kwargs=dict(max_workers=1), - spatial_pad=spatial_pad, temporal_pad=temporal_pad, - input_handler_kwargs=dict(target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, - worker_kwargs=dict(max_workers=1))) - data_chunked = np.zeros((shape[0] * s_enhance, shape[1] * s_enhance, - len(input_files) * t_enhance, - len(model.output_features))) - handlerNC = DataHandlerNC(input_files, FEATURES, target=target, - val_split=0.0, shape=shape, + handler = ForwardPassStrategy(input_files, + model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=fwp_shape, + worker_kwargs=dict(max_workers=1), + spatial_pad=spatial_pad, + temporal_pad=temporal_pad, + input_handler_kwargs=dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True, + worker_kwargs=dict(max_workers=1))) + data_chunked = np.zeros( + (shape[0] * s_enhance, shape[1] * s_enhance, + len(input_files) * t_enhance, len(model.output_features))) + handlerNC = DataHandlerNC(input_files, + FEATURES, + target=target, + val_split=0.0, + shape=shape, worker_kwargs=dict(ti_workers=1)) pad_width = ((spatial_pad, spatial_pad), (spatial_pad, spatial_pad), (temporal_pad, temporal_pad), (0, 0)) hr_crop = (slice(s_enhance * spatial_pad, -s_enhance * spatial_pad), slice(s_enhance * spatial_pad, -s_enhance * spatial_pad), - slice(t_enhance * temporal_pad, -t_enhance * temporal_pad), - slice(None)) - input_data = np.pad(handlerNC.data, pad_width=pad_width, + slice(t_enhance * temporal_pad, + -t_enhance * temporal_pad), slice(None)) + input_data = np.pad(handlerNC.data, + pad_width=pad_width, mode='reflect') - data_nochunk = model.generate( - np.expand_dims(input_data, axis=0))[0][hr_crop] + data_nochunk = model.generate(np.expand_dims(input_data, + axis=0))[0][hr_crop] for i in range(handler.chunks): fwp = ForwardPass(handler, chunk_index=i) out = fwp.run_chunk() @@ -438,7 +472,7 @@ def test_fwp_chunking(log=False, plot=False): fwp.ti_slice.stop * t_enhance) data_chunked[fwp.hr_slice][..., t_hr_slice, :] = out - err = (data_chunked - data_nochunk) + err = data_chunked - data_nochunk err /= data_nochunk if plot: for ifeature in range(data_nochunk.shape[-1]): @@ -448,17 +482,23 @@ def test_fwp_chunking(log=False, plot=False): ax3 = fig.add_subplot(133) vmin = np.min(data_nochunk) vmax = np.max(data_nochunk) - nc = ax1.imshow(data_nochunk[..., 0, ifeature], vmin=vmin, + nc = ax1.imshow(data_nochunk[..., 0, ifeature], + vmin=vmin, vmax=vmax) - ch = ax2.imshow(data_chunked[..., 0, ifeature], vmin=vmin, + ch = ax2.imshow(data_chunked[..., 0, ifeature], + vmin=vmin, vmax=vmax) diff = ax3.imshow(err[..., 0, ifeature]) ax1.set_title('Non chunked output') ax2.set_title('Chunked output') ax3.set_title('Difference') - fig.colorbar(nc, ax=ax1, shrink=0.6, + fig.colorbar(nc, + ax=ax1, + shrink=0.6, label=f'{model.output_features[ifeature]}') - fig.colorbar(ch, ax=ax2, shrink=0.6, + fig.colorbar(ch, + ax=ax2, + shrink=0.6, label=f'{model.output_features[ifeature]}') fig.colorbar(diff, ax=ax3, shrink=0.6, label='Difference') plt.savefig(f'./chunk_vs_nochunk_{ifeature}.png') @@ -489,23 +529,27 @@ def test_fwp_nochunking(): model.save(out_dir) cache_pattern = os.path.join(td, 'cache') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - cache_pattern=cache_pattern, - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + cache_pattern=cache_pattern, + overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=(shape[0], shape[1], list_chunk_size), - spatial_pad=0, temporal_pad=0, + spatial_pad=0, + temporal_pad=0, input_handler_kwargs=input_handler_kwargs, worker_kwargs=dict(max_workers=1)) forward_pass = ForwardPass(handler) data_chunked = forward_pass.run_chunk() - handlerNC = DataHandlerNC(input_files, FEATURES, - target=target, shape=shape, + handlerNC = DataHandlerNC(input_files, + FEATURES, + target=target, + shape=shape, temporal_slice=temporal_slice, cache_pattern=None, time_chunk_size=100, @@ -513,308 +557,12 @@ def test_fwp_nochunking(): val_split=0.0, worker_kwargs=dict(max_workers=1)) - data_nochunk = model.generate( - np.expand_dims(handlerNC.data, axis=0))[0] + data_nochunk = model.generate(np.expand_dims(handlerNC.data, + axis=0))[0] assert np.array_equal(data_chunked, data_nochunk) -def test_fwp_multi_step_model_topo_exoskip(log=False): - """Test the forward pass with a multi step model class using exogenous data - for the first two steps and not the last""" - - if log: - init_logger('sup3r', log_level='DEBUG') - - Sup3rGan.seed() - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3))) - - s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - max_workers = 1 - fwp_chunk_shape = (4, 4, 8) - s_enhance = 12 - t_enhance = 4 - - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 16], - 'exo_steps': [0, 1] - } - - model_kwargs = {'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir} - - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=max_workers), - overwrite_cache=True) - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=fwp_chunk_shape, - input_handler_kwargs=input_handler_kwargs, - spatial_pad=0, temporal_pad=0, - out_pattern=out_files, - worker_kwargs=dict(max_workers=max_workers), - exo_kwargs=exo_kwargs, - max_nodes=1) - - forward_pass = ForwardPass(handler) - - assert forward_pass.output_workers == max_workers - assert forward_pass.pass_workers == max_workers - assert forward_pass.max_workers == max_workers - assert forward_pass.data_handler.max_workers == max_workers - assert forward_pass.data_handler.compute_workers == max_workers - assert forward_pass.data_handler.load_workers == max_workers - assert forward_pass.data_handler.norm_workers == max_workers - assert forward_pass.data_handler.extract_workers == max_workers - - forward_pass.run(handler, node_index=0) - - with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) - - assert fh.global_attrs['package'] == 'sup3r' - assert fh.global_attrs['version'] == __version__ - assert 'full_version_record' in fh.global_attrs - version_record = json.loads(fh.global_attrs['full_version_record']) - assert version_record['tensorflow'] == tf.__version__ - assert 'gan_meta' in fh.global_attrs - gan_meta = json.loads(fh.global_attrs['gan_meta']) - assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m', - 'topography'] - - -def test_fwp_multi_step_spatial_model_topo_noskip(): - """Test the forward pass with a multi step spatial only model class using - exogenous data for all model steps""" - Sup3rGan.seed() - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3))) - - s2_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - max_workers = 1 - fwp_chunk_shape = (4, 4, 8) - s_enhancements = [2, 2, 1] - s_enhance = np.product(s_enhancements) - - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [12, 4, 2] - } - - model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]} - - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=max_workers), - overwrite_cache=True) - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='MultiStepGan', - fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=max_workers), - exo_kwargs=exo_kwargs, - max_nodes=1) - - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) - - assert fh.global_attrs['package'] == 'sup3r' - assert fh.global_attrs['version'] == __version__ - assert 'full_version_record' in fh.global_attrs - version_record = json.loads(fh.global_attrs['full_version_record']) - assert version_record['tensorflow'] == tf.__version__ - assert 'gan_meta' in fh.global_attrs - gan_meta = json.loads(fh.global_attrs['gan_meta']) - assert len(gan_meta) == 2 # two step model - assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m', - 'topography'] - - -def test_fwp_multi_step_model_topo_noskip(): - """Test the forward pass with a multi step model class using exogenous data - for all model steps""" - Sup3rGan.seed() - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3))) - - s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - max_workers = 1 - fwp_chunk_shape = (4, 4, 8) - s_enhancements = [2, 2, 3] - s_enhance = np.product(s_enhancements) - t_enhance = 4 - - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12] - } - - model_kwargs = {'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir} - - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=max_workers), - overwrite_cache=True) - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=max_workers), - exo_kwargs=exo_kwargs, - max_nodes=1) - - forward_pass = ForwardPass(handler) - - assert forward_pass.output_workers == max_workers - assert forward_pass.data_handler.compute_workers == max_workers - assert forward_pass.data_handler.load_workers == max_workers - assert forward_pass.data_handler.norm_workers == max_workers - assert forward_pass.data_handler.extract_workers == max_workers - - forward_pass.run(handler, node_index=0) - - with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) - - assert fh.global_attrs['package'] == 'sup3r' - assert fh.global_attrs['version'] == __version__ - assert 'full_version_record' in fh.global_attrs - version_record = json.loads(fh.global_attrs['full_version_record']) - assert version_record['tensorflow'] == tf.__version__ - assert 'gan_meta' in fh.global_attrs - gan_meta = json.loads(fh.global_attrs['gan_meta']) - assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m', - 'topography'] - - def test_fwp_multi_step_model(): """Test the forward pass with a multi step model class""" Sup3rGan.seed() @@ -823,8 +571,8 @@ def test_fwp_multi_step_model(): s_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) s_model.meta['training_features'] = ['U_100m', 'V_100m'] s_model.meta['output_features'] = ['U_100m', 'V_100m'] - s_model.meta['s_enhance'] = 2 - s_model.meta['t_enhance'] = 1 + assert s_model.s_enhance == 2 + assert s_model.t_enhance == 1 _ = s_model.generate(np.ones((4, 10, 10, 2))) fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') @@ -832,8 +580,8 @@ def test_fwp_multi_step_model(): st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) st_model.meta['training_features'] = ['U_100m', 'V_100m'] st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 + assert st_model.s_enhance == 3 + assert st_model.t_enhance == 4 _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) with tempfile.TemporaryDirectory() as td: @@ -851,27 +599,31 @@ def test_fwp_multi_step_model(): s_enhance = 6 t_enhance = 4 - model_kwargs = {'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': st_out_dir} + model_kwargs = { + 'model_dirs': [s_out_dir, st_out_dir] + } input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=0, temporal_pad=0, + spatial_pad=0, + temporal_pad=0, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers), max_nodes=1) forward_pass = ForwardPass(handler) - ones = np.ones((fwp_chunk_shape[2], fwp_chunk_shape[0], - fwp_chunk_shape[1], 2)) + ones = np.ones( + (fwp_chunk_shape[2], fwp_chunk_shape[0], fwp_chunk_shape[1], 2)) out = forward_pass.model.generate(ones) assert out.shape == (1, 24, 24, 32, 2) @@ -884,11 +636,10 @@ def test_fwp_multi_step_model(): forward_pass.run(handler, node_index=0) with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) assert fh.global_attrs['package'] == 'sup3r' assert fh.global_attrs['version'] == __version__ @@ -928,22 +679,26 @@ def test_slicing_no_pad(log=False): st_out_dir = os.path.join(td, 'st_gan') st_model.save(st_out_dir) - handler = DataHandlerNC(input_files, features, - target=target, shape=shape, + handler = DataHandlerNC(input_files, + features, + target=target, + shape=shape, sample_shape=(1, 1, 1), val_split=0.0, worker_kwargs=dict(max_workers=1)) - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) strategy = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': st_out_dir}, + input_files, + model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', fwp_chunk_shape=(3, 2, 4), - spatial_pad=0, temporal_pad=0, + spatial_pad=0, + temporal_pad=0, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=1), @@ -953,8 +708,7 @@ def test_slicing_no_pad(log=False): forward_pass = ForwardPass(strategy, chunk_index=ichunk) s_slices = strategy.lr_pad_slices[forward_pass.spatial_chunk_index] lr_data_slice = (s_slices[0], s_slices[1], - forward_pass.ti_pad_slice, - slice(None)) + forward_pass.ti_pad_slice, slice(None)) truth = handler.data[lr_data_slice] assert np.allclose(forward_pass.input_data, truth) @@ -987,23 +741,27 @@ def test_slicing_pad(log=False): st_out_dir = os.path.join(td, 'st_gan') st_model.save(st_out_dir) - handler = DataHandlerNC(input_files, features, - target=target, shape=shape, + handler = DataHandlerNC(input_files, + features, + target=target, + shape=shape, sample_shape=(1, 1, 1), val_split=0.0, worker_kwargs=dict(max_workers=1)) - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) strategy = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': st_out_dir}, + input_files, + model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', fwp_chunk_shape=(2, 1, 4), input_handler_kwargs=input_handler_kwargs, - spatial_pad=2, temporal_pad=2, + spatial_pad=2, + temporal_pad=2, out_pattern=out_files, worker_kwargs=dict(max_workers=1), max_nodes=1) @@ -1027,8 +785,7 @@ def test_slicing_pad(log=False): s_slices = strategy.lr_pad_slices[forward_pass.spatial_chunk_index] lr_data_slice = (s_slices[0], s_slices[1], - forward_pass.ti_pad_slice, - slice(None)) + forward_pass.ti_pad_slice, slice(None)) # do a manual calculation of what the padding should be. # s1 and t axes should have padding of 2 and the borders and @@ -1053,347 +810,11 @@ def test_slicing_pad(log=False): pad_t_end = end_t_pad_lookup.get(idt, 0) pad_width = ((pad_s1_start, pad_s1_end), - (pad_s2_start, pad_s2_end), - (pad_t_start, pad_t_end), (0, 0)) + (pad_s2_start, pad_s2_end), (pad_t_start, + pad_t_end), (0, 0)) truth = handler.data[lr_data_slice] padded_truth = np.pad(truth, pad_width, mode='reflect') assert forward_pass.input_data.shape == padded_truth.shape assert np.allclose(forward_pass.input_data, padded_truth) - - -def test_fwp_single_step_wind_hi_res_topo(plot=False): - """Test the forward pass with a single spatiotemporal WindGan model - requiring high-resolution topography input from the exogenous_data - feature.""" - Sup3rGan.seed() - gen_model = [{"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 64, "kernel_size": 3, - "strides": 1}, - {"class": "Cropping3D", "cropping": 2}, - {"class": "SpatioTemporalExpansion", "temporal_mult": 2, - "temporal_method": "nearest"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping3D", "cropping": 2}, - {"class": "SpatioTemporalExpansion", "spatial_mult": 2}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping3D", "cropping": 2}, - {"alpha": 0.2, "class": "LeakyReLU"}, - - {"class": "Sup3rConcat"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 2, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping3D", "cropping": 2}] - - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - model = WindGan(gen_model, fp_disc, learning_rate=1e-4) - model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - model.meta['output_features'] = ['U_100m', 'V_100m'] - model.meta['s_enhance'] = 2 - model.meta['t_enhance'] = 2 - _ = model.generate(np.random.rand(4, 10, 10, 6, 3), - exogenous_data=(None, np.random.rand(4, 20, 20, 6, 1))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - model.save(st_out_dir) - - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2], - 'agg_factors': [2, 4], - } - - model_kwargs = {'model_dir': st_out_dir} - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) - - # should get an error on a bad tensorflow concatenation - with pytest.raises(RuntimeError): - exo_kwargs['s_enhancements'] = [1, 1] - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='WindGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - exo_kwargs['s_enhancements'] = [1, 2] - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='WindGan', - fwp_chunk_shape=(8, 8, 8), - spatial_pad=4, temporal_pad=4, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - - if plot: - for ifeature, feature in enumerate(forward_pass.output_features): - fig = plt.figure(figsize=(15, 5)) - ax1 = fig.add_subplot(111) - vmin = np.min(forward_pass.input_data[..., ifeature]) - vmax = np.max(forward_pass.input_data[..., ifeature]) - nc = ax1.imshow(forward_pass.input_data[..., 0, ifeature], - vmin=vmin, vmax=vmax) - fig.colorbar(nc, ax=ax1, shrink=0.6, label=f'{feature}') - plt.savefig(f'./input_{feature}.png') - plt.close() - - forward_pass.run(handler, node_index=0) - - for fp in handler.out_files: - assert os.path.exists(fp) - - -def test_fwp_multi_step_wind_hi_res_topo(): - """Test the forward pass with multiple WindGan models requiring - high-resolution topograph input from the exogenous_data feature.""" - Sup3rGan.seed() - gen_model = [{"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, "kernel_size": 3, - "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}, - {"class": "SpatialExpansion", "spatial_mult": 2}, - {"class": "Activation", "activation": "relu"}, - - {"class": "Sup3rConcat"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 2, - "kernel_size": 3, "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}] - - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - s2_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12], - } - - model_kwargs = {'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir} - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) - - # should get an error on a bad tensorflow concatenation - with pytest.raises(RuntimeError): - exo_kwargs['s_enhancements'] = [1, 1, 1] - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - exo_kwargs['s_enhancements'] = [1, 2, 2] - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - for fp in handler.out_files: - assert os.path.exists(fp) - - -def test_fwp_wind_hi_res_topo_plus_linear(): - """Test the forward pass with a WindGan model requiring high-res topo input - from exo data for spatial enhancement and a linear interpolation model for - temporal enhancement.""" - - Sup3rGan.seed() - gen_model = [{"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, "kernel_size": 3, - "strides": 1}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping2D", "cropping": 4}, - {"class": "SpatialExpansion", "spatial_mult": 2}, - {"alpha": 0.2, "class": "LeakyReLU"}, - - {"class": "Sup3rConcat"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 2, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping2D", "cropping": 4}] - - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) - s_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s_model.meta['output_features'] = ['U_100m', 'V_100m'] - s_model.meta['s_enhance'] = 2 - s_model.meta['t_enhance'] = 1 - _ = s_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - t_model = LinearInterp(features=['U_100m', 'V_100m'], s_enhance=1, - t_enhance=4) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - s_out_dir = os.path.join(td, 's_gan') - t_out_dir = os.path.join(td, 't_interp') - s_model.save(s_out_dir) - t_model.save(t_out_dir) - - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2], - 'agg_factors': [2, 4], - } - - model_kwargs = {'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': t_out_dir} - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) - - exo_kwargs['s_enhancements'] = [1, 2] - handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - for fp in handler.out_files: - assert os.path.exists(fp) diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py new file mode 100644 index 0000000000..1bf24b2869 --- /dev/null +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -0,0 +1,1193 @@ +# -*- coding: utf-8 -*- +"""pytests for data handling""" +import json +import os +import tempfile + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import tensorflow as tf +from rex import ResourceX, init_logger + +from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__ +from sup3r.models import LinearInterp, Sup3rGan +from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy +from sup3r.utilities.pytest import make_fake_nc_files + +FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') +TARGET_COORD = (39.01, -105.15) +FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] +INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') +target = (19.3, -123.5) +shape = (8, 8) +sample_shape = (8, 8, 6) +temporal_slice = slice(None, None, 1) +list_chunk_size = 10 +fwp_chunk_shape = (4, 4, 150) +s_enhance = 3 +t_enhance = 4 + + +def test_fwp_multi_step_model_topo_exoskip(log=False): + """Test the forward pass with a multi step model class using exogenous data + for the first two steps and not the last""" + + if log: + init_logger('sup3r', log_level='DEBUG') + + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = ['U_100m', 'V_100m'] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhance = 12 + t_enhance = 4 + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'} + ] + } + } + + model_kwargs = { + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] + } + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=fwp_chunk_shape, + input_handler_kwargs=input_handler_kwargs, + spatial_pad=0, + temporal_pad=0, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + + assert forward_pass.output_workers == max_workers + assert forward_pass.pass_workers == max_workers + assert forward_pass.max_workers == max_workers + assert forward_pass.data_handler.max_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 3 # three step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_multi_step_spatial_model_topo_noskip(): + """Test the forward pass with a multi step spatial only model class using + exogenous data for all model steps""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '16km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '8km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 1] + s_enhance = np.product(s_enhancements) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'}, + ] + } + } + + model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]} + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 2 # two step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_multi_step_model_topo_noskip(): + """Test the forward pass with a multi step model class using exogenous data + for all model steps""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 3] + s_enhance = np.product(s_enhancements) + t_enhance = 4 + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 2, 'combine_type': 'input'}] + } + } + + model_kwargs = { + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] + } + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + + assert forward_pass.output_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 3 # three step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_single_step_wind_hi_res_topo(plot=False): + """Test the forward pass with a single spatiotemporal Sup3rGan model + requiring high-resolution topography input from the exogenous_data + feature.""" + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", + "temporal_mult": 2, + "temporal_method": "nearest" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", + "spatial_mult": 2 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "alpha": 0.2, + "class": "LeakyReLU" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 2, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['s_enhance'] = 2 + model.meta['t_enhance'] = 2 + model.meta['input_resolution'] = {'spatial': '8km', + 'temporal': '60min'} + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', + 'data': np.random.rand(4, 20, 20, 12, 1)}]}} + _ = model.generate(np.random.rand(4, 10, 10, 6, 3), exogenous_data=exo_tmp) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + model.save(st_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'} + ]}} + + model_kwargs = {'model_dir': st_out_dir} + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='Sup3rGan', + fwp_chunk_shape=(8, 8, 8), + spatial_pad=4, + temporal_pad=4, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + + if plot: + for ifeature, feature in enumerate(forward_pass.output_features): + fig = plt.figure(figsize=(15, 5)) + ax1 = fig.add_subplot(111) + vmin = np.min(forward_pass.input_data[..., ifeature]) + vmax = np.max(forward_pass.input_data[..., ifeature]) + nc = ax1.imshow(forward_pass.input_data[..., 0, ifeature], + vmin=vmin, + vmax=vmax) + fig.colorbar(nc, ax=ax1, shrink=0.6, label=f'{feature}') + plt.savefig(f'./input_{feature}.png') + plt.close() + + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + + +def test_fwp_multi_step_wind_hi_res_topo(): + """Test the forward pass with multiple Sup3rGan models requiring + high-resolution topograph input from the exogenous_data feature.""" + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "class": "Activation", + "activation": "relu" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', + 'data': np.random.rand(4, 20, 20, 1)}]}} + _ = s1_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=exo_tmp) + + s2_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=exo_tmp) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + } + } + + model_kwargs = { + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] + } + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + with pytest.raises(RuntimeError): + # should raise error since steps doesn't include + # {'model': 2, 'combine_type': 'input'} + steps = [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}] + exo_kwargs['topography']['steps'] = steps + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + steps = [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}, + {'model': 2, 'combine_type': 'input'}] + exo_kwargs['topography']['steps'] = steps + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + + +def test_fwp_wind_hi_res_topo_plus_linear(): + """Test the forward pass with a Sup3rGan model requiring high-res topo + input from exo data for spatial enhancement and a linear interpolation + model for temporal enhancement.""" + + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "alpha": 0.2, + "class": "LeakyReLU" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s_model.meta['output_features'] = ['U_100m', 'V_100m'] + s_model.meta['s_enhance'] = 2 + s_model.meta['t_enhance'] = 1 + s_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + exo_tmp = { + 'topography': { + 'steps': [ + {'combine_type': 'layer', 'data': np.ones((4, 20, 20, 1))}]}} + _ = s_model.generate(np.ones((4, 10, 10, 3)), + exogenous_data=exo_tmp) + + t_model = LinearInterp(features=['U_100m', 'V_100m'], + s_enhance=1, + t_enhance=4) + t_model.meta['input_resolution'] = {'spatial': '4km', + 'temporal': '60min'} + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + s_out_dir = os.path.join(td, 's_gan') + t_out_dir = os.path.join(td, 't_interp') + s_model.save(s_out_dir) + t_model.save(t_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}] + } + } + + model_kwargs = { + 'model_dirs': [s_out_dir, t_out_dir] + } + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + + +def test_fwp_multi_step_model_multi_exo(): + """Test the forward pass with a multi step model class using 2 exogenous + data features""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography' + ] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography' + ] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + st_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'sza' + ] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 3] + s_enhance = np.product(s_enhancements) + t_enhance = 4 + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'}] + }, + 'sza': { + 'file_paths': input_files, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_handler': 'SzaExtract', + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [{'model': 2, 'combine_type': 'input'}] + } + } + + model_kwargs = { + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] + } + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + + assert forward_pass.output_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 3 # three step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_multi_step_exo_hi_res_topo_and_sza(): + """Test the forward pass with multiple ExoGan models requiring + high-resolution topography and sza input from the exogenous_data + feature.""" + Sup3rGan.seed() + gen_s_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "class": "Activation", + "activation": "relu" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "Sup3rConcat", + "name": "sza" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }] + + gen_t_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }, { + "alpha": 0.2, "class": "LeakyReLU" + }, { + "class": "SpatioTemporalExpansion", "temporal_mult": 2, + "temporal_method": "nearest" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 36, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", "spatial_mult": 3 + }, { + "alpha": 0.2, "class": "LeakyReLU" + }, { + "class": "Sup3rConcat", "name": "sza" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 2, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(gen_s_model, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + exo_tmp = { + 'topography': { + 'steps': [{'model': 0, 'combine_type': 'layer', + 'data': np.ones((4, 20, 20, 1))}]}, + 'sza': { + 'steps': [{'model': 0, 'combine_type': 'layer', + 'data': np.ones((4, 20, 20, 1))}]} + } + _ = s1_model.generate(np.ones((4, 10, 10, 4)), exogenous_data=exo_tmp) + + s2_model = Sup3rGan(gen_s_model, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 4)), exogenous_data=exo_tmp) + + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(gen_t_model, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'sza' + ] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 2 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + exo_tmp = { + 'sza': { + 'steps': [{'model': 0, 'combine_type': 'layer', + 'data': np.ones((4, 30, 30, 12, 1))}]} + } + _ = st_model.generate(np.ones((4, 10, 10, 6, 3)), exogenous_data=exo_tmp) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}] + }, + 'sza': { + 'file_paths': input_files, + 'exo_handler': 'SzaExtract', + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}, + {'model': 2, 'combine_type': 'input'}, + {'model': 2, 'combine_type': 'layer'}] + } + } + + model_kwargs = { + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] + } + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index 6c9901e719..3da773b44c 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -1,14 +1,18 @@ # -*- coding: utf-8 -*- """Test forward passes through multi-step GAN models""" import os +import tempfile + import numpy as np import pytest -import tempfile from sup3r import CONFIG_DIR -from sup3r.models import (Sup3rGan, MultiStepGan, - SpatialThenTemporalGan, TemporalThenSpatialGan, - SolarMultiStepGan, LinearInterp) +from sup3r.models import ( + LinearInterp, + MultiStepGan, + SolarMultiStepGan, + Sup3rGan, +) FEATURES = ['U_100m', 'V_100m'] @@ -65,6 +69,9 @@ def test_multi_step_norm(norm_option): model2.set_norm_stats([0.1, 0.8], [0.04, 0.02]) model3.set_norm_stats([0.1, 0.8], [0.04, 0.02]) + model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} + model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} + model3.meta['input_resolution'] = {'spatial': '3km', 'temporal': '4min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -113,6 +120,8 @@ def test_spatial_then_temporal_gan(): model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} + model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -124,7 +133,7 @@ def test_spatial_then_temporal_gan(): model1.save(fp1) model2.save(fp2) - ms_model = SpatialThenTemporalGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp1, fp2]) x = np.ones((4, 10, 10, len(FEATURES))) out = ms_model.generate(x) @@ -145,6 +154,8 @@ def test_temporal_then_spatial_gan(): model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} + model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -156,11 +167,11 @@ def test_temporal_then_spatial_gan(): model1.save(fp1) model2.save(fp2) - ms_model = TemporalThenSpatialGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp2, fp1]) x = np.ones((1, 10, 10, 4, len(FEATURES))) out = ms_model.generate(x) - assert out.shape == (1, 60, 60, 16, 2) + assert out.shape == (16, 60, 60, 2) def test_spatial_gan_then_linear_interp(): @@ -173,6 +184,7 @@ def test_spatial_gan_then_linear_interp(): model2 = LinearInterp(features=FEATURES, s_enhance=3, t_enhance=4) model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) @@ -182,7 +194,7 @@ def test_spatial_gan_then_linear_interp(): model1.save(fp1) model2.save(fp2) - ms_model = SpatialThenTemporalGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp1, fp2]) x = np.ones((4, 10, 10, len(FEATURES))) out = ms_model.generate(x) @@ -199,6 +211,7 @@ def test_solar_multistep(): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) model1.set_norm_stats([0.7], [0.04]) + model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} model1.set_model_params(training_features=features1, output_features=features1) @@ -208,6 +221,7 @@ def test_solar_multistep(): model2 = Sup3rGan(fp_gen, fp_disc) _ = model2.generate(np.ones((4, 10, 10, len(features2)))) model2.set_norm_stats([4.2, 5.6], [1.1, 1.3]) + model2.meta['input_resolution'] = {'spatial': '4km', 'temporal': '40min'} model2.set_model_params(training_features=features2, output_features=features2) @@ -218,6 +232,7 @@ def test_solar_multistep(): model3 = Sup3rGan(fp_gen, fp_disc) _ = model3.generate(np.ones((4, 10, 10, 3, len(features_in_3)))) model3.set_norm_stats([0.7, 4.2, 5.6], [0.04, 1.1, 1.3]) + model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} model3.set_model_params(training_features=features_in_3, output_features=features_out_3) diff --git a/tests/forward_pass/test_solar_module.py b/tests/forward_pass/test_solar_module.py index 29865bf789..428509b165 100644 --- a/tests/forward_pass/test_solar_module.py +++ b/tests/forward_pass/test_solar_module.py @@ -1,23 +1,23 @@ # -*- coding: utf-8 -*- """Test the custom sup3r solar module that converts GAN clearsky ratio outputs to irradiance data.""" -import pytest -from click.testing import CliRunner import glob import json import os -import numpy as np import tempfile +from pathlib import Path + import matplotlib.pyplot as plt +import numpy as np +import pytest +from click.testing import CliRunner from rex import Resource -from pathlib import Path from sup3r import TEST_DATA_DIR from sup3r.solar import Solar -from sup3r.utilities.utilities import pd_date_range -from sup3r.utilities.pytest import make_fake_cs_ratio_files from sup3r.solar.solar_cli import from_config as solar_main - +from sup3r.utilities.pytest import make_fake_cs_ratio_files +from sup3r.utilities.utilities import pd_date_range NSRDB_FP = os.path.join(TEST_DATA_DIR, 'test_nsrdb_clearsky_2018.h5') GAN_META = {'s_enhance': 4, 't_enhance': 24} @@ -119,7 +119,7 @@ def test_chunk_file_parser(): with tempfile.TemporaryDirectory() as td: for idt in id_temporal: for ids in id_spatial: - fn = ('sup3r_chunk_out_{}_{}.h5'.format(idt, ids)) + fn = 'sup3r_chunk_out_{}_{}.h5'.format(idt, ids) fp = os.path.join(td, fn) Path(fp).touch() all_st_ids.append('{}_{}'.format(idt, ids)) @@ -169,13 +169,14 @@ def test_solar_cli(runner): log_file = os.path.join(td, 'logs/sup3r_solar.log') if os.path.exists(log_file): - with open(log_file, 'r') as f: + with open(log_file) as f: logs = ''.join(list(f.readlines())) msg += '\nlogs:\n{}'.format(logs) raise RuntimeError(msg) - status_files = glob.glob(os.path.join(td, 'jobstatus_*.json')) + status_files = glob.glob(os.path.join(f'{td}/.gaps/', + '*jobstatus*.json')) assert len(status_files) == len(fps) out_files = glob.glob(os.path.join(td, 'chunks/*_irradiance.h5')) diff --git a/tests/forward_pass/test_surface_model.py b/tests/forward_pass/test_surface_model.py index 0373394e2c..d4358e51e4 100644 --- a/tests/forward_pass/test_surface_model.py +++ b/tests/forward_pass/test_surface_model.py @@ -4,15 +4,15 @@ import json import os import tempfile -import pytest -import numpy as np +import numpy as np +import pytest from rex import Resource -from sup3r.models import Sup3rGan from sup3r import CONFIG_DIR, TEST_DATA_DIR -from sup3r.models.surface import SurfaceSpatialMetModel +from sup3r.models import Sup3rGan from sup3r.models.multi_step import MultiStepSurfaceMetGan +from sup3r.models.surface import SurfaceSpatialMetModel from sup3r.utilities.utilities import spatial_coarsening INPUT_FILE_W = os.path.join(TEST_DATA_DIR, 'test_wtk_surface_vars.h5') @@ -61,8 +61,8 @@ def test_surface_model(s_enhance=5): json.dump(kwargs, f) model = SurfaceSpatialMetModel.load(model_dir=td) - - hi_res = model.generate(low_res, exogenous_data=[topo_lr, topo_hr]) + exo_tmp = {'topography': {'steps': [{'data': topo_lr}, {'data': topo_hr}]}} + hi_res = model.generate(low_res, exogenous_data=exo_tmp) diff = true_hi_res - hi_res @@ -86,7 +86,9 @@ def test_train_rh_model(s_enhance=10): true_hr_rh = np.transpose(true_hi_res[..., 1], axes=(1, 2, 0)) model = SurfaceSpatialMetModel(FEATURES, s_enhance=s_enhance) - w_delta_temp, w_delta_topo = model.train(true_hr_temp, true_hr_rh, topo_hr) + w_delta_temp, w_delta_topo = model.train( + true_hr_temp, true_hr_rh, topo_hr, + input_resolution={'spatial': '3km', 'temporal': '60min'}) # pretty generous tolerances because the training dataset is so small assert np.allclose(w_delta_temp, SurfaceSpatialMetModel.W_DELTA_TEMP, @@ -120,6 +122,8 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): model.set_norm_stats([0.3, 0.9, 0.1], [0.02, 0.07, 0.03]) model.set_model_params(training_features=FEATURES, output_features=FEATURES, + input_resolution={'spatial': '30km', + 'temporal': '60min'}, s_enhance=1, t_enhance=t_enhance) @@ -155,7 +159,13 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): topo_lr = topo_lr[:4, :4] topo_hr = topo_hr[:8, :8] - hi_res = ms_model.generate(low_res, exogenous_data=[topo_lr, topo_hr]) + exo_tmp = { + 'topography': { + 'steps': [{'model': 0, 'combine_type': 'input', + 'data': topo_lr}, + {'model': 0, 'combine_type': 'output', + 'data': topo_hr}]}} + hi_res = ms_model.generate(low_res, exogenous_data=exo_tmp) target_shape = (1, low_res.shape[1] * s_enhance, diff --git a/tests/output/test_qa.py b/tests/output/test_qa.py index d7a4747d39..06ebb0a049 100644 --- a/tests/output/test_qa.py +++ b/tests/output/test_qa.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """pytests for data handling""" import os +import pickle import tempfile -import pandas as pd + import numpy as np -from rex import Resource, init_logger +import pandas as pd import xarray as xr -import pickle +from rex import Resource, init_logger -from sup3r import TEST_DATA_DIR, CONFIG_DIR -from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy +from sup3r import CONFIG_DIR, TEST_DATA_DIR from sup3r.models import Sup3rGan -from sup3r.utilities.pytest import make_fake_nc_files +from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.qa.qa import Sup3rQa from sup3r.qa.stats import Sup3rStatsMulti from sup3r.qa.utilities import continuous_dist - +from sup3r.utilities.pytest import make_fake_nc_files FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -201,7 +201,7 @@ def test_qa_h5(): assert np.allclose(test_diff, qa_diff, atol=0.01) -def test_stats(log=True): +def test_stats(log=False): """Test the WindStats module with forward pass output to h5 file.""" if log: diff --git a/tests/pipeline/test_cli.py b/tests/pipeline/test_cli.py index 396e8ccdbd..547ae1c747 100644 --- a/tests/pipeline/test_cli.py +++ b/tests/pipeline/test_cli.py @@ -1,25 +1,24 @@ # -*- coding: utf-8 -*- """pytests for sup3r cli""" +import glob import json import os import tempfile -import pytest -import glob -import numpy as np -from rex import ResourceX -from rex import init_logger +import numpy as np +import pytest from click.testing import CliRunner +from rex import ResourceX, init_logger -from sup3r.pipeline.pipeline_cli import from_config as pipe_main +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models.base import Sup3rGan from sup3r.pipeline.forward_pass_cli import from_config as fwp_main -from sup3r.preprocessing.data_extract_cli import from_config as dh_main +from sup3r.pipeline.pipeline_cli import from_config as pipe_main from sup3r.postprocessing.data_collect_cli import from_config as dc_main +from sup3r.preprocessing.data_extract_cli import from_config as dh_main from sup3r.qa.visual_qa_cli import from_config as vqa_main -from sup3r.models.base import Sup3rGan -from sup3r.utilities.pytest import make_fake_nc_files, make_fake_h5_chunks +from sup3r.utilities.pytest import make_fake_h5_chunks, make_fake_nc_files from sup3r.utilities.utilities import correct_path -from sup3r import TEST_DATA_DIR, CONFIG_DIR INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] @@ -34,7 +33,7 @@ def runner(): return CliRunner() -def test_pipeline_fwp_collect(runner, log=True): +def test_pipeline_fwp_collect(runner, log=False): """Test pipeline with forward pass and data collection""" if log: init_logger('sup3r', log_level='DEBUG') @@ -97,8 +96,8 @@ def test_pipeline_fwp_collect(runner, log=True): with open(pipe_config_path, 'w') as fh: json.dump(pipe_config, fh) - result = runner.invoke(pipe_main, ['-c', pipe_config_path, - '-v', '--monitor']) + result = runner.invoke(pipe_main, ['-c', pipe_config_path, '-v', + '--monitor']) if result.exit_code != 0: import traceback msg = ('Failed with error {}' @@ -197,7 +196,7 @@ def test_data_collection_cli(runner): assert np.allclose(wd_true, fh['winddirection_100m'], atol=0.1) -def test_fwd_pass_cli(runner, log=True): +def test_fwd_pass_cli(runner, log=False): """Test cli call to run forward pass""" if log: init_logger('sup3r', log_level='DEBUG') @@ -210,8 +209,8 @@ def test_fwd_pass_cli(runner, log=True): _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) model.meta['training_features'] = FEATURES model.meta['output_features'] = FEATURES[:2] - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + assert model.s_enhance == 3 + assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 8) @@ -289,20 +288,27 @@ def test_data_extract_cli(runner): assert len(glob.glob(f'{log_file}')) == 1 -def test_pipeline_fwp_qa(runner): +def test_pipeline_fwp_qa(runner, log=True): """Test the sup3r pipeline with Forward Pass and QA modules via pipeline cli""" + if log: + init_logger('sup3r', log_level='DEBUG') + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + input_resolution = {'spatial': '12km', 'temporal': '60min'} + model.meta['input_resolution'] = input_resolution + assert model.input_resolution == input_resolution + assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) model.meta['training_features'] = FEATURES model.meta['output_features'] = FEATURES[:2] - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + assert model.s_enhance == 3 + assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 8) @@ -354,8 +360,8 @@ def test_pipeline_fwp_qa(runner): with open(pipe_config_path, 'w') as fh: json.dump(pipe_config, fh) - result = runner.invoke(pipe_main, ['-c', pipe_config_path, - '-v', '--monitor']) + result = runner.invoke(pipe_main, ['-c', pipe_config_path, '-v', + '--monitor']) if result.exit_code != 0: import traceback msg = ('Failed with error {}' @@ -365,31 +371,34 @@ def test_pipeline_fwp_qa(runner): assert len(glob.glob(f'{td}/fwp_log*.log')) == 1 assert len(glob.glob(f'{td}/out*.h5')) == 1 assert len(glob.glob(f'{td}/qa.h5')) == 1 - assert len(glob.glob(f'{td}/*_status.json')) == 1 - - status_fp = glob.glob(f'{td}/*_status.json')[0] - with open(status_fp, 'r') as f: + status_fps = glob.glob(f'{td}/.gaps/*status*.json') + assert len(status_fps) == 1 + status_fp = status_fps[0] + with open(status_fp) as f: status = json.load(f) assert len(status) == 2 assert len(status['forward-pass']) == 2 fwp_status = status['forward-pass'] del fwp_status['pipeline_index'] - fwp_status = list(fwp_status.values())[0] + fwp_status = next(iter(fwp_status.values())) assert fwp_status['job_status'] == 'successful' assert fwp_status['time'] > 0 assert len(status['qa']) == 2 qa_status = status['qa'] del qa_status['pipeline_index'] - qa_status = list(qa_status.values())[0] + qa_status = next(iter(qa_status.values())) assert qa_status['job_status'] == 'successful' assert qa_status['time'] > 0 -def test_visual_qa(runner): +def test_visual_qa(runner, log=False): """Make sure visual qa module creates the right number of plots""" + if log: + init_logger('sup3r', log_level='DEBUG') + time_step = 500 plot_features = ['windspeed_100m', 'winddirection_100m'] with ResourceX(FP_WTK) as res: diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 7274d7a396..fd03b989d5 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1,17 +1,17 @@ """Sup3r pipeline tests""" -import tempfile -import os +import glob import json +import os import shutil -import numpy as np -import glob +import tempfile +import numpy as np from rex import ResourceX -from sup3r.pipeline.pipeline import Sup3rPipeline as Pipeline +from sup3r import CONFIG_DIR, TEST_DATA_DIR from sup3r.models.base import Sup3rGan +from sup3r.pipeline.pipeline import Sup3rPipeline as Pipeline from sup3r.utilities.pytest import make_fake_nc_files -from sup3r import TEST_DATA_DIR, CONFIG_DIR INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] @@ -26,10 +26,15 @@ def test_fwp_pipeline(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) + input_resolution = {'spatial': '12km', 'temporal': '60min'} + model.meta['input_resolution'] = input_resolution + assert model.input_resolution == input_resolution + assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} + _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) model.meta['training_features'] = FEATURES model.meta['output_features'] = FEATURES[:2] - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + assert model.s_enhance == 3 + assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 20) @@ -98,8 +103,10 @@ def test_fwp_pipeline(): with ResourceX(fp_out) as f: assert len(f.time_index) == t_enhance * n_tsteps - status_file = glob.glob(os.path.join(td, '*_status.json'))[0] - with open(status_file, 'r') as fh: + status_files = glob.glob(os.path.join(td, '.gaps', '*status.json')) + assert len(status_files) == 1 + status_file = status_files[0] + with open(status_file) as fh: status = json.load(fh) assert all(s in status for s in ('forward-pass', 'data-collect')) assert all(s not in str(status) diff --git a/tests/training/test_train_conditional_moments.py b/tests/training/test_train_conditional_moments.py index 927da3753a..71fbfb2814 100644 --- a/tests/training/test_train_conditional_moments.py +++ b/tests/training/test_train_conditional_moments.py @@ -1,33 +1,32 @@ # -*- coding: utf-8 -*- """Test the basic training of super resolution GAN""" import os +import tempfile + # import json import numpy as np import pytest -import tempfile import tensorflow as tf -from tensorflow.python.framework.errors_impl import InvalidArgumentError - from rex import init_logger +from tensorflow.python.framework.errors_impl import InvalidArgumentError -from sup3r import TEST_DATA_DIR -from sup3r import CONFIG_DIR +from sup3r import CONFIG_DIR, TEST_DATA_DIR from sup3r.models import Sup3rCondMom -from sup3r.preprocessing.data_handling import DataHandlerH5 from sup3r.preprocessing.conditional_moment_batch_handling import ( - SpatialBatchHandlerMom1, - SpatialBatchHandlerMom1SF, - SpatialBatchHandlerMom2, - SpatialBatchHandlerMom2Sep, - SpatialBatchHandlerMom2SF, - SpatialBatchHandlerMom2SepSF, BatchHandlerMom1, BatchHandlerMom1SF, BatchHandlerMom2, BatchHandlerMom2Sep, + BatchHandlerMom2SepSF, BatchHandlerMom2SF, - BatchHandlerMom2SepSF) - + SpatialBatchHandlerMom1, + SpatialBatchHandlerMom1SF, + SpatialBatchHandlerMom2, + SpatialBatchHandlerMom2Sep, + SpatialBatchHandlerMom2SepSF, + SpatialBatchHandlerMom2SF, +) +from sup3r.preprocessing.data_handling import DataHandlerH5 FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -79,7 +78,9 @@ def test_train_s_mom1(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -167,7 +168,9 @@ def test_train_s_mom1_sf(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -227,7 +230,10 @@ def test_train_s_mom2(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -288,7 +294,10 @@ def test_train_s_mom2_sf(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -340,7 +349,10 @@ def test_train_s_mom2_sep(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -392,7 +404,10 @@ def test_train_s_mom2_sep_sf(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -436,7 +451,9 @@ def test_train_st_mom1(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -485,7 +502,9 @@ def test_train_st_mom1_sf(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -541,7 +560,10 @@ def test_train_st_mom2(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -599,7 +621,10 @@ def test_train_st_mom2_sf(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -645,7 +670,10 @@ def test_train_st_mom2_sep(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -691,7 +719,10 @@ def test_train_st_mom2_sep_sf(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 17bbfb75ec..418dc30611 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -52,8 +52,13 @@ def test_train_spatial(log=False, full_shape=(20, 20), with tempfile.TemporaryDirectory() as td: # test that training works and reduces loss - model.train(batch_handler, n_epoch=n_epoch, weight_gen_advers=0.0, - train_gen=True, train_disc=False, checkpoint_int=1, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, + weight_gen_advers=0.0, + train_gen=True, + train_disc=False, + checkpoint_int=1, out_dir=os.path.join(td, 'test_{epoch}')) assert len(model.history) == n_epoch @@ -122,7 +127,9 @@ def test_train_st_weight_update(n_epoch=2, log=False): adaptive_update_bounds = (0.9, 0.99) with tempfile.TemporaryDirectory() as td: - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + n_epoch=n_epoch, weight_gen_advers=1e-6, train_gen=True, train_disc=True, checkpoint_int=10, @@ -178,7 +185,9 @@ def test_train_spatial_dc(log=False, full_shape=(20, 20), with tempfile.TemporaryDirectory() as td: # test that the normalized number of samples from each bin is close # to the weight for that bin - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=2, @@ -226,7 +235,9 @@ def test_train_st_dc(n_epoch=2, log=False): with tempfile.TemporaryDirectory() as td: # test that the normalized number of samples from each bin is close # to the weight for that bin - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=2, @@ -274,7 +285,9 @@ def test_train_st(n_epoch=2, log=False): with tempfile.TemporaryDirectory() as td: # test that training works and reduces loss - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=1, @@ -369,3 +382,34 @@ def test_optimizer_update(): assert model.optimizer.learning_rate == 0.1 assert model.optimizer_disc.learning_rate == 0.1 + + +def test_input_res_check(): + """Make sure error is raised for invalid input resolution""" + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4, + learning_rate_disc=4e-4) + + with pytest.raises(RuntimeError): + model.set_model_params( + input_resolution={'spatial': '22km', 'temporal': '9min'}) + + +def test_enhancement_check(): + """Make sure error is raised for invalid enhancement factor inputs""" + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4, + learning_rate_disc=4e-4) + + with pytest.raises(RuntimeError): + model.set_model_params( + input_resolution={'spatial': '12km', 'temporal': '60min'}, + s_enhance=7, t_enhance=3) diff --git a/tests/training/test_train_wind.py b/tests/training/test_train_gan_exo.py similarity index 68% rename from tests/training/test_train_wind.py rename to tests/training/test_train_gan_exo.py index 0eff729434..741019feb7 100644 --- a/tests/training/test_train_wind.py +++ b/tests/training/test_train_gan_exo.py @@ -8,10 +8,9 @@ from rex import init_logger from sup3r import CONFIG_DIR, TEST_DATA_DIR -from sup3r.models import WindGan -from sup3r.models.data_centric import WindGanDC +from sup3r.models import Sup3rGan +from sup3r.models.data_centric import Sup3rGanDC from sup3r.preprocessing.batch_handling import ( - BatchHandlerCC, BatchHandlerDC, SpatialBatchHandler, SpatialBatchHandlerCC, @@ -36,102 +35,6 @@ TARGET_COORD = (39.01, -105.15) -def test_wind_cc_model(log=False): - """Test the wind climate change wtk super res model. - - NOTE that the full 10x model is too big to train on the 20x20 test data. - """ - - handler = DataHandlerH5WindCC(INPUT_FILE_W, FEATURES_W, - target=TARGET_W, shape=SHAPE, - temporal_slice=slice(None, None, 2), - time_roll=-7, - sample_shape=(20, 20, 96), - worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) - - batcher = BatchHandlerCC([handler], batch_size=2, n_batches=2, - s_enhance=4, sub_daily_shape=None) - - if log: - init_logger('sup3r', log_level='DEBUG') - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_4x_24x_3f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - - WindGan.seed() - model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) - - with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, - weight_gen_advers=0.0, - train_gen=True, train_disc=False, - checkpoint_int=None, - out_dir=os.path.join(td, 'test_{epoch}')) - - assert 'test_0' in os.listdir(td) - assert model.meta['class'] == 'WindGan' - assert 'topography' in batcher.output_features - assert 'topography' not in model.output_features - assert len(model.output_features) == len(FEATURES_W) - 1 - - x = np.random.uniform(0, 1, (1, 4, 4, 4, 4)) - y = model.generate(x) - assert y.shape[0] == x.shape[0] - assert y.shape[1] == x.shape[1] * 4 - assert y.shape[2] == x.shape[2] * 4 - assert y.shape[3] == x.shape[3] * 24 - assert y.shape[4] == x.shape[4] - 1 - - -def test_wind_cc_model_spatial(log=False): - """Test the wind climate change wtk super res model with spatial - enhancement only. - """ - handler = DataHandlerH5WindCC(INPUT_FILE_W, - ('U_100m', 'V_100m', 'topography'), - target=TARGET_W, shape=SHAPE, - temporal_slice=slice(None, None, 2), - time_roll=-7, - val_split=0.1, - sample_shape=(20, 20), - worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) - - batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2, - s_enhance=2) - - if log: - init_logger('sup3r', log_level='DEBUG') - - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - - WindGan.seed() - model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) - - with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, - weight_gen_advers=0.0, - train_gen=True, train_disc=False, - checkpoint_int=None, - out_dir=os.path.join(td, 'test_{epoch}')) - - assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGan' - assert 'topography' in batcher.output_features - assert 'topography' not in model.output_features - - x = np.random.uniform(0, 1, (4, 30, 30, 3)) - y = model.generate(x) - assert y.shape[0] == x.shape[0] - assert y.shape[1] == x.shape[1] * 2 - assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == x.shape[3] - 1 - - @pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat']) def test_wind_hi_res_topo(custom_layer, log=False): """Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat @@ -146,7 +49,7 @@ def test_wind_hi_res_topo(custom_layer, log=False): val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + train_only_features=()) batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2, s_enhance=2) @@ -177,7 +80,7 @@ def test_wind_hi_res_topo(custom_layer, log=False): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -188,11 +91,14 @@ def test_wind_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - WindGan.seed() - model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + Sup3rGan.seed() + model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -200,17 +106,21 @@ def test_wind_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGan' + assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) - hi_res_topo = np.random.uniform(0, 1, (60, 60)) + hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): y = model.generate(x, exogenous_data=None) - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 @@ -262,7 +172,7 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -273,11 +183,14 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - WindGan.seed() - model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + Sup3rGan.seed() + model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -285,17 +198,21 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGan' + assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) - hi_res_topo = np.random.uniform(0, 1, (60, 60)) + hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): y = model.generate(x, exogenous_data=None) - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 @@ -347,7 +264,7 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): {"class": "SpatioTemporalExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], @@ -358,11 +275,14 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - WindGanDC.seed() - model = WindGanDC(gen_model, fp_disc, learning_rate=1e-4) + Sup3rGanDC.seed() + model = Sup3rGanDC(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -370,7 +290,7 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGanDC' + assert model.meta['class'] == 'Sup3rGanDC' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features @@ -380,7 +300,11 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): with pytest.raises(RuntimeError): y = model.generate(x, exogenous_data=None) - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 diff --git a/tests/training/test_train_gan_lr_era.py b/tests/training/test_train_gan_lr_era.py index 59ef88105a..52c1974ff4 100644 --- a/tests/training/test_train_gan_lr_era.py +++ b/tests/training/test_train_gan_lr_era.py @@ -73,6 +73,7 @@ def test_train_spatial( # test that training works and reduces loss model.train( batch_handler, + input_resolution={'spatial': '30km', 'temporal': '60min'}, n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, @@ -122,7 +123,7 @@ def test_train_spatial( assert loss_og.numpy() < loss_dummy.numpy() -def test_train_st(n_epoch=3, log=True): +def test_train_st(n_epoch=3, log=False): """Test basic spatiotemporal model training with only gen content loss.""" if log: init_logger('sup3r', log_level='DEBUG') @@ -173,6 +174,7 @@ def test_train_st(n_epoch=3, log=True): # test that training works and reduces loss model.train( batch_handler, + input_resolution={'spatial': '30km', 'temporal': '60min'}, n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, diff --git a/tests/training/test_train_solar.py b/tests/training/test_train_solar.py index 2ac26fc108..9b3a14bcab 100644 --- a/tests/training/test_train_solar.py +++ b/tests/training/test_train_solar.py @@ -2,19 +2,19 @@ """Test the basic training of super resolution GAN for solar climate change applications""" import os -import numpy as np import tempfile -from tensorflow.keras.losses import MeanAbsoluteError +import numpy as np from rex import init_logger +from tensorflow.keras.losses import MeanAbsoluteError -from sup3r import TEST_DATA_DIR -from sup3r import CONFIG_DIR -from sup3r.models import Sup3rGan, SolarCC +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models import SolarCC, Sup3rGan +from sup3r.preprocessing.batch_handling import ( + BatchHandlerCC, + SpatialBatchHandlerCC, +) from sup3r.preprocessing.data_handling import DataHandlerH5SolarCC -from sup3r.preprocessing.batch_handling import (BatchHandlerCC, - SpatialBatchHandlerCC) - SHAPE = (20, 20) @@ -54,7 +54,9 @@ def test_solar_cc_model(log=False): loss='MeanAbsoluteError') with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '4km', 'temporal': '40min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -108,7 +110,9 @@ def test_solar_cc_model_spatial(log=False): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '25km', 'temporal': '15min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -149,7 +153,9 @@ def test_solar_custom_loss(log=False): loss='MeanAbsoluteError') with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '4km', 'temporal': '40min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, diff --git a/tests/training/test_train_wind_conditional_moments.py b/tests/training/test_train_wind_conditional_moments.py index 7ac2ae7a2b..20cd092d23 100644 --- a/tests/training/test_train_wind_conditional_moments.py +++ b/tests/training/test_train_wind_conditional_moments.py @@ -2,30 +2,29 @@ """Test the basic training of super resolution GAN for solar climate change applications""" import os -import pytest -import numpy as np import tempfile +import numpy as np +import pytest from rex import init_logger -from sup3r import CONFIG_DIR -from sup3r import TEST_DATA_DIR -from sup3r.models import WindCondMom +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models import Sup3rCondMom from sup3r.preprocessing.data_handling import DataHandlerH5 from sup3r.preprocessing.wind_conditional_moment_batch_handling import ( + WindBatchHandlerMom1, + WindBatchHandlerMom1SF, + WindBatchHandlerMom2, + WindBatchHandlerMom2Sep, + WindBatchHandlerMom2SepSF, + WindBatchHandlerMom2SF, WindSpatialBatchHandlerMom1, WindSpatialBatchHandlerMom1SF, WindSpatialBatchHandlerMom2, - WindSpatialBatchHandlerMom2SF, WindSpatialBatchHandlerMom2Sep, WindSpatialBatchHandlerMom2SepSF, - WindBatchHandlerMom1, - WindBatchHandlerMom1SF, - WindBatchHandlerMom2, - WindBatchHandlerMom2SF, - WindBatchHandlerMom2Sep, - WindBatchHandlerMom2SepSF) - + WindSpatialBatchHandlerMom2SF, +) SHAPE = (20, 20) @@ -67,7 +66,7 @@ def make_s_gen_model(custom_layer): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -108,24 +107,32 @@ def test_wind_non_cc_hi_res_topo_mom1(custom_layer, batch_class, gen_model = make_s_gen_model(custom_layer) - WindCondMom.seed() - model = WindCondMom(gen_model, learning_rate=1e-4) + Sup3rCondMom.seed() + model = Sup3rCondMom(gen_model, learning_rate=1e-4) + input_resolution = {'spatial': '8km', 'temporal': '60min'} with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batcher, n_epoch=n_epoch, + model.train(batcher, + input_resolution={'spatial': '8km', 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) assert f'test_{n_epoch-1}' in os.listdir(out_dir_root) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindCondMom' + assert model.meta['class'] == 'Sup3rCondMom' + assert model.meta['input_resolution'] == input_resolution assert 'topography' in batcher.output_features assert 'topography' not in model.output_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) - hi_res_topo = np.random.uniform(0, 1, (60, 60)) + hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 @@ -159,8 +166,8 @@ def test_wind_non_cc_hi_res_st_topo_mom1(batch_class, log=False, 'sup3rcc', 'gen_wind_3x_4x_2f.json') - WindCondMom.seed() - model_mom1 = WindCondMom(fp_gen, learning_rate=1e-4) + Sup3rCondMom.seed() + model_mom1 = Sup3rCondMom(fp_gen, learning_rate=1e-4) batcher = batch_class([handler], batch_size=batch_size, @@ -171,7 +178,10 @@ def test_wind_non_cc_hi_res_st_topo_mom1(batch_class, log=False, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom1.train(batcher, n_epoch=n_epoch, + model_mom1.train(batcher, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -204,9 +214,9 @@ def test_wind_non_cc_hi_res_topo_mom2(custom_layer, batch_class, gen_model = make_s_gen_model(custom_layer) - WindCondMom.seed() - model_mom1 = WindCondMom(gen_model, learning_rate=1e-4) - model_mom2 = WindCondMom(gen_model, learning_rate=1e-4) + Sup3rCondMom.seed() + model_mom1 = Sup3rCondMom(gen_model, learning_rate=1e-4) + model_mom2 = Sup3rCondMom(gen_model, learning_rate=1e-4) batcher = batch_class([handler], batch_size=batch_size, @@ -217,7 +227,10 @@ def test_wind_non_cc_hi_res_topo_mom2(custom_layer, batch_class, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batcher, n_epoch=n_epoch, + model_mom2.train(batcher, + input_resolution={'spatial': '8km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -251,9 +264,9 @@ def test_wind_non_cc_hi_res_st_topo_mom2(batch_class, log=False, 'sup3rcc', 'gen_wind_3x_4x_2f.json') - WindCondMom.seed() - model_mom1 = WindCondMom(fp_gen, learning_rate=1e-4) - model_mom2 = WindCondMom(fp_gen, learning_rate=1e-4) + Sup3rCondMom.seed() + model_mom1 = Sup3rCondMom(fp_gen, learning_rate=1e-4) + model_mom2 = Sup3rCondMom(fp_gen, learning_rate=1e-4) batcher = batch_class([handler], batch_size=batch_size, @@ -264,7 +277,10 @@ def test_wind_non_cc_hi_res_st_topo_mom2(batch_class, log=False, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batcher, n_epoch=n_epoch, + model_mom2.train(batcher, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) diff --git a/tests/utilities/test_utilities.py b/tests/utilities/test_utilities.py index ec52c93dd7..93ae8be823 100644 --- a/tests/utilities/test_utilities.py +++ b/tests/utilities/test_utilities.py @@ -15,19 +15,22 @@ from sup3r.postprocessing.file_handling import OutputHandler from sup3r.utilities.interpolate_log_profile import LogLinInterpolator from sup3r.utilities.regridder import RegridOutput -from sup3r.utilities.utilities import (get_chunk_slices, spatial_coarsening, - st_interp, transform_rotate_wind, - uniform_box_sampler, - uniform_time_sampler, - weighted_box_sampler, - weighted_time_sampler, - ) +from sup3r.utilities.utilities import ( + get_chunk_slices, + spatial_coarsening, + st_interp, + transform_rotate_wind, + uniform_box_sampler, + uniform_time_sampler, + weighted_box_sampler, + weighted_time_sampler, +) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') FP_ERA = os.path.join(TEST_DATA_DIR, 'test_era5_co_2012.nc') -def test_log_interp(log=True): +def test_log_interp(log=False): """Make sure log interp generates reasonable output (e.g. between input levels)""" if log: