Skip to content

Commit

Permalink
Fix batchnorm in testmode without track stats
Browse files Browse the repository at this point in the history
In test mode, the CUDA cuDNN implementation of batchnorm was not
matching the CPU batchnorm in FLUX. In FLUX, with track_stats=False, the
mean and variance of the current batch are used. Here, mean and variance
were initialized to 0 and 1, respectively, and passed to
cudnnBatchNormalizationForwardInference.

To fix this, we need to calculate the mean and variance over the current
batch to match the CPU implementation. Unfortunately,
cudnnBatchNormalizationForwardInference requires a trained running mean
and variance. However, batchnorm train and test should be identical
without tracked stats since they both normalize over the current
batch. As a result we can use cudnnBatchNormalizationForwardTraining in
test mode as well, which works without a running mean and variance.
  • Loading branch information
paulnovo committed Apr 21, 2024
1 parent 0783363 commit 257fc97
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
10 changes: 9 additions & 1 deletion ext/NNlibCUDACUDNNExt/batchnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,15 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
cache.ivar = ivar
end
else
cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)
if track_stats
cudnnBatchNormalizationForwardInference(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, running_mean, running_var, eps)
else
# cudnnBatchNormalizationForwardInference does not accept CV_NULL for running_mean
# and running_var. We could calculate mean and var of `x` here, but instead use
# cudnnBatchNormalizationFowardTraining. cudnnBatchNormalizationForwardTraining does
# accept CV_NULL and will calculate mean and var itself.
cudnnBatchNormalizationForwardTraining(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), xd, x, yd, y, gd, g, b, momentum, CU_NULL, CU_NULL, eps, CU_NULL, CU_NULL)
end
end
return y
end
Expand Down
11 changes: 11 additions & 0 deletions test/ext_cuda/batchnorm.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Statistics

@testset "Batchnorm" begin
v = CUDA.rand(Float32, 2)
m = CUDA.rand(Float32, 2, 5)
Expand All @@ -24,4 +26,13 @@
@test_throws ArgumentError batchnorm(v, v, m, α, β, 1.0; kws...)
end
end
@testset "test mode" begin
y_no_track_stats = batchnorm(v, v, m, nothing, nothing, 1.0; training=false, track_stats=false)
running_mean = mean(m, dims=[2])
running_var = var(m, mean=running_mean, dims=[2], corrected=false)
y_track_stats = batchnorm(v, v, m, running_mean, running_var, 1.0; training=false, track_stats=true)
# batchnorm without tracked stats should equal bathnorm with tracked stats where the
# stats are calculated only on the input.
@test y_no_track_stats y_track_stats
end
end

0 comments on commit 257fc97

Please sign in to comment.