Skip to content

Commit

Permalink
edits for multi step spatial only model
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jul 5, 2023
1 parent eb3fc9c commit d5986cf
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 19 deletions.
18 changes: 11 additions & 7 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,16 +1589,20 @@ def _reshape_data_chunk(model, data_chunk, exo_data):
i_lr_s : int
Axis index for the low-resolution spatial_1 dimension
"""

current_model = None
if exo_data is not None:
for i, arr in enumerate(exo_data):
if arr is not None:
current_model = (model if not hasattr(model, 'models')
else model.models[i])
if current_model.input_dims == 4:
exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3))
else:
exo_data[i] = np.expand_dims(arr, axis=0)
if not hasattr(model, 'models'):
current_model = model
elif i < len(model.models):
current_model = model.models[i]

if current_model is not None:
if current_model.input_dims == 4:
exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3))
else:
exo_data[i] = np.expand_dims(arr, axis=0)

if model.input_dims == 4:
i_lr_t = 0
Expand Down
4 changes: 2 additions & 2 deletions sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def time_freq_hours(self):
"""Get the time frequency in hours as a float"""
ti_deltas = self.raw_time_index - np.roll(self.raw_time_index, 1)
ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600
time_freq = float(mode(ti_deltas_hours).mode[0])
time_freq = float(mode(ti_deltas_hours).mode)
return time_freq

@property
Expand Down Expand Up @@ -2546,7 +2546,7 @@ def get_clearsky_ghi(self):

ti_deltas = ti_nsrdb - np.roll(ti_nsrdb, 1)
ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600
time_freq = float(mode(ti_deltas_hours).mode[0])
time_freq = float(mode(ti_deltas_hours).mode)
t_start = self.temporal_slice.start or 0
t_end_target = self.temporal_slice.stop or len(self.raw_time_index)
t_start = int(t_start * 24 * (1 / time_freq))
Expand Down
101 changes: 91 additions & 10 deletions tests/forward_pass/test_forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
"""pytests for data handling"""
import json
import os
import pytest
import tempfile
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import pytest
import tensorflow as tf
import xarray as xr
import matplotlib.pyplot as plt
from rex import ResourceX, init_logger

from sup3r import TEST_DATA_DIR, CONFIG_DIR, __version__
from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__
from sup3r.models import LinearInterp, Sup3rGan, WindGan
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
from sup3r.preprocessing.data_handling import DataHandlerNC
from sup3r.pipeline.forward_pass import (ForwardPass, ForwardPassStrategy)
from sup3r.models import Sup3rGan, WindGan, LinearInterp
from sup3r.utilities.pytest import make_fake_nc_files

from rex import ResourceX
from rex import init_logger


FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)
FEATURES = ['U_100m', 'V_100m', 'BVF2_200m']
Expand Down Expand Up @@ -562,6 +560,89 @@ def test_fwp_multi_step_model_topo_exoskip(log=False):
'topography']


def test_fwp_multi_step_spatial_model_topo_noskip():
"""Test the forward pass with a multi step spatial only model class using
exogenous data for all model steps"""
Sup3rGan.seed()
fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json')
fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json')
s1_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4)
s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography']
s1_model.meta['output_features'] = ['U_100m', 'V_100m']
s1_model.meta['s_enhance'] = 2
s1_model.meta['t_enhance'] = 1
_ = s1_model.generate(np.ones((4, 10, 10, 3)))

s2_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4)
s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography']
s2_model.meta['output_features'] = ['U_100m', 'V_100m']
s2_model.meta['s_enhance'] = 2
s2_model.meta['t_enhance'] = 1
_ = s2_model.generate(np.ones((4, 10, 10, 3)))

with tempfile.TemporaryDirectory() as td:
input_files = make_fake_nc_files(td, INPUT_FILE, 8)

s1_out_dir = os.path.join(td, 's1_gan')
s2_out_dir = os.path.join(td, 's2_gan')
s1_model.save(s1_out_dir)
s2_model.save(s2_out_dir)

max_workers = 1
fwp_chunk_shape = (4, 4, 8)
s_enhancements = [2, 2, 1]
s_enhance = np.product(s_enhancements)

exo_kwargs = {'file_paths': input_files,
'features': ['topography'],
'source_file': FP_WTK,
'target': target,
'shape': shape,
's_enhancements': [1, 2, 2],
'agg_factors': [12, 4, 2]
}

model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]}

out_files = os.path.join(td, 'out_{file_id}.h5')
input_handler_kwargs = dict(
target=target, shape=shape,
temporal_slice=temporal_slice,
worker_kwargs=dict(max_workers=max_workers),
overwrite_cache=True)
handler = ForwardPassStrategy(
input_files, model_kwargs=model_kwargs,
model_class='MultiStepGan',
fwp_chunk_shape=fwp_chunk_shape,
spatial_pad=1, temporal_pad=1,
input_handler_kwargs=input_handler_kwargs,
out_pattern=out_files,
worker_kwargs=dict(max_workers=max_workers),
exo_kwargs=exo_kwargs,
max_nodes=1)

forward_pass = ForwardPass(handler)
forward_pass.run(handler, node_index=0)

with ResourceX(handler.out_files[0]) as fh:
assert fh.shape == (
len(input_files),
s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1])
assert all(f in fh.attrs for f in ('windspeed_100m',
'winddirection_100m'))

assert fh.global_attrs['package'] == 'sup3r'
assert fh.global_attrs['version'] == __version__
assert 'full_version_record' in fh.global_attrs
version_record = json.loads(fh.global_attrs['full_version_record'])
assert version_record['tensorflow'] == tf.__version__
assert 'gan_meta' in fh.global_attrs
gan_meta = json.loads(fh.global_attrs['gan_meta'])
assert len(gan_meta) == 2 # two step model
assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m',
'topography']


def test_fwp_multi_step_model_topo_noskip():
"""Test the forward pass with a multi step model class using exogenous data
for all model steps"""
Expand Down

0 comments on commit d5986cf

Please sign in to comment.