Skip to content

Commit

Permalink
Generate 2-to-1 overloads on arbitrary types (#197)
Browse files Browse the repository at this point in the history
* Allow 2-to-1 overloads on arbitrary types

* Simplify methods that avoid type ambiguities using this new tooling
  • Loading branch information
adrhill authored Oct 1, 2024
1 parent 65a52ad commit 702b2d0
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 69 deletions.
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ include("operators.jl")
include("overloads/conversion.jl")
include("overloads/gradient_tracer.jl")
include("overloads/hessian_tracer.jl")
include("overloads/ambiguities.jl")
include("overloads/special_cases.jl")
include("overloads/ifelse_global.jl")
include("overloads/dual.jl")
include("overloads/arrays.jl")
include("overloads/utils.jl")
include("overloads/ambiguities.jl")

include("trace_functions.jl")
include("adtypes_interface.jl")
Expand Down
41 changes: 4 additions & 37 deletions src/overloads/ambiguities.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,5 @@
## Special overloads to avoid ambiguity errors
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::T, ::S) where {T<:GradientTracer} = t
Base.:^(::S, t::T) where {T<:GradientTracer} = t
Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)
Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)

function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}}
x = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}}
y = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end

function Base.:^(d::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}}
x = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:HessianTracer,D<:Dual{P,T}}
y = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
end

for TT in (GradientTracer, HessianTracer)
function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:TT,D<:Dual{P,T}}
return isless(primal(dx), y)
end
function Base.isless(x::AbstractFloat, dy::D) where {P<:Real,T<:TT,D<:Dual{P,T}}
return isless(x, primal(dy))
end
end
eval(generate_code_2_to_1(:Base, ^, Integer))
eval(generate_code_2_to_1(:Base, ^, Rational))
eval(generate_code_2_to_1(:Base, ^, Irrational{:ℯ}))
eval(generate_code_2_to_1(:Base, isless, AbstractFloat))
3 changes: 0 additions & 3 deletions src/overloads/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,3 @@ for fn in (
throw(MissingPrimalError($fn, t))
end
end

# In some cases, more specialized methods are needed
Base.isless(dx::D, y::AbstractFloat) where {D<:Dual} = isless(primal(dx), y)
34 changes: 16 additions & 18 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ function gradient_tracer_2_to_1_inner(
end
end

function generate_code_gradient_2_to_1(M::Symbol, f)
function generate_code_gradient_2_to_1(
M::Symbol, # Symbol indicating Module of f, usually `:Base`
f::Function, # function to overload
Z::Type=Real, # external non-tracer-type to overload on
)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der1_arg2_zero_g = is_der1_arg2_zero_global(f)
Expand All @@ -122,11 +126,11 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
)
end

function $M.$fname(tx::$SCT.GradientTracer, ::Real)
function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end

function $M.$fname(::Real, ty::$SCT.GradientTracer)
function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end
Expand Down Expand Up @@ -158,20 +162,16 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
end
end
end
expr_dual_real = if is_der1_arg1_zero_g
expr_dual_nondual = if is_der1_arg1_zero_g
quote
function $M.$fname(
dx::D, y::Real
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(
dx::D, y::Real
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$fname(x, y)

Expand All @@ -182,20 +182,16 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
end
end
end
expr_real_dual = if is_der1_arg2_zero_g
expr_nondual_dual = if is_der1_arg2_zero_g
quote
function $M.$fname(
x::Real, dy::D
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(
x::Real, dy::D
) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.GradientTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$fname(x, y)

Expand All @@ -207,7 +203,9 @@ function generate_code_gradient_2_to_1(M::Symbol, f)
end
end

return Expr(:block, expr_gradienttracer, expr_dual_dual, expr_dual_real, expr_real_dual)
return Expr(
:block, expr_gradienttracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual
)
end

## 1-to-2
Expand Down
26 changes: 16 additions & 10 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ function hessian_tracer_2_to_1_inner(
return P(g_out, h_out) # return pattern
end

function generate_code_hessian_2_to_1(M::Symbol, f)
function generate_code_hessian_2_to_1(
M::Symbol, # Symbol indicating Module of f, usually `:Base`
f::Function, # function to overload
Z::Type=Real, # external non-tracer-type to overload on
)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der2_arg1_zero_g = is_der2_arg1_zero_global(f)
Expand All @@ -197,11 +201,11 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
)
end

function $M.$fname(tx::$SCT.HessianTracer, y::Real)
function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
end

function $M.$fname(x::Real, ty::$SCT.HessianTracer)
function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
end
end
Expand Down Expand Up @@ -251,16 +255,16 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
end
end
end
expr_dual_real = if is_der1_arg1_zero_g && is_der2_arg1_zero_g
expr_dual_nondual = if is_der1_arg1_zero_g && is_der2_arg1_zero_g
quote
function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(dx::D, y::Real) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(dx::D, y::$Z) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
x = $SCT.primal(dx)
p_out = $M.$fname(x, y)

Expand All @@ -272,16 +276,16 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
end
end
end
expr_real_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g
expr_nondual_dual = if is_der1_arg2_zero_g && is_der2_arg2_zero_g
quote
function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
return $M.$fname(x, y)
end
end
else
quote
function $M.$fname(x::Real, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
function $M.$fname(x::$Z, dy::D) where {P,T<:$SCT.HessianTracer,D<:$SCT.Dual{P,T}}
y = $SCT.primal(dy)
p_out = $M.$fname(x, y)

Expand All @@ -294,7 +298,9 @@ function generate_code_hessian_2_to_1(M::Symbol, f)
end
end

return Expr(:block, expr_hessiantracer, expr_dual_dual, expr_dual_real, expr_real_dual)
return Expr(
:block, expr_hessiantracer, expr_dual_dual, expr_dual_nondual, expr_nondual_dual
)
end

## 1-to-2
Expand Down
19 changes: 19 additions & 0 deletions src/overloads/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ for d in dims
end
end

# Overloads of 2-argument functions on arbitrary types
function generate_code_2_to_1(M::Symbol, f, Z::Type)
expr_g = generate_code_gradient_2_to_1(M, f, Z)
expr_h = generate_code_hessian_2_to_1(M, f, Z)
return Expr(:block, expr_g, expr_h)
end
function generate_code_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end
function generate_code_gradient_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_gradient_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end
function generate_code_hessian_2_to_1(M::Symbol, ops::Union{AbstractVector,Tuple}, Z::Type)
exprs = [generate_code_hessian_2_to_1(M, op, Z) for op in ops]
return Expr(:block, exprs...)
end

## Overload operators
eval(generate_code_1_to_1(:Base, ops_1_to_1))
eval(generate_code_2_to_1(:Base, ops_2_to_1))
Expand Down

0 comments on commit 702b2d0

Please sign in to comment.