Skip to content

Commit

Permalink
Merge pull request #113 from PumasAI/allowabstractrng
Browse files Browse the repository at this point in the history
allow passing `AbstractRNG`s to `init_params`
  • Loading branch information
korsbo authored Sep 29, 2022
2 parents 8c4a7ad + 707fb32 commit 6fbf2b0
Show file tree
Hide file tree
Showing 15 changed files with 41 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.3.2"
version = "0.3.3"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
1 change: 1 addition & 0 deletions src/SimpleChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import ChainRulesCore
import ForwardDiff
import LoopVectorization
import StaticArrays
using Random: AbstractRNG

using LoopVectorization: matmul_params, @turbo
# using LoopVectorization: matmul_params
Expand Down
2 changes: 1 addition & 1 deletion src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct Activation{F}
end
parameter_free(::Activation) = true
numparam(::Activation, id) = static(0), id
init_params!(::Activation, p, id) = p, id
init_params!(::Activation, p, id, ::AbstractRNG) = p, id
_check_input_dims(::Activation, _) = nothing

forward_layer_output_size(::Val{T}, a::Activation, s) where {T} =
Expand Down
4 changes: 2 additions & 2 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,9 @@ function forward_layer_output_size(::Val{T}, c::Conv, inputdim::Tuple) where {T}
align(static_sizeof(T) * prod(outputdim)), outputdim
end

function init_params!(c::Conv, p, inputdim)
function init_params!(c::Conv, p, inputdim, rng::AbstractRNG)
(K, b), p2 = getparams(c, p, inputdim)
glorot_uniform!(K)
glorot_uniform!(K, rng)
@turbo for i in eachindex(b)
b[i] = 0
end
Expand Down
12 changes: 6 additions & 6 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,22 @@ function _getparams(layer::TurboDense{true}, p, inputdim::Tuple)
_, outputdim = numparam(layer, inputdim)
(view(A, :, static(1):K), view(A, :, Kp1)), p, outputdim
end
function init_params!(td::TurboDense, p, inputdim::Tuple)
p, outputdim = _init_params!(td, p, first(inputdim))
function init_params!(td::TurboDense, p, inputdim::Tuple, rng::AbstractRNG)
p, outputdim = _init_params!(td, p, first(inputdim), rng)
p, (outputdim, Base.tail(inputdim)...)
end
function _init_params!(td::TurboDense{true}, p, inputdim::Integer)
function _init_params!(td::TurboDense{true}, p, inputdim::Integer, rng::AbstractRNG)
W, p = getparams(td, p, inputdim)
outputdim = td.outputdim
glorot_normal!(view(W, :, 1:inputdim))
glorot_normal!(view(W, :, 1:inputdim), rng)
@turbo for i = 1:outputdim
W[i, inputdim+1] = 0
end
return p, outputdim
end
function _init_params!(td::TurboDense{false}, p, inputdim::Integer)
function _init_params!(td::TurboDense{false}, p, inputdim::Integer, rng::AbstractRNG)
W, p = getparams(td, p, inputdim)
glorot_normal!(W)
glorot_normal!(W, rng)
return p, td.outputdim
end

Expand Down
2 changes: 1 addition & 1 deletion src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ gradval(::Val{T}, d::Dropout) where {T} = T(0xffffffff) / (T(0xffffffff) - d.p)
numparam(::Dropout, id) = static(0), id
parameter_free(::Dropout) = true

init_params!(::Dropout, p, id) = p, id
init_params!(::Dropout, p, id, _) = p, id

function (d::Dropout)(B::AbstractVecOrMat{T}, p::Ptr, pu::Ptr{UInt8}) where {T}
x = muladd(T(d.p), -inv(T(typemax(UInt32))), one(T))
Expand Down
2 changes: 1 addition & 1 deletion src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function forward_layer_output_size(::Val{T}, ::Flatten{N}, inputdim::Tuple) wher
static(0), getoutputdim(Flatten{N}(), inputdim)
end

init_params!(::Flatten{N}, p, id) where {N} = p, getoutputdim(Flatten{N}(), id)
init_params!(::Flatten{N}, p, id, ::AbstractRNG) where {N} = p, getoutputdim(Flatten{N}(), id)


numparam(::Flatten{N}, inputdim) where {N} = 0, getoutputdim(Flatten{N}(), inputdim)
Expand Down
2 changes: 1 addition & 1 deletion src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ end
(::SquaredLoss)(y) = SquaredLoss(y)
SquaredLoss() = SquaredLoss(nothing)
target(sl::SquaredLoss) = getfield(sl, :y)
init_params!(::AbstractLoss, p, _) = p, 1
init_params!(::AbstractLoss, p, _, ::AbstractRNG) = p, 1

function Base.getindex(sl::SquaredLoss, r)
SquaredLoss(view_slice_last(target(sl), r))
Expand Down
2 changes: 1 addition & 1 deletion src/maxpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
function getoutputdim(::MaxPool{D}, inputdim) where {D}
_maxpooloutputdim(map(StaticInt, D), inputdim)
end
init_params!(::MaxPool{D}, p, id) where {D} = p, getoutputdim(MaxPool{D}(), id)
init_params!(::MaxPool{D}, p, id, ::AbstractRNG) where {D} = p, getoutputdim(MaxPool{D}(), id)

numparam(mp::MaxPool, inputdim) = 0, getoutputdim(mp, inputdim)

Expand Down
4 changes: 2 additions & 2 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ end

for t [:train, :train_batched, :train_unbatched]
t! = Symbol(t, :!)
@eval function $t(chn::Chain, X, opt, iters)
$t!(init_params(chn), chn, X, opt, iters)
@eval function $t(chn::Chain, X, opt, iters; rng::AbstractRNG = local_rng())
$t!(init_params(chn, nothing, eltype(X); rng), chn, X, opt, iters)
end
end
13 changes: 8 additions & 5 deletions src/penalty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ _type_sym(c::Chain) = __type_sym(remove_loss(c))
function init_params(
Λ::AbstractPenalty,
id::Union{Nothing,InputDim} = nothing,
::Type{T} = Float32,
::Type{T} = Float32;
rng::AbstractRNG=local_rng(),
) where {T}
init_params(getchain(Λ), id, T)
init_params(getchain(Λ), id, T; rng)
end
function init_params::AbstractPenalty, ::Type{T}; rng::AbstractRNG=local_rng()) where {T}
init_params(getchain(Λ), nothing, T; rng)
end
function init_params::AbstractPenalty, ::Type{T}) where {T}
init_params(getchain(Λ), nothing, T)
function init_params!::AbstractPenalty, x, id = nothing; rng::AbstractRNG=local_rng())
init_params!(getchain(Λ), x, id; rng)
end
init_params!::AbstractPenalty, x, id = nothing) = init_params!(getchain(Λ), x, id)

target(c::AbstractPenalty) = target(getchain(c))

Expand Down
24 changes: 14 additions & 10 deletions src/simple_chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,31 +387,35 @@ end
Randomly initializes parameter vector `p` with input dim `id`. Input dim does not need to be specified if these were provided to the chain object itself.
See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions.
"""
function init_params!(chn::SimpleChain, x::AbstractVector, id = nothing)
GC.@preserve x init_params!(chn.layers, pointer(x), chain_input_dims(chn, id))
function init_params!(
chn::SimpleChain, x::AbstractVector, id = nothing; rng::AbstractRNG
)
GC.@preserve x init_params!(chn.layers, pointer(x), chain_input_dims(chn, id), rng)
return x
end
function init_params!(layers::Tuple, p::Ptr, id)
p, od = init_params!(first(layers), p, id)
init_params!(Base.tail(layers), p, od)
function init_params!(layers::Tuple, p::Ptr, id, rng::AbstractRNG)
p, od = init_params!(first(layers), p, id, rng)
init_params!(Base.tail(layers), p, od, rng)
end
init_params!(::Tuple{}, p::Ptr, _) = nothing
init_params!(::Tuple{}, p::Ptr, _, ::AbstractRNG) = nothing
function init_params(
Λ::SimpleChain,
id::Union{Nothing,InputDim} = nothing,
::Type{T} = Float32,
::Type{T} = Float32;
rng::AbstractRNG=local_rng()
) where {T}
_id = chain_input_dims(Λ, id)
init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id))
init_params!(Λ, StrideArray{T}(undef, numparam(Λ, id)), chain_input_dims(Λ, _id); rng)
end

"""
SimpleChains.init_params(chn[, id = nothing][, ::Type{T} = Float32])
Creates a parameter vector of element type `T` with size matching that by `id` (argument not required if provided to the `chain` object itself).
See the documentation of the individual layers to see how they are initialized, but it is generally via (Xavier) Glorot uniform or normal distributions.
"""
function init_params::SimpleChain, ::Type{T}) where {T}
init_params(Λ, nothing, T)
function init_params::SimpleChain, ::Type{T}; rng::AbstractRNG=local_rng()) where {T}
init_params(Λ, nothing, T; rng)
end

@inline function maybe_static_size_arg(s::Tuple, arg)
Expand Down
1 change: 0 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ function glorot_uniform!(A::AbstractArray{T}, rng::VectorizedRNG.AbstractVRNG =
end
function glorot_uniform!(A::AbstractArray{T}, rng) where {T}
scale = @fastmath sqrt(T(24) / tssum(nfan(size(A)...)))
@show scale
# (rand()-0.5)*scale === rand()*scale - 0.5scale
rand!(rng, A)
@inbounds @fastmath for i = eachindex(A)
Expand Down
2 changes: 1 addition & 1 deletion test/mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ lenetloss = SimpleChains.add_loss(lenet, LogitCrossEntropyLoss(ytrain1));
@test SimpleChains.outputdim(lenet, size(xtest4)) == (10, length(ytest1));

# initialize parameters
@time p = SimpleChains.init_params(lenet);
@time p = SimpleChains.init_params(lenet, rng = SimpleChains.local_rng());
@test all(isfinite, p)

@testset "Cache Corrupting Results" begin
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ InteractiveUtils.versioninfo(verbose=true)
# typename doesn't work on 1.5
@test_broken sprint((io, t) -> show(io, t), scflp) == print_str1
end
p = SimpleChains.init_params(scflp, T)
p = SimpleChains.init_params(scflp, T, rng = Random.default_rng())
g = similar(p)
let sc = SimpleChains.remove_loss(sc)
@test_throws ArgumentError sc(rand(T, 23, 2), p)
Expand Down

2 comments on commit 6fbf2b0

@korsbo
Copy link
Member Author

@korsbo korsbo commented on 6fbf2b0 Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/69189

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.3 -m "<description of version>" 6fbf2b040e150498570f480aa491b483ff484a48
git push origin v0.3.3

Please sign in to comment.