From e1fb3834dc60b4fad7b4c98cb35b0fcbc52b9977 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 21:15:51 -0700 Subject: [PATCH] fix: handle spurious erf type promotion --- ext/WeightInitializersGPUArraysExt.jl | 2 +- src/utils.jl | 2 +- test/initializers_tests.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/WeightInitializersGPUArraysExt.jl b/ext/WeightInitializersGPUArraysExt.jl index 6e358a3..5a3c3af 100644 --- a/ext/WeightInitializersGPUArraysExt.jl +++ b/ext/WeightInitializersGPUArraysExt.jl @@ -14,7 +14,7 @@ end ## dispatches for f in (:__rand, :__randn) @eval @inline function WeightInitializers.$(f)( - rng::RNG, ::Type{<:Complex{T}}, args...) where {T <: Number} + rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} real_part = WeightInitializers.$(f)(rng, rng.state, T, args...) imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...) return Complex{T}.(real_part, imag_part) diff --git a/src/utils.jl b/src/utils.jl index 33669d9..3b9c618 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,7 +3,7 @@ @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) @inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) +@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / √2))) # erf often doesn't respect the type @inline _default_rng() = Xoshiro(1234) diff --git a/test/initializers_tests.jl b/test/initializers_tests.jl index 6b2d718..f98327f 100644 --- a/test/initializers_tests.jl +++ b/test/initializers_tests.jl @@ -250,7 +250,7 @@ end v = kaiming_normal(rng, n_in, n_out) σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(v) < 1.1σ2 + @test 0.9σ2 < std(Array(v)) < 1.1σ2 # Just for safety move to Array end # Type @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32