Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix batchnorm in testmode without track stats
In test mode, the CUDA cuDNN implementation of batchnorm was not matching the CPU batchnorm in FLUX. In FLUX, with track_stats=False, the mean and variance of the current batch are used. Here, mean and variance were initialized to 0 and 1, respectively, and passed to cudnnBatchNormalizationForwardInference. To fix this, we need to calculate the mean and variance over the current batch to match the CPU implementation. Unfortunately, cudnnBatchNormalizationForwardInference requires a trained running mean and variance. However, batchnorm train and test should be identical without tracked stats since they both normalize over the current batch. As a result we can use cudnnBatchNormalizationForwardTraining in test mode as well, which works without a running mean and variance.
- Loading branch information