From efedbdf5dd67bbec65364111db8c2b2c2008d803 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 Apr 2024 08:15:45 +0200 Subject: [PATCH] more fixes --- docs/src/models/advanced.md | 5 ++- docs/src/models/recurrence.md | 7 ++- docs/src/training/reference.md | 2 +- src/Flux.jl | 5 ++- src/deprecations.jl | 8 +--- src/train.jl | 13 ++---- test/{optimise.jl => TOREMOVE_optimise.jl} | 2 - test/data.jl | 10 +++-- test/layers/basic.jl | 6 +-- test/layers/conv.jl | 26 +++++------ test/layers/recurrent.jl | 39 +---------------- test/losses.jl | 12 ++--- test/outputsize.jl | 6 +-- test/runtests.jl | 2 - test/train.jl | 51 +++------------------- test/utils.jl | 6 +-- 16 files changed, 57 insertions(+), 143 deletions(-) rename test/{optimise.jl => TOREMOVE_optimise.jl} (99%) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 6d78e1efef..9569944b2e 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -84,7 +84,10 @@ There is a second, more severe, kind of restriction possible. This is not recomm Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf). -Naively, we could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. But that would mean a new struct any time the operations along each path changes. Instead, this guide will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. +We could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. That would mean a new struct for each different block, +e.g. one would have a `TransformerBlock` struct for a transformer block, and a `ResNetBlock` struct for a ResNet block, each block being composed by smaller sub-blocks. This is often the simplest and cleanest way to implement complex models. + +This guide instead will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. ### Multiple inputs: a custom `Join` layer diff --git a/docs/src/models/recurrence.md b/docs/src/models/recurrence.md index d962f838d3..87cd944f4f 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/models/recurrence.md @@ -154,7 +154,7 @@ In such a model, only the last two outputs are used to compute the loss, hence t Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update: ```julia -function loss(x, y) +function loss(m, x, y) sum(mse(m(xi), yi) for (xi, yi) in zip(x, y)) end @@ -172,9 +172,8 @@ data = zip(X,Y) Flux.reset!(m) [m(x) for x in seq_init] -ps = Flux.params(m) -opt = Adam(1e-3) -Flux.train!(loss, ps, data, opt) +opt = Flux.setup(Adam(1e-3), m) +Flux.train!(loss, m, data, opt) ``` In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss. diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index ee192cbbfa..67980831f9 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -14,7 +14,7 @@ The available optimization rules are listed the [optimisation rules](@ref man-op ```@docs Flux.Train.setup -Flux.Train.train!(loss, model, data, state; cb) +Flux.Train.train!(loss, model, data, state) Optimisers.update! ``` diff --git a/src/Flux.jl b/src/Flux.jl index 487f2cba8e..af4c35cbb4 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,6 +12,7 @@ using MLUtils const stack = MLUtils.stack # now exported by Base @reexport using Optimisers import Optimisers: trainable +using Optimisers: update!, trainables using Random: default_rng using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback @@ -43,7 +44,7 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion include("train.jl") using .Train -using .Train: setup +using .Train: setup, train! using Adapt, Functors, OneHotArrays include("utils.jl") @@ -55,7 +56,7 @@ include("functor.jl") # from Functors.jl functor, @functor, # from Train/Optimisers.jl - setup, update!, destructure, freeze!, thaw!, adjust!, params, trainable + setup, update!, destructure, freeze!, thaw!, adjust!, trainable, trainables )) # Pirate error to catch a common mistake. diff --git a/src/deprecations.jl b/src/deprecations.jl index 9f3db4c3c4..7fe25f1d0e 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,13 +1,9 @@ # v0.15 deprecations -# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc: -# Base.@deprecate_binding Optimiser OptimiserChain -# Base.@deprecate_binding ClipValue ClipGrad - -train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( +Train.train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( """On Flux 0.15, `train!` no longer accepts implicit `Zygote.Params`. Instead of `train!(loss_xy, Flux.params(model), data, Adam())` - it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` + it now needs `opt_state = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt_state)` where `loss_mxy` accepts the model as its first argument. """ )) diff --git a/src/train.jl b/src/train.jl index 5e71f41f80..5bc0c1b58b 100644 --- a/src/train.jl +++ b/src/train.jl @@ -8,21 +8,14 @@ using ..Flux: Flux # used only in docstring export setup, train! using ProgressLogging: @progress, @withprogress, @logprogress -using Zygote: Zygote, Params +using Zygote: Zygote """ opt_state = setup(rule, model) This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!). -It differs from `Optimisers.setup` in that it: -* has one extra check for mutability (since Flux expects to mutate the model in-place, - while Optimisers.jl is designed to return an updated model) -* has methods which accept Flux's old optimisers, and convert them. - (The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.) - -!!! compat "New" - This function was added in Flux 0.13.9. It was not used by the old "implicit" - interface, using `Flux.Optimise` module and [`Flux.params`](@ref). +It differs from `Optimisers.setup` in that it has one extra check for mutability (since Flux expects to mutate the model in-place, + while Optimisers.jl is designed to return an updated model). # Example ```jldoctest diff --git a/test/optimise.jl b/test/TOREMOVE_optimise.jl similarity index 99% rename from test/optimise.jl rename to test/TOREMOVE_optimise.jl index 68240ec25e..9702b0906d 100644 --- a/test/optimise.jl +++ b/test/TOREMOVE_optimise.jl @@ -1,5 +1,3 @@ -using Flux.Optimise -using Flux.Optimise: runall using Flux: Params, gradient import FillArrays, ComponentArrays import Optimisers diff --git a/test/data.jl b/test/data.jl index b97c4dae80..1274bdcf1d 100644 --- a/test/data.jl +++ b/test/data.jl @@ -80,18 +80,20 @@ using Random # test interaction with `train!` θ = ones(2) X = zeros(2, 10) - loss(x) = sum((x .- θ).^2) + loss(θ, x) = sum((x .- θ).^2) d = DataLoader(X) - Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1)) + opt = Flux.setup(Descent(0.1), θ) + Flux.train!(loss, θ, ncycle(d, 10), opt) @test norm(θ) < 1e-4 # test interaction with `train!` θ = zeros(2) X = ones(2, 10) Y = fill(2, 10) - loss(x, y) = sum((y - x'*θ).^2) + loss(θ, x, y) = sum((y - x'*θ).^2) d = DataLoader((X, Y)) - Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1)) + opt = Flux.setup(Descent(0.1), θ) + Flux.train!(loss, θ, ncycle(d, 10), opt) @test norm(θ .- 1) < 1e-10 # specify the rng diff --git a/test/layers/basic.jl b/test/layers/basic.jl index a1ad34b093..16c08a0f54 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -80,7 +80,7 @@ using Flux: activations @test size(Dense(10 => 5)(randn(10,2))) == (5,2) @test size(Dense(10 => 5)(randn(10,2,3))) == (5,2,3) @test size(Dense(10 => 5)(randn(10,2,3,4))) == (5,2,3,4) - @test_throws DimensionMismatch Dense(10, 5)(randn(11,2,3)) + @test_throws DimensionMismatch Dense(10 => 5)(randn(11,2,3)) end @testset "zeros" begin @test Dense(10 => 1, identity, init = ones)(ones(10,1)) == 10*ones(1, 1) @@ -156,9 +156,9 @@ using Flux: activations @test mo(input) == target end - @testset "params" begin + @testset "trainables" begin mo = Maxout(()->Dense(32 => 64), 4) - ps = Flux.params(mo) + ps = Flux.trainables(mo) @test length(ps) == 8 #4 alts, each with weight and bias end end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 920231bbdf..bd180ce893 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -36,35 +36,35 @@ end @test size(m(r)) == (10, 5) # Test bias switch - bias = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3)) + m2 = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3)) ip = zeros(Float32, 28,28,1,1) - op = bias(ip) + op = m2(ip) @test sum(op) == prod(size(op)) @testset "No bias mapped through $lmap" for lmap in (identity, cpu, f32) - bias = Conv((2,2), 1=>3, bias = false) |> lmap - op = bias(ip) + m3 = Conv((2,2), 1=>3, bias = false) |> lmap + op = m3(ip) @test sum(op) ≈ 0.f0 - gs = gradient(() -> sum(bias(ip)), Flux.params(bias)) - @test bias.bias ∉ gs.params + gs = gradient(m -> sum(m(ip)), m3)[1] + @test gs.bias === nothing end # Train w/o bias and make sure no convergence happens # when only bias can be converged - bias = Conv((2, 2), 1=>3, bias = false); + m4 = Conv((2, 2), 1=>3, bias = false); ip = zeros(Float32, 28,28,1,1) op = zeros(Float32, 27,27,3,1) .+ 2.f0 - opt = Descent() + opt_state = Flux.setup(Descent(), m4) for _ = 1:10^3 - gs = gradient(Flux.params(bias)) do - Flux.Losses.mse(bias(ip), op) - end - Flux.Optimise.update!(opt, params(bias), gs) + gs = gradient(m4) do m + Flux.mse(m(ip), op) + end[1] + Flux.update!(opt_state, m4, gs) end - @test Flux.Losses.mse(bias(ip), op) ≈ 4.f0 + @test Flux.Losses.mse(m4(ip), op) ≈ 4.f0 @testset "Grouped Conv" begin ip = rand(Float32, 28, 100, 2) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 7df8b0d4c2..ab8cfb3bca 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,39 +1,8 @@ using LinearAlgebra -@testset "RNN gradients-implicit" begin - layer = Flux.Recur(Flux.RNNCell(1, 1, identity)) - layer.cell.Wi .= 5.0 - layer.cell.Wh .= 4.0 - layer.cell.b .= 0.0f0 - layer.cell.state0 .= 7.0 - x = [[2.0f0], [3.0f0]] - - # theoretical primal gradients - primal = - layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ - x[2] .* layer.cell.Wi - ∇Wi = x[1] .* layer.cell.Wh .+ x[2] - ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi - ∇b = layer.cell.Wh .+ 1 - ∇state0 = layer.cell.Wh .^ 2 - - Flux.reset!(layer) - ps = Flux.params(layer) - e, g = Flux.withgradient(ps) do - out = [layer(xi) for xi in x] - sum(out[2]) - end - - @test primal[1] ≈ e - @test ∇Wi ≈ g[ps[1]] - @test ∇Wh ≈ g[ps[2]] - @test ∇b ≈ g[ps[3]] - @test ∇state0 ≈ g[ps[4]] - -end -@testset "RNN gradients-explicit" begin - layer = Flux.Recur(Flux.RNNCell(1, 1, identity)) +@testset "RNN gradients" begin + layer = Flux.Recur(Flux.RNNCell(1 => 1, identity)) layer.cell.Wi .= 5.0f0 layer.cell.Wh .= 4.0f0 layer.cell.b .= 0.0f0 @@ -138,19 +107,15 @@ end @testset for R in [RNN, GRU, LSTM, GRUv3] m1 = R(3 => 5) m2 = R(3 => 5) - m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation x1 = rand(Float32, 3) x2 = rand(Float32, 3, 1) x3 = rand(Float32, 3, 1, 2) Flux.reset!(m1) Flux.reset!(m2) - Flux.reset!(m3) @test size(m1(x1)) == (5,) @test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape @test size(m2(x2)) == (5, 1) @test size(m2(x2)) == (5, 1) - @test size(m3(x3)) == (5, 1, 2) - @test size(m3(x3)) == (5, 1, 2) end end diff --git a/test/losses.jl b/test/losses.jl index a5ce1139df..1d745c9d20 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -76,7 +76,7 @@ y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:] @test crossentropy(ŷ, y_smoothed) ≈ lossvalue_smoothed @test crossentropy(ylp, label_smoothing(yl, 2sf)) ≈ -sum(yls.*log.(ylp)) @test crossentropy(ylp, yl) ≈ -sum(yl.*log.(ylp)) - @test iszero(crossentropy(y_same, ya, ϵ=0)) # ε is deprecated + @test iszero(crossentropy(y_same, ya, eps=0)) # ε is deprecated @test iszero(crossentropy(ya, ya, eps=0)) @test crossentropy(y_sim, ya) < crossentropy(y_sim, ya_smoothed) @test crossentropy(y_dis, ya) > crossentropy(y_dis, ya_smoothed) @@ -92,15 +92,15 @@ logŷ, y = randn(3), rand(3) yls = y.*(1-2sf).+sf @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); eps=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) @test binarycrossentropy(σ.(logŷ), y; eps=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) @test binarycrossentropy(σ.(logŷ), y) ≈ mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) @test binarycrossentropy([0.1,0.2,0.9], 1) ≈ -mean(log, [0.1,0.2,0.9]) # constant label end @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); ϵ=0) - @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; ϵ=0) + @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); eps=0) + @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; eps=0) end y = onehotbatch([1], 0:1) @@ -152,7 +152,7 @@ end @testset "tversky_loss" begin @test Flux.tversky_loss(ŷ, y) ≈ -0.06772009029345383 - @test Flux.tversky_loss(ŷ, y, β=0.8) ≈ -0.09490740740740744 + @test Flux.tversky_loss(ŷ, y, beta=0.8) ≈ -0.09490740740740744 @test Flux.tversky_loss(y, y) ≈ -0.5576923076923075 end @@ -180,7 +180,7 @@ end 0.4 0.7] @test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 - @test Flux.binary_focal_loss(ŷ, y; γ=0.0) ≈ Flux.binarycrossentropy(ŷ, y) + @test Flux.binary_focal_loss(ŷ, y; gamma=0.0) ≈ Flux.binarycrossentropy(ŷ, y) end @testset "focal_loss" begin diff --git a/test/outputsize.jl b/test/outputsize.jl index 6249043bb3..fe2316cdf0 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -2,11 +2,11 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) @test outputsize(m, (10, 10, 3, 1)) == (6, 6, 32, 1) - m = Dense(10, 5) + m = Dense(10 => 5) @test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1) @test outputsize(m, (10,); padbatch=true) == (5, 1) - m = Chain(Dense(10, 8, σ), Dense(8 => 5), Dense(5 => 2)) + m = Chain(Dense(10 => 8, σ), Dense(8 => 5), Dense(5 => 2)) @test outputsize(m, (10,); padbatch=true) == (2, 1) @test outputsize(m, (10, 30)) == (2, 30) @@ -168,7 +168,7 @@ end m = @autosize (3,) Dense(_ => 4) @test randn(3) |> m |> size == (4,) - m = @autosize (3, 1) Chain(Dense(_, 4), Dense(4 => 10), softmax) + m = @autosize (3, 1) Chain(Dense(_ => 4), Dense(4 => 10), softmax) @test randn(3, 5) |> m |> size == (10, 5) m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last diff --git a/test/runtests.jl b/test/runtests.jl index 39013b84e6..1a999868ed 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using Flux using Flux: OneHotArray, OneHotMatrix, OneHotVector -using Flux: params using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle @@ -26,7 +25,6 @@ Random.seed!(0) end @testset "Optimise / Train" begin - include("optimise.jl") include("train.jl") include("tracker.jl") end diff --git a/test/train.jl b/test/train.jl index 1d938649d0..5a326dd553 100644 --- a/test/train.jl +++ b/test/train.jl @@ -17,8 +17,8 @@ using Random model = (weight=copy(w2), bias=zeros(10), ignore=nothing) @test loss(model, rand(10, 10)) > 1 - opt = Flux.setup(rule, model) - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + opt_state = Flux.setup(rule, model) + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt_state) @test loss(model, rand(10, 10)) < 0.01 end @@ -54,49 +54,8 @@ end Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end - - @testset "callbacks give helpful error" begin - m1 = Dense(1 => 1) - cb = () -> println("this should not be printed") - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) - end end -@testset "Explicit Flux.update! features" begin - m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) - x = rand(2) - y1 = m(x) # before - - # Implicit gradient - gold = gradient(() -> m(x), Flux.params(m)) - @test gold isa Flux.Zygote.Grads - @test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly - Flux.update!(Flux.Adam(), Flux.params(m), gold) - y2 = m(x) - @test y2 < y1 - - # Explicit gradient - gs = gradient(marg -> marg(x), m) - @test gs isa Tuple - @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly - @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly - @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs) # friendly - @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs[1]) # friendly - s = Flux.setup(Adam(), m) - @info "ignore this warning, just testing an upgrade path:" - Flux.update!(s, m, gs) # Chain + Tuple can be unambiguously sorted out - y3 = m(x) - @test y3 < y2 - Flux.update!(s, m, gs[1]) # finally, this is the correct thing - y4 = m(x) - @test y4 < y3 - - # Also check that if you import the new Adam, then Flux.setup does still work! - s2 = Flux.setup(Optimisers.Adam(), m) - Flux.update!(s2, m, gs[1]) - y5 = m(x) - @test y5 < y4 -end @testset "L2 regularisation" begin # New docs claim an exact equivalent. It's a bit long to put the example in there, @@ -115,14 +74,14 @@ end end diff1 = model.weight .- init_weight - # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! + # Take 2: the same, but with Flux.trainables. model.weight .= init_weight model.bias .= 0 pen2(x::AbstractArray) = sum(abs2, x)/2 opt = Flux.setup(Adam(0.1), model) Flux.train!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) + l2 = sum(pen2, Flux.trainables(m)) err + 0.33 * l2 end diff2 = model.weight .- init_weight @@ -143,6 +102,6 @@ end # https://github.com/FluxML/Flux.jl/issues/2144 @test Flux.setup(Flux.Adam(), Embedding(3 => 1)).weight isa Optimisers.Leaf # Typo in 0.13.9's deprecation - @test Flux.setup(Flux.ClipValue(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipGrad + @test Flux.setup(Flux.ClipGrad(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipGrad @test Flux.setup(Flux.ClipNorm(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipNorm end diff --git a/test/utils.jl b/test/utils.jl index b03696ae9b..2bfaad07bd 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -469,7 +469,7 @@ end mod_par = Flux.modules(Parallel(Flux.Bilinear((2,2) => 2,cbrt), Dense(2 => 2,abs), Dense(2 => 2,abs2))) @test length(mod_par) == 5 - mod_rnn = Flux.modules(Chain(Dense(2 => 3), BatchNorm(3), LSTM(3,4))) + mod_rnn = Flux.modules(Chain(Dense(2 => 3), BatchNorm(3), LSTM(3 => 4))) @test length(mod_rnn) == 6 @test mod_rnn[end] isa Flux.LSTMCell @@ -648,7 +648,7 @@ end data = rand(Float32, n_input, n_batch) model = Chain( - Dense(n_input, n_shared), + Dense(n_input => n_shared), Split(Dense(n_shared => n_outputs[1]), Dense(n_shared => n_outputs[2])) ) @@ -662,6 +662,6 @@ end # make sure rng_from_array is non_differentiable @testset "rng_from_array" begin - m(x) = (rand(rng_from_array(x)) * x)[1] + m(x) = (rand(Flux.rng_from_array(x)) * x)[1] gradient(m, ones(2)) end