Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 7, 2024
1 parent 63b7613 commit efedbdf
Show file tree
Hide file tree
Showing 16 changed files with 57 additions and 143 deletions.
5 changes: 4 additions & 1 deletion docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ There is a second, more severe, kind of restriction possible. This is not recomm

Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf).

Naively, we could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. But that would mean a new struct any time the operations along each path changes. Instead, this guide will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path.
We could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. That would mean a new struct for each different block,
e.g. one would have a `TransformerBlock` struct for a transformer block, and a `ResNetBlock` struct for a ResNet block, each block being composed by smaller sub-blocks. This is often the simplest and cleanest way to implement complex models.

This guide instead will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path.

### Multiple inputs: a custom `Join` layer

Expand Down
7 changes: 3 additions & 4 deletions docs/src/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ In such a model, only the last two outputs are used to compute the loss, hence t
Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update:

```julia
function loss(x, y)
function loss(m, x, y)
sum(mse(m(xi), yi) for (xi, yi) in zip(x, y))
end

Expand All @@ -172,9 +172,8 @@ data = zip(X,Y)
Flux.reset!(m)
[m(x) for x in seq_init]

ps = Flux.params(m)
opt = Adam(1e-3)
Flux.train!(loss, ps, data, opt)
opt = Flux.setup(Adam(1e-3), m)
Flux.train!(loss, m, data, opt)
```

In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/training/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The available optimization rules are listed the [optimisation rules](@ref man-op

```@docs
Flux.Train.setup
Flux.Train.train!(loss, model, data, state; cb)
Flux.Train.train!(loss, model, data, state)
Optimisers.update!
```

Expand Down
5 changes: 3 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using MLUtils
const stack = MLUtils.stack # now exported by Base
@reexport using Optimisers
import Optimisers: trainable
using Optimisers: update!, trainables
using Random: default_rng
using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, gradient, pullback
Expand Down Expand Up @@ -43,7 +44,7 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion

include("train.jl")
using .Train
using .Train: setup
using .Train: setup, train!

using Adapt, Functors, OneHotArrays
include("utils.jl")
Expand All @@ -55,7 +56,7 @@ include("functor.jl")
# from Functors.jl
functor, @functor,
# from Train/Optimisers.jl
setup, update!, destructure, freeze!, thaw!, adjust!, params, trainable
setup, update!, destructure, freeze!, thaw!, adjust!, trainable, trainables
))

# Pirate error to catch a common mistake.
Expand Down
8 changes: 2 additions & 6 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# v0.15 deprecations

# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc:
# Base.@deprecate_binding Optimiser OptimiserChain
# Base.@deprecate_binding ClipValue ClipGrad

train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
Train.train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError(
"""On Flux 0.15, `train!` no longer accepts implicit `Zygote.Params`.
Instead of `train!(loss_xy, Flux.params(model), data, Adam())`
it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)`
it now needs `opt_state = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt_state)`
where `loss_mxy` accepts the model as its first argument.
"""
))
13 changes: 3 additions & 10 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@ using ..Flux: Flux # used only in docstring
export setup, train!

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote, Params
using Zygote: Zygote

"""
opt_state = setup(rule, model)
This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!).
It differs from `Optimisers.setup` in that it:
* has one extra check for mutability (since Flux expects to mutate the model in-place,
while Optimisers.jl is designed to return an updated model)
* has methods which accept Flux's old optimisers, and convert them.
(The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.)
!!! compat "New"
This function was added in Flux 0.13.9. It was not used by the old "implicit"
interface, using `Flux.Optimise` module and [`Flux.params`](@ref).
It differs from `Optimisers.setup` in that it has one extra check for mutability (since Flux expects to mutate the model in-place,
while Optimisers.jl is designed to return an updated model).
# Example
```jldoctest
Expand Down
2 changes: 0 additions & 2 deletions test/optimise.jl → test/TOREMOVE_optimise.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Flux.Optimise
using Flux.Optimise: runall
using Flux: Params, gradient
import FillArrays, ComponentArrays
import Optimisers
Expand Down
10 changes: 6 additions & 4 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,20 @@ using Random
# test interaction with `train!`
θ = ones(2)
X = zeros(2, 10)
loss(x) = sum((x .- θ).^2)
loss(θ, x) = sum((x .- θ).^2)
d = DataLoader(X)
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
opt = Flux.setup(Descent(0.1), θ)
Flux.train!(loss, θ, ncycle(d, 10), opt)
@test norm(θ) < 1e-4

# test interaction with `train!`
θ = zeros(2)
X = ones(2, 10)
Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2)
loss(θ, x, y) = sum((y - x'*θ).^2)
d = DataLoader((X, Y))
Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1))
opt = Flux.setup(Descent(0.1), θ)
Flux.train!(loss, θ, ncycle(d, 10), opt)
@test norm.- 1) < 1e-10

# specify the rng
Expand Down
6 changes: 3 additions & 3 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ using Flux: activations
@test size(Dense(10 => 5)(randn(10,2))) == (5,2)
@test size(Dense(10 => 5)(randn(10,2,3))) == (5,2,3)
@test size(Dense(10 => 5)(randn(10,2,3,4))) == (5,2,3,4)
@test_throws DimensionMismatch Dense(10, 5)(randn(11,2,3))
@test_throws DimensionMismatch Dense(10 => 5)(randn(11,2,3))
end
@testset "zeros" begin
@test Dense(10 => 1, identity, init = ones)(ones(10,1)) == 10*ones(1, 1)
Expand Down Expand Up @@ -156,9 +156,9 @@ using Flux: activations
@test mo(input) == target
end

@testset "params" begin
@testset "trainables" begin
mo = Maxout(()->Dense(32 => 64), 4)
ps = Flux.params(mo)
ps = Flux.trainables(mo)
@test length(ps) == 8 #4 alts, each with weight and bias
end
end
Expand Down
26 changes: 13 additions & 13 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,35 @@ end
@test size(m(r)) == (10, 5)

# Test bias switch
bias = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3))
m2 = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3))
ip = zeros(Float32, 28,28,1,1)

op = bias(ip)
op = m2(ip)
@test sum(op) == prod(size(op))

@testset "No bias mapped through $lmap" for lmap in (identity, cpu, f32)
bias = Conv((2,2), 1=>3, bias = false) |> lmap
op = bias(ip)
m3 = Conv((2,2), 1=>3, bias = false) |> lmap
op = m3(ip)
@test sum(op) 0.f0
gs = gradient(() -> sum(bias(ip)), Flux.params(bias))
@test bias.bias gs.params
gs = gradient(m -> sum(m(ip)), m3)[1]
@test gs.bias === nothing
end

# Train w/o bias and make sure no convergence happens
# when only bias can be converged
bias = Conv((2, 2), 1=>3, bias = false);
m4 = Conv((2, 2), 1=>3, bias = false);
ip = zeros(Float32, 28,28,1,1)
op = zeros(Float32, 27,27,3,1) .+ 2.f0
opt = Descent()
opt_state = Flux.setup(Descent(), m4)

for _ = 1:10^3
gs = gradient(Flux.params(bias)) do
Flux.Losses.mse(bias(ip), op)
end
Flux.Optimise.update!(opt, params(bias), gs)
gs = gradient(m4) do m
Flux.mse(m(ip), op)
end[1]
Flux.update!(opt_state, m4, gs)
end

@test Flux.Losses.mse(bias(ip), op) 4.f0
@test Flux.Losses.mse(m4(ip), op) 4.f0

@testset "Grouped Conv" begin
ip = rand(Float32, 28, 100, 2)
Expand Down
39 changes: 2 additions & 37 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,8 @@
using LinearAlgebra

@testset "RNN gradients-implicit" begin
layer = Flux.Recur(Flux.RNNCell(1, 1, identity))
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

Flux.reset!(layer)
ps = Flux.params(layer)
e, g = Flux.withgradient(ps) do
out = [layer(xi) for xi in x]
sum(out[2])
end

@test primal[1] e
@test ∇Wi g[ps[1]]
@test ∇Wh g[ps[2]]
@test ∇b g[ps[3]]
@test ∇state0 g[ps[4]]

end

@testset "RNN gradients-explicit" begin
layer = Flux.Recur(Flux.RNNCell(1, 1, identity))
@testset "RNN gradients" begin
layer = Flux.Recur(Flux.RNNCell(1 => 1, identity))
layer.cell.Wi .= 5.0f0
layer.cell.Wh .= 4.0f0
layer.cell.b .= 0.0f0
Expand Down Expand Up @@ -138,19 +107,15 @@ end
@testset for R in [RNN, GRU, LSTM, GRUv3]
m1 = R(3 => 5)
m2 = R(3 => 5)
m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation
x1 = rand(Float32, 3)
x2 = rand(Float32, 3, 1)
x3 = rand(Float32, 3, 1, 2)
Flux.reset!(m1)
Flux.reset!(m2)
Flux.reset!(m3)
@test size(m1(x1)) == (5,)
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
@test size(m2(x2)) == (5, 1)
@test size(m2(x2)) == (5, 1)
@test size(m3(x3)) == (5, 1, 2)
@test size(m3(x3)) == (5, 1, 2)
end
end

Expand Down
12 changes: 6 additions & 6 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:]
@test crossentropy(ŷ, y_smoothed) lossvalue_smoothed
@test crossentropy(ylp, label_smoothing(yl, 2sf)) -sum(yls.*log.(ylp))
@test crossentropy(ylp, yl) -sum(yl.*log.(ylp))
@test iszero(crossentropy(y_same, ya, ϵ=0)) # ε is deprecated
@test iszero(crossentropy(y_same, ya, eps=0)) # ε is deprecated
@test iszero(crossentropy(ya, ya, eps=0))
@test crossentropy(y_sim, ya) < crossentropy(y_sim, ya_smoothed)
@test crossentropy(y_dis, ya) > crossentropy(y_dis, ya_smoothed)
Expand All @@ -92,15 +92,15 @@ logŷ, y = randn(3), rand(3)
yls = y.*(1-2sf).+sf

@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ))
@test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); eps=0) -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ))
@test binarycrossentropy(σ.(logŷ), y; eps=0) mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ)))
@test binarycrossentropy(σ.(logŷ), y) mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ))))
@test binarycrossentropy([0.1,0.2,0.9], 1) -mean(log, [0.1,0.2,0.9]) # constant label
end

@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); ϵ=0)
@test logitbinarycrossentropy(logŷ, y) binarycrossentropy(σ.(logŷ), y; ϵ=0)
@test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); eps=0)
@test logitbinarycrossentropy(logŷ, y) binarycrossentropy(σ.(logŷ), y; eps=0)
end

y = onehotbatch([1], 0:1)
Expand Down Expand Up @@ -152,7 +152,7 @@ end

@testset "tversky_loss" begin
@test Flux.tversky_loss(ŷ, y) -0.06772009029345383
@test Flux.tversky_loss(ŷ, y, β=0.8) -0.09490740740740744
@test Flux.tversky_loss(ŷ, y, beta=0.8) -0.09490740740740744
@test Flux.tversky_loss(y, y) -0.5576923076923075
end

Expand Down Expand Up @@ -180,7 +180,7 @@ end
0.4 0.7]
@test Flux.binary_focal_loss(ŷ, y) 0.0728675615927385
@test Flux.binary_focal_loss(ŷ1, y1) 0.05691642237852222
@test Flux.binary_focal_loss(ŷ, y; γ=0.0) Flux.binarycrossentropy(ŷ, y)
@test Flux.binary_focal_loss(ŷ, y; gamma=0.0) Flux.binarycrossentropy(ŷ, y)
end

@testset "focal_loss" begin
Expand Down
6 changes: 3 additions & 3 deletions test/outputsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32))
@test outputsize(m, (10, 10, 3, 1)) == (6, 6, 32, 1)

m = Dense(10, 5)
m = Dense(10 => 5)
@test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1)
@test outputsize(m, (10,); padbatch=true) == (5, 1)

m = Chain(Dense(10, 8, σ), Dense(8 => 5), Dense(5 => 2))
m = Chain(Dense(10 => 8, σ), Dense(8 => 5), Dense(5 => 2))
@test outputsize(m, (10,); padbatch=true) == (2, 1)
@test outputsize(m, (10, 30)) == (2, 30)

Expand Down Expand Up @@ -168,7 +168,7 @@ end
m = @autosize (3,) Dense(_ => 4)
@test randn(3) |> m |> size == (4,)

m = @autosize (3, 1) Chain(Dense(_, 4), Dense(4 => 10), softmax)
m = @autosize (3, 1) Chain(Dense(_ => 4), Dense(4 => 10), softmax)
@test randn(3, 5) |> m |> size == (10, 5)

m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last
Expand Down
2 changes: 0 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Flux
using Flux: OneHotArray, OneHotMatrix, OneHotVector
using Flux: params
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
Expand All @@ -26,7 +25,6 @@ Random.seed!(0)
end

@testset "Optimise / Train" begin
include("optimise.jl")
include("train.jl")
include("tracker.jl")
end
Expand Down
Loading

0 comments on commit efedbdf

Please sign in to comment.