Skip to content

Commit

Permalink
Debug Flux tests (#371)
Browse files Browse the repository at this point in the history
* Add Tracker tests

* Remove Tracker
  • Loading branch information
gdalle authored Jul 19, 2024
1 parent a78c820 commit 5419f9a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Down/Flux/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5 changes: 3 additions & 2 deletions DifferentiationInterface/test/Down/Flux/test.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using Enzyme: Enzyme
using FiniteDifferences: FiniteDifferences
using Flux: Flux
using Enzyme: Enzyme
using Random
using Zygote: Zygote
using Test

Enzyme.API.runtimeActivity!(true)
Random.seed!(0)

test_differentiation(
[
Expand Down
3 changes: 2 additions & 1 deletion DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Expand All @@ -76,4 +77,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"]
7 changes: 5 additions & 2 deletions DifferentiationInterfaceTest/test/weird.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using FiniteDifferences: FiniteDifferences
using Flux: Flux
using ForwardDiff: ForwardDiff
using JLArrays: JLArrays
using Random
using SparseConnectivityTracer
using SparseMatrixColorings
using StaticArrays: StaticArrays
Expand All @@ -23,12 +24,14 @@ test_differentiation(
AutoZygote(), gpu_scenarios(); correctness=true, second_order=false, logging=LOGGING
)

Random.seed!(0)

test_differentiation(
AutoZygote(),
DIT.flux_scenarios();
isequal=DIT.flux_isequal,
isapprox=DIT.flux_isapprox,
rtol=5e-2,
atol=1e-2,
rtol=1e-2,
atol=1e-6,
logging=LOGGING,
)

0 comments on commit 5419f9a

Please sign in to comment.