diff --git a/wavenet/model.py b/wavenet/model.py index b3e6f4758..c6c563647 100644 --- a/wavenet/model.py +++ b/wavenet/model.py @@ -578,7 +578,7 @@ def predict_proba(self, waveform, global_condition=None, name='wavenet'): with tf.name_scope(name): if self.scalar_input: encoded = tf.cast(waveform, tf.float32) - encoded = tf.reshape(encoded, [-1, 1]) + encoded = tf.reshape(encoded, [self.batch_size, -1, 1]) else: encoded = self._one_hot(waveform)