diff --git a/src/code_transformation/differentiable.jl b/src/code_transformation/differentiable.jl index 710a0ba1..4b023066 100644 --- a/src/code_transformation/differentiable.jl +++ b/src/code_transformation/differentiable.jl @@ -49,8 +49,8 @@ get_quote_body(code::QuoteNode) = code.value Unionise the code inside a call to `eval`, such that when the `eval` call actually occurs the code inside will be unionised. """ -function unionise_eval(code::Expr) - body = Expr(:macrocall, Symbol("@unionise"), nothing, deepcopy(get_quote_body(code.args[end]))) +function unionise_eval(code::Expr, linfo::LineNumberNode=LineNumberNode(0)) + body = Expr(:macrocall, Symbol("@unionise"), linfo, deepcopy(get_quote_body(code.args[end]))) return length(code.args) == 3 ? Expr(:call, :eval, deepcopy(code.args[2]), quot(body)) : Expr(:call, :eval, quot(body)) @@ -62,11 +62,11 @@ end Unionise the code in a call to @eval, such that when the `eval` call actually occurs, the code inside will be unionised. """ -function unionise_macro_eval(code::Expr) - body = Expr(:macrocall, Symbol("@unionise"), nothing, deepcopy(code.args[end])) +function unionise_macro_eval(code::Expr, linfo::LineNumberNode=LineNumberNode(0)) + body = Expr(:macrocall, Symbol("@unionise"), linfo, deepcopy(code.args[end])) return length(code.args) == 4 ? - Expr(:macrocall, Symbol("@eval"), nothing, deepcopy(code.args[3]), body) : - Expr(:macrocall, Symbol("@eval"), nothing, body) + Expr(:macrocall, Symbol("@eval"), linfo, deepcopy(code.args[3]), body) : + Expr(:macrocall, Symbol("@eval"), linfo, body) end """ @@ -125,24 +125,24 @@ arguments. This should not affect the existing functionality of the code. function unionise end # If we get a symbol then we cannot have found a function definition, so ignore it. -unionise(code) = code +unionise(code, linfo::LineNumberNode=LineNumberNode(0)) = code # Recurse through an expression, bottoming out if we find a function definition or a # quoted expression to be `eval`-ed. -function unionise(code::Expr) +function unionise(code::Expr, linfo::LineNumberNode=LineNumberNode(0)) if code.head in (:function, Symbol("->")) return Expr(code.head, unionise_sig(code.args[1]), code.args[2]) elseif code.head == Symbol("=") && !isa(code.args[1], Symbol) && (get_body(code.args[1]).head == :tuple || get_body(code.args[1]).head isa Symbol) return Expr(code.head, unionise_sig(code.args[1]), code.args[2]) elseif code.head == :call && code.args[1] == :eval - return unionise_eval(code) + return unionise_eval(code, linfo) elseif code.head == :macrocall && code.args[1] == Symbol("@eval") - return unionise_macro_eval(code) + return unionise_macro_eval(code, linfo) elseif code.head == :struct return unionise_struct(code) else - return Expr(code.head, [unionise(arg) for arg in code.args]...) + return Expr(code.head, [unionise(arg, linfo) for arg in code.args]...) end end @@ -153,5 +153,5 @@ Transform code such that each function definition accepts `Node` objects as argu without effecting dispatch in other ways. """ macro unionise(code) - return esc(unionise(code)) + return esc(unionise(code, __source__)) end diff --git a/src/sensitivity.jl b/src/sensitivity.jl index 71233b8c..e6590ccc 100644 --- a/src/sensitivity.jl +++ b/src/sensitivity.jl @@ -76,11 +76,11 @@ macro explicit_intercepts( end insert!(oldcall.args, 2, params) # The actual function definition - def = Expr(:function, newcall, oldcall) + def = Expr(:function, newcall, Expr(:block, __source__, oldcall)) end # NOTE: If kws is nonempty, explicit_intercepts will add methods to both f and _f # See boxed_method - ex = explicit_intercepts(f, type_tuple, isnode; kws...) + ex = explicit_intercepts(f, type_tuple, isnode, __source__; kws...) # The result contains all method definitions generated for f (and _f if applicable) return esc(Expr(:block, def, ex)) end @@ -92,10 +92,10 @@ Return a `:block` expression which evaluates to declare all of the combinations that could be required to catch if a `Node` is ever passed to the function specified in `expr`. """ -function explicit_intercepts(f::SymOrExpr, types::Expr, is_node::Vector{Bool}; kwargs...) +function explicit_intercepts(f::SymOrExpr, types::Expr, is_node::Vector{Bool}, linfo::LineNumberNode; kwargs...) function explicit_intercepts_(states::Vector{Bool}) if length(states) == length(is_node) - return any(states) ? boxed_method(f, types, states; kwargs...) : [] + return any(states) ? boxed_method(f, types, states, linfo; kwargs...) : [] else return vcat( explicit_intercepts_(vcat(states, false)), @@ -145,7 +145,8 @@ function boxed_method( f::SymOrExpr, type_tuple::Expr, is_node::Vector{Bool}, - arg_names::Vector{Symbol}; + arg_names::Vector{Symbol}=[gensym() for _ in is_node], + linfo::LineNumberNode=LineNumberNode(0); kwargs... ) # Get the argument types and create the function call. @@ -161,7 +162,7 @@ function boxed_method( body = Expr(:call, :Branch, f, tuple_expr, tape_expr) # Combine call signature with the body to create a new function. - return Expr(:(=), call, body) + return Expr(:function, call, Expr(:block, linfo, body)) else _type_tuple = copy(type_tuple) _is_node = copy(is_node) @@ -171,15 +172,15 @@ function boxed_method( push!(_is_node, false) push!(_arg_names, k) end - kw_def = Expr(:function, call, Expr(:call, kwfname(f), _arg_names...)) + kw_def = Expr(:function, call, Expr(:block, linfo, Expr(:call, kwfname(f), _arg_names...))) # Recurse on the internal function to get a Branch call - branch_def = boxed_method(kwfname(f), _type_tuple, _is_node, _arg_names) + branch_def = boxed_method(kwfname(f), _type_tuple, _is_node, _arg_names, linfo) return Expr(:block, kw_def, branch_def) end end -boxed_method(f, t, n; kwargs...) = boxed_method(f, t, n, [gensym() for _ in n]; kwargs...) +boxed_method(f, t, n, l; kwargs...) = boxed_method(f, t, n, [gensym() for _ in n], l; kwargs...) """ get_sig(f::SymOrExpr, arg_names::Vector{Symbol}, types::Vector; kwargs...) diff --git a/test/sensitivity.jl b/test/sensitivity.jl index 01652b3a..a6af8a82 100644 --- a/test/sensitivity.jl +++ b/test/sensitivity.jl @@ -1,7 +1,11 @@ -@testset "sensitivity" begin +using Base.Meta +using Nabla: boxed_method - import Base.Meta.quot +function expected_func(sig::Expr, body::Expr) + return Expr(:function, sig, Expr(:block, LineNumberNode(0), body)) +end +@testset "sensitivity" begin # # "Test" `Nabla.get_body`. (Not currently unit testing this as it is awkward. Will # # change this at some point in the future to be more unit-testable.) # let @@ -23,60 +27,52 @@ # println(full_expr) # end - # Test `Nabla.boxed_method`. - import Nabla.Nabla.boxed_method - let - from_func = boxed_method(:foo, :(Tuple{Any}), [true], [:x1]) - expected = Expr(Symbol("="), - :(foo(x1::Node{<:Any})), - :(Branch(foo, (x1,), getfield(x1, $(quot(:tape)))))) - @test from_func == expected - end - let - from_func = boxed_method(:foo, :(Tuple{T{V}}), [true], [:x1]) - expected = Expr(Symbol("="), - :(foo(x1::Node{<:T{V}})), - :(Branch(foo, (x1,), getfield(x1, $(quot(:tape)))))) - @test from_func == expected - end - let - from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, false], [:x1, :x2]) - expected = Expr(Symbol("="), - :(foo(x1::Node{<:Any}, x2::Any)), - :(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape)))))) - @test from_func == expected - end - let - from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, true], [:x1, :x2]) - expected = Expr(Symbol("="), - :(foo(x1::Node{<:Any}, x2::Node{<:Any})), - :(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape)))))) - @test from_func == expected - end - let - from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2]) - expected = Expr(Symbol("="), - :(foo(x1::Any, x2::Node{<:Any})), - :(Branch(foo, (x1, x2), getfield(x2, $(quot(:tape)))))) - @test from_func == expected - end - let - from_func = boxed_method(:foo, :(Tuple{T} where T), [true], [:x1]) - expected = Expr(Symbol("="), - :(foo(x1::Node{<:T}) where T), - :(Branch(foo, (x1,), getfield(x1, $(quot(:tape)))))) - @test from_func == expected - end - let - from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2]; a=1, b=2) - expected = Expr(:block, - Expr(:function, - :(foo(x1::Any, x2::Node{<:Any}; a=1, b=2)), - :(_foo(x1, x2, a, b))), - Expr(:(=), - :(_foo(x1::Any, x2::Node{<:Any}, a::Any, b::Any)), - :(Branch(_foo, (x1, x2, a, b), getfield(x2, $(quot(:tape))))))) - @test from_func == expected + @testset "boxed_method" begin + let + from_func = boxed_method(:foo, :(Tuple{Any}), [true], [:x1]) + expected = expected_func(:(foo(x1::Node{<:Any})), + :(Branch(foo, (x1,), getfield(x1, $(quot(:tape)))))) + @test from_func == expected + end + let + from_func = boxed_method(:foo, :(Tuple{T{V}}), [true], [:x1]) + expected = expected_func(:(foo(x1::Node{<:T{V}})), + :(Branch(foo, (x1,), getfield(x1, $(quot(:tape)))))) + @test from_func == expected + end + let + from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, false], [:x1, :x2]) + expected = expected_func(:(foo(x1::Node{<:Any}, x2::Any)), + :(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape)))))) + @test from_func == expected + end + let + from_func = boxed_method(:foo, :(Tuple{Any, Any}), [true, true], [:x1, :x2]) + expected = expected_func(:(foo(x1::Node{<:Any}, x2::Node{<:Any})), + :(Branch(foo, (x1, x2), getfield(x1, $(quot(:tape)))))) + @test from_func == expected + end + let + from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2]) + expected = expected_func(:(foo(x1::Any, x2::Node{<:Any})), + :(Branch(foo, (x1, x2), getfield(x2, $(quot(:tape)))))) + @test from_func == expected + end + let + from_func = boxed_method(:foo, :(Tuple{T} where T), [true], [:x1]) + expected = expected_func(:(foo(x1::Node{<:T}) where T), + :(Branch(foo, (x1,), getfield(x1, $(quot(:tape)))))) + @test from_func == expected + end + let + from_func = boxed_method(:foo, :(Tuple{Any, Any}), [false, true], [:x1, :x2]; a=1, b=2) + expected = Expr(:block, + expected_func(:(foo(x1::Any, x2::Node{<:Any}; a=1, b=2)), + :(_foo(x1, x2, a, b))), + expected_func(:(_foo(x1::Any, x2::Node{<:Any}, a::Any, b::Any)), + :(Branch(_foo, (x1, x2, a, b), getfield(x2, $(quot(:tape))))))) + @test from_func == expected + end end # Test `Nabla.branch_expr`.