From 56f36c2594e49c52b9311a5f19b9d8ec9d41a253 Mon Sep 17 00:00:00 2001 From: RohitRathore1 Date: Fri, 11 Aug 2023 17:10:38 +0530 Subject: [PATCH] Support for lecun normal weight initialization --- src/utils.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 082d9dcb1c..7e85dc195d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -252,6 +252,48 @@ truncated_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwa ChainRulesCore.@non_differentiable truncated_normal(::Any...) +""" + lecun_normal([rng], size...) -> Array + lecun_normal([rng]; kw...) -> Function + +Return an `Array{Float32}` of the given `size` containing random numbers drawn from a truncated normal +distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units +in the weight tensor. + +# Examples +```jldoctest; setup = :(using Random; Random.seed!(0)) +julia> using Statistics + +julia> round(std(Flux.lecun_normal(10, 1000)), digits=3) +0.032f0 + +julia> round(std(Flux.lecun_normal(1000, 10)), digits=3) +0.317f0 + +julia> round(std(Flux.lecun_normal(1000, 1000)), digits=3) +0.032f0 + +julia> Dense(10 => 1000, selu; init = Flux.lecun_normal()) +Dense(10 => 1000, selu) # 11_000 parameters + +julia> round(std(ans.weight), sigdigits=3) +0.319f0 +``` + +# References + +[1] Lecun, Yann, et al. "Efficient backprop." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 9-48. +""" +function lecun_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) + std = Float32(gain)*sqrt(1.0f0 / first(nfan(dims...))) # calculates the standard deviation based on the `fan_in` value + return truncated_normal(rng, dims...; mean=0, std=std) +end + +lecun_normal(dims::Integer...; kwargs...) = lecun_normal(default_rng(), dims...; kwargs...) +lecun_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> lecun_normal(rng, dims...; init_kwargs..., kwargs...) + +ChainRulesCore.@non_differentiable lecun_normal(::Any...) + """ orthogonal([rng], size...; gain = 1) -> Array orthogonal([rng]; kw...) -> Function