Skip to content

Commit

Permalink
Merge pull request #32 from NREL/gb/st_expand_bug
Browse files Browse the repository at this point in the history
Gb/st expand bug
  • Loading branch information
grantbuster authored Mar 16, 2022
2 parents a20ad65 + fad066d commit cc76892
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
14 changes: 3 additions & 11 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def __init__(self, spatial_mult=1):
"""
super().__init__()
self._spatial_mult = int(spatial_mult)
self._n_spatial_1 = None

@staticmethod
def _check_shape(input_shape):
Expand Down Expand Up @@ -225,9 +224,6 @@ def __init__(self, spatial_mult=1, temporal_mult=1,
self._spatial_mult = int(spatial_mult)
self._temporal_mult = int(temporal_mult)
self._temporal_meth = temporal_method
self._n_spatial_1 = None
self._n_temporal = None
self._temp_expand_shape = None

@staticmethod
def _check_shape(input_shape):
Expand All @@ -254,17 +250,13 @@ def build(self, input_shape):
"""
self._check_shape(input_shape)

# desired final shape of the 2nd and 3rd axes for temporal expansion
self._n_spatial_1 = input_shape[2]
self._n_temporal = input_shape[3]
self._temp_expand_shape = tf.stack([
self._n_spatial_1, self._n_temporal * self._temporal_mult])

def _temporal_expand(self, x):
"""Expand the temporal dimension (axis=3) of a 5D tensor"""
temp_expand_shape = tf.stack(
[x.shape[2], x.shape[3] * self._temporal_mult])
out = []
for x_unstack in tf.unstack(x, axis=1):
out.append(tf.image.resize(x_unstack, self._temp_expand_shape,
out.append(tf.image.resize(x_unstack, temp_expand_shape,
method=self._temporal_meth))

return tf.stack(out, axis=1)
Expand Down
2 changes: 1 addition & 1 deletion phygnn/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Physics Guided Neural Network version."""

__version__ = '0.0.15'
__version__ = '0.0.16'
6 changes: 3 additions & 3 deletions tests/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_norm_df():
label_names=LABELS.columns, normalize=True)

baseline, means, stdevs = PreProcess.normalize(FEATURES)
test = model._parse_features(FEATURES)
test = model.parse_features(FEATURES)
assert np.allclose(baseline.values, test)
assert np.allclose(means, model.feature_means)
assert np.allclose(stdevs, model.feature_stdevs)
Expand All @@ -50,7 +50,7 @@ def test_norm_arr():
label_names=label_names, normalize=True)

baseline, means, stdevs = PreProcess.normalize(features)
test = model._parse_features(features, names=feature_names)
test = model.parse_features(features, names=feature_names)
assert np.allclose(baseline, test)
assert np.allclose(means, model.feature_means)
assert np.allclose(stdevs, model.feature_stdevs)
Expand All @@ -77,7 +77,7 @@ def test_OHE():

baseline, means, stdevs = \
PreProcess.normalize(FEATURES.values.astype('float32'))
test = model._parse_features(ohe_features)
test = model.parse_features(ohe_features)

assert np.allclose(baseline, test[:, :2])
assert np.allclose(means,
Expand Down
27 changes: 26 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,33 @@ def test_st_expansion(t_mult, s_mult):
assert y.shape[4] == x.shape[4] / (s_mult**2)


def test_st_expansion_new_shape():
"""Test that the spatiotemporal expansion layer can expand multiple shapes
and is not bound to the shape it was built on (bug found on 3/16/2022.)"""
s_mult = 3
t_mult = 6
layer = SpatioTemporalExpansion(spatial_mult=s_mult, temporal_mult=t_mult)
n_filters = 2 * s_mult**2
x = np.ones((32, 10, 10, 24, n_filters))
y = layer(x)
assert y.shape[0] == x.shape[0]
assert y.shape[1] == s_mult * x.shape[1]
assert y.shape[2] == s_mult * x.shape[2]
assert y.shape[3] == t_mult * x.shape[3]
assert y.shape[4] == x.shape[4] / (s_mult**2)

x = np.ones((32, 11, 11, 36, n_filters))
y = layer(x)
assert y.shape[0] == x.shape[0]
assert y.shape[1] == s_mult * x.shape[1]
assert y.shape[2] == s_mult * x.shape[2]
assert y.shape[3] == t_mult * x.shape[3]
assert y.shape[4] == x.shape[4] / (s_mult**2)


def test_st_expansion_bad():
"""Test an illegal spatial expansion request."""
"""Test an illegal spatial expansion request due to number of channels not
able to unpack into spatiotemporal dimensions."""
layer = SpatioTemporalExpansion(spatial_mult=2, temporal_mult=2)
x = np.ones((123, 10, 10, 24, 3))
with pytest.raises(RuntimeError):
Expand Down

0 comments on commit cc76892

Please sign in to comment.