Skip to content

Commit

Permalink
Allow visit_custom recursion to change the order (#179)
Browse files Browse the repository at this point in the history
* Allow visit_custom recursion to change the order

To implement order-changing intrinsics.

* Add ddt test
  • Loading branch information
Keno authored Jul 14, 2023
1 parent 3d5bee0 commit 24f047c
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
16 changes: 10 additions & 6 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vecto
ssa_orders[ssa.id] = order => ssa_orders[ssa.id][2]
inst = ir[ssa]
stmt = inst[:inst]
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
recurse(@nospecialize(val), new_order=order) = forward_visit!(ir, val, new_order, ssa_orders, visit_custom!)
if visit_custom!(ir, ssa, order, recurse)
ssa_orders[ssa.id] = order => true
return
Expand Down Expand Up @@ -220,10 +220,15 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
visit_custom! = (@nospecialize args...)->false,
transform! = (@nospecialize args...)->error())
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
ssa_orders = [0=>false for i = 1:length(ir.stmts)]
ssa_orders = [-1=>false for i = 1:length(ir.stmts)]
for (ssa, order) in to_diff
forward_visit!(ir, ssa, order, ssa_orders, visit_custom!)
end
for (ssa, (order, custom)) in enumerate(ssa_orders)
if order == -1
ssa_orders[ssa] = 0 => custom
end
end

truncation_map = Dict{Pair{SSAValue, Int}, SSAValue}()

Expand Down Expand Up @@ -266,7 +271,9 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
end

for (ssa, (order, custom)) in enumerate(ssa_orders)
if order == 0
if custom
transform!(ir, SSAValue(ssa), order, maparg)
elseif order == 0
inst = ir[SSAValue(ssa)]
stmt = inst[:inst]
urs = userefs(stmt)
Expand All @@ -275,9 +282,6 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
end
inst[:inst] = urs[]
continue
end
if custom
transform!(ir, SSAValue(ssa), order, maparg)
else
inst = ir[SSAValue(ssa)]
stmt = inst[:inst]
Expand Down
2 changes: 1 addition & 1 deletion src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
return generate_lambda_ex(world, source,
Core.svec(:ff, :args), Core.svec(), :(∂☆builtin(args)))
end

mthds = Base._methods_by_ftype(sig, -1, world)
if mthds === nothing || length(mthds) != 1
# Core.println("[perform_fwd_transform] ", sig, " => ", mthds)
Expand Down
40 changes: 37 additions & 3 deletions src/stage2/forward.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
using .CC: compact!

function is_known_invoke_or_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact})
return is_known_invoke(x, func, ir) || CC.is_known_call(x, func, ir)
end

function is_known_invoke(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact})
isexpr(x, :invoke) || return false
ft = argextype(x.args[2], ir)
return singleton_type(ft) === func
end

@noinline function dont_use_ddt_intrinsic(x::Float64)
if Base.inferencebarrier(true)
error("Intrinsic not transformed")
end
return Base.inferencebarrier(0.0)::Float64
end

# Engineering entry point for the 2nd-order forward AD functionality. This is
# unlikely to be the actual interface. For now, it is used for testing.
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
Expand Down Expand Up @@ -28,24 +45,41 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
if isa(stmt, ReturnNode)
recurse(stmt.val)
return true
elseif is_known_invoke_or_call(stmt, dont_use_ddt_intrinsic, ir)
recurse(stmt.args[end], order+1)
return true
else
return false
end
end

function transform!(ir::IRCode, ssa::SSAValue, _, _)
function transform!(ir::IRCode, ssa::SSAValue, _, maparg)
inst = ir[ssa]
stmt = inst[:inst]
if isa(stmt, ReturnNode)
if order == 0
return
end
nr = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, stmt.val, TaylorTangentIndex(order)), Any))
inst[:inst] = ReturnNode(nr)
elseif is_known_invoke_or_call(stmt, dont_use_ddt_intrinsic, ir)
arg = maparg(stmt.args[end], ssa, order+1)
if order > 0
replace_call!(ir, ssa, Expr(:call, error, "Only order 0 implemented here"))
else
replace_call!(ir, ssa, Expr(:call, getindex, arg, TaylorTangentIndex(1)))
end
else
error()
end
end

function transform!(ir::IRCode, arg::Argument, _, _)
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
function transform!(ir::IRCode, arg::Argument, order, _)
if order == 0
return arg
else
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
end
end

ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!)
Expand Down
2 changes: 1 addition & 1 deletion src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ end
Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val

"""
CompositeBundle{N, B <: Tuple}
CompositeBundle{N, B, B <: Tuple}
Represents the tagent bundle where the base space is some tuple or struct type.
Mathematically, this tangent bundle is the product bundle of the individual
Expand Down
9 changes: 9 additions & 0 deletions test/stage2_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,13 @@ module stage2_fwd
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
end

@testset "ddt intrinsic" begin
function my_cos_ddt(x)
return Diffractor.dont_use_ddt_intrinsic(sin(x))
end
let my_cos_ddt_transformed = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(my_cos_ddt), Float64}, 0)
@test my_cos_ddt_transformed(1.0) == cos(1.0)
end
end
end

0 comments on commit 24f047c

Please sign in to comment.