Skip to content

Commit

Permalink
Add extras type test
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Apr 25, 2024
1 parent cc94ef9 commit 0a484ac
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
12 changes: 11 additions & 1 deletion DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,17 @@ using DifferentiationInterface:
mutation_support,
pushforward_performance,
pullback_performance
using DifferentiationInterface: NoPullbackExtras, NoPushforwardExtras
using DifferentiationInterface:
DerivativeExtras,
GradientExtras,
HessianExtras,
HVPExtras,
JacobianExtras,
PullbackExtras,
PushforwardExtras,
NoPullbackExtras,
NoPushforwardExtras,
SecondDerivativeExtras
using DocStringExtensions
import DifferentiationInterface as DI
using JET: @test_call, @test_opt
Expand Down
72 changes: 72 additions & 0 deletions DifferentiationInterfaceTest/src/tests/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ function test_correctness(
dy2 = pushforward(f, ba, x, dx, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PushforwardExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -66,6 +69,9 @@ function test_correctness(
dy2 = pushforward!(f, dy2_in, ba, x, dx, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PushforwardExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -106,6 +112,9 @@ function test_correctness(
dy2 = pushforward(f!, y2_in, ba, x, dx, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PushforwardExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -145,6 +154,9 @@ function test_correctness(
dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PushforwardExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -187,6 +199,9 @@ function test_correctness(
dx3 = pullbackfunc(dy)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PullbackExtras
end
@testset "Primal value" begin
@test y1 y
@test y3 y
Expand Down Expand Up @@ -229,6 +244,9 @@ function test_correctness(
dx3 = pullbackfunc!(dx3_in, dy)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PullbackExtras
end
@testset "Primal value" begin
@test y1 y
@test y3 y
Expand Down Expand Up @@ -278,6 +296,9 @@ function test_correctness(
dx3 = pullbackfunc(y3_in2, dy)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PullbackExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -326,6 +347,9 @@ function test_correctness(
dx3 = pullbackfunc!(y3_in2, dx3_in, dy)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa PullbackExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -368,6 +392,9 @@ function test_correctness(
der2 = derivative(f, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa DerivativeExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -403,6 +430,9 @@ function test_correctness(
der2 = derivative!(f, der2_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa DerivativeExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -441,6 +471,9 @@ function test_correctness(
der2 = derivative(f!, y2_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa DerivativeExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -478,6 +511,9 @@ function test_correctness(
der2 = derivative!(f!, y2_in, der2_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa DerivativeExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -516,6 +552,9 @@ function test_correctness(
grad2 = gradient(f, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa GradientExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -551,6 +590,9 @@ function test_correctness(
grad2 = gradient!(f, grad2_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa GradientExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -588,6 +630,9 @@ function test_correctness(
jac2 = jacobian(f, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa JacobianExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -623,6 +668,9 @@ function test_correctness(
jac2 = jacobian!(f, jac2_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa JacobianExtras
end
@testset "Primal value" begin
@test y1 y
end
Expand Down Expand Up @@ -661,6 +709,9 @@ function test_correctness(
jac2 = jacobian(f!, y2_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa JacobianExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -698,6 +749,9 @@ function test_correctness(
jac2 = jacobian!(f!, y2_in, jac2_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa JacobianExtras
end
@testset "Primal value" begin
@test y1_in y
@test y1 y
Expand Down Expand Up @@ -734,6 +788,9 @@ function test_correctness(
der21 = second_derivative(f, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa SecondDerivativeExtras
end
@testset "Second derivative value" begin
@test der21 der2_true
end
Expand Down Expand Up @@ -762,6 +819,9 @@ function test_correctness(
der21 = second_derivative!(f, der21_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa SecondDerivativeExtras
end
@testset "Second derivative value" begin
@test der21_in der2_true
@test der21 der2_true
Expand Down Expand Up @@ -792,6 +852,9 @@ function test_correctness(
p1 = hvp(f, ba, x, dx, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa HVPExtras
end
@testset "HVP value" begin
@test p1 p_true
end
Expand Down Expand Up @@ -820,6 +883,9 @@ function test_correctness(
p1 = hvp!(f, p1_in, ba, x, dx, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa HVPExtras
end
@testset "HVP value" begin
@test p1_in p_true
@test p1 p_true
Expand Down Expand Up @@ -850,6 +916,9 @@ function test_correctness(
hess1 = hessian(f, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa HessianExtras
end
@testset "Hessian value" begin
@test hess1 hess_true
end
Expand Down Expand Up @@ -878,6 +947,9 @@ function test_correctness(
hess1 = hessian!(f, hess1_in, ba, x, extras)

let ()(x, y) = isapprox(x, y; atol, rtol)
@testset "Extras type" begin
@test extras isa HessianExtras
end
@testset "Hessian value" begin
@test hess1_in hess_true
@test hess1 hess_true
Expand Down

0 comments on commit 0a484ac

Please sign in to comment.