From 5419f9adac89d8b881a0f91c32a10de015a45c4b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 19 Jul 2024 10:23:42 +0200 Subject: [PATCH] Debug Flux tests (#371) * Add Tracker tests * Remove Tracker --- DifferentiationInterface/test/Down/Flux/Project.toml | 1 + DifferentiationInterface/test/Down/Flux/test.jl | 5 +++-- DifferentiationInterfaceTest/Project.toml | 3 ++- DifferentiationInterfaceTest/test/weird.jl | 7 +++++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/DifferentiationInterface/test/Down/Flux/Project.toml b/DifferentiationInterface/test/Down/Flux/Project.toml index 865584771..2fa78f6ac 100644 --- a/DifferentiationInterface/test/Down/Flux/Project.toml +++ b/DifferentiationInterface/test/Down/Flux/Project.toml @@ -4,4 +4,5 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index 30e42a00c..a09d8be8f 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -1,12 +1,13 @@ using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT +using Enzyme: Enzyme using FiniteDifferences: FiniteDifferences using Flux: Flux -using Enzyme: Enzyme +using Random using Zygote: Zygote using Test -Enzyme.API.runtimeActivity!(true) +Random.seed!(0) test_differentiation( [ diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 8cb7764d2..c12be09e5 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -68,6 +68,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" @@ -76,4 +77,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 0be6680ee..6171e3a11 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -7,6 +7,7 @@ using FiniteDifferences: FiniteDifferences using Flux: Flux using ForwardDiff: ForwardDiff using JLArrays: JLArrays +using Random using SparseConnectivityTracer using SparseMatrixColorings using StaticArrays: StaticArrays @@ -23,12 +24,14 @@ test_differentiation( AutoZygote(), gpu_scenarios(); correctness=true, second_order=false, logging=LOGGING ) +Random.seed!(0) + test_differentiation( AutoZygote(), DIT.flux_scenarios(); isequal=DIT.flux_isequal, isapprox=DIT.flux_isapprox, - rtol=5e-2, - atol=1e-2, + rtol=1e-2, + atol=1e-6, logging=LOGGING, )