diff --git a/src/Cassette.jl b/src/Cassette.jl index 6ebbd8d..61396a4 100644 --- a/src/Cassette.jl +++ b/src/Cassette.jl @@ -2,23 +2,14 @@ __precompile__(false) module Cassette -using Core: CodeInfo, SlotNumber, NewvarNode, LabelNode, GotoNode, SSAValue +using Core: CodeInfo, SlotNumber, NewvarNode, GotoNode, SSAValue using Logging -abstract type AbstractTag end -struct BottomTag <: AbstractTag end - -abstract type AbstractPass end -struct NoPass <: AbstractPass end -(::Type{NoPass})(::Any, ::Any, code_info) = code_info - -abstract type AbstractContext{T<:AbstractTag,P<:AbstractPass,B} end - include("utilities.jl") +include("context.jl") include("tagged.jl") include("overdub.jl") -include("contextdef.jl") include("macros.jl") function __init__() diff --git a/src/context.jl b/src/context.jl new file mode 100644 index 0000000..ac150f2 --- /dev/null +++ b/src/context.jl @@ -0,0 +1,87 @@ +################## +# `AbstractPass` # +################## + +abstract type AbstractPass end + +struct NoPass <: AbstractPass end + +(::Type{NoPass})(::Any, ::Any, code_info) = code_info + +######### +# `Tag` # +######### + +# this @pure annotation has official vtjnash approval :p +Base.@pure _pure_objectid(x) = objectid(x) + +abstract type AbstractContextName end + +struct Tag{N<:AbstractContextName,X,E#=<:Union{Nothing,Tag}=#} end + +Tag(::Type{N}, ::Type{X}) where {N,X} = Tag(N, X, Nothing) + +Tag(::Type{N}, ::Type{X}, ::Type{E}) where {N,X,E} = Tag{N,pure_objectid(X),E}() + +################# +# `BindingMeta` # +################# +# We define these here because we need them to define `Context`, +# but most code that works with these types is in src/tagged.jl + +mutable struct BindingMeta + data::Any + BindingMeta() = new() +end + +const BindingMetaDict = Dict{Symbol,BindingMeta} +const BindingMetaCache = IdDict{Module,BindingMetaDict} + +############# +# `Context` # +############# + +struct Context{N<:AbstractContextName, + M<:Any, + P<:AbstractPass, + T<:Union{Nothing,Tag}, + B<:Union{Nothing,BindingMetaCache}} + name::N + metadata::M + pass::P + tag::T + bindings::B + function Context(name::N, metadata::M, pass::P, ::Nothing, ::Nothing) where {N,M,P} + return new{N,M,P,Nothing,Nothing}(name, metadata, pass, nothing, nothing) + end + function Context(name::N, metadata::M, pass::P, tag::Tag{N}, bindings::BindingMetaCache) where {N,M,P} + return new{N,M,P,typeof(tag),BindingMetaCache}(name, metadata, pass, tag, bindings) + end +end + +function Context(name::AbstractContextName; metadata = nothing, pass::AbstractPass = NoPass()) + return Context(name, metadata, pass, nothing, nothing) +end + +function similarcontext(context::Context; + metadata = context.metadata, + pass = context.pass, + tag = context.tag, + bindings = context.bindings) + return Context(context.name, metadata, pass, tag, bindings) +end + +const ContextWithTag{T} = Context{<:AbstractContextName,<:Any,<:AbstractPass,T} +const ContextWithPass{P} = Context{<:AbstractContextName,<:Any,P} + +has_tagging_enabled(::Type{<:ContextWithTag{<:Tag}}) = true +has_tagging_enabled(::Type{<:ContextWithTag{Nothing}}) = false + +tagtype(::C) where {C<:Context} = tagtype(C) +tagtype(::Type{<:ContextWithTag{T}}) where {T} = T + +function withtagfor(context::Context, f) + return similarcontext(context; + tag = Tag(typeof(context), typeof(f)), + bindings = BindingMetaCache()) +end diff --git a/src/contextdef.jl b/src/contextdef.jl deleted file mode 100644 index 26acb62..0000000 --- a/src/contextdef.jl +++ /dev/null @@ -1,117 +0,0 @@ -####################### -# unhygeniec bindings # -####################### - -const CONTEXT_TYPE_BINDING = Symbol("__CONTEXT__") -const CONTEXT_BINDING = Symbol("__context__") - -################### -# stubs/utilities # -################### - -# these stubs are only overloaded on a per-context basis -function generate_tag end -function similar_context end - -# this @pure annotations has official vtjnash approval :p -Base.@pure pure_objectid(x) = objectid(x) - -########################################### -# context definition generation from name # -########################################### - -function generate_context_definition(Ctx) - @assert isa(Ctx, Symbol) "context name must be a Symbol" - CtxTag = gensym(string(Ctx, "Tag")) - return quote - struct $CtxTag{E,H} <: $Cassette.AbstractTag end - - $CtxTag(x) = $CtxTag($Cassette.BottomTag(), x) - $CtxTag(::E, ::X) where {E,X} = $CtxTag{E,$Cassette.pure_objectid(X)}() - - struct $Ctx{M,T<:$CtxTag,P<:$Cassette.AbstractPass,B<:Union{Nothing,$Cassette.BindingMetaCache}} <: $Cassette.AbstractContext{T,P,B} - metadata::M - tag::T - pass::P - bindings::B # tagging functionality is considered enabled if this field is of type `BindingMetaCache` - end - - function $Ctx(; - metadata = nothing, - pass::$Cassette.AbstractPass = $Cassette.NoPass(), - tagging_enabled::Bool = false) - bindings = tagging_enabled ? $Cassette.BindingMetaCache() : nothing - return $Ctx(metadata, $CtxTag(nothing), pass, bindings) - end - - $Cassette.generate_tag(ctx::$Ctx, f) = $CtxTag(f) - - function $Cassette.similar_context(ctx::$Ctx; - metadata = ctx.metadata, - tag = ctx.tag, - pass = ctx.pass, - bindings = ctx.bindings) - return $Ctx(metadata, tag, pass, bindings) - end - - #=== default primitives ===# - - $Cassette.@primitive function $CtxTag(x) where {__CONTEXT__<:$Ctx} - return $CtxTag(__context__.tag, x) - end - - $Cassette.@primitive function Core._apply(f, args...) where {__CONTEXT__<:$Ctx} - flattened_args = Core._apply(tuple, args...) - return $Cassette.overdub_execute(__context__, f, flattened_args...) - end - - # dispatch on `B` to ensure that we don't call this when tagging is disabled - $Cassette.@primitive function Array{T,N}(undef::UndefInitializer, args...) where {T,N,__CONTEXT__<:$Ctx{<:Any,<:Any,<:Any,$Cassette.BindingMetaCache}} - return $Cassette.tagged_new(__context__, Array{T,N}, undef, args...) - end - - $Cassette.@primitive function Base.nameof(m) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_nameof(__context__, m) - end - - $Cassette.@primitive function Core.getfield(x, name) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_getfield(__context__, x, name) - end - - $Cassette.@primitive function Core.setfield!(x, name, y) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_setfield!(__context__, x, name, y) - end - - $Cassette.@primitive function Core.arrayref(boundscheck, x, i) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_arrayref(__context__, boundscheck, x, i) - end - - $Cassette.@primitive function Core.arrayset(boundscheck, x, y, i) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_arrayset(__context__, boundscheck, x, y, i) - end - - $Cassette.@primitive function Base._growbeg!(x, delta) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_growbeg!(__context__, x, delta) - end - - $Cassette.@primitive function Base._growend!(x, delta) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_growend!(__context__, x, delta) - end - - $Cassette.@primitive function Base._growat!(x, i, delta) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_growat!(__context__, x, i, delta) - end - - $Cassette.@primitive function Base._deletebeg!(x, delta) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_deletebeg!(__context__, x, delta) - end - - $Cassette.@primitive function Base._deleteend!(x, delta) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_deleteend!(__context__, x, delta) - end - - $Cassette.@primitive function Base._deleteat!(x, i, delta) where {__CONTEXT__<:$Ctx} - return $Cassette.tagged_deleteat!(__context__, x, i, delta) - end - end -end diff --git a/src/macros.jl b/src/macros.jl index 4001156..458d4c3 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -1,3 +1,6 @@ +const CONTEXT_TYPE_BINDING = Symbol("__CONTEXT__") +const CONTEXT_BINDING = Symbol("__context__") + ############ # @context # ############ @@ -8,7 +11,75 @@ Define a new Cassette context type with the name `Ctx`. """ macro context(Ctx) - return esc(generate_context_definition(Ctx)) + @assert isa(Ctx, Symbol) "context name must be a Symbol" + CtxName = gensym(string(Ctx, "Name")) + return esc(quote + struct $CtxName <: $Cassette.AbstractContextName end + + const $Ctx{M,T<:Union{Nothing,$Cassette.Tag},P<:$Cassette.AbstractPass} = $Cassette.Context{$CtxName,M,P,T} + + $Ctx(; kwargs...) = $Cassette.Context($CtxName(); kwargs...) + + $Cassette.@primitive function $Cassette.Tag(::Type{N}, ::Type{X}) where {__CONTEXT__<:$Ctx,N,X} + return Tag(N, X, $Cassette.tagtype(__CONTEXT__)) + end + + $Cassette.@primitive function Core._apply(f, args...) where {__CONTEXT__<:$Ctx} + flattened_args = Core._apply(tuple, args...) + return $Cassette.overdub_execute(__context__, f, flattened_args...) + end + + # enforce `T<:Cassette.Tag` to ensure that we only call the below primitive functions + # if the context has the tagging system enabled + + $Cassette.@primitive function Array{T,N}(undef::UndefInitializer, args...) where {T,N,__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_new(__context__, Array{T,N}, undef, args...) + end + + $Cassette.@primitive function Base.nameof(m) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_nameof(__context__, m) + end + + $Cassette.@primitive function Core.getfield(x, name) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_getfield(__context__, x, name) + end + + $Cassette.@primitive function Core.setfield!(x, name, y) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_setfield!(__context__, x, name, y) + end + + $Cassette.@primitive function Core.arrayref(boundscheck, x, i) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_arrayref(__context__, boundscheck, x, i) + end + + $Cassette.@primitive function Core.arrayset(boundscheck, x, y, i) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_arrayset(__context__, boundscheck, x, y, i) + end + + $Cassette.@primitive function Base._growbeg!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_growbeg!(__context__, x, delta) + end + + $Cassette.@primitive function Base._growend!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_growend!(__context__, x, delta) + end + + $Cassette.@primitive function Base._growat!(x, i, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_growat!(__context__, x, i, delta) + end + + $Cassette.@primitive function Base._deletebeg!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_deletebeg!(__context__, x, delta) + end + + $Cassette.@primitive function Base._deleteend!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_deleteend!(__context__, x, delta) + end + + $Cassette.@primitive function Base._deleteat!(x, i, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}} + return $Cassette.tagged_deleteat!(__context__, x, i, delta) + end + end) end ############ @@ -20,12 +91,7 @@ end A convenience macro for overdubbing and executing `expression` within the context `Ctx`. """ macro overdub(ctx, expr) - return quote - func = $(esc(CONTEXT_BINDING)) -> $(esc(expr)) - ctx = Cassette.similar_context($(esc(ctx)); - tag = $Cassette.generate_tag($(esc(ctx)), func)) - $Cassette.overdub_recurse(ctx, func, ctx) - end + return :($Cassette.overdub_recurse($(esc(ctx)), () -> $(esc(expr)))) end ######### diff --git a/src/overdub.jl b/src/overdub.jl index b99441f..89f6b31 100644 --- a/src/overdub.jl +++ b/src/overdub.jl @@ -2,13 +2,13 @@ # contextual operations # ######################### -@inline prehook(::AbstractContext, ::Vararg{Any}) = nothing -@inline posthook(::AbstractContext, ::Vararg{Any}) = nothing -@inline is_user_primitive(::AbstractContext, ::Vararg{Any}) = false -@inline is_core_primitive(ctx::AbstractContext, args...) = _is_core_primitive(ctx, args...) -@inline execution(::AbstractContext, f, args...) = f(args...) +@inline prehook(::Context, ::Vararg{Any}) = nothing +@inline posthook(::Context, ::Vararg{Any}) = nothing +@inline is_user_primitive(::Context, ::Vararg{Any}) = false +@inline is_core_primitive(ctx::Context, args...) = _is_core_primitive(ctx, args...) +@inline execution(::Context, f, args...) = f(args...) -@generated function _is_core_primitive(::C, args...) where {C<:AbstractContext} +@generated function _is_core_primitive(::C, args...) where {C<:Context} # TODO: this is slow, we should try to check whether the reflection is possible # without going through the whole process of actually computing it untagged_args = ((untagtype(args[i], C) for i in 1:nfields(args))...,) @@ -27,7 +27,7 @@ end # overdub_execute # ################### -@inline function overdub_execute(ctx::AbstractContext, args...) +@inline function overdub_execute(ctx::Context, args...) prehook(ctx, args...) if is_user_primitive(ctx, args...) output = execution(ctx, args...) @@ -45,10 +45,13 @@ end const OVERDUB_CTX_SYMBOL = gensym("overdub_context") const OVERDUB_ARGS_SYMBOL = gensym("overdub_arguments") -# Note that this pass emits code in which LHS SSAValues are not monotonically increasing. -# This currently isn't a problem, but in the future, valid IR might require monotonically -# increasing LHS SSAValues, in which case we'll have to add an extra SSA-remapping pass to -# this function. +# The `overdub_recurse` pass has four intertwined tasks: +# 1. Apply the user-provided pass, if one is given +# 2. Munge the reflection-generated IR into a valid form for returning from +# `overdub_recurse_generator` (i.e. add new argument slots, substitute static +# parameters, destructure overdub arguments into underlying method slots, etc.) +# 3. Translate all function calls to `overdub_execute` calls +# 4. If tagging is enabled, do the necessary IR transforms for the metadata tagging system function overdub_recurse_pass!(reflection::Reflection, context_type::DataType, pass_type::DataType = NoPass) @@ -57,9 +60,12 @@ function overdub_recurse_pass!(reflection::Reflection, static_params = reflection.static_params code_info = reflection.code_info - # execute user-provided pass (is a no-op by default) + #=== 1. Execute user-provided pass (is a no-op by default) ===# + code_info = pass_type(context_type, signature, code_info) + #=== 2. Munge the code into a valid form for `overdub_recurse_generator` ===# + # construct new slotnames/slotflags for added slots code_info.slotnames = Any[:overdub_recurse, OVERDUB_CTX_SYMBOL, OVERDUB_ARGS_SYMBOL, code_info.slotnames...] code_info.slotflags = UInt8[0x00, 0x00, 0x00, code_info.slotflags...] @@ -67,18 +73,11 @@ function overdub_recurse_pass!(reflection::Reflection, overdub_ctx_slot = SlotNumber(2) overdub_args_slot = SlotNumber(3) - # substitute static parameters and offset slotnumbers by number of added slots - code_expr = Expr(:block) - code_expr.args = code_info.code - Core.Compiler.substitute!(code_expr, 0, Any[], method.sig, static_params, n_overdub_slots, :propagate) - - # Instantiate a new code array containing the same preceding `Nothing`s, `NewvarNode`s, - # etc. as `code_info`'s code array. The rest of this pass will translate statements from - # `code_info.code` to `overdubbed_code`, instead of updating `code_info.code` in-place - # (just for the sake of convenience). Then, at the end, we'll set `code_info.code` to - # `overdubbed_code`. - overdubbed_code = copy_prelude_code(code_info.code) - prelude_length = length(overdubbed_code) + # For the sake of convenience, the rest of this pass will translate `code_info`'s fields + # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at + # the end of the pass, we'll reset `code_info` fields accordingly. + overdubbed_code = Any[] + overdubbed_codelocs = Int32[] # destructure the generated argument slots into the overdubbed method's argument slots. n_actual_args = fieldcount(signature) @@ -87,56 +86,78 @@ function overdub_recurse_pass!(reflection::Reflection, slot = i + n_overdub_slots actual_argument = Expr(:call, GlobalRef(Core, :getfield), overdub_args_slot, i) push!(overdubbed_code, :($(SlotNumber(slot)) = $actual_argument)) + push!(overdubbed_codelocs, code_info.codelocs[1]) code_info.slotflags[slot] = 0x18 # this slot is now an "SSA slot" end - # If `method` is a varargs method, we have to destructure the original method call's + # If `method` is a varargs method, we have to restructure the original method call's # trailing arguments into a tuple and assign that tuple to the expected argument slot. if method.isva - # remove the final slot reassignment leftover from the previous destructuring - isempty(overdubbed_code) || pop!(overdubbed_code) - final_arguments = Expr(:call, GlobalRef(Core, :tuple)) + if !isempty(overdubbed_code) + # remove the final slot reassignment leftover from the previous destructuring + pop!(overdubbed_code) + pop!(overdubbed_codelocs) + end + trailing_arguments = Expr(:call, GlobalRef(Core, :tuple)) for i in n_method_args:n_actual_args - ssaval = SSAValue(code_info.ssavaluetypes) - actual_argument = Expr(:call, GlobalRef(Core, :getfield), overdub_args_slot, i) - push!(overdubbed_code, :($ssaval = $actual_argument)) - push!(final_arguments.args, ssaval) - code_info.ssavaluetypes += 1 + push!(overdubbed_code, Expr(:call, GlobalRef(Core, :getfield), overdub_args_slot, i)) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(trailing_arguments.args, SSAValue(length(overdubbed_code))) end - push!(overdubbed_code, :($(SlotNumber(n_method_args + n_overdub_slots)) = $final_arguments)) + push!(overdubbed_code, Expr(:(=), SlotNumber(n_method_args + n_overdub_slots), trailing_arguments)) + push!(overdubbed_codelocs, code_info.codelocs[1]) end - # Scan the IR for `Module`s in the first argument position for `GlobalRef`s. For every - # unique such `Module`, make a new `SSAValue` at the top of the method body corresponding - # to `Cassette.fetch_tagged_module` called with the given context and module. Then, - # replace all `GlobalRef`-loads with the corresponding `Cassette._tagged_global_ref` - # invocation. All `GlobalRef`-stores must be preserved as-is, but need a follow-up - # statement calling `Cassette._tagged_global_ref_set_meta!` on the relevant arguments. - # TODO + #=== 3. Translate function calls to `overdub_execute` calls ===# + + # substitute static parameters, offset slot numbers by number of added slots, and + # offset statement indices by the number of additional statements + Base.Meta.partially_inline!(code_info.code, Any[], method.sig, static_params, + n_overdub_slots, length(overdubbed_code), :propagate) # For the rest of the statements in `code_info.code`, intercept every applicable call # expression and replace it with a corresponding call to `Cassette.overdub_execute`. - for i in (prelude_length + 1):length(code_info.code) + for i in 1:length(code_info.code) stmnt = code_info.code[i] replace_match!(is_call, stmnt) do call call.args = Any[GlobalRef(Cassette, :overdub_execute), overdub_ctx_slot, call.args...] return call end push!(overdubbed_code, stmnt) + push!(overdubbed_codelocs, code_info.codelocs[i]) end - # TODO: Replace `new` expressions with calls to `Cassette.tagged_new`. + #=== 4. IR transforms for the metadata tagging system ===# - # TODO: appropriately untag all `gotoifnot` conditionals + if has_tagging_enabled(context_type) + changemap = fill(0, length(code_info.code)) + # Scan the IR for `Module`s in the first argument position for `GlobalRef`s. For every + # unique such `Module`, make a new `SSAValue` at the top of the method body corresponding + # to `Cassette.fetch_tagged_module` called with the given context and module. Then, + # replace all `GlobalRef`-loads with the corresponding `Cassette._tagged_global_ref` + # invocation. All `GlobalRef`-stores must be preserved as-is, but need a follow-up + # statement calling `Cassette._tagged_global_ref_set_meta!` on the relevant arguments. + # TODO - code_info.code = fix_labels_and_gotos!(overdubbed_code) + # TODO: Replace `new` expressions with calls to `Cassette.tagged_new`. + + # TODO: appropriately untag all `gotoifnot` conditionals + Core.Compiler.renumber_ir_elements!(overdubbed_code, changemap) + end + + #=== 5. Set `code_info`/`reflection` fields accordingly ===# + + code_info.code = overdubbed_code + code_info.codelocs = overdubbed_codelocs + code_info.ssavaluetypes = length(overdubbed_code) code_info.method_for_inference_limit_heuristics = method reflection.code_info = code_info + return reflection end # `args` is `(typeof(original_function), map(typeof, original_args_tuple)...)` -function overdub_recurse_generator(tag_type, pass_type, self, context_type, args::Tuple) +function overdub_recurse_generator(pass_type, self, context_type, args::Tuple) try untagged_args = ((untagtype(args[i], context_type) for i in 1:nfields(args))...,) reflection = reflect(untagged_args) @@ -153,7 +174,6 @@ function overdub_recurse_generator(tag_type, pass_type, self, context_type, args end return body catch err - @safe_error "error compiling" args context=context_type errmsg = "ERROR COMPILING $args IN CONTEXT $(context_type): \n" * sprint(showerror, err) return quote error($errmsg) @@ -163,14 +183,14 @@ end function overdub_recurse_definition(pass, line, file) return quote - function overdub_recurse($OVERDUB_CTX_SYMBOL::AbstractContext{tag,pass}, $OVERDUB_ARGS_SYMBOL...) where {tag,pass<:$pass} + function overdub_recurse($OVERDUB_CTX_SYMBOL::ContextWithPass{pass}, $OVERDUB_ARGS_SYMBOL...) where {pass<:$pass} $(Expr(:meta, :generated, Expr(:new, Core.GeneratedFunctionStub, :overdub_recurse_generator, Any[:overdub_recurse, OVERDUB_CTX_SYMBOL, OVERDUB_ARGS_SYMBOL], - Any[:tag, :pass], + Any[:pass], line, QuoteNode(Symbol(file)), true))) diff --git a/src/tagged.jl b/src/tagged.jl index dd6cc76..8e90bac 100644 --- a/src/tagged.jl +++ b/src/tagged.jl @@ -46,17 +46,10 @@ const NOMETA = Meta(NoMetaData(), NoMetaMeta()) Base.convert(::Type{M}, meta::M) where {M<:Meta} = meta Base.convert(::Type{Meta{D,M}}, meta::Meta) where {D,M} = Meta{D,M}(meta.data, meta.meta) -############################## -# `ModuleMeta`/`BindingMeta` # -############################## - -mutable struct BindingMeta - data::Any - BindingMeta() = new() -end - -const BindingMetaDict = Dict{Symbol,BindingMeta} -const BindingMetaCache = IdDict{Module,BindingMetaDict} +################ +# `ModuleMeta` # +################ +# note that `BindingMeta` was defined earlier in src/context.jl struct ModuleMeta{D,M} name::Meta{D,M} @@ -67,12 +60,12 @@ end # invocation. We easily have the module at compile time, but we don't have access to the # actual context object. This `@pure` is vtjnash-approved. It should allow the compiler to # optimize away the fetch once we have support for it, e.g. loop invariant code motion. -Base.@pure @noinline function fetch_tagged_module(context::AbstractContext, m::Module) +Base.@pure @noinline function fetch_tagged_module(context::Context, m::Module) bindings = get!(() -> BindingMetaDict(), context.bindings, m) return Tagged(context.tag, m, Meta(NoMetaData(), ModuleMeta(NOMETA, bindings))) end -Base.@pure @noinline function _fetch_binding_meta!(context::AbstractContext, +Base.@pure @noinline function _fetch_binding_meta!(context::Context, m::Module, bindings::BindingMetaDict, name::Symbol) @@ -90,7 +83,7 @@ Base.@pure @noinline function _fetch_binding_meta!(context::AbstractContext, end end -function fetch_binding_meta!(context::AbstractContext, +function fetch_binding_meta!(context::Context, m::Module, bindings::BindingMetaDict, name::Symbol, @@ -123,7 +116,7 @@ benefit that metatype computation is very fast. If, in the future, `metatype` is parameterized on world age, then we can call `subtypes` at compile time, and compute a more optimally bounded metatype. =# -function metatype(::Type{C}, ::Type{T}) where {C<:AbstractContext,T} +function metatype(::Type{C}, ::Type{T}) where {C<:Context,T} if isconcretetype(T) return Meta{metadatatype(C, T),metametatype(C, T)} end @@ -132,11 +125,11 @@ end #=== metadatatype ===# -metadatatype(::Type{<:AbstractContext}, ::DataType) = NoMetaData +metadatatype(::Type{<:Context}, ::DataType) = NoMetaData #=== metametatype ===# -@generated function metametatype(::Type{C}, ::Type{T}) where {C<:AbstractContext,T} +@generated function metametatype(::Type{C}, ::Type{T}) where {C<:Context,T} if !(isconcretetype(T)) return :(error("cannot call metametatype on non-concrete type ", $T)) end @@ -159,24 +152,24 @@ metadatatype(::Type{<:AbstractContext}, ::DataType) = NoMetaData end end -@generated function metametatype(::Type{C}, ::Type{T}) where {C<:AbstractContext,T<:Array} +@generated function metametatype(::Type{C}, ::Type{T}) where {C<:Context,T<:Array} return quote $(Expr(:meta, :inline)) Array{metatype(C, $(eltype(T))),$(ndims(T))} end end -@inline function metametatype(::Type{C}, ::Type{Module}) where {C<:AbstractContext} +@inline function metametatype(::Type{C}, ::Type{Module}) where {C<:Context} return ModuleMeta{metadatatype(C, Symbol), metametatype(C, Symbol)} end #=== initmetameta ===# -@inline initmetameta(context::AbstractContext, value::Module) = fetch_tagged_module(context, value).meta +@inline initmetameta(context::Context, value::Module) = fetch_tagged_module(context, value).meta -@inline initmetameta(context::C, value::Array{T}) where {C<:AbstractContext,T} = similar(value, metatype(C, T)) +@inline initmetameta(context::C, value::Array{T}) where {C<:Context,T} = similar(value, metatype(C, T)) -function initmetameta(context::C, value::V) where {C<:AbstractContext,V} +function initmetameta(context::C, value::V) where {C<:Context,V} if fieldcount(V) == 0 metameta_expr = :(NoMetaMeta()) else @@ -199,7 +192,7 @@ end #=== initmeta ===# -@inline function initmeta(context::C, value::V, metadata::D) where {C<:AbstractContext,V,D} +@inline function initmeta(context::C, value::V, metadata::D) where {C<:Context,V,D} return Meta{metadatatype(C, V),metametatype(C, V)}(metadata, initmetameta(context, value)) end @@ -212,11 +205,11 @@ Here, `U` is the innermost, "underlying" type of the value being wrapped. This p precomputed so that Cassette can directly dispatch on it in signatures generated for contextual primitives. =# -struct Tagged{T<:AbstractTag,U,V,D,M} +struct Tagged{T<:Tag,U,V,D,M} tag::T value::V meta::Meta{D,M} - function Tagged(tag::T, value::V, meta::Meta{D,M}) where {T<:AbstractTag,V,D,M} + function Tagged(tag::T, value::V, meta::Meta{D,M}) where {T<:Tag,V,D,M} return new{T,_underlying_type(V),V,D,M}(tag, value, meta) end end @@ -224,49 +217,49 @@ end #=== `Tagged` internals ===# _underlying_type(::Type{V}) where {V} = V -_underlying_type(::Type{<:Tagged{<:AbstractTag,U}}) where {U} = U +_underlying_type(::Type{<:Tagged{<:Tag,U}}) where {U} = U #=== `Tagged` API ===# -function tag(context::AbstractContext, value, metadata = NoMetaData()) +function tag(value, context::Context, metadata = NoMetaData()) return Tagged(context.tag, value, initmeta(context, value, metadata)) end -untag(x, context::AbstractContext) = untag(x, context.tag) -untag(x::Tagged{T}, tag::T) where {T<:AbstractTag} = x.value -untag(x, tag::AbstractTag) = x +untag(x, context::Context) = untag(x, context.tag) +untag(x::Tagged{T}, tag::T) where {T<:Tag} = x.value +untag(x, ::Union{Tag,Nothing}) = x -untagtype(::Type{X}, ::Type{<:AbstractContext{T}}) where {X,T} = untagtype(X, T) -untagtype(::Type{<:Tagged{T,U,V}}, ::Type{T}) where {T<:AbstractTag,U,V} = V -untagtype(::Type{X}, ::Type{<:AbstractTag}) where {X} = X +untagtype(::Type{X}, ::Type{C}) where {X,C<:Context} = untagtype(X, tagtype(C)) +untagtype(::Type{<:Tagged{T,U,V}}, ::Type{T}) where {T<:Tag,U,V} = V +untagtype(::Type{X}, ::Type{<:Union{Tag,Nothing}}) where {X} = X -metadata(x, context::AbstractContext) = metadata(x, context.tag) -metadata(x::Tagged{T}, tag::T) where {T<:AbstractTag} = x.meta.data -metadata(x, tag::AbstractTag) = NoMetaData() +metadata(x, context::Context) = metadata(x, context.tag) +metadata(x::Tagged{T}, tag::T) where {T<:Tag} = x.meta.data +metadata(::Any, ::Union{Tag,Nothing}) = NoMetaData() -metameta(x, context::AbstractContext) = metameta(x, context.tag) -metameta(x::Tagged{T}, tag::T) where {T<:AbstractTag} = x.meta.meta -metameta(x, tag::AbstractTag) = NoMetaMeta() +metameta(x, context::Context) = metameta(x, context.tag) +metameta(x::Tagged{T}, tag::T) where {T<:Tag} = x.meta.meta +metameta(::Any, ::Union{Tag,Nothing}) = NoMetaMeta() -istagged(x, context::AbstractContext) = istagged(x, context.tag) -istagged(x::Tagged{T}, tag::T) where {T<:AbstractTag} = true -istagged(x, tag::AbstractTag) = false +istagged(x, context::Context) = istagged(x, context.tag) +istagged(x::Tagged{T}, tag::T) where {T<:Tag} = true +istagged(::Any, ::Union{Tag,Nothing}) = false -istaggedtype(::Type{X}, ::Type{<:AbstractContext{T}}) where {X,T} = istaggedtype(X, T) -istaggedtype(::Type{<:Tagged{T}}, ::Type{T}) where {T<:AbstractTag} = true -istaggedtype(::Type{<:Any}, ::Type{<:AbstractTag}) = false +istaggedtype(::Type{X}, ::Type{C}) where {X,C<:Context} = istaggedtype(X, tagtype(C)) +istaggedtype(::Type{<:Tagged{T}}, ::Type{T}) where {T<:Tag} = true +istaggedtype(::DataType, ::Type{<:Union{Tag,Nothing}}) = false -hasmetadata(x, context::AbstractContext) = hasmetadata(x, context.tag) -hasmetadata(x, tag::AbstractTag) = !isa(metadata(x, tag), NoMetaData) +hasmetadata(x, context::Context) = hasmetadata(x, context.tag) +hasmetadata(x, tag::Union{Tag,Nothing}) = !isa(metadata(x, tag), NoMetaData) -hasmetameta(x, context::AbstractContext) = hasmetameta(x, context.tag) -hasmetameta(x, tag::AbstractTag) = !isa(metameta(x, tag), NoMetaMeta) +hasmetameta(x, context::Context) = hasmetameta(x, context.tag) +hasmetameta(x, tag::Union{Tag,Nothing}) = !isa(metameta(x, tag), NoMetaMeta) ################ # `tagged_new` # ################ -@generated function tagged_new(context::C, ::Type{T}, args...) where {C<:AbstractContext,T} +@generated function tagged_new(context::C, ::Type{T}, args...) where {C<:Context,T} tagged_count = 0 fields = Expr(:tuple) ftypes = [fieldtype(T, i) for i in 1:fieldcount(T)] @@ -316,15 +309,15 @@ hasmetameta(x, tag::AbstractTag) = !isa(metameta(x, tag), NoMetaMeta) end end -@generated function tagged_new(context::C, ::Type{T}, args...) where {C<:AbstractContext,T<:Array} +@generated function tagged_new(context::C, ::Type{T}, args...) where {C<:Context,T<:Array} untagged_args = [:(untagged(args[$i], context)) for i in 1:nfields(args)] return quote $(Expr(:meta, :inline)) - return tag(context, $(T)($(untagged_args...))) + return tag($(T)($(untagged_args...)), context) end end -@generated function tagged_new(context::C, ::Type{Module}, args...) where {C<:AbstractContext} +@generated function tagged_new(context::C, ::Type{Module}, args...) where {C<:Context} if istaggedtype(args[1], C) return_expr = quote Tagged(tagged_module.tag, tagged_module.value, @@ -348,9 +341,9 @@ end #=== tagged_nameof ===# -tagged_nameof(context::AbstractContext, x) = nameof(untag(x, context)) +tagged_nameof(context::Context, x) = nameof(untag(x, context)) -function tagged_nameof(context::AbstractContext{T}, x::Tagged{T,Module}) where {T} +function tagged_nameof(context::ContextWithTag{T}, x::Tagged{T,Module}) where {T} name_value = nameof(x.value) name_meta = hasmetameta(x, context) ? x.meta.meta.name : NOMETA return Tagged(context.tag, name_value, name_meta) @@ -358,7 +351,7 @@ end #=== tagged_globalref ===# -function tagged_globalref(context::AbstractContext{T}, +function tagged_globalref(context::ContextWithTag{T}, m::Tagged{T}, name, primal) where {T} @@ -369,7 +362,7 @@ function tagged_globalref(context::AbstractContext{T}, end end -function _tagged_globalref(context::AbstractContext{T}, +function _tagged_globalref(context::ContextWithTag{T}, m::Tagged{T}, name, primal) where {T} @@ -389,7 +382,7 @@ end # TODO: try to inline these operations into the IR so that the primal binding and meta # binding mutations occur directly next to one another -function tagged_global_set!(context::AbstractContext{T}, m::Tagged{T}, name::Symbol, primal) where {T} +function tagged_global_set!(context::ContextWithTag{T}, m::Tagged{T}, name::Symbol, primal) where {T} binding = fetch_binding!(context, m.value, m.meta.meta.bindings, name) meta = istagged(primal, context) ? primal.meta : NOMETA # this line is where the primal binding assignment should happen order-wise @@ -399,14 +392,14 @@ end #=== tagged_getfield ===# -tagged_getfield(context::AbstractContext, x, name) = getfield(x, untag(name, context)) +tagged_getfield(context::Context, x, name) = getfield(x, untag(name, context)) -function tagged_getfield(context::AbstractContext{T}, x::Tagged{T,Module}, name) where {T} +function tagged_getfield(context::ContextWithTag{T}, x::Tagged{T,Module}, name) where {T} untagged_name = untag(name, context) return tagged_global_ref(context, x, untagged_name, getfield(x.value, untagged_name)) end -function tagged_getfield(context::AbstractContext{T}, x::Tagged{T}, name) where {T} +function tagged_getfield(context::ContextWithTag{T}, x::Tagged{T}, name) where {T} untagged_name = untag(name, context) y_value = getfield(untag(x, context), untagged_name) y_meta = hasmetameta(x, context) ? load(getfield(x.meta.meta, untagged_name)) : NOMETA @@ -415,9 +408,9 @@ end #=== tagged_setfield! ===# -tagged_setfield!(context::AbstractContext, x, name, y) = setfield!(x, untag(name, context), y) +tagged_setfield!(context::Context, x, name, y) = setfield!(x, untag(name, context), y) -function tagged_setfield!(context::AbstractContext{T}, x::Tagged{T}, name, y) where {T} +function tagged_setfield!(context::ContextWithTag{T}, x::Tagged{T}, name, y) where {T} untagged_name = untag(name, context) y_value = untag(y, context) y_meta = istagged(y, context) ? y.meta : NOMETA @@ -428,11 +421,11 @@ end #=== tagged_arrayref ===# -function tagged_arrayref(context::AbstractContext, boundscheck, x, i) +function tagged_arrayref(context::Context, boundscheck, x, i) return Core.arrayref(untag(boundscheck, context), x, untag(i, context)) end -function tagged_arrayref(context::AbstractContext{T}, boundscheck, x::Tagged{T}, i) where {T} +function tagged_arrayref(context::ContextWithTag{T}, boundscheck, x::Tagged{T}, i) where {T} untagged_boundscheck = untag(boundscheck, context) untagged_i = untag(i, context) y_value = Core.arrayref(untagged_boundscheck, untag(x, context), untagged_i) @@ -442,11 +435,11 @@ end #=== tagged_arrayset ===# -function tagged_arrayset(context::AbstractContext, boundscheck, x, y, i) +function tagged_arrayset(context::Context, boundscheck, x, y, i) return Core.arrayset(untag(boundscheck, context), x, y, untag(i, context)) end -function tagged_arrayset(context::AbstractContext{T}, boundscheck, x::Tagged{T}, y, i) where {T} +function tagged_arrayset(context::ContextWithTag{T}, boundscheck, x::Tagged{T}, y, i) where {T} untagged_boundscheck = untag(boundscheck, context) untagged_i = untag(i, context) y_value = untag(y, context) @@ -458,9 +451,9 @@ end #=== tagged_growbeg! ===# -tagged_growbeg!(context::AbstractContext, x, delta) = Base._growbeg!(x, untag(delta, context)) +tagged_growbeg!(context::Context, x, delta) = Base._growbeg!(x, untag(delta, context)) -function tagged_growbeg!(context::AbstractContext{T}, x::Tagged{T}, delta) where {T} +function tagged_growbeg!(context::ContextWithTag{T}, x::Tagged{T}, delta) where {T} untagged_delta = untag(delta, context) Base._growbeg!(x.value, delta_untagged) hasmetameta(x, context) && Base._growbeg!(x.meta.meta, delta_untagged) @@ -469,9 +462,9 @@ end #=== tagged_growend! ===# -tagged_growend!(context::AbstractContext, x, delta) = Base._growend!(x, untag(delta, context)) +tagged_growend!(context::Context, x, delta) = Base._growend!(x, untag(delta, context)) -function tagged_growend!(context::AbstractContext{T}, x::Tagged{T}, delta) where {T} +function tagged_growend!(context::ContextWithTag{T}, x::Tagged{T}, delta) where {T} untagged_delta = untag(delta, context) Base._growend!(x.value, delta_untagged) hasmetameta(x, context) && Base._growend!(x.meta.meta, delta_untagged) @@ -480,11 +473,11 @@ end #=== tagged_growat! ===# -function tagged_growat!(context::AbstractContext, x, i, delta) +function tagged_growat!(context::Context, x, i, delta) return Base._growat!(x, untag(i, context), untag(delta, context)) end -function tagged_growat!(context::AbstractContext{T}, x::Tagged{T}, i, delta) where {T} +function tagged_growat!(context::ContextWithTag{T}, x::Tagged{T}, i, delta) where {T} i_untagged = untag(i, context) delta_untagged = untag(delta, context) Base._growat!(x.value, i_untagged, delta_untagged) @@ -494,9 +487,9 @@ end #=== tagged_deletebeg! ===# -tagged_deletebeg!(context::AbstractContext, x, delta) = Base._deletebeg!(x, untag(delta, context)) +tagged_deletebeg!(context::Context, x, delta) = Base._deletebeg!(x, untag(delta, context)) -function tagged_deletebeg!(context::AbstractContext{T}, x::Tagged{T}, delta) where {T} +function tagged_deletebeg!(context::ContextWithTag{T}, x::Tagged{T}, delta) where {T} untagged_delta = untag(delta, context) Base._deletebeg!(x.value, delta_untagged) hasmetameta(x, context) && Base._deletebeg!(x.meta.meta, delta_untagged) @@ -505,9 +498,9 @@ end #=== tagged_deleteend! ===# -tagged_deleteend!(context::AbstractContext, x, delta) = Base._deleteend!(x, untag(delta, context)) +tagged_deleteend!(context::Context, x, delta) = Base._deleteend!(x, untag(delta, context)) -function tagged_deleteend!(context::AbstractContext{T}, x::Tagged{T}, delta) where {T} +function tagged_deleteend!(context::ContextWithTag{T}, x::Tagged{T}, delta) where {T} untagged_delta = untag(delta, context) Base._deleteend!(x.value, delta_untagged) hasmetameta(x, context) && Base._deleteend!(x.meta.meta, delta_untagged) @@ -516,11 +509,11 @@ end #=== tagged_deleteat! ===# -function tagged_deleteat!(context::AbstractContext, x, i, delta) +function tagged_deleteat!(context::Context, x, i, delta) return Base._deleteat!(x, untag(i, context), untag(delta, context)) end -function tagged_deleteat!(context::AbstractContext{T}, x::Tagged{T}, i, delta) where {T} +function tagged_deleteat!(context::ContextWithTag{T}, x::Tagged{T}, i, delta) where {T} i_untagged = untag(i, context) delta_untagged = untag(delta, context) Base._deleteat!(x.value, i_untagged, delta_untagged) diff --git a/src/utilities.jl b/src/utilities.jl index 16f83b8..651ec70 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -38,9 +38,11 @@ function replace_match!(replace, ismatch, x) return x end -####################### -# Julia IR/Reflection # -####################### +############ +# Julia IR # +############ + +#=== reflection ===# mutable struct Reflection signature::DataType @@ -77,6 +79,9 @@ function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt)) return Reflection(S, method, static_params, code_info) end +#=== IR Repair ===# + +# TODO: update this function fix_labels_and_gotos!(code::Vector) changes = Dict{Int,Int}() for (i, stmnt) in enumerate(code) @@ -97,18 +102,6 @@ function fix_labels_and_gotos!(code::Vector) return code end -function copy_prelude_code(code::Vector) - prelude_code = Any[] - for stmnt in code - if isa(stmnt, Nothing) || isa(stmnt, NewvarNode) - push!(prelude_code, stmnt) - else - break - end - end - return prelude_code -end - ################# # Miscellaneous # ################# diff --git a/test/ExampleTests.jl b/test/ExampleTests.jl index a31b277..c480a19 100644 --- a/test/ExampleTests.jl +++ b/test/ExampleTests.jl @@ -108,7 +108,6 @@ c = Count{Union{String,Int}}(0) ############################################################################################ -#= XXX: This test requires Cassette's world age problems to be fixed (https://github.com/jrevels/Cassette.jl/issues/6) @context WorldCtx worldtest = 0 @@ -121,12 +120,12 @@ Cassette.overdub_recurse(WorldCtx(), sin, 1) tmp = worldtest Cassette.overdub_recurse(oldctx, sin, 1) -@test tmp === worldtest +@test tmp < worldtest +tmp = worldtest @prehook (f::Any)(args...) where {__CONTEXT__<:WorldCtx} = nothing Cassette.overdub_recurse(WorldCtx(), sin, 1) @test tmp === worldtest -=# ############################################################################################ @@ -138,7 +137,7 @@ Cassette.overdub_recurse(WorldCtx(), sin, 1) if Cassette.is_core_primitive(__context__, f, args...) return f(args...) else - newctx = Cassette.similar_context(__context__, metadata = subtrace) + newctx = Cassette.similarcontext(__context__, metadata = subtrace) return Cassette.overdub_recurse(newctx, f, args...) end end @@ -157,6 +156,17 @@ trtest(x, y, z) = x*y + y*z ############################################################################################ +@context NestedReflectCtx +r_pre = Cassette.reflect((typeof(sin), Int)) +r_post = Cassette.reflect((typeof(Cassette.overdub_recurse), typeof(NestedReflectCtx()), typeof(sin), Int)) +@test isa(r_pre, Cassette.Reflection) && isa(r_post, Cassette.Reflection) +Cassette.overdub_recurse_pass!(r_pre, typeof(NestedReflectCtx())) +@test r_pre.code_info.code == r_post.code_info.code + +#= TODO: The rest of the tests below should be restored for the metadata tagging system + +############################################################################################ + @context NestedCtx function nested_test(n, x) @@ -168,7 +178,7 @@ function nested_test(n, x) end x = rand() -tags = Cassette.AbstractTag[] +tags = Cassette.Tag[] tag_id = objectid(typeof(nested_test)) @prehook function (::Any)(args...) where {__CONTEXT__<:NestedCtx} @@ -188,17 +198,6 @@ end ############################################################################################ -@context NestedReflectCtx -r_pre = Cassette.reflect((typeof(sin), Int)) -r_post = Cassette.reflect((typeof(Cassette.overdub_recurse), typeof(NestedReflectCtx()), typeof(sin), Int)) -@test isa(r_pre, Cassette.Reflection) && isa(r_post, Cassette.Reflection) -Cassette.overdub_recurse_pass!(r_pre) -@test r_pre.code_info.code == r_post.code_info.code - -############################################################################################ - -# TODO: restore this test once Boxes are restored -#= struct Baz x::Int y::Float64 @@ -221,11 +220,9 @@ result = @overdub BazCtx(boxes=Val(true)) begin end @test x === result[1] @test n === result[2] -=# ############################################################################################ -# TODO: restore this test once Boxes are restored -#= + struct Bar{X,Y,Z} x::X y::Y @@ -260,13 +257,12 @@ result = @overdub FooBarCtx(boxes=Val(true)) begin end @test x === result[1] @test n === result[2] -=# + ############################################################################################ # TODO: The below is a highly pathological function for metadata propagation; we should turn # it into an actual test -#= const const_binding = Float64[] global global_binding = 1.0 @@ -303,6 +299,7 @@ function f(x::Vector{Float64}, y::Vector{Float64}) global global_binding = 1.0 return z end + =# end # module