Skip to content

Commit

Permalink
Fix links
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 26, 2023
1 parent 4f2aa0a commit 33e400a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 108 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# LuxTestUtils.jl

[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](http://lux.csail.mit.edu/dev/api/)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/api/)

[![CI](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxTestUtils.jl/actions/workflows/CI.yml)

Expand Down
149 changes: 43 additions & 106 deletions src/LuxTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,20 @@ or julia version is < 1.7, then the macro will be a no-op.
All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_opt`.
!!! note
::: note
Instead of specifying `target_modules` with every call, you can set preferences for
`target_modules` using `Preferences.jl`. For example, to set `target_modules` to
`(Lux, LuxLib)` we can run:
Instead of specifying `target_modules` with every call, you can set preferences for
`target_modules` using `Preferences.jl`. For example, to set `target_modules` to
`(Lux, LuxLib)` we can run:
```julia
using Preferences
```julia
using Preferences
set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"),
"target_modules" => ["Lux", "LuxLib"])
```
set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"),
"target_modules" => ["Lux", "LuxLib"])
```
:::
## Example
Expand Down Expand Up @@ -81,16 +83,10 @@ macro jet(expr, args...)

push!(all_args, expr)

ex_call = JET.call_test_ex(:report_call,
Symbol("@test_call"),
vcat(call_extras, all_args),
__module__,
__source__)
ex_opt = JET.call_test_ex(:report_opt,
Symbol("@test_opt"),
vcat(opt_extras, all_args),
__module__,
__source__)
ex_call = JET.call_test_ex(:report_call, Symbol("@test_call"),
vcat(call_extras, all_args), __module__, __source__)
ex_opt = JET.call_test_ex(:report_opt, Symbol("@test_opt"),
vcat(opt_extras, all_args), __module__, __source__)

return Expr(:block, ex_call, ex_opt)
end
Expand All @@ -110,8 +106,7 @@ struct GradientComputationSkipped end
end
end

function check_approx(x::LuxCore.AbstractExplicitLayer,
y::LuxCore.AbstractExplicitLayer;
function check_approx(x::LuxCore.AbstractExplicitLayer, y::LuxCore.AbstractExplicitLayer;
kwargs...)
return x == y
end
Expand All @@ -122,8 +117,7 @@ function check_approx(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...)
check_approx(x.state, y.state; kwargs...)
end

function check_approx(nt1::NamedTuple{fields},
nt2::NamedTuple{fields};
function check_approx(nt1::NamedTuple{fields}, nt2::NamedTuple{fields};
kwargs...) where {fields}
_check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...)
_check_approx(t::Tuple{Nothing, Nothing}) = true
Expand Down Expand Up @@ -227,10 +221,7 @@ macro test_gradients(all_args...)
return test_gradients_expr(__module__, __source__, args...; kwargs...)
end

function test_gradients_expr(__module__,
__source__,
f,
args...;
function test_gradients_expr(__module__, __source__, f, args...;
gpu_testing::Bool=false,
soft_fail::Bool=false,
# Skip Gradient Computation
Expand All @@ -255,29 +246,20 @@ function test_gradients_expr(__module__,
nans::Bool=false,
kwargs...)
orig_exprs = map(x -> QuoteNode(Expr(:macrocall,
GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")),
__source__,
f,
args...)),
GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)),
("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences"))
len = length(args)
__source__ = QuoteNode(__source__)
return quote
gs_zygote = __gradient(Zygote.gradient,
$(esc(f)),
$(esc.(args)...);
gs_zygote = __gradient(Zygote.gradient, $(esc(f)), $(esc.(args)...);
skip=$skip_zygote)

gs_tracker = __gradient(Base.Fix1(broadcast, Tracker.data) Tracker.gradient,
$(esc(f)),
$(esc.(args)...);
skip=$skip_tracker)
$(esc(f)), $(esc.(args)...); skip=$skip_tracker)
tracker_broken = $(tracker_broken && !skip_tracker)

skip_reverse_diff = $(skip_reverse_diff || gpu_testing)
gs_rdiff = __gradient(_rdiff_gradient,
$(esc(f)),
$(esc.(args)...);
gs_rdiff = __gradient(_rdiff_gradient, $(esc(f)), $(esc.(args)...);
skip=skip_reverse_diff)
reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff

Expand All @@ -289,82 +271,38 @@ function test_gradients_expr(__module__,
@debug "Large arrays detected. Skipping some tests based on keyword arguments."
end

skip_forward_diff = $skip_forward_diff ||
$gpu_testing ||
skip_forward_diff = $skip_forward_diff || $gpu_testing ||
(large_arrays && $large_arrays_skip_forward_diff)
gs_fdiff = __gradient(_fdiff_gradient,
$(esc(f)),
$(esc.(args)...);
gs_fdiff = __gradient(_fdiff_gradient, $(esc(f)), $(esc.(args)...);
skip=skip_forward_diff)
forward_diff_broken = $forward_diff_broken && !skip_forward_diff

skip_finite_differences = $skip_finite_differences ||
$gpu_testing ||
skip_finite_differences = $skip_finite_differences || $gpu_testing ||
(large_arrays && $large_arrays_skip_finite_differences)
gs_finite_diff = __gradient(_finitedifferences_gradient,
$(esc(f)),
$(esc.(args)...);
skip=skip_finite_differences)
gs_finite_diff = __gradient(_finitedifferences_gradient, $(esc(f)),
$(esc.(args)...); skip=skip_finite_differences)
finite_differences_broken = $finite_differences_broken && !skip_finite_differences

for idx in 1:($len)
__test_gradient_pair_check($__source__,
$(orig_exprs[1]),
gs_zygote[idx],
gs_tracker[idx],
"Zygote",
"Tracker";
broken=tracker_broken,
soft_fail=$soft_fail,
atol=$atol,
rtol=$rtol,
nans=$nans)
__test_gradient_pair_check($__source__,
$(orig_exprs[2]),
gs_zygote[idx],
gs_rdiff[idx],
"Zygote",
"ReverseDiff";
broken=reverse_diff_broken,
soft_fail=$soft_fail,
atol=$atol,
rtol=$rtol,
nans=$nans)
__test_gradient_pair_check($__source__,
$(orig_exprs[3]),
gs_zygote[idx],
gs_fdiff[idx],
"Zygote",
"ForwardDiff";
broken=forward_diff_broken,
soft_fail=$soft_fail,
atol=$atol,
rtol=$rtol,
nans=$nans)
__test_gradient_pair_check($__source__,
$(orig_exprs[4]),
gs_zygote[idx],
gs_finite_diff[idx],
"Zygote",
"FiniteDifferences";
broken=finite_differences_broken,
soft_fail=$soft_fail,
atol=$atol,
rtol=$rtol,
nans=$nans)
__test_gradient_pair_check($__source__, $(orig_exprs[1]), gs_zygote[idx],
gs_tracker[idx], "Zygote", "Tracker"; broken=tracker_broken,
soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans)
__test_gradient_pair_check($__source__, $(orig_exprs[2]), gs_zygote[idx],
gs_rdiff[idx], "Zygote", "ReverseDiff"; broken=reverse_diff_broken,
soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans)
__test_gradient_pair_check($__source__, $(orig_exprs[3]), gs_zygote[idx],
gs_fdiff[idx], "Zygote", "ForwardDiff"; broken=forward_diff_broken,
soft_fail=$soft_fail, atol=$atol, rtol=$rtol, nans=$nans)
__test_gradient_pair_check($__source__, $(orig_exprs[4]), gs_zygote[idx],
gs_finite_diff[idx], "Zygote", "FiniteDifferences";
broken=finite_differences_broken, soft_fail=$soft_fail, atol=$atol,
rtol=$rtol, nans=$nans)
end
end
end

function __test_gradient_pair_check(__source__,
orig_expr,
v1,
v2,
name1,
name2;
broken::Bool=false,
soft_fail::Bool=false,
kwargs...)
function __test_gradient_pair_check(__source__, orig_expr, v1, v2, name1, name2;
broken::Bool=false, soft_fail::Bool=false, kwargs...)
match = check_approx(v1, v2; kwargs...)
test_type = Symbol("@test_gradients{$name1, $name2}")

Expand Down Expand Up @@ -452,8 +390,7 @@ function _fdiff_gradient(f, args...)
end

function _finitedifferences_gradient(f, args...)
return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1),
f,
return _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(3, 1), f,
args...))
end

Expand Down

0 comments on commit 33e400a

Please sign in to comment.