-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Lux tests * Set up extension * Test deps * LuxTestUtils * Add scenarios * Fix x * Lux tests working with ComponentArrays * Fix Flux * More tolerant lux tests * Adapt to extensions * Rng * Chill rtol
- Loading branch information
Showing
7 changed files
with
214 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
using Pkg | ||
Pkg.add(["FiniteDiff", "Lux", "LuxTestUtils", "Zygote"]) | ||
|
||
using ComponentArrays: ComponentArrays | ||
using DifferentiationInterface, DifferentiationInterfaceTest | ||
import DifferentiationInterfaceTest as DIT | ||
using FiniteDiff: FiniteDiff | ||
using Lux: Lux | ||
using LuxTestUtils: LuxTestUtils | ||
using Random | ||
|
||
Random.seed!(0) | ||
|
||
test_differentiation( | ||
AutoZygote(), | ||
DIT.lux_scenarios(Random.Xoshiro(63)); | ||
isequal=DIT.lux_isequal, | ||
isapprox=DIT.lux_isapprox, | ||
rtol=1.0f-2, | ||
atol=1.0f-3, | ||
logging=LOGGING, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
138 changes: 138 additions & 0 deletions
138
...nterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
module DifferentiationInterfaceTestLuxExt | ||
|
||
using Compat: @compat | ||
using ComponentArrays: ComponentArray | ||
import DifferentiationInterface as DI | ||
using DifferentiationInterfaceTest | ||
import DifferentiationInterfaceTest as DIT | ||
using FiniteDiff: FiniteDiff | ||
using Lux | ||
using LuxTestUtils | ||
using LuxTestUtils: check_approx | ||
using Random: AbstractRNG, default_rng | ||
|
||
#= | ||
Relevant discussions: | ||
- https://github.com/LuxDL/Lux.jl/issues/769 | ||
=# | ||
|
||
function DIT.lux_isequal(a, b) | ||
return check_approx(a, b; atol=0, rtol=0) | ||
end | ||
|
||
function DIT.lux_isapprox(a, b; atol, rtol) | ||
return check_approx(a, b; atol, rtol) | ||
end | ||
|
||
struct SquareLoss{M,X,S} | ||
model::M | ||
x::X | ||
st::S | ||
end | ||
|
||
function (sql::SquareLoss)(ps) | ||
@compat (; model, x, st) = sql | ||
return sum(abs2, first(model(x, ps, st))) | ||
end | ||
|
||
function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) | ||
models_and_xs = [ | ||
(Dense(2, 4), randn(rng, Float32, 2, 3)), | ||
(Dense(2, 4, gelu), randn(rng, Float32, 2, 3)), | ||
(Dense(2, 4, gelu; use_bias=false), randn(rng, Float32, 2, 3)), | ||
(Chain(Dense(2, 4, relu), Dense(4, 3)), randn(rng, Float32, 2, 3)), | ||
(Scale(2), randn(rng, Float32, 2, 3)), | ||
(Conv((3, 3), 2 => 3), randn(rng, Float32, 3, 3, 2, 2)), | ||
(Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(rng, Float32, 3, 3, 2, 2)), | ||
( | ||
Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), | ||
randn(rng, Float32, 3, 3, 2, 2), | ||
), | ||
( | ||
Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), | ||
rand(rng, Float32, 5, 5, 2, 2), | ||
), | ||
( | ||
Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), | ||
rand(rng, Float32, 5, 5, 2, 2), | ||
), | ||
( | ||
Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), | ||
rand(rng, Float32, 5, 5, 2, 2), | ||
), | ||
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(rng, Float32, 5, 2)), | ||
(Bilinear((2, 2) => 3), randn(rng, Float32, 2, 3)), | ||
(SkipConnection(Dense(2 => 2), vcat), randn(rng, Float32, 2, 3)), | ||
(ConvTranspose((3, 3), 3 => 2; stride=2), rand(rng, Float32, 5, 5, 3, 1)), | ||
(StatefulRecurrentCell(RNNCell(3 => 5)), rand(rng, Float32, 3, 2)), | ||
(StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(rng, Float32, 3, 2)), | ||
( | ||
StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), | ||
rand(rng, Float32, 3, 2), | ||
), | ||
( | ||
Chain( | ||
StatefulRecurrentCell(RNNCell(3 => 5)), | ||
StatefulRecurrentCell(RNNCell(5 => 3)), | ||
), | ||
rand(rng, Float32, 3, 2), | ||
), | ||
(StatefulRecurrentCell(LSTMCell(3 => 5)), rand(rng, Float32, 3, 2)), | ||
( | ||
Chain( | ||
StatefulRecurrentCell(LSTMCell(3 => 5)), | ||
StatefulRecurrentCell(LSTMCell(5 => 3)), | ||
), | ||
rand(rng, Float32, 3, 2), | ||
), | ||
(StatefulRecurrentCell(GRUCell(3 => 5)), rand(rng, Float32, 3, 10)), | ||
( | ||
Chain( | ||
StatefulRecurrentCell(GRUCell(3 => 5)), | ||
StatefulRecurrentCell(GRUCell(5 => 3)), | ||
), | ||
rand(rng, Float32, 3, 10), | ||
), | ||
(Chain(Dense(2, 4), BatchNorm(4)), randn(rng, Float32, 2, 3)), | ||
(Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(rng, Float32, 2, 3)), | ||
( | ||
Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), | ||
randn(rng, Float32, 2, 3), | ||
), | ||
(Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), | ||
(Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), | ||
(Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(rng, Float32, 2, 3)), | ||
(Chain(Dense(2, 4), GroupNorm(4, 2)), randn(rng, Float32, 2, 3)), | ||
(Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(rng, Float32, 6, 6, 2, 2)), | ||
( | ||
Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), | ||
randn(rng, Float32, 6, 6, 2, 2), | ||
), | ||
( | ||
Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), | ||
randn(rng, Float32, 4, 4, 2, 2), | ||
), | ||
(Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), | ||
( | ||
Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), | ||
randn(rng, Float32, 6, 6, 2, 2), | ||
), | ||
] | ||
|
||
scens = Scenario[] | ||
|
||
for (model, x) in models_and_xs | ||
ps, st = Lux.setup(rng, model) | ||
ps = ComponentArray(ps) | ||
loss = SquareLoss(model, x, st) | ||
l = loss(ps) | ||
g = DI.gradient(loss, DI.AutoFiniteDiff(), ps) | ||
scen = GradientScenario(loss; x=ps, y=l, grad=g, nb_args=1, place=:outofplace) | ||
push!(scens, scen) | ||
end | ||
|
||
return scens | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters