Skip to content

Commit

Permalink
Adjust forward stage2 to Core.Compiler changes
Browse files Browse the repository at this point in the history
Only what is necessary for Cedar right now. Ordinary stage 2 reverse
mode will need similar changes at a later point.
  • Loading branch information
Keno committed Oct 3, 2024
1 parent d0b3e3e commit ef83e29
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
64 changes: 64 additions & 0 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,65 @@
using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState

if VERSION >= v"1.12.0-DEV.1268"

using Core.Compiler: Future

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta})
if f === ChainRulesCore.frule
# TODO: Currently, we don't have any termination analysis for the non-stratified
# forward analysis, so bail out here.
return Future{Union{CallMeta, Nothing}}(nothing)
end

nargs = length(arginfo.argtypes)-1
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
frule_atype = CC.argtypes_to_type(frule_argtypes)

local frule_call::Future{CallMeta}
local result::Future{Union{CallMeta, Nothing}} = Future{Union{CallMeta, Nothing}}()
function make_progress(_, sv)
if isa(primal_call[].info, UnionSplitApplyCallInfo)
result[] = nothing
return true
end

ready = false
if !@isdefined(frule_call)
# Here we simply check for the frule existance - we don't want to do a full
# inference with specialized argtypes and everything since the problem is
# likely sparse and we only need to do a full inference on a few calls.
# Thus, here we pick `Any` for the tangent types rather than trying to
# discover what they are. frules should be written in such a way that
# whether or not they return `nothing`, only depends on the non-tangent arguments
frule_arginfo = ArgInfo(nothing, frule_argtypes)
frule_si = StmtInfo(true)
# turn off frule analysis in the frule to avoid cycling
interp′ = disable_forward(interp)
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future
isready(frule_call) || return false
end

frc = frule_call[]
pc = primal_call[]

if frc.rt !== Const(nothing)
result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc))
else
result[] = nothing
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end

return true
end
(!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress)
return result
end

else

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta)
if f === ChainRulesCore.frule
Expand Down Expand Up @@ -38,4 +98,8 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
return nothing
end



end

const frule_mt = methods(ChainRulesCore.frule).mt
6 changes: 6 additions & 0 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end

Base.copy(ir::IRCode) = CC.copy(ir)

CC.NewInstruction(@nospecialize node) =
Expand Down
4 changes: 4 additions & 0 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]
if hasfield(CodeInfo, :nargs)
ci.nargs = 2
ci.isva = true
end

return ci
end
Expand Down
3 changes: 3 additions & 0 deletions src/stage2/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ end
CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx)
if isdefined(CC, :add_uncovered_edges_impl)
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::FRuleCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)
end

function Base.show(io::IO, info::FRuleCallInfo)
print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")")
Expand Down

0 comments on commit ef83e29

Please sign in to comment.