Skip to content

Commit

Permalink
Merge pull request #1036 from asterycs/feature/resolve_dual_ambiguities2
Browse files Browse the repository at this point in the history
Resolve Symbolics.Num - ForwardDiff.Dual  ambiguities
  • Loading branch information
ChrisRackauckas authored Jan 8, 2024
2 parents df4b171 + 1b90690 commit 493e276
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
[extensions]
SymbolicsGroebnerExt = "Groebner"
SymbolicsPreallocationToolsExt = ["ForwardDiff", "PreallocationTools"]
SymbolicsForwardDiffExt = "ForwardDiff"
SymbolicsSymPyExt = "SymPy"

[compat]
Expand Down
256 changes: 256 additions & 0 deletions ext/SymbolicsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
module SymbolicsForwardDiffExt

using ForwardDiff
using ForwardDiff.NaNMath
using ForwardDiff.DiffRules
using ForwardDiff: value, Dual, partials
using Symbolics

# The method generation in this file have been adapted from
# https://github.com/JuliaDiff/ForwardDiff.jl/blob/v0.10.36/src/dual.jl

const AMBIGUOUS_TYPES = (Num,)

####################################
# N-ary Operation Definition Tools #
####################################

macro define_binary_dual_op(f, xy_body, x_body, y_body, Ts)
FD = ForwardDiff
defs = quote end
for R in Ts
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}) where {Ty} = $y_body
end
append!(defs.args, expr.args)
end
return esc(defs)
end

macro define_ternary_dual_op(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body, Ts)
FD = ForwardDiff
defs = quote end
for R in Ts
expr = quote
@inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body
@inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty Tx ? $x_body : $y_body
@inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz Tx ? $x_body : $z_body
@inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz Ty ? $y_body : $z_body
end
append!(defs.args, expr.args)
for Q in Ts
Q === R && continue
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body
@inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body
end
append!(defs.args, expr.args)
end
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body
@inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body
end
append!(defs.args, expr.args)
end
return esc(defs)
end

function binary_dual_definition(M, f, Ts)
FD = ForwardDiff
dvx, dvy = DiffRules.diffrule(M, f, :vx, :vy)
Mf = M == :Base ? f : :($M.$f)
xy_work = FD.qualified_cse!(quote
val = $Mf(vx, vy)
dvx = $dvx
dvy = $dvy
end)
dvx, _ = DiffRules.diffrule(M, f, :vx, :y)
x_work = FD.qualified_cse!(quote
val = $Mf(vx, y)
dvx = $dvx
end)
_, dvy = DiffRules.diffrule(M, f, :x, :vy)
y_work = FD.qualified_cse!(quote
val = $Mf(x, vy)
dvy = $dvy
end)
expr = quote
@define_binary_dual_op(
$M.$f,
begin
vx, vy = $FD.value(x), $FD.value(y)
$xy_work
return $FD.dual_definition_retval(Val{Txy}(), val, dvx, $FD.partials(x), dvy, $FD.partials(y))
end,
begin
vx = $FD.value(x)
$x_work
return $FD.dual_definition_retval(Val{Tx}(), val, dvx, $FD.partials(x))
end,
begin
vy = $FD.value(y)
$y_work
return $FD.dual_definition_retval(Val{Ty}(), val, dvy, $FD.partials(y))
end,
$Ts
)
end
return expr
end

###################################
# General Mathematical Operations #
###################################

for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-), (:Base, :sin), (:Base, :cos))
continue # Skip methods which we define elsewhere.
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
continue # Skip rules for methods not defined in the current scope
end
if arity == 1
# no-op
elseif arity == 2
eval(binary_dual_definition(M, f, AMBIGUOUS_TYPES))
else
# no-op
end
end

#################
# Special Cases #
#################

# +/- #
#-----#

@eval begin
@define_binary_dual_op(
Base.:+,
begin
vx, vy = value(x), value(y)
Dual{Txy}(vx + vy, partials(x) + partials(y))
end,
Dual{Tx}(value(x) + y, partials(x)),
Dual{Ty}(x + value(y), partials(y)),
$AMBIGUOUS_TYPES
)
end

@eval begin
@define_binary_dual_op(
Base.:-,
begin
vx, vy = value(x), value(y)
Dual{Txy}(vx - vy, partials(x) - partials(y))
end,
Dual{Tx}(value(x) - y, partials(x)),
Dual{Ty}(x - value(y), -partials(y)),
$AMBIGUOUS_TYPES
)
end

# / #
#---#

# We can't use the normal diffrule autogeneration for this because (x/y) === (x * (1/y))
# doesn't generally hold true for floating point; see issue #264
@eval begin
@define_binary_dual_op(
Base.:/,
begin
vx, vy = value(x), value(y)
Dual{Txy}(vx / vy, _div_partials(partials(x), partials(y), vx, vy))
end,
Dual{Tx}(value(x) / y, partials(x) / y),
begin
v = value(y)
divv = x / v
Dual{Ty}(divv, -(divv / v) * partials(y))
end,
$AMBIGUOUS_TYPES
)
end

# exponentiation #
#----------------#

for f in (:(Base.:^), :(NaNMath.pow))
@eval begin
@define_binary_dual_op(
$f,
begin
vx, vy = value(x), value(y)
expv = ($f)(vx, vy)
powval = vy * ($f)(vx, vy - 1)
if isconstant(y)
logval = one(expv)
elseif iszero(vx) && vy > 0
logval = zero(vx)
else
logval = expv * log(vx)
end
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
return Dual{Txy}(expv, new_partials)
end,
begin
v = value(x)
expv = ($f)(v, y)
if y == zero(y) || iszero(partials(x))
new_partials = zero(partials(x))
else
new_partials = partials(x) * y * ($f)(v, y - 1)
end
return Dual{Tx}(expv, new_partials)
end,
begin
v = value(y)
expv = ($f)(x, v)
deriv = (iszero(x) && v > 0) ? zero(expv) : expv*log(x)
return Dual{Ty}(expv, deriv * partials(y))
end,
$AMBIGUOUS_TYPES
)
end
end

# hypot #
#-------#

@eval begin
@define_ternary_dual_op(
Base.hypot,
calc_hypot(x, y, z, Txyz),
calc_hypot(x, y, z, Txy),
calc_hypot(x, y, z, Txz),
calc_hypot(x, y, z, Tyz),
calc_hypot(x, y, z, Tx),
calc_hypot(x, y, z, Ty),
calc_hypot(x, y, z, Tz),
$AMBIGUOUS_TYPES
)
end

# muladd #
#--------#

@eval begin
@define_ternary_dual_op(
Base.muladd,
calc_muladd_xyz(x, y, z), # xyz_body
calc_muladd_xy(x, y, z), # xy_body
calc_muladd_xz(x, y, z), # xz_body
Base.muladd(y, x, z), # yz_body
Dual{Tx}(muladd(value(x), y, z), partials(x) * y), # x_body
Base.muladd(y, x, z), # y_body
Dual{Tz}(muladd(x, y, value(z)), partials(z)), # z_body
$AMBIGUOUS_TYPES
)
end

end
109 changes: 109 additions & 0 deletions test/forwarddiff_symbolic_dual_ops.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using ForwardDiff
using Symbolics
using Symbolics.SymbolicUtils
using Symbolics.SymbolicUtils.SpecialFunctions
using Symbolics.NaNMath
using Test

SF = SymbolicUtils.SpecialFunctions

@variables x

# Test functions from Symbolics #
#-------------------------------#

for f SymbolicUtils.basic_monadic
fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f SymbolicUtils.monadic
# The polygamma and trigamma functions seem to be missing rules in ForwardDiff.
# The abs rule uses conditionals and cannot be used with Symbolics.Num.
# acsc, asech, NanMath.log2 and NaNMath.log10 are tested separately
if f (abs, SF.polygamma, SF.trigamma, acsc, acsch, asech, NaNMath.log2, NaNMath.log10)
continue
end

fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

# These are evaluated numerically. For some reason isequal evaluates to false for the symbolic expressions.
for f (acsc, asech, NaNMath.log2, NaNMath.log10)
fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, 1.0)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test fd substitute(sym, Dict(x => 1.0))
end

for f SymbolicUtils.basic_diadic
if f (//,)
continue
end

fun = eval(:(ξ ->($f)(ξ, 2.0)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f SymbolicUtils.diadic
if f (max, min, NaNMath.atanh, mod, rem, copysign, besselj, bessely, besseli, besselk)
continue
end

fun = eval(:(ξ ->($f)(ξ, 2.0)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f (NaNMath.atanh,)
fun = eval(:(ξ ->($f)(ξ)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

for f (besselj, bessely, besseli, besselk)
fun = eval(:(ξ ->($f)(ξ, 2)))

fd = ForwardDiff.derivative(fun, x)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test isequal(fd, sym)
end

# Additionally test these definitions from ForwardDiff #
#------------------------------------------------------#

# https://github.com/JuliaDiff/ForwardDiff.jl/blob/d3002093beb88ff0b98ed178377961dfd55c1247/src/dual.jl#L599
# and
# https://github.com/JuliaDiff/ForwardDiff.jl/blob/d3002093beb88ff0b98ed178377961dfd55c1247/src/dual.jl#L683
for f (hypot, muladd)
fun = eval(:(ξ ->($f)(ξ, 2.0, 3.0)))

fd = ForwardDiff.derivative(fun, 5.0)
sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives

@test fd substitute(sym, Dict(x => 5.0))
end

# fma is not defined for Symbolics.Num
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Linear Solver Test" begin include("linear_solver.jl") end
@safetestset "Algebraic Solver Test" begin include("solver.jl") end
@safetestset "Overloading Test" begin include("overloads.jl") end
@safetestset "ForwardDiff Extension Test" begin include("forwarddiff_symbolic_dual_ops.jl") end
@safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end
@safetestset "Build Function Test" begin include("build_function.jl") end
@safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end
Expand Down

0 comments on commit 493e276

Please sign in to comment.