From a830fd49dbc1975c6eb3b2132eec69410eaace1f Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Thu, 7 Mar 2019 10:34:44 -0800 Subject: [PATCH] Allow default values and kwargs in at-unionise (#135) Currently, `at-unionise` fails to parse `f(x::A, y::B=z)` as well as `f(x::A; y::B=z)`. This change permits modifying those signatures such that positional arguments are unionized with the default value(s) preserved, and keyword arguments are allowed but ignored. That is, the aforementioned examples become ```julia f(x::Union{A,Node{<:A}}, y::Union{B,Node{<:B}}=z) f(x::Union{A,Node{<:A}}; y::B=z) ``` respectively. --- src/code_transformation/differentiable.jl | 20 ++++++++++++++------ test/code_transformation/differentiable.jl | 9 +++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/code_transformation/differentiable.jl b/src/code_transformation/differentiable.jl index 6d9b89dd..710a0ba1 100644 --- a/src/code_transformation/differentiable.jl +++ b/src/code_transformation/differentiable.jl @@ -7,12 +7,20 @@ Return an expression in which the argument expression `arg` is replaced with an whos type admits `Node`s. """ unionise_arg(arg::Symbol) = arg -unionise_arg(arg::Expr) = - arg.head == Symbol("::") ? - Expr(Symbol("::"), arg.args[1:end-1]..., unionise_type(arg.args[end])) : - arg.head == Symbol("...") ? - Expr(Symbol("..."), unionise_arg(arg.args[1])) : - throw(ArgumentError("Unrecognised argument in Symbol ($arg).")) + +function unionise_arg(arg::Expr) + if arg.head === :(::) + Expr(:(::), arg.args[1:end-1]..., unionise_type(arg.args[end])) + elseif arg.head === :... + Expr(:..., unionise_arg(arg.args[1])) + elseif arg.head === :kw + Expr(:kw, unionise_arg(arg.args[1]), arg.args[2]) + elseif arg.head === :parameters + arg # Ignore keyword arguments and leave them untouched for now + else + throw(ArgumentError("Unrecognized argument in Symbol ($arg).")) + end +end """ unionise_subtype(arg::Union{Symbol, Expr}) diff --git a/test/code_transformation/differentiable.jl b/test/code_transformation/differentiable.jl index 7e9b6e2f..f899653c 100644 --- a/test/code_transformation/differentiable.jl +++ b/test/code_transformation/differentiable.jl @@ -112,4 +112,13 @@ skip_line_info(ex) = ex @test unionise(:(@eval foo)) ≃ unionise_macro_eval(:(@eval foo)) @test unionise(:(@eval DiffBase foo)) ≃ unionise_macro_eval(:(@eval DiffBase foo)) @test unionise(:(struct Foo{T<:V} end)) == unionise_struct(:(struct Foo{T<:V} end)) + + # @unionise with default values and keywords + UT = unionise_type(:T) + raw = unionise(:(f(x::T, y::T=2; z::T=4) = x + y + z)) + new = :(f(x::$UT, y::$UT=2; z::T=4) = x + y + z) + @test raw ≃ new + + # @unionise error conditions + @test_throws ArgumentError unionise(:(f(@nospecialize x) = x)) end