Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use CassetteBase.jl #55

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI
name: CassetteOverlayCI

on:
push:
Expand All @@ -8,7 +8,7 @@ on:

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} /w CassetteBase dev
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false # don't stop CI even when one of them fails
Expand Down Expand Up @@ -38,6 +38,31 @@ jobs:
- version: '1' # x64 macOS
os: macos-latest
arch: x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- name: dev CassetteBase # dev CassetteBase so that CassetteOverlay can be tested against a unreleased version
shell: julia --color=yes --project=. {0} # this is necessary for the next command to work on Windows
run: 'using Pkg; Pkg.develop(;path="./CassetteBase")'
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
file: ./lcov.info

release-test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} /w CassetteBase released
runs-on: ${{ matrix.os }}
strategy:
matrix:
include:
- version: '1' # current stable
os: ubuntu-latest
arch: x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ uuid = "d78b62d4-37fa-4a6f-acd8-2f19986eb9ee"
authors = ["JuliaHub, Inc. and other contributors"]
version = "0.1.11"

[deps]
CassetteBase = "6dd3e646-b1c5-42c7-94be-00277fa12e22"

[compat]
CassetteBase = "0.1"
julia = "1.10"

[extras]
Expand Down
153 changes: 26 additions & 127 deletions src/CassetteOverlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module CassetteOverlay
export @MethodTable, @overlay, @overlaypass, getpass, nonoverlay, @nonoverlay,
AbstractBindingOverlay, Overlay

using Core.IR
using Core: SimpleVector, MethodTable
using CassetteBase

using Core: MethodTable
using Base.Experimental: @MethodTable, @overlay

abstract type OverlayPass end
Expand All @@ -29,129 +30,31 @@ macro nonoverlay(ex)
return esc(out)
end

function overlay_transform!(src::CodeInfo, mi::MethodInstance, nargs::Int)
method = mi.def::Method
mnargs = Int(method.nargs)

src.slotnames = Symbol[Symbol("#self#"), :fargs, src.slotnames[mnargs+1:end]...]
src.slotflags = UInt8[ 0x00, 0x00, src.slotflags[mnargs+1:end]...]

code = src.code
fargsslot = SlotNumber(2)
precode = Any[]
local ssaid = 0
for i = 1:mnargs
if method.isva && i == mnargs
tuplecall = Expr(:call, tuple)
for j = i:nargs
push!(precode, Expr(:call, getfield, fargsslot, j))
ssaid += 1
push!(tuplecall.args, SSAValue(ssaid))
end
push!(precode, tuplecall)
else
push!(precode, Expr(:call, getfield, fargsslot, i))
end
ssaid += 1
end
prepend!(code, precode)
@static if VERSION < v"1.12.0-DEV.173"
prepend!(src.codelocs, [0 for i = 1:ssaid])
else
di = Core.Compiler.DebugInfoStream(mi, src.debuginfo, length(code))
src.debuginfo = Core.DebugInfo(di, length(code))
end
prepend!(src.ssaflags, [0x00 for i = 1:ssaid])
src.ssavaluetypes += ssaid
if @static isdefined(Base, :__has_internal_change) && Base.__has_internal_change(v"1.12-alpha", :codeinfonargs)
src.nargs = 2
src.isva = true
end

function map_slot_number(slot::Int)
@assert slot ≥ 1
if 1 ≤ slot ≤ mnargs
if method.isva && slot == mnargs
return SSAValue(ssaid)
else
return SSAValue(slot)
end
else
return SlotNumber(slot - mnargs + 2)
function make_overlay_generator(selfname::Symbol, fargsname::Symbol)
function overlay_generator(world::UInt, source::LineNumberNode, passtype, fargtypes)
@nospecialize passtype fargtypes
try
return generate_overlay_src(world, source, passtype, fargtypes, selfname, fargsname)
catch err
# internal error happened - return an expression to raise the special exception
return generate_internalerr_ex(
err, #=bt=#catch_backtrace(), #=context=#:overlay_generator, world, source,
#=argnames=#Core.svec(selfname, fargsname), #=spnames=#Core.svec(),
#=metadata=#(; world, source, passtype, fargtypes))
end
end
map_ssa_value(id::Int) = id + ssaid
for i = (ssaid+1:length(code))
code[i] = transform_stmt(code[i], map_slot_number, map_ssa_value, mi.def.sig, mi.sparam_vals)
end

src.edges = MethodInstance[mi]
src.method_for_inference_limit_heuristics = method

return src
end

function transform_stmt(@nospecialize(x), map_slot_number, map_ssa_value, @nospecialize(spsig), sparams::SimpleVector)
transform(@nospecialize x′) = transform_stmt(x′, map_slot_number, map_ssa_value, spsig, sparams)
if isa(x, Expr)
head = x.head
if head === :call
return Expr(:call, SlotNumber(1), map(transform, x.args[1:end])...)
elseif head === :foreigncall
arg1 = x.args[1]
if Meta.isexpr(arg1, :call)
# first argument of :foreigncall may be a magic tuple call, and it should be preserved
arg1 = Expr(:call, map(transform, arg1.args)...)
else
arg1 = transform(x.args[1])
end
arg2 = @ccall jl_instantiate_type_in_env(x.args[2]::Any, spsig::Any, sparams::Ptr{Any})::Any
arg3 = Core.svec(Any[
@ccall jl_instantiate_type_in_env(argt::Any, spsig::Any, sparams::Ptr{Any})::Any
for argt in x.args[3]::SimpleVector ]...)
return Expr(:foreigncall, arg1, arg2, arg3, map(transform, x.args[4:end])...)
elseif head === :enter
return Expr(:enter, map_ssa_value(x.args[1]::Int))
elseif head === :static_parameter
return sparams[x.args[1]::Int]
elseif head === :isdefined
arg1 = x.args[1]
if Meta.isexpr(arg1, :static_parameter)
return 1 ≤ arg1.args[1]::Int ≤ length(sparams)
end
end
return Expr(head, map(transform, x.args)...)
elseif isa(x, GotoNode)
return GotoNode(map_ssa_value(x.label))
elseif isa(x, GotoIfNot)
return GotoIfNot(transform(x.cond), map_ssa_value(x.dest))
elseif isa(x, ReturnNode)
return ReturnNode(transform(x.val))
elseif isa(x, SlotNumber)
return map_slot_number(x.id)
elseif isa(x, NewvarNode)
return NewvarNode(map_slot_number(x.slot.id))
elseif isa(x, SSAValue)
return SSAValue(map_ssa_value(x.id))
elseif @static @isdefined(EnterNode) && isa(x, EnterNode)
if isdefined(x, :scope)
return EnterNode(map_ssa_value(x.catch_dest), transform(x.scope))
else
return EnterNode(map_ssa_value(x.catch_dest))
end
end
return x
end

function overlay_generator(world::UInt, source::LineNumberNode, passtype, fargtypes)
function generate_overlay_src(world::UInt, #=source=#::LineNumberNode, passtype, fargtypes,
selfname::Symbol, fargsname::Symbol)
@nospecialize passtype fargtypes
tt = Base.to_tuple_type(fargtypes)
match = Base._which(tt; method_table=methodtable(passtype), raise=false, world)
match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
mi = Core.Compiler.specialize_method(match)
src = Core.Compiler.retrieve_code_info(mi, world)
src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it
overlay_transform!(src, mi, length(fargtypes))
cassette_transform!(src, mi, length(fargtypes), selfname, fargsname)
return src
end

Expand All @@ -166,11 +69,11 @@ macro overlaypass(args...)
if PassName === nothing
PassName = esc(gensym(string(method_table)))
decl_pass = :(struct $PassName <: $OverlayPass end)
ret = :($PassName())
retval = :($PassName())
else
PassName = esc(PassName)
decl_pass = :(@assert $PassName <: $OverlayPass)
ret = nothing
retval = nothing
end

nonoverlaytype = typeof(CassetteOverlay.nonoverlay)
Expand Down Expand Up @@ -208,7 +111,7 @@ macro overlaypass(args...)
# the main code transformation pass
mainpass = quote
function (pass::$PassName)(fargs...)
$(Expr(:meta, :generated, overlay_generator))
$(Expr(:meta, :generated, make_overlay_generator(:pass, :fargs)))
# also include a fallback implementation that will be used when this method
# is dynamically dispatched with `!isdispatchtuple` signatures.
return first(fargs)(Base.tail(fargs)...)
Expand All @@ -219,20 +122,16 @@ macro overlaypass(args...)
# nonoverlay primitives
nonoverlaypass = quote
@nospecialize
@inline function (pass::$PassName)(::$nonoverlaytype, f, args...; kwargs...)
return f(args...; kwargs...)
end
@inline (pass::$PassName)(::$nonoverlaytype,
f, args...; kwargs...) = f(args...; kwargs...)
@inline (pass::$PassName)(::typeof(Core.kwcall),
kwargs::Any, ::$nonoverlaytype, fargs...) = Core.kwcall(kwargs, fargs...)
@specialize

@inline function (pass::$PassName)(::typeof(Core.kwcall), kwargs::Any, ::$nonoverlaytype, fargs...)
@nospecialize kwargs fargs
return Core.kwcall(kwargs, fargs...)
end

return $ret
end
append!(topblk.args, nonoverlaypass.args)

push!(topblk.args, :(return $retval))

return topblk
end

Expand Down
Loading