Skip to content

Commit

Permalink
get_exo_loss_input -> get_high_res_exo_input
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Oct 4, 2023
1 parent fb0c50c commit d78bc08
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 13 deletions.
17 changes: 14 additions & 3 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def load(cls, model_dir, verbose=True):
model_dir
"""

@abstractmethod
def generate(self,
low_res,
norm_in=True,
un_norm_out=True,
exogenous_data=None):
"""Use the generator model to generate high res data from low res
input. This is the public generate function."""

@staticmethod
def seed(s=0):
"""
Expand Down Expand Up @@ -141,7 +150,9 @@ def input_resolution(self):
return input_resolution

def _get_numerical_resolutions(self):
"""Get the input and output resolutions without units"""
"""Get the input and output resolutions without units. e.g. for
{"spatial": "30km", "temporal": "60min"} this returns
{"spatial": 30, "temporal": 60}"""
ires_num = {k: int(re.search(r'\d+', v).group(0))
for k, v in self.input_resolution.items()}
enhancements = {'spatial': self.s_enhance,
Expand Down Expand Up @@ -797,7 +808,7 @@ def load_saved_params(out_dir, verbose=True):

return params

def get_exo_loss_input(self, high_res):
def get_high_res_exo_input(self, high_res):
"""Get exogenous feature data from high_res
Parameters
Expand Down Expand Up @@ -1418,7 +1429,7 @@ def get_single_grad(self,
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(training_weights)

hi_res_exo = self.get_exo_loss_input(hi_res_true)
hi_res_exo = self.get_high_res_exo_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss_out = self.calc_loss(hi_res_true, hi_res_gen,
**calc_loss_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
logger = logging.getLogger(__name__)


class Sup3rGan(AbstractInterface, AbstractSingleModel):
class Sup3rGan(AbstractSingleModel, AbstractInterface):
"""Basic sup3r GAN model."""

def __init__(self,
Expand Down Expand Up @@ -602,7 +602,7 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details):
logger.debug('Starting end-of-epoch validation loss calculation...')
loss_details['n_obs'] = 0
for val_batch in batch_handler.val_data:
val_exo_data = self.get_exo_loss_input(val_batch.high_res)
val_exo_data = self.get_high_res_exo_input(val_batch.high_res)
high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data)
_, v_loss_details = self.calc_loss(
val_batch.high_res, high_res_gen,
Expand Down
4 changes: 2 additions & 2 deletions sup3r/models/conditional_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = logging.getLogger(__name__)


class Sup3rCondMom(AbstractInterface, AbstractSingleModel):
class Sup3rCondMom(AbstractSingleModel, AbstractInterface):
"""Basic Sup3r conditional moments model."""

def __init__(self, gen_layers,
Expand Down Expand Up @@ -283,7 +283,7 @@ def calc_val_loss(self, batch_handler, loss_details):
logger.debug('Starting end-of-epoch validation loss calculation...')
loss_details['n_obs'] = 0
for val_batch in batch_handler.val_data:
val_exo_data = self.get_exo_loss_input(val_batch.high_res)
val_exo_data = self.get_high_res_exo_input(val_batch.high_res)
output_gen = self._tf_generate(val_batch.low_res, val_exo_data)
_, v_loss_details = self.calc_loss(
val_batch.output, output_gen, val_batch.mask)
Expand Down
4 changes: 2 additions & 2 deletions sup3r/models/data_centric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers):
"""
losses = []
for obs in batch_handler.val_data:
exo_data = self.get_exo_loss_input(obs.high_res)
exo_data = self.get_high_res_exo_input(obs.high_res)
gen = self._tf_generate(obs.low_res, exo_data)
loss, _ = self.calc_loss(obs.high_res, gen,
weight_gen_advers=weight_gen_advers,
Expand Down Expand Up @@ -66,7 +66,7 @@ def calc_val_loss_gen_content(self, batch_handler):
"""
losses = []
for obs in batch_handler.val_data:
exo_data = self.get_exo_loss_input(obs.high_res)
exo_data = self.get_high_res_exo_input(obs.high_res)
gen = self._tf_generate(obs.low_res, exo_data)
loss = self.calc_loss_gen_content(obs.high_res, gen)
losses.append(float(loss))
Expand Down
19 changes: 15 additions & 4 deletions sup3r/preprocessing/data_handling/exogenous_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,29 @@ def __init__(self,
feature : str
Exogenous feature to extract from source_h5
models : list
List of models used with the given steps list
List of models used with the given steps list. This list of models
is used to determine the input and output resolution and
enhancement factors for each model step which is then used to
determine aggregation factors. If agg factors and enhancement
factors are provided in the steps list the model list is not
needed.
steps : list
List of dictionaries containing info on which models to use for a
given step index and what type of exo data the step requires. e.g.
[{'model': 0, 'combine_type': 'input'},
{'model': 0, 'combine_type': 'layer'}]
Each step entry can also contain s_enhance, t_enhance,
s_agg_factor, t_agg_factor. If they are not included they will be
computed using exo_resolution and model attributes
s_agg_factor, t_agg_factor. e.g.
[{'model': 0, 'combine_type': 'input', 's_agg_factor': 900,
's_enhance': 1, 't_agg_factor': 5, 't_enhance': 1},
{'model': 0, 'combine_type': 'layer', 's_agg_factor', 100,
's_enhance': 3, 't_agg_factor': 5, 't_enhance': 1}]
If they are not included they will be computed using exo_resolution
and model attributes.
exo_resolution : dict
Dictionary of spatiotemporal resolution for the given exo data
source. e.g. {'spatial': '4km', 'temporal': '60min'}
source. e.g. {'spatial': '4km', 'temporal': '60min'}. This is used
only if agg factors are not provided in the steps list.
source_file : str
Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or
4km) data from which will be mapped to the enhanced grid of the
Expand Down

0 comments on commit d78bc08

Please sign in to comment.