Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Commit

Permalink
Allow default values and kwargs in at-unionise (#135)
Browse files Browse the repository at this point in the history
Currently, `at-unionise` fails to parse `f(x::A, y::B=z)` as well as
`f(x::A; y::B=z)`. This change permits modifying those signatures such
that positional arguments are unionized with the default value(s)
preserved, and keyword arguments are allowed but ignored. That is, the
aforementioned examples become
```julia
f(x::Union{A,Node{<:A}}, y::Union{B,Node{<:B}}=z)
f(x::Union{A,Node{<:A}}; y::B=z)
```
respectively.
  • Loading branch information
ararslan authored Mar 7, 2019
1 parent 21275e3 commit a830fd4
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/code_transformation/differentiable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@ Return an expression in which the argument expression `arg` is replaced with an
whos type admits `Node`s.
"""
unionise_arg(arg::Symbol) = arg
unionise_arg(arg::Expr) =
arg.head == Symbol("::") ?
Expr(Symbol("::"), arg.args[1:end-1]..., unionise_type(arg.args[end])) :
arg.head == Symbol("...") ?
Expr(Symbol("..."), unionise_arg(arg.args[1])) :
throw(ArgumentError("Unrecognised argument in Symbol ($arg)."))

function unionise_arg(arg::Expr)
if arg.head === :(::)
Expr(:(::), arg.args[1:end-1]..., unionise_type(arg.args[end]))
elseif arg.head === :...
Expr(:..., unionise_arg(arg.args[1]))
elseif arg.head === :kw
Expr(:kw, unionise_arg(arg.args[1]), arg.args[2])
elseif arg.head === :parameters
arg # Ignore keyword arguments and leave them untouched for now
else
throw(ArgumentError("Unrecognized argument in Symbol ($arg)."))
end
end

"""
unionise_subtype(arg::Union{Symbol, Expr})
Expand Down
9 changes: 9 additions & 0 deletions test/code_transformation/differentiable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,13 @@ skip_line_info(ex) = ex
@test unionise(:(@eval foo)) unionise_macro_eval(:(@eval foo))
@test unionise(:(@eval DiffBase foo)) unionise_macro_eval(:(@eval DiffBase foo))
@test unionise(:(struct Foo{T<:V} end)) == unionise_struct(:(struct Foo{T<:V} end))

# @unionise with default values and keywords
UT = unionise_type(:T)
raw = unionise(:(f(x::T, y::T=2; z::T=4) = x + y + z))
new = :(f(x::$UT, y::$UT=2; z::T=4) = x + y + z)
@test raw new

# @unionise error conditions
@test_throws ArgumentError unionise(:(f(@nospecialize x) = x))
end

0 comments on commit a830fd4

Please sign in to comment.