Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust forward stage2 to Core.Compiler changes #295

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/analysis/forward.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,65 @@
using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState

if VERSION >= v"1.12.0-DEV.1268"

using Core.Compiler: Future

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta})
if f === ChainRulesCore.frule
# TODO: Currently, we don't have any termination analysis for the non-stratified
# forward analysis, so bail out here.
return Future{Union{CallMeta, Nothing}}(nothing)
end

nargs = length(arginfo.argtypes)-1
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
frule_atype = CC.argtypes_to_type(frule_argtypes)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful not to closure capture any types, as your performance may suffer quite badly, but still just fast enough you won't notice (e.g. the sysimage could build still when I missed one of these cases, it just took several times longer)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So put it in a Ref{Any}?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or a Core.Box equivalently. I'd found several places where we had a MethodMatch object that was needed anyways, so that also happened to work sometimes


local frule_call::Future{CallMeta}
local result::Future{Union{CallMeta, Nothing}} = Future{Union{CallMeta, Nothing}}()
function make_progress(_, sv)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vtjnash, please confirm that this is the intended way to use this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seem to have closure captured interp instead of using the argument? The interp struct is commonly quite large, so that can increase memory usage quite a bit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument is the wrong interp

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this way of defining a make_progress seems fine. There isn't really one right answer about how to code this, so base itself already uses probably 3 or 4 different patterns, depending on what kept the original code control flow seemed least distorted. I hadn't used the @isdefined trick, but it is essentially equivalent to the nextstate pattern I'd used for manual stackless state machine conversion

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think capturing the interp will give you the behavior you want. I think you might need to mutate sv instead of re-using sv with different interp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not re-using sv with a different interp. sv here is an IRInterpretationState, which doesn't have an interp argument, so when the callback later gets scheduled, there's just some random interp in there.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sv claims to be a AbsIntState here? For IRInterpretationState currently it passes the interp here that was originally used to construct the IRInterpretationState, since everything is on the stack there and doesn't handle recursion

I think the behavior here is also probably fine, but that no other callback will be using the right interp, since none other are expecting the interp to be different from the one used to allocate the state object

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and doesn't handle recursion

Doesn't handle recursion in Base ;).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, I know almost nothing about this code, so I am reviewing without really knowing how this integrates. The current implementation in Base would potentially break here: https://github.com/JuliaLang/julia/blob/be401635fe02b28ce994e2e3cae0733d101f8927/base/compiler/ssair/irinterp.jl#L154
since it was not tracking if the return type changed to reschedule this instruction if it became part of cycle (I believe it should detect and @assert though if that attempts to happen)

if isa(primal_call[].info, UnionSplitApplyCallInfo)
result[] = nothing
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type appears to be wrong here. The intended behavior appears to be returning result[] = primal_call[] in this case (

r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
if r !== nothing
return r
end
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the function that calls this. Potentially it should be refactored to just do that, but I just wanted to only make the refactoring change.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is required to be Future{CallMeta} though, or the caller's caller will be unhappy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call-site there isn't updated yet. This function is called directly from DAECompiler and I adjusted the call-site there to work with Future{Union{Nothing, CallMeta}}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine if it still branches on r!==nothing there, it will just be dead code now, as it appears you you must handle that case here now, instead of being able to handle it there

return true
end

ready = false
if !@isdefined(frule_call)
# Here we simply check for the frule existance - we don't want to do a full
# inference with specialized argtypes and everything since the problem is
# likely sparse and we only need to do a full inference on a few calls.
# Thus, here we pick `Any` for the tangent types rather than trying to
# discover what they are. frules should be written in such a way that
# whether or not they return `nothing`, only depends on the non-tangent arguments
frule_arginfo = ArgInfo(nothing, frule_argtypes)
frule_si = StmtInfo(true)
# turn off frule analysis in the frule to avoid cycling
interp′ = disable_forward(interp)
frule_call = CC.abstract_call_gf_by_type(interp′,
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future
isready(frule_call) || return false
end

frc = frule_call[]
pc = primal_call[]

if frc.rt !== Const(nothing)
result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc))
else
result[] = nothing
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
end

return true
end
(!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress)
return result
end

else

function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta)
if f === ChainRulesCore.frule
Expand Down Expand Up @@ -38,4 +98,8 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
return nothing
end



end

const frule_mt = methods(ChainRulesCore.frule).mt
6 changes: 6 additions & 0 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ if VERSION < v"1.11.0-DEV.258"
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
end

if isdefined(CC, :Future)
Base.isready(future::CC.Future) = CC.isready(future)
Base.getindex(future::CC.Future) = CC.getindex(future)
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
end

Base.copy(ir::IRCode) = CC.copy(ir)

CC.NewInstruction(@nospecialize node) =
Expand Down
4 changes: 4 additions & 0 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
ci.method_for_inference_limit_heuristics = meth
ci.edges = MethodInstance[mi]
if hasfield(CodeInfo, :nargs)
ci.nargs = 2
ci.isva = true
end

return ci
end
Expand Down
3 changes: 3 additions & 0 deletions src/stage2/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ end
CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx)
if isdefined(CC, :add_uncovered_edges_impl)
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::FRuleCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)
end

function Base.show(io::IO, info::FRuleCallInfo)
print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")")
Expand Down
Loading