From 77fbd94824496430cb64d30b71f53353f9ad1865 Mon Sep 17 00:00:00 2001 From: jiangxinglei Date: Tue, 10 Sep 2024 11:30:11 +0800 Subject: [PATCH] hotfix mean, var assign --- tensornet/layers/normalization_layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensornet/layers/normalization_layer.py b/tensornet/layers/normalization_layer.py index 8853f08..0749f85 100644 --- a/tensornet/layers/normalization_layer.py +++ b/tensornet/layers/normalization_layer.py @@ -120,9 +120,10 @@ def _increment_and_check_count(): _increment_and_check_count() else: self.bn_statistics_push(False) - else: - mean = self.moving_mean - var = self.moving_variance + self.update_moments() + + mean = self.moving_mean + var = self.moving_variance outputs = tf.nn.batch_normalization(x=inputs, mean=mean, variance=var, offset=self.beta, scale=self.gamma, variance_epsilon=self.epsilon)