From 257fc97cd5656fa667ddbf88a693f6ffc6231aac Mon Sep 17 00:00:00 2001 From: Paul Novotny Date: Sun, 21 Apr 2024 13:27:51 +0000 Subject: [PATCH] Fix batchnorm in testmode without track stats In test mode, the CUDA cuDNN implementation of batchnorm was not matching the CPU batchnorm in FLUX. In FLUX, with track_stats=False, the mean and variance of the current batch are used. Here, mean and variance were initialized to 0 and 1, respectively, and passed to cudnnBatchNormalizationForwardInference. To fix this, we need to calculate the mean and variance over the current batch to match the CPU implementation. Unfortunately, cudnnBatchNormalizationForwardInference requires a trained running mean and variance. However, batchnorm train and test should be identical without tracked stats since they both normalize over the current batch. As a result we can use cudnnBatchNormalizationForwardTraining in test mode as well, which works without a running mean and variance. --- ext/NNlibCUDACUDNNExt/batchnorm.jl | 10 +++++++++- test/ext_cuda/batchnorm.jl | 11 +++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/ext/NNlibCUDACUDNNExt/batchnorm.jl b/ext/NNlibCUDACUDNNExt/batchnorm.jl index 2c38f009e..d74fb3ad8 100644 --- a/ext/NNlibCUDACUDNNExt/batchnorm.jl +++ b/ext/NNlibCUDACUDNNExt/batchnorm.jl @@ -84,7 +84,15 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray cache.ivar = ivar end else - cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) + if track_stats + cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps) + else + # cudnnBatchNormalizationForwardInference does not accept CV_NULL for running_mean + # and running_var. We could calculate mean and var of `x` here, but instead use + # cudnnBatchNormalizationFowardTraining. cudnnBatchNormalizationForwardTraining does + # accept CV_NULL and will calculate mean and var itself. + cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, CU_NULL, CU_NULL, eps, CU_NULL, CU_NULL) + end end return y end diff --git a/test/ext_cuda/batchnorm.jl b/test/ext_cuda/batchnorm.jl index 0adea7024..17bce0f36 100644 --- a/test/ext_cuda/batchnorm.jl +++ b/test/ext_cuda/batchnorm.jl @@ -1,3 +1,5 @@ +using Statistics + @testset "Batchnorm" begin v = CUDA.rand(Float32, 2) m = CUDA.rand(Float32, 2, 5) @@ -24,4 +26,13 @@ @test_throws ArgumentError batchnorm(v, v, m, α, β, 1.0; kws...) end end + @testset "test mode" begin + y_no_track_stats = batchnorm(v, v, m, nothing, nothing, 1.0; training=false, track_stats=false) + running_mean = mean(m, dims=[2]) + running_var = var(m, mean=running_mean, dims=[2], corrected=false) + y_track_stats = batchnorm(v, v, m, running_mean, running_var, 1.0; training=false, track_stats=true) + # batchnorm without tracked stats should equal bathnorm with tracked stats where the + # stats are calculated only on the input. + @test y_no_track_stats ≈ y_track_stats + end end