Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make @register_array_symbolic overload promote_symtype #1129

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions src/register.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function destructure_registration_expr(expr, Ts)
end


function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :())
function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs = :(), define_promotion = true)
def_assignments = MacroTools.rmlines(partial_defs).args
defs = map(def_assignments) do ex
@assert ex.head == :(=)
Expand All @@ -90,7 +90,7 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs


args′ = map((a, T) -> :($a::$T), argnames, Ts)
quote
fexpr = quote
@wrapped function $f($(args′...))
args = [$(argnames...),]
unwrapped_args = map($unwrap, args)
Expand All @@ -109,10 +109,44 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs
end
end
end |> esc

if define_promotion
container_type = get(defs, :container_type, :($propagate_atype(f, $(argnames...))))
etype = get(defs, :eltype, :($propagate_eltype(f, $(argnames...))))
ndims = get(defs, :ndims, nothing)
is_callable_struct = f isa Expr && f.head == :(::)
fn_arg = if is_callable_struct
f
else
:(f::$ftype)
end
fn_arg_name = if is_callable_struct
f.args[1]
else
:f
end
promote_expr = quote
function (::$typeof($promote_symtype))($fn_arg, $(argnames...))
f = $fn_arg_name
container_type = $container_type
etype = $etype
$(
if ndims === nothing
:(return container_type{etype})
else
:(ndims = $ndims; return container_type{etype, ndims})
end
)
end
end |> esc
fexpr = :($fexpr; $promote_expr)
end

return fexpr
end

"""
@register_array_symbolic(expr)
@register_array_symbolic(expr, define_promotion = true)

Example:

Expand All @@ -132,8 +166,18 @@ You can also register calls on callable structs:
eltype=promote_type(eltype(x), eltype(c))
end
```

If `define_promotion = true` then a promotion method in the form of
```julia
SymbolicUtils.promote_symtype(::typeof(f_registered), args...) = # inferred or annotated return type
```

is defined for the register function. Note that when defining multiple register
overloads for one function, all the rest of the registers must set
`define_promotion` to `false` except for the first one, to avoid method
overwriting.
"""
macro register_array_symbolic(expr, block)
macro register_array_symbolic(expr, block, define_promotion = true)
f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([]))
return register_array_symbolic(f, ftype, argnames, Ts, ret_type, block)
register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion)
end
74 changes: 61 additions & 13 deletions test/macro.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Symbolics
import Symbolics: getsource, getdefaultval, wrap, unwrap, getname
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic
import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype
using LinearAlgebra
using Test

Expand All @@ -9,34 +9,57 @@ Symbolics.@register_symbolic fff(t)
@test isequal(fff(t), Symbolics.Num(Symbolics.Term{Real}(fff, [Symbolics.value(t)])))

const SymMatrix{T,N} = Symmetric{T, AbstractArray{T, N}}
many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4]

let
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
end false

## @variables

gg = ggg(x)

@test ndims(gg) == 2
@test size(gg) == (8,8)
@test eltype(gg) == Real
@test symtype(unwrap(gg)) == SymMatrix{Real, 2}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this SymMatrix{Real, 2} and the other SymMatrix{Real}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't infer ndims if it's not specified in the macrocall, since promote_symtype only gets the types of arguments to the function

@test promote_symtype(ggg, symtype(unwrap(x))) == Any # no promote_symtype defined
end
let
# redefine with promote_symtype
@register_array_symbolic ggg(x::AbstractVector) begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use a different function so that we can re-include the file.

container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
end
@test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real}
end

# ndims specified
@register_array_symbolic ggg(x::AbstractVector) begin
container_type=SymMatrix
size=(length(x) * 2, length(x) * 2)
eltype=eltype(x)
ndims = 2
end

## @variables

many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4]

@test promote_symtype(ggg, symtype(unwrap(x))) == SymMatrix{Real, 2}

gg = ggg(x)

@test ndims(gg) == 2
@test size(gg) == (8,8)
@test eltype(gg) == Real
@test symtype(unwrap(gg)) == SymMatrix{Real, 2}

struct CanCallWithArray{T}
params::T
end

ccwa = CanCallWithArray((length=10,))
@register_array_symbolic (c::CanCallWithArray)(x::AbstractArray, b::AbstractVector) begin
size=(size(x, 1), length(b), c.params.length)
eltype=Real
end
end false # without promote_symtype

hh = CanCallWithArray((length=10,))(gg, x)
hh = ccwa(gg, x)
@test size(hh) == (8,4,10)
@test eltype(hh) == Real
@test isequal(arguments(unwrap(hh)), unwrap.([gg, x]))
Expand All @@ -52,9 +75,34 @@ hh = CanCallWithArray((length=10,))(gg, x)
@test getdefaultval(z[3]) == 4

@test symtype(p) <: FnType{Tuple, Array{Real,1}}
@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == Any
@test p(t)[1] isa Symbolics.Num


struct CanCallWithArray2{T}
params::T
end

ccwa = CanCallWithArray2((length=10,))
@register_array_symbolic (c::CanCallWithArray2)(x::AbstractArray, b::AbstractVector) begin
size=(size(x, 1), length(b), c.params.length)
eltype=Real
end
@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real}

struct CanCallWithArray3{T}
params::T
end

ccwa = CanCallWithArray3((length=10,))
# ndims specified
@register_array_symbolic (c::CanCallWithArray3)(x::AbstractArray, b::AbstractVector) begin
size=(size(x, 1), length(b), c.params.length)
eltype=Real
ndims = 3
end
@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real, 3}

## Wrapper types

abstract type AbstractFoo{T} end
Expand Down
Loading