Skip to content

Commit

Permalink
fix: handle spurious erf type promotion
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 3, 2024
1 parent 3e3adda commit e1fb383
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ext/WeightInitializersGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
## dispatches
for f in (:__rand, :__randn)
@eval @inline function WeightInitializers.$(f)(
rng::RNG, ::Type{<:Complex{T}}, args...) where {T <: Number}
rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number}
real_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
return Complex{T}.(real_part, imag_part)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
@inline _nfan(dims::Tuple) = _nfan(dims...)
@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels
@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / 2))
@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / 2))) # erf often doesn't respect the type

@inline _default_rng() = Xoshiro(1234)

Expand Down
2 changes: 1 addition & 1 deletion test/initializers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ end

v = kaiming_normal(rng, n_in, n_out)
σ2 = sqrt(2 / n_out)
@test 0.9σ2 < std(v) < 1.1σ2
@test 0.9σ2 < std(Array(v)) < 1.1σ2 # Just for safety move to Array
end
# Type
@test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32
Expand Down

0 comments on commit e1fb383

Please sign in to comment.