Skip to content

Refactor for recent IR changes + move context definition out of macro #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions src/Cassette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,14 @@ __precompile__(false)

module Cassette

using Core: CodeInfo, SlotNumber, NewvarNode, LabelNode, GotoNode, SSAValue
using Core: CodeInfo, SlotNumber, NewvarNode, GotoNode, SSAValue

using Logging

abstract type AbstractTag end
struct BottomTag <: AbstractTag end

abstract type AbstractPass end
struct NoPass <: AbstractPass end
(::Type{NoPass})(::Any, ::Any, code_info) = code_info

abstract type AbstractContext{T<:AbstractTag,P<:AbstractPass,B} end

include("utilities.jl")
include("context.jl")
include("tagged.jl")
include("overdub.jl")
include("contextdef.jl")
include("macros.jl")

function __init__()
Expand Down
87 changes: 87 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
##################
# `AbstractPass` #
##################

abstract type AbstractPass end

struct NoPass <: AbstractPass end

(::Type{NoPass})(::Any, ::Any, code_info) = code_info

#########
# `Tag` #
#########

# this @pure annotation has official vtjnash approval :p
Base.@pure _pure_objectid(x) = objectid(x)

abstract type AbstractContextName end

struct Tag{N<:AbstractContextName,X,E#=<:Union{Nothing,Tag}=#} end

Tag(::Type{N}, ::Type{X}) where {N,X} = Tag(N, X, Nothing)

Tag(::Type{N}, ::Type{X}, ::Type{E}) where {N,X,E} = Tag{N,pure_objectid(X),E}()

#################
# `BindingMeta` #
#################
# We define these here because we need them to define `Context`,
# but most code that works with these types is in src/tagged.jl

mutable struct BindingMeta
data::Any
BindingMeta() = new()
end

const BindingMetaDict = Dict{Symbol,BindingMeta}
const BindingMetaCache = IdDict{Module,BindingMetaDict}

#############
# `Context` #
#############

struct Context{N<:AbstractContextName,
M<:Any,
P<:AbstractPass,
T<:Union{Nothing,Tag},
B<:Union{Nothing,BindingMetaCache}}
name::N
metadata::M
pass::P
tag::T
bindings::B
function Context(name::N, metadata::M, pass::P, ::Nothing, ::Nothing) where {N,M,P}
return new{N,M,P,Nothing,Nothing}(name, metadata, pass, nothing, nothing)
end
function Context(name::N, metadata::M, pass::P, tag::Tag{N}, bindings::BindingMetaCache) where {N,M,P}
return new{N,M,P,typeof(tag),BindingMetaCache}(name, metadata, pass, tag, bindings)
end
end

function Context(name::AbstractContextName; metadata = nothing, pass::AbstractPass = NoPass())
return Context(name, metadata, pass, nothing, nothing)
end

function similarcontext(context::Context;
metadata = context.metadata,
pass = context.pass,
tag = context.tag,
bindings = context.bindings)
return Context(context.name, metadata, pass, tag, bindings)
end

const ContextWithTag{T} = Context{<:AbstractContextName,<:Any,<:AbstractPass,T}
const ContextWithPass{P} = Context{<:AbstractContextName,<:Any,P}

has_tagging_enabled(::Type{<:ContextWithTag{<:Tag}}) = true
has_tagging_enabled(::Type{<:ContextWithTag{Nothing}}) = false

tagtype(::C) where {C<:Context} = tagtype(C)
tagtype(::Type{<:ContextWithTag{T}}) where {T} = T

function withtagfor(context::Context, f)
return similarcontext(context;
tag = Tag(typeof(context), typeof(f)),
bindings = BindingMetaCache())
end
117 changes: 0 additions & 117 deletions src/contextdef.jl

This file was deleted.

80 changes: 73 additions & 7 deletions src/macros.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
const CONTEXT_TYPE_BINDING = Symbol("__CONTEXT__")
const CONTEXT_BINDING = Symbol("__context__")

############
# @context #
############
Expand All @@ -8,7 +11,75 @@
Define a new Cassette context type with the name `Ctx`.
"""
macro context(Ctx)
return esc(generate_context_definition(Ctx))
@assert isa(Ctx, Symbol) "context name must be a Symbol"
CtxName = gensym(string(Ctx, "Name"))
return esc(quote
struct $CtxName <: $Cassette.AbstractContextName end

const $Ctx{M,T<:Union{Nothing,$Cassette.Tag},P<:$Cassette.AbstractPass} = $Cassette.Context{$CtxName,M,P,T}

$Ctx(; kwargs...) = $Cassette.Context($CtxName(); kwargs...)

$Cassette.@primitive function $Cassette.Tag(::Type{N}, ::Type{X}) where {__CONTEXT__<:$Ctx,N,X}
return Tag(N, X, $Cassette.tagtype(__CONTEXT__))
end

$Cassette.@primitive function Core._apply(f, args...) where {__CONTEXT__<:$Ctx}
flattened_args = Core._apply(tuple, args...)
return $Cassette.overdub_execute(__context__, f, flattened_args...)
end

# enforce `T<:Cassette.Tag` to ensure that we only call the below primitive functions
# if the context has the tagging system enabled

$Cassette.@primitive function Array{T,N}(undef::UndefInitializer, args...) where {T,N,__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_new(__context__, Array{T,N}, undef, args...)
end

$Cassette.@primitive function Base.nameof(m) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_nameof(__context__, m)
end

$Cassette.@primitive function Core.getfield(x, name) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_getfield(__context__, x, name)
end

$Cassette.@primitive function Core.setfield!(x, name, y) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_setfield!(__context__, x, name, y)
end

$Cassette.@primitive function Core.arrayref(boundscheck, x, i) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_arrayref(__context__, boundscheck, x, i)
end

$Cassette.@primitive function Core.arrayset(boundscheck, x, y, i) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_arrayset(__context__, boundscheck, x, y, i)
end

$Cassette.@primitive function Base._growbeg!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_growbeg!(__context__, x, delta)
end

$Cassette.@primitive function Base._growend!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_growend!(__context__, x, delta)
end

$Cassette.@primitive function Base._growat!(x, i, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_growat!(__context__, x, i, delta)
end

$Cassette.@primitive function Base._deletebeg!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_deletebeg!(__context__, x, delta)
end

$Cassette.@primitive function Base._deleteend!(x, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_deleteend!(__context__, x, delta)
end

$Cassette.@primitive function Base._deleteat!(x, i, delta) where {__CONTEXT__<:$Ctx{<:Any,<:$Cassette.Tag}}
return $Cassette.tagged_deleteat!(__context__, x, i, delta)
end
end)
end

############
Expand All @@ -20,12 +91,7 @@ end
A convenience macro for overdubbing and executing `expression` within the context `Ctx`.
"""
macro overdub(ctx, expr)
return quote
func = $(esc(CONTEXT_BINDING)) -> $(esc(expr))
ctx = Cassette.similar_context($(esc(ctx));
tag = $Cassette.generate_tag($(esc(ctx)), func))
$Cassette.overdub_recurse(ctx, func, ctx)
end
return :($Cassette.overdub_recurse($(esc(ctx)), () -> $(esc(expr))))
end

#########
Expand Down
Loading