From 93fa9abdd7ad8fa7752f1374fed4a845267ff961 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 28 Sep 2022 12:02:35 -0400 Subject: [PATCH 1/5] allow passing `AbstractRNG`s to init_params --- src/SimpleChains.jl | 1 + src/activation.jl | 2 +- src/conv.jl | 4 ++-- src/dense.jl | 12 ++++++------ src/dropout.jl | 2 +- src/flatten.jl | 2 +- src/loss.jl | 2 +- src/maxpool.jl | 2 +- src/optimize.jl | 4 ++-- src/penalty.jl | 11 +++++++---- src/simple_chain.jl | 21 ++++++++++++--------- src/utils.jl | 1 - 12 files changed, 35 insertions(+), 29 deletions(-) diff --git a/src/SimpleChains.jl b/src/SimpleChains.jl index 1425dc4..dc326fc 100644 --- a/src/SimpleChains.jl +++ b/src/SimpleChains.jl @@ -41,6 +41,7 @@ import ChainRulesCore import ForwardDiff import LoopVectorization import StaticArrays +using Random: AbstractRNG using LoopVectorization: matmul_params, @turbo # using LoopVectorization: matmul_params diff --git a/src/activation.jl b/src/activation.jl index 25e2eed..65d007a 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -12,7 +12,7 @@ struct Activation{F} end parameter_free(::Activation) = true numparam(::Activation, id) = static(0), id -init_params!(::Activation, p, id) = p, id +init_params!(::Activation, p, id, ::AbstractRNG) = p, id _check_input_dims(::Activation, _) = nothing forward_layer_output_size(::Val{T}, a::Activation, s) where {T} = diff --git a/src/conv.jl b/src/conv.jl index 8c6b591..a261978 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -797,9 +797,9 @@ function forward_layer_output_size(::Val{T}, c::Conv, inputdim::Tuple) where {T} align(static_sizeof(T) * prod(outputdim)), outputdim end -function init_params!(c::Conv, p, inputdim) +function init_params!(c::Conv, p, inputdim, rng::AbstractRNG) (K, b), p2 = getparams(c, p, inputdim) - glorot_uniform!(K) + glorot_uniform!(K, rng) @turbo for i in eachindex(b) b[i] = 0 end diff --git a/src/dense.jl b/src/dense.jl index 1a0c96d..8bfe7dc 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -87,22 +87,22 @@ function _getparams(layer::TurboDense{true}, p, inputdim::Tuple) _, outputdim = numparam(layer, inputdim) (view(A, :, static(1):K), view(A, :, Kp1)), p, outputdim end -function init_params!(td::TurboDense, p, inputdim::Tuple) - p, outputdim = _init_params!(td, p, first(inputdim)) +function init_params!(td::TurboDense, p, inputdim::Tuple, rng::AbstractRNG) + p, outputdim = _init_params!(td, p, first(inputdim), rng) p, (outputdim, Base.tail(inputdim)...) end -function _init_params!(td::TurboDense{true}, p, inputdim::Integer) +function _init_params!(td::TurboDense{true}, p, inputdim::Integer, rng::AbstractRNG) W, p = getparams(td, p, inputdim) outputdim = td.outputdim - glorot_normal!(view(W, :, 1:inputdim)) + glorot_normal!(view(W, :, 1:inputdim), rng) @turbo for i = 1:outputdim W[i, inputdim+1] = 0 end return p, outputdim end -function _init_params!(td::TurboDense{false}, p, inputdim::Integer) +function _init_params!(td::TurboDense{false}, p, inputdim::Integer, rng::AbstractRNG) W, p = getparams(td, p, inputdim) - glorot_normal!(W) + glorot_normal!(W, rng) return p, td.outputdim end diff --git a/src/dropout.jl b/src/dropout.jl index df71b65..82bf5dc 100644 --- a/src/dropout.jl +++ b/src/dropout.jl @@ -27,7 +27,7 @@ gradval(::Val{T}, d::Dropout) where {T} = T(0xffffffff) / (T(0xffffffff) - d.p) numparam(::Dropout, id) = static(0), id parameter_free(::Dropout) = true -init_params!(::Dropout, p, id) = p, id +init_params!(::Dropout, p, id, _) = p, id function (d::Dropout)(B::AbstractVecOrMat{T}, p::Ptr, pu::Ptr{UInt8}) where {T} x = muladd(T(d.p), -inv(T(typemax(UInt32))), one(T)) diff --git a/src/flatten.jl b/src/flatten.jl index 843d264..397f70c 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -34,7 +34,7 @@ function forward_layer_output_size(::Val{T}, ::Flatten{N}, inputdim::Tuple) wher static(0), getoutputdim(Flatten{N}(), inputdim) end -init_params!(::Flatten{N}, p, id) where {N} = p, getoutputdim(Flatten{N}(), id) +init_params!(::Flatten{N}, p, id, ::AbstractRNG) where {N} = p, getoutputdim(Flatten{N}(), id) numparam(::Flatten{N}, inputdim) where {N} = 0, getoutputdim(Flatten{N}(), inputdim) diff --git a/src/loss.jl b/src/loss.jl index cbd1a1e..7606baf 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -66,7 +66,7 @@ end (::SquaredLoss)(y) = SquaredLoss(y) SquaredLoss() = SquaredLoss(nothing) target(sl::SquaredLoss) = getfield(sl, :y) -init_params!(::AbstractLoss, p, _) = p, 1 +init_params!(::AbstractLoss, p, _, ::AbstractRNG) = p, 1 function Base.getindex(sl::SquaredLoss, r) SquaredLoss(view_slice_last(target(sl), r)) diff --git a/src/maxpool.jl b/src/maxpool.jl index 077b303..1d9868f 100644 --- a/src/maxpool.jl +++ b/src/maxpool.jl @@ -27,7 +27,7 @@ end function getoutputdim(::MaxPool{D}, inputdim) where {D} _maxpooloutputdim(map(StaticInt, D), inputdim) end -init_params!(::MaxPool{D}, p, id) where {D} = p, getoutputdim(MaxPool{D}(), id) +init_params!(::MaxPool{D}, p, id, ::AbstractRNG) where {D} = p, getoutputdim(MaxPool{D}(), id) numparam(mp::MaxPool, inputdim) = 0, getoutputdim(mp, inputdim) diff --git a/src/optimize.jl b/src/optimize.jl index 984e151..139c39a 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -694,7 +694,7 @@ end for t ∈ [:train, :train_batched, :train_unbatched] t! = Symbol(t, :!) - @eval function $t(chn::Chain, X, opt, iters) - $t!(init_params(chn), chn, X, opt, iters) + @eval function $t(chn::Chain, X, opt, iters; rng::AbstractRNG = local_rng()) + $t!(init_params(chn, nothing, eltype(X), rng), chn, X, opt, iters) end end diff --git a/src/penalty.jl b/src/penalty.jl index ff90a35..0dc5ed9 100644 --- a/src/penalty.jl +++ b/src/penalty.jl @@ -40,13 +40,16 @@ function init_params( Λ::AbstractPenalty, id::Union{Nothing,InputDim} = nothing, ::Type{T} = Float32, + rng::AbstractRNG=local_rng(), ) where {T} - init_params(getchain(Λ), id, T) + init_params(getchain(Λ), id, T, rng) end -function init_params(Λ::AbstractPenalty, ::Type{T}) where {T} - init_params(getchain(Λ), nothing, T) +function init_params(Λ::AbstractPenalty, ::Type{T}, rng::AbstractRNG=local_rng()) where {T} + init_params(getchain(Λ), nothing, T, rng) +end +function init_params!(Λ::AbstractPenalty, x, id = nothing, rng::AbstractRNG=local_rng()) + init_params!(getchain(Λ), x, id, rng) end -init_params!(Λ::AbstractPenalty, x, id = nothing) = init_params!(getchain(Λ), x, id) target(c::AbstractPenalty) = target(getchain(c)) diff --git a/src/simple_chain.jl b/src/simple_chain.jl index f451ff9..57f306a 100644 --- a/src/simple_chain.jl +++ b/src/simple_chain.jl @@ -387,22 +387,25 @@ end Randomly initializes parameter vector `p` with input dim `id`. Input dim does not need to be specified if these were provided to the chain object itself. See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions. """ -function init_params!(chn::SimpleChain, x::AbstractVector, id = nothing) - GC.@preserve x init_params!(chn.layers, pointer(x), chain_input_dims(chn, id)) +function init_params!( + chn::SimpleChain, x::AbstractVector, id = nothing, rng::AbstractRNG=local_rng() +) + GC.@preserve x init_params!(chn.layers, pointer(x), chain_input_dims(chn, id), rng) return x end -function init_params!(layers::Tuple, p::Ptr, id) - p, od = init_params!(first(layers), p, id) - init_params!(Base.tail(layers), p, od) +function init_params!(layers::Tuple, p::Ptr, id, rng::AbstractRNG=local_rng()) + p, od = init_params!(first(layers), p, id, rng) + init_params!(Base.tail(layers), p, od, rng) end -init_params!(::Tuple{}, p::Ptr, _) = nothing +init_params!(::Tuple{}, p::Ptr, _, ::AbstractRNG) = nothing function init_params( Λ::SimpleChain, id::Union{Nothing,InputDim} = nothing, ::Type{T} = Float32, + rng::AbstractRNG=local_rng() ) where {T} _id = chain_input_dims(Λ, id) - init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id)) + init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id), rng) end """ SimpleChains.init_params(chn[, id = nothing][, ::Type{T} = Float32]) @@ -410,8 +413,8 @@ end Creates a parameter vector of element type `T` with size matching that by `id` (argument not required if provided to the `chain` object itself). See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions. """ -function init_params(Λ::SimpleChain, ::Type{T}) where {T} - init_params(Λ, nothing, T) +function init_params(Λ::SimpleChain, ::Type{T}, rng::AbstractRNG=local_rng()) where {T} + init_params(Λ, nothing, T, rng) end @inline function maybe_static_size_arg(s::Tuple, arg) diff --git a/src/utils.jl b/src/utils.jl index 4ea370f..aeec9e2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -79,7 +79,6 @@ function glorot_uniform!(A::AbstractArray{T}, rng::VectorizedRNG.AbstractVRNG = end function glorot_uniform!(A::AbstractArray{T}, rng) where {T} scale = @fastmath sqrt(T(24) / tssum(nfan(size(A)...))) - @show scale # (rand()-0.5)*scale === rand()*scale - 0.5scale rand!(rng, A) @inbounds @fastmath for i = eachindex(A) From 1011f7f0c78f83d91b149e9d27e8ded9ccd5b82a Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 28 Sep 2022 12:07:24 -0400 Subject: [PATCH 2/5] make rng a kwarg and pass it in a couple places --- src/optimize.jl | 2 +- src/simple_chain.jl | 9 +++++---- test/mnist.jl | 2 +- test/runtests.jl | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 139c39a..e46e078 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -695,6 +695,6 @@ end for t ∈ [:train, :train_batched, :train_unbatched] t! = Symbol(t, :!) @eval function $t(chn::Chain, X, opt, iters; rng::AbstractRNG = local_rng()) - $t!(init_params(chn, nothing, eltype(X), rng), chn, X, opt, iters) + $t!(init_params(chn, nothing, eltype(X); rng), chn, X, opt, iters) end end diff --git a/src/simple_chain.jl b/src/simple_chain.jl index 57f306a..00c8bf1 100644 --- a/src/simple_chain.jl +++ b/src/simple_chain.jl @@ -388,12 +388,12 @@ Randomly initializes parameter vector `p` with input dim `id`. Input dim does no See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions. """ function init_params!( - chn::SimpleChain, x::AbstractVector, id = nothing, rng::AbstractRNG=local_rng() + chn::SimpleChain, x::AbstractVector, id = nothing; rng::AbstractRNG ) GC.@preserve x init_params!(chn.layers, pointer(x), chain_input_dims(chn, id), rng) return x end -function init_params!(layers::Tuple, p::Ptr, id, rng::AbstractRNG=local_rng()) +function init_params!(layers::Tuple, p::Ptr, id, rng::AbstractRNG) p, od = init_params!(first(layers), p, id, rng) init_params!(Base.tail(layers), p, od, rng) end @@ -401,19 +401,20 @@ init_params!(::Tuple{}, p::Ptr, _, ::AbstractRNG) = nothing function init_params( Λ::SimpleChain, id::Union{Nothing,InputDim} = nothing, - ::Type{T} = Float32, + ::Type{T} = Float32; rng::AbstractRNG=local_rng() ) where {T} _id = chain_input_dims(Λ, id) init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id), rng) end + """ SimpleChains.init_params(chn[, id = nothing][, ::Type{T} = Float32]) Creates a parameter vector of element type `T` with size matching that by `id` (argument not required if provided to the `chain` object itself). See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions. """ -function init_params(Λ::SimpleChain, ::Type{T}, rng::AbstractRNG=local_rng()) where {T} +function init_params(Λ::SimpleChain, ::Type{T}; rng::AbstractRNG=local_rng()) where {T} init_params(Λ, nothing, T, rng) end diff --git a/test/mnist.jl b/test/mnist.jl index bfdfc7d..28a9892 100644 --- a/test/mnist.jl +++ b/test/mnist.jl @@ -27,7 +27,7 @@ lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1)); @test SimpleChains.outputdim(lenet, size(xtest4)) == (10, length(ytest1)); # initialize parameters -@time p = SimpleChains.init_params(lenet); +@time p = SimpleChains.init_params(lenet, rng = SimpleChains.local_rng()); @test all(isfinite, p) @testset "Cache Corrupting Results" begin diff --git a/test/runtests.jl b/test/runtests.jl index c03b529..e854893 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,7 +80,7 @@ InteractiveUtils.versioninfo(verbose=true) # typename doesn't work on 1.5 @test_broken sprint((io, t) -> show(io, t), scflp) == print_str1 end - p = SimpleChains.init_params(scflp, T) + p = SimpleChains.init_params(scflp, T, rng = Random.default_rng()) g = similar(p) let sc = SimpleChains.remove_loss(sc) @test_throws ArgumentError sc(rand(T, 23, 2), p) From 4e2f226f8e65d182ff73242e7b94d7f6bebb7a20 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 28 Sep 2022 12:18:33 -0400 Subject: [PATCH 3/5] missed a couple kwargs --- src/simple_chain.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simple_chain.jl b/src/simple_chain.jl index 00c8bf1..df6a981 100644 --- a/src/simple_chain.jl +++ b/src/simple_chain.jl @@ -405,7 +405,7 @@ function init_params( rng::AbstractRNG=local_rng() ) where {T} _id = chain_input_dims(Λ, id) - init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id), rng) + init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id); rng) end """ @@ -415,7 +415,7 @@ Creates a parameter vector of element type `T` with size matching that by `id` ( See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions. """ function init_params(Λ::SimpleChain, ::Type{T}; rng::AbstractRNG=local_rng()) where {T} - init_params(Λ, nothing, T, rng) + init_params(Λ, nothing, T; rng) end @inline function maybe_static_size_arg(s::Tuple, arg) From 5496ddb16a85d3cb20824888b27de3ba4f84eeb4 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 28 Sep 2022 12:30:53 -0400 Subject: [PATCH 4/5] penalty rng kwarg --- src/penalty.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/penalty.jl b/src/penalty.jl index 0dc5ed9..fd0755e 100644 --- a/src/penalty.jl +++ b/src/penalty.jl @@ -39,16 +39,16 @@ _type_sym(c::Chain) = __type_sym(remove_loss(c)) function init_params( Λ::AbstractPenalty, id::Union{Nothing,InputDim} = nothing, - ::Type{T} = Float32, + ::Type{T} = Float32; rng::AbstractRNG=local_rng(), ) where {T} - init_params(getchain(Λ), id, T, rng) + init_params(getchain(Λ), id, T; rng) end -function init_params(Λ::AbstractPenalty, ::Type{T}, rng::AbstractRNG=local_rng()) where {T} - init_params(getchain(Λ), nothing, T, rng) +function init_params(Λ::AbstractPenalty, ::Type{T}; rng::AbstractRNG=local_rng()) where {T} + init_params(getchain(Λ), nothing, T; rng) end -function init_params!(Λ::AbstractPenalty, x, id = nothing, rng::AbstractRNG=local_rng()) - init_params!(getchain(Λ), x, id, rng) +function init_params!(Λ::AbstractPenalty, x, id = nothing; rng::AbstractRNG=local_rng()) + init_params!(getchain(Λ), x, id; rng) end target(c::AbstractPenalty) = target(getchain(c)) From 707fb321a2ce692bc05086007f21b4cc57f2044b Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Wed, 28 Sep 2022 14:57:01 -0400 Subject: [PATCH 5/5] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5b5e312..70341de 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleChains" uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" authors = ["Chris Elrod and contributors"] -version = "0.3.2" +version = "0.3.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"