Skip to content

Commit

Permalink
Fix binomial registration
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Oct 4, 2023
1 parent afd8e80 commit a87bb48
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/extra_functions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@register_symbolic Base.binomial(n,k)
@register_symbolic Base.binomial(n, k)::Int true [Integer]

@register_symbolic Base.sign(x)::Int
derivative(::typeof(sign), args::NTuple{1,Any}, ::Val{1}) = 0
Expand Down
11 changes: 8 additions & 3 deletions src/register.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using SymbolicUtils: Symbolic

"""
@register_symbolic(expr, define_promotion = true, Ts = [Num, Symbolic, Real])
@register_symbolic(expr, define_promotion = true, Ts = [Real])
Overload appropriate methods so that Symbolics can stop tracing into the
registered function. If `define_promotion` is true, then a promotion method in
Expand All @@ -22,7 +22,7 @@ overwriting.
@register_symbolic hoo(x, y)::Int # `hoo` returns `Int`
```
"""
macro register_symbolic(expr, define_promotion = true, Ts = [])
macro register_symbolic(expr, define_promotion = true, Ts = :([]))
if expr.head === :(::)
ret_type = expr.args[2]
expr = expr.args[1]
Expand All @@ -31,6 +31,8 @@ macro register_symbolic(expr, define_promotion = true, Ts = [])
end

@assert expr.head === :call
@assert Ts.head === :vect
Ts = Ts.args

f = expr.args[1]
args = expr.args[2:end]
Expand All @@ -41,7 +43,10 @@ macro register_symbolic(expr, define_promotion = true, Ts = [])

types = map(args) do x
if x isa Symbol
:(($Real, $wrapper_type($Real), $Symbolic{<:$Real}))
if isempty(Ts)
Ts = [Real]
end
:(($(Ts...), $wrapper_type($Real), $Symbolic{<:$Real}))
elseif Meta.isexpr(x, :(::))
T = x.args[2]
:($has_symwrapper($T) ?
Expand Down
2 changes: 2 additions & 0 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,5 @@ stringcontent = string(d.content)
for f in [<, <=, >, >=, isless]
@test_nowarn f(t, 1.0)
end

@test_nowarn binomial(t, 1)

0 comments on commit a87bb48

Please sign in to comment.