From c37f82a7b68018c37703f8a990bdc2817dfd9849 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 3 Oct 2024 00:36:48 +0000 Subject: [PATCH] Adjust to stackless compiler changes Depends on: - https://github.com/JuliaDiff/Diffractor.jl/pull/295 - https://github.com/JuliaLang/julia/pull/55972 --- Manifest.toml | 10 +++-- src/analysis/compiler.jl | 4 +- src/analysis/interpreter.jl | 84 ++++++++++++++++++++----------------- 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index be6aff4..07cbe6b 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -444,9 +444,11 @@ version = "1.15.1" [[deps.Diffractor]] deps = ["AbstractDifferentiation", "ChainRules", "ChainRulesCore", "Combinatorics", "Cthulhu", "InteractiveUtils", "OffsetArrays", "PrecompileTools", "StaticArrays", "StructArrays"] -git-tree-sha1 = "e9472ffeff4ec8958e96cf3ddcae5e700cbeacbd" +git-tree-sha1 = "b69270604edd914f886c1f8f83476c53b19e0101" +repo-rev = "kf/compileradjust" +repo-url = "https://github.com/JuliaDiff/Diffractor.jl.git" uuid = "9f5e2b26-1114-432f-b630-d3fe2085c51c" -version = "0.2.10" +version = "0.2.8" [[deps.Distances]] deps = ["LinearAlgebra", "Statistics", "StatsAPI"] @@ -1235,7 +1237,7 @@ weakdeps = ["Adapt"] [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.27+1" +version = "0.3.28+2" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -1830,7 +1832,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.7.0+0" +version = "7.8.0+0" [[deps.Sundials]] deps = ["CEnum", "DataStructures", "DiffEqBase", "Libdl", "LinearAlgebra", "Logging", "PrecompileTools", "Reexport", "SciMLBase", "SparseArrays", "Sundials_jll"] diff --git a/src/analysis/compiler.jl b/src/analysis/compiler.jl index 971fd74..91217c7 100644 --- a/src/analysis/compiler.jl +++ b/src/analysis/compiler.jl @@ -726,7 +726,7 @@ end analysis_interp = DAEInterpreter(interp; var_to_diff, var_kind, eq_kind, in_analysis=interp.ipo_analysis_mode) irsv = CC.IRInterpretationState(analysis_interp, method_info, ir, mi, argtypes, world, min_world, max_world) - ultimate_rt, _ = CC._ir_abstract_constant_propagation(analysis_interp, irsv; externally_refined) + ultimate_rt, _ = CC.ir_abstract_constant_propagation(analysis_interp, irsv; externally_refined) record_ir!(debug_config, "incidence_propagation", ir) # Encountering a `ddt` during abstract interpretation can add variables, @@ -745,7 +745,7 @@ end # recalculate domtree (inference could have changed the cfg) domtree = CC.construct_domtree(ir.cfg.blocks) - # We use the _ir_abstract_constant_propagation pass for three things: + # We use the ir_abstract_constant_propagation pass for three things: # 1. To establish incidence # 2. To constant propagate scope information that may not have been # available at inference time diff --git a/src/analysis/interpreter.jl b/src/analysis/interpreter.jl index 261b622..93ace4f 100644 --- a/src/analysis/interpreter.jl +++ b/src/analysis/interpreter.jl @@ -7,7 +7,7 @@ using .CC: AbstractInterpreter, NativeInterpreter, InferenceParams, Optimization StmtInfo, MethodCallResult, ConstCallResults, ConstPropResult, MethodTableView, CachedMethodTable, InternalMethodTable, OverlayMethodTable, CallMeta, CallInfo, IRCode, LazyDomtree, IRInterpretationState, set_inlineable!, block_for_inst, - BitSetBoundedMinPrioritySet, AbsIntState + BitSetBoundedMinPrioritySet, AbsIntState, Future using Base: IdSet using StateSelection: DiffGraph @@ -282,13 +282,13 @@ widenincidence(@nospecialize(x)) = x if length(argtypes) == 2 xarg = argtypes[2] if isa(xarg, Union{Incidence, Const}) - return structural_inc_ddt(interp.var_to_diff, interp.var_kind, xarg) + return Future{CallMeta}(structural_inc_ddt(interp.var_to_diff, interp.var_kind, xarg)) end end end if interp.in_analysis && !isa(f, Core.Builtin) && !isa(f, Core.IntrinsicFunction) # We don't want to do new inference here - return CallMeta(Any, Any, CC.Effects(), CC.NoCallInfo()) + return Future{CallMeta}(CallMeta(Any, Any, CC.Effects(), CC.NoCallInfo())) end ret = @invoke CC.abstract_call_known(interp::AbstractInterpreter, f::Any, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) @@ -306,26 +306,30 @@ widenincidence(@nospecialize(x)) = x end arginfo = ArgInfo(arginfo.fargs, map(widenincidence, arginfo.argtypes)) r = Diffractor.fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret) - r !== nothing && return r - return ret + return Future{CallMeta}(CC.isready(r) ? ret : r, interp, sv) do _, interp, sv + r[] !== nothing && return r[] + return ret[] + end end @override function CC.abstract_call_method(interp::DAEInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState) - ret = @invoke CC.abstract_call_method(interp::AbstractInterpreter, + mret = @invoke CC.abstract_call_method(interp::AbstractInterpreter, method::Method, sig::Any, sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState) - edge = ret.edge - if edge !== nothing - cache = CC.get(CC.code_cache(interp), edge, nothing) - if cache !== nothing - src = @atomic :monotonic cache.inferred - if isa(src, DAECache) - info = src.info - merge_daeinfo!(interp, sv.result, info) + return Future{MethodCallResult}(mret, interp, sv) do ret, interp, sv + edge = ret.edge + if edge !== nothing + cache = CC.get(CC.code_cache(interp), edge, nothing) + if cache !== nothing + src = @atomic :monotonic cache.inferred + if isa(src, DAECache) + info = src.info + merge_daeinfo!(interp, sv.result, info) + end end end + return ret end - return ret end @override function CC.const_prop_call(interp::DAEInterpreter, @@ -974,34 +978,36 @@ function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instr end @override function CC.abstract_eval_statement_expr(interp::DAEInterpreter, inst::Expr, vtypes::Nothing, irsv::IRInterpretationState) - (; rt, exct, effects) = @invoke CC.abstract_eval_statement_expr(interp::AbstractInterpreter, inst::Expr, vtypes::Nothing, irsv::IRInterpretationState) - - if (!interp.ipo_analysis_mode || interp.in_analysis) && !isa(rt, Const) && !isa(rt, Incidence) && !CC.isType(rt) && !is_all_inc_or_const(Any[rt]) - argtypes = CC.collect_argtypes(interp, inst.args, nothing, irsv) - if argtypes === nothing - return CC.RTEffects(rt, exct, effects) - end - if is_all_inc_or_const(argtypes) - if inst.head in (:call, :invoke) && CC.hasintersect(widenconst(argtypes[inst.head === :call ? 1 : 2]), Union{typeof(variable), typeof(sim_time), typeof(state_ddt)}) - # The `variable` and `state_ddt` intrinsics can source Incidence. For all other - # calls, if there's no incidence in the arguments, there cannot be any incidence - # in the result. + ret = @invoke CC.abstract_eval_statement_expr(interp::AbstractInterpreter, inst::Expr, vtypes::Nothing, irsv::IRInterpretationState) + return Future{CC.RTEffects}(ret, interp, irsv) do ret, interp, irsv + (; rt, exct, effects) = ret + if (!interp.ipo_analysis_mode || interp.in_analysis) && !isa(rt, Const) && !isa(rt, Incidence) && !CC.isType(rt) && !is_all_inc_or_const(Any[rt]) + argtypes = CC.collect_argtypes(interp, inst.args, nothing, irsv) + if argtypes === nothing return CC.RTEffects(rt, exct, effects) end - fb_inci = _fallback_incidence(argtypes) - if fb_inci !== nothing - update_type(t::Type) = Incidence(t, fb_inci.row, fb_inci.eps) - update_type(t::Incidence) = t - update_type(t::Const) = t - update_type(t::CC.PartialTypeVar) = t - update_type(t::PartialStruct) = PartialStruct(t.typ, Any[Base.isvarargtype(f) ? f : update_type(f) for f in t.fields]) - update_type(t::CC.Conditional) = CC.Conditional(t.slot, update_type(t.thentype), update_type(t.elsetype)) - newrt = update_type(rt) - return CC.RTEffects(newrt, exct, effects) + if is_all_inc_or_const(argtypes) + if inst.head in (:call, :invoke) && CC.hasintersect(widenconst(argtypes[inst.head === :call ? 1 : 2]), Union{typeof(variable), typeof(sim_time), typeof(state_ddt)}) + # The `variable` and `state_ddt` intrinsics can source Incidence. For all other + # calls, if there's no incidence in the arguments, there cannot be any incidence + # in the result. + return CC.RTEffects(rt, exct, effects) + end + fb_inci = _fallback_incidence(argtypes) + if fb_inci !== nothing + update_type(t::Type) = Incidence(t, fb_inci.row, fb_inci.eps) + update_type(t::Incidence) = t + update_type(t::Const) = t + update_type(t::CC.PartialTypeVar) = t + update_type(t::PartialStruct) = PartialStruct(t.typ, Any[Base.isvarargtype(f) ? f : update_type(f) for f in t.fields]) + update_type(t::CC.Conditional) = CC.Conditional(t.slot, update_type(t.thentype), update_type(t.elsetype)) + newrt = update_type(rt) + return CC.RTEffects(newrt, exct, effects) + end end end + return CC.RTEffects(rt, exct, effects) end - return CC.RTEffects(rt, exct, effects) end @override function CC.compute_forwarded_argtypes(interp::DAEInterpreter, arginfo::ArgInfo, sv::AbsIntState) @@ -1222,7 +1228,7 @@ function infer_ir!(ir, interp::AbstractInterpreter, mi::MethodInstance) min_world = world = get_inference_world(interp) max_world = get_world_counter() irsv = IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world) - (rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv) + (rt, nothrow) = CC.ir_abstract_constant_propagation(interp, irsv) return rt end