Skip to content

Commit

Permalink
Better error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Apr 19, 2024
1 parent 8ed0129 commit 32ca659
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ DocMeta.setdocmeta!(
@testset "JET tests" begin
JET.test_package(SparseConnectivityTracer; target_defined_modules=true)
end
@testset verbose = true "Order classification" begin
include("order.jl")
@testset verbose = true "Classification of operators by diff'ability" begin
include("test_differentiability.jl")
end
@testset "Connectivity" begin
x = rand(3)
Expand Down
29 changes: 15 additions & 14 deletions test/order.jl → test/test_differentiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function differentiability(∂f∂x, ∂²f∂x²; atol=DEFAULT_ATOL)
end
end

## 1 to 1
## 1-to-1

function classify_1_to_1(f, x; atol=DEFAULT_ATOL)
∂f∂x = derivative(f, x)
Expand All @@ -82,12 +82,12 @@ function classify_1_to_1(f, x; atol=DEFAULT_ATOL)
return order
end

classify_1_to_1(op::Symbol; kwargs...) = classify_1_to_1(sym2fn(op); kwargs...)
function classify_1_to_1(f::Function; atol=DEFAULT_ATOL, trials=100)
function classify_1_to_1(op::Symbol; atol=DEFAULT_ATOL, trials=100)
f = sym2fn(op)
try
return maximum(classify_1_to_1(f, random_input(f); atol) for _ in 1:trials)
catch e
@warn "Classification failed" e
@warn "Classification of 1-to-1 operator $op failed" e
return error_order
end
end
Expand All @@ -98,7 +98,7 @@ const TEST_1_TO_1 = (
("Zero order", ops_1_to_1_z, zero_order),
("Constant", ops_1_to_1_const, zero_order),
)
@testset verbose = true "1 to 1" begin
@testset verbose = true "1-to-1" begin
@testset "All operators covered" begin
all_ops = union([ops for (name, ops, ref_order) in TEST_1_TO_1]...)
@test Set(all_ops) == Set(ops_1_to_1)
Expand All @@ -114,7 +114,7 @@ const TEST_1_TO_1 = (
end
end;

## 2 to 1
## 2-to-1

function classify_2_to_1(f, x, y; atol)
g = gradient(Base.splat(f), [x, y])
Expand All @@ -137,14 +137,15 @@ function classify_2_to_1(f, x, y; atol)
end

classify_2_to_1(op::Symbol; kwargs...) = classify_2_to_1(sym2fn(op); kwargs...)
function classify_2_to_1(f; atol=1e-5, trials=100)
function classify_2_to_1(op::Symbol; atol=1e-5, trials=100)
f = sym2fn(op)
try
return maximum(
classify_2_to_1(f, random_first_input(f), random_second_input(f); atol) for
_ in 1:trials
)
catch e
@warn "Classification failed" e
@warn "Classification of 2-to-1 operator `$op` failed" e
return (error_order, error_order, error_order)
end
end
Expand All @@ -164,7 +165,7 @@ const TEST_2_TO_1 = (
("zfz", ops_2_to_1_zfz, (zero_order, second_order, zero_order)),
("zzz", ops_2_to_1_zzz, (zero_order, zero_order, zero_order)),
)
@testset verbose = true "2 to 1" begin
@testset verbose = true "2-to-1" begin
@testset "All operators covered" begin
all_ops = union([ops for (name, ops, ref_order) in TEST_2_TO_1]...)
@test Set(all_ops) == Set(ops_2_to_1)
Expand All @@ -180,7 +181,7 @@ const TEST_2_TO_1 = (
end
end;

## 1 to 2
## 1-to-2

function classify_1_to_2(f, x; atol)
d1 = derivative(f, x)
Expand All @@ -198,12 +199,12 @@ function classify_1_to_2(f, x; atol)
return (first_arg, second_arg)
end

classify_1_to_2(op::Symbol; kwargs...) = classify_1_to_2(sym2fn(op); kwargs...)
function classify_1_to_2(f; atol=1e-5, trials=100)
function classify_1_to_2(op::Symbol; atol=1e-5, trials=100)
f = sym2fn(op)
try
return maximum(classify_1_to_1(f, random_input(f); atol) for _ in 1:trials)
catch e
@warn "Classification failed" e
@warn "Classification of 1-to-2 operator `$op` failed" e
return (error_order, error_order)
end
end
Expand All @@ -219,7 +220,7 @@ const TEST_1_TO_2 = (
("zf", ops_1_to_2_zf, (zero_order, second_order)),
("zz", ops_1_to_2_zz, (zero_order, zero_order)),
)
@testset verbose = true "1 to 2" begin
@testset verbose = true "1-to-2" begin
@testset "All operators covered" begin
all_ops = union([ops for (name, ops, ref_order) in TEST_1_TO_2]...)
@test Set(all_ops) == Set(ops_1_to_2)
Expand Down

0 comments on commit 32ca659

Please sign in to comment.