diff --git a/Project.toml b/Project.toml index b03fba6..5b5e312 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleChains" uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" authors = ["Chris Elrod and contributors"] -version = "0.3.1" +version = "0.3.2" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/examples/mnist_lenet.jl b/examples/mnist_lenet.jl index 7358f2c..8ef8ba3 100644 --- a/examples/mnist_lenet.jl +++ b/examples/mnist_lenet.jl @@ -2,6 +2,8 @@ import MLDatasets function get_data() + # xtrain, ytrain = MLDatasets.MNIST.traindata(Float32); + # xtest, ytest = MLDatasets.MNIST.testdata(Float32); xtrain, ytrain = MLDatasets.MNIST(:train)[:] xtest, ytest = MLDatasets.MNIST(:test)[:] @@ -53,7 +55,7 @@ SimpleChains.accuracy_and_loss(lenetloss, xtest, ytest, p) -lenet.memory .= 0; +# lenet.memory .= 0; SimpleChains.init_params!(lenet, p); @time SimpleChains.train_batched!(G, p, lenetloss, xtrain, SimpleChains.ADAM(3e-4), 10); SimpleChains.accuracy_and_loss(lenetloss, xtrain, p), diff --git a/src/utils.jl b/src/utils.jl index 07dd4d9..4ea370f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -72,16 +72,34 @@ nfan(n_out, n_in) = n_in, n_out p * dft, p * dt end # https://github.com/FluxML/Flux.jl/blob/master/LICENSE.md -function glorot_uniform!(A::AbstractArray{T}, rng = local_rng()) where {T} +function glorot_uniform!(A::AbstractArray{T}, rng::VectorizedRNG.AbstractVRNG = local_rng()) where {T} scale = @fastmath sqrt(T(24) / tssum(nfan(size(A)...))) # (rand()-0.5)*scale === rand()*scale - 0.5scale rand!(rng, A, static(0), T(-0.5) * scale, scale) end +function glorot_uniform!(A::AbstractArray{T}, rng) where {T} + scale = @fastmath sqrt(T(24) / tssum(nfan(size(A)...))) + @show scale + # (rand()-0.5)*scale === rand()*scale - 0.5scale + rand!(rng, A) + @inbounds @fastmath for i = eachindex(A) + A[i] = A[i]*scale - 0.5*scale + end + return A +end # https://github.com/FluxML/Flux.jl/blob/master/LICENSE.md -function glorot_normal!(A::AbstractArray{T}, rng = local_rng()) where {T} +function glorot_normal!(A::AbstractArray{T}, rng::VectorizedRNG.AbstractVRNG = local_rng()) where {T} σ = @fastmath sqrt(T(2) / tssum(nfan(size(A)...))) randn!(rng, A, static(0), static(0), σ) end +function glorot_normal!(A::AbstractArray{T}, rng) where {T} + σ = @fastmath sqrt(T(2) / tssum(nfan(size(A)...))) + randn!(rng, A) + @inbounds @fastmath for i = eachindex(A) + A[i] *= σ + end + return A +end function randpermzero!(r::Random.AbstractRNG, a::AbstractArray{<:Integer}) n = length(a) diff --git a/test/Project.toml b/test/Project.toml index 5b522bf..f35f5aa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,8 +4,10 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/test/random.jl b/test/random.jl new file mode 100644 index 0000000..25c0190 --- /dev/null +++ b/test/random.jl @@ -0,0 +1,36 @@ + +using Random, VectorizedRNG + +_mean(x) = sum(x)/length(x) +function _mean_std(x, xbar = _mean(x)) + xbar, sqrt(sum(abs2 ∘ Base.Fix2(-, xbar), x)/(length(x)-1)) +end +x = Vector{Float64}(undef, 2047); +y = Vector{Float64}(undef, 2047); + +vrng = VectorizedRNG.MutableXoshift(3); +rng = VERSION >= v"1.7" ? Random.Xoshiro(3) : Random.MersenneTwister(4); + +SimpleChains.glorot_uniform!(x, vrng); +SimpleChains.glorot_uniform!(y, rng); +mx, sx = _mean_std(x) +@test abs(mx) < 0.01 +@test sx ≈ 0.03125 rtol = 1e-2 +my, sy = _mean_std(y) +@test abs(my) < 0.01 +@test sy ≈ 0.03125 rtol = 1e-2 + +SimpleChains.glorot_normal!(x, vrng); +SimpleChains.glorot_normal!(y, rng); +mx, sx = _mean_std(x) +@test abs(mx) < 0.01 +@test sx ≈ 0.03125 rtol = 1e-2 +my, sy = _mean_std(y) +@test abs(my) < 0.01 +@test sy ≈ 0.03125 rtol = 1e-2 + + + + + + diff --git a/test/runtests.jl b/test/runtests.jl index 20350d3..c03b529 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using SimpleChains -using Test, Aqua, ForwardDiff, Zygote, ChainRules +using Test, Aqua, ForwardDiff, Zygote, ChainRules, Random function countallocations!(g, sc, x, p) @allocated valgrad!(g, sc, x, p) @@ -487,6 +487,9 @@ InteractiveUtils.versioninfo(verbose=true) @testset "SArray" begin include("staticarrays.jl") end + @testset "Glorot" begin + include("random.jl") + end end # TODO: test ambiguities once ForwardDiff fixes them, or once ForwardDiff is dropped # For now, there are the tests at the start.