Skip to content

Commit

Permalink
use CassetteBase.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Jun 21, 2024
1 parent 6c4f246 commit 3580dc2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 129 deletions.
29 changes: 27 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI
name: CassetteOverlayCI

on:
push:
Expand All @@ -8,7 +8,7 @@ on:

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} /w CassetteBase dev
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false # don't stop CI even when one of them fails
Expand Down Expand Up @@ -38,6 +38,31 @@ jobs:
- version: '1' # x64 macOS
os: macos-latest
arch: x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- name: dev CassetteBase # dev CassetteBase so that CassetteOverlay can be tested against a unreleased version
shell: julia --color=yes --project=. {0} # this is necessary for the next command to work on Windows
run: 'using Pkg; Pkg.develop(;path="./CassetteBase")'
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
file: ./lcov.info

release-test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} /w CassetteBase released
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- version: '1' # current stable
os: ubuntu-latest
arch: x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ uuid = "d78b62d4-37fa-4a6f-acd8-2f19986eb9ee"
authors = ["JuliaHub, Inc. and other contributors"]
version = "0.1.11"

[deps]
CassetteBase = "6dd3e646-b1c5-42c7-94be-00277fa12e22"

[compat]
CassetteBase = "0.1"
julia = "1.10"

[extras]
Expand Down
153 changes: 26 additions & 127 deletions src/CassetteOverlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module CassetteOverlay
export @MethodTable, @overlay, @overlaypass, getpass, nonoverlay, @nonoverlay,
AbstractBindingOverlay, Overlay

using Core.IR
using Core: SimpleVector, MethodTable
using CassetteBase

using Core: MethodTable
using Base.Experimental: @MethodTable, @overlay

abstract type OverlayPass end
Expand All @@ -29,129 +30,31 @@ macro nonoverlay(ex)
return esc(out)
end

function overlay_transform!(src::CodeInfo, mi::MethodInstance, nargs::Int)
method = mi.def::Method
mnargs = Int(method.nargs)

src.slotnames = Symbol[Symbol("#self#"), :fargs, src.slotnames[mnargs+1:end]...]
src.slotflags = UInt8[ 0x00, 0x00, src.slotflags[mnargs+1:end]...]

code = src.code
fargsslot = SlotNumber(2)
precode = Any[]
local ssaid = 0
for i = 1:mnargs
if method.isva && i == mnargs
tuplecall = Expr(:call, tuple)
for j = i:nargs
push!(precode, Expr(:call, getfield, fargsslot, j))
ssaid += 1
push!(tuplecall.args, SSAValue(ssaid))
end
push!(precode, tuplecall)
else
push!(precode, Expr(:call, getfield, fargsslot, i))
end
ssaid += 1
end
prepend!(code, precode)
@static if VERSION < v"1.12.0-DEV.173"
prepend!(src.codelocs, [0 for i = 1:ssaid])
else
di = Core.Compiler.DebugInfoStream(mi, src.debuginfo, length(code))
src.debuginfo = Core.DebugInfo(di, length(code))
end
prepend!(src.ssaflags, [0x00 for i = 1:ssaid])
src.ssavaluetypes += ssaid
if @static isdefined(Base, :__has_internal_change) && Base.__has_internal_change(v"1.12-alpha", :codeinfonargs)
src.nargs = 2
src.isva = true
end

function map_slot_number(slot::Int)
@assert slot 1
if 1 slot mnargs
if method.isva && slot == mnargs
return SSAValue(ssaid)
else
return SSAValue(slot)
end
else
return SlotNumber(slot - mnargs + 2)
function make_overlay_generator(selfname::Symbol, fargsname::Symbol)
function overlay_generator(world::UInt, source::LineNumberNode, passtype, fargtypes)
@nospecialize passtype fargtypes
try
return generate_overlay_src(world, source, passtype, fargtypes, selfname, fargsname)
catch err
# internal error happened - return an expression to raise the special exception
return generate_internalerr_ex(
err, #=bt=#catch_backtrace(), #=context=#:overlay_generator, world, source,
#=argnames=#Core.svec(selfname, fargsname), #=spnames=#Core.svec(),
#=metadata=#(; world, source, passtype, fargtypes))
end
end
map_ssa_value(id::Int) = id + ssaid
for i = (ssaid+1:length(code))
code[i] = transform_stmt(code[i], map_slot_number, map_ssa_value, mi.def.sig, mi.sparam_vals)
end

src.edges = MethodInstance[mi]
src.method_for_inference_limit_heuristics = method

return src
end

function transform_stmt(@nospecialize(x), map_slot_number, map_ssa_value, @nospecialize(spsig), sparams::SimpleVector)
transform(@nospecialize x′) = transform_stmt(x′, map_slot_number, map_ssa_value, spsig, sparams)
if isa(x, Expr)
head = x.head
if head === :call
return Expr(:call, SlotNumber(1), map(transform, x.args[1:end])...)
elseif head === :foreigncall
arg1 = x.args[1]
if Meta.isexpr(arg1, :call)
# first argument of :foreigncall may be a magic tuple call, and it should be preserved
arg1 = Expr(:call, map(transform, arg1.args)...)
else
arg1 = transform(x.args[1])
end
arg2 = @ccall jl_instantiate_type_in_env(x.args[2]::Any, spsig::Any, sparams::Ptr{Any})::Any
arg3 = Core.svec(Any[
@ccall jl_instantiate_type_in_env(argt::Any, spsig::Any, sparams::Ptr{Any})::Any
for argt in x.args[3]::SimpleVector ]...)
return Expr(:foreigncall, arg1, arg2, arg3, map(transform, x.args[4:end])...)
elseif head === :enter
return Expr(:enter, map_ssa_value(x.args[1]::Int))
elseif head === :static_parameter
return sparams[x.args[1]::Int]
elseif head === :isdefined
arg1 = x.args[1]
if Meta.isexpr(arg1, :static_parameter)
return 1 arg1.args[1]::Int length(sparams)
end
end
return Expr(head, map(transform, x.args)...)
elseif isa(x, GotoNode)
return GotoNode(map_ssa_value(x.label))
elseif isa(x, GotoIfNot)
return GotoIfNot(transform(x.cond), map_ssa_value(x.dest))
elseif isa(x, ReturnNode)
return ReturnNode(transform(x.val))
elseif isa(x, SlotNumber)
return map_slot_number(x.id)
elseif isa(x, NewvarNode)
return NewvarNode(map_slot_number(x.slot.id))
elseif isa(x, SSAValue)
return SSAValue(map_ssa_value(x.id))
elseif @static @isdefined(EnterNode) && isa(x, EnterNode)
if isdefined(x, :scope)
return EnterNode(map_ssa_value(x.catch_dest), transform(x.scope))
else
return EnterNode(map_ssa_value(x.catch_dest))
end
end
return x
end

function overlay_generator(world::UInt, source::LineNumberNode, passtype, fargtypes)
function generate_overlay_src(world::UInt, source::LineNumberNode, passtype, fargtypes,
selfname::Symbol, fargsname::Symbol)
@nospecialize passtype fargtypes
tt = Base.to_tuple_type(fargtypes)
match = Base._which(tt; method_table=method_table(passtype), raise=false, world)
match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
mi = Core.Compiler.specialize_method(match)
src = Core.Compiler.retrieve_code_info(mi, world)
src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it
overlay_transform!(src, mi, length(fargtypes))
cassette_transform!(src, mi, length(fargtypes), selfname, fargsname)
return src
end

Expand All @@ -166,11 +69,11 @@ macro overlaypass(args...)
if PassName === nothing
PassName = esc(gensym(string(method_table)))
decl_pass = :(struct $PassName <: $OverlayPass end)
ret = :($PassName())
retval = :($PassName())
else
PassName = esc(PassName)
decl_pass = :(@assert $PassName <: $OverlayPass)
ret = nothing
retval = nothing
end

nonoverlaytype = typeof(CassetteOverlay.nonoverlay)
Expand Down Expand Up @@ -208,7 +111,7 @@ macro overlaypass(args...)
# the main code transformation pass
mainpass = quote
function (pass::$PassName)(fargs...)
$(Expr(:meta, :generated, overlay_generator))
$(Expr(:meta, :generated, make_overlay_generator(:pass, :fargs)))
# also include a fallback implementation that will be used when this method
# is dynamically dispatched with `!isdispatchtuple` signatures.
return first(fargs)(Base.tail(fargs)...)
Expand All @@ -219,20 +122,16 @@ macro overlaypass(args...)
# nonoverlay primitives
nonoverlaypass = quote
@nospecialize
@inline function (pass::$PassName)(::$nonoverlaytype, f, args...; kwargs...)
return f(args...; kwargs...)
end
@inline (pass::$PassName)(::$nonoverlaytype,
f, args...; kwargs...) = f(args...; kwargs...)
@inline (pass::$PassName)(::typeof(Core.kwcall),
kwargs::Any, ::$nonoverlaytype, fargs...) = Core.kwcall(kwargs, fargs...)
@specialize

@inline function (pass::$PassName)(::typeof(Core.kwcall), kwargs::Any, ::$nonoverlaytype, fargs...)
@nospecialize kwargs fargs
return Core.kwcall(kwargs, fargs...)
end

return $ret
end
append!(topblk.args, nonoverlaypass.args)

push!(topblk.args, :(return $retval))

return topblk
end

Expand Down

0 comments on commit 3580dc2

Please sign in to comment.